├── Dataset2_L1_FootTongue_test.mat ├── Dataset2_L1_FootTongue_train.mat ├── EEGInception.py ├── EEGInception_main.py ├── EEGModels.py ├── EEGNet_main.py ├── EEGSym_DataAugmentation.py ├── EEGSym_architecture.py ├── EEGSym_main.py ├── README.md ├── SBLEST.m ├── SBLEST_main.m ├── SBLEST_main.py ├── SBLEST_model.py ├── dCNN_main.py ├── sCNN_main.py ├── signal_target.py └── splitters.py /Dataset2_L1_FootTongue_test.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EEGdecoding/Code-SBLEST/7465f55116b0e4a9577b7e354169fb2027704428/Dataset2_L1_FootTongue_test.mat -------------------------------------------------------------------------------- /Dataset2_L1_FootTongue_train.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/EEGdecoding/Code-SBLEST/7465f55116b0e4a9577b7e354169fb2027704428/Dataset2_L1_FootTongue_train.mat -------------------------------------------------------------------------------- /EEGInception.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Activation, Input, Flatten 2 | from tensorflow.keras.layers import Dropout, BatchNormalization 3 | from tensorflow.keras.layers import Conv2D, AveragePooling2D, DepthwiseConv2D 4 | from tensorflow.keras.layers import Dense 5 | from tensorflow.keras.constraints import max_norm 6 | from tensorflow import keras 7 | 8 | 9 | def EEGInception(input_time=1000, fs=128, ncha=8, filters_per_branch=8, 10 | scales_time=(500, 250, 125), dropout_rate=0.25, 11 | activation='elu', n_classes=2, learning_rate=0.001): 12 | """Keras implementation of EEG-Inception. All hyperparameters and 13 | architectural choices are explained in the original article: 14 | 15 | https://doi.org/10.1109/TNSRE.2020.3048106 16 | 17 | Parameters 18 | ---------- 19 | input_time : int 20 | EEG epoch time in milliseconds 21 | fs : int 22 | Sample rate of the EEG 23 | ncha : 24 | Number of input channels 25 | filters_per_branch : int 26 | Number of filters in each Inception branch 27 | scales_time : list 28 | Temporal scale (ms) of the convolutions on each Inception module. 29 | This parameter determines the kernel sizes of the filters 30 | dropout_rate : float 31 | Dropout rate 32 | activation : str 33 | Activation 34 | n_classes : int 35 | Number of output classes 36 | learning_rate : float 37 | Learning rate 38 | 39 | Returns 40 | ------- 41 | model : keras.models.Model 42 | Keras model already compiled and ready to work 43 | 44 | """ 45 | 46 | # ============================= CALCULATIONS ============================= # 47 | input_samples = int(input_time * fs / 1000) 48 | scales_samples = [int(s * fs / 1000) for s in scales_time] 49 | 50 | # ================================ INPUT ================================= # 51 | input_layer = Input((input_samples, ncha, 1)) 52 | 53 | # ========================== BLOCK 1: INCEPTION ========================== # 54 | b1_units = list() 55 | for i in range(len(scales_samples)): 56 | unit = Conv2D(filters=filters_per_branch, 57 | kernel_size=(scales_samples[i], 1), 58 | kernel_initializer='he_normal', 59 | padding='same')(input_layer) 60 | unit = BatchNormalization()(unit) 61 | unit = Activation(activation)(unit) 62 | unit = Dropout(dropout_rate)(unit) 63 | 64 | unit = DepthwiseConv2D((1, ncha), 65 | use_bias=False, 66 | depth_multiplier=2, 67 | depthwise_constraint=max_norm(1.))(unit) 68 | unit = BatchNormalization()(unit) 69 | unit = Activation(activation)(unit) 70 | unit = Dropout(dropout_rate)(unit) 71 | 72 | b1_units.append(unit) 73 | 74 | # Concatenation 75 | b1_out = keras.layers.concatenate(b1_units, axis=3) 76 | b1_out = AveragePooling2D((4, 1))(b1_out) 77 | 78 | # ========================== BLOCK 2: INCEPTION ========================== # 79 | b2_units = list() 80 | for i in range(len(scales_samples)): 81 | unit = Conv2D(filters=filters_per_branch, 82 | kernel_size=(int(scales_samples[i]/4), 1), 83 | kernel_initializer='he_normal', 84 | use_bias=False, 85 | padding='same')(b1_out) 86 | unit = BatchNormalization()(unit) 87 | unit = Activation(activation)(unit) 88 | unit = Dropout(dropout_rate)(unit) 89 | 90 | b2_units.append(unit) 91 | 92 | # Concatenate + Average pooling 93 | b2_out = keras.layers.concatenate(b2_units, axis=3) 94 | b2_out = AveragePooling2D((2, 1))(b2_out) 95 | 96 | # ============================ BLOCK 3: OUTPUT =========================== # 97 | b3_u1 = Conv2D(filters=int(filters_per_branch*len(scales_samples)/2), 98 | kernel_size=(8, 1), 99 | kernel_initializer='he_normal', 100 | use_bias=False, 101 | padding='same')(b2_out) 102 | b3_u1 = BatchNormalization()(b3_u1) 103 | b3_u1 = Activation(activation)(b3_u1) 104 | b3_u1 = AveragePooling2D((2, 1))(b3_u1) 105 | b3_u1 = Dropout(dropout_rate)(b3_u1) 106 | 107 | b3_u2 = Conv2D(filters=int(filters_per_branch*len(scales_samples)/4), 108 | kernel_size=(4, 1), 109 | kernel_initializer='he_normal', 110 | use_bias=False, 111 | padding='same')(b3_u1) 112 | b3_u2 = BatchNormalization()(b3_u2) 113 | b3_u2 = Activation(activation)(b3_u2) 114 | b3_u2 = AveragePooling2D((2, 1))(b3_u2) 115 | b3_out = Dropout(dropout_rate)(b3_u2) 116 | 117 | # Output layer 118 | output_layer = Flatten()(b3_out) 119 | output_layer = Dense(n_classes, activation='softmax')(output_layer) 120 | 121 | # ================================ MODEL ================================= # 122 | model = keras.models.Model(inputs=input_layer, outputs=output_layer) 123 | optimizer = keras.optimizers.Adam(learning_rate=learning_rate, beta_1=0.9, 124 | beta_2=0.999, amsgrad=False) 125 | model.compile(loss='binary_crossentropy', optimizer=optimizer, 126 | metrics=['accuracy']) 127 | return model 128 | 129 | 130 | 131 | -------------------------------------------------------------------------------- /EEGInception_main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path 3 | from scipy.io import loadmat, savemat 4 | import numpy as np 5 | import logging 6 | import sys 7 | import random 8 | from signal_target import SignalAndTarget,convert_numbers_to_one_hot 9 | from splitters import split_into_two_sets 10 | 11 | from keras.callbacks import EarlyStopping 12 | import tensorflow as tf 13 | import keras.backend as ka 14 | import keras as k 15 | import sys 16 | from EEGInception import EEGInception 17 | 18 | time_start = time.time() 19 | 20 | # TensorFlow configuration for GPU usage 21 | config = tf.compat.v1.ConfigProto() 22 | config.gpu_options.allow_growth = True 23 | session = tf.compat.v1.Session(config=config) 24 | 25 | # Fix random seed 26 | seed=20190706 27 | random.seed(seed) 28 | os.environ['PYTHONHASHSEED'] = str(seed) 29 | np.random.seed(seed) 30 | tf.random.set_seed(seed) 31 | 32 | log = logging.getLogger(__name__) 33 | logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', 34 | level=logging.DEBUG, stream=sys.stdout) 35 | 36 | # Data folder where the datasets are located 37 | data_folder = 'C:/Users/Administrator/Desktop/Code-SBLEST-main' # The folder you download from https://github.com/EEGdecoding/Code-SBLEST 38 | 39 | 40 | # Fraction of data to be used as validation set 41 | valid_set_fraction= 0.2 42 | 43 | # Initialize train and test datasets 44 | X = np.zeros([1]) # np.ndarray([]) 45 | y= np.zeros([1]) # np.ndarray([]) 46 | train_set = SignalAndTarget(X, y) 47 | test_set = SignalAndTarget(X, y) 48 | 49 | # Load train and test datasets from .mat files 50 | train_filename = 'Dataset2_L1_FootTongue_train.mat' 51 | test_filename = 'Dataset2_L1_FootTongue_test.mat' 52 | train_filepath = os.path.join(data_folder, train_filename) 53 | test_filepath = os.path.join(data_folder, test_filename) 54 | train = loadmat(train_filepath) 55 | test = loadmat(test_filepath) 56 | 57 | # Prepare train and test datasets 58 | label_1d_train = train['Y_train'].astype(np.int32) 59 | label_1d_test = test['Y_test'].astype(np.int32) 60 | train_set.y =convert_numbers_to_one_hot(label_1d_train) 61 | test_set.y =convert_numbers_to_one_hot(label_1d_test) 62 | train_set.X = np.transpose(train['X_train'], (2, 1, 0)).astype(np.float32) 63 | test_set.X = np.transpose(test['X_test'], (2, 1, 0)).astype(np.float32) 64 | 65 | # Split train set into train and validation set 66 | train_set, valid_set = split_into_two_sets( 67 | train_set, first_set_fraction = 1 - valid_set_fraction 68 | ) 69 | 70 | # Prepare data for model training and evaluation 71 | X_train = np.expand_dims(train_set.X, axis=3) 72 | X_validate = np.expand_dims(valid_set.X, axis=3) 73 | X_test = np.expand_dims(test_set.X, axis=3) 74 | Y_train = train_set.y 75 | Y_valid = valid_set.y 76 | Y_test = test_set.y 77 | # Get number of channels and samples from input data 78 | chans = X_train.shape[1] 79 | samples = X_train.shape[2] 80 | print('X_train shape:', X_train.shape) 81 | print(X_train.shape[0], 'train samples') 82 | print(X_test.shape[0], 'test samples') 83 | 84 | # Create and compile the EEGNet model 85 | model = EEGInception( 86 | input_time=3000, fs=250, ncha=60, filters_per_branch=8, 87 | scales_time=(500, 250, 125), dropout_rate=0.5, 88 | activation='elu', n_classes=2, learning_rate=0.0001) 89 | 90 | model.compile(loss='categorical_crossentropy', optimizer= 'adam', 91 | metrics=['accuracy']) 92 | 93 | model.summary() 94 | 95 | # Callbacks 96 | early_stopping = EarlyStopping( 97 | monitor='loss', min_delta=0.0001, 98 | mode='min', patience=50, verbose=1, 99 | restore_best_weights=True) 100 | 101 | # Train the model 102 | fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=500, 103 | verbose=2, validation_data=(X_validate, Y_valid),callbacks=[early_stopping] 104 | ) 105 | 106 | # Predict the labels for the test set using the trained model and print the calculated classification accuracy 107 | probs = model.predict(X_test) 108 | preds = probs.argmax(axis=-1) 109 | acc = np.mean(preds == Y_test.argmax(axis=-1)) 110 | print("Classification accuracy: %f " % (acc)) 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /EEGModels.py: -------------------------------------------------------------------------------- 1 | """ 2 | ARL_EEGModels - A collection of Convolutional Neural Network models for EEG 3 | Signal Processing and Classification, using Keras and Tensorflow 4 | 5 | Requirements: 6 | (1) tensorflow == 2.X (as of this writing, 2.0 - 2.3 have been verified 7 | as working) 8 | 9 | To run the EEG/MEG ERP classification sample script, you will also need 10 | 11 | (4) mne >= 0.17.1 12 | (5) PyRiemann >= 0.2.5 13 | (6) scikit-learn >= 0.20.1 14 | (7) matplotlib >= 2.2.3 15 | 16 | To use: 17 | 18 | (1) Place this file in the PYTHONPATH variable in your IDE (i.e.: Spyder) 19 | (2) Import the model as 20 | 21 | from EEGModels import EEGNet 22 | 23 | model = EEGNet(nb_classes = ..., Chans = ..., Samples = ...) 24 | 25 | (3) Then compile and fit the model 26 | 27 | model.compile(loss = ..., optimizer = ..., metrics = ...) 28 | fitted = model.fit(...) 29 | predicted = model.predict(...) 30 | 31 | Portions of this project are works of the United States Government and are not 32 | subject to domestic copyright protection under 17 USC Sec. 105. Those 33 | portions are released world-wide under the terms of the Creative Commons Zero 34 | 1.0 (CC0) license. 35 | 36 | Other portions of this project are subject to domestic copyright protection 37 | under 17 USC Sec. 105. Those portions are licensed under the Apache 2.0 38 | license. The complete text of the license governing this material is in 39 | the file labeled LICENSE.TXT that is a part of this project's official 40 | distribution. 41 | """ 42 | 43 | from tensorflow.keras.models import Model 44 | from tensorflow.keras.layers import Dense, Activation, Permute, Dropout 45 | from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D 46 | from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D 47 | from tensorflow.keras.layers import BatchNormalization 48 | from tensorflow.keras.layers import SpatialDropout2D 49 | from tensorflow.keras.regularizers import l1_l2 50 | from tensorflow.keras.layers import Input, Flatten 51 | from tensorflow.keras.constraints import max_norm 52 | from tensorflow.keras import backend as K 53 | 54 | 55 | def EEGNet(nb_classes, Chans=60, Samples=250, 56 | dropoutRate=0.5, kernLength=125, F1=8, 57 | D=2, F2=16, norm_rate=0.25, dropoutType='Dropout'): 58 | """ Keras Implementation of EEGNet 59 | http://iopscience.iop.org/article/10.1088/1741-2552/aace8c/meta 60 | 61 | Note that this implements the newest version of EEGNet and NOT the earlier 62 | version (version v1 and v2 on arxiv). We strongly recommend using this 63 | architecture as it performs much better and has nicer properties than 64 | our earlier version. For example: 65 | 66 | 1. Depthwise Convolutions to learn spatial filters within a 67 | temporal convolution. The use of the depth_multiplier option maps 68 | exactly to the number of spatial filters learned within a temporal 69 | filter. This matches the setup of algorithms like FBCSP which learn 70 | spatial filters within each filter in a filter-bank. This also limits 71 | the number of free parameters to fit when compared to a fully-connected 72 | convolution. 73 | 74 | 2. Separable Convolutions to learn how to optimally combine spatial 75 | filters across temporal bands. Separable Convolutions are Depthwise 76 | Convolutions followed by (1x1) Pointwise Convolutions. 77 | 78 | 79 | While the original paper used Dropout, we found that SpatialDropout2D 80 | sometimes produced slightly better results for classification of ERP 81 | signals. However, SpatialDropout2D significantly reduced performance 82 | on the Oscillatory dataset (SMR, BCI-IV Dataset 2A). We recommend using 83 | the default Dropout in most cases. 84 | 85 | Assumes the input signal is sampled at 128Hz. If you want to use this model 86 | for any other sampling rate you will need to modify the lengths of temporal 87 | kernels and average pooling size in blocks 1 and 2 as needed (double the 88 | kernel lengths for double the sampling rate, etc). Note that we haven't 89 | tested the model performance with this rule so this may not work well. 90 | 91 | The model with default parameters gives the EEGNet-8,2 model as discussed 92 | in the paper. This model should do pretty well in general, although it is 93 | advised to do some model searching to get optimal performance on your 94 | particular dataset. 95 | 96 | We set F2 = F1 * D (number of input filters = number of output filters) for 97 | the SeparableConv2D layer. We haven't extensively tested other values of this 98 | parameter (say, F2 < F1 * D for compressed learning, and F2 > F1 * D for 99 | overcomplete). We believe the main parameters to focus on are F1 and D. 100 | 101 | Inputs: 102 | 103 | nb_classes : int, number of classes to classify 104 | Chans, Samples : number of channels and time points in the EEG data 105 | dropoutRate : dropout fraction 106 | kernLength : length of temporal convolution in first layer. We found 107 | that setting this to be half the sampling rate worked 108 | well in practice. For the SMR dataset in particular 109 | since the data was high-passed at 4Hz we used a kernel 110 | length of 32. 111 | F1, F2 : number of temporal filters (F1) and number of pointwise 112 | filters (F2) to learn. Default: F1 = 8, F2 = F1 * D. 113 | D : number of spatial filters to learn within each temporal 114 | convolution. Default: D = 2 115 | dropoutType : Either SpatialDropout2D or Dropout, passed as a string. 116 | 117 | """ 118 | 119 | if dropoutType == 'SpatialDropout2D': 120 | dropoutType = SpatialDropout2D 121 | elif dropoutType == 'Dropout': 122 | dropoutType = Dropout 123 | else: 124 | raise ValueError('dropoutType must be one of SpatialDropout2D ' 125 | 'or Dropout, passed as a string.') 126 | 127 | input1 = Input(shape=(Chans, Samples, 1)) 128 | 129 | ################################################################## 130 | block1 = Conv2D(F1, (1, kernLength), padding='same', 131 | input_shape=(Chans, Samples, 1), 132 | use_bias=False)(input1) 133 | block1 = BatchNormalization()(block1) 134 | block1 = DepthwiseConv2D((Chans, 1), use_bias=False, 135 | depth_multiplier=D, 136 | depthwise_constraint=max_norm(1.))(block1) 137 | block1 = BatchNormalization()(block1) 138 | block1 = Activation('elu')(block1) 139 | block1 = AveragePooling2D((1, 8))(block1) 140 | block1 = dropoutType(dropoutRate)(block1) 141 | 142 | block2 = SeparableConv2D(F2, (1, 32), 143 | use_bias=False, padding='same')(block1) 144 | block2 = BatchNormalization()(block2) 145 | block2 = Activation('elu')(block2) 146 | block2 = AveragePooling2D((1, 16))(block2) 147 | block2 = dropoutType(dropoutRate)(block2) 148 | 149 | flatten = Flatten(name='flatten')(block2) 150 | 151 | dense = Dense(nb_classes, name='dense', 152 | kernel_constraint=max_norm(norm_rate))(flatten) 153 | softmax = Activation('softmax', name='softmax')(dense) 154 | 155 | return Model(inputs=input1, outputs=softmax) 156 | 157 | 158 | 159 | 160 | def EEGNet_SSVEP(nb_classes = 12, Chans = 8, Samples = 256, 161 | dropoutRate = 0.5, kernLength = 256, F1 = 96, 162 | D = 1, F2 = 96, dropoutType = 'Dropout'): 163 | """ SSVEP Variant of EEGNet, as used in [1]. 164 | 165 | Inputs: 166 | 167 | nb_classes : int, number of classes to classify 168 | Chans, Samples : number of channels and time points in the EEG data 169 | dropoutRate : dropout fraction 170 | kernLength : length of temporal convolution in first layer 171 | F1, F2 : number of temporal filters (F1) and number of pointwise 172 | filters (F2) to learn. 173 | D : number of spatial filters to learn within each temporal 174 | convolution. 175 | dropoutType : Either SpatialDropout2D or Dropout, passed as a string. 176 | 177 | 178 | [1]. Waytowich, N. et. al. (2018). Compact Convolutional Neural Networks 179 | for Classification of Asynchronous Steady-State Visual Evoked Potentials. 180 | Journal of Neural Engineering vol. 15(6). 181 | http://iopscience.iop.org/article/10.1088/1741-2552/aae5d8 182 | 183 | """ 184 | 185 | if dropoutType == 'SpatialDropout2D': 186 | dropoutType = SpatialDropout2D 187 | elif dropoutType == 'Dropout': 188 | dropoutType = Dropout 189 | else: 190 | raise ValueError('dropoutType must be one of SpatialDropout2D ' 191 | 'or Dropout, passed as a string.') 192 | 193 | input1 = Input(shape = (Chans, Samples, 1)) 194 | 195 | ################################################################## 196 | block1 = Conv2D(F1, (1, kernLength), padding = 'same', 197 | input_shape = (Chans, Samples, 1), 198 | use_bias = False)(input1) 199 | block1 = BatchNormalization()(block1) 200 | block1 = DepthwiseConv2D((Chans, 1), use_bias = False, 201 | depth_multiplier = D, 202 | depthwise_constraint = max_norm(1.))(block1) 203 | block1 = BatchNormalization()(block1) 204 | block1 = Activation('elu')(block1) 205 | block1 = AveragePooling2D((1, 4))(block1) 206 | block1 = dropoutType(dropoutRate)(block1) 207 | 208 | block2 = SeparableConv2D(F2, (1, 16), 209 | use_bias = False, padding = 'same')(block1) 210 | block2 = BatchNormalization()(block2) 211 | block2 = Activation('elu')(block2) 212 | block2 = AveragePooling2D((1, 8))(block2) 213 | block2 = dropoutType(dropoutRate)(block2) 214 | 215 | flatten = Flatten(name = 'flatten')(block2) 216 | 217 | dense = Dense(nb_classes, name = 'dense')(flatten) 218 | softmax = Activation('softmax', name = 'softmax')(dense) 219 | 220 | return Model(inputs=input1, outputs=softmax) 221 | 222 | 223 | 224 | def EEGNet_old(nb_classes, Chans = 64, Samples = 128, regRate = 0.0001, 225 | dropoutRate = 0.25, kernels = [(2, 32), (8, 4)], strides = (2, 4)): 226 | """ Keras Implementation of EEGNet_v1 (https://arxiv.org/abs/1611.08024v2) 227 | 228 | This model is the original EEGNet model proposed on arxiv 229 | https://arxiv.org/abs/1611.08024v2 230 | 231 | with a few modifications: we use striding instead of max-pooling as this 232 | helped slightly in classification performance while also providing a 233 | computational speed-up. 234 | 235 | Note that we no longer recommend the use of this architecture, as the new 236 | version of EEGNet performs much better overall and has nicer properties. 237 | 238 | Inputs: 239 | 240 | nb_classes : total number of final categories 241 | Chans, Samples : number of EEG channels and samples, respectively 242 | regRate : regularization rate for L1 and L2 regularizations 243 | dropoutRate : dropout fraction 244 | kernels : the 2nd and 3rd layer kernel dimensions (default is 245 | the [2, 32] x [8, 4] configuration) 246 | strides : the stride size (note that this replaces the max-pool 247 | used in the original paper) 248 | 249 | """ 250 | 251 | # start the model 252 | input_main = Input((Chans, Samples)) 253 | layer1 = Conv2D(16, (Chans, 1), input_shape=(Chans, Samples, 1), 254 | kernel_regularizer = l1_l2(l1=regRate, l2=regRate))(input_main) 255 | layer1 = BatchNormalization()(layer1) 256 | layer1 = Activation('elu')(layer1) 257 | layer1 = Dropout(dropoutRate)(layer1) 258 | 259 | permute_dims = 2, 1, 3 260 | permute1 = Permute(permute_dims)(layer1) 261 | 262 | layer2 = Conv2D(4, kernels[0], padding = 'same', 263 | kernel_regularizer=l1_l2(l1=0.0, l2=regRate), 264 | strides = strides)(permute1) 265 | layer2 = BatchNormalization()(layer2) 266 | layer2 = Activation('elu')(layer2) 267 | layer2 = Dropout(dropoutRate)(layer2) 268 | 269 | layer3 = Conv2D(4, kernels[1], padding = 'same', 270 | kernel_regularizer=l1_l2(l1=0.0, l2=regRate), 271 | strides = strides)(layer2) 272 | layer3 = BatchNormalization()(layer3) 273 | layer3 = Activation('elu')(layer3) 274 | layer3 = Dropout(dropoutRate)(layer3) 275 | 276 | flatten = Flatten(name = 'flatten')(layer3) 277 | 278 | dense = Dense(nb_classes, name = 'dense')(flatten) 279 | softmax = Activation('softmax', name = 'softmax')(dense) 280 | 281 | return Model(inputs=input_main, outputs=softmax) 282 | 283 | 284 | 285 | def DeepConvNet(nb_classes, Chans = 64, Samples = 256, 286 | dropoutRate = 0.5): 287 | """ Keras implementation of the Deep Convolutional Network as described in 288 | Schirrmeister et. al. (2017), Human Brain Mapping. 289 | 290 | This implementation assumes the input is a 2-second EEG signal sampled at 291 | 128Hz, as opposed to signals sampled at 250Hz as described in the original 292 | paper. We also perform temporal convolutions of length (1, 5) as opposed 293 | to (1, 10) due to this sampling rate difference. 294 | 295 | Note that we use the max_norm constraint on all convolutional layers, as 296 | well as the classification layer. We also change the defaults for the 297 | BatchNormalization layer. We used this based on a personal communication 298 | with the original authors. 299 | 300 | ours original paper 301 | pool_size 1, 2 1, 3 302 | strides 1, 2 1, 3 303 | conv filters 1, 5 1, 10 304 | 305 | Note that this implementation has not been verified by the original 306 | authors. 307 | 308 | """ 309 | 310 | # start the model 311 | input_main = Input((Chans, Samples, 1)) 312 | block1 = Conv2D(25, (1, 5), 313 | input_shape=(Chans, Samples, 1), 314 | kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main) 315 | block1 = Conv2D(25, (Chans, 1), 316 | kernel_constraint = max_norm(2., axis=(0,1,2)))(block1) 317 | block1 = BatchNormalization(epsilon=1e-05, momentum=0.1)(block1) 318 | block1 = Activation('elu')(block1) 319 | block1 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block1) 320 | block1 = Dropout(dropoutRate)(block1) 321 | 322 | block2 = Conv2D(50, (1, 5), 323 | kernel_constraint = max_norm(2., axis=(0,1,2)))(block1) 324 | block2 = BatchNormalization(epsilon=1e-05, momentum=0.1)(block2) 325 | block2 = Activation('elu')(block2) 326 | block2 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block2) 327 | block2 = Dropout(dropoutRate)(block2) 328 | 329 | block3 = Conv2D(100, (1, 5), 330 | kernel_constraint = max_norm(2., axis=(0,1,2)))(block2) 331 | block3 = BatchNormalization(epsilon=1e-05, momentum=0.1)(block3) 332 | block3 = Activation('elu')(block3) 333 | block3 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block3) 334 | block3 = Dropout(dropoutRate)(block3) 335 | 336 | block4 = Conv2D(200, (1, 5), 337 | kernel_constraint = max_norm(2., axis=(0,1,2)))(block3) 338 | block4 = BatchNormalization(epsilon=1e-05, momentum=0.1)(block4) 339 | block4 = Activation('elu')(block4) 340 | block4 = MaxPooling2D(pool_size=(1, 2), strides=(1, 2))(block4) 341 | block4 = Dropout(dropoutRate)(block4) 342 | 343 | flatten = Flatten()(block4) 344 | 345 | dense = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten) 346 | softmax = Activation('softmax')(dense) 347 | 348 | return Model(inputs=input_main, outputs=softmax) 349 | 350 | 351 | # need these for ShallowConvNet 352 | def square(x): 353 | return K.square(x) 354 | 355 | def log(x): 356 | return K.log(K.clip(x, min_value = 1e-7, max_value = 10000)) 357 | 358 | 359 | def ShallowConvNet(nb_classes, Chans = 64, Samples = 128, dropoutRate = 0.5): 360 | """ Keras implementation of the Shallow Convolutional Network as described 361 | in Schirrmeister et. al. (2017), Human Brain Mapping. 362 | 363 | Assumes the input is a 2-second EEG signal sampled at 128Hz. Note that in 364 | the original paper, they do temporal convolutions of length 25 for EEG 365 | data sampled at 250Hz. We instead use length 13 since the sampling rate is 366 | roughly half of the 250Hz which the paper used. The pool_size and stride 367 | in later layers is also approximately half of what is used in the paper. 368 | 369 | Note that we use the max_norm constraint on all convolutional layers, as 370 | well as the classification layer. We also change the defaults for the 371 | BatchNormalization layer. We used this based on a personal communication 372 | with the original authors. 373 | 374 | ours original paper 375 | pool_size 1, 35 1, 75 376 | strides 1, 7 1, 15 377 | conv filters 1, 13 1, 25 378 | 379 | Note that this implementation has not been verified by the original 380 | authors. We do note that this implementation reproduces the results in the 381 | original paper with minor deviations. 382 | """ 383 | 384 | # start the model 385 | input_main = Input((Chans, Samples, 1)) 386 | block1 = Conv2D(40, (1, 13), 387 | input_shape=(Chans, Samples, 1), 388 | kernel_constraint = max_norm(2., axis=(0,1,2)))(input_main) 389 | block1 = Conv2D(40, (Chans, 1), use_bias=False, 390 | kernel_constraint = max_norm(2., axis=(0,1,2)))(block1) 391 | block1 = BatchNormalization(epsilon=1e-05, momentum=0.1)(block1) 392 | block1 = Activation(square)(block1) 393 | block1 = AveragePooling2D(pool_size=(1, 35), strides=(1, 7))(block1) 394 | block1 = Activation(log)(block1) 395 | block1 = Dropout(dropoutRate)(block1) 396 | flatten = Flatten()(block1) 397 | dense = Dense(nb_classes, kernel_constraint = max_norm(0.5))(flatten) 398 | softmax = Activation('softmax')(dense) 399 | 400 | return Model(inputs=input_main, outputs=softmax) 401 | 402 | 403 | def EEGNet2(nb_classes, Chans=64, Samples=128, 404 | dropoutRate=0.5, kernLength=125, F1=8, 405 | D=2, F2=16, norm_rate=0.25, dropoutType='Dropout'): 406 | """ Keras Implementation of EEGNet 407 | http://iopscience.iop.org/article/10.1088/1741-2552/aace8c/meta 408 | 409 | Note that this implements the newest version of EEGNet and NOT the earlier 410 | version (version v1 and v2 on arxiv). We strongly recommend using this 411 | architecture as it performs much better and has nicer properties than 412 | our earlier version. For example: 413 | 414 | 1. Depthwise Convolutions to learn spatial filters within a 415 | temporal convolution. The use of the depth_multiplier option maps 416 | exactly to the number of spatial filters learned within a temporal 417 | filter. This matches the setup of algorithms like FBCSP which learn 418 | spatial filters within each filter in a filter-bank. This also limits 419 | the number of free parameters to fit when compared to a fully-connected 420 | convolution. 421 | 422 | 2. Separable Convolutions to learn how to optimally combine spatial 423 | filters across temporal bands. Separable Convolutions are Depthwise 424 | Convolutions followed by (1x1) Pointwise Convolutions. 425 | 426 | 427 | While the original paper used Dropout, we found that SpatialDropout2D 428 | sometimes produced slightly better results for classification of ERP 429 | signals. However, SpatialDropout2D significantly reduced performance 430 | on the Oscillatory dataset (SMR, BCI-IV Dataset 2A). We recommend using 431 | the default Dropout in most cases. 432 | 433 | Assumes the input signal is sampled at 128Hz. If you want to use this model 434 | for any other sampling rate you will need to modify the lengths of temporal 435 | kernels and average pooling size in blocks 1 and 2 as needed (double the 436 | kernel lengths for double the sampling rate, etc). Note that we haven't 437 | tested the model performance with this rule so this may not work well. 438 | 439 | The model with default parameters gives the EEGNet-8,2 model as discussed 440 | in the paper. This model should do pretty well in general, although it is 441 | advised to do some model searching to get optimal performance on your 442 | particular dataset. 443 | 444 | We set F2 = F1 * D (number of input filters = number of output filters) for 445 | the SeparableConv2D layer. We haven't extensively tested other values of this 446 | parameter (say, F2 < F1 * D for compressed learning, and F2 > F1 * D for 447 | overcomplete). We believe the main parameters to focus on are F1 and D. 448 | 449 | Inputs: 450 | 451 | nb_classes : int, number of classes to classify 452 | Chans, Samples : number of channels and time points in the EEG data 453 | dropoutRate : dropout fraction 454 | kernLength : length of temporal convolution in first layer. We found 455 | that setting this to be half the sampling rate worked 456 | well in practice. For the SMR dataset in particular 457 | since the data was high-passed at 4Hz we used a kernel 458 | length of 32. 459 | F1, F2 : number of temporal filters (F1) and number of pointwise 460 | filters (F2) to learn. Default: F1 = 8, F2 = F1 * D. 461 | D : number of spatial filters to learn within each temporal 462 | convolution. Default: D = 2 463 | dropoutType : Either SpatialDropout2D or Dropout, passed as a string. 464 | 465 | """ 466 | 467 | if dropoutType == 'SpatialDropout2D': 468 | dropoutType = SpatialDropout2D 469 | elif dropoutType == 'Dropout': 470 | dropoutType = Dropout 471 | else: 472 | raise ValueError('dropoutType must be one of SpatialDropout2D ' 473 | 'or Dropout, passed as a string.') 474 | 475 | input1 = Input(shape=(Chans, Samples, 1)) 476 | 477 | ################################################################## 478 | block1 = Conv2D(F1, (1, kernLength), padding='same', 479 | input_shape=(Chans, Samples, 1), 480 | use_bias=False)(input1) 481 | block1 = BatchNormalization()(block1) 482 | block1 = DepthwiseConv2D((Chans, 1), use_bias=False, 483 | depth_multiplier=D, 484 | depthwise_constraint=max_norm(1.))(block1) 485 | block1 = BatchNormalization()(block1) 486 | block1 = Activation('elu')(block1) 487 | block1 = AveragePooling2D((1, 4))(block1) 488 | block1 = dropoutType(dropoutRate)(block1) 489 | 490 | block2 = SeparableConv2D(F2, (1, 16), 491 | use_bias=False, padding='same')(block1) 492 | block2 = BatchNormalization()(block2) 493 | block2 = Activation('elu')(block2) 494 | block2 = AveragePooling2D((1, 8))(block2) 495 | block2 = dropoutType(dropoutRate)(block2) 496 | 497 | flatten = Flatten(name='flatten')(block2) 498 | 499 | dense = Dense(nb_classes, name='dense', 500 | kernel_constraint=max_norm(norm_rate))(flatten) 501 | softmax = Activation('softmax', name='softmax')(dense) 502 | 503 | return Model(inputs=input1, outputs=softmax) 504 | -------------------------------------------------------------------------------- /EEGNet_main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path 3 | from scipy.io import loadmat, savemat 4 | import numpy as np 5 | import logging 6 | import sys 7 | import random 8 | from signal_target import SignalAndTarget,convert_numbers_to_one_hot 9 | from splitters import split_into_two_sets 10 | 11 | from keras.callbacks import EarlyStopping 12 | import tensorflow as tf 13 | import keras.backend as ka 14 | import keras as k 15 | import sys 16 | from EEGModels import EEGNet, ShallowConvNet, DeepConvNet 17 | 18 | time_start = time.time() 19 | 20 | # TensorFlow configuration for GPU usage 21 | config = tf.compat.v1.ConfigProto() 22 | config.gpu_options.allow_growth = True 23 | session = tf.compat.v1.Session(config=config) 24 | 25 | # Fix random seed 26 | seed=20190706 27 | random.seed(seed) 28 | os.environ['PYTHONHASHSEED'] = str(seed) 29 | np.random.seed(seed) 30 | tf.random.set_seed(seed) 31 | 32 | log = logging.getLogger(__name__) 33 | logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', 34 | level=logging.DEBUG, stream=sys.stdout) 35 | 36 | # Data folder where the datasets are located 37 | data_folder = 'C:/Users/Administrator/Desktop/Code-SBLEST-main' # The folder you download from https://github.com/EEGdecoding/Code-SBLEST 38 | 39 | 40 | # Fraction of data to be used as validation set 41 | valid_set_fraction= 0.2 42 | 43 | # Initialize train and test datasets 44 | X = np.zeros([1]) # np.ndarray([]) 45 | y= np.zeros([1]) # np.ndarray([]) 46 | train_set = SignalAndTarget(X, y) 47 | test_set = SignalAndTarget(X, y) 48 | 49 | # Load train and test datasets from .mat files 50 | train_filename = 'Dataset2_L1_FootTongue_train.mat' 51 | test_filename = 'Dataset2_L1_FootTongue_test.mat' 52 | train_filepath = os.path.join(data_folder, train_filename) 53 | test_filepath = os.path.join(data_folder, test_filename) 54 | train = loadmat(train_filepath) 55 | test = loadmat(test_filepath) 56 | 57 | # Prepare train and test datasets 58 | label_1d_train = train['Y_train'].astype(np.int32) 59 | label_1d_test = test['Y_test'].astype(np.int32) 60 | train_set.y =convert_numbers_to_one_hot(label_1d_train) 61 | test_set.y =convert_numbers_to_one_hot(label_1d_test) 62 | train_set.X = np.transpose(train['X_train'], (2, 0, 1)).astype(np.float32) 63 | test_set.X = np.transpose(test['X_test'], (2, 0, 1)).astype(np.float32) 64 | 65 | # Split train set into train and validation set 66 | train_set, valid_set = split_into_two_sets( 67 | train_set, first_set_fraction = 1 - valid_set_fraction 68 | ) 69 | 70 | # Prepare data for model training and evaluation 71 | X_train = np.expand_dims(train_set.X, axis=3) 72 | X_validate = np.expand_dims(valid_set.X, axis=3) 73 | X_test = np.expand_dims(test_set.X, axis=3) 74 | Y_train = train_set.y 75 | Y_valid = valid_set.y 76 | Y_test = test_set.y 77 | 78 | # Get number of channels and samples from input data 79 | chans = X_train.shape[1] 80 | samples = X_train.shape[2] 81 | print('X_train shape:', X_train.shape) 82 | print(X_train.shape[0], 'train samples') 83 | print(X_test.shape[0], 'test samples') 84 | 85 | # Create and compile the EEGNet model 86 | model = EEGNet(nb_classes=2, Chans=chans, Samples=samples, 87 | dropoutRate=0.5, kernLength=125, F1=8, D=2, F2=16, 88 | dropoutType='Dropout') 89 | 90 | model.compile(loss='categorical_crossentropy', optimizer= 'adam', 91 | metrics=['accuracy']) 92 | 93 | model.summary() 94 | 95 | # Callbacks 96 | early_stopping = EarlyStopping( 97 | monitor='loss', min_delta=0.0001, 98 | mode='min', patience=50, verbose=1, 99 | restore_best_weights=True) 100 | 101 | # Train the model 102 | fittedModel = model.fit(X_train, Y_train, batch_size=16, epochs=500, 103 | verbose=2, validation_data=(X_validate, Y_valid),callbacks=[early_stopping] 104 | ) 105 | 106 | # Predict the labels for the test set using the trained model and print the calculated classification accuracy 107 | probs = model.predict(X_test) 108 | preds = probs.argmax(axis=-1) 109 | acc = np.mean(preds == Y_test.argmax(axis=-1)) 110 | print("Classification accuracy: %f " % (acc)) 111 | 112 | 113 | 114 | -------------------------------------------------------------------------------- /EEGSym_DataAugmentation.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | def preprocessing_function(augmentation=True): 6 | """Custom Data Augmentation for EEGSym. 7 | 8 | Parameters 9 | ---------- 10 | augmentation : Bool 11 | If the augmentation is performed to the input. 12 | 13 | Returns 14 | ------- 15 | data_augmentation : function 16 | Data augmentation performed to each trial 17 | """ 18 | 19 | def data_augmentation(trial): 20 | """Custom Data Augmentation for EEGSym. 21 | 22 | Parameters 23 | ---------- 24 | trial : tf.tensor 25 | Input of the 26 | 27 | Returns 28 | ------- 29 | data_augmentation : keras.models.Model 30 | Data augmentation performed to each trial 31 | """ 32 | 33 | samples, ncha, _ = trial.shape 34 | 35 | augmentations = dict() 36 | augmentations["patch_perturbation"] = 0 37 | augmentations["random_shift"] = 0 38 | augmentations["hemisphere_perturbation"] = 0 39 | augmentations["no_augmentation"] = 0 40 | 41 | selectionables = ["patch_perturbation", "random_shift", 42 | "hemisphere_perturbation", "no_augmentation"] 43 | probabilities = None 44 | 45 | if augmentation: 46 | selection = np.random.choice(selectionables, p=probabilities) 47 | augmentations[selection] = 1 48 | 49 | method = np.random.choice((0, 2)) 50 | std = 'self' 51 | # elif data_augmentation == 1: # Random shift 52 | for _ in range(augmentations["random_shift"]): # Random shift 53 | # Select position where to erase that timeframe 54 | position = 0 55 | if position == 0: 56 | samples_shifted = np.random.randint(low=1, high=int( 57 | samples * 0.5 / 3)) 58 | else: 59 | samples_shifted = np.random.randint(low=1, high=int( 60 | samples * 0.1 / 3)) 61 | 62 | if method == 0: 63 | shifted_samples = np.zeros((samples_shifted, ncha, 1)) 64 | else: 65 | if std == 'self': 66 | std_applied = np.std(trial) 67 | else: 68 | std_applied = std 69 | center = 0 70 | shifted_samples = np.random.normal(center, std_applied, 71 | (samples_shifted, ncha, 72 | 1)) 73 | if position == 0: 74 | trial = np.concatenate((shifted_samples, trial), 75 | axis=0)[:samples] 76 | else: 77 | trial = np.concatenate((trial, shifted_samples), 78 | axis=0)[samples_shifted:] 79 | 80 | for _ in range( 81 | augmentations["patch_perturbation"]): # Patch perturbation 82 | channels_affected = np.random.randint(low=1, high=ncha - 1) 83 | pct_max = 1 84 | pct_min = 0.2 85 | pct_erased = np.random.uniform(low=pct_min, high=pct_max) 86 | # Select time to be erased acording to pct_erased 87 | # samples_erased = np.min((int(samples*ncha*pct_erased//channels_affected), samples))#np.random.randint(low=1, high=samples//3) 88 | samples_erased = int(samples * pct_erased) 89 | # Select position where to erase that timeframe 90 | if samples_erased != samples: 91 | samples_idx = np.arange(samples_erased) + np.random.randint( 92 | samples - samples_erased) 93 | else: 94 | samples_idx = np.arange(samples_erased) 95 | # Select indexes to erase (always keep at least a channel) 96 | channel_idx = np.random.permutation(np.arange(ncha))[ 97 | :channels_affected] 98 | channel_idx.sort() 99 | for channel in channel_idx: 100 | if method == 0: 101 | trial[samples_idx, channel] = 0 102 | else: 103 | if std == 'self': 104 | std_applied = np.std(trial[:, channel]) \ 105 | * np.random.uniform(low=0.01, high=2) 106 | else: 107 | std_applied = std 108 | center = 0 109 | trial[samples_idx, channel] += \ 110 | np.random.normal(center, std_applied, 111 | trial[samples_idx, channel, 112 | :].shape) 113 | # Standarize the channel again after the change 114 | temp_trial_ch_mean = np.mean(trial[:, channel], axis=0) 115 | temp_trial_ch_std = np.std(trial[:, channel], axis=0) 116 | trial[:, channel] = (trial[:, 117 | channel] - temp_trial_ch_mean) / temp_trial_ch_std 118 | 119 | for _ in range(augmentations["hemisphere_perturbation"]): 120 | # Select side to mix/change for noise 121 | left_right = np.random.choice((0, 1)) 122 | if method == 0: 123 | if left_right == 1: 124 | channel_idx = np.arange(ncha)[:int((ncha / 2) - 1)] 125 | channel_mix = np.random.permutation(channel_idx.copy()) 126 | else: 127 | channel_idx = np.arange(ncha)[-int((ncha / 2) - 1):] 128 | channel_mix = np.random.permutation(channel_idx.copy()) 129 | temp_trial = trial.copy() 130 | for channel, channel_mixed in zip(channel_idx, channel_mix): 131 | temp_trial[:, channel] = trial[:, channel_mixed] 132 | trial = temp_trial 133 | else: 134 | if left_right == 1: 135 | channel_idx = np.arange(ncha)[:int((ncha / 2) - 1)] 136 | else: 137 | channel_idx = np.arange(ncha)[-int((ncha / 2) - 1):] 138 | for channel in channel_idx: 139 | trial[:, channel] = np.random.normal(0, 1, 140 | trial[:, 141 | channel].shape) 142 | 143 | return trial 144 | 145 | return data_augmentation 146 | 147 | 148 | def trial_iterator(X, y, batch_size=32, shuffle=True, augmentation=True): 149 | """Custom trial iterator to pretrain EEGSym. 150 | 151 | Parameters 152 | ---------- 153 | X : tf.tensor 154 | Input tensor of EEG features. 155 | y : tf.tensor 156 | Input tensor of labels. 157 | batch_size : int 158 | Number of features in each batch. 159 | shuffle : Bool 160 | If the features are shuffled at each training epoch. 161 | augmentation : Bool 162 | If the augmentation is performed to the input. 163 | 164 | Returns 165 | ------- 166 | trial_iterator : tf.keras.preprocessing.image.NumpyArrayIterator 167 | Iterator used to train the model. 168 | """ 169 | 170 | trial_data_generator = tf.keras.preprocessing.image.ImageDataGenerator( 171 | preprocessing_function=preprocessing_function( 172 | augmentation=augmentation)) 173 | 174 | trial_iterator = tf.keras.preprocessing.image.NumpyArrayIterator( 175 | X, y, trial_data_generator, batch_size=batch_size, shuffle=shuffle, 176 | sample_weight=None, 177 | seed=None, data_format=None, save_to_dir=None, save_prefix='', 178 | save_format='png', subset=None, dtype=None 179 | ) 180 | return trial_iterator 181 | -------------------------------------------------------------------------------- /EEGSym_architecture.py: -------------------------------------------------------------------------------- 1 | from tensorflow.keras.layers import Activation, Input, Flatten 2 | from tensorflow.keras.layers import Dropout, BatchNormalization 3 | from tensorflow.keras.layers import Conv3D, Add, AveragePooling3D 4 | from tensorflow.keras.layers import Dense 5 | import tensorflow.keras as keras 6 | import tensorflow as tf 7 | import numpy as np 8 | 9 | def EEGSym(input_time=3000, fs=128, ncha=8, filters_per_branch=8, 10 | scales_time=(500, 250, 125), dropout_rate=0.25, activation='elu', 11 | n_classes=2, learning_rate=0.001, ch_lateral=3, 12 | spatial_resnet_repetitions=1, residual=True, symmetric=True): 13 | 14 | """Keras implementation of EEGSym. 15 | 16 | This model was initially designed for MI decodification of either 17 | left/right hand. 18 | Hyperparameters and architectural choices are explained in the 19 | original article. 20 | 21 | Parameters 22 | ---------- 23 | input_time : int 24 | EEG epoch time in milliseconds. 25 | fs : int 26 | Sample rate of the EEG. 27 | ncha : 28 | Number of input channels. 29 | filters_per_branch : int 30 | Number of filters in each Inception branch. The number should be 31 | multiplies of 8. 32 | scales_time : list 33 | Temporal scale of the temporal convolutions on first Inception module. 34 | This parameter determines the kernel sizes of the filters. 35 | dropout_rate : float 36 | Dropout rate 37 | activation : str 38 | Activation 39 | n_classes : int 40 | Number of output classes 41 | learning_rate : float 42 | Learning rate 43 | ch_lateral : int 44 | Number of channels that are attributed to one hemisphere of the head. 45 | spatial_resnet_repetitions: int 46 | Number of repetitions of the operations of spatial analysis at each 47 | step of the spatiotemporal analysis. In the original publication this 48 | value was set to 1 and not tested its variations. 49 | residual : Bool 50 | If the residual operations are present in EEGSym architecture. 51 | symmetric : Bool 52 | If the architecture considers the parameter ch_lateral to create two 53 | symmetric inputs of the electrodes. 54 | 55 | Returns 56 | ------- 57 | model : keras.models.Model 58 | Keras model already compiled and ready to work 59 | """ 60 | 61 | # ======================================================================== # 62 | # ================== GENERAL INCEPTION/RESIDUAL MODULE =================== # 63 | def general_module(input, scales_samples, filters_per_branch, ncha, 64 | activation, dropout_rate, average, 65 | spatial_resnet_repetitions=1, residual=True, 66 | init=False): 67 | """General inception/residual module. 68 | 69 | This function returns the input with the operations of a 70 | inception or residual module from the publication applied. 71 | 72 | Parameters 73 | ---------- 74 | input : list 75 | List of input blocks to the module. 76 | scales_samples : list 77 | List of samples size of the temporal operations kernels. 78 | filters_per_branch : int 79 | Number of filters in each Inception branch. The number should be 80 | multiplies of 8. 81 | ncha : 82 | Number of input channels. 83 | activation : str 84 | Activation 85 | dropout_rate : float 86 | Dropout rate 87 | spatial_resnet_repetitions: int 88 | Number of repetitions of the operations of spatial analysis at 89 | each step of the spatiotemporal analysis. In the original 90 | publication this value was set to 1 and not tested its 91 | variations. 92 | residual : Bool 93 | If the residual operations are present in EEGSym architecture. 94 | init : Bool 95 | If the module is the first one applied to the input, to apply a 96 | channel merging operation if the architecture does not include 97 | residual operations. 98 | 99 | Returns 100 | ------- 101 | block_out : list 102 | List of outputs modules 103 | """ 104 | block_units = list() 105 | unit_conv_t = list() 106 | unit_batchconv_t = list() 107 | 108 | for i in range(len(scales_samples)): 109 | unit_conv_t.append(Conv3D(filters=filters_per_branch, 110 | kernel_size=(1, scales_samples[i], 1), 111 | kernel_initializer='he_normal', 112 | padding='same')) 113 | unit_batchconv_t.append(BatchNormalization()) 114 | 115 | if ncha != 1: 116 | unit_dconv = list() 117 | unit_batchdconv = list() 118 | unit_conv_s = list() 119 | unit_batchconv_s = list() 120 | for i in range(spatial_resnet_repetitions): 121 | # 3D Implementation of DepthwiseConv 122 | unit_dconv.append(Conv3D(kernel_size=(1, 1, ncha), 123 | filters=filters_per_branch * len( 124 | scales_samples), 125 | groups=filters_per_branch * len( 126 | scales_samples), 127 | use_bias=False, 128 | padding='valid')) 129 | unit_batchdconv.append(BatchNormalization()) 130 | 131 | unit_conv_s.append(Conv3D(kernel_size=(1, 1, ncha), 132 | filters=filters_per_branch, 133 | # groups=filters_per_branch, 134 | use_bias=False, 135 | strides=(1, 1, 1), 136 | kernel_initializer='he_normal', 137 | padding='valid')) 138 | unit_batchconv_s.append(BatchNormalization()) 139 | 140 | unit_conv_1 = Conv3D(kernel_size=(1, 1, 1), 141 | filters=filters_per_branch, 142 | use_bias=False, 143 | kernel_initializer='he_normal', 144 | padding='valid') 145 | unit_batchconv_1 = BatchNormalization() 146 | 147 | for j in range(len(input)): 148 | block_side_units = list() 149 | for i in range(len(scales_samples)): 150 | unit = input[j] 151 | unit = unit_conv_t[i](unit) 152 | 153 | unit = unit_batchconv_t[i](unit) 154 | unit = Activation(activation)(unit) 155 | unit = Dropout(dropout_rate)(unit) 156 | 157 | block_side_units.append(unit) 158 | block_units.append(block_side_units) 159 | # Concatenation 160 | block_out = list() 161 | for j in range(len(input)): 162 | if len(block_units[j]) != 1: 163 | block_out.append( 164 | keras.layers.concatenate(block_units[j], axis=-1)) 165 | else: 166 | block_out.append(block_units[j][0]) 167 | 168 | if residual: 169 | if len(block_units[j]) != 1: 170 | block_out_temp = input[j] 171 | else: 172 | block_out_temp = input[j] 173 | block_out_temp = unit_conv_1(block_out_temp) 174 | 175 | block_out_temp = unit_batchconv_1(block_out_temp) 176 | block_out_temp = Activation(activation)(block_out_temp) 177 | block_out_temp = Dropout(dropout_rate)(block_out_temp) 178 | 179 | block_out[j] = Add()([block_out[j], block_out_temp]) 180 | 181 | if average != 1: 182 | block_out[j] = AveragePooling3D((1, average, 1))(block_out[j]) 183 | 184 | if ncha != 1: 185 | for i in range(spatial_resnet_repetitions): 186 | block_out_temp = list() 187 | for j in range(len(input)): 188 | if len(scales_samples) != 1: 189 | if residual: 190 | block_out_temp.append(block_out[j]) 191 | 192 | block_out_temp[j] = unit_dconv[i](block_out_temp[j]) 193 | 194 | block_out_temp[j] = unit_batchdconv[i]( 195 | block_out_temp[j]) 196 | block_out_temp[j] = Activation(activation)( 197 | block_out_temp[j]) 198 | block_out_temp[j] = Dropout(dropout_rate)( 199 | block_out_temp[j]) 200 | 201 | block_out[j] = Add()( 202 | [block_out[j], block_out_temp[j]]) 203 | 204 | elif init: 205 | block_out[j] = unit_dconv[i](block_out[j]) 206 | block_out[j] = unit_batchdconv[i](block_out[j]) 207 | block_out[j] = Activation(activation)(block_out[j]) 208 | block_out[j] = Dropout(dropout_rate)(block_out[j]) 209 | else: 210 | if residual: 211 | block_out_temp.append(block_out[j]) 212 | 213 | block_out_temp[j] = unit_conv_s[i]( 214 | block_out_temp[j]) 215 | block_out_temp[j] = unit_batchconv_s[i]( 216 | block_out_temp[j]) 217 | block_out_temp[j] = Activation(activation)( 218 | block_out_temp[j]) 219 | block_out_temp[j] = Dropout(dropout_rate)( 220 | block_out_temp[j]) 221 | 222 | block_out[j] = Add()( 223 | [block_out[j], block_out_temp[j]]) 224 | return block_out 225 | # ============================= CALCULATIONS ============================= # 226 | input_samples = int(input_time * fs / 1000) 227 | scales_samples = [int(s * fs / 1000) for s in scales_time] 228 | 229 | # ================================ INPUT ================================= # 230 | input_layer = Input((input_samples, ncha, 1)) 231 | input = tf.expand_dims(input_layer, axis=1) 232 | if symmetric: 233 | superposition = False 234 | if ch_lateral < ncha // 2: 235 | superposition = True 236 | ncha = ncha - ch_lateral 237 | 238 | left_idx = list(range(ch_lateral)) 239 | ch_left = tf.gather(input, indices=left_idx, axis=-2) 240 | right_idx = list(np.array(left_idx) + int(ncha)) 241 | ch_right = tf.gather(input, indices=right_idx, axis=-2) 242 | 243 | if superposition: 244 | # ch_central = crop(3, self.ch_lateral, -self.ch_lateral)(input) 245 | central_idx = list( 246 | np.array(range(ncha - ch_lateral)) + ch_lateral) 247 | ch_central = tf.gather(input, indices=central_idx, axis=-2) 248 | 249 | left_init = keras.layers.concatenate((ch_left, ch_central), 250 | axis=-2) 251 | right_init = keras.layers.concatenate((ch_right, ch_central), 252 | axis=-2) 253 | else: 254 | left_init = ch_left 255 | right_init = ch_right 256 | 257 | input = keras.layers.concatenate((left_init, right_init), axis=1) 258 | division = 2 259 | else: 260 | division = 1 261 | # ======================== TEMPOSPATIAL ANALYSIS ========================= # 262 | # ============================ Inception (x2) ============================ # 263 | b1_out = general_module([input], 264 | scales_samples=scales_samples, 265 | filters_per_branch=filters_per_branch, 266 | ncha=ncha, 267 | activation=activation, 268 | dropout_rate=dropout_rate, average=2, 269 | spatial_resnet_repetitions=spatial_resnet_repetitions, 270 | residual=residual, init=True) 271 | 272 | b2_out = general_module(b1_out, scales_samples=[int(x / 4) for x in 273 | scales_samples], 274 | filters_per_branch=filters_per_branch, 275 | ncha=ncha, 276 | activation=activation, 277 | dropout_rate=dropout_rate, average=2, 278 | spatial_resnet_repetitions=spatial_resnet_repetitions, 279 | residual=residual) 280 | # ============================== Residual (x3) =========================== # 281 | b3_u1 = general_module(b2_out, scales_samples=[16], 282 | filters_per_branch=int( 283 | filters_per_branch * len( 284 | scales_samples) / 2), 285 | ncha=ncha, 286 | activation=activation, 287 | dropout_rate=dropout_rate, average=2, 288 | spatial_resnet_repetitions=spatial_resnet_repetitions, 289 | residual=residual) 290 | b3_u1 = general_module(b3_u1, 291 | scales_samples=[8], 292 | filters_per_branch=int( 293 | filters_per_branch * len( 294 | scales_samples) / 2), 295 | 296 | ncha=ncha, 297 | activation=activation, 298 | dropout_rate=dropout_rate, average=2, 299 | spatial_resnet_repetitions=spatial_resnet_repetitions, 300 | residual=residual) 301 | b3_u2 = general_module(b3_u1, scales_samples=[4], 302 | filters_per_branch=int( 303 | filters_per_branch * len( 304 | scales_samples) / 4), 305 | ncha=ncha, 306 | activation=activation, 307 | dropout_rate=dropout_rate, average=2, 308 | spatial_resnet_repetitions=spatial_resnet_repetitions, 309 | residual=residual) 310 | # ========================== TEMPORAL REDUCTION ========================== # 311 | t_red = b3_u2[0] 312 | for _ in range(1): 313 | t_red_temp = t_red 314 | t_red_temp = Conv3D(kernel_size=(1, 4, 1), 315 | filters=int(filters_per_branch * len( 316 | scales_samples) / 4), 317 | use_bias=False, 318 | strides=(1, 1, 1), 319 | kernel_initializer='he_normal', 320 | padding='same')(t_red_temp) 321 | t_red_temp = BatchNormalization()(t_red_temp) 322 | t_red_temp = Activation(activation)(t_red_temp) 323 | t_red_temp = Dropout(dropout_rate)(t_red_temp) 324 | 325 | if residual: 326 | t_red = Add()([t_red, t_red_temp]) 327 | else: 328 | t_red = t_red_temp 329 | 330 | t_red = AveragePooling3D((1, 2, 1))(t_red) 331 | 332 | # =========================== CHANNEL MERGING ============================ # 333 | ch_merg = t_red 334 | if residual: 335 | for _ in range(2): 336 | ch_merg_temp = ch_merg 337 | ch_merg_temp = Conv3D(kernel_size=(division, 1, ncha), 338 | filters=int(filters_per_branch * len( 339 | scales_samples) / 4), 340 | use_bias=False, 341 | strides=(1, 1, 1), 342 | kernel_initializer='he_normal', 343 | padding='valid')(ch_merg_temp) 344 | ch_merg_temp = BatchNormalization()(ch_merg_temp) 345 | ch_merg_temp = Activation(activation)(ch_merg_temp) 346 | ch_merg_temp = Dropout(dropout_rate)(ch_merg_temp) 347 | 348 | ch_merg = Add()([ch_merg, ch_merg_temp]) 349 | 350 | ch_merg = Conv3D(kernel_size=(division, 1, ncha), 351 | filters=int( 352 | filters_per_branch * len(scales_samples) / 4), 353 | groups=int( 354 | filters_per_branch * len(scales_samples) / 8), 355 | use_bias=False, 356 | padding='valid')(ch_merg) 357 | ch_merg = BatchNormalization()(ch_merg) 358 | ch_merg = Activation(activation)(ch_merg) 359 | ch_merg = Dropout(dropout_rate)(ch_merg) 360 | else: 361 | if symmetric: 362 | ch_merg = Conv3D(kernel_size=(division, 1, 1), 363 | filters=int( 364 | filters_per_branch * len( 365 | scales_samples) / 4), 366 | groups=int( 367 | filters_per_branch * len( 368 | scales_samples) / 8), 369 | use_bias=False, 370 | padding='valid')(ch_merg) 371 | ch_merg = BatchNormalization()(ch_merg) 372 | ch_merg = Activation(activation)(ch_merg) 373 | ch_merg = Dropout(dropout_rate)(ch_merg) 374 | # ========================== TEMPORAL MERGING ============================ # 375 | t_merg = ch_merg 376 | for _ in range(1): 377 | if residual: 378 | t_merg_temp = t_merg 379 | t_merg_temp = Conv3D(kernel_size=(1, input_samples // 64, 1), 380 | filters=int(filters_per_branch * len( 381 | scales_samples) / 4), 382 | use_bias=False, 383 | strides=(1, 1, 1), 384 | kernel_initializer='he_normal', 385 | padding='valid')(t_merg_temp) 386 | t_merg_temp = BatchNormalization()(t_merg_temp) 387 | t_merg_temp = Activation(activation)(t_merg_temp) 388 | t_merg_temp = Dropout(dropout_rate)(t_merg_temp) 389 | 390 | t_merg = Add()([t_merg, t_merg_temp]) 391 | else: 392 | t_merg_temp = t_merg 393 | t_merg_temp = Conv3D(kernel_size=(1, input_samples // 64, 1), 394 | filters=int(filters_per_branch * len( 395 | scales_samples) / 4), 396 | use_bias=False, 397 | strides=(1, 1, 1), 398 | kernel_initializer='he_normal', 399 | padding='same')(t_merg_temp) 400 | t_merg_temp = BatchNormalization()(t_merg_temp) 401 | t_merg_temp = Activation(activation)(t_merg_temp) 402 | t_merg_temp = Dropout(dropout_rate)(t_merg_temp) 403 | t_merg = t_merg_temp 404 | 405 | t_merg = Conv3D(kernel_size=(1, input_samples // 64, 1), 406 | filters=int( 407 | filters_per_branch * len(scales_samples) / 4) * 2, 408 | groups=int( 409 | filters_per_branch * len(scales_samples) / 4), 410 | use_bias=False, 411 | padding='valid')(t_merg) 412 | t_merg = BatchNormalization()(t_merg) 413 | t_merg = Activation(activation)(t_merg) 414 | t_merg = Dropout(dropout_rate)(t_merg) 415 | # =============================== OUTPUT ================================= # 416 | output = t_merg 417 | for _ in range(4): 418 | output_temp = output 419 | output_temp = Conv3D(kernel_size=(1, 1, 1), 420 | filters=int( 421 | filters_per_branch * len( 422 | scales_samples) / 2), 423 | use_bias=False, 424 | strides=(1, 1, 1), 425 | kernel_initializer='he_normal', 426 | padding='valid')(output_temp) 427 | output_temp = BatchNormalization()(output_temp) 428 | output_temp = Activation(activation)(output_temp) 429 | output_temp = Dropout(dropout_rate)(output_temp) 430 | if residual: 431 | output = Add()([output, output_temp]) 432 | else: 433 | output = output_temp 434 | output = Flatten()(output) 435 | output_layer = Dense(n_classes, activation='softmax')(output) 436 | # Create and compile model 437 | model = keras.models.Model(inputs=input_layer, outputs=output_layer) 438 | optimizer = keras.optimizers.Adam(learning_rate=learning_rate, 439 | beta_1=0.9, beta_2=0.999, 440 | amsgrad=False) 441 | model.compile(loss='categorical_crossentropy', optimizer=optimizer, 442 | metrics=['accuracy']) 443 | return model -------------------------------------------------------------------------------- /EEGSym_main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path 3 | from scipy.io import loadmat, savemat 4 | import numpy as np 5 | import logging 6 | import sys 7 | from signal_target import SignalAndTarget,convert_numbers_to_one_hot 8 | from splitters import split_into_two_sets 9 | 10 | import tensorflow as tf 11 | import keras.backend as ka 12 | import keras as k 13 | import sys 14 | import h5py 15 | import random 16 | 17 | from EEGSym_architecture import EEGSym 18 | from EEGSym_DataAugmentation import trial_iterator 19 | from tensorflow.keras.callbacks import EarlyStopping as kerasEarlyStopping 20 | time_start = time.time() 21 | 22 | # TensorFlow configuration for GPU usage 23 | config = tf.compat.v1.ConfigProto() 24 | config.gpu_options.allow_growth = True 25 | session = tf.compat.v1.Session(config=config) 26 | 27 | # Fix random seed 28 | seed=20190706 29 | random.seed(seed) 30 | os.environ['PYTHONHASHSEED'] = str(seed) 31 | np.random.seed(seed) 32 | tf.random.set_seed(seed) 33 | 34 | log = logging.getLogger(__name__) 35 | logging.basicConfig(format='%(asctime)s %(levelname)s : %(message)s', 36 | level=logging.DEBUG, stream=sys.stdout) 37 | 38 | 39 | # Data folder where the datasets are located 40 | data_folder = 'C:/Users/Administrator/Desktop/Code-SBLEST-main' # The folder you download from https://github.com/EEGdecoding/Code-SBLEST 41 | 42 | # Fraction of data to be used as validation set 43 | valid_set_fraction= 0.2 44 | 45 | # Initialize train and test datasets 46 | X = np.zeros([1]) # np.ndarray([]) 47 | y= np.zeros([1]) # np.ndarray([]) 48 | train_set = SignalAndTarget(X, y) 49 | test_set = SignalAndTarget(X, y) 50 | 51 | bs_EEGSym = 16 # 52 | os.environ['CUDA_VISIBLE_DEVICES'] = " 0" 53 | pretrained = False # Parameter to load pre-trained weight values 54 | ncha=60 55 | hyperparameters = dict() 56 | hyperparameters["ncha"] = ncha 57 | hyperparameters["dropout_rate"] = 0.5 58 | hyperparameters["activation"] = 'elu' 59 | hyperparameters["n_classes"] = 2 60 | hyperparameters["learning_rate"] = 0.0001 # 1e-3 for pretraining and 1e-4 61 | # for fine-tuning 62 | hyperparameters["fs"] = 250 63 | hyperparameters["input_time"] = 3*1000 64 | hyperparameters["scales_time"] = np.tile([125, 250, 500], 1) 65 | hyperparameters['filters_per_branch'] = 8 66 | hyperparameters['ch_lateral'] = int((ncha / 2) - 1) # 67 | hyperparameters['residual'] = True 68 | hyperparameters['symmetric'] = True 69 | 70 | # Load train and test datasets from .mat files 71 | train_filename = 'Dataset2_L1_FootTongue_train.mat' 72 | test_filename = 'Dataset2_L1_FootTongue_test.mat' 73 | train_filepath = os.path.join(data_folder, train_filename) 74 | test_filepath = os.path.join(data_folder, test_filename) 75 | train = loadmat(train_filepath) 76 | test = loadmat(test_filepath) 77 | 78 | # Prepare train and test datasets 79 | 80 | label_1d_train = train['Y_train'].astype(np.int32) 81 | label_1d_test = test['Y_test'].astype(np.int32) 82 | train_set.y =convert_numbers_to_one_hot(label_1d_train) 83 | test_set.y =convert_numbers_to_one_hot(label_1d_test) 84 | train_set.X = np.transpose(train['X_train'], (2, 1, 0)).astype(np.float32) 85 | test_set.X = np.transpose(test['X_test'], (2, 1, 0)).astype(np.float32) 86 | 87 | 88 | # Split train set into train and validation set 89 | train_set, valid_set = split_into_two_sets( 90 | train_set, first_set_fraction = 1 - valid_set_fraction 91 | ) 92 | 93 | # Prepare data for model training and evaluation 94 | X_train = np.expand_dims(train_set.X, axis=3) 95 | X_validate = np.expand_dims(valid_set.X, axis=3) 96 | X_test = np.expand_dims(test_set.X, axis=3) 97 | Y_train = train_set.y 98 | Y_valid = valid_set.y 99 | Y_test = test_set.y 100 | 101 | # Get number of channels and samples from input data 102 | chans = X_train.shape[1] 103 | samples = X_train.shape[2] 104 | print('X_train shape:', X_train.shape) 105 | print(X_train.shape[0], 'train samples') 106 | print(X_test.shape[0], 'test samples') 107 | 108 | # Create and compile the EEGNet model 109 | model = EEGSym(**hyperparameters) 110 | model.summary() 111 | 112 | # Select if the data augmentation is performed 113 | augmentation = True # Parameter to activate or deactivate DA 114 | 115 | # Load pre-trained weight values 116 | if pretrained: 117 | model.load_weights('EEGSym_pretrained_weights_{}_electrode.h5'.format(ncha)) 118 | 119 | # Early stopping 120 | early_stopping = [(kerasEarlyStopping(mode='auto', monitor='val_loss', 121 | min_delta=0.0001, patience=50, verbose=1, 122 | restore_best_weights=True))] 123 | # %% OPTIONAL: Train the model 124 | if pretrained: 125 | for layer in model.layers[:-1]: 126 | layer.trainable = False 127 | 128 | fittedModel = model.fit(trial_iterator(X_train, Y_train, 129 | batch_size=bs_EEGSym, shuffle=True, 130 | augmentation=augmentation), 131 | steps_per_epoch=X_train.shape[0] / bs_EEGSym, 132 | epochs=500, validation_data=(X_validate, Y_valid), 133 | callbacks=[early_stopping]) 134 | 135 | # %% Obtain the accuracies of the trained model 136 | probs_test = model.predict(X_test) 137 | pred_test = probs_test.argmax(axis=-1) 138 | accuracy = (pred_test == Y_test.argmax(axis=-1)) 139 | 140 | 141 | # Predict the labels for the test set using the trained model and print the calculated classification accuracy 142 | probs = model.predict(X_test) 143 | preds = probs.argmax(axis=-1) 144 | acc = np.mean(preds == Y_test.argmax(axis=-1)) 145 | print("Classification accuracy: %f " % (acc)) 146 | 147 | 148 | 149 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Code-SBLEST 2 | 3 | This repo contains Matlab and Python code for the SBLEST (Sparse Bayesian Learning for End-to-End Spatio-Temporal-Filtering-Based Single-Trial EEG Classification) algorithm, as well as implementations of Convolutional Neural Networks (CNNs) used in the paper. Detailed information about the algorithms and CNN implementations can be found in [W. Wang, F. Qi, D. Wipf, C. Can, T. Yu, Z. Gu, Y. Li, Z. Yu, W. Wu. Sparse Bayesian Learning for End-to-End EEG Decoding, IEEE Transactions on Pattern Analysis and Machine Intelligence, vol. 45, no. 12, pp. 15632-15649, 2023](https://doi.org/10.1109/tpami.2023.3299568). 4 | 5 | ## Data 6 | The data used in this repository is from Subject L1 (foot vs. tongue) in Dataset II, as mentioned in the referenced paper. 7 | 8 | ### File Descriptions 9 | 10 | * [Dataset2_L1_FootTongue_train.mat](https://github.com/EEGdecoding/Code-SBLEST/blob/main/Dataset2_L1_FootTongue_train.mat) — This file contains the training data used in this repository. 11 | * [Dataset2_L1_FootTongue_test.mat](https://github.com/EEGdecoding/Code-SBLEST/blob/main/Dataset2_L1_FootTongue_test.mat) — This file contains the test data used in this repository. 12 | 13 | ## Matlab code for SBLEST 14 | 15 | The MATLAB scripts provided in this section implement the SBLEST algorithm and have been tested with MATLAB R2018b. 16 | 17 | ### File Descriptions 18 | 19 | * [SBLEST.m](https://github.com/EEGdecoding/Code-SBLEST/blob/main/SBLEST.m) —Matlab code for the SBLEST algorithm. 20 | 21 | * [SBLEST_main.m](https://github.com/EEGdecoding/Code-SBLEST/blob/main/SBLEST_main.m) — An example code for classifying single-trial EEG data using SBLEST in Matlab. 22 | 23 | ### Usage 24 | 25 | 1. To run the code, download and extract them into a folder of your choice, and navigate to this folder within MATLAB. 26 | 27 | 2. At the MATLAB command line, type 28 | ``` 29 | SBLEST_main 30 | ``` 31 | 32 | 33 | ## Python code for SBLEST 34 | 35 | The Python scripts for SBLEST are implemented in PyTorch and have been fully tested with Python 3.9. 36 | 37 | ### File Descriptions 38 | 39 | * [SBLEST_model.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/SBLEST_model.py) —Python code for the SBLEST algorithm. 40 | 41 | * [SBLEST_main.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/SBLEST_main.py) — An example code for classifying single-trial EEG data using SBLEST in Python. 42 | 43 | 44 | 45 | 46 | ## Python Implementations of sCNN, dCNN, EEGNet, EEG-Inception, and EEGSym 47 | 48 | sCNN and dCNN are implemented in PyTorch using the braindecode package, which is provided at https://github.com/robintibor/braindecode. 49 | 50 | EEGNet is implemented in TensorFlow using the Keras API, with the model provided at https://github.com/vlawhern/arl-eegmodels. 51 | 52 | EEG-inception and EEGSym are also implemented in TensorFlow, with the models provided at https://github.com/esantamariavazquez/EEGInception and https://github.com/Serpeve/EEGSym, respectively. 53 | 54 | ### File Descriptions 55 | 56 | * [sCNN_main.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/sCNN_main.py) — An example code for classifying single-trial EEG data using sCNN. 57 | 58 | * [dCNN_main.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/dCNN_main.py) — An example code for classifying single-trial EEG data using dCNN. 59 | 60 | * [EEGNet_main.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/EEGNet_main.py) — An example code for classifying single-trial EEG data using EEGNet. 61 | * [EEGModels.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/EEGModels.py) — A model file used in the EEGNet implementation. 62 | 63 | * [EEGInception_main.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/EEGInception_main.py) — An example code for classifying single-trial EEG data using EEG-inception. 64 | * [EEGInception.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/EEGInception.py) — A model file used in the EEG-inception implementation. 65 | 66 | * [EEGSym_main.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/EEGSym.py) — An example code for classifying single-trial EEG data using EEGSym. 67 | * [EEGSym_architecture.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/EEGSym_architecture.py) — A model file used in the EEGSym implementation. 68 | * [EEGSym_DataAugmentation.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/EEGSym_DataAugmentation.py) — A python file for data augmentation used in the EEGSym implementation. 69 | 70 | * [signal_target.py](https://github.com/EEGdecoding/Code-SBLEST/blob/main/signal_target.py) — A code for preprocessing the signal and target used in all the cNNs implementations. 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /SBLEST.m: -------------------------------------------------------------------------------- 1 | function [W, alpha, V, Wh] = SBLEST(X, Y, K, tau) 2 | % SBLEST : Sparse Bayesina Learning for End-to-end Spatio-Temporal-filtering-based single-trial EEG classification 3 | % 4 | % Syntax: 5 | % [W, alpha, V, Wh] = SBLEST(X, Y, K, tau) 6 | % 7 | % --- Inputs --- 8 | % Y : True label vector. [M, 1]. 9 | % X : M trials of C (channel) x T (time) EEG signals. [C, T, M]. 10 | % K : Order of FIR filter. 11 | % tau : Time delay parameter. 12 | % 13 | % --- Outputs --- 14 | % W : Estimated low-rank weight matrix. [K*C, K*C]. 15 | % alpha : Classifier weights. [L, 1]. 16 | % V : Spatio-temporal filter matrix. [K*C, L]. 17 | % Each column of V represents a spatio-temporal filter. 18 | % Wh : Whitening matrix for enhancing covariance matrices (required for prediction on test set). [(K*C)^2, (K*C)^2]. 19 | 20 | % Reference: 21 | % "W. Wang, F. Qi, D. Wipf, C. Can, T. Yu, Z. Gu, Y. Li, Z. Yu, W. Wu. Sparse Bayesian Learning for End-to-End EEG Decoding 22 | % (accepted by IEEE Transactions on Pattern Analysis and Machine Intelligence)." 23 | % 24 | % Wenlong Wang, Feifei Qi, Wei Wu, 2023. 25 | % Email: 201710102248@mail.scut.edu.cn 26 | 27 | % ************************************************************************ 28 | % Compute enhanced covariace matrices and whitening matrix 29 | [R_train, Wh] = Enhanced_cov_train(X, K, tau); 30 | 31 | %% Check properties of R 32 | [M, D_R] = size(R_train); % M: # of samples; D_R: dimention of vec(R_m) 33 | KC = round(sqrt(D_R)); 34 | Loss_old = 1e12; 35 | threshold = 0.05; % 36 | if (D_R ~= KC^2) 37 | disp('ERROR: Columns of A do not align with square matrix'); 38 | return; 39 | end 40 | 41 | % Check if R is symmetric 42 | for c = 1:M 43 | row_cov = reshape(R_train(c,:), KC, KC); 44 | if ( norm(row_cov - row_cov','fro') > 1e-4 ) 45 | disp('ERROR: Measurement row does not form symmetric matrix'); 46 | return 47 | end 48 | end 49 | 50 | %% Initializations 51 | U = zeros(KC, KC); % estimated low-rank matrix W initialized to be Zeros 52 | Psi = eye(KC); % covariance matrix of Gaussian prior distribution is initialized to be unit diagonal matrix 53 | lambda = 1;% variance of the additive noise set to 1 54 | 55 | %% Optimization loop 56 | for i = 1:5000 57 | %% Update U 58 | RPR = zeros(M, M); 59 | B = zeros(KC^2, M); 60 | for c = 1:KC 61 | start = (c-1)*KC + 1; stop = start + KC - 1; 62 | Temp = Psi*R_train(:, start:stop)'; 63 | B(start:stop,:) = Temp; 64 | RPR = RPR + R_train(:, start:stop)*Temp; 65 | end 66 | 67 | Sigma_y = RPR + lambda*eye(M); 68 | uc = B*(Sigma_y\Y ); % maximum a posterior estimation of uc 69 | Uc = reshape(uc, KC, KC); 70 | U = (Uc + Uc')/2; 71 | u = U(:); 72 | %% Update Phi (dual variable of Psi) 73 | Phi = cell(1, KC); 74 | SR = Sigma_y\R_train; 75 | for c = 1:KC 76 | start = (c-1)*KC + 1; stop = start + KC - 1; 77 | Phi{1,c} = Psi - Psi * ( R_train(:,start:stop)' * SR(:,start:stop) ) * Psi; 78 | end 79 | 80 | %% Update Psi 81 | PHI = 0; 82 | UU = 0; 83 | for c = 1:KC 84 | PHI = PHI + Phi{1, c}; 85 | UU = UU + U(:,c) * U(:,c)'; 86 | end 87 | Psi = ((UU + UU')/2 + (PHI + PHI')/2 )/KC; % make sure Psi is symmetric 88 | 89 | %% Update theta (dual variable of lambda) and lambda 90 | theta = 0; 91 | for c = 1:KC 92 | start = (c-1)*KC + 1; stop = start + KC - 1; 93 | theta = theta +trace(Phi{1,c}* R_train(:,start:stop)'*R_train(:,start:stop)) ; 94 | end 95 | lambda = (sum((Y-R_train*u).^2) + theta)/M; 96 | 97 | %% Convergence check 98 | logdet_Sigma_y = calculate_log_det(Sigma_y); 99 | Loss = Y'*Sigma_y^(-1)*Y + logdet_Sigma_y; 100 | delta_loss = abs(Loss_old-Loss)/abs( Loss_old); 101 | if (delta_loss < 2e-4) 102 | disp('EXIT: Change in loss below threshold'); 103 | break; 104 | end 105 | Loss_old = Loss; 106 | if (~rem(i,100)) 107 | disp(['Iterations: ', num2str(i), ' lambda: ', num2str(lambda),' Loss: ', num2str(Loss), ' Delta_Loss: ', num2str(delta_loss)]); 108 | end 109 | end 110 | %% Eigendecomposition of W 111 | W = U; 112 | [~, D, V_all] = eig(W); % each column of V represents a spatio-temporal filter 113 | alpha_all = diag(D); % classifier weights 114 | %% Determine spatio-temporal filters V and classifier weights alpha 115 | d = abs(diag(D)); d_max = max(d); 116 | w_norm = d/d_max; % normalize eigenvalues of W by the maximum eigenvalue 117 | index = find(w_norm > threshold); % indices of selected V according to a pre-defined threshold,.e.g., 0.05 118 | V = V_all(:,index); alpha = alpha_all(index); 119 | end 120 | 121 | %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 122 | function [R_train, Wh] = Enhanced_cov_train(X, K, tau) 123 | % Compute enhanced covariace matrices 124 | % 125 | % Inputs : 126 | % X : M trials of C (channel) x T (time) EEG signals. [C, T, M]. 127 | % K : Order of FIR filter 128 | % tau : Time delay parameter 129 | 130 | 131 | % Outputs : 132 | % R : Enhanced covariace matrices. [M,(K*C)^2*(K*C)^2 ] 133 | % Wh : Whitening matrix. [(K*C)^2, (K*C)^2]. 134 | 135 | % ************************************************************************ 136 | 137 | % Initializaiton 138 | [C, T, M] = size(X); 139 | KC = K*C; % [KC, KC]: dimension of augmented covariance matrix 140 | Cov = cell(1, M); 141 | Sig_Cov = zeros(KC, KC); 142 | for m = 1:M 143 | X_m = X(:,:,m); 144 | X_m_hat = []; 145 | 146 | % Generate augumented EEG data 147 | for k = 1 : K 148 | n_delay = (k-1)*tau; 149 | if n_delay == 0 150 | X_order_k = X_m; 151 | else 152 | X_order_k(:,1:n_delay) = 0; 153 | X_order_k(:,n_delay+1:T) = X_m(:,1:T-n_delay); 154 | end 155 | X_m_hat = cat(1, X_m_hat, X_order_k); 156 | end 157 | 158 | % Compute covariance matrices with trace normalizaiton 159 | Cov{1,m} = X_m_hat*X_m_hat'; 160 | Cov{1,m} = Cov{1,m}/trace(Cov{1,m}); 161 | Sig_Cov = Sig_Cov + Cov{1,m}; 162 | end 163 | 164 | % Compute Whitening matrix 165 | Wh = Sig_Cov/M; 166 | 167 | % Whitening, logarithm transform, and Vectorization 168 | Cov_whiten = zeros(M, KC, KC); 169 | for m = 1:M 170 | temp_cov = Wh^(-1/2)*Cov{1,m}*Wh^(-1/2);% whitening 171 | Cov_whiten(m,:,:) = (temp_cov + temp_cov')/2; 172 | R_m = logm(squeeze(Cov_whiten(m,:,:))); % logarithm transform 173 | R_m = R_m(:); % column-wise vectorization 174 | R_train(m,:) = R_m'; 175 | end 176 | end 177 | 178 | function log_det_X = calculate_log_det(X) 179 | % This function calculates the log determinant of a matrix X 180 | % by normalizing its diagonal elements to avoid infinite values. 181 | n = size(X,1); % Get the size of matrix X 182 | c = 10^floor(log10(X(1,1))); % Extract the scaling factor c as a power of 10 183 | A = X / c; % Normalize the matrix by the scaling factor 184 | L = chol(A, 'lower'); % Perform Cholesky decomposition on A 185 | log_det_A= 2 * sum(log(diag(L)));% Compute the log determinant of the normalized matrix via L 186 | % log_det_A = log(det(A)); 187 | log_det_X = n*log(c) + log_det_A; % Combine the results to get the log determinant of the original matrix 188 | end 189 | 190 | -------------------------------------------------------------------------------- /SBLEST_main.m: -------------------------------------------------------------------------------- 1 | %%% An example code for classifying single-trial EEG data using SBLEST 2 | clc; clear; close all; 3 | 4 | %% Load data: subject L1 from Dataset II ( "foot" vs "tongue") 5 | load('Dataset2_L1_FootTongue_train.mat'); 6 | load('Dataset2_L1_FootTongue_test.mat'); 7 | tau_selected = 1; % this was determined based on 10-fold cross-validation on the training set 8 | 9 | %% Initialization 10 | tau = tau_selected; 11 | if tau == 0 12 | K = 1; 13 | else 14 | K = 2; 15 | end 16 | %% Training stage: run SBLEST on the training set 17 | disp(['FIR filter order: ', num2str(K), ' Time delay: ', num2str(tau)]); 18 | disp('Running SBLEST : update W, Psi and lambda'); 19 | [W, alpha, V, Wh] = SBLEST(X_train, Y_train, K, tau); 20 | 21 | %% Test stage : predicte labels in the test set 22 | R_test = Enhanced_cov_test(X_test, K, tau, Wh); 23 | predict_Y = R_test*W(:); 24 | accuracy = compute_acc (predict_Y, Y_test); 25 | disp(['Test Accuracy: ', num2str(accuracy)]); 26 | 27 | 28 | %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 29 | function R_test = Enhanced_cov_test(X, K, tau, Wh) 30 | % Compute enhanced covariace matrices of test set 31 | % 32 | % Inputs : 33 | % X : EEG signals of test set 34 | % K : Order of FIR filter 35 | % tau : Time delay parameter 36 | % Wh : Whitening matrix for enhancing covariance matrices 37 | 38 | % Outputs : 39 | % R_test : Enhanced covariace matrices of test set 40 | % ************************************************************************ 41 | [C, T, M] = size(X); 42 | KC = K*C; % [KC, KC]: dimension of augmented covariance matrix 43 | Cov = cell(1, M); 44 | Sig_Cov = zeros(KC, KC); 45 | for m = 1:M 46 | X_m = X(:,:,m); 47 | X_m_hat = []; 48 | % Generate augumented EEG data 49 | for k = 1:K 50 | n_delay = (k-1)*tau; 51 | if n_delay ==0 52 | X_order_k = X_m; 53 | else 54 | X_order_k(:,1:n_delay) = 0; 55 | X_order_k(:,n_delay+1:T) = X_m(:,1:T-n_delay); 56 | end 57 | X_m_hat = cat(1,X_m_hat,X_order_k); 58 | end 59 | % Compute covariance and trace normalization 60 | Cov{1,m} = X_m_hat*X_m_hat'; 61 | Cov{1,m} = Cov{1,m}/trace(Cov{1,m}); 62 | Sig_Cov = Sig_Cov + Cov{1,m}; 63 | end 64 | 65 | % Whitenning, logarithm transform, and vectorization 66 | Cov_whiten = zeros(M, KC, KC); 67 | for m = 1:M 68 | temp_cov = Wh^(-1/2)*Cov{1,m}*Wh^(-1/2); 69 | Cov_whiten(m,:,:) = (temp_cov + temp_cov')/2; 70 | R_m =logm(squeeze(Cov_whiten(m,:,:))); % logarithm transform 71 | R_m = R_m(:); % column-wise vectorization 72 | R_test(m,:) = R_m'; 73 | end 74 | end 75 | 76 | %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 77 | function accuracy = compute_acc (predict_Y, Y_test) 78 | % Compute classification accuracy for test set 79 | Y_predict = zeros(length(predict_Y),1); 80 | for i = 1:length(predict_Y) 81 | if (predict_Y(i) > 0) 82 | Y_predict(i) = 1; 83 | else 84 | Y_predict(i) = -1; 85 | end 86 | end 87 | % Compute classification accuracy 88 | error_num = 0; 89 | total_num = length(predict_Y); 90 | for i = 1:total_num 91 | if (Y_predict(i) ~= Y_test(i)) 92 | error_num = error_num + 1; 93 | end 94 | end 95 | accuracy = (total_num-error_num)/total_num; 96 | end 97 | 98 | -------------------------------------------------------------------------------- /SBLEST_main.py: -------------------------------------------------------------------------------- 1 | # An example code for classifying single-trial EEG data using SBLEST 2 | from SBLEST_model import SBLEST, computer_acc, Enhanced_cov 3 | import torch 4 | from scipy.io import loadmat 5 | from torch import DoubleTensor 6 | 7 | # Initialization 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | tau = 1 10 | K = 2 11 | Epoch = 5000 12 | 13 | if __name__ == '__main__': 14 | # Load data: subject L1 from Dataset II ( "foot" vs "tongue") 15 | 16 | file_train = './Dataset2_L1_FootTongue_train.mat' 17 | file_test = './Dataset2_L1_FootTongue_test.mat' 18 | 19 | data_train = loadmat(file_train, mat_dtype=True) 20 | data_test = loadmat(file_test, mat_dtype=True) 21 | 22 | X_train = DoubleTensor(data_train['X_train']).to(device) 23 | Y_train = DoubleTensor(data_train['Y_train']).to(device) 24 | X_test = DoubleTensor(data_test['X_test']).to(device) 25 | Y_test = DoubleTensor(data_test['Y_test']).to(device) 26 | 27 | # Training stage: run SBLEST on the training set 28 | print('\n', 'FIR filter order: ', str(K), ' Time delay: ', str(tau)) 29 | W, alpha, V, Wh = SBLEST(X_train, Y_train, K, tau, Epoch) 30 | 31 | # Test stage : predict labels in the test set 32 | R_test, _ = Enhanced_cov(X_test, K, tau, Wh, train=0) 33 | vec_W = W.T.flatten() # vec operation (Torch) 34 | predict_Y = R_test @ vec_W 35 | accuracy = computer_acc(predict_Y, Y_test) 36 | print('Test Accuracy: ', str(accuracy)) 37 | 38 | 39 | 40 | -------------------------------------------------------------------------------- /SBLEST_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | import numpy as np 4 | from torch import reshape, norm, zeros, eye, float64, mm, inverse, log, det 5 | import numpy as np 6 | from torch import linalg, diag, log 7 | from torch import zeros, float64, mm, DoubleTensor 8 | 9 | warnings.filterwarnings('ignore') 10 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 11 | 12 | 13 | def SBLEST(X, Y, K, tau, Epoch=5000, epoch_print=100): 14 | """ 15 | SBLEST : Sparse Bayesina Learning for End-to-end Spatio-Temporal-filtering-based single-trial EEG classification 16 | 17 | --- Parameters --- 18 | Y : True label vector. [M, 1]. 19 | X : M trials of C (channel) x T (time) EEG signals. [C, T, M]. 20 | K : Order of FIR filter. 21 | tau : Time delay parameter. 22 | 23 | --- Returns --- 24 | W : Estimated low-rank weight matrix. [K*C, K*C]. 25 | alpha : Classifier weights. [L, 1]. 26 | V : Spatio-temporal filter matrix. [K*C, L]. 27 | Each column of V represents a spatio-temporal filter. 28 | Wh : Whitening matrix for enhancing covariance matrices (required for prediction on test set). [(K*C)^2, (K*C)^2]. 29 | 30 | Reference: 31 | "W. Wang, F. Qi, D. Wipf, C. Can, T. Yu, Z. Gu, Y. Li, Z. Yu, W. Wu. Sparse Bayesian Learning for End-to-End EEG Decoding 32 | (accepted by IEEE Transactions on Pattern Analysis and Machine Intelligence)." 33 | 34 | Wenlong Wang, Feifei Qi, Wei Wu, 2023. 35 | Email: 201710102248@mail.scut.edu.cn 36 | """ 37 | 38 | # Compute enhanced covariance matrices and whitening matrix 39 | R_train, Wh = Enhanced_cov(X, K, tau, train=1) 40 | # print('\n') 41 | 42 | # Check properties of R 43 | M, D_R = R_train.shape # M: number of samples; D_R: dimension of vec(R_m) 44 | KC = round(np.sqrt(D_R)) 45 | Loss_old = 1e12 46 | threshold = 0.05 47 | r2_list = [] 48 | 49 | assert D_R == KC ** 2, "ERROR: Columns of A do not align with square matrix" 50 | 51 | # Check if R is symmetric 52 | for j in range(M): 53 | row_cov = reshape(R_train[j, :], (KC, KC)) 54 | row_cov = (row_cov + row_cov.T) / 2 55 | assert norm(row_cov - row_cov.T) < 1e-4, "ERROR: Measurement row does not form symmetric matrix" 56 | 57 | # Initializations 58 | U = zeros(KC, KC, dtype=float64).to(device) # estimated low-rank matrix W initialized to be Zeros 59 | Psi = eye(KC, dtype=float64).to(device) # covariance matrix of Gaussian prior distribution is initialized to be unit diagonal matrix 60 | lambda_noise = 1. # variance of the additive noise set to 1 61 | 62 | # Optimization loop 63 | for i in range(Epoch+1): 64 | 65 | # update B,Sigma_y,u 66 | RPR = zeros(M, M, dtype=float64).to(device) 67 | B = zeros(KC ** 2, M, dtype=float64).to(device) 68 | for j in range(KC): 69 | start = j * KC 70 | stop = start + KC 71 | Temp = mm(Psi, R_train[:, start:stop].T) 72 | B[start:stop, :] = Temp 73 | RPR = RPR + mm(R_train[:, start:stop], Temp) 74 | Sigma_y = RPR + lambda_noise * eye(M, dtype=float64).to(device) 75 | uc = mm(mm(B, inverse(Sigma_y)), Y) # maximum a posterior estimation of uc 76 | Uc = reshape(uc, (KC, KC)) 77 | U = (Uc + Uc.T) / 2 78 | u = U.T.flatten() # vec operation (Torch) 79 | 80 | # Update Phi (dual variable of Psi) 81 | Phi = [] 82 | SR = mm(inverse(Sigma_y), R_train) 83 | for j in range(KC): 84 | start = j * KC 85 | stop = start + KC 86 | Phi_temp = Psi - Psi @ R_train[:, start:stop].T @ SR[:, start:stop] @ Psi 87 | Phi.append(Phi_temp) 88 | 89 | # Update Psi 90 | PHI = 0 91 | UU = 0 92 | for j in range(KC): 93 | PHI = PHI + Phi[j] 94 | UU = UU + U[:, j].reshape(-1, 1) @ U[:, j].reshape(-1, 1).T 95 | # UU = U @ U.T 96 | Psi = ((UU + UU.T) / 2 + (PHI + PHI.T) / 2) / KC # make sure Psi is symmetric 97 | 98 | # Update theta (dual variable of lambda) 99 | theta = 0 100 | for j in range(KC): 101 | start = j * KC 102 | stop = start + KC 103 | theta = theta + (Phi[j] @ R_train[:, start:stop].T @ R_train[:, start:stop]).trace() 104 | 105 | # Update lambda 106 | lambda_noise = ((norm(Y - (R_train @ u).reshape(-1, 1), p=2) ** 2).sum() + theta) / M 107 | 108 | # Convergence check 109 | Loss = Y.T @ inverse(Sigma_y) @ Y + log(det(Sigma_y)) 110 | delta_loss = abs(Loss_old - Loss.cpu().numpy()) / abs(Loss_old) 111 | if delta_loss < 2e-4: 112 | print('EXIT: Change in loss below threshold') 113 | break 114 | Loss_old = Loss.cpu().numpy() 115 | if i % epoch_print == 99: 116 | print('Iterations: ', str(i+1), ' lambda: ', str(lambda_noise.cpu().numpy()), ' Loss: ', float(Loss.cpu().numpy()), 117 | ' Delta_Loss: ', float(delta_loss)) 118 | 119 | # Eigen-decomposition of W 120 | W = U 121 | D, V_all = torch.linalg.eig(W) 122 | D, V_all = D.double().cpu().numpy(), V_all.double().cpu().numpy() 123 | idx = D.argsort() 124 | D = D[idx] 125 | V_all = V_all[:, idx] # each column of V represents a spatio-temporal filter 126 | alpha_all = D 127 | 128 | # Determine spatio-temporal filters V and classifier weights alpha 129 | d = np.abs(alpha_all) 130 | d_max = np.max(d) 131 | w_norm = d / d_max # normalize eigenvalues of W by the maximum eigenvalue 132 | index = np.where(w_norm > threshold)[0] # indices of selected V according to a pre-defined threshold,.e.g., 0.05 133 | V = V_all[:, index] 134 | alpha = alpha_all[index] 135 | 136 | return W, alpha, V, Wh 137 | 138 | 139 | def matrix_operations(A): 140 | """Calculate the -1/2 power of matrix A""" 141 | 142 | V, Q = linalg.eig(A) 143 | V_inverse = diag(V ** (-0.5)) 144 | A_inverse = mm(mm(Q, V_inverse), linalg.inv(Q)) 145 | 146 | return A_inverse.double() 147 | 148 | 149 | def logm(A): 150 | """Calculate the matrix logarithm of matrix A""" 151 | 152 | V, Q = linalg.eig(A) # V为特征值,Q为特征向量 153 | V_log = diag(log(V)) 154 | A_logm = mm(mm(Q, V_log), linalg.inv(Q)) 155 | 156 | return A_logm.double() 157 | 158 | 159 | def computer_acc(predict_Y, Y_test): 160 | """Compute classification accuracy for test set""" 161 | 162 | predict_Y = predict_Y.cpu().numpy() 163 | Y_test = torch.squeeze(Y_test).cpu().numpy() 164 | total_num = len(predict_Y) 165 | error_num = 0 166 | 167 | # Compute classification accuracy for test set 168 | Y_predict = np.zeros(total_num) 169 | for i in range(total_num): 170 | if predict_Y[i] > 0: 171 | Y_predict[i] = 1 172 | else: 173 | Y_predict[i] = -1 174 | 175 | # Compute classification accuracy 176 | for i in range(total_num): 177 | if Y_predict[i] != Y_test[i]: 178 | error_num = error_num + 1 179 | 180 | accuracy = (total_num - error_num) / total_num 181 | return accuracy 182 | 183 | 184 | def Enhanced_cov(X, K, tau, Wh=None, train=1): 185 | """ 186 | Compute enhanced covariance matrices 187 | 188 | --- Parameters --- 189 | X : M trials of C (channel) x T (time) EEG signals. [C, T, M]. 190 | K : Order of FIR filter 191 | tau : Time delay parameter 192 | Wh : Whitening matrix for enhancing covariance matrices. 193 | In training mode(train=1), Wh will be initialized as following python_code. 194 | In testing mode(train=0), Wh will receive the concrete value. 195 | train : train = 1 denote training mode, train = 0 denote testing mode. 196 | 197 | --- Returns --- 198 | R : Enhanced covariance matrices. [M,(K*C)^2*(K*C)^2 ] 199 | Wh : Whitening matrix. [(K*C)^2, (K*C)^2]. 200 | """ 201 | 202 | # Initialization, [KC, KC]: dimension of augmented covariance matrix 203 | X_order_k = None 204 | C, T, M = X.shape 205 | Cov = [] 206 | Sig_Cov = zeros(K * C, K * C).to(device) 207 | 208 | for m in range(M): 209 | X_m = X[:, :, m] 210 | X_m_hat = DoubleTensor().to(device) 211 | 212 | # Generate augmented EEG data 213 | for k in range(K): 214 | n_delay = k * tau 215 | if n_delay == 0: 216 | X_order_k = X_m.clone() 217 | else: 218 | X_order_k[:, 0:n_delay] = 0 219 | X_order_k[:, n_delay:T] = X_m[:, 0:T - n_delay].clone() 220 | X_m_hat = torch.cat((X_m_hat, X_order_k), 0) 221 | 222 | # Compute covariance matrices 223 | R_m = mm(X_m_hat, X_m_hat.T) 224 | 225 | # Trace normalization 226 | R_m = R_m / R_m.trace() 227 | Cov.append(R_m) 228 | 229 | Sig_Cov = Sig_Cov + R_m 230 | 231 | # Compute Whitening matrix (Rp). 232 | if train == 1: 233 | Wh = Sig_Cov / M 234 | 235 | # Whitening, logarithm transform, and Vectorization 236 | Cov_whiten = zeros(M, K * C, K * C, dtype=float64).to(device) 237 | R_train = zeros(M, K * C * K * C, dtype=float64).to(device) 238 | 239 | for m in range(M): 240 | # progress_bar(m, M) 241 | 242 | # whitening 243 | Wh_inverse = matrix_operations(Wh) # Rp^(-1/2) 244 | temp_cov = Wh_inverse @ Cov[m] @ Wh_inverse 245 | Cov_whiten[m, :, :] = (temp_cov + temp_cov.T) / 2 246 | R_m = logm(Cov_whiten[m, :, :]) 247 | R_m = R_m.reshape(R_m.numel()) # column-wise vectorization 248 | R_train[m, :] = R_m 249 | 250 | return R_train, Wh -------------------------------------------------------------------------------- /dCNN_main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os.path 3 | import time 4 | import sys 5 | from scipy.io import loadmat 6 | from braindecode.datautil.signal_target import SignalAndTarget 7 | import numpy as np 8 | 9 | import random 10 | import torch 11 | 12 | import torch.nn.functional as F 13 | from torch import optim 14 | import torch as th 15 | 16 | from braindecode.models.deep4 import Deep4Net 17 | from braindecode.models.util import to_dense_prediction_model 18 | from braindecode.experiments.experiment import Experiment 19 | from braindecode.experiments.monitors import LossMonitor, MisclassMonitor, \ 20 | RuntimeMonitor, CroppedTrialMisclassMonitor 21 | from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or 22 | from braindecode.datautil.iterators import CropsFromTrialsIterator 23 | from braindecode.models.shallow_fbcsp import ShallowFBCSPNet 24 | from braindecode.datautil.splitters import split_into_two_sets 25 | from braindecode.torch_ext.constraints import MaxNormDefaultConstraint 26 | from braindecode.torch_ext.util import set_random_seeds, np_to_var 27 | log = logging.getLogger(__name__) 28 | 29 | 30 | # Set fixed random seed 31 | seed=20190706 32 | torch.manual_seed(seed) 33 | torch.cuda.manual_seed(seed) 34 | torch.cuda.manual_seed_all(seed) 35 | np.random.seed(seed) 36 | random.seed(seed) 37 | torch.backends.cudnn.benchmark = False 38 | torch.backends.cudnn.deterministic = True 39 | 40 | 41 | def run_exp(data_folder, model, cuda): 42 | # parameter initialization 43 | input_time_length = 750 # 1000 44 | max_epochs = 1600 45 | max_increase_epochs = 160 46 | batch_size = 16 47 | valid_set_fraction = 0.2 48 | # load data 49 | X = np.zeros([1]) # np.ndarray([]) 50 | y = np.zeros([1]) # np.ndarray([]) 51 | train_set = SignalAndTarget(X, y) 52 | test_set = SignalAndTarget(X, y) 53 | 54 | # Load train and test datasets from .mat files 55 | train_filename = 'Dataset2_L1_FootTongue_train.mat' 56 | test_filename = 'Dataset2_L1_FootTongue_test.mat' 57 | train_filepath = os.path.join(data_folder, train_filename) 58 | test_filepath = os.path.join(data_folder, test_filename) 59 | train = loadmat(train_filepath) 60 | test = loadmat(test_filepath) 61 | # 62 | 63 | train_set.X = np.transpose(train['X_train'], (2, 0, 1)).astype(np.float32) 64 | test_set.X = np.transpose(test['X_test'], (2, 0, 1)).astype(np.float32) 65 | train['Y_train'] = np.where(train['Y_train'] == -1, 0, train['Y_train']) 66 | test['Y_test'] = np.where(test['Y_test'] == -1, 0, test['Y_test']) 67 | train_set.y = train['Y_train'].astype(np.int64) 68 | test_set.y = test['Y_test'].astype(np.int64) 69 | train_set.y = train_set.y.reshape(np.size(train_set.y, 0)) 70 | test_set.y = test_set.y.reshape(np.size(test_set.y, 0)) 71 | 72 | # split data into two sets 73 | train_set, valid_set = split_into_two_sets( 74 | train_set, first_set_fraction=1-valid_set_fraction) 75 | 76 | n_classes = 2 77 | n_chans = int(train_set.X.shape[1]) 78 | if model == 'shallow': 79 | model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length, 80 | final_conv_length=30).create_network() 81 | elif model == 'deep': 82 | model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length, 83 | final_conv_length=2).create_network() 84 | 85 | 86 | to_dense_prediction_model(model) 87 | if cuda: 88 | model.cuda() 89 | 90 | log.info("Model: \n{:s}".format(str(model))) 91 | dummy_input = np_to_var(train_set.X[:1, :, :, None]) 92 | if cuda: 93 | dummy_input = dummy_input.cuda() 94 | out = model(dummy_input) 95 | 96 | n_preds_per_input = out.cpu().data.numpy().shape[2] 97 | 98 | optimizer = optim.Adam(model.parameters()) 99 | 100 | iterator = CropsFromTrialsIterator(batch_size=batch_size, 101 | input_time_length=input_time_length, 102 | n_preds_per_input=n_preds_per_input) 103 | 104 | stop_criterion = Or([MaxEpochs(max_epochs), 105 | NoDecrease('valid_misclass', max_increase_epochs)]) 106 | 107 | monitors = [LossMonitor(), MisclassMonitor(col_suffix='sample_misclass'), 108 | CroppedTrialMisclassMonitor( 109 | input_time_length=input_time_length), RuntimeMonitor()] 110 | 111 | model_constraint = MaxNormDefaultConstraint() 112 | 113 | loss_function = lambda preds, targets: F.nll_loss( 114 | th.mean(preds, dim=2, keepdim=False), targets) 115 | 116 | exp = Experiment(model, train_set, valid_set, test_set, iterator=iterator, 117 | loss_function=loss_function, optimizer=optimizer, 118 | model_constraint=model_constraint, 119 | monitors=monitors, 120 | stop_criterion=stop_criterion, 121 | remember_best_column='valid_misclass', 122 | run_after_early_stop=True, cuda=cuda) 123 | exp.run() 124 | return exp 125 | 126 | if __name__ == '__main__': 127 | logging.basicConfig( 128 | format='%(asctime)s %(levelname)s : %(message)s', 129 | level=logging.DEBUG, 130 | stream=sys.stdout, 131 | ) 132 | 133 | # Data folder where the datasets are located 134 | data_folder = 'C:/Users/Administrator/Desktop/Code-SBLEST-main' # The folder you download from https://github.com/EEGdecoding/Code-SBLEST 135 | 136 | model = "deep" # 'shallow' or 'deep' 137 | cuda = True # True or False 138 | time_start = time.time() 139 | exp = run_exp(data_folder, model, cuda) 140 | log.info("Last 5 epochs") 141 | log.info("\n" + str(exp.epochs_df.iloc[-5:])) 142 | Accuracy = 1-exp.epochs_df.iloc.obj.test_misclass[-1:] 143 | time_end = time.time() 144 | mean_acc = np.mean(Accuracy) 145 | print('time cost', time_end - time_start, 's') 146 | print('mean accuracy', Accuracy) 147 | -------------------------------------------------------------------------------- /sCNN_main.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from torch import optim 6 | from braindecode.datautil.signal_target import SignalAndTarget 7 | import logging 8 | import sys 9 | import os 10 | from scipy.io import loadmat 11 | from braindecode.models.deep4 import Deep4Net 12 | from braindecode.experiments.experiment import Experiment 13 | from braindecode.experiments.monitors import ( 14 | LossMonitor, 15 | MisclassMonitor, 16 | RuntimeMonitor, 17 | ) 18 | from braindecode.experiments.stopcriteria import MaxEpochs, NoDecrease, Or 19 | from braindecode.datautil.iterators import BalancedBatchSizeIterator 20 | from braindecode.models.shallow_fbcsp import ShallowFBCSPNet 21 | from braindecode.datautil.splitters import split_into_two_sets 22 | from braindecode.torch_ext.constraints import MaxNormDefaultConstraint 23 | 24 | log = logging.getLogger(__name__) 25 | 26 | # Set fixed random seed 27 | seed=20190706 28 | torch.manual_seed(seed) 29 | torch.cuda.manual_seed(seed) 30 | torch.cuda.manual_seed_all(seed) 31 | np.random.seed(seed) 32 | random.seed(seed) 33 | torch.backends.cudnn.benchmark = False 34 | torch.backends.cudnn.deterministic = True 35 | 36 | 37 | def run_exp(data_folder, model, cuda): 38 | input_time_length = 750 # 1000 39 | max_epochs = 1600 40 | max_increase_epochs = 160 41 | batch_size = 16 42 | 43 | valid_set_fraction = 0.2 44 | # load data 45 | X = np.zeros([1]) # np.ndarray([]) 46 | y = np.zeros([1]) # np.ndarray([]) 47 | train_set = SignalAndTarget(X, y) 48 | test_set = SignalAndTarget(X, y) 49 | 50 | # Load train and test datasets from .mat files 51 | train_filename = 'Dataset2_L1_FootTongue_train.mat' 52 | test_filename = 'Dataset2_L1_FootTongue_test.mat' 53 | train_filepath = os.path.join(data_folder, train_filename) 54 | test_filepath = os.path.join(data_folder, test_filename) 55 | train = loadmat(train_filepath) 56 | test = loadmat(test_filepath) 57 | # 58 | 59 | train_set.X = np.transpose(train['X_train'], (2, 0, 1)).astype(np.float32) 60 | test_set.X = np.transpose(test['X_test'], (2, 0, 1)).astype(np.float32) 61 | train['Y_train'] = np.where(train['Y_train'] == -1, 0, train['Y_train']) 62 | test['Y_test'] = np.where(test['Y_test'] == -1, 0, test['Y_test']) 63 | train_set.y = train['Y_train'].astype(np.int64) 64 | test_set.y = test['Y_test'].astype(np.int64) 65 | train_set.y = train_set.y.reshape(np.size(train_set.y, 0)) 66 | test_set.y = test_set.y.reshape(np.size(test_set.y, 0)) 67 | 68 | # split data into two sets 69 | train_set, valid_set = split_into_two_sets( 70 | train_set, first_set_fraction=1-valid_set_fraction) 71 | 72 | n_classes = 2 73 | n_chans = int(train_set.X.shape[1]) 74 | input_time_length = train_set.X.shape[2] 75 | if model == "shallow": 76 | model = ShallowFBCSPNet( 77 | n_chans, 78 | n_classes, 79 | input_time_length=input_time_length, 80 | final_conv_length="auto", 81 | ).create_network() 82 | elif model == "deep": 83 | model = Deep4Net( 84 | n_chans, 85 | n_classes, 86 | input_time_length=input_time_length, 87 | final_conv_length="auto", 88 | ).create_network() 89 | if cuda: 90 | model.cuda() 91 | log.info("Model: \n{:s}".format(str(model))) 92 | 93 | optimizer = optim.Adam(model.parameters()) 94 | 95 | # # set_random_seeds(seed=20190706, cuda=cuda) 96 | # set_random_seeds(seed=20190706, cuda=False) 97 | 98 | iterator = BalancedBatchSizeIterator(batch_size=batch_size) 99 | 100 | stop_criterion = Or( 101 | [ 102 | MaxEpochs(max_epochs), 103 | NoDecrease("valid_misclass", max_increase_epochs), 104 | ] 105 | ) 106 | 107 | monitors = [LossMonitor(), MisclassMonitor(), RuntimeMonitor()] 108 | 109 | model_constraint = MaxNormDefaultConstraint() 110 | 111 | exp = Experiment( 112 | model, 113 | train_set, 114 | valid_set, 115 | test_set, 116 | iterator=iterator, 117 | loss_function=F.nll_loss, 118 | optimizer=optimizer, 119 | model_constraint=model_constraint, 120 | monitors=monitors, 121 | stop_criterion=stop_criterion, 122 | remember_best_column="valid_misclass", 123 | run_after_early_stop=True, 124 | cuda=cuda, 125 | ) 126 | exp.run() 127 | 128 | 129 | if __name__ == '__main__': 130 | logging.basicConfig( 131 | format='%(asctime)s %(levelname)s : %(message)s', 132 | level=logging.DEBUG, 133 | stream=sys.stdout, 134 | ) 135 | 136 | # Data folder where the datasets are located 137 | data_folder = 'C:/Users/Administrator/Desktop/Code-SBLEST-main' # The folder you download from https://github.com/EEGdecoding/Code-SBLEST 138 | 139 | model = "shallow" # 'shallow' or 'deep' 140 | cuda = True # True or False 141 | exp = run_exp(data_folder, model, cuda) 142 | log.info("Last 5 epochs") 143 | log.info("\n" + str(exp.epochs_df.iloc[-5:])) 144 | Accuracy = 1-exp.epochs_df.iloc.obj.test_misclass[-1:] 145 | print('mean accuracy', Accuracy) 146 | -------------------------------------------------------------------------------- /signal_target.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | class SignalAndTarget(object): 3 | """ 4 | Simple data container class. 5 | 6 | Parameters 7 | ---------- 8 | X: 3darray or list of 2darrays 9 | The input signal per trial. 10 | y: 1darray or list 11 | Labels for each trial. 12 | """ 13 | 14 | def __init__(self, X, y): 15 | assert len(X) == len(y) 16 | self.X = X 17 | self.y = y 18 | 19 | 20 | def apply_to_X_y(fn, *sets): 21 | """ 22 | Apply a function to all `X` and `y` attributes of all given sets. 23 | 24 | Applies function to list of X arrays and to list of y arrays separately. 25 | 26 | Parameters 27 | ---------- 28 | fn: function 29 | Function to apply 30 | sets: :class:`.SignalAndTarget` objects 31 | 32 | Returns 33 | ------- 34 | result_set: :class:`.SignalAndTarget` 35 | Dataset with X and y as the result of the 36 | application of the function. 37 | """ 38 | X = fn(*[s.X for s in sets]) 39 | y = fn(*[s.y for s in sets]) 40 | return SignalAndTarget(X, y) 41 | 42 | def convert_numbers_to_one_hot(arr): 43 | """ 44 | 将输入的一维数组中的数字1转换为[1, 0],数字-1转换为[0, 1],返回二维数组 45 | 46 | Args: 47 | arr (ndarray): 输入的一维数组 48 | 49 | Returns: 50 | ndarray: 转换后的二维数组 51 | """ 52 | # 将输入数组转换为NumPy数组 53 | arr = np.array(arr) 54 | 55 | # 创建全零数组,形状为(arr长度,2) 56 | one_hot = np.zeros((arr.shape[0], 2)) 57 | 58 | # 遍历输入数组 59 | for i, num in enumerate(arr): 60 | if num == 1: 61 | one_hot[i, 0] = 1 62 | elif num == -1: 63 | one_hot[i, 1] = 1 64 | 65 | return one_hot 66 | -------------------------------------------------------------------------------- /splitters.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from iterators import get_balanced_batches 4 | from signal_target import apply_to_X_y, SignalAndTarget 5 | 6 | 7 | def concatenate_sets(sets): 8 | """ 9 | Concatenate all sets together. 10 | 11 | Parameters 12 | ---------- 13 | sets: list of :class:`.SignalAndTarget` 14 | 15 | Returns 16 | ------- 17 | concatenated_set: :class:`.SignalAndTarget` 18 | """ 19 | concatenated_set = sets[0] 20 | for s in sets[1:]: 21 | concatenated_set = concatenate_two_sets(concatenated_set, s) 22 | return concatenated_set 23 | 24 | 25 | def concatenate_two_sets(set_a, set_b): 26 | """ 27 | Concatenate two sets together. 28 | 29 | Parameters 30 | ---------- 31 | set_a, set_b: :class:`.SignalAndTarget` 32 | 33 | Returns 34 | ------- 35 | concatenated_set: :class:`.SignalAndTarget` 36 | """ 37 | new_X = concatenate_np_array_or_add_lists(set_a.X, set_b.X) 38 | new_y = concatenate_np_array_or_add_lists(set_a.y, set_b.y) 39 | return SignalAndTarget(new_X, new_y) 40 | 41 | 42 | def concatenate_np_array_or_add_lists(a, b): 43 | if hasattr(a, "ndim") and hasattr(b, "ndim"): 44 | new = np.concatenate((a, b), axis=0) 45 | else: 46 | if hasattr(a, "ndim"): 47 | a = a.tolist() 48 | if hasattr(b, "ndim"): 49 | b = b.tolist() 50 | new = a + b 51 | return new 52 | 53 | 54 | def split_into_two_sets(dataset, first_set_fraction=None, n_first_set=None): 55 | """ 56 | Split set into two sets either by fraction of first set or by number 57 | of trials in first set. 58 | 59 | Parameters 60 | ---------- 61 | dataset: :class:`.SignalAndTarget` 62 | first_set_fraction: float, optional 63 | Fraction of trials in first set. 64 | n_first_set: int, optional 65 | Number of trials in first set 66 | 67 | Returns 68 | ------- 69 | first_set, second_set: :class:`.SignalAndTarget` 70 | The two splitted sets. 71 | """ 72 | assert (first_set_fraction is None) != ( 73 | n_first_set is None 74 | ), "Pass either first_set_fraction or n_first_set" 75 | if n_first_set is None: 76 | n_first_set = int(round(len(dataset.X) * first_set_fraction)) 77 | assert n_first_set < len(dataset.X) 78 | first_set = apply_to_X_y(lambda a: a[:n_first_set], dataset) 79 | second_set = apply_to_X_y(lambda a: a[n_first_set:], dataset) 80 | return first_set, second_set 81 | 82 | 83 | def select_examples(dataset, indices): 84 | """ 85 | Select examples from dataset. 86 | 87 | Parameters 88 | ---------- 89 | dataset: :class:`.SignalAndTarget` 90 | indices: list of int, 1d-array of int 91 | Indices to select 92 | 93 | Returns 94 | ------- 95 | reduced_set: :class:`.SignalAndTarget` 96 | Dataset with only examples selected. 97 | """ 98 | # probably not necessary 99 | indices = np.array(indices) 100 | if hasattr(dataset.X, "ndim"): 101 | # numpy array 102 | new_X = np.array(dataset.X)[indices] 103 | else: 104 | # list 105 | new_X = [dataset.X[i] for i in indices] 106 | new_y = np.asarray(dataset.y)[indices] 107 | return SignalAndTarget(new_X, new_y) 108 | 109 | 110 | def split_into_train_valid_test(dataset, n_folds, i_test_fold, rng=None): 111 | """ 112 | Split datasets into folds, select one valid fold, one test fold and merge rest as train fold. 113 | 114 | Parameters 115 | ---------- 116 | dataset: :class:`.SignalAndTarget` 117 | n_folds: int 118 | Number of folds to split dataset into. 119 | i_test_fold: int 120 | Index of the test fold (0-based). Validation fold will be immediately preceding fold. 121 | rng: `numpy.random.RandomState`, optional 122 | Random Generator for shuffling, None means no shuffling 123 | 124 | Returns 125 | ------- 126 | reduced_set: :class:`.SignalAndTarget` 127 | Dataset with only examples selected. 128 | """ 129 | n_trials = len(dataset.X) 130 | if n_trials < n_folds: 131 | raise ValueError( 132 | "Less Trials: {:d} than folds: {:d}".format(n_trials, n_folds) 133 | ) 134 | shuffle = rng is not None 135 | folds = get_balanced_batches(n_trials, rng, shuffle, n_batches=n_folds) 136 | test_inds = folds[i_test_fold] 137 | valid_inds = folds[i_test_fold - 1] 138 | all_inds = list(range(n_trials)) 139 | train_inds = np.setdiff1d(all_inds, np.union1d(test_inds, valid_inds)) 140 | assert np.intersect1d(train_inds, valid_inds).size == 0 141 | assert np.intersect1d(train_inds, test_inds).size == 0 142 | assert np.intersect1d(valid_inds, test_inds).size == 0 143 | assert np.array_equal( 144 | np.sort(np.union1d(train_inds, np.union1d(valid_inds, test_inds))), 145 | all_inds, 146 | ) 147 | 148 | train_set = select_examples(dataset, train_inds) 149 | valid_set = select_examples(dataset, valid_inds) 150 | test_set = select_examples(dataset, test_inds) 151 | 152 | return train_set, valid_set, test_set 153 | 154 | 155 | def split_into_train_test(dataset, n_folds, i_test_fold, rng=None): 156 | """ 157 | Split datasets into folds, select one test fold and merge rest as train fold. 158 | 159 | Parameters 160 | ---------- 161 | dataset: :class:`.SignalAndTarget` 162 | n_folds: int 163 | Number of folds to split dataset into. 164 | i_test_fold: int 165 | Index of the test fold (0-based) 166 | rng: `numpy.random.RandomState`, optional 167 | Random Generator for shuffling, None means no shuffling 168 | 169 | Returns 170 | ------- 171 | reduced_set: :class:`.SignalAndTarget` 172 | Dataset with only examples selected. 173 | """ 174 | n_trials = len(dataset.X) 175 | if n_trials < n_folds: 176 | raise ValueError( 177 | "Less Trials: {:d} than folds: {:d}".format(n_trials, n_folds) 178 | ) 179 | shuffle = rng is not None 180 | folds = get_balanced_batches(n_trials, rng, shuffle, n_batches=n_folds) 181 | test_inds = folds[i_test_fold] 182 | all_inds = list(range(n_trials)) 183 | train_inds = np.setdiff1d(all_inds, test_inds) 184 | assert np.intersect1d(train_inds, test_inds).size == 0 185 | assert np.array_equal(np.sort(np.union1d(train_inds, test_inds)), all_inds) 186 | 187 | train_set = select_examples(dataset, train_inds) 188 | test_set = select_examples(dataset, test_inds) 189 | return train_set, test_set 190 | --------------------------------------------------------------------------------