├── demo-SAME-BETA-main.ipynb ├── demo_beta_TDCA_withSAME.mat ├── demo_beta_TDCA_withoutSAME.mat ├── demo_beta_eTRCA_withSAME.mat ├── demo_beta_eTRCA_withoutSAME.mat └── readme.md /demo-SAME-BETA-main.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# load module\n", 10 | "import numpy as np\n", 11 | "import numpy.matlib\n", 12 | "import math\n", 13 | "import scipy.io as sio\n", 14 | "import warnings\n", 15 | "from scipy import signal\n", 16 | "warnings.filterwarnings('default')" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": 2, 22 | "metadata": {}, 23 | "outputs": [ 24 | { 25 | "name": "stderr", 26 | "output_type": "stream", 27 | "text": [ 28 | "/root/miniconda3/envs/myconda/lib/python3.8/site-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n", 29 | " and should_run_async(code)\n" 30 | ] 31 | } 32 | ], 33 | "source": [ 34 | "def corr2(a, b):\n", 35 | " \"\"\" Solving for the two-dimensional correlation coefficient\n", 36 | " :param a: \n", 37 | " :param b: \n", 38 | "\n", 39 | " :return: correlation coefficient\n", 40 | " \"\"\"\n", 41 | " \n", 42 | " a = a - np.sum(a) / np.size(a)\n", 43 | " b = b - np.sum(b) / np.size(b)\n", 44 | " r = (a * b).sum() / math.sqrt((a * a).sum() * (b * b).sum())\n", 45 | " return r\n", 46 | "\n", 47 | "def acc_calculate(predict):\n", 48 | " \"\"\" Calculate accuracy\n", 49 | " :param predict: (n_trial,n_event)\n", 50 | "\n", 51 | " :return: acc\n", 52 | " \"\"\"\n", 53 | " \n", 54 | " [nTrials, nEvents] = predict.shape \n", 55 | " label_target = np.ones((nTrials, 1)) * np.arange(0, nEvents, 1, int)\n", 56 | " logical_right = (label_target == predict)\n", 57 | " acc_num = np.sum(logical_right != 0)\n", 58 | " acc = acc_num / nTrials / nEvents\n", 59 | " return acc\n", 60 | "\n", 61 | "class PreProcessing_BETA():\n", 62 | " \"\"\"\n", 63 | " Adapted from Orion Han\n", 64 | " https://github.com/OrionHH/BrainOn-an-online-brain-computer-interface-BCI-framework\n", 65 | " \"\"\"\n", 66 | " \n", 67 | " CHANNELS = [\n", 68 | " 'FP1','FPZ','FP2','AF3','AF4','F7','F5','F3',\n", 69 | " 'F1','FZ','F2','F4','F6','F8','FT7','FC5',\n", 70 | " 'FC3','FC1','FCZ','FC2','FC4','FC6','FC8','T7',\n", 71 | " 'C5','C3','C1','CZ','C2','C4','C6','T8',\n", 72 | " 'M1','TP7','CP5','CP3','CP1','CPZ','CP2','CP4',\n", 73 | " 'CP6','TP8','M2','P7','P5','P3','P1','PZ',\n", 74 | " 'P2','P4','P6','P8','PO7','PO5','PO3','POZ',\n", 75 | " 'PO4','PO6','PO8','CB1','O1','OZ','O2','CB2'\n", 76 | " ] # M1: 33. M2: 43.\n", 77 | "\n", 78 | " def __init__(self, filepath, t_begin, t_end, n_classes=40, fs_down=250, chans=None, num_filter=1):\n", 79 | "\n", 80 | " self.filepath = filepath\n", 81 | " self.fs_down = fs_down\n", 82 | " self.t_begin = t_begin\n", 83 | " self.t_end = t_end\n", 84 | " self.chans = chans\n", 85 | " self.n_classes = n_classes\n", 86 | " self.num_filter = num_filter\n", 87 | "\n", 88 | " def load_data(self):\n", 89 | " '''\n", 90 | " Application: load data and selected channels by chans.\n", 91 | "\n", 92 | " :param chans: list | None\n", 93 | " :return: raw_data: 4-D, numpy\n", 94 | " n_chans * n_samples * n_classes * n_trials\n", 95 | "\n", 96 | " '''\n", 97 | " raw_mat = sio.loadmat(self.filepath)\n", 98 | " raw_data11 = raw_mat['data'] \n", 99 | " data = raw_data11[0,0]['EEG']\n", 100 | " raw_data = np.transpose(data,[0,1,3,2]) # n_chans * n_samples * n_classes * n_trials\n", 101 | "\n", 102 | " idx_loc = list()\n", 103 | " if isinstance(self.chans, list):\n", 104 | " for _, char_value in enumerate(self.chans):\n", 105 | " idx_loc.append(self.CHANNELS.index(char_value.upper()))\n", 106 | "\n", 107 | " raw_data = raw_data[idx_loc, : , : , :] if idx_loc else raw_data\n", 108 | "\n", 109 | " self.raw_fs = 250 # .mat sampling rate\n", 110 | "\n", 111 | " return raw_data\n", 112 | "\n", 113 | " def resample_data(self, raw_data):\n", 114 | " '''\n", 115 | " :param raw_data: from method load_data.\n", 116 | " :return: raw_data_resampled, 4-D, numpy\n", 117 | " n_chans * n_samples * n_classes * n_trials\n", 118 | " '''\n", 119 | " if self.raw_fs > self.fs_down:\n", 120 | " raw_data_resampled = signal.resample(raw_data, round(self.fs_down*raw_data.shape[1]/self.raw_fs), axis=1)\n", 121 | " elif self.raw_fs < self.fs_down:\n", 122 | " warnings.warn('You are up-sampling, no recommended')\n", 123 | " raw_data_resampled = signal.resample(raw_data, round(self.fs_down*raw_data.shape[1]/self.raw_fs), axis=1)\n", 124 | " else:\n", 125 | " raw_data_resampled = raw_data\n", 126 | "\n", 127 | " return raw_data_resampled\n", 128 | "\n", 129 | " def _get_iir_sos_band(self, w_pass, w_stop):\n", 130 | " '''\n", 131 | " Get second-order sections (like 'ba') of Chebyshev type I filter.\n", 132 | " :param w_pass: list, 2 elements\n", 133 | " :param w_stop: list, 2 elements\n", 134 | " :return: sos_system\n", 135 | " i.e the filter coefficients.\n", 136 | " '''\n", 137 | " if len(w_pass) != 2 or len(w_stop) != 2:\n", 138 | " raise ValueError('w_pass and w_stop must be a list with 2 elements.')\n", 139 | "\n", 140 | " if w_pass[0] > w_pass[1] or w_stop[0] > w_stop[1]:\n", 141 | " raise ValueError('Element 1 must be greater than Element 0 for w_pass and w_stop.')\n", 142 | "\n", 143 | " if w_pass[0] < w_stop[0] or w_pass[1] > w_stop[1]:\n", 144 | " raise ValueError('It\\'s a band-pass iir filter, please check the values between w_pass and w_stop.')\n", 145 | "\n", 146 | " wp = [2 * w_pass[0] / self.fs_down, 2 * w_pass[1] / self.fs_down]\n", 147 | " ws = [2 * w_stop[0] / self.fs_down, 2 * w_stop[1] / self.fs_down]\n", 148 | " gpass = 4 \n", 149 | " gstop = 30 # dB\n", 150 | "\n", 151 | " N, wn = signal.cheb1ord(wp, ws, gpass=gpass, gstop=gstop)\n", 152 | " sos_system = signal.cheby1(N, rp=0.5, Wn=wn, btype='bandpass', output='sos')\n", 153 | "\n", 154 | " return sos_system\n", 155 | "\n", 156 | "\n", 157 | " def filtered_data_iir111(self, w_pass_2d, w_stop_2d, data):\n", 158 | " '''\n", 159 | " filter data by IIR, which parameters are set by method _get_iir_sos_band in BasePreProcessing class.\n", 160 | " :param w_pass_2d: 2-d, numpy,\n", 161 | " w_pass_2d[0, :]: w_pass[0] of method _get_iir_sos_band,\n", 162 | " w_pass_2d[1, :]: w_pass[1] of method _get_iir_sos_band.\n", 163 | " :param w_stop_2d: 2-d, numpy,\n", 164 | " w_stop_2d[0, :]: w_stop[0] of method _get_iir_sos_band,\n", 165 | " w_stop_2d[1, :]: w_stop[1] of method _get_iir_sos_band.\n", 166 | " :param data: 4-d, numpy, from method load_data or resample_data.\n", 167 | " n_chans * n_samples * n_classes * n_trials\n", 168 | " :return: filtered_data: dict,\n", 169 | " {'bank1': values1, 'bank2': values2, ...,'bank'+str(num_filter): values}\n", 170 | " values1, values2,...: 4-D, numpy, n_chans * n_samples * n_classes * n_trials.\n", 171 | " e.g.\n", 172 | " w_pass_2d = np.array([[5, 14, 22, 30, 38, 46, 54],[70, 70, 70, 70, 70, 70, 70]])\n", 173 | " w_stop_2d = np.array([[3, 12, 20, 28, 36, 44, 52],[72, 72, 72, 72, 72, 72, 72]])\n", 174 | " '''\n", 175 | " if w_pass_2d.shape != w_stop_2d.shape:\n", 176 | " raise ValueError('The shape of w_pass_2d and w_stop_2d should be equal.')\n", 177 | " if self.num_filter > w_pass_2d.shape[1]:\n", 178 | " raise ValueError('num_filter should be less than or equal to w_pass_2d.shape[1]')\n", 179 | "\n", 180 | " begin_point, end_point = int(np.ceil(self.t_begin * self.fs_down)), int(np.ceil(self.t_end * self.fs_down) + 1)\n", 181 | " data = data[:,begin_point:end_point,:,:]\n", 182 | "\n", 183 | " sos_system = dict()\n", 184 | " filtered_data = dict()\n", 185 | " for idx_filter in range(self.num_filter):\n", 186 | " sos_system['filter'+str(idx_filter+1)] = self._get_iir_sos_band(w_pass=[w_pass_2d[0, idx_filter], w_pass_2d[1, idx_filter]],\n", 187 | " w_stop=[w_stop_2d[0, idx_filter],\n", 188 | " w_stop_2d[1, idx_filter]])\n", 189 | " filter_data = signal.sosfiltfilt(sos_system['filter' + str(idx_filter + 1)], data, axis=1)\n", 190 | " filtered_data['bank'+str(idx_filter+1)] = filter_data\n", 191 | "\n", 192 | " return filtered_data" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "### SAME" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 3, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "def TRCs_estimation(data, mean_target):\n", 209 | " \"\"\" source signal estimation\n", 210 | "\n", 211 | " :param data: n_channel_1, n_times\n", 212 | " :param mean_target: n_channel_2, n_times\n", 213 | "\n", 214 | " :return: data_after: n_channel_2, n_times\n", 215 | " \"\"\"\n", 216 | " \n", 217 | " nChannel, nTimes = np.shape(data)\n", 218 | " X_a = data\n", 219 | " X = mean_target\n", 220 | "\n", 221 | " # solve PT\n", 222 | " PT = X @ X_a.T @ np.linalg.inv(X_a @ X_a.T)\n", 223 | " data_after = PT @ X_a\n", 224 | "\n", 225 | " return data_after\n", 226 | "\n", 227 | "def get_augment_fb_noiseAfter(fs, f, Nh_start, Nh_end, ntrail_noise, mean_temp):\n", 228 | " \"\"\" Artificially generated signals by SAME\n", 229 | "\n", 230 | " :param fs: Sampling rate\n", 231 | " :param f: the frequency of signal\n", 232 | " :param Nh_start: Minimum number of harmonics\n", 233 | " :param Nh_end: Maximum number of harmonics\n", 234 | " :param ntrial_noise: Number of generated signals\n", 235 | " :param mean_temp: n_channel, n_times\n", 236 | "\n", 237 | " :return: data_aug: n_channel, n_times, ntrial_noise\n", 238 | " \"\"\"\n", 239 | " \n", 240 | " Nh_step = np.arange(Nh_start,Nh_end+1,1)\n", 241 | " Nh = Nh_step.shape[-1]\n", 242 | " nChannel, nTime = mean_temp.shape\n", 243 | " # Generate reference signal Yf\n", 244 | " Ts = 1 / fs\n", 245 | " n = np.arange(nTime) * Ts\n", 246 | " Yf = np.zeros((nTime, 2*Nh)) \n", 247 | " flag = 0\n", 248 | " for iNh in Nh_step:\n", 249 | " y_sin = np.sin(2 * np.pi * f * iNh * n)\n", 250 | " Yf[:, flag * 2] = y_sin\n", 251 | " y_cos = np.cos(2 * np.pi * f * iNh * n)\n", 252 | " Yf[:, flag * 2 + 1] = y_cos\n", 253 | " flag = flag + 1\n", 254 | "\n", 255 | " Z = TRCs_estimation(Yf.T, mean_temp)\n", 256 | " # get vars of Z\n", 257 | " vars = np.zeros((Z.shape[0],Z.shape[0]))\n", 258 | " for i_c in range(nChannel):\n", 259 | " vars[i_c,i_c] = np.var(Z[i_c,:])\n", 260 | "\n", 261 | " # add noise\n", 262 | " data_aug = np.zeros((nChannel, nTime, ntrail_noise))\n", 263 | " for i_aug in range(ntrail_noise):\n", 264 | " # Randomly generated noise\n", 265 | " Datanosie = np.random.multivariate_normal(mean=np.zeros((nChannel)), cov=vars, size = nTime)\n", 266 | " data_aug[:,:,i_aug] = Z + 0.05 * Datanosie.T\n", 267 | "\n", 268 | " return data_aug" 269 | ] 270 | }, 271 | { 272 | "cell_type": "markdown", 273 | "metadata": {}, 274 | "source": [ 275 | "### eTRCA" 276 | ] 277 | }, 278 | { 279 | "cell_type": "code", 280 | "execution_count": 4, 281 | "metadata": {}, 282 | "outputs": [], 283 | "source": [ 284 | "##TRCA\n", 285 | "def trca_matrix(data):\n", 286 | " \"\"\" Task-related component analysis (TRCA)\n", 287 | " :param data: Multi-trial EEG signals under the same task\n", 288 | " ndarray(n_channels, n_sample_points, n_trials)\n", 289 | "\n", 290 | " :return: w: spatial filter\n", 291 | " ndarray(n_channels, 1)\n", 292 | " \"\"\"\n", 293 | " \n", 294 | " X = data\n", 295 | "\n", 296 | " # X_mean = X.mean(axis=1, keepdims=True)\n", 297 | " # X = X - X_mean\n", 298 | "\n", 299 | " nChans = X.shape[0]\n", 300 | " nTimes = X.shape[1]\n", 301 | " nTrial = X.shape[2]\n", 302 | " # solve S\n", 303 | " S = np.zeros((nChans, nChans))\n", 304 | " for i in range(nTrial):\n", 305 | " for j in range(nTrial):\n", 306 | " if (i != j):\n", 307 | " x_i = X[:, :, i]\n", 308 | " x_j = X[:, :, j]\n", 309 | " S = S + np.dot(x_i, (x_j.T))\n", 310 | " # solve Q\n", 311 | " X1 = X.reshape([nChans, nTimes * nTrial], order='F') \n", 312 | " Q = X1 @ X1.T\n", 313 | " \n", 314 | " # get eigenvector\n", 315 | " b = np.dot(np.linalg.inv(Q), S)\n", 316 | " [eig_value, eig_w] = np.linalg.eig(b) # in matlab:a/b = inv(a)*b\n", 317 | "\n", 318 | " # Descending order\n", 319 | " eig_w = eig_w[:, eig_value.argsort()[::-1]] # return indices in ascending order and reverse\n", 320 | " eig_value.sort()\n", 321 | " eig_value = eig_value[::-1] # sort in descending\n", 322 | "\n", 323 | " w = eig_w[:, 0]\n", 324 | " return w.real\n", 325 | "\n", 326 | "def TRCA_train(trainData):\n", 327 | " \"\"\" Get TRCA spatial filters and average templates for all classes\n", 328 | " :param trainData: training data of all events\n", 329 | " ndarray(n_channels, n_sample_points, n_events, n_trials)\n", 330 | "\n", 331 | " :return: w: (n_channels, n_events)\n", 332 | " mean_temp (n_channels, n_sample_points, n_events)\n", 333 | " \"\"\"\n", 334 | " \n", 335 | " [nChannels, nTimes, nEvents, nTrials] = trainData.shape \n", 336 | " # get w of event class\n", 337 | " w = np.zeros((nChannels, nEvents))\n", 338 | " for i in range(nEvents):\n", 339 | " w_data = trainData[:, :, i, :]\n", 340 | " w1 = trca_matrix(w_data)\n", 341 | " w[:, i] = w1\n", 342 | " # get mean temps\n", 343 | " mean_temp = np.zeros((nChannels, nTimes, nEvents))\n", 344 | " mean_temp = np.mean(trainData, -1)\n", 345 | " return w, mean_temp\n", 346 | "\n", 347 | "def TRCA_test(testData, w, mean_temp, ensemble):\n", 348 | " \"\"\"\n", 349 | " :param testData: test_data of multi trials\n", 350 | " ndarray(n_channels, n_sample_points, n_trials(equals to n_events))\n", 351 | " :param w: Spatial Filters\n", 352 | " ndarray (n_channels, n_events)\n", 353 | " :param mean_temp: Average template\n", 354 | " ndarray (n_channels, n_sample_points, n_events)\n", 355 | " :param ensemble: bool\n", 356 | "\n", 357 | " :return: predict of singe block\n", 358 | " ndarray(n_trials, n_classes)\n", 359 | " \"\"\"\n", 360 | " \n", 361 | " [nChannels, nTimes, nEvents] = testData.shape\n", 362 | " rr = np.zeros((nEvents, nEvents))\n", 363 | " for m in range(nEvents): # the m-th test data\n", 364 | " test = testData[:, :, m]\n", 365 | " # Calculate the vector of correlation coefficients\n", 366 | " r = np.zeros(nEvents)\n", 367 | " for n in range(nEvents): # the n-th train model\n", 368 | " train = mean_temp[:, :, n]\n", 369 | " if ensemble is True:\n", 370 | " r[n] = corr2(train.T @ w, test.T @ w)\n", 371 | " else:\n", 372 | " r[n] = corr2(train.T @ w[:, n], test.T @ w[:, n])\n", 373 | " rr[m, :] = r\n", 374 | " return rr" 375 | ] 376 | }, 377 | { 378 | "cell_type": "markdown", 379 | "metadata": {}, 380 | "source": [ 381 | "### TDCA" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 5, 387 | "metadata": {}, 388 | "outputs": [], 389 | "source": [ 390 | "##TDCA\n", 391 | "def get_P(f_list, Nh, sTime, sfreq):\n", 392 | " \"\"\" Get the projection matrix P for all classes\n", 393 | " :param f_list: the frequency of all events\n", 394 | " :param Nh: number of harmonics\n", 395 | " :param sTime: signal duration\n", 396 | " :param sfreq: sampling rate\n", 397 | "\n", 398 | " :return: P: the projection matrix P for all classes\n", 399 | " ndarray(n_Times, n_Times, n_Events)\n", 400 | " \"\"\"\n", 401 | " \n", 402 | " nEvent = f_list.shape[0]\n", 403 | " P = np.zeros((int(sTime * sfreq), int(sTime * sfreq), nEvent))\n", 404 | " for iievent in range(nEvent):\n", 405 | " # Generate reference signal Yf\n", 406 | " f = f_list[iievent]\n", 407 | " nTime = int(sTime * sfreq)\n", 408 | " Ts = 1 / sfreq\n", 409 | " n = np.arange(nTime) * Ts\n", 410 | " Yf = np.zeros((nTime, 2 * Nh)) \n", 411 | " for iNh in range(Nh):\n", 412 | " y_sin = np.sin(2 * np.pi * f * (iNh + 1) * n)\n", 413 | " Yf[:, iNh * 2] = y_sin\n", 414 | " y_cos = np.cos(2 * np.pi * f * (iNh + 1) * n)\n", 415 | " Yf[:, iNh * 2 + 1] = y_cos\n", 416 | " q, _ = np.linalg.qr(Yf, mode='reduced')\n", 417 | " P[:,:,iievent] = q @ q.T\n", 418 | "\n", 419 | " return P\n", 420 | "\n", 421 | "def tdca_matrix(data, Nk):\n", 422 | " \"\"\" Task-discriminant component analysis (TDCA)\n", 423 | " :param data: training data of all events\n", 424 | " ndarray(n_channels * (l + 1), 2*n_points, n_events, n_trials)\n", 425 | " :param Nk: the number of subspaces\n", 426 | " int\n", 427 | "\n", 428 | " :return: w: Spatial Filters\n", 429 | " ndarray(n_channels * (l + 1), Nk)\n", 430 | " \"\"\"\n", 431 | " \n", 432 | " X_aug_2 = np.transpose(data,[2,3,0,1]) # nEvents, nTrials, nChannels * (l + 1), 2*npoints\n", 433 | " [n_events, n_trials, _, _] = X_aug_2.shape\n", 434 | " # get Sb\n", 435 | " class_center = X_aug_2.mean(axis=1) # # nEvents , nChannels * (l + 1), 2*npoints\n", 436 | " total_center = class_center.mean(axis=0,keepdims=True) # # 1, nChannels * (l + 1), 2*npoints\n", 437 | " Hb = class_center - total_center # Broadcasting in numpy\n", 438 | " Sb = np.einsum('ecp, ehp->ch', Hb, Hb)\n", 439 | " Sb /= n_events\n", 440 | " # get Sw\n", 441 | " class_center = np.expand_dims(class_center, 1) # nEvents , 1, nChannels * (l + 1), 2*npoints\n", 442 | " Hw = X_aug_2 - np.tile(class_center,[1, n_trials,1,1]) # nEvents , nTrials, nChannels * (l + 1), 2*npoints\n", 443 | " Sw = np.einsum('etcp, ethp->ch', Hw, Hw)\n", 444 | " Sw /= (n_events * n_trials)\n", 445 | " Sw = 0.001 * np.eye(Hw.shape[2]) + Sw # regularization\n", 446 | " # get eigenvector\n", 447 | " b = np.dot(np.linalg.inv(Sw), Sb)\n", 448 | " [eig_value, eig_w] = np.linalg.eig(b) \n", 449 | "\n", 450 | " # Descending order\n", 451 | " eig_w = eig_w[:, eig_value.argsort()[::-1]] # return indices in ascending order and reverse\n", 452 | " eig_value.sort()\n", 453 | " eig_value = eig_value[::-1] # sort in descending\n", 454 | " w = eig_w[:, :Nk]\n", 455 | "\n", 456 | " return w \n", 457 | "\n", 458 | "def TDCA_train(trainData, P ,l , Nk ):\n", 459 | " \"\"\" Get TDCA spatial filters and average templates for all classes\n", 460 | " :param trainData: training data of all events\n", 461 | " ndarray(n_channels, (n_sample_points + l), n_events, n_trials)\n", 462 | " :param P: projection matrix for all classes\n", 463 | " ndarray(n_sample_points, n_sample_points, n_events)\n", 464 | " :param l: delay point\n", 465 | " :param Nk: the number of subspaces\n", 466 | "\n", 467 | " :return: w: Spatial Filters\n", 468 | " ndarray(n_channels * (l + 1), Nk)\n", 469 | " mean_temp: average templates\n", 470 | " ndarray(Nk, 2n_sample_points, n_events)\n", 471 | "\n", 472 | " \"\"\"\n", 473 | " \n", 474 | " [nChannels, nTimes, nEvents, nTrials] = trainData.shape \n", 475 | " npoints = nTimes - l\n", 476 | "\n", 477 | "\n", 478 | " data_aug_2 = np.zeros((nChannels * (l + 1), 2*npoints, nEvents, nTrials))\n", 479 | " for ievent in range(nEvents):\n", 480 | " dat = trainData[:, :, ievent,:]\n", 481 | " # first\n", 482 | " dat_aug_1 = np.zeros((nChannels*(l+1), npoints, nTrials))\n", 483 | " for il in range(l+1):\n", 484 | " dat_aug_1[il*(nChannels):(il+1)*nChannels, : ,: ] = dat[:,il:(il+npoints) ,:]\n", 485 | " # second\n", 486 | " dat_p = np.zeros_like((dat_aug_1))\n", 487 | " for itrial in range(nTrials):\n", 488 | " dat_p[:, :, itrial] = dat_aug_1[:, :, itrial] @ P[:, :, ievent] # projection\n", 489 | " dat_aug_2 = np.concatenate((dat_aug_1, dat_p), axis=1, out=None)\n", 490 | " #\n", 491 | " data_aug_2[..., ievent, :] = dat_aug_2\n", 492 | "\n", 493 | " # get w\n", 494 | " w = tdca_matrix(data_aug_2 , Nk=Nk)\n", 495 | " # get mean temps Nk * 2 Num of sample points * num of events\n", 496 | " mean_tem = np.zeros((Nk, npoints*2, nEvents))\n", 497 | " mean_data = np.mean(data_aug_2, -1)\n", 498 | " for i in range((nEvents)):\n", 499 | " mean_tem[:,:,i] = w.T @ mean_data[:,:,i]\n", 500 | "\n", 501 | " return w, mean_tem\n", 502 | "\n", 503 | "def TDCA_test(testData, w, mean_temp ,P , l ):\n", 504 | " \"\"\"\n", 505 | " :param testData: test_data of multi trials\n", 506 | " ndarray(n_channels, n_sample_points, n_trials(equals to n_events))\n", 507 | " :param w: Spatial Filters\n", 508 | " ndarray(n_channels * (l + 1), Nk)\n", 509 | " :param mean_temp: Average template\n", 510 | " ndarray(Nk, 2n_sample_points, n_events)\n", 511 | " :param P: projection matrix for all classes\n", 512 | " ndarray(n_sample_points, n_sample_points, n_events)\n", 513 | " :param l: delay point\n", 514 | "\n", 515 | " :return: predict of singe block\n", 516 | " ndarray(n_trials, n_classes)\n", 517 | " \"\"\"\n", 518 | " \n", 519 | " [nChannels, nTimes, nEvents] = testData.shape\n", 520 | " rr = np.zeros((nEvents, nEvents))\n", 521 | " for m in range(nEvents): # the m-th test data\n", 522 | " test = testData[:, :, m]\n", 523 | " # first\n", 524 | " test_aug_1 = np.zeros((nChannels * (l + 1), nTimes))\n", 525 | " aug_zero = np.zeros((nChannels, l)) # Splice 0 matrix\n", 526 | " test = np.concatenate((test, aug_zero), axis=1, out=None) # nChannels, nTimes + l\n", 527 | " for il in range(l + 1):\n", 528 | " test_aug_1[il * (nChannels):(il + 1) * nChannels, :] = test[:, il:(il + nTimes)]\n", 529 | " # Calculate the vector of correlation coefficients\n", 530 | " r = np.zeros(nEvents)\n", 531 | " for n in range(nEvents): # the n-th train model\n", 532 | " # second\n", 533 | " dat_p = test_aug_1 @ P[:, :, n]\n", 534 | " test_aug_2 = np.concatenate((test_aug_1, dat_p), axis=1, out=None)\n", 535 | " # slove rr\n", 536 | " train = mean_temp[:, :, n]\n", 537 | " r[n] = corr2(train, w.T @ test_aug_2)\n", 538 | " rr[m, :] = r\n", 539 | " return rr" 540 | ] 541 | }, 542 | { 543 | "cell_type": "markdown", 544 | "metadata": {}, 545 | "source": [ 546 | "### COUNT ACC" 547 | ] 548 | }, 549 | { 550 | "cell_type": "code", 551 | "execution_count": 6, 552 | "metadata": {}, 553 | "outputs": [], 554 | "source": [ 555 | "import os, sys\n", 556 | "import scipy.io as sio\n", 557 | "import warnings\n", 558 | "warnings.filterwarnings('ignore')\n", 559 | "import time\n", 560 | "from joblib import Parallel, delayed\n", 561 | "\n", 562 | "import numpy as np\n", 563 | "from sklearn.model_selection import ShuffleSplit,LeaveOneOut\n", 564 | "\n", 565 | "\n", 566 | "def beta_TDCA_Aug(idx_num, n_train, t_task ,n_Aug):\n", 567 | "\n", 568 | " # setting\n", 569 | " f_list = [8.6, 8.8,\n", 570 | " 9, 9.2, 9.4, 9.6, 9.8,\n", 571 | " 10, 10.2, 10.4, 10.6, 10.8,\n", 572 | " 11, 11.2, 11.4, 11.6, 11.8,\n", 573 | " 12, 12.2, 12.4, 12.6, 12.8,\n", 574 | " 13, 13.2, 13.4, 13.6, 13.8,\n", 575 | " 14, 14.2, 14.4, 14.6, 14.8,\n", 576 | " 15, 15.2, 15.4, 15.6, 15.8,\n", 577 | " 8, 8.2, 8.4, ]\n", 578 | " f_list = np.array(f_list)\n", 579 | " subject_id = ['S'+'{:02d}'.format(idx_subject+1) for idx_subject in range(70)] # S01,S02,.....S70\n", 580 | "\n", 581 | " idx_num = idx_num\n", 582 | " idx_subject = subject_id[idx_num]\n", 583 | " sfreq = 250\n", 584 | " filepath = r'Beta' \n", 585 | " filepath = os.path.join(filepath, str(idx_subject) + '.mat')\n", 586 | " num_filter = 5\n", 587 | " preEEG = PreProcessing_BETA(filepath, t_begin=0.5, t_end=0.5 + 0.13 + t_task + 3/sfreq,\n", 588 | " fs_down=250, chans=['POZ', 'PZ', 'PO3', 'PO5', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],\n", 589 | " num_filter=num_filter)\n", 590 | "\n", 591 | " raw_data = preEEG.load_data()\n", 592 | " w_pass_2d = np.array([[5, 14, 22, 30, 38], [90, 90, 90, 90, 90]]) # 70\n", 593 | " w_stop_2d = np.array([[3, 12, 20, 28, 36], [92, 92, 92, 92, 92]]) # 72\n", 594 | " filtered_data = preEEG.filtered_data_iir111(w_pass_2d, w_stop_2d, raw_data)\n", 595 | "\n", 596 | " \"\"\"\n", 597 | " Cross-validation parameter setting\n", 598 | " \"\"\"\n", 599 | " nBlock = 4\n", 600 | " nEvent = 40\n", 601 | " train_size = n_train \n", 602 | " n_splits = 4\n", 603 | " if train_size == nBlock - 1:\n", 604 | " kf = LeaveOneOut()\n", 605 | " else:\n", 606 | " kf = ShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=idx_num + 1)\n", 607 | "\n", 608 | " \"\"\"\n", 609 | " TDCA parameter setting\n", 610 | " \"\"\"\n", 611 | " l = 3 # delay point \n", 612 | " t = t_task\n", 613 | " train_point = np.arange(int((0.13) * sfreq), int((0.13 + t) * sfreq)+l)\n", 614 | " test_point = np.arange(int((0.13) * sfreq), int((0.13 + t) * sfreq))\n", 615 | " # get P of all classes\n", 616 | " P = get_P(f_list=f_list, Nh=5, sTime=t, sfreq=sfreq)\n", 617 | "\n", 618 | " # Cross-validation\n", 619 | " acc_s = 0\n", 620 | " for train, test in kf.split(np.arange(nBlock)):\n", 621 | "\n", 622 | " # train : get ensembleW of banks\n", 623 | " train_w = dict()\n", 624 | " train_meantemp = dict()\n", 625 | " for idx_filter in range(num_filter):\n", 626 | " idx_filter += 1\n", 627 | " bank_data = filtered_data['bank' + str(idx_filter)]\n", 628 | " train_data11 = bank_data[:, :, :, train] # \n", 629 | " train_data = train_data11[:, train_point, :, :] # n_channel * n_times * n_events * n_trials\n", 630 | " \n", 631 | " if n_Aug == 0:\n", 632 | " trainData_pt = train_data.copy()\n", 633 | " else: \n", 634 | " # Data augmentation\n", 635 | " ntrail_noise = n_Aug\n", 636 | " data_augment = np.zeros((train_data.shape[0], train_data.shape[1], train_data.shape[2], ntrail_noise))\n", 637 | " for ievent in range(nEvent):\n", 638 | " # get Nh_strat\n", 639 | " f = f_list[ievent]\n", 640 | " for ih in range(5):\n", 641 | " ih = ih+1\n", 642 | " if ih*f >= 8*idx_filter:\n", 643 | " Nh_start = ih\n", 644 | " break\n", 645 | " data_augment[:, :, ievent, :] = get_augment_fb_noiseAfter(fs=sfreq, f=f_list[ievent],Nh_start=Nh_start, Nh_end=5,\n", 646 | " ntrail_noise=ntrail_noise,\n", 647 | " mean_temp=np.mean(train_data,-1)[:, :, ievent])\n", 648 | " trainData_pt = np.concatenate((train_data, data_augment), axis=3)\n", 649 | "\n", 650 | " # train\n", 651 | " w, mean_temp_TDCA = TDCA_train(trainData_pt, P=P, l=l, Nk=9)\n", 652 | " #\n", 653 | " train_w['bank' + str(idx_filter)] = w\n", 654 | " train_meantemp['bank' + str(idx_filter)] = mean_temp_TDCA\n", 655 | "\n", 656 | " # test:\n", 657 | " predictAll = np.zeros((test.shape[0], nEvent))\n", 658 | " flag = 0\n", 659 | " for isplit in test:\n", 660 | " rrall = np.zeros((nEvent, nEvent))\n", 661 | " for idx_filter in range(num_filter):\n", 662 | " idx_filter += 1\n", 663 | " bank_data = filtered_data['bank' + str(idx_filter)]\n", 664 | " test_data111 = bank_data[:, :, :, isplit]\n", 665 | " test_data = test_data111[:,test_point,:]\n", 666 | " rr = TDCA_test(test_data, train_w['bank' + str(idx_filter)], train_meantemp['bank' + str(idx_filter)],\n", 667 | " P=P, l=l)\n", 668 | " rrall += np.multiply(np.sign(rr), (rr ** 2)) * (idx_filter ** (-1.25) + 0.25)\n", 669 | " predict = np.argmax(rrall, -1)\n", 670 | " predictAll[flag, :] = predict\n", 671 | " flag += 1\n", 672 | " acc_s = acc_calculate(predictAll) + acc_s\n", 673 | " acc = acc_s / n_splits\n", 674 | " # print('sub', idx_num + 1, ', acc = ', acc_s / n_splits)\n", 675 | " return acc\n", 676 | "\n", 677 | "def beta_eTRCA_Aug(idx_num, n_train, t_task ,n_Aug):\n", 678 | "\n", 679 | " # setting\n", 680 | " f_list = [8.6, 8.8,\n", 681 | " 9, 9.2, 9.4, 9.6, 9.8,\n", 682 | " 10, 10.2, 10.4, 10.6, 10.8,\n", 683 | " 11, 11.2, 11.4, 11.6, 11.8,\n", 684 | " 12, 12.2, 12.4, 12.6, 12.8,\n", 685 | " 13, 13.2, 13.4, 13.6, 13.8,\n", 686 | " 14, 14.2, 14.4, 14.6, 14.8,\n", 687 | " 15, 15.2, 15.4, 15.6, 15.8,\n", 688 | " 8, 8.2, 8.4, ]\n", 689 | " subject_id = ['S'+'{:02d}'.format(idx_subject+1) for idx_subject in range(70)] # S01,S02,...S70\n", 690 | "\n", 691 | " idx_num = idx_num\n", 692 | " idx_subject = subject_id[idx_num]\n", 693 | " sfreq = 250\n", 694 | " filepath = r'Beta' \n", 695 | " filepath = os.path.join(filepath, str(idx_subject) + '.mat')\n", 696 | " num_filter = 5\n", 697 | " preEEG = PreProcessing_BETA(filepath, t_begin=0.5, t_end=0.5 + 0.13 + t_task, # t_begin=0.5+0.13, t_end=0.5+0.13+0.3\n", 698 | " fs_down=250, chans=['POZ', 'PZ', 'PO3', 'PO5', 'PO4', 'PO6', 'O1', 'OZ', 'O2'],\n", 699 | " num_filter=num_filter)\n", 700 | "\n", 701 | " raw_data = preEEG.load_data()\n", 702 | " w_pass_2d = np.array([[5, 14, 22, 30, 38], [90, 90, 90, 90, 90]]) # 70\n", 703 | " w_stop_2d = np.array([[3, 12, 20, 28, 36], [92, 92, 92, 92, 92]]) # 72\n", 704 | " filtered_data = preEEG.filtered_data_iir111(w_pass_2d, w_stop_2d, raw_data)\n", 705 | "\n", 706 | " \"\"\"\n", 707 | " Cross-validation parameter setting\n", 708 | " \"\"\"\n", 709 | " nBlock = 4\n", 710 | " nEvent = 40\n", 711 | " train_size = n_train \n", 712 | " n_splits = 4\n", 713 | " if train_size == nBlock - 1:\n", 714 | " kf = LeaveOneOut()\n", 715 | " else:\n", 716 | " kf = ShuffleSplit(n_splits=n_splits, train_size=train_size, random_state=idx_num + 1)\n", 717 | "\n", 718 | " t = t_task\n", 719 | " task_point = np.arange(int((0.13) * sfreq), int((0.13 + t) * sfreq))\n", 720 | "\n", 721 | " # train : get ensembleW of banks\n", 722 | " acc_s = 0\n", 723 | " for train, test in kf.split(np.arange(nBlock)):\n", 724 | "\n", 725 | " # train : get ensembleW of banks\n", 726 | " train_w = dict()\n", 727 | " train_meantemp = dict()\n", 728 | " for idx_filter in range(num_filter):\n", 729 | " idx_filter += 1\n", 730 | " bank_data = filtered_data['bank' + str(idx_filter)]\n", 731 | " train_data11 = bank_data[:, :, :, train] \n", 732 | " train_data = train_data11[:, task_point, :, :] # n_channel * n_times * n_events * n_trials\n", 733 | " \n", 734 | " if n_Aug == 0:\n", 735 | " trainData_pt = train_data.copy()\n", 736 | " else: \n", 737 | " # Data augmentation\n", 738 | " ntrail_noise = n_Aug\n", 739 | " data_augment = np.zeros((train_data.shape[0], train_data.shape[1], train_data.shape[2], ntrail_noise))\n", 740 | " for ievent in range(nEvent):\n", 741 | " # get Nh_strat\n", 742 | " f = f_list[ievent]\n", 743 | " for ih in range(5):\n", 744 | " ih = ih+1\n", 745 | " if ih*f >= 8*idx_filter:\n", 746 | " Nh_start = ih\n", 747 | " break\n", 748 | " data_augment[:, :, ievent, :] = get_augment_fb_noiseAfter(fs=sfreq, f=f_list[ievent],Nh_start=Nh_start, Nh_end=5,\n", 749 | " ntrail_noise=ntrail_noise,\n", 750 | " mean_temp=np.mean(train_data,-1)[:, :, ievent])\n", 751 | " trainData_pt = np.concatenate((train_data, data_augment), axis=3)\n", 752 | "\n", 753 | " # train\n", 754 | " w, mean_temp = TRCA_train(trainData_pt)\n", 755 | " #\n", 756 | " train_w['bank' + str(idx_filter)] = w\n", 757 | " train_meantemp['bank' + str(idx_filter)] = mean_temp\n", 758 | "\n", 759 | " # test:\n", 760 | " predictAll = np.zeros((test.shape[0], nEvent))\n", 761 | " flag = 0\n", 762 | " for isplit in test:\n", 763 | " rrall = np.zeros((nEvent, nEvent))\n", 764 | " for idx_filter in range(num_filter):\n", 765 | " idx_filter += 1\n", 766 | " bank_data = filtered_data['bank' + str(idx_filter)]\n", 767 | " test_data = bank_data[:, :, :, isplit]\n", 768 | " test_data = test_data[:, task_point, :]\n", 769 | " rr = TRCA_test(test_data, train_w['bank' + str(idx_filter)],\n", 770 | " train_meantemp['bank' + str(idx_filter)], True)\n", 771 | " rrall += np.multiply(np.sign(rr), (rr ** 2)) * (idx_filter ** (-1.25) + 0.25)\n", 772 | " predict = np.argmax(rrall, -1)\n", 773 | " predictAll[flag, :] = predict\n", 774 | " flag += 1\n", 775 | " acc_s = acc_calculate(predictAll) + acc_s\n", 776 | " acc = acc_s / n_splits\n", 777 | " # print('sub', idx_num + 1, ', acc = ', acc_s / n_splits)\n", 778 | "\n", 779 | " return acc" 780 | ] 781 | }, 782 | { 783 | "cell_type": "markdown", 784 | "metadata": {}, 785 | "source": [ 786 | "## Result" 787 | ] 788 | }, 789 | { 790 | "cell_type": "markdown", 791 | "metadata": {}, 792 | "source": [ 793 | "### eTRCA(w/oSAME)" 794 | ] 795 | }, 796 | { 797 | "cell_type": "code", 798 | "execution_count": 7, 799 | "metadata": {}, 800 | "outputs": [ 801 | { 802 | "name": "stdout", 803 | "output_type": "stream", 804 | "text": [ 805 | "train_size= 1\n", 806 | "time = 106.3842568397522 s\n", 807 | "[0.11458333 0.05208333 0.34166667 0.10833333 0.09791667 0.06041667\n", 808 | " 0.16666667 0.07916667 0.21458333 0.12916667 0.03333333 0.2375\n", 809 | " 0.36666667 0.06666667 0.03125 0.05416667 0.01666667 0.18125\n", 810 | " 0.08958333 0.05833333 0.13541667 0.06041667 0.31458333 0.10208333\n", 811 | " 0.17916667 0.06875 0.15 0.25416667 0.07708333 0.10416667\n", 812 | " 0.1625 0.04375 0.05625 0.09375 0.15416667 0.20416667\n", 813 | " 0.07916667 0.03333333 0.08333333 0.0375 0.07083333 0.23333333\n", 814 | " 0.05 0.04791667 0.04166667 0.06041667 0.03333333 0.13125\n", 815 | " 0.1125 0.05 0.10625 0.07291667 0.325 0.05625\n", 816 | " 0.03333333 0.40625 0.20416667 0.24166667 0.01041667 0.225\n", 817 | " 0.02708333 0.11875 0.15833333 0.11875 0.08333333 0.13958333\n", 818 | " 0.1 0.07916667 0.03541667 0.06666667]\n", 819 | "mean_acc: 0.11904761904761903\n", 820 | "train_size= 2\n", 821 | "time = 95.07941007614136 s\n", 822 | "[0.85 0.88125 0.68125 0.6125 0.665625 0.615625 0.621875 0.415625\n", 823 | " 0.865625 0.403125 0.125 0.684375 0.728125 0.584375 0.45625 0.446875\n", 824 | " 0.109375 0.878125 0.753125 0.309375 0.615625 0.709375 0.86875 0.540625\n", 825 | " 0.54375 0.165625 0.56875 0.5875 0.415625 0.715625 0.240625 0.184375\n", 826 | " 0.28125 0.73125 0.66875 0.746875 0.690625 0.084375 0.31875 0.584375\n", 827 | " 0.190625 0.60625 0.29375 0.165625 0.39375 0.275 0.1625 0.746875\n", 828 | " 0.734375 0.43125 0.403125 0.628125 0.503125 0.378125 0.20625 0.45625\n", 829 | " 0.81875 0.66875 0.0875 0.63125 0.103125 0.640625 0.815625 0.440625\n", 830 | " 0.14375 0.70625 0.79375 0.75 0.590625 0.76875 ]\n", 831 | "mean_acc: 0.5212053571428571\n", 832 | "train_size= 3\n", 833 | "time = 94.33243155479431 s\n", 834 | "[0.9 0.90625 0.775 0.70625 0.85 0.68125 0.73125 0.54375 0.91875\n", 835 | " 0.46875 0.2 0.75625 0.7875 0.7125 0.64375 0.55625 0.15625 0.90625\n", 836 | " 0.83125 0.44375 0.75625 0.75625 0.93125 0.65625 0.63125 0.20625 0.69375\n", 837 | " 0.65625 0.55625 0.825 0.38125 0.25625 0.375 0.8125 0.7875 0.80625\n", 838 | " 0.775 0.09375 0.525 0.675 0.2875 0.6875 0.425 0.2125 0.525\n", 839 | " 0.4875 0.23125 0.8 0.79375 0.54375 0.45 0.775 0.7 0.50625\n", 840 | " 0.24375 0.44375 0.81875 0.78125 0.1375 0.6875 0.1375 0.73125 0.9\n", 841 | " 0.58125 0.15 0.74375 0.88125 0.80625 0.70625 0.85 ]\n", 842 | "mean_acc: 0.609375\n" 843 | ] 844 | } 845 | ], 846 | "source": [ 847 | "Aug_size = [0,0,0]\n", 848 | "acc_all = np.zeros((70,3))\n", 849 | "for i_train in range(3):\n", 850 | " print('train_size=',i_train+1)\n", 851 | " c = time.time()\n", 852 | " acc = Parallel(n_jobs=-1)(delayed(beta_eTRCA_Aug)(idx_num, n_train=i_train+1, t_task=0.5, n_Aug=Aug_size[i_train]) for idx_num in range(70))\n", 853 | " acc = np.array(acc)\n", 854 | " acc_all[:,i_train] = acc\n", 855 | " e = time.time()\n", 856 | " print('time =', e - c,' s')\n", 857 | " print(acc)\n", 858 | " print('mean_acc:',np.mean(acc))\n", 859 | "sio.savemat(r'demo_beta_eTRCA_withoutSAME.mat', {'acc': acc_all})" 860 | ] 861 | }, 862 | { 863 | "cell_type": "markdown", 864 | "metadata": {}, 865 | "source": [ 866 | "### eTRCA(W/SAME)" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": 8, 872 | "metadata": {}, 873 | "outputs": [ 874 | { 875 | "name": "stdout", 876 | "output_type": "stream", 877 | "text": [ 878 | "train_size= 1\n", 879 | "time = 99.01571726799011 s\n", 880 | "[0.83958333 0.83333333 0.67916667 0.5625 0.65416667 0.56041667\n", 881 | " 0.51041667 0.35625 0.86666667 0.41041667 0.11041667 0.66875\n", 882 | " 0.73958333 0.50833333 0.4875 0.53958333 0.14583333 0.83541667\n", 883 | " 0.64166667 0.32083333 0.62708333 0.72291667 0.81458333 0.5125\n", 884 | " 0.5625 0.21041667 0.61041667 0.56041667 0.45 0.725\n", 885 | " 0.3375 0.23125 0.2625 0.68125 0.63541667 0.7375\n", 886 | " 0.69791667 0.10208333 0.4625 0.50416667 0.1875 0.59375\n", 887 | " 0.3625 0.2 0.46666667 0.36875 0.1625 0.68541667\n", 888 | " 0.72708333 0.37083333 0.41041667 0.63541667 0.52291667 0.45208333\n", 889 | " 0.13125 0.4875 0.76875 0.59166667 0.17291667 0.6\n", 890 | " 0.15 0.65 0.84375 0.51041667 0.16875 0.6875\n", 891 | " 0.81458333 0.7375 0.66041667 0.72083333]\n", 892 | "mean_acc: 0.5222916666666666\n", 893 | "train_size= 2\n", 894 | "time = 97.92097282409668 s\n", 895 | "[0.909375 0.896875 0.746875 0.728125 0.83125 0.6875 0.70625 0.525\n", 896 | " 0.921875 0.509375 0.203125 0.74375 0.821875 0.66875 0.58125 0.5875\n", 897 | " 0.240625 0.875 0.853125 0.528125 0.75 0.828125 0.95 0.59375\n", 898 | " 0.646875 0.25625 0.71875 0.690625 0.640625 0.81875 0.4125 0.3125\n", 899 | " 0.40625 0.771875 0.828125 0.790625 0.78125 0.146875 0.615625 0.653125\n", 900 | " 0.3 0.6625 0.496875 0.271875 0.621875 0.575 0.24375 0.746875\n", 901 | " 0.75 0.53125 0.578125 0.753125 0.69375 0.553125 0.271875 0.4875\n", 902 | " 0.84375 0.7125 0.2 0.7125 0.175 0.70625 0.903125 0.65625\n", 903 | " 0.1875 0.7625 0.821875 0.80625 0.734375 0.85625 ]\n", 904 | "mean_acc: 0.625625\n", 905 | "train_size= 3\n", 906 | "time = 86.72654867172241 s\n", 907 | "[0.9125 0.93125 0.78125 0.80625 0.89375 0.7625 0.79375 0.6125 0.9375\n", 908 | " 0.575 0.2875 0.76875 0.85625 0.7375 0.6625 0.6625 0.31875 0.91875\n", 909 | " 0.8875 0.63125 0.81875 0.84375 0.95625 0.7125 0.68125 0.30625 0.78125\n", 910 | " 0.70625 0.71875 0.8625 0.44375 0.4 0.5 0.86875 0.80625 0.8125\n", 911 | " 0.83125 0.21875 0.6375 0.69375 0.3375 0.69375 0.53125 0.31875 0.68125\n", 912 | " 0.6875 0.325 0.8125 0.8 0.61875 0.58125 0.8 0.8 0.6125\n", 913 | " 0.3 0.4625 0.8625 0.7875 0.275 0.7625 0.2375 0.775 0.9375\n", 914 | " 0.71875 0.2375 0.80625 0.8875 0.81875 0.7625 0.91875]\n", 915 | "mean_acc: 0.6783928571428572\n" 916 | ] 917 | } 918 | ], 919 | "source": [ 920 | "Aug_size = [3,5,6]\n", 921 | "acc_all = np.zeros((70,3))\n", 922 | "for i_train in range(3):\n", 923 | " print('train_size=',i_train+1)\n", 924 | " c = time.time()\n", 925 | " acc = Parallel(n_jobs=-1)(delayed(beta_eTRCA_Aug)(idx_num, n_train=i_train+1, t_task=0.5, n_Aug=Aug_size[i_train]) for idx_num in range(70))\n", 926 | " acc = np.array(acc)\n", 927 | " acc_all[:,i_train] = acc\n", 928 | " e = time.time()\n", 929 | " print('time =', e - c,' s')\n", 930 | " print(acc)\n", 931 | " print('mean_acc:',np.mean(acc))\n", 932 | "sio.savemat(r'demo_beta_eTRCA_withSAME.mat', {'acc': acc_all})" 933 | ] 934 | }, 935 | { 936 | "cell_type": "markdown", 937 | "metadata": {}, 938 | "source": [ 939 | "### TDCA(w/oSAME)" 940 | ] 941 | }, 942 | { 943 | "cell_type": "code", 944 | "execution_count": 9, 945 | "metadata": {}, 946 | "outputs": [ 947 | { 948 | "name": "stdout", 949 | "output_type": "stream", 950 | "text": [ 951 | "train_size= 1\n", 952 | "time = 303.67482328414917 s\n", 953 | "[0.32291667 0.17083333 0.33958333 0.11666667 0.14791667 0.10416667\n", 954 | " 0.1375 0.14375 0.43333333 0.2 0.09166667 0.4125\n", 955 | " 0.44583333 0.08958333 0.08125 0.10833333 0.02083333 0.43125\n", 956 | " 0.23125 0.07916667 0.22291667 0.17083333 0.44583333 0.21666667\n", 957 | " 0.20416667 0.08541667 0.27083333 0.25833333 0.11458333 0.2375\n", 958 | " 0.2 0.01666667 0.09583333 0.22708333 0.22291667 0.32083333\n", 959 | " 0.225 0.0375 0.1125 0.09375 0.0875 0.30833333\n", 960 | " 0.07916667 0.09791667 0.0875 0.07916667 0.06666667 0.32291667\n", 961 | " 0.2875 0.06041667 0.19375 0.13958333 0.25625 0.07083333\n", 962 | " 0.05833333 0.39166667 0.27708333 0.25 0.07916667 0.29375\n", 963 | " 0.0375 0.15416667 0.26666667 0.18333333 0.07916667 0.18958333\n", 964 | " 0.19583333 0.2125 0.05833333 0.1625 ]\n", 965 | "mean_acc: 0.1844940476190476\n", 966 | "train_size= 2\n", 967 | "time = 222.89798712730408 s\n", 968 | "[0.925 0.915625 0.740625 0.721875 0.81875 0.634375 0.71875 0.540625\n", 969 | " 0.9375 0.46875 0.1125 0.690625 0.79375 0.671875 0.55625 0.51875\n", 970 | " 0.215625 0.890625 0.834375 0.546875 0.696875 0.796875 0.925 0.64375\n", 971 | " 0.671875 0.24375 0.675 0.628125 0.64375 0.834375 0.375 0.2875\n", 972 | " 0.328125 0.78125 0.778125 0.775 0.7 0.11875 0.575 0.628125\n", 973 | " 0.271875 0.61875 0.384375 0.215625 0.55625 0.51875 0.2625 0.759375\n", 974 | " 0.76875 0.503125 0.54375 0.73125 0.690625 0.50625 0.259375 0.803125\n", 975 | " 0.821875 0.7625 0.18125 0.690625 0.13125 0.678125 0.871875 0.65\n", 976 | " 0.15625 0.7625 0.815625 0.803125 0.684375 0.815625]\n", 977 | "mean_acc: 0.6082589285714286\n", 978 | "train_size= 3\n", 979 | "time = 155.24060249328613 s\n", 980 | "[0.9375 0.9375 0.78125 0.8 0.9125 0.75625 0.775 0.66875 0.95625\n", 981 | " 0.5625 0.225 0.7625 0.8375 0.7375 0.68125 0.6 0.325 0.90625\n", 982 | " 0.9125 0.7125 0.78125 0.84375 0.95625 0.7375 0.76875 0.34375 0.75\n", 983 | " 0.69375 0.7125 0.9125 0.44375 0.4 0.48125 0.85 0.83125 0.81875\n", 984 | " 0.8125 0.21875 0.66875 0.725 0.3625 0.675 0.50625 0.25 0.6625\n", 985 | " 0.6375 0.30625 0.84375 0.8125 0.6375 0.60625 0.8 0.8 0.5625\n", 986 | " 0.30625 0.9 0.88125 0.7875 0.25 0.75 0.2 0.7625 0.93125\n", 987 | " 0.7 0.18125 0.8125 0.9 0.86875 0.7125 0.86875]\n", 988 | "mean_acc: 0.6830357142857142\n" 989 | ] 990 | } 991 | ], 992 | "source": [ 993 | "Aug_size = [0,0,0]\n", 994 | "acc_all = np.zeros((70,3))\n", 995 | "for i_train in range(3):\n", 996 | " print('train_size=',i_train+1)\n", 997 | " c = time.time()\n", 998 | " acc = Parallel(n_jobs=-1)(delayed(beta_TDCA_Aug)(idx_num, n_train=i_train+1, t_task=0.5, n_Aug=Aug_size[i_train]) for idx_num in range(70))\n", 999 | " acc = np.array(acc)\n", 1000 | " acc_all[:,i_train] = acc\n", 1001 | " e = time.time()\n", 1002 | " print('time =', e - c,' s')\n", 1003 | " print(acc)\n", 1004 | " print('mean_acc:',np.mean(acc))\n", 1005 | "sio.savemat(r'demo_beta_TDCA_withoutSAME.mat', {'acc': acc_all})" 1006 | ] 1007 | }, 1008 | { 1009 | "cell_type": "markdown", 1010 | "metadata": {}, 1011 | "source": [ 1012 | "### TDCA(w/SAME)" 1013 | ] 1014 | }, 1015 | { 1016 | "cell_type": "code", 1017 | "execution_count": 10, 1018 | "metadata": {}, 1019 | "outputs": [ 1020 | { 1021 | "name": "stdout", 1022 | "output_type": "stream", 1023 | "text": [ 1024 | "train_size= 1\n", 1025 | "time = 303.69347047805786 s\n", 1026 | "[0.80625 0.83125 0.68958333 0.61041667 0.75833333 0.52291667\n", 1027 | " 0.66041667 0.36875 0.88958333 0.41041667 0.11875 0.66666667\n", 1028 | " 0.7375 0.49375 0.47708333 0.46875 0.20833333 0.81666667\n", 1029 | " 0.6875 0.37291667 0.64583333 0.64583333 0.82916667 0.55833333\n", 1030 | " 0.53333333 0.26041667 0.59583333 0.56666667 0.46875 0.79791667\n", 1031 | " 0.33125 0.23958333 0.29166667 0.70416667 0.64375 0.75416667\n", 1032 | " 0.70416667 0.13541667 0.48958333 0.46041667 0.18541667 0.55\n", 1033 | " 0.3375 0.19375 0.48958333 0.33958333 0.16458333 0.66458333\n", 1034 | " 0.70416667 0.43958333 0.43333333 0.6125 0.62291667 0.42916667\n", 1035 | " 0.18958333 0.67291667 0.70833333 0.6125 0.14583333 0.58958333\n", 1036 | " 0.18333333 0.575 0.80833333 0.55833333 0.15625 0.69166667\n", 1037 | " 0.81666667 0.68541667 0.54583333 0.72291667]\n", 1038 | "mean_acc: 0.5297321428571431\n", 1039 | "train_size= 2\n", 1040 | "time = 235.44340467453003 s\n", 1041 | "[0.93125 0.921875 0.74375 0.73125 0.85 0.6625 0.703125 0.5125\n", 1042 | " 0.91875 0.540625 0.16875 0.725 0.80625 0.628125 0.575 0.56875\n", 1043 | " 0.253125 0.8625 0.878125 0.565625 0.746875 0.825 0.95625 0.659375\n", 1044 | " 0.684375 0.275 0.725 0.68125 0.6625 0.8875 0.434375 0.28125\n", 1045 | " 0.39375 0.78125 0.796875 0.78125 0.809375 0.20625 0.63125 0.65\n", 1046 | " 0.303125 0.646875 0.471875 0.259375 0.63125 0.515625 0.284375 0.771875\n", 1047 | " 0.78125 0.575 0.575 0.71875 0.7125 0.534375 0.25625 0.759375\n", 1048 | " 0.85 0.765625 0.23125 0.70625 0.20625 0.7 0.903125 0.709375\n", 1049 | " 0.196875 0.759375 0.825 0.80625 0.65 0.840625]\n", 1050 | "mean_acc: 0.6337499999999999\n", 1051 | "train_size= 3\n", 1052 | "time = 165.51607608795166 s\n", 1053 | "[0.925 0.9375 0.7875 0.7875 0.9 0.7375 0.7875 0.6 0.9625\n", 1054 | " 0.5875 0.24375 0.7625 0.84375 0.725 0.675 0.65 0.35625 0.91875\n", 1055 | " 0.9 0.66875 0.81875 0.83125 0.975 0.75625 0.78125 0.35625 0.76875\n", 1056 | " 0.7375 0.75625 0.93125 0.49375 0.375 0.5625 0.8375 0.825 0.80625\n", 1057 | " 0.85625 0.2625 0.66875 0.71875 0.34375 0.7 0.53125 0.3 0.70625\n", 1058 | " 0.6375 0.35625 0.85625 0.825 0.6625 0.6125 0.7875 0.8 0.61875\n", 1059 | " 0.3125 0.8875 0.86875 0.78125 0.25 0.7625 0.3 0.76875 0.9375\n", 1060 | " 0.73125 0.225 0.8375 0.9 0.85625 0.76875 0.8875 ]\n", 1061 | "mean_acc: 0.6955357142857143\n" 1062 | ] 1063 | } 1064 | ], 1065 | "source": [ 1066 | "Aug_size = [3,5,6]\n", 1067 | "acc_all = np.zeros((70,3))\n", 1068 | "for i_train in range(3):\n", 1069 | " print('train_size=',i_train+1)\n", 1070 | " c = time.time()\n", 1071 | " acc = Parallel(n_jobs=-1)(delayed(beta_TDCA_Aug)(idx_num, n_train=i_train+1, t_task=0.5, n_Aug=Aug_size[i_train]) for idx_num in range(70))\n", 1072 | " acc = np.array(acc)\n", 1073 | " acc_all[:,i_train] = acc\n", 1074 | " e = time.time()\n", 1075 | " print('time =', e - c,' s')\n", 1076 | " print(acc)\n", 1077 | " print('mean_acc:',np.mean(acc))\n", 1078 | "sio.savemat(r'demo_beta_TDCA_withSAME.mat', {'acc': acc_all})" 1079 | ] 1080 | }, 1081 | { 1082 | "cell_type": "code", 1083 | "execution_count": null, 1084 | "metadata": {}, 1085 | "outputs": [], 1086 | "source": [] 1087 | } 1088 | ], 1089 | "metadata": { 1090 | "kernelspec": { 1091 | "display_name": "myconda", 1092 | "language": "python", 1093 | "name": "myconda" 1094 | }, 1095 | "language_info": { 1096 | "codemirror_mode": { 1097 | "name": "ipython", 1098 | "version": 3 1099 | }, 1100 | "file_extension": ".py", 1101 | "mimetype": "text/x-python", 1102 | "name": "python", 1103 | "nbconvert_exporter": "python", 1104 | "pygments_lexer": "ipython3", 1105 | "version": "3.8.5" 1106 | } 1107 | }, 1108 | "nbformat": 4, 1109 | "nbformat_minor": 4 1110 | } 1111 | -------------------------------------------------------------------------------- /demo_beta_TDCA_withSAME.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuixinLuo/Source-Aliasing-Matrix-Estimation-DataAugmentation-SAME-SSVEP/11cc82c8a0b169668cbb27abe52d5651d5134c00/demo_beta_TDCA_withSAME.mat -------------------------------------------------------------------------------- /demo_beta_TDCA_withoutSAME.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuixinLuo/Source-Aliasing-Matrix-Estimation-DataAugmentation-SAME-SSVEP/11cc82c8a0b169668cbb27abe52d5651d5134c00/demo_beta_TDCA_withoutSAME.mat -------------------------------------------------------------------------------- /demo_beta_eTRCA_withSAME.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuixinLuo/Source-Aliasing-Matrix-Estimation-DataAugmentation-SAME-SSVEP/11cc82c8a0b169668cbb27abe52d5651d5134c00/demo_beta_eTRCA_withSAME.mat -------------------------------------------------------------------------------- /demo_beta_eTRCA_withoutSAME.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/RuixinLuo/Source-Aliasing-Matrix-Estimation-DataAugmentation-SAME-SSVEP/11cc82c8a0b169668cbb27abe52d5651d5134c00/demo_beta_eTRCA_withoutSAME.mat -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | > # Data augmentation for SSVEPs using Source Aliasing Matrix Estimation 2 | 3 | ## Source Aliasing Matrix Estimation (SAME) based on paper [1]_. 4 | 5 | We propose a data augmentation method named Source Aliasing Matrix Estimation (SAME) [1] to enhance the performance of state-of-the-art spatial filtering methods (i.e., eTRCA, TDCA) for SSVEP-BCIs. Based on the superposition model of SSVEPs, the task-related components are reconstructed by estimating the source aliasing matrixes. After adding noise, multiple artificial signals are generated and then added to the calibrated data in an appropriate proportion. 6 | 7 | This demo shows an example of using SAME for SSVEP-BCIs. The state-of-the-art algorithms for SSVEP-BCIs, i.e. ensemble task-related component analysis (eTRCA) [2] and task-discriminant component analysis (TDCA) [3] are used to compare the performance with and without SAME (w/SAME and w/oSAME, respectively). 8 | 9 | > [1] Luo R., et al. "Data augmentation of SSVEPs using source aliasing matrix estimation for brain-computer interfaces". *IEEE Trans. Biomed. Eng.*, 2022. DOI: 10.1109/TBME.2022.3227036 10 | > 11 | > [2] Nakanishi M., et al. "Enhancing detection of SSVEPs for a high-speed brain speller using task-related component analysis". *IEEE Trans. Biomed. Eng*., 2018, 65(1), 104-112. 12 | > 13 | > [3] Liu B., et al. "Improving the Performance of Individually Calibrated SSVEP-BCI by Task-Discriminant Component Analysis". *IEEE Trans. Neural Syst. Rehabil. Eng*, 2021, 29, 1998-2007. 14 | 15 | ## The main steps of SAME 16 | 17 | 1. SSVEP template averaged across trials is initially obtained. 18 | 19 | ![](https://latex.codecogs.com/svg.image?&space;&space;&space;\overline{\boldsymbol{X}}_{n}=\frac{1}{N_t}&space;\sum_{j=1}^{N_t}&space;\boldsymbol{X}_{n}^{(j)}) 20 | 21 | 2. The estimated source signal is reconstructed by estimating the aliasing matrix of sine-cosine signal. 22 | 23 | ![](https://latex.codecogs.com/svg.image?\begin{aligned}&space;&space;&space;&&space;\widehat{\boldsymbol{\Phi}}={\operatorname&space;{&space;a&space;r&space;g&space;}&space;\operatorname&space;{&space;m&space;i&space;n&space;}}\left\|\overline{\boldsymbol{X}}_{n}-\boldsymbol{\Phi}&space;Y_{n}\right\|_F^2&space;&space;&space;\end{aligned}) 24 | 25 | ![](https://latex.codecogs.com/svg.image?\begin{aligned}&space;&space;&space;&&space;\widehat{\boldsymbol{S}}_{n}=\widehat{\boldsymbol{\Phi}}&space;Y_{n}&space;&space;&space;\end{aligned}) 26 | 27 | 3. Random noise is added to obtain multiple artificial generated signals. 28 | 29 | ![](https://latex.codecogs.com/svg.image?\boldsymbol{Z}_{n}^{(k)}=\hat{\boldsymbol{S}}_{n}+\alpha&space;\boldsymbol{W}_{n&space;o&space;i&space;s&space;e}^{(k)}) 30 | 31 | ​ α is used to control the intensity of the noise, here it is set to 0.05 in this study. Random noise is added just to increase the number of generated signals, with no substantial information. 32 | 33 | ## Dataset 34 | 35 | #### BETA dataset [4]_ from Tsinghua university. 36 | 37 | > [4] Liu B., et al. "BETA: A Large Benchmark Database Toward SSVEP-BCI Application. *Frontiers in Neuroscience*, 2020, 14. 38 | 39 | After downloading the dataset, we renamed S1,...S9 to S01,...S09 according to our personal habits. If you don't want to do that, you can change the variables *subject_id* in the functions *beta_TDCA_Aug()* and *beta_eTRCA_Aug()*. 40 | 41 | ## Results 42 | 43 | - #### demo-SAME-BETA-main.ipynb 44 | 45 | This file is used to calculate the classification accuracy. We have run it and obtained the results for all subjects when the time window is 0.5s, stored in *demo_beta_eTRCA_withoutSAME.mat*, *demo_beta_eTRCA_withSAME.mat*, *demo_beta_TDCA_withoutSAME.mat*, and *demo_beta_TDCA_withSAME.mat*. 46 | 47 | The average accuracy across all subjects with different training trials (Nt) is listed as below: 48 | 49 | | | Nt=1 | Nt=2 | Nt=3 | 50 | | -------------- | ----- | ----- | ----- | 51 | | eTRCA(w/oSAME) | 0.119 | 0.521 | 0.609 | 52 | | eTRCA(w/SAME) | 0.522 | 0.626 | 0.678 | 53 | | TDCA(w/oSAME) | 0.184 | 0.608 | 0.683 | 54 | | TDCA(w/SAME) | 0.530 | 0.634 | 0.696 | 55 | 56 | The original accuracy of eTRCA and TDCA obtained by our code is similar to the results in the paper [3]. 57 | 58 | We used a device with 40 vCPUs to compute the results for different subjects in parallel. If you do not have such a configuration, running this code may be a bit slower. 59 | 60 | ## Acknowledgement 61 | 62 | Thanks to Liu B, the author of paper [3] and [4], for his patience in responding to my questions about TDCA. 63 | 64 | ## email 65 | 66 | email: ruixin_luo@tju.edu.cn 67 | 68 | This is my first paper. It may not be outstanding or perfect. But it encouraged me, an unconfident girl. I hope I can become more brave and confident in the future. 69 | 70 | --------------------------------------------------------------------------------