├── preprocessing ├── preprocessing_order.txt ├── load_mat.py ├── preprocessing.py ├── epoch.py ├── window.py └── to3d.py ├── README.md ├── run_sta_net.py └── sta.py /preprocessing/preprocessing_order.txt: -------------------------------------------------------------------------------- 1 | load_mat-preprocessing-epoch-to3d-window 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # STA-Net 2 | 3 | STA-Net: Spatial–temporal alignment network for hybrid EEG-fNIRS decoding [[Paper](https://www.sciencedirect.com/science/article/pii/S156625352500096X)] 4 | 5 | ## Requirements: 6 | - Python 3.9.7 7 | - Tensorflow 2.10 (**Note**: Please make sure that your TensorFlow version is **2.10**. The input shape of the latest version of some functions has changed, such as Conv3D function, which may cause bugs.) 8 | 9 | ## Dataset 10 | 11 | - [Open access dataset for simultaneous EEG and NIRS brain-computer interface (BCI)](https://doc.ml.tu-berlin.de/hBCI/contactthanks.php) 12 | - [Simultaneous acquisition of EEG and NIRS during cognitive tasks for an open access dataset](https://doc.ml.tu-berlin.de/simultaneous_EEG_NIRS/) 13 | 14 | ## Citation 15 | ``` 16 | @article{liu2025sta, 17 | title={STA-Net: Spatial--temporal alignment network for hybrid EEG-fNIRS decoding}, 18 | author={Liu, Mutian and Yang, Banghua and Meng, Lin and Zhang, Yonghuai and Gao, Shouwei and Zan, Peng and Xia, Xinxing}, 19 | journal={Information Fusion}, 20 | pages={103023}, 21 | year={2025}, 22 | publisher={Elsevier} 23 | } 24 | ``` 25 | 26 | ## Contact 27 | If you have any questions, please contact us at shulmt@shu.edu.cn 28 | 29 | 30 | -------------------------------------------------------------------------------- /preprocessing/load_mat.py: -------------------------------------------------------------------------------- 1 | import scipy.io as io 2 | import numpy as np 3 | import os 4 | 5 | subject_list = [] 6 | 7 | folder_path = r'E:\IF\review\new_dataset\EEG_01-26_MATLAB' 8 | for filename in os.listdir(folder_path): 9 | subject_no = filename.split('-')[0] 10 | 11 | subject_list.append(subject_no) 12 | 13 | for name in subject_list: 14 | eeg_data = io.loadmat(r'E:\IF\review\new_dataset\EEG_01-26_MATLAB\{}-EEG\cnt_wg.mat'.format(name)) 15 | eeg_mrk_data = io.loadmat(r'E:\IF\review\new_dataset\EEG_01-26_MATLAB\{}-EEG\mrk_wg.mat'.format(name)) 16 | 17 | fnirs_data = io.loadmat(r'E:\IF\review\new_dataset\NIRS_01-26_MATLAB\{}-NIRS\cnt_wg.mat'.format(name)) 18 | fnirs_mrk_data = io.loadmat(r'E:\IF\review\new_dataset\NIRS_01-26_MATLAB\{}-NIRS\mrk_wg.mat'.format(name)) 19 | 20 | eeg = eeg_data['cnt_wg'][0,0][3].T 21 | eeg_time = eeg_mrk_data['mrk_wg'][0,0][0] 22 | 23 | hbo = fnirs_data['cnt_wg']['oxy'][0,0][0,0][5].T 24 | hbr = fnirs_data['cnt_wg']['deoxy'][0,0][0,0][5].T 25 | fnirs_time = fnirs_mrk_data['mrk_wg'][0,0][0] 26 | 27 | label = eeg_mrk_data['mrk_wg'][0,0][1] 28 | 29 | print(eeg.shape) 30 | print(eeg_time.shape) 31 | 32 | print(hbo.shape) 33 | print(hbr.shape) 34 | print(fnirs_time.shape) 35 | 36 | print(label.shape) 37 | 38 | save_dict = { 39 | 'eeg':eeg, 40 | 'eeg_time':eeg_time, 41 | 'hbo':hbo, 42 | 'hbr':hbr, 43 | 'fnirs_time':fnirs_time, 44 | 'label':label 45 | } 46 | 47 | save_dir = r'E:\IF\review\new_dataset\mat2array' 48 | save_name = name 49 | 50 | np.savez(os.path.join(save_dir,save_name),**save_dict) 51 | print('\n==============save {} success=============\n'.format(save_name)) 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /preprocessing/preprocessing.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import mne 4 | import matplotlib.pyplot as plt 5 | 6 | # eeg info 7 | eeg_chn_names = ['Fp1','AFF5h','AFz','F1','FC5','FC1','T7','C3','Cz','CP5','CP1','P7','P3','Pz','POz','O1','Fp2', 8 | 'AFF6h','F2','FC2','FC6','C4','T8','CP2','CP6','P4','P8','O2'] 9 | eeg_info = mne.create_info(ch_names=eeg_chn_names, sfreq=200, ch_types='eeg') 10 | eeg_info.set_montage('standard_1005') 11 | 12 | # fnirs info 13 | fnirs_chn_names = ['AF7','AFF5','AFp7','AF5h','AFp3','AFF3h','AF1','AFFz','AFpz','AF2','AFp4','FCC3','C3h','C5h','CCP3','CPP3','P3h','P5h','PPO3','AFF4h','AF6h','AFF6','AFp8','AF8','FCC4','C6h','C4h','CCP4','CPP4','P6h','P4h','PPO4','PPOz','PO1','PO2','POOz'] 14 | fnirs_info = mne.create_info(ch_names=fnirs_chn_names, sfreq=10, ch_types='eeg') 15 | fnirs_info.set_montage('standard_1005') 16 | 17 | # begin preprocess 18 | subject_no = 'VP026' 19 | with np.load(r'E:\IF\review\new_dataset\mat2array\{}.npz'.format(subject_no)) as data: 20 | eeg = data['eeg'] 21 | eeg_time = data['eeg_time'] 22 | hbo = data['hbo'] 23 | hbr = data['hbr'] 24 | fnirs_time = data['fnirs_time'] 25 | label = data['label'] 26 | 27 | 28 | # eeg 29 | raw = mne.io.RawArray(data=eeg[:-2,:], info=eeg_info) 30 | 31 | raw_notch = raw.notch_filter(np.arange(50, 100, 50)) 32 | raw_filtered = raw_notch.filter(0.5, 50., method='iir', iir_params=dict(order=6, ftype='butter')) 33 | 34 | raw_avg_ref = raw_filtered.set_eeg_reference(ref_channels="average") 35 | 36 | raw_avg_ref.load_data() 37 | 38 | #filtering just for ICA 39 | filt_ica_raw = raw_avg_ref.copy().filter(l_freq=1., h_freq=None) 40 | 41 | ica = mne.preprocessing.ICA(n_components=20) 42 | ica.fit(filt_ica_raw) 43 | ica 44 | 45 | ica.plot_sources(raw_avg_ref) 46 | plt.show() 47 | 48 | input_str = input('exclude components:') 49 | exclude_list = input_str.split(" ") 50 | for j in range(0,len(exclude_list)): 51 | exclude_list[j] = int(exclude_list[j]) 52 | 53 | ica.exclude = exclude_list 54 | print(ica.exclude) 55 | raw_icaed = ica.apply(raw_avg_ref) 56 | 57 | eeg_processed = raw_icaed.get_data() 58 | 59 | 60 | #fnirs 61 | hbo_raw = mne.io.RawArray(data=hbo, info=fnirs_info) 62 | hbr_raw = mne.io.RawArray(data=hbr, info=fnirs_info) 63 | 64 | hbo_filtered = hbo_raw.filter(0.01, 0.1, method='iir', iir_params=dict(order=6, ftype='butter')) 65 | hbr_filtered = hbr_raw.filter(0.01, 0.1, method='iir', iir_params=dict(order=6, ftype='butter')) 66 | 67 | hbo_processed = hbo_filtered.get_data() 68 | hbr_processed = hbr_filtered.get_data() 69 | 70 | save_dict = { 71 | 'eeg':eeg_processed, 72 | 'eeg_time':eeg_time, 73 | 'hbo':hbo_processed, 74 | 'hbr':hbr_processed, 75 | 'fnirs_time':fnirs_time, 76 | 'label':label 77 | } 78 | 79 | save_dir = r'E:\IF\review\new_dataset\preprocessed' 80 | save_name = subject_no 81 | 82 | np.savez(os.path.join(save_dir,save_name),**save_dict) 83 | print('\n==============save {} success=============\n'.format(save_name)) 84 | 85 | -------------------------------------------------------------------------------- /preprocessing/epoch.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import mne 4 | 5 | task_period = 10 6 | 7 | eeg_sample_rate = 200 8 | eeg_pre_onset = 5 9 | eeg_post_onset = task_period 10 | 11 | fnirs_sample_rate = 10 12 | fnirs_pre_onset = 5 13 | fnirs_post_onset = task_period + 12 14 | 15 | fnirs_chn_names = ['AF7','AFF5','AFp7','AF5h','AFp3','AFF3h','AF1','AFFz','AFpz','AF2','AFp4','FCC3','C3h','C5h','CCP3','CPP3','P3h','P5h','PPO3','AFF4h','AF6h','AFF6','AFp8','AF8','FCC4','C6h','C4h','CCP4','CPP4','P6h','P4h','PPO4','PPOz','PO1','PO2','POOz'] 16 | fnirs_info = mne.create_info(ch_names=fnirs_chn_names, sfreq=10, ch_types='eeg') 17 | fnirs_info.set_montage('standard_1005') 18 | 19 | subject_path = r'E:\IF\review\new_dataset\preprocessed' 20 | subject_list = os.listdir(subject_path) 21 | 22 | for subject in subject_list: 23 | with np.load(os.path.join(subject_path,subject)) as data: 24 | eeg = data['eeg'] 25 | eeg_time = data['eeg_time'] 26 | hbo = data['hbo'] 27 | hbr = data['hbr'] 28 | fnirs_time = data['fnirs_time'] 29 | label = data['label'] 30 | 31 | #epoch 32 | eeg_epoch = np.ones((60, 28, eeg_sample_rate*(eeg_pre_onset+eeg_post_onset)), dtype=np.float64) 33 | hbo_epoch = np.ones((60, 36, fnirs_sample_rate*(fnirs_pre_onset+fnirs_post_onset)), dtype=np.float64) 34 | hbr_epoch = np.ones((60, 36, fnirs_sample_rate*(fnirs_pre_onset+fnirs_post_onset)), dtype=np.float64) 35 | 36 | for t in range(60): 37 | #eeg 38 | eeg_start_indice = int((eeg_time[0, t]/1000.-eeg_pre_onset)*eeg_sample_rate) 39 | eeg_end_indice = int(eeg_start_indice + (eeg_pre_onset+eeg_post_onset)*eeg_sample_rate) 40 | 41 | eeg_one_epoch = eeg[:, eeg_start_indice:eeg_end_indice] 42 | eeg_epoch[t,] = eeg_one_epoch 43 | 44 | #fnirs 45 | fnirs_start_indice = int((fnirs_time[0, t]/1000.-fnirs_pre_onset)*fnirs_sample_rate) 46 | fnirs_end_indice = int(fnirs_start_indice + (fnirs_pre_onset+fnirs_post_onset)*fnirs_sample_rate) 47 | 48 | hbo_one_epoch = hbo[:, fnirs_start_indice:fnirs_end_indice] 49 | hbr_one_epoch = hbr[:, fnirs_start_indice:fnirs_end_indice] 50 | hbo_epoch[t,] = hbo_one_epoch 51 | hbr_epoch[t,] = hbr_one_epoch 52 | 53 | #fnirs baseline correction 54 | hbo_raw_bc = mne.EpochsArray(data=hbo_epoch, info=fnirs_info, baseline=(None, 3.)) 55 | hbr_raw_bc = mne.EpochsArray(data=hbr_epoch, info=fnirs_info, baseline=(None, 3.)) 56 | hbo_epoch_bc = hbo_raw_bc.get_data() 57 | hbr_epoch_bc = hbr_raw_bc.get_data() 58 | 59 | print(eeg_epoch.shape) 60 | print(hbo_epoch_bc.shape) 61 | print(hbr_epoch_bc.shape) 62 | print(label.shape) 63 | 64 | save_dict = { 65 | 'eeg':eeg_epoch, 66 | 'hbo':hbo_epoch_bc, 67 | 'hbr':hbr_epoch_bc, 68 | 'label':label 69 | } 70 | 71 | save_dir = r'E:\IF\review\new_dataset\epoch' 72 | save_name = subject 73 | 74 | np.savez(os.path.join(save_dir,save_name),**save_dict) 75 | print('\n==============save {} success=============\n'.format(save_name)) 76 | 77 | 78 | -------------------------------------------------------------------------------- /preprocessing/window.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | ''' 5 | win_step = 1 6 | win_length = 3 7 | 8 | eeg_segments_number = 10 9 | fnirs_segments_number = 22 10 | 11 | eeg_srate = 200 12 | fnirs_srate = 10 13 | 14 | subject_path = r'E:\IF\review\new_dataset\d3' 15 | subject_list = os.listdir(subject_path) 16 | 17 | for subject in subject_list: 18 | with np.load(os.path.join(subject_path, subject)) as data: 19 | eeg = data['eeg'] 20 | hbo = data['hbo'] 21 | hbr = data['hbr'] 22 | label = data['label'] 23 | 24 | eeg_window = np.ones((60, eeg_segments_number, 16, 16, win_length*eeg_srate)) 25 | hbo_window = np.ones((60, fnirs_segments_number, 16, 16, win_length*fnirs_srate)) 26 | hbr_window = np.ones((60, fnirs_segments_number, 16, 16, win_length*fnirs_srate)) 27 | 28 | for e in range(60): 29 | # first 10 windows has same time interval 30 | for w in range(eeg_segments_number): 31 | eeg_start_indice = (3+w)*eeg_srate 32 | eeg_end_indice = eeg_start_indice + win_length*eeg_srate 33 | 34 | eeg_segment = eeg[e, :, :, eeg_start_indice:eeg_end_indice] 35 | 36 | eeg_window[e, w, :, :, :] = eeg_segment 37 | 38 | for fw in range(fnirs_segments_number): 39 | fnirs_start_indice = (3+fw)*fnirs_srate 40 | fnirs_end_indice = fnirs_start_indice + win_length*fnirs_srate 41 | 42 | hbo_segment = hbo[e, :, :, fnirs_start_indice:fnirs_end_indice] 43 | hbr_segment = hbr[e, :, :, fnirs_start_indice:fnirs_end_indice] 44 | 45 | hbo_window[e, fw, :, :, :] = hbo_segment 46 | hbr_window[e, fw, :, :, :] = hbr_segment 47 | 48 | print(eeg_window.shape) 49 | print(hbo_window.shape) 50 | print(hbr_window.shape) 51 | print(label.shape) 52 | 53 | save_dict = { 54 | 'eeg':eeg_window, 55 | 'hbo':hbo_window, 56 | 'hbr':hbr_window, 57 | 'label':label 58 | } 59 | 60 | save_dir = r'E:\IF\review\new_dataset\window' 61 | save_name = subject 62 | 63 | np.savez(os.path.join(save_dir,save_name),**save_dict) 64 | print('\n==============save {} success=============\n'.format(save_name)) 65 | ''' 66 | 67 | 68 | 69 | fnirs_lag_length = 11 # with t-self 70 | 71 | subject_path = r'E:\IF\review\new_dataset\window' 72 | subject_list = os.listdir(subject_path) 73 | 74 | for subject in subject_list: 75 | with np.load(os.path.join(subject_path, subject)) as data: 76 | eeg = data['eeg'] 77 | hbo = data['hbo'] 78 | hbr = data['hbr'] 79 | label = data['label'] 80 | 81 | # eeg 82 | eeg_session_dataset = np.expand_dims(eeg, axis=-1) 83 | eeg_input = eeg_session_dataset.reshape(600, 16, 16, 600, 1) 84 | 85 | # fnirs 86 | fnirs_session_dataset = np.ones((60, 10, fnirs_lag_length, 16, 16, 30, 2)) 87 | 88 | for e in range(60): 89 | for w in range(10): 90 | # first 10 windows has same time interval 91 | hbo_sample = hbo[e, w:(w+fnirs_lag_length),] 92 | hbr_sample = hbr[e, w:(w+fnirs_lag_length),] 93 | 94 | hbo_sample = np.expand_dims(hbo_sample, axis=-1) 95 | hbr_sample = np.expand_dims(hbr_sample, axis=-1) 96 | 97 | fnirs_sample = np.concatenate((hbo_sample, hbr_sample), axis=-1) 98 | 99 | fnirs_session_dataset[e, w,] = fnirs_sample 100 | 101 | fnirs_input = fnirs_session_dataset.reshape(600, 11, 16, 16, 30, 2) 102 | 103 | # label 104 | label_session_dataset = label.T 105 | label_input = np.repeat(label_session_dataset, repeats=10, axis=0) 106 | 107 | print(eeg_input.shape) 108 | print(fnirs_input.shape) 109 | print(label_input.shape) 110 | 111 | save_dict = { 112 | 'eeg':eeg_input, 113 | 'fnirs':fnirs_input, 114 | 'label':label_input 115 | } 116 | 117 | save_dir = r'E:\IF\review\new_dataset\model_input' 118 | save_name = subject 119 | 120 | np.savez(os.path.join(save_dir,save_name),**save_dict) 121 | print('\n==============save {} success=============\n'.format(save_name)) 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /run_sta_net.py: -------------------------------------------------------------------------------- 1 | from sta import sta_net 2 | 3 | import numpy as np 4 | import tensorflow as tf 5 | from tensorflow import keras 6 | from keras import layers 7 | from keras.callbacks import ModelCheckpoint 8 | import os 9 | 10 | 11 | class targetacccallback(keras.callbacks.Callback): 12 | def __init__(self, target_acc): 13 | super().__init__() 14 | 15 | self.target_acc = target_acc 16 | 17 | def on_epoch_end(self, epoch, logs={}): 18 | if(logs['class_output_loss'] <= self.target_acc): 19 | print("\nReached target loss value {} so cancelling training!\n".format(self.target_acc)) 20 | self.model.stop_training = True 21 | 22 | 23 | subject_path = r'E:\IF\dataset\model_input' 24 | subject_list = os.listdir(subject_path) 25 | 26 | for subject in subject_list: 27 | with np.load(os.path.join(subject_path, subject)) as data: 28 | eeg = data['eeg'] 29 | fnirs = data['fnirs'] 30 | label = data['label'] 31 | 32 | fnirs *= 1e3 33 | 34 | label = label.astype(float) 35 | 36 | for session in range(3): 37 | all_eeg = np.delete(eeg, slice(session*200, (session+1)*200), 0) 38 | all_fnirs = np.delete(fnirs, slice(session*200, (session+1)*200), 0) 39 | all_label = np.delete(label, slice(session*200, (session+1)*200), 0) 40 | 41 | second_train_dataset = tf.data.Dataset.from_tensor_slices( 42 | ( 43 | {"eeg_input": all_eeg, "fnirs_input": all_fnirs}, 44 | {"class_output": all_label, 'eeg_output':all_label} 45 | ) 46 | ) 47 | second_train_dataset = second_train_dataset.shuffle(buffer_size=128).batch(32) 48 | 49 | eeg_test = eeg[session*200:(session+1)*200,] 50 | fnirs_test = fnirs[session*200:(session+1)*200,] 51 | label_test = label[session*200:(session+1)*200,] 52 | 53 | test_dataset = tf.data.Dataset.from_tensor_slices( 54 | ( 55 | {"eeg_input": eeg_test, "fnirs_input": fnirs_test}, 56 | {"class_output": label_test, 'eeg_output':label_test} 57 | ) 58 | ) 59 | test_dataset = test_dataset.batch(32) 60 | 61 | np.random.seed(42) 62 | indices = np.random.choice(all_eeg.shape[0], size=80, replace=False) 63 | 64 | eeg_train = np.delete(all_eeg, indices, axis=0) 65 | fnirs_train = np.delete(all_fnirs, indices, axis=0) 66 | label_train = np.delete(all_label, indices, axis=0) 67 | first_train_dataset = tf.data.Dataset.from_tensor_slices( 68 | ( 69 | {"eeg_input": eeg_train, "fnirs_input": fnirs_train}, 70 | {"class_output": label_train, 'eeg_output':label_train} 71 | ) 72 | ) 73 | first_train_dataset = first_train_dataset.shuffle(buffer_size=128).batch(32) 74 | 75 | eeg_val = all_eeg[indices] 76 | fnirs_val = all_fnirs[indices] 77 | label_val = all_label[indices] 78 | val_dataset = tf.data.Dataset.from_tensor_slices( 79 | ( 80 | {"eeg_input": eeg_val, "fnirs_input": fnirs_val}, 81 | {"class_output": label_val, 'eeg_output':label_val} 82 | ) 83 | ) 84 | val_dataset = val_dataset.batch(32) 85 | 86 | print('eeg_train shape:', eeg_train.shape) 87 | print('fnirs_train shape:', fnirs_train.shape) 88 | print('label_train shape:', label_train.shape) 89 | 90 | print('eeg_val shape:', eeg_val.shape) 91 | print('fnirs_val shape:', fnirs_val.shape) 92 | print('label_val shape:', label_val.shape) 93 | 94 | print(subject) 95 | print(session) 96 | 97 | tf.keras.backend.clear_session() 98 | model = sta_net() 99 | 100 | model.compile(loss='categorical_crossentropy', optimizer='adam', metrics = ['accuracy']) 101 | 102 | stopping = tf.keras.callbacks.EarlyStopping(monitor='val_class_output_loss', patience=50, restore_best_weights=True, verbose=1) 103 | 104 | print('begin first train') 105 | first_history = model.fit(first_train_dataset, epochs = 300, 106 | verbose = 2, validation_data=val_dataset, 107 | callbacks=[stopping]) 108 | 109 | min_val_class_output_loss = min(first_history.history['val_class_output_loss']) 110 | min_val_class_output_loss_epoch = first_history.history['val_class_output_loss'].index(min_val_class_output_loss) 111 | target_acc = first_history.history['class_output_loss'][min_val_class_output_loss_epoch] 112 | 113 | print('begin second train') 114 | model.fit(second_train_dataset, epochs = 200, 115 | verbose = 2, callbacks=[targetacccallback(target_acc)]) 116 | 117 | print('begin test') 118 | test_results = model.evaluate(test_dataset) 119 | 120 | 121 | print('all done') 122 | 123 | 124 | 125 | -------------------------------------------------------------------------------- /preprocessing/to3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.interpolate import griddata 3 | import os 4 | 5 | # all points 6 | x = np.arange(16) 7 | y = np.arange(16) 8 | xx, yy = np.meshgrid(x, y) 9 | all_points = np.column_stack((xx.ravel(), yy.ravel())) 10 | 11 | # eeg interpolate 12 | # [x, y] 13 | known_eeg_point_coordinates = np.array([[0., 6.], #Fp1 14 | [2., 5.], #AFF5h 15 | [2., 8.], #AFz 16 | [3., 7.], #F1 17 | [5., 2.], #FC5 18 | [5., 6.], #FC1 19 | [7., 1.], #T7 20 | [7., 4.], #C3 21 | [7., 8.], #Cz 22 | [9., 2.], #CP5 23 | [9., 6.], #CP1 24 | [11., 2.], #P7 25 | [11., 5.], #P3 26 | [11., 8.], #Pz 27 | [13., 8.], #POz 28 | [14., 6.], #O1 29 | [0., 10.], #Fp2 30 | [2., 11.], #AFF6h 31 | [3., 9.], #F2 32 | [5., 10.], #FC2 33 | [5., 14.], #FC6 34 | [7., 12.], #C4 35 | [7., 15.], #T8 36 | [9., 10.], #CP2 37 | [9., 14.], #CP6 38 | [11., 11.], #P4 39 | [11., 14.], #P8 40 | [14., 10.] #O2 41 | ]) 42 | 43 | unknown_eeg_point_coordinates = np.array([coord for coord in all_points if coord.tolist() not in known_eeg_point_coordinates.tolist()]) 44 | unknown_eeg_point_coordinates = unknown_eeg_point_coordinates.astype(float) 45 | 46 | # fnirs interpolate 47 | known_fnirs_point_coordinates = np.array([[2., 4.], #AF7 48 | [3., 4.], #AFF5 49 | [1., 5.], #AFp7 50 | [2., 5.], #AF5h 51 | [1., 7.], #AFp3 52 | [3., 6.], #AFF3h 53 | [2., 7.], #AF1 54 | [3., 8.], #AFFz 55 | [1., 8.], #AFpz 56 | [2., 9.], #AF2 57 | [1., 9.], #AFp4 58 | [6., 4.], #FCC3 59 | [7., 5.], #C3h 60 | [7., 3.], #C5h 61 | [8., 4.], #CCP3 62 | [10., 5.], #CPP3 63 | [11., 6.], #P3h 64 | [11., 4.], #P5h 65 | [12., 5.], #PPO3 66 | [3., 10.], #AFF4h 67 | [2., 11.], #AF6h 68 | [3., 12.], #AFF6 69 | [1., 11.], #AFp8 70 | [2., 12.], #AF8 71 | [6., 12.], #FCC4 72 | [7., 13.], #C6h 73 | [7., 11.], #C4h 74 | [8., 12.], #CCP4 75 | [10., 11.], #CPP4 76 | [11., 12.], #P6h 77 | [11., 10.], #P4h 78 | [12., 11.], #PPO4 79 | [12., 8.], #PPOz 80 | [13., 7.], #PO1 81 | [13., 9.], #PO2 82 | [14., 8.] #POOz 83 | ]) 84 | 85 | unknown_fnirs_point_coordinates = np.array([coord for coord in all_points if coord.tolist() not in known_fnirs_point_coordinates.tolist()]) 86 | unknown_fnirs_point_coordinates = unknown_fnirs_point_coordinates.astype(float) 87 | 88 | n_epoch = 60 89 | 90 | subject_path = r'E:\IF\review\new_dataset\epoch' 91 | subject_list = os.listdir(subject_path) 92 | 93 | for subject in subject_list: 94 | with np.load(os.path.join(subject_path, subject)) as data: 95 | eeg = data['eeg'] 96 | hbo = data['hbo'] 97 | hbr = data['hbr'] 98 | label = data['label'] 99 | 100 | eeg_3dtensor = np.ones((eeg.shape[0], 16, 16, eeg.shape[-1])) 101 | fnirs_hbo_3dtensor = np.ones((hbo.shape[0], 16, 16, hbo.shape[-1])) 102 | fnirs_hbr_3dtensor = np.ones((hbr.shape[0], 16, 16, hbr.shape[-1])) 103 | 104 | assert eeg.shape[0] == hbo.shape[0] == hbr.shape[0] == n_epoch 105 | for e in range(n_epoch): 106 | # 3D eeg 107 | for t in range(eeg.shape[-1]): 108 | known_eeg_point_values = eeg[e, :, t] 109 | 110 | # create 16*16 array 111 | eeg_2dimage = np.ones((16, 16)) 112 | 113 | # cubic spline interpolate 114 | eeg_interpolated_values = griddata(points=known_eeg_point_coordinates, 115 | values=known_eeg_point_values, 116 | xi=unknown_eeg_point_coordinates, 117 | method='cubic') 118 | 119 | # y=row, x=col 120 | # first known points 121 | assert known_eeg_point_values.shape[0] == known_eeg_point_coordinates.shape[0] == 28 122 | for k in range(28): 123 | eeg_2dimage[int(known_eeg_point_coordinates[k, 0]), int(known_eeg_point_coordinates[k, 1])] = known_eeg_point_values[k] 124 | 125 | # second unknown points 126 | assert eeg_interpolated_values.shape[0] == unknown_eeg_point_coordinates.shape[0] == 228 127 | for u in range(228): 128 | eeg_2dimage[int(unknown_eeg_point_coordinates[u, 0]), int(unknown_eeg_point_coordinates[u, 1])] = eeg_interpolated_values[u] 129 | 130 | # nearest interpolate 131 | aftcub_known_eeg_point_values = eeg_2dimage[~np.isnan(eeg_2dimage)] 132 | aftcub_known_eeg_point_coordinates = np.argwhere(~np.isnan(eeg_2dimage)) 133 | nan_eeg_point_coordinates = np.argwhere(np.isnan(eeg_2dimage)) 134 | 135 | nan_eeg_interpolated_values = griddata(points=aftcub_known_eeg_point_coordinates, 136 | values=aftcub_known_eeg_point_values, 137 | xi=nan_eeg_point_coordinates, 138 | method='nearest') 139 | 140 | for ne in range(nan_eeg_point_coordinates.shape[0]): 141 | eeg_2dimage[nan_eeg_point_coordinates[ne, 0], nan_eeg_point_coordinates[ne, 1]] = nan_eeg_interpolated_values[ne] 142 | 143 | eeg_3dtensor[e, :, :, t] = eeg_2dimage 144 | 145 | # 3D fnirs 146 | assert hbo.shape[-1] == hbr.shape[-1] 147 | for ft in range(hbo.shape[-1]): 148 | known_hbo_point_values = hbo[e, :, ft] 149 | known_hbr_point_values = hbr[e, :, ft] 150 | 151 | # create 16*16 array 152 | hbo_2dimage = np.ones((16, 16)) 153 | hbr_2dimage = np.ones((16, 16)) 154 | 155 | # cubic spline interpolate 156 | hbo_interpolated_values = griddata(points=known_fnirs_point_coordinates, 157 | values=known_hbo_point_values, 158 | xi=unknown_fnirs_point_coordinates, 159 | method='cubic') 160 | 161 | hbr_interpolated_values = griddata(points=known_fnirs_point_coordinates, 162 | values=known_hbr_point_values, 163 | xi=unknown_fnirs_point_coordinates, 164 | method='cubic') 165 | 166 | # first known points 167 | assert known_hbo_point_values.shape[0] == known_hbr_point_values.shape[0] == known_fnirs_point_coordinates.shape[0] == 36 168 | for fk in range(36): 169 | hbo_2dimage[int(known_fnirs_point_coordinates[fk, 0]), int(known_fnirs_point_coordinates[fk, 1])] = known_hbo_point_values[fk] 170 | hbr_2dimage[int(known_fnirs_point_coordinates[fk, 0]), int(known_fnirs_point_coordinates[fk, 1])] = known_hbr_point_values[fk] 171 | 172 | # second unknown points 173 | assert hbo_interpolated_values.shape[0] == hbr_interpolated_values.shape[0] == unknown_fnirs_point_coordinates.shape[0] == 220 174 | for fu in range(220): 175 | hbo_2dimage[int(unknown_fnirs_point_coordinates[fu, 0]), int(unknown_fnirs_point_coordinates[fu, 1])] = hbo_interpolated_values[fu] 176 | hbr_2dimage[int(unknown_fnirs_point_coordinates[fu, 0]), int(unknown_fnirs_point_coordinates[fu, 1])] = hbr_interpolated_values[fu] 177 | 178 | # nearest interpolate 179 | aftcub_known_hbo_point_values = hbo_2dimage[~np.isnan(hbo_2dimage)] 180 | aftcub_known_hbo_point_coordinates = np.argwhere(~np.isnan(hbo_2dimage)) 181 | nan_hbo_point_coordinates = np.argwhere(np.isnan(hbo_2dimage)) 182 | 183 | aftcub_known_hbr_point_values = hbr_2dimage[~np.isnan(hbr_2dimage)] 184 | aftcub_known_hbr_point_coordinates = np.argwhere(~np.isnan(hbr_2dimage)) 185 | nan_hbr_point_coordinates = np.argwhere(np.isnan(hbr_2dimage)) 186 | 187 | nan_hbo_interpolated_values = griddata(points=aftcub_known_hbo_point_coordinates, 188 | values=aftcub_known_hbo_point_values, 189 | xi=nan_hbo_point_coordinates, 190 | method='nearest') 191 | 192 | nan_hbr_interpolated_values = griddata(points=aftcub_known_hbr_point_coordinates, 193 | values=aftcub_known_hbr_point_values, 194 | xi=nan_hbr_point_coordinates, 195 | method='nearest') 196 | 197 | assert nan_hbo_point_coordinates.shape[0] == nan_hbr_point_coordinates.shape[0] 198 | for nf in range(nan_hbo_point_coordinates.shape[0]): 199 | hbo_2dimage[nan_hbo_point_coordinates[nf, 0], nan_hbo_point_coordinates[nf, 1]] = nan_hbo_interpolated_values[nf] 200 | hbr_2dimage[nan_hbr_point_coordinates[nf, 0], nan_hbr_point_coordinates[nf, 1]] = nan_hbr_interpolated_values[nf] 201 | 202 | fnirs_hbo_3dtensor[e, :, :, ft] = hbo_2dimage 203 | fnirs_hbr_3dtensor[e, :, :, ft] = hbr_2dimage 204 | 205 | print(eeg_3dtensor.shape) 206 | print(fnirs_hbo_3dtensor.shape) 207 | print(fnirs_hbr_3dtensor.shape) 208 | print(label.shape) 209 | 210 | save_dict = { 211 | 'eeg':eeg_3dtensor, 212 | 'hbo':fnirs_hbo_3dtensor, 213 | 'hbr':fnirs_hbr_3dtensor, 214 | 'label':label 215 | } 216 | 217 | save_dir = r'E:\IF\review\new_dataset\d3' 218 | save_name = subject 219 | 220 | np.savez(os.path.join(save_dir,save_name),**save_dict) 221 | print('\n==============save {} success=============\n'.format(save_name)) -------------------------------------------------------------------------------- /sta.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | from tensorflow import keras 4 | from keras import layers 5 | 6 | # If you have any questions, please contact us at shulmt@shu.edu.cn 7 | 8 | def pearson_r(eeg, fnirs): 9 | mx = tf.math.reduce_mean(eeg, axis=1, keepdims=True) 10 | my = tf.math.reduce_mean(fnirs, axis=1, keepdims=True) 11 | xm, ym = eeg-mx, fnirs-my 12 | r_num = tf.math.reduce_mean(tf.multiply(xm,ym), axis=1) 13 | r_den = tf.math.reduce_std(xm, axis=1) * tf.math.reduce_std(ym, axis=1) + 1e-6 14 | plcc = r_num / r_den 15 | plcc = tf.math.abs(plcc) 16 | plcc_meanbatch = tf.math.reduce_mean(plcc) 17 | 18 | return plcc_meanbatch 19 | 20 | 21 | class pos_embedding(layers.Layer): 22 | def __init__(self): 23 | super(pos_embedding, self).__init__() 24 | 25 | def build(self, input_shape): 26 | self.pos_embedding = self.add_weight(name='pos_embedding', 27 | shape=(1, input_shape[-2], input_shape[-1]), 28 | initializer=tf.keras.initializers.HeUniform(), 29 | trainable=True) 30 | 31 | def call(self, inputs): 32 | return inputs + self.pos_embedding 33 | 34 | 35 | def get_angles(pos, i, d_model): 36 | angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model)) 37 | return pos * angle_rates 38 | 39 | def positional_encoding(position, d_model): 40 | angle_rads = get_angles(np.arange(position)[:, np.newaxis], 41 | np.arange(d_model)[np.newaxis, :], 42 | d_model) 43 | 44 | angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) 45 | 46 | angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) 47 | 48 | pos_encoding = angle_rads[np.newaxis, ...] 49 | 50 | return tf.cast(pos_encoding, dtype=tf.float32) 51 | 52 | 53 | class e_f_attention(keras.layers.Layer): 54 | def __init__(self, emb_size, d_model, heads, drop): 55 | super(e_f_attention, self).__init__() 56 | 57 | self.q_flat = layers.Flatten() 58 | self.q_proj = layers.Dense(emb_size) 59 | 60 | self.fusion_proj = layers.Dense(emb_size) 61 | 62 | self.k_flat = layers.Reshape((11, -1)) 63 | self.k_proj = layers.Dense(emb_size) 64 | self.pos = pos_embedding() 65 | 66 | self.dot_product_attention = layers.MultiHeadAttention(num_heads=heads, key_dim=d_model, dropout=drop) 67 | 68 | self.ef_plcc_tracker = keras.metrics.Mean(name="ef_plcc") 69 | 70 | def call(self, inputs): 71 | eeg, fnirs = inputs 72 | 73 | q_eeg = self.q_flat(eeg) 74 | 75 | fusion_output = self.fusion_proj(q_eeg) 76 | 77 | q_eeg = self.q_proj(q_eeg) 78 | q_eeg = tf.expand_dims(q_eeg, axis=1) 79 | 80 | k_fnirs = self.k_flat(fnirs) 81 | k_fnirs = self.pos(k_fnirs) 82 | k_fnirs = self.k_proj(k_fnirs) 83 | 84 | fnirs_weighted, attention_weights = self.dot_product_attention(q_eeg, k_fnirs, return_attention_scores=True) 85 | attention_weights = tf.math.reduce_mean(attention_weights, axis=(1, 2)) 86 | 87 | q_eeg = tf.math.reduce_mean(q_eeg, axis=1) 88 | fnirs_weighted = tf.math.reduce_mean(fnirs_weighted, axis=1) 89 | 90 | ef_loss = pearson_r(q_eeg, fnirs_weighted) 91 | 92 | self.add_loss(1-ef_loss) 93 | 94 | self.ef_plcc_tracker.update_state(ef_loss) 95 | 96 | return fusion_output, fnirs_weighted, attention_weights 97 | 98 | 99 | class gap(keras.layers.Layer): 100 | def __init__(self): 101 | super(gap, self).__init__() 102 | 103 | def call(self, inputs): 104 | return tf.reduce_mean(inputs, axis=-2, keepdims=True) 105 | 106 | 107 | class fga(keras.layers.Layer): 108 | def __init__(self, tem_kernel_size, fga_loss_name): 109 | super(fga, self).__init__() 110 | 111 | self.channel_pooling = layers.Conv3D(filters=1, kernel_size=(3, 3, tem_kernel_size), strides=(1, 1, 1), padding='same') 112 | 113 | self.tap_fnirs = gap() 114 | 115 | self.residual_para = self.add_weight(name='residual_para', initializer="zeros", trainable=True) 116 | 117 | self.add_eeg = layers.Add() 118 | self.add = layers.Add() 119 | 120 | self.eeg_flatten = layers.Flatten() 121 | self.fnirs_flatten = layers.Flatten() 122 | self.fusion_flatten = layers.Flatten() 123 | 124 | self.fga_loss_tracker = keras.metrics.Mean(name=fga_loss_name) 125 | 126 | def call(self, inputs): 127 | eeg_fusion, eeg, fnirs = inputs 128 | 129 | fnirs_attention = self.channel_pooling(fnirs) 130 | 131 | fnirs_attention_map = self.tap_fnirs(fnirs_attention) 132 | fnirs_attention_map = tf.math.reduce_mean(fnirs_attention_map, axis=1) 133 | 134 | fnirs_attention_map_norm = keras.activations.sigmoid(fnirs_attention_map) 135 | 136 | eeg_fusion_guided = tf.math.multiply(eeg_fusion, fnirs_attention_map_norm) 137 | 138 | residual_para_norm = keras.activations.sigmoid(self.residual_para) 139 | eeg_add = self.add_eeg([residual_para_norm*eeg, (1-residual_para_norm)*eeg_fusion]) 140 | 141 | fga_feature = self.add([eeg_fusion_guided, eeg_add]) 142 | 143 | 144 | eeg_plcc = tf.math.reduce_mean(eeg, axis=(-1, -2)) 145 | eeg_plcc = self.eeg_flatten(eeg_plcc) 146 | 147 | fnirs_attention_map_norm_plcc = self.fnirs_flatten(fnirs_attention_map_norm) 148 | 149 | fga_loss = pearson_r(eeg_plcc, fnirs_attention_map_norm_plcc) 150 | 151 | self.add_loss(1-fga_loss) 152 | 153 | self.fga_loss_tracker.update_state(fga_loss) 154 | 155 | 156 | return fga_feature 157 | 158 | 159 | class conv_block(keras.layers.Layer): 160 | def __init__(self, eeg_filter, eeg_size, eeg_stride, 161 | fnirs_filter, fnirs_size, fnirs_stride, 162 | eegfusion_filter, eegfusion_size, eegfusion_stride, 163 | tem_kernel_size, fga_loss_name, padding): 164 | super(conv_block, self).__init__() 165 | 166 | self.eeg_conv = layers.Conv3D(filters=eeg_filter, kernel_size=eeg_size, strides=eeg_stride, padding=padding) 167 | self.eeg_act = layers.Activation('elu') 168 | self.eeg_bn = layers.BatchNormalization() 169 | 170 | self.fnirs_conv = layers.Conv3D(filters=fnirs_filter, kernel_size=fnirs_size, strides=fnirs_stride, padding=padding) 171 | self.fnirs_act = layers.Activation('elu') 172 | self.fnirs_bn = layers.BatchNormalization() 173 | 174 | self.eegfusion_conv = layers.Conv3D(filters=eegfusion_filter, kernel_size=eegfusion_size, strides=eegfusion_stride, padding=padding) 175 | self.eegfusion_act = layers.Activation('elu') 176 | self.eegfusion_bn = layers.BatchNormalization() 177 | 178 | self.fga = fga(tem_kernel_size, fga_loss_name) 179 | 180 | def call(self, inputs): 181 | eegfusion, eeg, fnirs = inputs 182 | 183 | eeg_feature = self.eeg_conv(eeg) 184 | eeg_feature = self.eeg_bn(eeg_feature) 185 | eeg_feature = self.eeg_act(eeg_feature) 186 | 187 | fnirs_feature = self.fnirs_conv(fnirs) 188 | fnirs_feature = self.fnirs_bn(fnirs_feature) 189 | fnirs_feature = self.fnirs_act(fnirs_feature) 190 | 191 | eegfusion_feature = self.eegfusion_conv(eegfusion) 192 | eegfusion_feature = self.eegfusion_bn(eegfusion_feature) 193 | eegfusion_feature = self.eegfusion_act(eegfusion_feature) 194 | 195 | eegfusion_fga = self.fga((eegfusion_feature, eeg_feature, fnirs_feature)) 196 | 197 | return eegfusion_fga, eeg_feature, fnirs_feature 198 | 199 | 200 | class reduce_sum_layer(keras.layers.Layer): 201 | def __init__(self, axis, keepaxis, name=None, *args, **kwargs): 202 | super(reduce_sum_layer, self).__init__(name=name) 203 | 204 | self.axis = axis 205 | self.keepaxis = keepaxis 206 | 207 | def call(self, inputs): 208 | return tf.math.reduce_sum(inputs, axis=self.axis, keepdims=self.keepaxis) 209 | 210 | 211 | class reduce_mean_layer(keras.layers.Layer): 212 | def __init__(self, axis, keepaxis, name=None, *args, **kwargs): 213 | super(reduce_mean_layer, self).__init__(name=name) 214 | 215 | self.axis = axis 216 | self.keepaxis = keepaxis 217 | 218 | def call(self, inputs): 219 | return tf.math.reduce_mean(inputs, axis=self.axis, keepdims=self.keepaxis) 220 | 221 | 222 | class expand_dims_layer(keras.layers.Layer): 223 | def __init__(self, axis, *args, **kwargs): 224 | super(expand_dims_layer, self).__init__() 225 | 226 | self.axis = axis 227 | 228 | def call(self, inputs): 229 | return tf.expand_dims(inputs, axis=self.axis) 230 | 231 | 232 | class prediction_weight_layer(keras.layers.Layer): 233 | def __init__(self, name=None): 234 | super(prediction_weight_layer, self).__init__(name=name) 235 | 236 | self.p_weight = self.add_weight(name='p_weight', shape=(2,1), initializer="zeros", trainable=True) 237 | 238 | def call(self, inputs): 239 | p_weights_softmax = tf.nn.softmax(self.p_weight, axis=0) 240 | 241 | p_weights_softmax = tf.expand_dims(p_weights_softmax, axis=0) 242 | 243 | the_prediction = tf.math.multiply(inputs, p_weights_softmax) 244 | the_prediction = tf.math.reduce_sum(the_prediction, axis=1) 245 | 246 | return the_prediction 247 | 248 | 249 | def sta_net(): 250 | eeg_input = keras.Input(shape=(16, 16, 600, 1), name="eeg_input") 251 | fnirs_input = keras.Input(shape=(11, 16, 16, 30, 2), name="fnirs_input") 252 | 253 | eegfusion1, eeg1, fnirs1 = conv_block(eeg_filter=16, eeg_size=(2, 2, 13), eeg_stride=(2, 2, 6), 254 | fnirs_filter=16, fnirs_size=(2, 2, 5), fnirs_stride=(2, 2, 2), 255 | eegfusion_filter=16, eegfusion_size=(2, 2, 13), eegfusion_stride=(2, 2, 6), 256 | tem_kernel_size=5, fga_loss_name='fgsa1_plcc', padding='same')((eeg_input, eeg_input, fnirs_input)) 257 | eegfusion1 = layers.Dropout(0.5)(eegfusion1) 258 | eeg1 = layers.Dropout(0.5)(eeg1) 259 | fnirs1 = layers.Dropout(0.5)(fnirs1) 260 | 261 | eegfusion2, eeg2, fnirs2 = conv_block(eeg_filter=32, eeg_size=(2, 2, 5), eeg_stride=(2, 2, 2), 262 | fnirs_filter=32, fnirs_size=(2, 2, 3), fnirs_stride=(2, 2, 2), 263 | eegfusion_filter=32, eegfusion_size=(2, 2, 5), eegfusion_stride=(2, 2, 2), 264 | tem_kernel_size=3, fga_loss_name='fgsa2_plcc', padding='same')((eegfusion1, eeg1, fnirs1)) 265 | eegfusion2 = gap()(eegfusion2) 266 | eeg2 = gap()(eeg2) 267 | fnirs2 = gap()(fnirs2) 268 | 269 | eegfusion2 = layers.Dropout(0.5)(eegfusion2) 270 | eeg2 = layers.Dropout(0.5)(eeg2) 271 | fnirs2 = layers.Dropout(0.5)(fnirs2) 272 | 273 | eegfusion_feature, fnirs_feature, _ = e_f_attention(emb_size=256, d_model=256, heads=10, drop=0.5)((eegfusion2, fnirs2)) 274 | eegfusion_feature_pweight = layers.Activation('elu')(eegfusion_feature) 275 | fnirs_feature_pweight = layers.Activation('elu')(fnirs_feature) 276 | 277 | eegfusion_feature_pweight = layers.Dense(256, activation='elu')(eegfusion_feature_pweight) 278 | fnirs_feature_pweight = layers.Dense(256, activation='elu')(fnirs_feature_pweight) 279 | 280 | eeg_feature = layers.Flatten()(eeg2) 281 | eeg_feature = layers.Dense(256, activation='elu')(eeg_feature) 282 | 283 | eegfusion_pred = layers.Dense(2)(eegfusion_feature_pweight) 284 | fnirs_pred = layers.Dense(2)(fnirs_feature_pweight) 285 | eeg_pred = layers.Dense(2)(eeg_feature) 286 | 287 | eeg_pred = layers.Activation('softmax', name='eeg_output')(eeg_pred) 288 | eegfusion_pred = layers.Activation('softmax')(eegfusion_pred) 289 | fnirs_pred = layers.Activation('softmax')(fnirs_pred) 290 | 291 | eegfusion_pred = expand_dims_layer(axis=1)(eegfusion_pred) 292 | fnirs_pred = expand_dims_layer(axis=1)(fnirs_pred) 293 | 294 | the_pred = layers.Concatenate(axis=1)([eegfusion_pred, fnirs_pred]) 295 | 296 | fnirs_p_weight = layers.Dense(1)(fnirs_feature_pweight) 297 | eegfusion_p_weight = layers.Dense(1)(eegfusion_feature_pweight) 298 | 299 | p_weight = layers.Concatenate()([eegfusion_p_weight, fnirs_p_weight]) 300 | p_weight = layers.Activation('softmax')(p_weight) 301 | p_weight = expand_dims_layer(axis=-1)(p_weight) 302 | 303 | the_pred = layers.Multiply()([the_pred, p_weight]) 304 | the_pred = reduce_sum_layer(axis=1, keepaxis=False, name='class_output')(the_pred) 305 | 306 | model = keras.Model(inputs=[eeg_input, fnirs_input], outputs=[the_pred, eeg_pred]) 307 | 308 | return model 309 | 310 | 311 | --------------------------------------------------------------------------------