├── LICENSE ├── README.md ├── locs_orig.mat ├── model.py ├── multi_processing.py ├── pic └── model.png └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 SCUT-IEL 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Low-Latency Auditory Spatial Attention Detection Based on Spectro-Spatial Features from EEG 2 | 3 | This repository contains the python scripts developed as a part of the work presented in the paper "Low-Latency Auditory Spatial Attention Detection Based on Spectro-Spatial Features from EEG" 4 | 5 | ## Getting Started 6 | 7 | ### Dataset 8 | 9 | The public [KUL dataset](https://zenodo.org/record/3997352#.YUGaZdP7R6q) is used in the paper. The dataset itself comes with matlab processing program, please adjust it according to your own needs. 10 | 11 | ### Prerequisites 12 | 13 | - python 3.7.9 14 | - tensorflow 2.2.0 15 | - keras 2.4.3 16 | 17 | ### Run the Code 18 | 19 | 1. Download the preprocessed data from [here](https://mailscuteducn-my.sharepoint.com/:f:/g/personal/202021058399_mail_scut_edu_cn/Evu3JoynOJxJlYtpKft2UfIBcZuNbkSrbymvDHLNdpiK9w?e=gWx9J0). 20 | 21 | 2. Modify the `args.data_document_path` variable in model.py to point to the downloaded data folder 22 | 23 | 3. Run the model: 24 | 25 | ```powershell 26 | python model.py 27 | ``` 28 | 29 | 4. If you want to run multiple subjects in parallel, you can modify the variable `path` in multi_processing.py and run: 30 | 31 | ```powershell 32 | python multi_processing.py 33 | ``` 34 | 35 | ## Paper 36 | 37 | ![model](./pic/model.png) 38 | 39 | Paper Link: [**Low-Latency Auditory Spatial Attention Detection Based on Spectro-Spatial Features from EEG**](https://arxiv.org/abs/2103.03621) 40 | 41 | The proposed convolutional neural network (CNN) with spectro-spatial feature (SSF) for auditory spatial attention detection, that is referred to as SSF-CNN model. The SSF-CNN network is trained to output two values, i.e., 0 and 1, to indicate the spatial location of the attended speaker. 42 | 43 | Please cite our paper if you find our work useful for your research: 44 | 45 | ```tex 46 | @article{cai2021low, 47 | title={Low-latency auditory spatial attention detection based on spectro-spatial features from EEG}, 48 | author={Cai, Siqi and Sun, Pengcheng and Schultz, Tanja and Li, Haizhou}, 49 | journal={arXiv preprint arXiv:2103.03621}, 50 | year={2021} 51 | } 52 | ``` 53 | 54 | ## License 55 | 56 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details 57 | 58 | ## Contact 59 | 60 | Siqi Cai, Pengcheng Sun, Tanja Schultz , and Haizhou Li 61 | 62 | Siqi Cai, Pengcheng Sun and Haizhou Li are with the Department of Electrical and Computer Engineering, National University of Singapore, Singapore. 63 | 64 | Tanja Schultz is with Cognitive Systems Lab, University of Bremen, Germany. 65 | 66 | -------------------------------------------------------------------------------- /locs_orig.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCUT-IEL/SSF-CNN/3cebe2ad0a4c567746b2e0f9a536a46932415239/locs_orig.mat -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : multi_processing.py 5 | 6 | @Modify Time @Author @Version @Desciption 7 | ------------ ------- -------- ----------- 8 | 2021/9/15 1:13 lintean 1.0 None 9 | ''' 10 | 11 | import math 12 | import time 13 | import random 14 | import numpy as np 15 | import pandas as pd 16 | from dotmap import DotMap 17 | from utils import cart2sph, pol2cart, makePath 18 | from keras.layers import Input, Dense, Activation, ZeroPadding2D, BatchNormalization, Flatten, Conv2D 19 | from keras.layers import AveragePooling2D, MaxPooling2D, Dropout, GlobalMaxPooling2D, GlobalAveragePooling2D 20 | from keras.models import Sequential 21 | import keras.backend as K 22 | from sklearn.preprocessing import scale 23 | from scipy.interpolate import griddata 24 | from keras.utils import np_utils 25 | from scipy.io import loadmat 26 | import keras 27 | import os 28 | from importlib import reload 29 | np.set_printoptions(suppress=True) 30 | 31 | 32 | def get_logger(name, log_path): 33 | import logging 34 | reload(logging) 35 | 36 | logger = logging.getLogger() 37 | logger.setLevel(logging.INFO) 38 | 39 | logfile = makePath(log_path) + "/Train_" + name + ".log" 40 | fh = logging.FileHandler(logfile, mode='w') 41 | fh.setLevel(logging.DEBUG) 42 | 43 | formatter = logging.Formatter("%(asctime)s - %(levelname)s: %(message)s") 44 | fh.setFormatter(formatter) 45 | 46 | logger.addHandler(fh) 47 | 48 | if log_path == "./result/test": 49 | ch = logging.StreamHandler() 50 | ch.setLevel(logging.INFO) 51 | ch.setFormatter(formatter) 52 | logger.addHandler(ch) 53 | 54 | return logger 55 | 56 | 57 | def azim_proj(pos): 58 | """ 59 | Computes the Azimuthal Equidistant Projection of input point in 3D Cartesian Coordinates. 60 | Imagine a plane being placed against (tangent to) a globe. If 61 | a light source inside the globe projects the graticule onto 62 | the plane the result would be a planar, or azimuthal, map 63 | projection. 64 | 65 | :param pos: position in 3D Cartesian coordinates 66 | :return: projected coordinates using Azimuthal Equidistant Projection 67 | """ 68 | [r, elev, az] = cart2sph(pos[0], pos[1], pos[2]) 69 | return pol2cart(az, math.pi / 2 - elev) 70 | 71 | 72 | def gen_images(data, args): 73 | locs = loadmat('locs_orig.mat') 74 | locs_3d = locs['data'] 75 | locs_2d = [] 76 | for e in locs_3d: 77 | locs_2d.append(azim_proj(e)) 78 | 79 | locs_2d_final = np.array(locs_2d) 80 | grid_x, grid_y = np.mgrid[ 81 | min(np.array(locs_2d)[:, 0]):max(np.array(locs_2d)[:, 0]):args.image_size * 1j, 82 | min(np.array(locs_2d)[:, 1]):max(np.array(locs_2d)[:, 1]):args.image_size * 1j] 83 | 84 | images = [] 85 | for i in range(data.shape[0]): 86 | images.append(griddata(locs_2d_final, data[i, :], (grid_x, grid_y), method='cubic', fill_value=np.nan)) 87 | images = np.stack(images, axis=0) 88 | 89 | images[~np.isnan(images)] = scale(images[~np.isnan(images)]) 90 | images = np.nan_to_num(images) 91 | return images 92 | 93 | 94 | def read_prepared_data(args): 95 | data = [] 96 | 97 | for l in range(len(args.ConType)): 98 | for k in range(args.trail_number): 99 | filename = args.data_document_path + "/" + args.ConType[l] + "/" + args.name + "Tra" + str(k + 1) + ".csv" 100 | data_pf = pd.read_csv(filename, header=None) 101 | eeg_data = data_pf.iloc[:, 2 * args.audio_channel:] 102 | 103 | data.append(eeg_data) 104 | 105 | data = pd.concat(data, axis=0, ignore_index=True) 106 | return data 107 | 108 | 109 | # output shape: [(time, feature) (window, feature) (window, feature)] 110 | def window_split(data, args): 111 | random.seed(args.random_seed) 112 | # init 113 | test_percent = args.test_percent 114 | window_lap = args.window_length * (1 - args.overlap) 115 | overlap_distance = max(0, math.floor(1 / (1 - args.overlap)) - 1) 116 | 117 | train_set = [] 118 | test_set = [] 119 | 120 | for l in range(len(args.ConType)): 121 | label = pd.read_csv(args.data_document_path + "/csv/" + args.name + args.ConType[l] + ".csv") 122 | 123 | # split trial 124 | for k in range(args.trail_number): 125 | # the number of windows in a trial 126 | window_number = math.floor( 127 | (args.cell_number - args.window_length) / window_lap) + 1 128 | 129 | test_window_length = math.floor( 130 | (args.cell_number * test_percent - args.window_length) / window_lap) 131 | test_window_length = test_window_length if test_percent == 0 else max( 132 | 0, test_window_length) 133 | test_window_length = test_window_length + 1 134 | 135 | test_window_left = random.randint(0, window_number - test_window_length) 136 | test_window_right = test_window_left + test_window_length - 1 137 | target = label.iloc[k, args.label_col] 138 | 139 | # split window 140 | for i in range(window_number): 141 | left = math.floor(k * args.cell_number + i * window_lap) 142 | right = math.floor(left + args.window_length) 143 | # train set or test set 144 | if test_window_left > test_window_right or test_window_left - i > overlap_distance or i - test_window_right > overlap_distance: 145 | train_set.append(np.array([left, right, target, len(train_set), k, args.subject_number])) 146 | elif test_window_left <= i <= test_window_right: 147 | test_set.append(np.array([left, right, target, len(test_set), k, args.subject_number])) 148 | 149 | # concat 150 | train_set = np.stack(train_set, axis=0) 151 | test_set = np.stack(test_set, axis=0) if len(test_set) > 1 else None 152 | 153 | return np.array(data), train_set, test_set 154 | 155 | 156 | def to_alpha(data, window, args): 157 | alpha_data = [] 158 | for window_index in range(window.shape[0]): 159 | start = window[window_index][args.window_metadata.start] 160 | end = window[window_index][args.window_metadata.end] 161 | window_data = np.fft.fft(data[start:end, :], n=args.window_length, axis=0) 162 | window_data = np.abs(window_data) / args.window_length 163 | window_data = np.sum(np.power(window_data[args.point_low:args.point_high, :], 2), axis=0) 164 | alpha_data.append(window_data) 165 | alpha_data = np.stack(alpha_data, axis=0) 166 | return alpha_data 167 | 168 | 169 | def main(name="S1", data_document_path="D:\\eegdata\\KUL_single_single3"): 170 | args = DotMap() 171 | args.name = name 172 | args.subject_number = int(args.name[1:]) 173 | args.data_document_path = data_document_path 174 | args.ConType = ["No"] 175 | args.fs = 128 176 | args.window_length = math.ceil(args.fs * 1) 177 | args.overlap = 0.8 178 | args.batch_size = 32 179 | args.max_epoch = 200 180 | args.random_seed = time.time() 181 | args.image_size = 32 182 | args.people_number = 16 183 | args.eeg_channel = 64 184 | args.audio_channel = 1 185 | args.channel_number = args.eeg_channel + args.audio_channel * 2 186 | args.trail_number = 8 187 | args.cell_number = 46080 188 | args.test_percent = 0.1 189 | args.vali_percent = 0.1 190 | args.label_col = 0 191 | args.alpha_low = 8 192 | args.alpha_high = 13 193 | args.log_path = "./result" 194 | args.frequency_resolution = args.fs / args.window_length 195 | args.point_low = math.ceil(args.alpha_low / args.frequency_resolution) 196 | args.point_high = math.ceil(args.alpha_high / args.frequency_resolution) + 1 197 | args.window_metadata = DotMap(start=0, end=1, target=2, index=3, trail_number=4, subject_number=5) 198 | os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,4,5,6,7" 199 | logger = get_logger(args.name, args.log_path) 200 | 201 | # load data 和 label 202 | data = read_prepared_data(args) 203 | 204 | # split window、testset 205 | data, train_window, test_window = window_split(data, args) 206 | train_label = train_window[:, args.window_metadata.target] 207 | test_label = test_window[:, args.window_metadata.target] 208 | 209 | # fft 210 | train_data = to_alpha(data, train_window, args) 211 | test_data = to_alpha(data, test_window, args) 212 | del data 213 | 214 | # to images 215 | train_data = gen_images(train_data, args) 216 | test_data = gen_images(test_data, args) 217 | 218 | train_data = np.expand_dims(train_data, axis=-1) 219 | test_data = np.expand_dims(test_data, axis=-1) 220 | train_label = np_utils.to_categorical(train_label - 1, 2) 221 | test_label = np_utils.to_categorical(test_label - 1, 2) 222 | 223 | # train 224 | model = Sequential() 225 | model.add(Conv2D(32, (3, 3), padding='same', input_shape=train_data.shape[1:], 226 | kernel_regularizer=keras.regularizers.l2(0.01), data_format="channels_last")) 227 | model.add(BatchNormalization()) 228 | model.add(Activation('relu')) 229 | model.add(AveragePooling2D(pool_size=(2, 2), data_format="channels_last")) 230 | model.add(Dropout(0.1)) 231 | 232 | model.add(Flatten()) 233 | 234 | model.add(Dense(512)) 235 | model.add(BatchNormalization()) 236 | model.add(Activation('relu')) 237 | model.add(Dropout(0.3)) 238 | 239 | model.add(Dense(32)) 240 | model.add(BatchNormalization()) 241 | model.add(Activation('relu')) 242 | 243 | # Output layer 244 | model.add(Dense(2)) 245 | model.add(Activation('softmax')) 246 | 247 | # Output the parameter status of each layer of the model 248 | model.summary() 249 | 250 | opt = keras.optimizers.RMSprop(lr=0.0003, decay=3e-4) 251 | model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) 252 | # plot_model(model, to_file='model.png', show_shapes=True) 253 | 254 | history = model.fit(train_data, train_label, batch_size=args.batch_size, epochs=args.max_epoch, validation_split=args.vali_percent, verbose=2) 255 | loss, accuracy = model.evaluate(test_data, test_label) 256 | print(loss, accuracy) 257 | logger.info(loss) 258 | logger.info(accuracy) 259 | 260 | 261 | if __name__ == "__main__": 262 | main() 263 | 264 | -------------------------------------------------------------------------------- /multi_processing.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | @File : multi_processing.py 5 | 6 | @Modify Time @Author @Version @Desciption 7 | ------------ ------- -------- ----------- 8 | 2021/9/15 1:15 lintean 1.0 None 9 | ''' 10 | 11 | from multiprocessing import Process 12 | from model import main 13 | import utils as util 14 | 15 | if __name__ == "__main__": 16 | multiple = 1 17 | process = [] 18 | path = "/document/data/eeg/KUL_single_single3" 19 | names = ['S' + str(i+1) for i in range(0, 16)] 20 | for name in names: 21 | p = Process(target=main, args=(name, path,)) # 必须加,号 22 | p.start() 23 | process.append(p) 24 | util.monitor(process, multiple, 60) 25 | 26 | -------------------------------------------------------------------------------- /pic/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SCUT-IEL/SSF-CNN/3cebe2ad0a4c567746b2e0f9a536a46932415239/pic/model.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import numpy as np 3 | import math 4 | np.set_printoptions(suppress=True) 5 | import os 6 | import time 7 | 8 | 9 | def cart2sph(x, y, z): 10 | """ 11 | Transform Cartesian coordinates to spherical 12 | :param x: X coordinate 13 | :param y: Y coordinate 14 | :param z: Z coordinate 15 | :return: radius, elevation, azimuth 16 | """ 17 | x2_y2 = x**2 + y**2 18 | r = math.sqrt(x2_y2 + z**2) # r 19 | elev = math.atan2(z, math.sqrt(x2_y2)) # Elevation 20 | az = math.atan2(y, x) # Azimuth 21 | return r, elev, az 22 | 23 | 24 | def pol2cart(theta, rho): 25 | """ 26 | Transform polar coordinates to Cartesian 27 | :param theta: angle value 28 | :param rho: radius value 29 | :return: X, Y 30 | """ 31 | return rho * math.cos(theta), rho * math.sin(theta) 32 | 33 | 34 | def makePath(path): 35 | if not os.path.isdir(path): 36 | os.makedirs(path) 37 | return path 38 | 39 | 40 | def monitor(process, multiple, second): 41 | while True: 42 | sum = 0 43 | for ps in process: 44 | if ps.is_alive(): 45 | sum += 1 46 | if sum < multiple: 47 | break 48 | else: 49 | time.sleep(second) 50 | 51 | --------------------------------------------------------------------------------