├── __pycache__ └── CorrCA.cpython-37.pyc ├── README.md └── CorrCA.py /__pycache__/CorrCA.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/renzocom/CorrCA/HEAD/__pycache__/CorrCA.cpython-37.pyc -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CorrCA: Correlated Component Analysis 2 | 3 | This repository contains a implementation of Correlated Component Analysis (CorrCA) based on the [original Matlab code](https://www.parralab.org/corrca/) from Parra's lab. 4 | 5 | ## Usage 6 | Example script demonstrating how to compute CorrCA on EEG evoked data. 7 | ``` 8 | import numpy as np 9 | from corrca import CorrCA 10 | 11 | # Load your preprocessed EEG data as a NumPy array 12 | # epochs: shape (n_epochs, n_channels, n_times) 13 | # times: shape (n_times,) 14 | epochs = np.load('path/to/your/epochs.npy') 15 | times = np.load('path/to/your/times.npy') 16 | 17 | # Define CorrCA parameters 18 | params = {'baseline_window': (-0.3, -0.05), 'response_window': (0., 0.6), 'gamma': 0, 'K': 60, 'stats': True, 'n_surrogates': 500, 'alpha': 0.01} 19 | 20 | # Perform CorrCA 21 | W, ISC, A, Y, Yfull, ISC_thr = CorrCA.calc_corrca(epochs, times, **params) 22 | ``` 23 | 24 | For other use cases look inside `calc_corrca()` to see how the main functions are called. 25 | -------------------------------------------------------------------------------- /CorrCA.py: -------------------------------------------------------------------------------- 1 | # 2 | # Renzo Comolatti (renzo.com@gmail.com) 3 | # 4 | # Class with Correlated Component Analysis (CorrCA) method based on 5 | # original matlab code from Parra's lab (https://www.parralab.org/corrca/). 6 | # 7 | # started 18/10/2019 8 | 9 | import numpy as np 10 | from scipy import linalg as sp_linalg 11 | from scipy import diag as sp_diag 12 | 13 | def calc_corrca(epochs, times, **par): 14 | """ 15 | Calculate Correlated Component Analysis (CorrCA) on given epochs and times. 16 | 17 | Parameters 18 | ---------- 19 | epochs : ndarray of shape (n_epochs, n_channels, n_times) 20 | Input signal data. 21 | times : ndarray of shape (n_times,) 22 | Array of time points corresponding to the epochs. 23 | **par : dict 24 | Additional parameters for the analysis. Expected keys are: 25 | - 'response_window' : tuple of float 26 | Start and end time for the response window. 27 | - 'gamma' : float 28 | Regularization parameter for the within-subject covariance matrix. 29 | - 'K' : int 30 | Number of components to retain. 31 | - 'n_surrogates' : int 32 | Number of surrogate datasets to use for statistical testing. 33 | - 'alpha' : float 34 | Significance level for statistical testing. 35 | - 'stats' : bool 36 | Whether to calculate statistics. 37 | 38 | Returns 39 | ------- 40 | W : ndarray of shape (n_channels, n_components) 41 | Backward model (signal to components). 42 | ISC : ndarray of shape (n_components,) 43 | Inter-subject correlation values. 44 | A : ndarray of shape (n_channels, n_components) 45 | Forward model (components to signal). 46 | Y : ndarray of shape (n_epochs, n_components, n_times) 47 | Transformed signal within the response window. 48 | Yfull : ndarray of shape (n_epochs, n_components, n_times) 49 | Transformed signal for the entire epoch duration. 50 | ISC_thr : float 51 | Threshold for inter-subject correlation values based on surrogate data. 52 | """ 53 | ini_ix = time2ix(times, par['response_window'][0]) 54 | end_ix = time2ix(times, par['response_window'][1]) 55 | X = np.array(epochs)[..., ini_ix : end_ix] 56 | 57 | W, ISC, A = fit(X, gamma=par['gamma'], k=par['K']) 58 | 59 | n_components = W.shape[1] 60 | if stats: 61 | print('Calculating statistics...') 62 | ISC_thr, ISC_null = stats(X, par['gamma'], par['K'], par['n_surrogates'], par['alpha']) 63 | n_components = sum(ISC > ISC_thr) 64 | W, ISC, A = W[:, :n_components], ISC[:n_components], A[:, :n_components] 65 | 66 | Y = transform(X, W) 67 | Yfull = transform(np.array(epochs), W) 68 | return W, ISC, A, Y, Yfull, ISC_thr 69 | 70 | ################## 71 | # MAIN FUNCTIONS # 72 | ################## 73 | def fit(X, version=2, gamma=0, k=None): 74 | ''' 75 | Correlated Component Analysis (CorrCA). 76 | 77 | Parameters 78 | ---------- 79 | X : ndarray of shape = (n_subj, n_dim, n_times) 80 | Signal to calculate CorrCA. 81 | k : int, 82 | Truncates eigenvalues on the Kth component. 83 | gamma : float, 84 | Truncates eigenvalues using SVD. 85 | 86 | Returns 87 | ------- 88 | W : ndarray of shape = (n_times, n_components) 89 | Backward model (signal to components). 90 | ISC : list of floats 91 | Inter-subject Correlation values. 92 | A : ndarray of shape = (n_times, n_components) 93 | Forward model (components to signal). 94 | ''' 95 | 96 | # TODO: implement case 3, tsvd truncation 97 | 98 | N, D, T = X.shape # subj x dim x times (instead of times x dim x subj) 99 | 100 | if k is not None: # truncate eigenvalues using SVD 101 | gamma = 0 102 | else: 103 | k = D 104 | 105 | # Compute within- (Rw) and between-subject (Rb) covariances 106 | if False: # Intuitive but innefficient way to calculate Rb and Rw 107 | Xcat = X.reshape((N * D, T)) # T x (D + N) note: dimensions vary first, then subjects 108 | Rkl = np.cov(Xcat).reshape((N, D, N, D)).swapaxes(1, 2) 109 | Rw = Rkl[range(N), range(N), ...].sum(axis=0) # Sum within subject covariances 110 | Rt = Rkl.reshape(N*N, D, D).sum(axis=0) 111 | Rb = (Rt - Rw) / (N-1) 112 | 113 | # Rw = sum(np.cov(X[n,...]) for n in range(N)) 114 | # Rt = N**2 * np.cov(X.mean(axis=0)) 115 | # Rb = (Rt - Rw) / (N-1) 116 | 117 | # fix for channel specific bad trial 118 | temp = [np.cov(X[n,...]) for n in range(N)] 119 | Rw = np.nansum(temp, axis=0) 120 | Rt = N**2 * np.cov(np.nanmean(X, axis=0)) 121 | Rb = (Rt - Rw) / (N-1) 122 | 123 | rank = np.linalg.matrix_rank(Rw) 124 | if rank < D and gamma != 0: 125 | print('Warning: data is rank deficient (gamma not used).') 126 | 127 | k = min(k, rank) # handle rank deficient data. 128 | if k < D: 129 | def regInv(R, k): 130 | '''PCA regularized inverse of square symmetric positive definite matrix R.''' 131 | 132 | U, S, Vh = np.linalg.svd(R) 133 | invR = U[:, :k].dot(sp_diag(1 / S[:k])).dot(Vh[:k, :]) 134 | return invR 135 | 136 | invR = regInv(Rw, k) 137 | ISC, W = sp_linalg.eig(invR.dot(Rb)) 138 | ISC, W = ISC[:k], W[:, :k] 139 | 140 | else: 141 | Rw_reg = (1-gamma) * Rw + gamma * Rw.diagonal().mean() * np.identity(D) 142 | ISC, W = sp_linalg.eig(Rb, Rw_reg) # W is already sorted by eigenvalue and normalized 143 | 144 | ISC = np.diagonal(W.T.dot(Rb).dot(W)) / np.diag(W.T.dot(Rw).dot(W)) 145 | 146 | ISC, W = np.real(ISC), np.real(W) 147 | 148 | if k==D: 149 | A = Rw.dot(W).dot(sp_linalg.inv(W.T.dot(Rw).dot(W))) 150 | else: 151 | A = Rw.dot(W).dot(np.diag(1 / np.diag(W.T.dot(Rw).dot(W)))) 152 | 153 | return W, ISC, A 154 | 155 | def transform(X, W): 156 | ''' 157 | Get CorrCA components from signal(X), e.g. epochs or evoked, using backward model (W). 158 | 159 | Parameters 160 | ---------- 161 | X : ndarray of shape = (n_subj, n_dim, n_times) or (n_dim, n_times) 162 | Signal to transform. 163 | W : ndarray of shape = (n_times, n_components) 164 | Backward model (signal to components). 165 | 166 | Returns 167 | ------- 168 | Y : ndarray of shape = (n_subj, n_components, n_times) or (n_components, n_times) 169 | CorrCA components. 170 | ''' 171 | 172 | flag = False 173 | if X.ndim == 2: 174 | flag = True 175 | X = X[np.newaxis, ...] 176 | N, _, T = X.shape 177 | K = W.shape[1] 178 | Y = np.zeros((N, K, T)) 179 | for n in range(N): 180 | Y[n, ...] = W.T.dot(X[n, ...]) 181 | if flag: 182 | Y = np.squeeze(Y, axis=0) 183 | return Y 184 | 185 | def get_ISC(X, W): 186 | ''' 187 | Get ISC values from signal (X) and backward model (W) 188 | 189 | Parameters 190 | ---------- 191 | X : ndarray of shape = (n_subj, n_dim, n_times) 192 | Signal to calculate CorrCA. 193 | W : ndarray of shape = (n_times, n_components) 194 | Backward model (signal to components). 195 | 196 | Returns 197 | ------- 198 | ISC : list of floats 199 | Inter-subject Correlation values. 200 | ''' 201 | N, D, T = X.shape 202 | 203 | Rw = sum(np.cov(X[n,...]) for n in range(N)) 204 | Rt = N**2 * np.cov(X.mean(axis=0)) 205 | Rb = (Rt - Rw) / (N-1) 206 | 207 | ISC = np.diagonal(W.T.dot(Rb).dot(W)) / np.diag(W.T.dot(Rw).dot(W)) 208 | return np.real(ISC) 209 | 210 | def get_forwardmodel(X, W): 211 | ''' 212 | Get forward model from signal(X) and backward model (W). 213 | 214 | Parameters 215 | ---------- 216 | X : ndarray of shape = (n_subj, n_dim, n_times) 217 | Signal to transform. 218 | W : ndarray of shape = (n_times, n_components) 219 | Backward model (signal to components). 220 | 221 | Returns 222 | ------- 223 | A : ndarray of shape = (n_times, n_components) 224 | Forward model (components to signal). 225 | ''' 226 | 227 | N, D, T = X.shape # subj x dim x times (instead of times x dim x subj) 228 | 229 | Rw = sum(np.cov(X[n,...]) for n in range(N)) 230 | Rt = N**2 * np.cov(X.mean(axis=0)) 231 | Rb = (Rt - Rw) / (N-1) 232 | 233 | k = np.linalg.matrix_rank(Rw) 234 | if k==D: 235 | A = Rw.dot(W).dot(sp_linalg.inv(W.T.dot(Rw).dot(W))) 236 | else: 237 | A = Rw.dot(W).dot(np.diag(1 / np.diag(W.T.dot(Rw).dot(W)))) 238 | return A 239 | 240 | def reconstruct(Y, A): 241 | ''' 242 | Reconstruct signal(X) from components (Y) and forward model (A). 243 | 244 | Parameters 245 | ---------- 246 | Y : ndarray of shape = (n_subj, n_components, n_times) or (n_components, n_times) 247 | CorrCA components. 248 | A : ndarray of shape = (n_times, n_components) 249 | Forward model (components to signal). 250 | 251 | Returns 252 | ------- 253 | X : ndarray of shape = (n_subj, n_dim, n_times) or (n_dim, n_times) 254 | Signal. 255 | ''' 256 | 257 | flag = False 258 | if Y.ndim == 2: 259 | flag = True 260 | Y = Y[np.newaxis, ...] 261 | N, _, T = Y.shape 262 | D = A.shape[0] 263 | X = np.zeros((N, D, T)) 264 | for n in range(N): 265 | X[n, ...] = A.dot(Y[n, ...]) 266 | 267 | if flag: 268 | X = np.squeeze(X, axis=0) 269 | return X 270 | 271 | def stats(X, gamma=0, k=None, n_surrogates=200, alpha=0.05): 272 | ''' 273 | Compute ISC statistical threshold using circular shift surrogates. 274 | Parameters 275 | ---------- 276 | Y : ndarray of shape = (n_subj, n_components, n_times) or (n_components, n_times) 277 | CorrCA components. 278 | A : ndarray of shape = (n_times, n_components) 279 | Forward model (components to signal). 280 | 281 | Returns 282 | ------- 283 | ''' 284 | ISC_null = [] 285 | for n in range(n_surrogates): 286 | if n%10==0: 287 | print('#', end='') 288 | surrogate = circular_shift(X) 289 | W, ISC, A = fit(surrogate, gamma=gamma, k=k) 290 | ISC_null.append(ISC[0]) # get max ISC 291 | ISC_null = np.array(ISC_null) 292 | thr = np.percentile(ISC_null, (1 - alpha) * 100) 293 | print('') 294 | return thr, ISC_null 295 | 296 | def circular_shift(X): 297 | n_reps, n_dims, n_times = X.shape 298 | shifts = np.random.choice(range(n_times), n_reps, replace=True) 299 | surrogate = np.zeros_like(X) 300 | for i in range(n_reps): 301 | surrogate[i, ...] = np.roll(X[i, ...], shifts[i], axis=1) 302 | return surrogate 303 | 304 | def time2ix(times, t): 305 | return np.abs(times - t).argmin() 306 | 307 | def get_id(params): 308 | CCA_id = 'CorrCA_{}_{}'.format(params['response_window'][0], params['response_window'][1]) 309 | if params['stats']: 310 | CCA_id += '_stats_K_{}_surr_{}_alpha_{}_gamma_{}'.format(params['K'], params['n_surrogates'], params['alpha'], params['gamma']) 311 | return CCA_id 312 | 313 | ############ 314 | # PLOTTING # 315 | ############ 316 | def plot_CCA(CCA, plot_trials=True, plot_evk=False, plot_signal=False, collapse=False, xlim=(-0.3,0.6), ylim=(-7,5), norm=True, trials_alpha=0.5, width=10): 317 | times = CCA['times'] 318 | 319 | Y = CorrCA.transform(CCA['epochs'], CCA['W'] ) 320 | Ymean = np.mean(Y, axis=0) 321 | 322 | ISC, A, times, info = CCA['ISC'], CCA['A'], CCA['times'], CCA['info'] 323 | n_CC = Y.shape[-2] 324 | 325 | n_rows = 2 if plot_signal else 1 326 | height = 6 if plot_signal else 0 327 | n_rows = 2 if plot_signal else 0 328 | height += 12 if collapse else 2.5 * n_CC 329 | n_rows += 2 if collapse else n_CC 330 | n_cols = n_CC if collapse else 3 331 | 332 | fig = plt.figure(figsize=(width, height)) 333 | 334 | if plot_signal: 335 | plot_evoked(CCA['evoked'], CCA['times'], CCA['info'], fig=fig, xlim=xlim, ylim=ylim, norm=norm) 336 | 337 | if CCA['W'].shape[1]!=0: 338 | if collapse: 339 | gs = fig.add_gridspec(3, min(8, n_CC), top=0.49, hspace=0.5) 340 | ax = fig.add_subplot(gs[:2, :]) 341 | if plot_evk: 342 | ax.plot(times, CCA['evoked'].T, color='tab:grey', linewidth=0.3) 343 | 344 | for n in range(n_CC): 345 | ax.plot(times, Ymean[n, :], label = 'Component {} - ISC = {:.2f}'.format(n+1, CCA['ISC'][n]), linewidth=1.8) 346 | 347 | ax.legend(loc='lower left') 348 | ax.set_xlim(xlim) 349 | 350 | for n in range(min(8, n_CC)): 351 | vmax = np.max(np.abs(A)) 352 | ax2 = fig.add_subplot(gs[2, n]) 353 | im, cn = mne.viz.plot_topomap(A[:, n], pos=info, axes=ax2, show=False, vmax=vmax, vmin=-vmax) 354 | ax2.set_title('Component {}'.format(n+1)) 355 | 356 | if n == n_CC-1: 357 | plt.colorbar(im, ax=ax2, fraction=0.04, pad=0.04) 358 | else: 359 | top = 0.49 if plot_signal else 0.88 360 | gs = fig.add_gridspec(n_CC, 3, top=top, hspace=0.3) 361 | for i in range(n_CC): 362 | ax = fig.add_subplot(gs[i, :2]) 363 | 364 | if plot_trials: 365 | ax.plot(times, Y[:, i, :].T, linewidth=0.5, color='tab:blue', alpha=trials_alpha) 366 | 367 | if plot_evk: 368 | ax.plot(times, CCA['evoked'].T, color='tab:grey', linewidth=0.3) 369 | 370 | ax.plot(times, Ymean[i], color='black') 371 | 372 | ax.set_xlim(xlim) 373 | ax.set_title('Component {} - ISC = {:.2f}'.format(i+1, ISC[i])) 374 | 375 | ax2 = fig.add_subplot(gs[i, 2]) 376 | im, cn = mne.viz.plot_topomap(A[:, i], pos=info, axes=ax2, show=False) 377 | 378 | return fig 379 | 380 | 381 | 382 | # Translation of original matlab function by Parra 383 | def CorrCA_matlab(X, W=None, version=2, gamma=0, k=None): 384 | ''' 385 | Correlated Component Analysis. 386 | 387 | Parameters 388 | ---------- 389 | X : array, shape (n_subj, n_dim, n_times) 390 | k : int, 391 | Truncates eigenvalues on the Kth component. 392 | 393 | Returns 394 | ------- 395 | W 396 | ISC 397 | Y 398 | A 399 | ''' 400 | 401 | # TODO: implement case 3, tsvd truncation 402 | 403 | N, D, T = X.shape # subj x dim x times (instead of times x dim x subj) 404 | 405 | if k is not None: # truncate eigenvalues using SVD 406 | gamma = 0 407 | else: 408 | k = D 409 | 410 | # Compute within- and between-subject covariances 411 | if version == 1: 412 | Xcat = X.reshape((N * D, T)) # T x (D + N) note: dimensions vary first, then subjects 413 | Rkl = np.cov(Xcat).reshape((N, D, N, D)).swapaxes(1, 2) 414 | Rw = Rkl[range(N), range(N), ...].sum(axis=0) # Sum within subject covariances 415 | Rt = Rkl.reshape(N*N, D, D).sum(axis=0) 416 | Rb = (Rt - Rw) / (N-1) 417 | 418 | elif version == 2: 419 | Rw = sum(np.cov(X[n,...]) for n in range(N)) 420 | Rt = N**2 * np.cov(X.mean(axis=0)) 421 | Rb = (Rt - Rw) / (N-1) 422 | 423 | elif version == 3: 424 | pass 425 | 426 | if W is None: 427 | k = min(k, np.linalg.matrix_rank(Rw)) # handle rank deficient data. 428 | if k < D: 429 | def regInv(R, k): 430 | '''PCA regularized inverse of square symmetric positive definite matrix R.''' 431 | 432 | U, S, Vh = np.linalg.svd(R) 433 | invR = U[:, :k].dot(sp_diag(1 / S[:k])).dot(Vh[:k, :]) 434 | return invR 435 | 436 | invR = regInv(Rw, k) 437 | ISC, W = sp_linalg.eig(invR.dot(Rb)) 438 | ISC, W = ISC[:k], W[:, :k] 439 | 440 | else: 441 | Rw_reg = (1-gamma) * Rw + gamma * Rw.diagonal().mean() * np.identity(D) 442 | ISC, W = sp_linalg.eig(Rb, Rw_reg) # W is already sorted by eigenvalue and normalized 443 | 444 | ISC = np.diagonal(W.T.dot(Rb).dot(W)) / np.diag(W.T.dot(Rw).dot(W)) 445 | 446 | ISC, W = np.real(ISC), np.real(W) 447 | 448 | Y = np.zeros((N, k, T)) 449 | for n in range(N): 450 | Y[n, ...] = W.T.dot(X[n, ...]) 451 | 452 | if k==D: 453 | A = Rw.dot(W).dot(sp_linalg.inv(W.T.dot(Rw).dot(W))) 454 | else: 455 | A = Rw.dot(W).dot(np.diag(1 / np.diag(W.T.dot(Rw).dot(W)))) 456 | 457 | return W, ISC, Y, A 458 | --------------------------------------------------------------------------------