├── MixedNet_main.py ├── MyFunctions.py ├── MyModule.py ├── PowerAndConneMixedNet.py ├── README.md └── environment.yml /MixedNet_main.py: -------------------------------------------------------------------------------- 1 | # Demo of the FBSTCNet-M network on the SEED dataset. 2 | 3 | # Organized SEED data structure: 4 | # File name: Dataset1_sub(#subject)_s(#session).mat 5 | # Data structure: 6 | # data: 1 × 15 cell, each element contains an array of channels × timepoints; 7 | # fs = 200; 8 | # label: 1 × 15 double. (1 : positive, 2 : negative, 3 : neutral) 9 | 10 | # Data from each SEED subset was first divided into five equalsized folds in chronological order. 11 | # The first three folds were then used as the training set, and the remaining two folds were used for testing. 12 | 13 | # Reference: 14 | # "W. Huang, W. Wang, Y. Li, W. Wu. FBSTCNet: A Spatio-Temporal Convolutional Network Integrating Power and Connectivity Features for EEG-Based Emotion Decoding. 2023. (under review)" 15 | 16 | 17 | import sys 18 | import os 19 | import time 20 | from datetime import datetime 21 | import scipy.io as scio 22 | import numpy as np 23 | from torch.utils.data import Dataset 24 | import torch 25 | import logging 26 | import csv 27 | 28 | from PowerAndConneMixedNet import PowerAndConneMixedNet 29 | from skorch.helper import predefined_split 30 | from skorch.callbacks import LRScheduler 31 | from braindecode import EEGClassifier 32 | from torch.optim import AdamW 33 | from braindecode.training import CroppedLoss 34 | from braindecode.util import set_random_seeds 35 | from braindecode.models import get_output_shape 36 | from sklearn.metrics import confusion_matrix 37 | 38 | def WindowCutting(X, window_length): 39 | numChannel, numPoint = X.shape 40 | numSamples = int(numPoint/window_length) 41 | Samples = np.zeros([numSamples,numChannel,window_length],dtype='float32') 42 | for i in range(numSamples): 43 | Samples[i,:,:] = X[:,window_length*i:window_length*(i+1)] 44 | return Samples,numSamples 45 | 46 | class SEED_DATASET(Dataset): 47 | def __init__(self, X, y): 48 | assert len(X) == len(y), "n_samples dimension mismatch" 49 | self.X = X 50 | self.y = y 51 | 52 | def __len__(self): 53 | return len(self.X) 54 | 55 | def __getitem__(self, item): 56 | return self.X[item], self.y[item] 57 | 58 | 59 | 60 | now = datetime.now() 61 | timestr = now.strftime("%Y%m%d%H%M") 62 | dir = os.getcwd() + '/Results/Dat1_FBSTCNet-M_' + timestr 63 | if ~os.path.exists(dir): 64 | os.makedirs(dir) 65 | f_log = open(dir+'/log.txt',"w+") 66 | sys.stdout = f_log 67 | datapath = '../Data/emotion/' 68 | 69 | for SubID in range(1,16): 70 | for SessionID in range(1,4): 71 | print('...... | Subject: %d | Session: %d | ......' % (SubID, SessionID)) 72 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 73 | print("Load and split dataset...") 74 | filename = 'Dataset1_sub' + str(SubID) + '_s' + str(SessionID) + '.mat' 75 | 76 | 77 | data = scio.loadmat(datapath + filename) 78 | label = data['label'][0] - 1 79 | index_pos = np.where(label == 0)[0] 80 | index_neg = np.where(label == 1)[0] 81 | index_neu = np.where(label == 2)[0] 82 | index_all = np.concatenate((index_pos, index_neg, index_neu), axis=0) 83 | sfreq = int(data['fs'][0]) 84 | n_classes = 3 85 | input_window_time = 5 86 | n_epochs = 50 87 | input_window_samples = sfreq * input_window_time 88 | ConfM = np.zeros([n_epochs,3, 3]) 89 | ch_names = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 90 | 'FC5', 91 | 'FC3', 'FC1', 'FCZ', 92 | 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 93 | 'CP3', 94 | 'CP1', 'CPZ', 'CP2', 95 | 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 96 | 'POZ', 97 | 'PO4', 'PO6', 'PO8', 98 | 'CB1', 'O1', 'OZ', 'O2', 'CB2'] 99 | index_test = np.concatenate((index_pos[[3, 4]], index_neg[[3, 4]], index_neu[[3, 4]]), axis=0) 100 | index_cv = np.setdiff1d(index_all, index_test) 101 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 102 | print("Cross-validation start...") 103 | start_time_cv = time.perf_counter() 104 | device = 'cpu' 105 | 106 | torch.set_default_tensor_type('torch.FloatTensor') 107 | for ifold in range(0, 3): 108 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 109 | print('Fold %d' % (ifold + 1)) 110 | index_valid = np.concatenate(([index_pos[ifold]], [index_neg[ifold]], [index_neu[ifold]]), axis=0) 111 | index_train = np.setdiff1d(index_cv, index_valid) 112 | 113 | 114 | for itrial in range(len(index_train)): 115 | if itrial == 0: 116 | X_train, nTrial = WindowCutting(data['data'][0][index_train[itrial]], input_window_samples) 117 | Y_train = label[index_train[itrial]] * np.ones(nTrial, dtype=int) 118 | else: 119 | X_tmp, nTrial = WindowCutting(data['data'][0][index_train[itrial]], input_window_samples) 120 | X_train = np.concatenate((X_train, X_tmp), axis=0) 121 | Y_train = np.concatenate((Y_train, label[index_train[itrial]] * np.ones(nTrial, dtype=int))) 122 | Dataset_Train = SEED_DATASET(X_train,Y_train) 123 | 124 | for itrial in range(len(index_valid)): 125 | if itrial == 0: 126 | X_test, nTrial = WindowCutting(data['data'][0][index_valid[itrial]], input_window_samples) 127 | Y_test = label[index_valid[itrial]] * np.ones(nTrial, dtype=int) 128 | else: 129 | X_tmp, nTrial = WindowCutting(data['data'][0][index_valid[itrial]], input_window_samples) 130 | X_test = np.concatenate((X_test, X_tmp), axis=0) 131 | Y_test = np.concatenate((Y_test, label[index_valid[itrial]] * np.ones(nTrial, dtype=int))) 132 | Dataset_Valid = SEED_DATASET(X_test,Y_test) 133 | 134 | seed = 20220930 135 | set_random_seeds(seed=seed, cuda=False) 136 | 137 | confusion_mat_group = np.zeros([n_epochs, 3, 3]) 138 | 139 | 140 | 141 | n_chans = X_train.shape[1] 142 | filterRange = [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40), 143 | (40, 44), (44, 48), (48, 52)] 144 | model = PowerAndConneMixedNet( 145 | n_chans, 146 | n_classes, 147 | fs=sfreq, 148 | filterRange=filterRange, 149 | input_window_samples=input_window_samples, 150 | same_filters_for_features = False, 151 | ) 152 | 153 | lr = 0.0625 * 0.01 154 | weight_decay = 0 155 | batch_size = 16 156 | 157 | 158 | 159 | clf = EEGClassifier( 160 | model, 161 | cropped=True, #cropped decoding 162 | criterion=CroppedLoss, 163 | criterion__loss_function=torch.nn.functional.nll_loss, 164 | optimizer=torch.optim.AdamW, 165 | train_split=None, 166 | optimizer__lr=lr, 167 | optimizer__weight_decay=weight_decay, 168 | iterator_train__shuffle=True, 169 | batch_size=batch_size, 170 | device=device, 171 | ) 172 | 173 | 174 | for iep in range(n_epochs): 175 | clf.partial_fit(Dataset_Train, y=None, epochs=1) 176 | y_pred = clf.predict(Dataset_Valid) 177 | confusion_mat_group[iep,:,:] = confusion_matrix(Dataset_Valid.y, y_pred) 178 | ConfM[iep,:,:] = ConfM[iep,:,:] + confusion_matrix(Dataset_Valid.y, y_pred) 179 | print("Saving Results...") 180 | 181 | savename = dir + '/Sub' + str(SubID) + '_s' + str(SessionID) + '_fold' + str(ifold + 1) + '.npz' 182 | np.savez(savename, confusion_mat_group=confusion_mat_group) 183 | print("Finish Saving!") 184 | 185 | #torch.save(clf,dir+'/Sub' + str(SubID) + '_S' + str(SessionID) + '_fold' + str(ifold+1) + '_model.pth') 186 | 187 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 188 | print("Finding the best parameters...") 189 | 190 | correct_all = np.zeros(n_epochs) 191 | for iepc in range(n_epochs): 192 | correct_all[iepc] = ConfM[iepc][0][0] + ConfM[iepc][1][1]+ ConfM[iepc][2][2] 193 | best_epoch = np.argmax(correct_all) + 1 194 | best_correct = np.max(correct_all) 195 | 196 | end_time_cv = time.perf_counter() 197 | cost_time_cv = end_time_cv - start_time_cv 198 | 199 | savename = dir + '/Result_Sub' + str(SubID) + '_s' + str(SessionID) + '_cv.npz' 200 | np.savez(savename, ConfM=ConfM,correct_all=correct_all,cost_time_cv=cost_time_cv,best_epoch=best_epoch,best_correct=best_correct) 201 | 202 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 203 | print("Finish cross validation, cost %.5f seconds" % (cost_time_cv)) 204 | 205 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 206 | print("Split training and testing dataset...") 207 | 208 | for itrial in range(len(index_cv)): 209 | if itrial == 0: 210 | X_train, nTrial = WindowCutting(data['data'][0][index_cv[itrial]], input_window_samples) 211 | Y_train = label[index_cv[itrial]] * np.ones(nTrial, dtype=int) 212 | else: 213 | X_tmp, nTrial = WindowCutting(data['data'][0][index_cv[itrial]], input_window_samples) 214 | X_train = np.concatenate((X_train, X_tmp), axis=0) 215 | Y_train = np.concatenate((Y_train, label[index_cv[itrial]] * np.ones(nTrial, dtype=int))) 216 | Dataset_Train = SEED_DATASET(X_train, Y_train) 217 | 218 | for itrial in range(len(index_test)): 219 | if itrial == 0: 220 | X_test, nTrial = WindowCutting(data['data'][0][index_test[itrial]], input_window_samples) 221 | Y_test = label[index_test[itrial]] * np.ones(nTrial, dtype=int) 222 | else: 223 | X_tmp, nTrial = WindowCutting(data['data'][0][index_test[itrial]], input_window_samples) 224 | X_test = np.concatenate((X_test, X_tmp), axis=0) 225 | Y_test = np.concatenate((Y_test, label[index_test[itrial]] * np.ones(nTrial, dtype=int))) 226 | Dataset_Test = SEED_DATASET(X_test, Y_test) 227 | 228 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 229 | print("Start training model...") 230 | print('Total %d epochs' % (best_epoch)) 231 | start_time_train = time.perf_counter() 232 | 233 | seed = 20220930 234 | set_random_seeds(seed=seed, cuda=False) 235 | 236 | 237 | n_chans = X_train.shape[1] 238 | filterRange = [(4, 8), (8, 12), (12, 16), (16, 20), (20, 24), (24, 28), (28, 32), (32, 36), (36, 40), 239 | (40, 44), (44, 48), (48, 52)] 240 | 241 | model = PowerAndConneMixedNet( 242 | n_chans, 243 | n_classes, 244 | fs=sfreq, 245 | filterRange=filterRange, 246 | input_window_samples=input_window_samples, 247 | same_filters_for_features=False, 248 | ) 249 | 250 | lr = 0.0625 * 0.01 251 | weight_decay = 0 252 | batch_size = 16 253 | 254 | confusion_mat = np.zeros([3, 3]) 255 | 256 | clf = EEGClassifier( 257 | model, 258 | cropped=True, 259 | criterion=CroppedLoss, 260 | criterion__loss_function=torch.nn.functional.nll_loss, 261 | optimizer=torch.optim.AdamW, 262 | train_split=None, 263 | optimizer__lr=lr, 264 | optimizer__weight_decay=weight_decay, 265 | iterator_train__shuffle=True, 266 | batch_size=batch_size, 267 | device=device, 268 | ) 269 | clf.fit(Dataset_Train, y=None, epochs=best_epoch) 270 | 271 | end_time_train = time.perf_counter() 272 | cost_time_train = end_time_train - start_time_train 273 | 274 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 275 | print('Finish training,cost %.5f seconds' % (cost_time_train)) 276 | print("Saving model...") 277 | 278 | torch.save(model, dir + '/Sub' + str(SubID) + '_S' + str(SessionID) + '_train_model.pth') 279 | torch.save(clf, dir + '/Sub' + str(SubID) + '_S' + str(SessionID) + '_train_classifier.pth') 280 | 281 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 282 | print("Start prediction...") 283 | start_time_test = time.perf_counter() 284 | 285 | y_pred = clf.predict(Dataset_Test) 286 | confusion_mat = confusion_matrix(Dataset_Test.y, y_pred) 287 | 288 | test_acc = (confusion_mat[0,0] + confusion_mat[1,1] +confusion_mat[2,2])/np.sum(confusion_mat) 289 | end_time_test = time.perf_counter() 290 | cost_time_test = end_time_test - start_time_test 291 | 292 | print(datetime.now().strftime("%Y-%m-%d %H:%M:%S")) 293 | print("Finish prediction, cost %.5f seconds" % (cost_time_test)) 294 | print("Saving results...") 295 | 296 | np.savez(dir + '/Result_Sub' + str(SubID) + '_ses' + str(SessionID) + '_test.npz', 297 | test_acc=test_acc, confusion_mat=confusion_mat, 298 | cost_time_test=cost_time_test, cost_time_train=cost_time_train) 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /MyFunctions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import scipy.signal as signal 4 | from scipy.signal import cheb2ord 5 | 6 | def coherence(x): 7 | if x.dim() == 4: 8 | y = torch.zeros(x.size()[0],x.size()[1],x.size()[1],x.size()[3]) 9 | for i in range(0,x.size()[0]): 10 | for j in range(0,x.size()[3]): 11 | temp = x[i,:,:,j].squeeze(-1).squeeze(0) 12 | y[i,:,:,j] = torch.corrcoef(temp) 13 | 14 | return y 15 | 16 | 17 | 18 | 19 | def squeeze_final_output(x): 20 | assert x.size()[3] == 1 21 | x = x[:, :, :, 0] 22 | return x 23 | 24 | def squeeze_final_output_2d(x): 25 | assert x.size()[3] == 1 26 | x = x[:, :, :, 0] 27 | if x.size()[2] == 1: 28 | x = x[:, :, 0] 29 | return x 30 | 31 | 32 | def squeeze_3rd_dim_output(x): 33 | assert x.size()[2] == 1 34 | x = torch.squeeze(x,2) 35 | return x 36 | 37 | def shift_3rd_dim_output(x): 38 | assert x.size()[2] == 1 39 | x = torch.transpose(x,2,3) 40 | return x 41 | 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /MyModule.py: -------------------------------------------------------------------------------- 1 | from warnings import warn 2 | import numpy as np 3 | import pandas as pd 4 | import scipy.signal as signal 5 | from scipy.signal import cheb2ord 6 | import torch 7 | import torchaudio 8 | from torch import nn 9 | 10 | class filterbank(nn.Module): 11 | def __init__(self, fs, frequency_bands, filterStop=None, f_trans=1, gpass = 3, gstop = 30): 12 | super(filterbank, self).__init__() 13 | self.fs = fs 14 | self.f_trans = f_trans 15 | self.frequency_bands = frequency_bands 16 | self.filterStop = filterStop 17 | self.gpass = gpass 18 | self.gstop = gstop 19 | self.Nyquist_freq = self.fs / 2 20 | self.nFilter = len(self.frequency_bands) 21 | 22 | 23 | def forward(self, x): 24 | while (len(x.shape) < 4): 25 | x = x.unsqueeze(-1) 26 | (n_trials, n_channels, n_samples, temp) = x.size() 27 | all_filtered = torch.Tensor(np.zeros((n_trials, n_channels, n_samples, self.nFilter))) 28 | 29 | for i in range(self.nFilter): 30 | (l_freq, h_freq) = self.frequency_bands[i] 31 | f_pass = np.asarray([l_freq, h_freq]) 32 | if self.filterStop is not None: 33 | f_stop = np.asarray(self.filterStop[i]) 34 | else: 35 | f_stop = np.asarray([l_freq - self.f_trans, h_freq + self.f_trans]) 36 | wp = f_pass / self.Nyquist_freq 37 | ws = f_stop / self.Nyquist_freq 38 | order, wn = cheb2ord(wp, ws, self.gpass, self.gstop) 39 | b, a = signal.cheby2(order, self.gstop, ws, btype='bandpass') 40 | data = x[:,:,:,0] 41 | 42 | torch_a = torch.as_tensor(a,dtype = data.dtype) 43 | torch_b = torch.as_tensor(b,dtype = data.dtype) 44 | for j in range(n_trials): 45 | all_filtered[j,:,:,i] = torchaudio.functional.lfilter(data[j, :, :],torch_a, torch_b) 46 | 47 | return all_filtered 48 | 49 | class coherence_cropped(nn.Module): 50 | def __init__(self, time_length = "auto", time_stride = 1): 51 | super(coherence_cropped, self).__init__() 52 | self.time_length = time_length 53 | self.time_stride = time_stride 54 | 55 | def forward(self, x): 56 | while (len(x.shape) < 4): 57 | x = x.unsqueeze(-1) 58 | (n_trials, n_channels, n_samples, n_slice) = x.size() 59 | if self.time_length == "auto": 60 | self.time_length = n_samples 61 | n_windows_per_slice = int((n_samples - self.time_length) / self.time_stride) + 1 62 | y = torch.zeros(n_trials, n_channels, n_channels, n_slice * n_windows_per_slice) 63 | for i_trial in range(n_trials): 64 | for i_slice in range(n_slice): 65 | for i in range(n_windows_per_slice): 66 | temp = x[i_trial, :, self.time_stride * i:self.time_stride * i + self.time_length, i_slice].squeeze(-1).squeeze(0) 67 | y[i_trial, :, :, i_slice * n_windows_per_slice + i] = torch.corrcoef(temp) 68 | return y 69 | 70 | def get_padding(kernel_size, stride=1, dilation=1, **_): 71 | if isinstance(kernel_size, tuple): 72 | kernel_size = max(kernel_size) 73 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 74 | return padding 75 | 76 | class group_temporal_filter(nn.Module): 77 | def __init__(self, 78 | n_filterbank, 79 | n_filters_time, 80 | kernel_size_group, 81 | stride_size = 1 82 | ): 83 | super(group_temporal_filter, self).__init__() 84 | self.n_filterbank = n_filterbank 85 | self.n_filters_time = n_filters_time 86 | self.kernel_size_group = kernel_size_group if isinstance(kernel_size_group, list) else [kernel_size_group] 87 | self.n_group = len(self.kernel_size_group) 88 | self.stride_size = stride_size 89 | self.filter_list = nn.ModuleList([nn.Conv2d( 90 | self.n_filterbank, 91 | self.n_filters_time, 92 | self.kernel_size_group[i], 93 | stride=self.stride_size, 94 | padding=(get_padding(self.kernel_size_group[i],self.stride_size),0) 95 | ) for i in range(self.n_group)]) 96 | 97 | for layer in self.filter_list: 98 | torch.nn.init.xavier_uniform_(layer.weight, gain=1) 99 | torch.nn.init.constant_(layer.bias, 0) 100 | 101 | def forward(self, x): 102 | for layer in self.filter_list: 103 | if 'y' in dir(): 104 | y = torch.cat((y,layer(x)),dim = 1) 105 | else: 106 | y = layer(x) 107 | return y 108 | 109 | 110 | 111 | -------------------------------------------------------------------------------- /PowerAndConneMixedNet.py: -------------------------------------------------------------------------------- 1 | # This is the PyTorch implementation of the FBSTCNet-M architecture for EEG-based emotion classification. 2 | 3 | # Reference: 4 | # "W. Huang, W. Wang, Y. Li, W. Wu. FBSTCNet: A Spatio-Temporal Convolutional Network Integrating Power and Connectivity Features for EEG-Based Emotion Decoding. 2023. (under review)" 5 | 6 | 7 | import torch 8 | import numpy as np 9 | from torch import nn 10 | from torch.nn import init 11 | 12 | from braindecode.util import np_to_th 13 | from braindecode.models.modules import Expression, Ensure4d 14 | from braindecode.models.functions import ( 15 | safe_log, square, transpose_time_to_spat 16 | ) 17 | from MyModule import coherence_cropped, filterbank 18 | from MyFunctions import coherence, squeeze_final_output, shift_3rd_dim_output 19 | 20 | 21 | def get_padding(kernel_size, stride=1, dilation=1, **_): 22 | if isinstance(kernel_size, tuple): 23 | kernel_size = max(kernel_size) 24 | padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 25 | return padding 26 | 27 | class PowerAndConneMixedNet(nn.Module): 28 | 29 | 30 | def __init__( 31 | self, 32 | in_chans, 33 | n_classes, 34 | fs=100, 35 | f_trans = 2, 36 | filterRange=None, 37 | filterStop =None, 38 | input_window_samples=None, 39 | n_filters_time=72, 40 | filter_time_length=25, 41 | n_filters_spat=72, 42 | n_filters_power=36, 43 | pool_time_length=80, 44 | pool_time_stride=5, 45 | final_conv_length=35, 46 | final_conv_stride=25, 47 | conn_nonlin=coherence, 48 | pool_mode="mean", 49 | pool_nonlin=safe_log, 50 | split_first_layer=True, 51 | batch_norm=True, 52 | same_filters_for_features = True, 53 | batch_norm_alpha=0.1, 54 | drop_prob=0.5, 55 | ): 56 | super().__init__() 57 | if final_conv_length == "auto": 58 | assert input_window_samples is not None 59 | self.in_chans = in_chans 60 | self.n_classes = n_classes 61 | self.input_window_samples = input_window_samples 62 | self.n_filters_time = n_filters_time 63 | self.filter_time_length = filter_time_length 64 | self.n_filters_spat = n_filters_spat 65 | self.pool_time_length = pool_time_length 66 | self.pool_time_stride = pool_time_stride 67 | self.final_conv_length = final_conv_length 68 | self.final_conv_stride = final_conv_stride 69 | self.same_filters_for_features = same_filters_for_features 70 | if(self.same_filters_for_features): 71 | self.n_filters_power = self.n_filters_spat 72 | self.n_filters_coherence = self.n_filters_spat 73 | else: 74 | self.n_filters_power = n_filters_power 75 | self.n_filters_coherence = self.n_filters_spat-self.n_filters_power 76 | 77 | self.conn_nonlin = conn_nonlin 78 | self.pool_mode = pool_mode 79 | self.pool_nonlin = pool_nonlin 80 | self.split_first_layer = split_first_layer 81 | self.batch_norm = batch_norm 82 | self.batch_norm_alpha = batch_norm_alpha 83 | self.drop_prob = drop_prob 84 | self.filterRange = filterRange 85 | self.n_filterbank = len(self.filterRange) 86 | self.fs = fs 87 | self.f_trans =f_trans 88 | self.filterStop = filterStop 89 | 90 | # filter bank 91 | if self.filterRange is not None: 92 | self.add_module("filterbank", filterbank(fs=self.fs,frequency_bands=self.filterRange,filterStop=self.filterStop,f_trans=self.f_trans)) 93 | else: 94 | self.add_module("filterbank", Ensure4d()) # [numSample × numChannel × numPoint × 1] 95 | pool_class = dict(max=nn.MaxPool2d, mean=nn.AvgPool2d)[self.pool_mode] 96 | padding_size = get_padding((self.filter_time_length, 1)) 97 | 98 | self.add_module("dimshuffle", Expression(transpose_time_to_spat)) # [numSample × 1 × numPoint × numChannel] 99 | 100 | # temporal convolution 101 | self.add_module( 102 | "conv_time", 103 | nn.Conv2d( 104 | self.n_filterbank, 105 | self.n_filters_time, 106 | (self.filter_time_length, 1), 107 | stride=1, 108 | padding=(padding_size,0), 109 | ), 110 | ) 111 | self.add_module("conv_nonlin_exp", Expression(square)) 112 | self.add_module( 113 | "poolfunc", 114 | pool_class( 115 | kernel_size=(pool_time_length, 1), 116 | stride=(pool_time_stride, 1), 117 | ), 118 | ) 119 | 120 | # spatial convolutions for power and connectivity-based network, respectively 121 | self.add_module( 122 | "conv_spat_power", 123 | nn.Conv2d( 124 | self.n_filters_power, 125 | self.n_filters_power, 126 | (1, self.in_chans), 127 | stride=1, 128 | groups=self.n_filters_power, 129 | bias=not self.batch_norm, 130 | ), 131 | ) 132 | self.add_module( 133 | "conv_spat_conne", 134 | nn.Conv2d( 135 | self.n_filters_coherence, 136 | self.n_filters_coherence, 137 | (1, self.in_chans), 138 | stride=1, 139 | groups=self.n_filters_coherence, 140 | bias=not self.batch_norm, 141 | ), 142 | ) 143 | if self.batch_norm: 144 | self.add_module( 145 | "bnorm_power", 146 | nn.BatchNorm2d( 147 | self.n_filters_power, momentum=self.batch_norm_alpha, affine=True 148 | ), 149 | ) 150 | self.add_module( 151 | "bnorm_conne", 152 | nn.BatchNorm2d( 153 | self.n_filters_coherence, momentum=self.batch_norm_alpha, affine=True 154 | ), 155 | ) 156 | #power-based feature extraction 157 | self.add_module("pool_nonlin_exp", Expression(safe_log)) 158 | 159 | #connectivity-based feature extraction 160 | self.connectivity_exp = coherence_cropped(time_length=200, time_stride=100) 161 | self.add_module("power_drop", nn.Dropout(p=self.drop_prob)) 162 | self.add_module("conne_drop", nn.Dropout(p=self.drop_prob)) 163 | 164 | #final convolution 165 | self.add_module( 166 | "conv_power_classifier", 167 | nn.Conv2d( 168 | self.n_filters_power, 169 | self.n_classes, 170 | (self.final_conv_length, 1), 171 | stride=(self.final_conv_stride, 1), 172 | bias=True, 173 | ), 174 | ) 175 | self.add_module( 176 | "conv_conn_classifier", 177 | nn.Conv2d( 178 | self.n_filters_coherence, 179 | self.n_classes, 180 | (self.n_filters_coherence, 1), 181 | bias=True, 182 | ), 183 | ) 184 | self.add_module("conne_shift", Expression(shift_3rd_dim_output)) 185 | 186 | self.add_module("softmax", nn.LogSoftmax(dim=1)) 187 | self.add_module("squeeze", Expression(squeeze_final_output)) 188 | 189 | 190 | init.xavier_uniform_(self.conv_time.weight, gain=1) 191 | init.constant_(self.conv_time.bias, 0) 192 | init.xavier_uniform_(self.conv_conn_classifier.weight, gain=1) 193 | init.constant_(self.conv_conn_classifier.bias, 0) 194 | init.xavier_uniform_(self.conv_power_classifier.weight, gain=1) 195 | init.constant_(self.conv_power_classifier.bias, 0) 196 | if self.batch_norm: 197 | init.constant_(self.bnorm_power.weight, 1) 198 | init.constant_(self.bnorm_power.bias, 0) 199 | init.constant_(self.bnorm_conne.weight, 1) 200 | init.constant_(self.bnorm_conne.bias, 0) 201 | init.xavier_uniform_(self.conv_spat_conne.weight, gain=1) 202 | if not self.batch_norm: 203 | init.constant_(self.conv_spat_conne.bias, 0) 204 | init.xavier_uniform_(self.conv_spat_power.weight, gain=1) 205 | if not self.batch_norm: 206 | init.constant_(self.conv_spat_power.bias, 0) 207 | 208 | def forward(self, x): 209 | 210 | x = self.filterbank(x) 211 | x = self.dimshuffle(x) 212 | x = self.conv_time(x) 213 | if self.n_filters_power > 0: 214 | x1 = x[:, 0:self.n_filters_power, :, :] 215 | x1 = self.conv_spat_power(x1) 216 | if self.batch_norm: 217 | x1 = self.bnorm_power(x1) 218 | x1 = self.conv_nonlin_exp(x1) 219 | x1 = self.poolfunc(x1) 220 | x1 = self.pool_nonlin_exp(x1) 221 | x1 = self.power_drop(x1) 222 | x1 = self.conv_power_classifier(x1) 223 | 224 | if self.n_filters_coherence > 1: 225 | if self.same_filters_for_features: 226 | x2 = x[:, 0:self.n_filters_coherence, :, :] 227 | else: 228 | x2 = x[:, self.n_filters_power:self.n_filters_power+self.n_filters_coherence, :, :] 229 | x2 = self.conv_spat_conne(x2) 230 | if self.batch_norm: 231 | x2 = self.bnorm_conne(x2) 232 | x2 = self.connectivity_exp(x2) 233 | x2 = self.conv_conn_classifier(x2) 234 | x2 = self.conne_shift(x2) 235 | 236 | if self.n_filters_power > 0: 237 | if self.n_filters_coherence > 1: 238 | xout = torch.cat((x1, x2), dim=2) 239 | else: 240 | xout = x1 241 | else: 242 | xout = x2 243 | xout = self.softmax(xout) 244 | xout = self.squeeze(xout) 245 | 246 | return xout 247 | 248 | 249 | 250 | 251 | 252 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FBSTCNet 2 | This is the PyTorch implementation of the FBSTCNet-M architecture for EEG-based emotion classification. 3 | 4 | ## File Descriptions 5 | 6 | * [PowerAndConneMixedNet.py](https://github.com/TimeSpacerRob/FBSTCNet/blob/main/PowerAndConneMixedNet.py) — Code for the FBSTCNet-M network. 7 | 8 | * [MixedNet_main.py](https://github.com/TimeSpacerRob/FBSTCNet/blob/main/MixedNet_main.py) — An example code for classifying SEED dataset using FBSTCNet-M. 9 | 10 | # Reference 11 | Weichen Huang, Wenlong Wang, Yuanqing Li, Wei Wu, "FBSTCNet: A Spatio-Temporal Convolutional Network Integrating Power and Connectivity Features for EEG-Based Emotion Decoding", _IEEE Transactions on Affective Computing_, 15(4), 1906-1918, 2024, DOI: [10.1109/TAFFC.2024.3385651](http://dx.doi.org/10.1109/TAFFC.2024.3385651). 12 | 13 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | # This file may be used to create an Anaconda environment 2 | name: eeg_ml 3 | channels: 4 | - pytorch 5 | - defaults 6 | dependencies: 7 | - blas=1.0=mkl 8 | - bottleneck=1.3.4=py39h080aedc_0 9 | - brotli=1.0.9=ha925a31_2 10 | - brotlipy=0.7.0=py39h2bbff1b_1003 11 | - ca-certificates=2022.10.11=haa95532_0 12 | - captum=0.5.0=0 13 | - certifi=2022.6.15=py39haa95532_0 14 | - cffi=1.15.0=py39h2bbff1b_1 15 | - cpuonly=2.0=0 16 | - cryptography=36.0.0=py39h21b164f_0 17 | - cycler=0.11.0=pyhd3eb1b0_0 18 | - fonttools=4.25.0=pyhd3eb1b0_0 19 | - freetype=2.10.4=hd328e21_0 20 | - h5py=3.6.0=py39h3de5c98_0 21 | - hdf5=1.10.6=h7ebc959_0 22 | - icc_rt=2019.0.0=h0cc432a_1 23 | - icu=58.2=ha925a31_3 24 | - idna=3.3=pyhd3eb1b0_0 25 | - intel-openmp=2021.4.0=haa95532_3556 26 | - joblib=1.1.0=pyhd3eb1b0_0 27 | - jpeg=9d=h2bbff1b_0 28 | - kiwisolver=1.3.2=py39hd77b12b_0 29 | - libpng=1.6.37=h2a8f88b_0 30 | - libtiff=4.2.0=hd0e1b90_0 31 | - libuv=1.40.0=he774522_0 32 | - libwebp=1.2.2=h2bbff1b_0 33 | - lz4-c=1.9.3=h2bbff1b_1 34 | - matplotlib=3.5.1=py39haa95532_1 35 | - matplotlib-base=3.5.1=py39hd77b12b_1 36 | - mkl=2021.4.0=haa95532_640 37 | - mkl-service=2.4.0=py39h2bbff1b_0 38 | - mkl_fft=1.3.1=py39h277e83a_0 39 | - mkl_random=1.2.2=py39hf11a4ad_0 40 | - munkres=1.1.4=py_0 41 | - numexpr=2.8.1=py39hb80d3ca_0 42 | - numpy=1.21.5=py39h7a0a035_1 43 | - numpy-base=1.21.5=py39hca35cd5_1 44 | - openssl=1.1.1s=h2bbff1b_0 45 | - packaging=21.3=pyhd3eb1b0_0 46 | - pandas=1.4.1=py39hd77b12b_1 47 | - pillow=9.0.1=py39hdc2b20a_0 48 | - pip=21.2.4=py39haa95532_0 49 | - pycparser=2.21=pyhd3eb1b0_0 50 | - pyopenssl=22.0.0=pyhd3eb1b0_0 51 | - pyparsing=3.0.4=pyhd3eb1b0_0 52 | - pyqt=5.9.2=py39hd77b12b_6 53 | - pyreadline=2.1=py39haa95532_1 54 | - pysocks=1.7.1=py39haa95532_0 55 | - python=3.9.12=h6244533_0 56 | - python-dateutil=2.8.2=pyhd3eb1b0_0 57 | - pytorch=1.11.0=py3.9_cpu_0 58 | - pytorch-mutex=1.0=cpu 59 | - pytz=2021.3=pyhd3eb1b0_0 60 | - qt=5.9.7=vc14h73c81de_0 61 | - requests=2.27.1=pyhd3eb1b0_0 62 | - scipy=1.7.3=py39h0a974cb_0 63 | - sip=4.19.13=py39hd77b12b_0 64 | - six=1.16.0=pyhd3eb1b0_1 65 | - sqlite=3.38.2=h2bbff1b_0 66 | - tk=8.6.11=h2bbff1b_0 67 | - torchaudio=0.11.0=py39_cpu 68 | - torchvision=0.12.0=py39_cpu 69 | - tornado=6.1=py39h2bbff1b_0 70 | - typing_extensions=4.1.1=pyh06a4308_0 71 | - tzdata=2022a=hda174b7_0 72 | - vc=14.2=h21ff451_1 73 | - vs2015_runtime=14.27.29016=h5e58377_2 74 | - wheel=0.37.1=pyhd3eb1b0_0 75 | - win_inet_pton=1.1.0=py39haa95532_0 76 | - wincertstore=0.2=py39haa95532_2 77 | - xlwt=1.3.0=py39haa95532_0 78 | - xz=5.2.5=h62dcd97_0 79 | - zlib=1.2.12=h8cc25b3_1 80 | - zstd=1.4.9=h19a0ad4_0 81 | - pip: 82 | - appdirs==1.4.4 83 | - braindecode==0.6 84 | - charset-normalizer==2.0.12 85 | - colorama==0.4.4 86 | - coverage==5.5 87 | - dalib==0.2 88 | - decorator==5.1.1 89 | - jinja2==3.1.1 90 | - llvmlite==0.39.1 91 | - markupsafe==2.1.1 92 | - mne==1.0.2 93 | - moabb==0.4.6 94 | - numba==0.56.2 95 | - pooch==1.6.0 96 | - prettytable==3.4.1 97 | - pyecharts==1.9.1 98 | - pyriemann==0.2.7 99 | - pyyaml==5.4.1 100 | - qpsolvers==2.7.1 101 | - resampy==0.4.2 102 | - scikit-learn==1.0.2 103 | - seaborn==0.11.2 104 | - setuptools==59.8.0 105 | - simplejson==3.17.6 106 | - skorch==0.11.0 107 | - tabulate==0.8.9 108 | - threadpoolctl==3.1.0 109 | - torchsummary==1.5.1 110 | - tqdm==4.64.0 111 | - urllib3==1.26.9 112 | - wcwidth==0.2.5 113 | prefix: D:\anaconda3\envs\eeg_ml 114 | --------------------------------------------------------------------------------