├── HA_prep_functions.py ├── README.md ├── benchmarks.py ├── budapest_utils.py ├── drive_hyperalignment_cross_validation.py ├── environment.yml ├── hybrid_hyperalignment.py ├── permutation_tests.ipynb ├── raiders_utils.py ├── run_anatomical_benchmarks.py └── whiplash_utils.py /HA_prep_functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | HA_prep_functions.py 3 | An amalgamation of functions that were useful to set up hyperalignment for the H2A analyses. 4 | 5 | see :ref: `Busch, Slipski, et al., NeuroImage (2021)` for details. 6 | 7 | February 2021 8 | @author: Erica L. Busch 9 | """ 10 | import os 11 | import numpy as np 12 | from mvpa2.datasets.base import Dataset 13 | from mvpa2.mappers.fxy import FxyMapper 14 | from mvpa2.misc.surfing.queryengine import SurfaceQueryEngine 15 | from mvpa2.measures.searchlight import Searchlight 16 | from mvpa2.measures.base import Measure 17 | from mvpa2.mappers.zscore import zscore 18 | import scipy.stats 19 | 20 | MASKS = {'l': np.load(os.path.join(basedir, 'fsaverage_lh_mask.npy')), 'r': np.load(os.path.join(basedir, 'fsaverage_rh_mask.npy'))} 21 | 22 | 23 | class MeanFeatureMeasure(Measure): 24 | """Mean group feature measure 25 | Because the vanilla one doesn't want to work for Swaroop and I adapted this from Swaroop. 26 | Debugging is hard. Accepting the struggles of someone smarter than me is easy. 27 | """ 28 | is_trained = True 29 | 30 | def __init__(self, **kwargs): 31 | Measure.__init__(self, **kwargs) 32 | 33 | def _call(self, dataset): 34 | return Dataset(samples=np.mean(dataset.samples, axis=1)) 35 | 36 | 37 | def compute_seed_means(measure, queryengine, ds, roi_ids): 38 | """ 39 | Parameters: 40 | ---------- 41 | measure: a PyMVPA measure passed to the Searchlight. 42 | queryengine: a trained PyMVPA SurfaceQueryEngine. 43 | ds: a single PyMVPA dataset (samples: timeseries, features: vertices) 44 | roi_ids: the vertex indices where each searchlight will be centered. 45 | 46 | Returns: 47 | -------- 48 | seed_data: dataset with the mean timeseries for a searchlight centered on each ROI_id. 49 | 50 | """ 51 | 52 | # Seed data is the mean timeseries for each searchlight 53 | seed_data = Searchlight(measure, queryengine=queryengine, 54 | nproc=1, roi_ids=roi_ids.copy()) 55 | if isinstance(ds,np.ndarray): 56 | ds = Dataset(ds) 57 | seed_data = seed_data(ds) 58 | zscore(seed_data.samples, chunks_attr=None) 59 | return seed_data 60 | 61 | def compute_connectomes(datasets, queryengine, target_indices): 62 | """ 63 | Parameters: 64 | ----------- 65 | datasets: a list of PyMVPA datasets, one per subject. 66 | queryengine: the trained PyMVPA Surface queryengine, trained on the surface defining this data 67 | and the searchlight radius matching your analysis. 68 | target_indices: the indices of the vertices where each searchlight for a connectivity target will be centered. 69 | 70 | Returns: 71 | ------- 72 | connectomes: a list of connectivity matrices, where each entry is the correlation between the timeseries of a 73 | corresponding searchlight centered on each a connectivity seed and a connectivity target. 74 | """ 75 | 76 | conn_metric = lambda x,y: np.dot(x.samples, y.samples)/x.nsamples 77 | connectivity_mapper = FxyMapper(conn_metric) 78 | mean_feature_measure = MeanFeatureMeasure() 79 | 80 | # compute means for aligning seed features 81 | conn_means = [seed_means(MeanFeatureMeasure(), queryengine, ds, target_indices) for ds in datasets] 82 | 83 | conn_targets = [] 84 | for csm in conn_means: 85 | zscore(csm, chunks_attr=None) 86 | conn_targets.append(csm) 87 | 88 | connectomes = [] 89 | for target, ds in zip(conn_targets, datasets): 90 | conn_mapper.train(target) 91 | connectome = connectivity_mapper.forward(ds) 92 | connectome.fa = ds.fa 93 | zscore(connectome, chunks_attr=None) 94 | connectomes.append(connectome) 95 | return connectomes 96 | 97 | 98 | def get_node_indices(hemi, surface_res=None): 99 | """ 100 | Parameters: 101 | ----------- 102 | hemi: hemisphere (can be "r", "l", or "b" for both hemis together) 103 | surface_res: defaults to the full resolution of your surface (in this case, 10242). 104 | 105 | Returns: 106 | -------- 107 | idx: a list of the indices of nodes on a surface of your desired resolution but are not in the medial wall. 108 | """ 109 | 110 | if surface_res == None: 111 | surface_res = utils.SURFACE_RESOLUTION 112 | if hemi == 'b': 113 | r = get_node_indices('r', surface_res=surface_res) 114 | l = get_node_indices('l', surface_res=surface_res) 115 | r = r + TOT_NODES 116 | return [l,r] 117 | mask = MASKS[hemi] 118 | idx = np.where(mask[:surface_res])[0] 119 | return idx 120 | 121 | 122 | def get_freesurfer_surfaces(hemi): 123 | """ 124 | Parameters: 125 | ----------- 126 | hemi: hemisphere (can be 'r','l','b') 127 | 128 | Returns: 129 | surf: a freesurfer surface created with the .white and .pial files. 130 | """ 131 | 132 | import nibabel as nib 133 | from mvpa2.support.nibabel.surf import Surface 134 | if hemi == 'b': 135 | lh = get_freesurfer_surfaces('l') 136 | rh = get_freesurfer_surfaces('r') 137 | return lh.merge(rh) 138 | coords1, faces1 = nib.freesurfer.read_geometry(os.path.join(utils.basedir,'{lr}h.white'.format(lr=hemi))) 139 | coords2, faces2 = nib.freesurfer.read_geometry(os.path.join(utils.basedir,'{lr}h.pial'.format(lr=hemi))) 140 | np.testing.assert_array_equal(faces1, faces2) 141 | surf = Surface((coords1 + coords2) * 0.5, faces1) 142 | return surf 143 | 144 | 145 | 146 | def get_searchlights(hemi,radius): 147 | """ 148 | Parameters: 149 | ----------- 150 | hemi: hemisphere (can be 'r','l','b') 151 | radius: radius of the searchlights you want. 152 | 153 | Returns: 154 | -------- 155 | searchlights: a list of lists, where each sub-list is all the nodes within a searchlight of 156 | radius X centered on the first node in the list. Uses the PyMVPA surface query engine to do this. 157 | 158 | """ 159 | if radius is None: 160 | radius = utils.SEARCHLIGHT_RADIUS 161 | savename = os.path.join(utils.basedir,'{R}mm_searchlights_{S}h.npy'.format(R=radius,S=hemi)) 162 | try: 163 | return np.load(savename) 164 | from mvpa2.misc.surfing.queryengine import SurfaceQueryEngine 165 | # get the data for jsut the first participant 166 | node_indices = get_node_indices(hemi) 167 | surf = get_freesurfer_surfaces(hemi) 168 | subj = utils.subjects[0] 169 | # get one run of one subject 170 | ds = get_train_data(hemi, 1, num_subjects=1)[0] 171 | ds.fa['node_indices'] = node_indices.copy() 172 | qe = SurfaceQueryEngine(surf, radius) 173 | qe.train(ds) 174 | 175 | searchlights = [] 176 | for idx in node_indices: 177 | sl = qe.query_byid(idx) 178 | searchlights.append(sl) 179 | 180 | np.save(savename, searchlights) 181 | return searchlights 182 | 183 | 184 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Hybrid hyperalignment: A single high-dimensional model of shared information embedded in cortical patterns of response and functional connectivity 2 | ## H2A Pipeline & Analysis Scripts 3 | 4 | ### Erica Busch, March 2021 5 | 6 | [![DOI](https://zenodo.org/badge/DOI/10.5281/zenodo.7246743.svg)](https://doi.org/10.5281/zenodo.7246743) 7 | 8 | 9 | This directory contains the hyperalignment procedures for response hyperalignment, connectivity hyperalignment, and hybrid hyperalignment used in the H2A manuscript. It also contains all scripts used in the final analyses of the Whiplash, Raiders, and Budapest datasets. 10 | 11 | For more information, check out our [paper](https://www.sciencedirect.com/science/article/pii/S1053811921002524?via%3Dihub), now published in NeuroImage. 12 | 13 | ### Scripts & Functions 14 | - `[dataset]_utils.py` : One for each of the three datasets used here (Raiders, Whiplash, & Budapest). These scripts include relevant paths, subject IDs, and dataset-specific information, as well as functions for accessing each dataset. 15 | 16 | - `HA_prep_functions.py` : An amalgamation of functions used to prepare input data for hyperalignment. 17 | 18 | - `drive_hyperalignment_cross_validation.py` : This script drives the training of the three flavors of hyperalignment models tested in this paper in a leave-one-run-out cross-validation scheme. This script takes 2 (optionally 3) command-line arguments - type of hyperalignment model to train, the dataset, and the held-out run number. If 2 arguments, it runs through all train-test combos. After training the HA model on the training datasets, it derives mappers for each subject's data into the trained common space, saves those mappers with `save_transformations` and transforms and saves the test data in the common model space with `save_transformed_data`. It then calls `run_benchmarks` on the transformed test data. 19 | - for example: `$ python drive_hyperalignment_cross_validation.py cha whiplash 1` 20 | 21 | - `benchmarks.py` : includes source code for vertex-wise intersubject correlations, connectivity profile intersubject correlations, and sliding window between-subject multivariate pattern classifications. 22 | 23 | - `run_anatomical_benchmarks.py` : drives the control analyses (running each of the benchmarks on the anatomically aligned data, pre-hyperalignment) 24 | 25 | - `hybrid_hyperalignment.py` : the class defining the hybrid hyperalignment method and source code. 26 | -------------------------------------------------------------------------------- /benchmarks.py: -------------------------------------------------------------------------------- 1 | # benchmarks.py 2 | # erica busch, 6/2020 3 | import numpy as np 4 | import os, glob, sys 5 | from scipy.spatial.distance import cdist 6 | 7 | 8 | def vertex_isc(data): 9 | """ 10 | Performs an intersubject correlation of vertex response profiles, comparing each subject's response profiles at each vertex 11 | to the mean response profile of all the other subjects. 12 | 13 | Parameters: 14 | ---------- 15 | data: a n_subjects-length list of (timeseries, features) datasets upon which to perform ISC. 16 | 17 | Returns: 18 | ------- 19 | all_results: a numpy array of shape (n_subjects, n_features) of ISC values. 20 | """ 21 | all_results = np.ndarray((data.shape[0],data.shape[2]), dtype=float) 22 | all_subjs = np.arange(data.shape[0]) 23 | for v in np.arange(data.shape[2]): 24 | data_v = data[:,:,v] 25 | # hold out one subject; compare with average of remaining subjects 26 | for subj, ds in enumerate(data_v): 27 | group = np.setdiff1d(all_subjs, subj) 28 | group_avg = np.mean(data_v[group,:], axis=0).ravel() 29 | r = np.corrcoef(group_avg, ds.ravel())[0,1] 30 | all_results[subj, v] = r 31 | return np.array(all_results) 32 | 33 | 34 | def dense_connectivity_profile_isc(data): 35 | """ 36 | Takes the data and creates a vertex-by-vertex full connectivity matrix for each subject, then performs ISC on the 37 | connectivity profiles. 38 | 39 | Parameters: 40 | ---------- 41 | data: a n_subjects-length list of (timeseries, features) datasets from which to compute a connectivity matrix. 42 | 43 | Retu 44 | ------- 45 | all_results: a numpy array of shape (n_subjects, n_features) of ISC values. 46 | 47 | """ 48 | from mvpa2.datasets.base import Dataset 49 | from mvpa2.mappers.fxy import FxyMapper 50 | 51 | conn_metric = lambda x,y: np.dot(x.samples, y.samples)/x.nsamples 52 | connectivity_mapper = FxyMapper(conn_metric) 53 | connectomes = np.ndarray((data.shape[0], data.shape[2], data.shape[2]), dtype=float) 54 | for i,ds in enumerate(data): 55 | d = Dataset(ds) 56 | conn_targets = Dataset(samples=ds.T) 57 | connectivity_mapper.train(conn_targets) 58 | connectomes[i]=connectivity_mapper.forward(d) 59 | del conn_targets,d 60 | results = vertex_isc(connectomes) 61 | return results 62 | 63 | ## all of this runs between subject multivariate time segment classifications 64 | def searchlight_timepoint_clf(data, window_size=5, buffer_size=10, NPROC=16): 65 | """ 66 | Performs a sliding window between-subject multivariate classification on each time segment. 67 | 68 | Parameters: 69 | ----------- 70 | data: a n_subjects-length list of (timeseries, features) datasets from which to compute a connectivity matrix. 71 | window_size: defaults to 5. The number of TRs to be considered in each classification. 72 | buffer_size: defaults to 10. The number of TRs to be excluded from the classification before and after the window. 73 | NPROC: defaults to 16. The number of parallel processes you can use. 74 | 75 | Returns: 76 | -------- 77 | results: a (n_subjects, features) array of classification accuracies. 78 | """ 79 | from joblib import Parallel, delayed 80 | searchlights = get_searchlights('b', utils.SEARCHLIGHT_RADIUS) 81 | results = [] 82 | for test_subj, sub_id in enumerate(utils.subjects): 83 | train_subj = np.setdiff1d(range(len(utils.subjects)), test_subj) 84 | ds_train = np.mean(dss[train_subj],axis=0) 85 | ds_test = dss[test_subj] 86 | results.append(get_subj_accuracy(sub_id, ds_train, ds_test, searchlights, window_size, buffer_size)) 87 | results = np.stack(results) 88 | return results 89 | 90 | def get_subj_accuracy(subj_id, ds_train, ds_test, searchlights, window_size, buffer_size, NPROC): 91 | sl_errors,jobs = [],[] 92 | n_timepoints = ds_train.shape[0] 93 | for sl in searchlights: 94 | train_ds_sl = ds_train[:,sl] 95 | test_ds_sl = ds_test[:,sl] 96 | jobs.append(delayed(run_clf_job)(train_ds_sl, test_ds_sl, n_timepoints, window_size, buffer_size)) 97 | with Parallel(n_jobs=NPROC) as parallel: 98 | accuracy = np.array(parallel(jobs)) 99 | return accuracy 100 | 101 | def run_clf_job(train_ds_sl, test_ds_sl, n_timepoints, window_size, buffer_size): 102 | clf_errors=[] 103 | for t0 in np.arange(n_timepoints - window_size): 104 | foil_startpoints = get_foil_startpoints(n_timepoints, t0, window_size, buffer_size) 105 | target_index = foil_startpoints.index(t0) 106 | # average across all timepoints within the foil segments to get one score per segment, then average across participants 107 | # spatiotemporal patterns for all foil segments in the SL 108 | train_ = np.stack([np.ravel(train_ds_sl[t:t+window_size]) for t in foil_startpoints]) 109 | test_ = np.ravel(test_ds_sl[t0: t0+window_size]) 110 | dist = cdist(train_,test_[np.newaxis,:],metric='correlation') 111 | winner = np.argmin(dist) 112 | clf_errors.append(int(winner == target_index)) 113 | return np.mean(np.array(clf_errors)) 114 | 115 | def get_foil_startpoints(n_timepoints, t0, window_size, buffer_size): 116 | pre_target, post_target = get_foil_boundaries(np.arange(n_timepoints),t0, window_size, buffer_size) 117 | foil_startpoints = [t0] 118 | if pre_target: 119 | foil_startpoints += range(0, pre_target) 120 | if post_target: 121 | foil_startpoints += range(post_target, n_timepoints - window_size) 122 | return sorted(foil_startpoints) 123 | 124 | # this returns the final possible start point of a foil segment before our target segment 125 | # and the first possible start point after the target segment 126 | # this will return none if there are no valid foil segments before or after a given startpoint. 127 | def get_foil_boundaries(timepoint_arr, tstart, window_size, buffer_size): 128 | end_of_first_buffer, start_of_second_buffer = None, None 129 | if tstart > window_size + buffer_size: 130 | end_of_first_buffer = np.argmin(abs(timepoint_arr - (tstart - window_size - buffer_size))) 131 | if (tstart + window_size * 2 + buffer_size) < len(timepoint_arr): 132 | start_of_second_buffer = np.argmin(abs(timepoint_arr - (tstart + window_size + buffer_size))) 133 | return end_of_first_buffer, start_of_second_buffer 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /budapest_utils.py: -------------------------------------------------------------------------------- 1 | # budapest_utils.py 2 | # utils file specific to your dataset; all dataset-specific code will be kept here 3 | # erica busch, 2020 4 | 5 | import numpy as np 6 | from scipy.stats import zscore 7 | import os 8 | from mvpa2.datasets.base import Dataset 9 | 10 | basedir = '/dartfs/rc/lab/D/DBIC/DBIC/f002d44/h2a' 11 | home_dir = '/dartfs-hpc/rc/home/4/f002d44/h2a' 12 | connhyper_dir = os.path.join(basedir, 'connhyper','budapest') 13 | resphyper_dir = os.path.join(basedir, 'response_hyper','budapest') 14 | h2a_dir = os.path.join(basedir, 'mixed_hyper','budapest') 15 | datadir = os.path.join(basedir, 'data','budapest') 16 | orig_datadir = os.path.join(datadir, 'original') 17 | connectome_dir = os.path.join(datadir, 'connectomes') 18 | results_dir = os.path.join(basedir, 'results','budapest') 19 | iterative_HA_dir = os.path.join(basedir, 'iterative_hyper') 20 | 21 | sub_nums = [5, 7, 9, 10, 13, 20, 21, 24, 29, 34, 52, 114, 120, 134, 142, 22 | 278, 416, 499, 522, 535, 560] 23 | subjects = ['{:0>6}'.format(subid) for subid in sub_nums] 24 | SURFACE_RESOLUTION = 10242 25 | SEARCHLIGHT_RADIUS = 13 26 | HYPERALIGNMENT_RADIUS = 20 27 | TOT_RUNS= 5 28 | NNODES_LH=9372 29 | 30 | midstr = '_ses-budapest_task-movie_run-' 31 | endstr = '_space-fsaverage-icoorder5_hemi-' 32 | 33 | def get_RHA_data(runs, num_subjects=21, training=False): 34 | ds_lh = np.load(os.path.join(resphyper_dir, 'fold_{0}'.format(runs[0]),'data', 'dss_lh.npy')) 35 | ds_rh = np.load(os.path.join(resphyper_dir, 'fold_{0}'.format(runs[0]),'data', 'dss_rh.npy')) 36 | dss= np.concatenate((ds_lh,ds_rh),axis=2) 37 | if len(runs)>1: 38 | for run in runs[1:]: 39 | ds_lh = np.load(os.path.join(resphyper_dir, 'fold_{0}'.format(run),'data', 'dss_lh.npy')) 40 | ds_rh = np.load(os.path.join(resphyper_dir, 'fold_{0}'.format(run),'data', 'dss_rh.npy')) 41 | arr = np.concatenate((ds_lh, ds_rh),axis=2) 42 | dss=np.concatenate((dss,arr),axis=1) 43 | dss = format_for_training(dss,num_subjects) 44 | return dss 45 | 46 | def format_for_training(dss, num_subjects): 47 | dss_formatted = [] 48 | for ds,subj in zip(dss,subjects[:num_subjects]): 49 | data = zscore(ds, axis=0) 50 | dss_formatted.append(Dataset(data)) 51 | return dss_formatted 52 | 53 | # runs indicates which runs we want to return. 54 | # this will be useful for datafolding. 55 | def get_train_data(side, runs, num_subjects=21, z=True, mask=False): 56 | dss = [] 57 | for subj in subjects[:num_subjects]: 58 | data = _get_budapest_data(subj, side.upper(), runs, z, mask) 59 | ds = Dataset(data) 60 | idx = np.where(np.logical_not(np.all(ds.samples == 0, axis=0)))[0] 61 | ds = ds[:, idx] 62 | dss.append(ds) 63 | return dss 64 | 65 | # specific formatting for budapets data; only gets called internally. 66 | def _get_budapest_data(subject, side, runs, z, mask): 67 | LR = side.upper() 68 | run_list = ['{:0>2}'.format(r) for r in runs] 69 | if LR == 'B': 70 | return np.hstack([_get_budapest_data(subject, 'L', runs, z, mask), 71 | _get_budapest_data(subject, 'R', runs, z, mask)]) 72 | fns = ['{d}/sub-sid{s}{m}{r}{e}{LR}.func.npy'.format(d=orig_datadir, s=subject, m=midstr, 73 | r=i,e=endstr,LR=LR.upper()) for i in run_list] 74 | ds = [zscore(np.load(fn),axis=0) for fn in fns] 75 | dss = np.concatenate(ds,axis=0) 76 | return dss 77 | 78 | # dont want to return this as a pymvpa dataset; takes too long & is unnecessary 79 | def get_test_data(side, runs, num_subjects=21, z=True, mask=False): 80 | dss=[] 81 | for subj in subjects[:num_subjects]: 82 | ds = _get_budapest_data(subj, side.upper(), runs, z, mask) 83 | dss.append(ds) 84 | return np.array(dss) 85 | -------------------------------------------------------------------------------- /drive_hyperalignment_cross_validation.py: -------------------------------------------------------------------------------- 1 | # drive_hyperalignment_cross_validation.py 2 | # erica busch, 2020 3 | # this script takes 2 (optionally 3) command line arguments and runs leave-one-run-out 4 | # cross validation on the hyperalignment training. 5 | # 1) type of hyperalignment [either RHA, H2A, or CHA] 6 | # 2) dataset to use [either raiders, whiplash, or budapest] 7 | # 3) the run to test on (which means you're training the HA model on the other runs) 8 | 9 | import os, sys, itertools 10 | import numpy as np 11 | from scipy.sparse import load_npz, save_npz 12 | from mvpa2.base.hdf5 import h5save, h5load 13 | from scipy.stats import zscore 14 | import HA_prep_functions as prep 15 | import hybrid_hyperalignment as h2a 16 | from benchmarks import searchlight_timepoint_clf, vertex_isc, dense_connectivity_profile_isc 17 | 18 | os.environ['TMPDIR'] = '/dartfs-hpc/scratch/f002d44/temp' 19 | os.environ['TEMP'] = '/dartfs-hpc/scratch/f002d44/temp' 20 | os.environ['TMP'] = '/dartfs-hpc/scratch/f002d44/temp' 21 | N_LH_NODES_MASKED = 9372 22 | N_JOBS=16 23 | N_BLOCKS=128 24 | TOTAL_NODES=10242 25 | SPARSE_NODES=642 26 | HYPERALIGNMENT_RADIUS=20 27 | 28 | def save_transformations(transformations, outdir): 29 | if not os.path.isdir(outdir): 30 | os.makedirs(outdir) 31 | h5save(outdir+'/all_subjects_mappers.hdf5', transformations) 32 | for T, s in zip(transformations, utils.subjects): 33 | save_npz(outdir+"/subj{}_mapper.npz".format(s), T.proj) 34 | 35 | # apply the HA transformations to the testing data, splits into hemispheres, saves 36 | def save_transformed_data(transformations, data, outdir): 37 | if not os.path.isdir(outdir): 38 | os.makedirs(outdir) 39 | 40 | dss_lh,dss_rh=[],[] 41 | for T, d, sub in zip(transformations, data, utils.subjects): 42 | aligned = np.nan_to_num(zscore((np.asmatrix(ds)*T).A, axis=0)) 43 | ar, al = aligned[:,N_LH_NODES_MASKED:], aligned[:,:N_LH_NODES_MASKED] 44 | dss_rh.append(ar) 45 | dss_lh.append(al) 46 | np.save(outdir+'/dss_lh.npy', np.array(dss_lh)) 47 | np.save(outdir+'/dss_rh.npy', np.array(dss_rh)) 48 | print('saved at {}'.format(outdir)) 49 | 50 | # runs benchmarks and saves 51 | def run_benchmarks(fold_basedir): 52 | results_dir, data_dir = os.path.join(fold_basedir, 'results'), os.path.join(fold_basedir, 'data') 53 | dss_lh, dss_rh = np.load(data_dir+'/dss_lh.npy'), np.load(data_dir+'/dss_rh.npy') 54 | lh_res = vertex_isc(dss_lh) 55 | rh_res = vertex_isc(dss_rh) 56 | np.save(os.path.join(results_dir, 'vertex_isc_lh.npy',lh_res)) 57 | np.save(os.path.join(results_dir, 'vertex_isc_rh.npy',rh_res)) 58 | dss = np.concatenate((dss_lh, dss_rh),axis=2) 59 | cnx_results = dense_connectivity_profile_isc(dss) 60 | np.save(os.path.join(results_dir, 'dense_connectivity_profile_isc.npy'), cnx_results) 61 | clf_results = searchlight_timepoint_clf(dss,window_size=5, buffer_size=10, NPROC=16) 62 | np.save(os.path.join(results_dir, 'time_segment_clf_accuracy.npy'), clf_results) 63 | 64 | 65 | 66 | 67 | # perform leave-one-run-out cross validation on hyperalignment training 68 | # this script 69 | if __name__ == '__main__': 70 | ha_type = sys.argv[1] 71 | dataset = sys.argv[2] 72 | if dataset == 'budapest': 73 | import budapest_utils as utils 74 | elif dataset == 'raiders': 75 | import raiders_utils as utils 76 | elif dataset == 'whiplash': 77 | import whiplash_utils as utils 78 | else: 79 | print('dataset must be one of [whiplash,raiders,budapest]') 80 | sys.exit() 81 | print('running {a} on {b}'.format(a=ha_type,b=dataset)) 82 | all_runs = np.arange(1, utils.TOT_RUNS+1) 83 | 84 | # check if you specified which run you wanted to hold out. 85 | # otherwise, iterate through all train/test combos 86 | if len(sys.argv) > 3: 87 | test = [int(sys.argv[3])] 88 | train = np.setdiff1d(all_runs, test) 89 | train_run_combos = [train] 90 | else: 91 | train_run_combos = list(itertools.combinations(all_runs, utils.TOT_RUNS-1)) 92 | 93 | for train in train_run_combos: 94 | test = np.setdiff1d(all_runs, train) 95 | print('training on runs {r}; testing on run {n}'.format(r=train, n=test)) 96 | 97 | # separate testing and training data 98 | dss_train = utils.get_train_data('b',train) 99 | dss_test = utils.get_test_data('b', test) 100 | 101 | # get the node indices to run SL HA, both hemis 102 | node_indices = np.concatenate(prep.get_node_indices('b', surface_res=TOTAL_NODES)) 103 | # get the surfaces for both hemis 104 | surface = prep.get_freesurfer_surfaces('b') 105 | # make the surface QE 106 | qe = SurfaceQueryEngine(surface, HYPERALIGNMENT_RADIUS) 107 | 108 | # prepare the connectivity matrices and run HA if we are running CHA 109 | if ha_type == 'cha': 110 | target_indices = prep.get_node_indices('b', surface_res=SPARSE_NODES) 111 | dss_train = prep.compute_connectomes(dss_train, qe, target_indices) 112 | ha = SearchlightHyperalignment(queryengine=qe, 113 | nproc=N_JOBS, 114 | nblocks=N_BLOCKS, 115 | mask_node_ids=node_indices, 116 | dtype ='float64') 117 | Ts = ha(dss_train) 118 | outdir = os.path.join(utils.connhyper_dir, 'fold_{}/'.format(int(test[0]))) 119 | 120 | # run response-based searchlight hyperalignment 121 | elif ha_type == 'rha': 122 | outdir = os.path.join(utils.resphyper_dir, 'fold_{}/'.format(int(test[0]))) 123 | ha = SearchlightHyperalignment(queryengine=qe, 124 | nproc=N_JOBS, 125 | nblocks=N_BLOCKS, 126 | mask_node_ids=node_indices, 127 | dtype ='float64') 128 | Ts = ha(dss_train) 129 | 130 | # run hybrid hyperalignment 131 | elif ha_type == 'h2a': 132 | outdir = os.path.join(utils.h2a_dir, 'fold_{}/'.format(int(test[0]))) 133 | ha = h2a.HybridHyperalignment(ref_ds=data[0], 134 | mask_node_indices=node_indices, 135 | seed_indices=node_indices, 136 | target_indices=target_indices, 137 | target_radius=utils.HYPERALIGNMENT_RADIUS, 138 | surface=surf) 139 | Ts = ha(dss_train) 140 | else: 141 | print('first argument must be one of h2a, cha, rha') 142 | sys.exit() 143 | 144 | save_transformations(Ts, os.path.join(outdir, 'transformations')) 145 | save_transformed_data(Ts, dss_test, os.path.join(outdir,'data') ) 146 | run_benchmarks(ha_type, test[0], outdir) 147 | 148 | 149 | 150 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: comp_meth_env 2 | channels: 3 | - defaults 4 | dependencies: 5 | - _libgcc_mutex=0.1=main 6 | - alabaster=0.7.12=py27_0 7 | - anaconda=2019.07=py27_0 8 | - anaconda-client=1.7.2=py27_0 9 | - anaconda-project=0.8.3=py_0 10 | - asn1crypto=0.24.0=py27_0 11 | - astroid=1.6.5=py27_0 12 | - astropy=2.0.9=py27hdd07704_0 13 | - atomicwrites=1.3.0=py27_1 14 | - attrs=19.1.0=py27_1 15 | - babel=2.7.0=py_0 16 | - backports=1.0=py_2 17 | - backports.functools_lru_cache=1.5=py_2 18 | - backports.os=0.1.1=py27_0 19 | - backports.shutil_get_terminal_size=1.0.0=py27_2 20 | - backports_abc=0.5=py27_0 21 | - beautifulsoup4=4.7.1=py27_1 22 | - bitarray=0.9.3=py27h7b6447c_0 23 | - bkcharts=0.2=py27_0 24 | - blas=1.0=mkl 25 | - bleach=3.1.0=py27_0 26 | - blosc=1.16.3=hd408876_0 27 | - bokeh=1.2.0=py27_0 28 | - boto=2.49.0=py27_0 29 | - bottleneck=1.2.1=py27h035aef0_1 30 | - bzip2=1.0.8=h7b6447c_0 31 | - ca-certificates=2019.5.15=0 32 | - cairo=1.14.12=h8948797_3 33 | - cdecimal=2.3=py27h14c3975_3 34 | - certifi=2019.6.16=py27_0 35 | - cffi=1.12.3=py27h2e261b9_0 36 | - chardet=3.0.4=py27_1 37 | - click=7.0=py27_0 38 | - cloudpickle=1.2.1=py_0 39 | - clyent=1.2.2=py27_1 40 | - colorama=0.4.1=py27_0 41 | - configparser=3.7.4=py27_0 42 | - contextlib2=0.5.5=py27_0 43 | - cryptography=2.7=py27h1ba5d50_0 44 | - curl=7.65.2=hbc83047_0 45 | - cycler=0.10.0=py27_0 46 | - cython=0.29.12=py27he6710b0_0 47 | - cytoolz=0.10.0=py27h7b6447c_0 48 | - dask=1.2.2=py_0 49 | - dask-core=1.2.2=py_0 50 | - dbus=1.13.6=h746ee38_0 51 | - decorator=4.4.0=py27_1 52 | - defusedxml=0.6.0=py_0 53 | - distributed=1.28.1=py27_0 54 | - docutils=0.14=py27_0 55 | - entrypoints=0.3=py27_0 56 | - enum34=1.1.6=py27_1 57 | - et_xmlfile=1.0.1=py27_0 58 | - expat=2.2.6=he6710b0_0 59 | - fastcache=1.1.0=py27h7b6447c_0 60 | - filelock=3.0.12=py_0 61 | - flask=1.1.1=py_0 62 | - fontconfig=2.13.0=h9420a91_0 63 | - freetype=2.9.1=h8a8886c_1 64 | - fribidi=1.0.5=h7b6447c_0 65 | - funcsigs=1.0.2=py27_0 66 | - functools32=3.2.3.2=py27_1 67 | - future=0.17.1=py27_0 68 | - futures=3.3.0=py27_0 69 | - get_terminal_size=1.0.0=haa9412d_0 70 | - gevent=1.4.0=py27h7b6447c_0 71 | - glib=2.56.2=hd408876_0 72 | - glob2=0.7=py_0 73 | - gmp=6.1.2=h6c8ec71_1 74 | - gmpy2=2.0.8=py27h10f8cd9_2 75 | - graphite2=1.3.13=h23475e2_0 76 | - greenlet=0.4.15=py27h7b6447c_0 77 | - grin=1.2.1=py27_4 78 | - gst-plugins-base=1.14.0=hbbd80ab_1 79 | - gstreamer=1.14.0=hb453b48_1 80 | - h5py=2.9.0=py27h7918eee_0 81 | - harfbuzz=1.8.8=hffaf4a1_0 82 | - hdf5=1.10.4=hb1b8bf9_0 83 | - heapdict=1.0.0=py27_2 84 | - html5lib=1.0.1=py27_0 85 | - icu=58.2=h9c2bf20_1 86 | - idna=2.8=py27_0 87 | - imageio=2.5.0=py27_0 88 | - imagesize=1.1.0=py27_0 89 | - importlib_metadata=0.17=py27_1 90 | - intel-openmp=2019.4=243 91 | - ipaddress=1.0.22=py27_0 92 | - ipykernel=4.10.0=py27_0 93 | - ipython=5.8.0=py27_0 94 | - ipython_genutils=0.2.0=py27_0 95 | - ipywidgets=7.5.0=py_0 96 | - isort=4.3.21=py27_0 97 | - itsdangerous=1.1.0=py27_0 98 | - jbig=2.1=hdba287a_0 99 | - jdcal=1.4.1=py_0 100 | - jedi=0.13.3=py27_0 101 | - jinja2=2.10.1=py27_0 102 | - jpeg=9b=h024ee3a_2 103 | - jsonschema=3.0.1=py27_0 104 | - jupyter=1.0.0=py27_7 105 | - jupyter_client=5.3.1=py_0 106 | - jupyter_console=5.2.0=py27_1 107 | - jupyter_core=4.5.0=py_0 108 | - jupyterlab=0.33.11=py27_0 109 | - jupyterlab_launcher=0.11.2=py27h28b3542_0 110 | - kiwisolver=1.1.0=py27he6710b0_0 111 | - krb5=1.16.1=h173b8e3_7 112 | - lazy-object-proxy=1.4.1=py27h7b6447c_0 113 | - libarchive=3.3.3=h5d8350f_5 114 | - libcurl=7.65.2=h20c2e04_0 115 | - libedit=3.1.20181209=hc058e9b_0 116 | - libffi=3.2.1=hd88cf55_4 117 | - libgcc-ng=9.1.0=hdf63c60_0 118 | - libgfortran-ng=7.3.0=hdf63c60_0 119 | - liblief=0.9.0=h7725739_2 120 | - libpng=1.6.37=hbc83047_0 121 | - libsodium=1.0.16=h1bed415_0 122 | - libssh2=1.8.2=h1ba5d50_0 123 | - libstdcxx-ng=9.1.0=hdf63c60_0 124 | - libtiff=4.0.10=h2733197_2 125 | - libtool=2.4.6=h7b6447c_5 126 | - libuuid=1.0.3=h1bed415_2 127 | - libxcb=1.13=h1bed415_1 128 | - libxml2=2.9.9=hea5a465_1 129 | - libxslt=1.1.33=h7d1a2b0_0 130 | - linecache2=1.0.0=py27_0 131 | - llvmlite=0.29.0=py27hd408876_0 132 | - locket=0.2.0=py27_1 133 | - lxml=4.3.4=py27hefd8a0e_0 134 | - lz4-c=1.8.1.2=h14c3975_0 135 | - lzo=2.10=h49e0be7_2 136 | - markupsafe=1.1.1=py27h7b6447c_0 137 | - matplotlib=2.2.3=py27hb69df0a_0 138 | - mccabe=0.6.1=py27_1 139 | - mistune=0.8.4=py27h7b6447c_0 140 | - mkl=2019.4=243 141 | - mkl-service=2.0.2=py27h7b6447c_0 142 | - mkl_fft=1.0.12=py27ha843d7b_0 143 | - mkl_random=1.0.2=py27hd81dba3_0 144 | - mock=3.0.5=py27_0 145 | - more-itertools=5.0.0=py27_0 146 | - mpc=1.1.0=h10f8cd9_1 147 | - mpfr=4.0.1=hdf1c602_3 148 | - mpmath=1.1.0=py27_0 149 | - msgpack-python=0.6.1=py27hfd86e86_1 150 | - multipledispatch=0.6.0=py27_0 151 | - nbconvert=5.5.0=py_0 152 | - nbformat=4.4.0=py27_0 153 | - ncurses=6.1=he6710b0_1 154 | - networkx=2.2=py27_1 155 | - nltk=3.4.4=py27_0 156 | - nose=1.3.7=py27_2 157 | - notebook=5.7.8=py27_0 158 | - numba=0.44.1=py27h962f231_0 159 | - numexpr=2.6.9=py27h9e4a6bb_0 160 | - numpy=1.16.4=py27h7e9f1db_0 161 | - numpy-base=1.16.4=py27hde5b4d6_0 162 | - numpydoc=0.9.1=py_0 163 | - olefile=0.46=py27_0 164 | - openpyxl=2.6.2=py_0 165 | - openssl=1.1.1c=h7b6447c_1 166 | - packaging=19.0=py27_0 167 | - pandas=0.24.2=py27he6710b0_0 168 | - pandoc=2.2.3.2=0 169 | - pandocfilters=1.4.2=py27_1 170 | - pango=1.42.4=h049681c_0 171 | - parso=0.5.0=py_0 172 | - partd=1.0.0=py_0 173 | - patchelf=0.9=he6710b0_3 174 | - path.py=11.5.0=py27_0 175 | - pathlib2=2.3.4=py27_0 176 | - patsy=0.5.1=py27_0 177 | - pcre=8.43=he6710b0_0 178 | - pep8=1.7.1=py27_0 179 | - pexpect=4.7.0=py27_0 180 | - pickleshare=0.7.5=py27_0 181 | - pillow=6.1.0=py27h34e0f95_0 182 | - pip=19.1.1=py27_0 183 | - pixman=0.38.0=h7b6447c_0 184 | - pkginfo=1.5.0.1=py27_0 185 | - pluggy=0.12.0=py_0 186 | - ply=3.11=py27_0 187 | - prometheus_client=0.7.1=py_0 188 | - prompt_toolkit=1.0.15=py27_0 189 | - psutil=5.6.3=py27h7b6447c_0 190 | - ptyprocess=0.6.0=py27_0 191 | - py=1.8.0=py27_0 192 | - py-lief=0.9.0=py27h7725739_2 193 | - pycairo=1.18.1=py27h2a1e443_0 194 | - pycodestyle=2.5.0=py27_0 195 | - pycosat=0.6.3=py27h14c3975_0 196 | - pycparser=2.19=py27_0 197 | - pycrypto=2.6.1=py27h14c3975_9 198 | - pycurl=7.43.0.3=py27h1ba5d50_0 199 | - pyflakes=2.1.1=py27_0 200 | - pygments=2.4.2=py_0 201 | - pylint=1.9.2=py27_0 202 | - pyodbc=4.0.26=py27he6710b0_0 203 | - pyopenssl=19.0.0=py27_0 204 | - pyparsing=2.4.0=py_0 205 | - pyqt=5.9.2=py27h05f1152_2 206 | - pyrsistent=0.14.11=py27h7b6447c_0 207 | - pysocks=1.7.0=py27_0 208 | - pytables=3.5.2=py27h71ec239_1 209 | - pytest=4.6.2=py27_0 210 | - python=2.7.16=h9bab390_0 211 | - python-dateutil=2.8.0=py27_0 212 | - python-libarchive-c=2.8=py27_11 213 | - pytz=2019.1=py_0 214 | - pywavelets=1.0.3=py27hdd07704_1 215 | - pyyaml=5.1.1=py27h7b6447c_0 216 | - pyzmq=18.0.0=py27he6710b0_0 217 | - qt=5.9.7=h5867ecd_1 218 | - qtawesome=0.5.7=py27_1 219 | - qtconsole=4.5.1=py_0 220 | - qtpy=1.8.0=py_0 221 | - readline=7.0=h7b6447c_5 222 | - requests=2.22.0=py27_0 223 | - rope=0.14.0=py_0 224 | - ruamel_yaml=0.15.46=py27h14c3975_0 225 | - scandir=1.10.0=py27h7b6447c_0 226 | - scikit-image=0.14.2=py27he6710b0_0 227 | - scikit-learn=0.20.3=py27hd81dba3_0 228 | - scipy=1.2.1=py27h7c811a0_0 229 | - seaborn=0.9.0=py27_0 230 | - send2trash=1.5.0=py27_0 231 | - setuptools=41.0.1=py27_0 232 | - simplegeneric=0.8.1=py27_2 233 | - singledispatch=3.4.0.3=py27_0 234 | - sip=4.19.8=py27hf484d3e_0 235 | - six=1.12.0=py27_0 236 | - snappy=1.1.7=hbae5bb6_3 237 | - snowballstemmer=1.9.0=py_0 238 | - sortedcollections=1.1.2=py27_0 239 | - sortedcontainers=2.1.0=py27_0 240 | - soupsieve=1.8=py27_0 241 | - sphinx=1.8.5=py27_0 242 | - sphinxcontrib=1.0=py27_1 243 | - sphinxcontrib-websupport=1.1.2=py_0 244 | - spyder=3.3.6=py27_0 245 | - spyder-kernels=0.5.1=py27_0 246 | - sqlalchemy=1.3.5=py27h7b6447c_0 247 | - sqlite=3.29.0=h7b6447c_0 248 | - ssl_match_hostname=3.7.0.1=py27_0 249 | - statsmodels=0.10.0=py27hdd07704_0 250 | - subprocess32=3.5.4=py27h7b6447c_0 251 | - sympy=1.4=py27_0 252 | - tblib=1.4.0=py_0 253 | - terminado=0.8.2=py27_0 254 | - testpath=0.4.2=py27_0 255 | - tk=8.6.8=hbc83047_0 256 | - toolz=0.10.0=py_0 257 | - tornado=5.1.1=py27h7b6447c_0 258 | - tqdm=4.32.1=py_0 259 | - traceback2=1.4.0=py27_0 260 | - traitlets=4.3.2=py27_0 261 | - typing=3.7.4=py27_0 262 | - unicodecsv=0.14.1=py27_0 263 | - unittest2=1.1.0=py27_0 264 | - unixodbc=2.3.7=h14c3975_0 265 | - urllib3=1.24.2=py27_0 266 | - wcwidth=0.1.7=py27_0 267 | - webencodings=0.5.1=py27_1 268 | - werkzeug=0.15.4=py_0 269 | - wheel=0.33.4=py27_0 270 | - widgetsnbextension=3.5.0=py27_0 271 | - wrapt=1.11.2=py27h7b6447c_0 272 | - wurlitzer=1.0.2=py27_0 273 | - xlrd=1.2.0=py27_0 274 | - xlsxwriter=1.1.8=py_0 275 | - xlwt=1.3.0=py27_0 276 | - xz=5.2.4=h14c3975_4 277 | - yaml=0.1.7=had09818_2 278 | - zeromq=4.3.1=he6710b0_3 279 | - zict=1.0.0=py_0 280 | - zipp=0.5.1=py_0 281 | - zlib=1.2.11=h7b6447c_3 282 | - zstd=1.3.7=h0b5b093_0 283 | - pip: 284 | - nilearn==0.5.2 285 | - pprocess==0.5.3 286 | prefix: /dartfs-hpc/rc/home/4/f002d44/.conda/envs/comp_meth_env 287 | -------------------------------------------------------------------------------- /hybrid_hyperalignment.py: -------------------------------------------------------------------------------- 1 | """ hybrid hyperalignment on surface. 2 | 3 | see :ref: `Busch, Slipski, et al., NeuroImage (2021)` for details. 4 | 5 | February 2021 6 | @author: Erica L. Busch 7 | 8 | """ 9 | import os 10 | import numpy as np 11 | import scipy.stats 12 | from mvpa2.mappers.zscore import zscore 13 | 14 | from mvpa2.base import debug 15 | from mvpa2.measures.base import Measure 16 | from mvpa2.base.hdf5 import h5save, h5load 17 | 18 | from mvpa2.misc.surfing.queryengine import SurfaceQueryEngine 19 | from mvpa2.measures.searchlight import Searchlight 20 | from mvpa2.algorithms.searchlight_hyperalignment import SearchlightHyperalignment 21 | from mvpa2.datasets.base import Dataset 22 | from mvpa2.mappers.fxy import FxyMapper 23 | 24 | class MeanFeatureMeasure(Measure): 25 | """Mean group feature measure 26 | Because the vanilla one doesn't want to work for Swaroop and I adapted this from Swaroop. 27 | Debugging is hard. Accepting the struggles of someone smarter than me is easy. 28 | """ 29 | is_trained = True 30 | 31 | def __init__(self, **kwargs): 32 | Measure.__init__(self, **kwargs) 33 | 34 | def _call(self, dataset): 35 | return Dataset(samples=np.mean(dataset.samples, axis=1)) 36 | 37 | 38 | 39 | class HybridHyperalignment(): 40 | """ 41 | Given a list of datasets, provide a list of mappers into common space based on hybrid hyperalignment. 42 | 43 | 1) Input datasets should be PyMVPA datasets. 44 | 2) They should be of the same size (nsamples, nfeatures) 45 | and be aligned anatomically. 46 | 3) All features in the datasets should be zscored. 47 | 4) Datasets should all have a feature attribute `node_indices` containing the location of the feature 48 | on the surface. 49 | 50 | Parameters 51 | ---------- 52 | mask_ids: Default is none, type is list or array of integers. Specify a mask within which to compute 53 | searchlight hyperalignment. If none, set equal to seed_indices. One of the two is required. 54 | 55 | seed_indices: Default is none, type is list or array of integers. Node indices that correspond to seed 56 | centers for connectivity seeds. If none, set equal to mask_ids. One of the two is required. 57 | 58 | target_indices: Default is none, type is list or array of integers. Node indices corresponding to the center 59 | of connectivity targets. If none, set equal to seed_indices (will be dense!). 60 | 61 | surface: Required. The freesurfer surface defining your data. 62 | 63 | queryengine: Required. A single pymvpa query engine (or list of pymvpa queryengines, one per dataset) to be used by 64 | searchlight hyperalignment. If none, will be defined from surface and seed radius. 65 | 66 | target_radius: default is 20mm ala H2A paper. Minimum is 1. Radius for target searchlight. 67 | 68 | seed_radius: default is 13 ala H2A paper. Radius for connectivity seed searchlight. 69 | 70 | conn_metric: Connectivity metric between features. Default is the dot product of samples (which on zscored data 71 | becomes correlation if you normalize by nsamples. 72 | 73 | dtype: default is 'float64'. 74 | 75 | nproc: default is 1. 76 | 77 | nblocks: Number of blocks to divide to process. Higher number means lower memory consumption. default is 1. 78 | 79 | get_all_mappers: do you want to return both the mappers from AA -> iter1 space, and iter1_space -> H2A space? 80 | defualts to false. 81 | 82 | Returns 83 | ------- 84 | if get_all_mappers, returns iteration1 mappers, iteration2 mappers, and the final mappers in a list for each subject. 85 | otherwise, returns only the final mappers. 86 | 87 | """ 88 | 89 | 90 | 91 | def __init__(self, ref_ds, surface, mask_node_indices=None, seed_indices=None, target_indices=None, queryengine=None, target_radius=20, seed_radius=13, dtype='float64', nproc=1, nblocks=1, get_all_mappers=False): 92 | self.ref_ds = ref_ds 93 | self.mask_node_indices = mask_node_indices 94 | self.seed_indices = seed_indices 95 | self.target_indices = target_indices 96 | self.surface = surface 97 | self.queryengine = queryengine 98 | self.target_radius = target_radius 99 | self.seed_radius = seed_radius 100 | self.conn_metric = lambda x, y: np.dot(x.samples.T, y.samples)/x.nsamples 101 | self.dtype = dtype 102 | self.nproc = nproc #"""Number of blocks to divide to process. Higher number means lower memory consumption.""" 103 | self.nblocks = nblocks #"""Number of blocks to divide to process. Higher number means lower memory consumption.""" 104 | self.target_queryengine = None 105 | self.get_all_mappers = get_all_mappers 106 | 107 | if self.seed_indices is None: 108 | self.seed_indices = np.arange(ref_ds.shape[-1]) 109 | 110 | if self.target_indices is None: 111 | self.target_indices = np.arange(ref_ds.shape[-1]) 112 | 113 | if self.mask_node_indices is None: 114 | self.mask_node_indices = self.seed_indices.copy() 115 | 116 | if self.queryengine is None: 117 | self.queryengine = SurfaceQueryEngine(self.surface, self.seed_radius) 118 | self.queryengine.train(ref_ds) 119 | 120 | if self.target_queryengine is None: 121 | self.target_queryengine = SurfaceQueryEngine(self.surface, self.target_radius) 122 | self.target_queryengine.train(ref_ds) 123 | 124 | def _seed_means(self, measure, queryengine, ds, seed_indices): 125 | # Seed data is the mean timeseries for each searchlight 126 | seed_data = Searchlight(measure, queryengine=queryengine, 127 | nproc=self.nproc, roi_ids=np.concatenate(seed_indices).copy()) 128 | if isinstance(ds,np.ndarray): 129 | ds = Dataset(ds) 130 | seed_data = seed_data(ds) 131 | zscore(seed_data.samples, chunks_attr=None) 132 | return seed_data 133 | 134 | def _get_connectomes(self, datasets): 135 | conn_mapper = FxyMapper(self.conn_metric) 136 | mean_feature_measure = MeanFeatureMeasure() 137 | qe = self.queryengine 138 | 139 | roi_ids = self.target_indices 140 | # compute means for aligning seed features 141 | conn_means = [self._seed_means(MeanFeatureMeasure(), qe, ds, roi_ids) for ds in datasets] 142 | 143 | conn_targets = [] 144 | for csm in conn_means: 145 | zscore(csm, chunks_attr=None) 146 | conn_targets.append(csm) 147 | 148 | connectomes = [] 149 | for target, ds in zip(conn_targets, datasets): 150 | conn_mapper.train(target) 151 | connectome = conn_mapper.forward(ds) 152 | connectome.fa = ds.fa 153 | zscore(connectome, chunks_attr=None) 154 | connectomes.append(connectome) 155 | return connectomes 156 | 157 | def _apply_mappers(self, datasets, mappers): 158 | aligned_datasets = [d.get_mapped(M) for d,M in zip(datasets, mappers)] 159 | return aligned_datasets 160 | 161 | def _frobenius_norm_and_merge(self, dss_connectomes, dss_response, node_indices): 162 | # figure out which of the two types of data are larger 163 | if dss_response[0].shape[0] > dss_connectomes[0].shape[0]: 164 | larger = dss_response 165 | smaller = dss_connectomes 166 | else: 167 | larger = dss_connectomes 168 | smaller = dss_response 169 | node_ids = node_indices 170 | # find the normalization ratio based on which is larger 171 | norm_ratios = [] 172 | for la, sm in zip(larger, smaller): 173 | laN = np.linalg.norm(la, ord='fro') 174 | smN = np.linalg.norm(sm, ord='fro') 175 | v = laN / smN 176 | norm_ratios.append(v) 177 | 178 | # normalize the smaller one and then merge the datasets 179 | merged_dss = [] 180 | for la, sm, norm in zip(larger, smaller, norm_ratios): 181 | d_sm = sm.samples * norm 182 | merged = np.vstack((d_sm, la.samples)) 183 | merged = Dataset(samples=merged) 184 | merged.fa['node_indices'] = node_ids.copy() 185 | merged_dss.append(merged) 186 | return merged_dss 187 | 188 | def _prep_h2a_data(self, response_data, node_indices): 189 | for d in response_data: 190 | if isinstance(d, np.ndarray): 191 | d = Dataset(d) 192 | d.fa['node_indices']= node_indices.copy() 193 | 194 | connectivity_data = self._get_connectomes(response_data) 195 | h2a_input_data = self._frobenius_norm_and_merge(connectivity_data, response_data, node_indices) 196 | for d in h2a_input_data: 197 | d.fa['node_indices'] = node_indices.copy() 198 | zscore(d, chunks_attr=None) 199 | return h2a_input_data 200 | 201 | def __call__(self, datasets): 202 | """ estimate mappers for each dataset. 203 | Parameters 204 | ---------- 205 | datasets : list of datasets 206 | 207 | Returns 208 | ------- 209 | mappers_iter1: mappers from the first HA iteration 210 | mappers_iter2: mappers from the second HA iteration 211 | 212 | """ 213 | debug.active += ['SHPAL', 'SLC'] 214 | 215 | mask_node_indices = np.concatenate(self.mask_node_indices) 216 | qe = self.queryengine 217 | nproc = self.nproc 218 | dtype = self.dtype 219 | nblocks = self.nblocks 220 | 221 | ha_iter1 = SearchlightHyperalignment(queryengine=qe, 222 | nproc=nproc, 223 | nblocks=nblocks, 224 | mask_node_ids=mask_node_indices, 225 | dtype=dtype) 226 | 227 | mappers_iter1 = ha_iter1(datasets) 228 | aligned_iter1_datasets = self._apply_mappers(datasets, mappers_iter1) 229 | 230 | 231 | h2a_input_data = self._prep_h2a_data(aligned_iter1_datasets, mask_node_indices) 232 | ha_iter2 = SearchlightHyperalignment(queryengine=qe, 233 | nproc=nproc, 234 | nblocks=nblocks, 235 | mask_node_ids=mask_node_indices, 236 | dtype=dtype) 237 | mappers_iter2 = ha_iter2(h2a_input_data) 238 | # push the original data through the trained model 239 | mappers_final = ha_iter2(datasets) 240 | 241 | if self.get_all_mappers: 242 | return mappers_iter1, mappers_iter2, mappers_final 243 | return mappers_final 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | ` 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | -------------------------------------------------------------------------------- /permutation_tests.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 24, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import os,sys,glob\n", 10 | "import numpy as np\n", 11 | "import random\n", 12 | "from itertools import combinations\n", 13 | "import pandas as pd\n", 14 | "from scipy.stats import percentileofscore" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "### Average final results across data folds\n", 22 | "goes from (n_runs, n_subjects, n_features) -> (n_subjects, n_features)\n", 23 | "and we can average across features to get an average value (barplots) or across subjects to get a feature map" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": { 30 | "collapsed": true 31 | }, 32 | "outputs": [ 33 | { 34 | "ename": "NameError", 35 | "evalue": "name 'os' is not defined", 36 | "output_type": "error", 37 | "traceback": [ 38 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 39 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 40 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mrha_dir\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'../../response_hyper/{}/'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mmetric\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"dense\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\"vertex_isc\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\"clf\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0moutfn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'../final_results/{}/rha_{}.npy'\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0mg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mglob\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglob\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrha_dir\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"fold*\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\"results\"\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\"*{}*.npy\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmetric\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mmetric_results\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 41 | "\u001b[0;31mNameError\u001b[0m: name 'os' is not defined" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "for dataset in ['budapest','raiders','whiplash']:\n", 47 | " rha_dir = '../../response_hyper/{}/'.format(dataset)\n", 48 | " for metric in [\"dense\",\"vertex_isc\",\"clf\"]:\n", 49 | " outfn = os.path.join('../final_results/{}/rha_{}.npy'.format(dataset,metric))\n", 50 | " g = glob.glob(os.path.join(rha_dir, \"fold*\",\"results\",\"*{}*.npy\".format(metric)))\n", 51 | " metric_results = np.stack([np.load(b) for b in g])\n", 52 | " avg_result = np.mean(metric_results, axis=0)\n", 53 | " np.save(outfn, avg_result)\n", 54 | "\n", 55 | " cha_dir = '../../conn_hyper/{}/'.format(dataset)\n", 56 | " for metric in [\"dense\",\"vertex_isc\",\"clf\"]:\n", 57 | " outfn = os.path.join('../final_results/{}/cha_{}.npy'.format(dataset,metric))\n", 58 | " g = glob.glob(os.path.join(cha_dir, \"fold*\",\"results\",\"*{}*.npy\".format(metric)))\n", 59 | " metric_results = np.stack([np.load(b) for b in g])\n", 60 | " avg_result = np.mean(metric_results, axis=0)\n", 61 | " np.save(outfn, avg_result)\n", 62 | "\n", 63 | " h2a_dir = '../../response_hyper/{}/'.format(dataset)\n", 64 | " for metric in [\"dense\",\"vertex_isc\",\"clf\"]:\n", 65 | " outfn = os.path.join('../final_results/{}/h2a_{}.npy'.format(dataset,metric))\n", 66 | " g = glob.glob(os.path.join(h2a_dir, \"fold*\",\"results\",\"*{}*.npy\".format(metric)))\n", 67 | " metric_results = np.stack([np.load(b) for b in g])\n", 68 | " avg_result = np.mean(metric_results, axis=0)\n", 69 | " np.save(outfn, avg_result)\n", 70 | "\n", 71 | " aa_dir = '../../response_hyper/{}/'.format(dataset)\n", 72 | " for metric in [\"dense\",\"vertex_isc\",\"clf\"]:\n", 73 | " outfn = os.path.join('../final_results/{}/aa_{}.npy'.format(dataset,metric))\n", 74 | " g = glob.glob(os.path.join(aa_dir, \"fold*\",\"results\",\"*{}*.npy\".format(metric)))\n", 75 | " metric_results = np.stack([np.load(b) for b in g])\n", 76 | " avg_result = np.mean(metric_results, axis=0)\n", 77 | " np.save(outfn, avg_result)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 25, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "results_dict = {'dataset':[],'metric':[],'comparison':[],'pvalue':[]}" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 27, 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "# observed_difference = model1_metricA - model2_metricA\n", 96 | "def permutation_test(observed_difference, n_iterations):\n", 97 | " obtained_mean = observed_difference.mean()\n", 98 | " null_distribution_of_means = np.ndarray((n_iterations))\n", 99 | " # flip sign permutation test\n", 100 | " for i in range(n_iterations):\n", 101 | " weights = [random.choice([1,-1]) for d in range(len(observed_difference))]\n", 102 | " null_distribution_of_means[i]=(weights * observed_difference).mean()\n", 103 | " percentile = percentileofscore(null_distribution_of_means, obtained_mean)\n", 104 | " return percentile # this returns the percentile of obtained score versus the null distribution\n", 105 | "# of scores. Then we compute the pvalue depending upon if it's a 1 tailed or 2 tailed value.\n", 106 | "\n", 107 | "def run_permutations(data, n_iterations):\n", 108 | " pvalues = {}\n", 109 | " combos = combinations(data.keys(),2) \n", 110 | " for combo in combos:\n", 111 | " combo_label = str(combo[0])+'_'+str(combo[1])\n", 112 | " obtained_difference = data[combo[0]] - data[combo[1]]\n", 113 | " percentile = permutation_test(obtained_difference, n_iterations)\n", 114 | " if (obtained_difference>0).sum() < 20:\n", 115 | " percentile=100.-percentile\n", 116 | " pvalues[combo_label] = (100.-percentile)/100.\n", 117 | " return pvalues" 118 | ] 119 | }, 120 | { 121 | "cell_type": "markdown", 122 | "metadata": {}, 123 | "source": [ 124 | "### Budapest significance testing" 125 | ] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "execution_count": 6, 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "budapest_results = '../../final_results/budapest/'\n", 134 | "isc={}\n", 135 | "for A in ['aa','cha','rha','h2a']:\n", 136 | " isc[A] = np.mean(np.nan_to_num(np.load(os.path.join(budapest_results, '{A}_vertex_isc.npy'.format(A=A)))),axis=0)\n", 137 | "\n", 138 | "clf={}\n", 139 | "for A in ['aa','cha','rha','h2a']:\n", 140 | " clf[A] = np.mean(np.nan_to_num(np.load(os.path.join(budapest_results, '{A}_clf.npy'.format(A=A)))),axis=1)\n", 141 | "\n", 142 | "cnx={}\n", 143 | "for A in ['aa','cha','rha','h2a']:\n", 144 | " cnx[A] = np.mean(np.nan_to_num(np.load(os.path.join(budapest_results, '{A}_cnx_isc.npy'.format(A=A)))),axis=0)\n", 145 | "\n", 146 | "isc_p, clf_p, cnx_p = [run_permutations(d, 10000) for d in [isc,clf, cnx]]\n", 147 | "for d,typ in zip([isc_p,clf_p,cnx_p],['isc','clf','cnx']):\n", 148 | " for key in d.keys():\n", 149 | " results_dict['dataset'].append('budapest')\n", 150 | " results_dict['metric'].append(typ)\n", 151 | " results_dict['comparison'].append(key)\n", 152 | " # these are always one-sided because we expect HA to outperform AA\n", 153 | " if 'aa' in key:\n", 154 | " d[key]=2*d[key]\n", 155 | " results_dict['pvalue'].append(d[key])\n", 156 | " " 157 | ] 158 | }, 159 | { 160 | "cell_type": "markdown", 161 | "metadata": {}, 162 | "source": [ 163 | "### Raiders " 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 8, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "results = '../../final_results/raiders/'\n", 173 | "isc={}\n", 174 | "for A in ['aa','cha','rha','h2a']:\n", 175 | " isc[A] = np.mean(np.nan_to_num(np.load(os.path.join(results, '{A}_vertex_isc.npy'.format(A=A)))),axis=0)\n", 176 | "\n", 177 | "clf={}\n", 178 | "for A in ['aa','cha','rha','h2a']:\n", 179 | " clf[A] = np.mean(np.nan_to_num(np.load(os.path.join(results, '{A}_clf.npy'.format(A=A)))),axis=1)\n", 180 | "\n", 181 | "cnx={}\n", 182 | "for A in ['aa','cha','rha','h2a']:\n", 183 | " cnx[A] = np.mean(np.nan_to_num(np.load(os.path.join(results, '{A}_cnx_isc.npy'.format(A=A)))),axis=0)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 9, 189 | "metadata": {}, 190 | "outputs": [], 191 | "source": [ 192 | "isc_p, clf_p, cnx_p = [run_permutations(d, 10000) for d in [isc,clf, cnx]]\n", 193 | "for d,typ in zip([isc_p,clf_p,cnx_p],['isc','clf','cnx']):\n", 194 | " for key in d.keys():\n", 195 | " results_dict['dataset'].append('raiders')\n", 196 | " results_dict['metric'].append(typ)\n", 197 | " results_dict['comparison'].append(key)\n", 198 | " # these are always one-sided because we expect HA to outperform AA\n", 199 | " if 'aa' in key:\n", 200 | " d[key]=2*d[key]\n", 201 | " results_dict['pvalue'].append(d[key])" 202 | ] 203 | }, 204 | { 205 | "cell_type": "markdown", 206 | "metadata": {}, 207 | "source": [ 208 | "### Whiplash" 209 | ] 210 | }, 211 | { 212 | "cell_type": "code", 213 | "execution_count": 29, 214 | "metadata": {}, 215 | "outputs": [], 216 | "source": [ 217 | "results = '../../final_results/whiplash/'\n", 218 | "isc={}\n", 219 | "for A in ['aa','cha','rha','h2a']:\n", 220 | " isc[A] = np.mean(np.nan_to_num(np.load(os.path.join(results, '{A}_vertex_isc.npy'.format(A=A)))),axis=0)\n", 221 | "\n", 222 | "clf={}\n", 223 | "for A in ['aa','cha','rha','h2a']:\n", 224 | " clf[A] = np.mean(np.nan_to_num(np.load(os.path.join(results, '{A}_clf.npy'.format(A=A)))),axis=1)\n", 225 | "clf['aa'] = np.mean(np.nan_to_num(np.load(os.path.join(results, '{A}_clf.npy'.format(A='aa')))),axis=1)\n", 226 | "\n", 227 | "cnx={}\n", 228 | "for A in ['aa','cha','rha','h2a']:\n", 229 | " cnx[A] = np.mean(np.nan_to_num(np.load(os.path.join(results, '{A}_cnx_isc.npy'.format(A=A)))),axis=0)\n", 230 | "\n", 231 | "isc_p, clf_p, cnx_p = [run_permutations(d, 10000) for d in [isc,clf,cnx]]\n", 232 | "for d,typ in zip([isc_p,clf_p,cnx_p],['isc','clf','cnx']):\n", 233 | " for key in d.keys():\n", 234 | " results_dict['dataset'].append('whiplash')\n", 235 | " results_dict['metric'].append(typ)\n", 236 | " results_dict['comparison'].append(key)\n", 237 | " # these are always one-sided because we expect HA to outperform AA\n", 238 | " if 'aa' in key:\n", 239 | " d[key]=2*d[key]\n", 240 | " results_dict['pvalue'].append(d[key]) \n", 241 | " \n", 242 | "res = pd.DataFrame(results_dict)\n", 243 | "res.to_csv('../../final_results/significance_all_datasets_final.csv') " 244 | ] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "Python 2", 250 | "language": "python", 251 | "name": "python2" 252 | }, 253 | "language_info": { 254 | "codemirror_mode": { 255 | "name": "ipython", 256 | "version": 3 257 | }, 258 | "file_extension": ".py", 259 | "mimetype": "text/x-python", 260 | "name": "python", 261 | "nbconvert_exporter": "python", 262 | "pygments_lexer": "ipython3", 263 | "version": "3.7.6" 264 | } 265 | }, 266 | "nbformat": 4, 267 | "nbformat_minor": 2 268 | } 269 | -------------------------------------------------------------------------------- /raiders_utils.py: -------------------------------------------------------------------------------- 1 | # raiders_utils.py 2 | # utils file specific to raiders_utils; all dataset-specific code will be kept here 3 | # erica busch, 2020 4 | 5 | import numpy as np 6 | from scipy.stats import zscore 7 | import os 8 | from mvpa2.datasets.base import Dataset 9 | 10 | basedir = '/dartfs/rc/lab/D/DBIC/DBIC/f002d44/h2a' 11 | home_dir = '/dartfs-hpc/rc/home/4/f002d44/h2a' 12 | connhyper_dir = os.path.join(basedir, 'connhyper','raiders') 13 | resphyper_dir = os.path.join(basedir, 'response_hyper','raiders') 14 | h2a_dir = os.path.join(basedir, 'mixed_hyper','raiders') 15 | datadir = os.path.join(basedir, 'data','raiders') 16 | orig_datadir = os.path.join(datadir, 'original') 17 | connectome_dir = os.path.join(datadir, 'connectomes') 18 | results_dir = os.path.join(basedir, 'results','raiders') 19 | iterative_HA_dir = os.path.join(basedir, 'iterative_hyper') 20 | 21 | sub_nums = [5, 7, 9, 10, 12, 13, 20, 21, 24, 29, 34, 52, 102, 114, 120, 134, 142, 22 | 278, 416, 433, 499, 522, 535] 23 | subjects = ['{:0>6}'.format(subid) for subid in sub_nums] 24 | SURFACE_RESOLUTION = 10242 25 | SEARCHLIGHT_RADIUS = 13 26 | HYPERALIGNMENT_RADIUS = 20 27 | TOT_RUNS= 4 28 | MASKS = {'l':np.load(basedir+'/fsaverage_lh_mask.npy')[:SURFACE_RESOLUTION], 29 | 'r':np.load(basedir+'/fsaverage_rh_mask.npy')[:SURFACE_RESOLUTION]} 30 | NNODES_LH = 9372 31 | 32 | def get_RHA_data(runs, num_subjects=21): 33 | dss=[] 34 | for run in runs: 35 | ds_lh = np.load(os.path.join(resphyper_dir, 'fold_{0}'.format(run), 'dss_lh.npy')) 36 | ds_rh = np.load(os.path.join(resphyper_dir, 'fold_{0}'.format(run), 'dss_rh.npy')) 37 | dss.append(np.concatenate((ds_lh, ds_rh),axis=2)) 38 | return np.array(dss) 39 | 40 | # runs indicates which runs we want to return. 41 | # this will be useful for datafolding. 42 | def get_train_data(side, runs, num_subjects=23, z=True, mask=True): 43 | dss = [] 44 | for subj in subjects[:num_subjects]: 45 | data = _get_raiders_data(subj, side.upper(), runs, z, mask) 46 | ds = Dataset(data) 47 | idx = np.where(np.logical_not(np.all(ds.samples == 0, axis=0)))[0] 48 | ds = ds[:, idx] 49 | dss.append(ds) 50 | return dss 51 | 52 | # specific formatting for raiders data; only gets called internally. 53 | def _get_raiders_data(subject, side, runs, z, mask): 54 | side = side.lower() 55 | run_list = ['{:0>2}'.format(r) for r in runs] 56 | if side == 'b': 57 | return np.hstack([_get_raiders_data(subject, 'l', runs, z, mask), 58 | _get_raiders_data(subject, 'r', runs, z, mask)]) 59 | fns = ['{d}/sid{s}_{h}h_movie_{r}.npy'.format(d=orig_datadir, s=subject, h=side, r=i) for i in run_list] 60 | ds = [] 61 | for fn in fns: 62 | d = np.load(fn) 63 | if mask: 64 | d = d[:,MASKS[side]] 65 | if z: 66 | d = zscore(d,axis=0) 67 | ds.append(d) 68 | dss = np.concatenate(ds,axis=0) 69 | return dss 70 | 71 | # dont want to return this as a pymvpa dataset; takes too long & is unnecessary 72 | def get_test_data(side, runs, num_subjects=23, z=True, mask=True): 73 | dss=[] 74 | for subj in subjects[:num_subjects]: 75 | ds = _get_raiders_data(subj, side, runs, z, mask) 76 | dss.append(ds) 77 | return np.array(dss) 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | -------------------------------------------------------------------------------- /run_anatomical_benchmarks.py: -------------------------------------------------------------------------------- 1 | #apply tests of fit to anatomically aligned data as control 2 | #EB 7.2020 3 | 4 | import sys, os, glob 5 | import numpy as np 6 | import benchmarks 7 | 8 | 9 | if __name__ == '__main__': 10 | dataset = sys.argv[1] 11 | if dataset == 'budapest': 12 | import budapest_utils as utils 13 | elif dataset == 'whiplash': 14 | import whiplash_utils as utils 15 | elif dataset == 'raiders': 16 | import raiders_utils as utils 17 | else: 18 | print('dataset must be one of [raiders, whiplash, budapest]') 19 | sys.exit(2) 20 | test_run = int(sys.argv[2]) 21 | outdir = os.path.join(utils.basedir,'anatomical',dataset,'fold_{x}'.format(x=run)) 22 | if not os.path.isdir(outdir): 23 | os.makedirs(outdir) 24 | test_lh = utils.get_test_data('lh',[run]) 25 | test_rh = utils.get_test_data('rh',[run]) 26 | lh_res, rh_res = benchmarks.vertex_isc(test_lh), benchmarks.vertex_isc(test_rh) 27 | np.save(outdir+'/vertex_isc_lh.npy', lh_res) 28 | np.save(outdir+'/vertex_isc_rh.npy', rh_res) 29 | test_data = np.concatenate((test_lh,test_rh),axis=2) 30 | res = benchmarks.dense_connectivity_profile_isc(test_data) 31 | np.save(outdir+'/dense_connectome_isc.npy', res) 32 | res = benchmarks.searchlight_timepoint_clf(test_data) 33 | 34 | 35 | -------------------------------------------------------------------------------- /whiplash_utils.py: -------------------------------------------------------------------------------- 1 | # whiplash_utils.py 2 | # utils file specific to whiplash dataset; all dataset-specific code will be kept here 3 | # erica busch, 2020 4 | import os,glob 5 | import pandas as pd 6 | import numpy as np 7 | from scipy.stats import zscore 8 | from mvpa2.datasets.base import Dataset 9 | 10 | basedir = '/dartfs/rc/lab/D/DBIC/DBIC/f002d44/h2a' 11 | connhyper_dir = os.path.join(basedir, 'connhyper','whiplash') 12 | resphyper_dir = os.path.join(basedir, 'response_hyper','whiplash') 13 | datadir = os.path.join(basedir, 'data','whiplash') 14 | orig_datadir = os.path.join(datadir, 'whiplash') 15 | connectome_dir = os.path.join(datadir, 'connectomes') 16 | results_dir = os.path.join(basedir, 'results','whiplash') 17 | h2a_dir = os.path.join(basedir, 'iterative_hyper') 18 | 19 | for dn in [connhyper_dir, resphyper_dir, h2a_dir, results_dir, iterative_HA_dir]: 20 | if not os.path.isdir(dn): 21 | os.makedirs(dn) 22 | print('made '+str(dn)) 23 | 24 | subj_df = pd.read_csv(os.path.join(datadir,'whiplash_subjects.csv'))['subject_id'] 25 | subjects = [s.split('sub-sid')[1] for s in sorted(list(subj_df))] 26 | num_subjects=len(subjects) 27 | 28 | SURFACE_RESOLUTION = 10242 29 | SEARCHLIGHT_RADIUS = 13 30 | HYPERALIGNMENT_RADIUS = 20 31 | TOT_TRs = 1770 32 | TOT_RUNS = 4 33 | NNODES_LH = 9372 34 | MASKS = {'lh':np.load(basedir+'/fsaverage_lh_mask.npy')[:SURFACE_RESOLUTION], 35 | 'rh':np.load(basedir+'/fsaverage_rh_mask.npy')[:SURFACE_RESOLUTION]} 36 | 37 | midstr = '_ses-3-task-movie_run-02' 38 | endstr = '_space-fsaverage_hemi-' 39 | 40 | # this dataset is collected all in one run so we have to manually divide the session into 4 runs. 41 | # we're going to make each of these 'runs' 443 TRs long 42 | n = int(round(TOT_TRs/TOT_RUNS)+1) 43 | TR_run_chunks = [np.arange(TOT_TRs)[i:i + n] for i in range(0, TOT_TRs, n)] 44 | 45 | def get_RHA_data(runs): 46 | dss=[] 47 | for run in runs: 48 | ds_lh = np.load(os.path.join(resphyper_dir, 'data','fold_{0}'.format(run), 'dss_lh.npy')) 49 | ds_rh = np.load(os.path.join(resphyper_dir,'data', 'fold_{0}'.format(run), 'dss_rh.npy')) 50 | dss.append(np.concatenate((ds_lh, ds_rh),axis=2)) 51 | return np.array(dss) 52 | 53 | # runs indicates which runs we want to return. 54 | # this will be useful for datafolding. 55 | def get_train_data(side, runs, num_subjects=num_subjects, z=True, mask=True): 56 | dss = [] 57 | TR_train = np.concatenate([TR_run_chunks[i-1] for i in runs]) 58 | for subj in subjects[:num_subjects]: 59 | data = _get_whiplash_data(subj, side, z, mask) 60 | data = data[TR_train,:] 61 | ds = Dataset(data) 62 | dss.append(ds) 63 | return dss 64 | 65 | # specific formatting for budapets data; only gets called internally. 66 | def _get_whiplash_data(subject, side, z, mask): 67 | LR = side.lower() 68 | if LR == 'b': 69 | return np.hstack([_get_whiplash_data(subject, 'lh', z, mask), 70 | _get_whiplash_data(subject, 'rh', z, mask)]) 71 | a = orig_datadir+'/*{s}_{LR}*.npy'.format(s=subject,LR=LR.lower()) 72 | fn = glob.glob(a)[0] 73 | d = np.nan_to_num(np.load(fn)) 74 | if mask: 75 | d = d[:,MASKS[side]] 76 | if z: 77 | d = zscore(d) 78 | return d 79 | 80 | # dont want to return this as a pymvpa dataset; takes too long & is unnecessary 81 | def get_test_data(side, runs, num_subjects=29, z=True, mask=True): 82 | dss=[] 83 | TR_test = np.concatenate([TR_run_chunks[i-1] for i in runs]) 84 | for subj in subjects[:num_subjects]: 85 | ds = _get_whiplash_data(subj, side, z, mask) 86 | ds = ds[TR_test,:] 87 | dss.append(ds) 88 | return np.array(dss) 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | --------------------------------------------------------------------------------