├── extra_images ├── readme.md ├── sim_1.png ├── sim_2.png ├── sim_3.png └── sim_4.png ├── concatenate.py ├── README.md ├── extended_simulations.md ├── run_simulation_sweep.py ├── simulation_experiment.py ├── visualize_sim.py ├── ridge_tools.py ├── variance_partitioning.ipynb ├── simulate.ipynb ├── stacking.py └── stacking_tutorial.ipynb /extra_images/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /extra_images/sim_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_1.png -------------------------------------------------------------------------------- /extra_images/sim_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_2.png -------------------------------------------------------------------------------- /extra_images/sim_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_3.png -------------------------------------------------------------------------------- /extra_images/sim_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_4.png -------------------------------------------------------------------------------- /concatenate.py: -------------------------------------------------------------------------------- 1 | # from cvxopt import matrix, solvers 2 | import numpy as np 3 | from scipy.stats import zscore 4 | from ridge_tools import cross_val_ridge, R2, ridge 5 | from stacking import feat_ridge_CV, get_cv_indices 6 | 7 | 8 | def concatenate_CV_fmri(data, features, method="cross_val_ridge", n_folds=5, score_f=R2): 9 | """ 10 | A function that concatenates feature spaces to predict fMRI signal. 11 | 12 | Args: 13 | - data (ndarray): A matrix of fMRI signal data with dimensions n_time x n_voxels. 14 | - features (list): A list of length n_features containing arrays of predictors 15 | with dimensions n_time x n_dim. 16 | - method (str): A string indicating the method to use to train the model. Default is "cross_val_ridge". 17 | - n_folds (int): An integer indicating the number of cross-validation folds to use. Default is 5. 18 | - score_f (function): A function to use for scoring the model. Default is R2. 19 | 20 | Returns: 21 | - A tuple containing the following element: 22 | - concat_r2s (float): The R2 score for the concatenated model predictions. 23 | 24 | """ 25 | 26 | n_time, n_voxels = data.shape 27 | n_features = len(features) 28 | 29 | ind = get_cv_indices(n_time, n_folds=n_folds) 30 | 31 | # create arrays to store predictions 32 | concat_pred = np.zeros((n_time, n_voxels)) 33 | 34 | # perform cross-validation by fold 35 | for ind_num in range(n_folds): 36 | # split data into training and testing sets 37 | train_ind = ind != ind_num 38 | test_ind = ind == ind_num 39 | train_data = data[train_ind] 40 | train_features = [F[train_ind] for F in features] 41 | test_data = data[test_ind] 42 | test_features = [F[test_ind] for F in features] 43 | 44 | # normalize data 45 | train_data = np.nan_to_num(zscore(train_data)) 46 | test_data = np.nan_to_num(zscore(test_data)) 47 | 48 | train_features = [np.nan_to_num(zscore(F)) for F in train_features] 49 | test_features = [np.nan_to_num(zscore(F)) for F in test_features] 50 | 51 | # Store predictions 52 | __,__, concat_pred[test_ind], __,__ = feat_ridge_CV(np.hstack(train_features), train_data, np.hstack(test_features), 53 | method=method) 54 | 55 | 56 | # Compute overall performance metrics 57 | data_zscored = zscore(data) 58 | 59 | concat_r2s = score_f(concat_pred, data_zscored) 60 | 61 | # return the results 62 | return ( 63 | concat_r2s, 64 | ) 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 🧠 Stacked regressions and structured variance partitioning for interpretable brain maps 2 | 3 | ## Overview 4 | 5 | This is a Python package that provides an implementation of stacked regression for functional MRI (fMRI) data. The package uses ridge regression to train models on multiple feature spaces and combines the predictions from these models using a weighted linear combination. The weights are learned using quadratic programming. 6 | 7 | > Here we presents an approach for brain mapping based on two proposed methods: stacking different encoding models and structured variance partitioning. This package is useful for researchers interested in aligning brain activity with different layers of a neural network, or with other types of correlated feature spaces. 8 | 9 | > Relating brain activity associated with a complex stimulus to different attributes of that stimulus is a powerful approach for constructing functional brain maps. However, when stimuli are naturalistic, their attributes are often correlated. These different attributes can act as confounders for each other and complicate the interpretability of brain maps. Correlations between attributes also impact the robustness of statistical estimators. 10 | 11 | > Each encoding model uses as input a feature space that describes a different stimulus attribute. The algorithm learns to predict the activity of a voxel as a linear combination of the individual encoding models. We show that the resulting unified model can predict held-out brain activity better or at least as well as the individual encoding models. Further, the weights of the linear combination are readily interpretable; they show the importance of each feature space for predicting a voxel. 12 | 13 | > We build on our stacking models to introduce a new variant of variance partitioning in which we rely on the known relationships between features during hypothesis testing. This approach, which we term structured variance partitioning, constraints the size of the hypothesis space and allows us to ask targeted questions about the similarity between feature spaces and brain regions even in the presence of correlations between the feature spaces. 14 | 15 | > We validate our approach in simulation, showcase its brain mapping potential on fMRI data, and release a Python package. 16 | 17 | ## Installation 18 | To use this code, simply install with pip: 19 | 20 | 21 | ```bash 22 | pip install stacking_fmri 23 | ``` 24 | 25 | 26 | ## Usage 27 | Here is an example of how to use the `stacking_fmri` function: 28 | ```python 29 | from stacking_fmri import stacking_fmri 30 | from sklearn.datasets import make_regression 31 | 32 | # Generate synthetic data 33 | X_train, y_train = make_regression(n_samples=50, n_features=1000, random_state=42) 34 | X_test, y_test = make_regression(n_samples=50, n_features=1000, random_state=43) 35 | 36 | # Generate random feature spaces 37 | n_features = 5 38 | train_features = [np.random.randn(X_train.shape[0], 10) for _ in range(n_features)] 39 | test_features = [np.random.randn(X_test.shape[0], 10) for _ in range(n_features)] 40 | 41 | # Train and test the model 42 | ( 43 | r2s, 44 | stacked_r2s, 45 | r2s_weighted, 46 | r2s_train, 47 | stacked_train_r2s, 48 | S, 49 | ) = stacking_fmri( 50 | X_train, 51 | X_test, 52 | train_features, 53 | test_features, 54 | method="cross_val_ridge", 55 | score_f=np.mean_squared_error, 56 | ) 57 | 58 | print("R2 scores for each feature and voxel:") 59 | print(r2s) 60 | print("\nWeighted R2 scores for each feature and voxel:") 61 | print(r2s_weighted) 62 | print("\nUnweighted R2 scores for the stacked predictions:") 63 | print(stacked_r2s) 64 | print("\nStacking weights:") 65 | print(S) 66 | ``` 67 | 68 | We also provide examples of how to use the package in jupyter notebooks: 69 | 70 | - stacking_tutorial.ipynb 71 | 72 | - variance_partitioning.ipynb 73 | 74 | 75 | 77 | 78 | 79 | ## Contributions 80 | Contributions are welcome! Please feel free to submit a pull request with your changes or open an issue to report a bug or suggest a new feature. 81 | 82 | 83 | ## Contact 84 | Created by [@lrg1213] - feel free to contact me! 85 | 86 | 87 | ## References 88 | [1] 89 | Ruogu Lin, Thomas Naselaris, Kendrick Kay, and Leila Wehbe (2023). 90 | Stacked regressions and structured variance partitioning for interpretable brain maps. 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /extended_simulations.md: -------------------------------------------------------------------------------- 1 | ## Correction and Extended Simulation Study 2 | 3 | ### Correction to the Original Paper 4 | 5 | In the original paper, we stated that the simulation experiment used a **correlation of 0.2 between the feature spaces**. However, due to a bug in our simulation code (which is now fixed), the results in the published article actually correspond to a **correlation of 0**. 6 | 7 | We now provide **two corrected examples** below generated using different correlations levels. 8 | 9 | --- 10 | 11 | ### New Examples with Correlated Feature Spaces 12 | 13 | #### **Example 1** 14 | ![Figure 1](./extra_images/sim_1.png) 15 | 16 | #### **Example 2** 17 | 18 | ![Figure 2](./extra_images/sim_2.png) 19 | 20 | 21 | As can be seen, in some parameter settings **stacking outperforms concatenation**, while in others **concatenation performs better**. 22 | 23 | As we emphasized in the paper, *these simulations cannot be taken as proof of either model's superiority*: there are infinitely many simulation configurations and the behaviors of stacking vs. concatenation depend heavily on the structure of the data. 24 | 25 | To deepen our understanding of this specific simulation setup, we performed a substantially more **comprehensive sweep** of parameters. 26 | 27 | --- 28 | 29 | ## Extended Simulation Sweep 30 | 31 | Below we report results from a broad grid of simulation settings (576 in total). 32 | For each configuration, we generate **100 independent datasets** using the same simulation framework as in the original study. 33 | 34 | This sweep varies the following parameters: 35 | 36 | - **Number of samples (n)** 50, 100, 200, 400 37 | - **Dimensions of feature spaces (d1,d2,d3,d4)**: [10, 10, 10, 10], [10, 100, 100, 100], [100, 100, 100, 100] 38 | - **Feature importance (alpha)**: [0.3, 0.3, 0.3, 0.1], [0.5, 0.2, 0.2, 0.1], [0.7, 0.1, 0.1, 0.1] 39 | - **Noise (sigma)**: 0, 0.5, 1, 1.5 40 | - **Feature correlation (correl)**: = [0, 0.1, 0.2, 0.3] 41 | 42 | Together, this produces a total of 576 total simulation settings. Each setting is run with 100 simulated datasets. 43 | 44 | ### Results 45 | Below we group the simulation results by parameter families to understand which factors most strongly influence whether stacking or concatenation performs better. 46 | 47 | In the figure below, we focus on one parameter at a time, and count the number of simulation settings (over all other parameters) for which stacking has reliably lower estimation error than concatenation, and the ones for which concatenation has reliably lower estimation error than stacking (similar to the right part of the two plots above, we count the number of stars and filled circles). 48 | 49 | ![Figure 3](./extra_images/sim_3.png) 50 | 51 | We see that smaller sample size n, higher noise sigma, low correlation, higher values of alpha1 and lower values of d1/d2 lead to better performance of stacking across the different settings of the other parameters. 52 | 53 | Instead of counting the number of experiment where one method is reliably better than the other, we compute the average error across experiments and settings, picking pairs of parameters at at time: 54 | 55 | ![Figure 4](./extra_images/sim_4.png) 56 | 57 | This visualization of the difference in average error highlights a relationship between n, sigma and alpha1. Small n and low sigma lead to concatenation performing better. Large alpha1 leads to stacking being better, but not at high n. 58 | 59 | ### Conclusion 60 | Across this **specific** simulation grid, we observe the following patterns: 61 | - Stacking can have lower error in lower sample sizes (n). 62 | - Stacking can have lower error in higher noise settings (sigma). 63 | - Stacking can have lower error when d1 is smaller than the other feature sizes. 64 | - Stacking can have lower error when alpha1 is large. 65 | 66 | Importantly, when we compare these findings to the real-data experiment in our paper: 67 | - Our dataset has 7 × 1024 = 7168 total features, which is on the same order as the ∼8000 training samples. 68 | - fMRI data has considerable noise. 69 | - Potentially, there could be a preference for a feature space (one of the alphas is large). 70 | 71 | Thus, the empirical conditions of the real dataset (and of data from experiments similar to the one we use) appear to correspond to simulation settings under which stacking tends to have an advantage, suggesting that stacking could be useful for building fMRI encoding models in naturalistic experiments. As we mentioned in the paper, even in the event where the experimenter thinks that concatenation might be better, they can include an additional feature space with the concatenated features. 72 | -------------------------------------------------------------------------------- /run_simulation_sweep.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import argparse 4 | import itertools 5 | import numpy as np 6 | import pandas as pd 7 | 8 | from simulation_experiment import run_one_simulation 9 | 10 | # Number of samples (n) 11 | N_LIST = [ 12 | 50,100,200,400 13 | ] 14 | 15 | # Feature dimensions for each of the 4 feature spaces 16 | DS_LIST = [ 17 | [10, 10, 10, 10], 18 | [10, 100, 100, 100], 19 | [100,100,100,100] 20 | ] 21 | 22 | # Alpha for each of the 4 feature spaces 23 | ALPHA_LIST = [ 24 | [0.3, 0.3, 0.3, 0.1], 25 | [0.5, 0.2, 0.2, 0.1], 26 | [0.7, 0.1, 0.1, 0.1], 27 | ] 28 | 29 | # Output noise (y_noise / sigma) 30 | NOISE_LIST = [ 31 | 0, 0.5, 1, 1.5 32 | ] 33 | 34 | # Correlation between feature spaces 35 | CORREL_LIST = [ 36 | 0, 0.1, 0.2, 0.3 37 | ] 38 | 39 | # Fixed parameters (you can change these too if you want) 40 | SCALE = 0.5 41 | Y_DIM = 2 42 | 43 | # How many Monte Carlo runs per parameter setting? 44 | N_RUNS_PER_SETTING = 50 45 | 46 | 47 | def build_tasks(): 48 | """ 49 | Flatten all (parameter combination × run) into a list of tasks. 50 | Each task gets a unique global_index so you can shard across machines. 51 | """ 52 | tasks = [] 53 | idx = 0 54 | 55 | for n, ds, alpha, noise, correl in itertools.product( 56 | N_LIST, DS_LIST, ALPHA_LIST, NOISE_LIST, CORREL_LIST 57 | ): 58 | for run_id in range(N_RUNS_PER_SETTING): 59 | tasks.append( 60 | dict( 61 | global_index=idx, 62 | n=int(n), 63 | ds=list(ds), 64 | alpha=list(alpha), 65 | noise=float(noise), 66 | correl=float(correl), 67 | run=run_id, 68 | ) 69 | ) 70 | idx += 1 71 | 72 | return tasks 73 | 74 | 75 | def main(): 76 | parser = argparse.ArgumentParser( 77 | description="Run a slice of the parameter sweep for the stacking vs concatenation simulations." 78 | ) 79 | parser.add_argument( 80 | "--start-index", 81 | type=int, 82 | required=True, 83 | help="First task index (inclusive) to run.", 84 | ) 85 | parser.add_argument( 86 | "--end-index", 87 | type=int, 88 | required=True, 89 | help="Last task index (inclusive) to run.", 90 | ) 91 | parser.add_argument( 92 | "--output", 93 | type=str, 94 | required=True, 95 | help="Output pickle file name, e.g. results_0_99.pkl", 96 | ) 97 | parser.add_argument( 98 | "--print-task-count", 99 | action="store_true", 100 | help="If set, only print total number of tasks and exit.", 101 | ) 102 | 103 | args = parser.parse_args() 104 | 105 | tasks = build_tasks() 106 | total_tasks = len(tasks) 107 | 108 | if args.print_task_count: 109 | print(f"Total tasks: {total_tasks}") 110 | return 111 | 112 | if total_tasks == 0: 113 | raise ValueError("No tasks were generated. Did you fill in the parameter lists?") 114 | 115 | start = max(0, args.start_index) 116 | end = min(total_tasks - 1, args.end_index) 117 | 118 | if start > end: 119 | raise ValueError( 120 | f"Invalid range: start_index={args.start_index}, end_index={args.end_index}, total_tasks={total_tasks}" 121 | ) 122 | 123 | selected_tasks = [t for t in tasks if start <= t["global_index"] <= end] 124 | 125 | print( 126 | f"Total tasks: {total_tasks}. " 127 | f"Running tasks {start}..{end} ({len(selected_tasks)} tasks) " 128 | f"into {args.output}" 129 | ) 130 | 131 | all_dfs = [] 132 | 133 | for t in selected_tasks: 134 | idx = t["global_index"] 135 | n = t["n"] 136 | ds = t["ds"] 137 | alpha = t["alpha"] 138 | noise = t["noise"] 139 | correl = t["correl"] 140 | run_id = t["run"] 141 | 142 | print( 143 | f"\n=== Task {idx} ===\n" 144 | f"n={n}, ds={ds}, alpha={alpha}, noise={noise}, correl={correl}, run={run_id}" 145 | ) 146 | 147 | # For reproducibility, seed based on global task index 148 | np.random.seed(idx) 149 | 150 | df = run_one_simulation( 151 | samples=n, 152 | ds=ds, 153 | scale=SCALE, 154 | correl=correl, 155 | alpha=alpha, 156 | y_dim=Y_DIM, 157 | y_noise=noise, 158 | ) 159 | 160 | # Attach parameter metadata to each row in this simulation 161 | df["n"] = n 162 | df["d1"], df["d2"], df["d3"], df["d4"] = ds 163 | df["alpha1"], df["alpha2"], df["alpha3"], df["alpha4"] = alpha 164 | df["sigma"] = noise 165 | df["correl"] = correl 166 | df["run"] = run_id 167 | df["task_index"] = idx 168 | 169 | all_dfs.append(df) 170 | 171 | if not all_dfs: 172 | print("No tasks were selected; nothing to save.") 173 | return 174 | 175 | result = pd.concat(all_dfs, ignore_index=True) 176 | result.to_pickle(args.output) 177 | print(f"\nSaved {len(result)} rows to {args.output}") 178 | 179 | 180 | if __name__ == "__main__": 181 | main() 182 | -------------------------------------------------------------------------------- /simulation_experiment.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import pandas as pd 4 | from ridge_tools import cross_val_ridge, fit_predict, R2 5 | import sys 6 | from stacking import stacking_CV_fmri 7 | from concatenate import concatenate_CV_fmri 8 | from scipy.stats import zscore, multivariate_normal #, wishart 9 | from scipy.linalg import toeplitz 10 | import time 11 | 12 | 13 | def toeplitz_cov(n, scale=1): 14 | return toeplitz(np.exp(-(np.arange(n*1.0))**2/(n*scale))) 15 | 16 | def feat_sample(n,ds,scale): 17 | Xs = [] 18 | for di in ds: 19 | X = multivariate_normal.rvs(np.zeros(di),cov=toeplitz_cov(di,scale),size=n)#reshape([n,di]) 20 | Xs.append(X) 21 | return Xs 22 | 23 | def data_sample(Xs,correl,ds,scale,alpha,data_dim,noise=0): 24 | assert len(Xs) == len(alpha) 25 | ts = [] 26 | y = 0 27 | d = sum(ds) 28 | cnt = 0 29 | wtot = multivariate_normal.rvs(mean=np.zeros(d),cov=toeplitz_cov(d,scale), 30 | size=data_dim).T#reshape([d,data_dim]) 31 | for iX, X in enumerate(Xs): 32 | w = wtot[cnt:cnt+ds[iX],:] 33 | cnt += ds[iX] 34 | t = zscore(X.dot(w)) 35 | ts.append(t) 36 | y += alpha[iX]*t 37 | # add correlated component 38 | if correl>0: 39 | if iX < len(Xs)-1: # add a correlated component from iX+1 to every iX 40 | X_tmp = Xs[iX+1].dot(multivariate_normal.rvs(np.zeros(ds[iX]),cov=toeplitz_cov(ds[iX],scale),size=ds[iX+1])) 41 | t_tmp = correl*zscore(X_tmp.dot(w)) 42 | Xs[iX] += correl*X_tmp 43 | y += alpha[iX]*t_tmp 44 | else: # increase contribution of last feature space to keep alphas meaningful 45 | y += alpha[iX]*t*correl 46 | Xs[iX] *= (1+ correl) 47 | y = zscore(y) 48 | ns = noise*np.random.randn(y.shape[0], y.shape[1]) 49 | y_orig = y 50 | y+= ns 51 | var_X = [R2(alpha[i]*ts[i],y) for i in range(len(Xs))] 52 | return y, var_X, Xs 53 | 54 | def sample_all_at_once(n,ds,scale,correl, alpha,data_dim,y_noise=0): 55 | Xs = feat_sample(n,ds,scale) 56 | y,var_X, Xs = data_sample(Xs,correl,ds,scale,alpha,data_dim,y_noise) 57 | return Xs, y, var_X 58 | 59 | # Experiment Functions 60 | 61 | import time 62 | 63 | def synexp(runs,sim_type, samples_settings,ds_settings,y_dim,alpha_settings,correl = 0,scale = 1, 64 | y_noise_settings=0): 65 | 66 | Results = pd.DataFrame() 67 | 68 | start = time.time() 69 | for run in range(runs): 70 | print('iteration number {}'.format(run+1)) 71 | if sim_type == 'Feat_Dim_ratio': # Vary the dimensionality of X1 with respect to other feature spaces 72 | for ds in ds_settings: 73 | df = run_one_simulation(samples_settings,ds,scale,correl,alpha_settings,y_dim, 74 | y_noise_settings) 75 | df['Feat_Dim_ratio'] = ds[0] 76 | Results = pd.concat([Results,df],ignore_index=True) 77 | elif sim_type == 'Cond': # Vary the weight of X1 with respect to other feature spaces 78 | for alpha in alpha_settings: 79 | df = run_one_simulation(samples_settings,ds_settings,scale,correl,alpha,y_dim, 80 | y_noise_settings) 81 | df['Cond'] = alpha[0] 82 | Results = pd.concat([Results,df],ignore_index=True) 83 | elif sim_type == 'Sample_Dim_ratio': # Vary the number of samples 84 | for samples in samples_settings: 85 | df = run_one_simulation(samples,ds_settings,scale,correl,alpha_settings,y_dim, 86 | y_noise_settings) 87 | df['Sample_Dim_ratio'] = samples 88 | Results = pd.concat([Results,df],ignore_index=True) 89 | elif sim_type == 'noise': # Vary the noise level 90 | for y_noise in y_noise_settings: 91 | df = run_one_simulation(samples_settings,ds_settings,scale,correl,alpha_settings,y_dim, 92 | y_noise) 93 | df['noise'] = y_noise 94 | Results = pd.concat([Results,df],ignore_index=True) 95 | elif sim_type == 'correl': # Vary the feature space correlation level 96 | for correl_v in correl: 97 | df = run_one_simulation(samples_settings,ds_settings,scale,correl_v,alpha_settings,y_dim, 98 | y_noise_settings) 99 | df['correl'] = correl_v 100 | Results = pd.concat([Results,df],ignore_index=True) 101 | 102 | if run==0: 103 | time_int = (time.time() - start) 104 | print("first iteration time: {}, total {}".format(int(time_int), int(time_int*runs))) 105 | 106 | 107 | 108 | return Results 109 | 110 | 111 | 112 | def run_one_simulation(samples,ds,scale,correl,alpha,y_dim,y_noise): 113 | Xs, y, var_X = sample_all_at_once(samples,ds,scale,correl,alpha,y_dim,y_noise) 114 | print('data sampled') 115 | time_begin = time.time() 116 | y = zscore(y) 117 | 118 | concat_X = np.hstack(Xs) 119 | 120 | result = stacking_CV_fmri(y,Xs, method = 'cross_val_ridge',n_folds = 4) 121 | result2 = stacking_CV_fmri(y,Xs[1:], method = 'cross_val_ridge',n_folds = 4) 122 | 123 | print('time for stacking: {}'.format(time.time()-time_begin)) 124 | 125 | df = pd.DataFrame() 126 | 127 | df['stacked'] = result[1] 128 | df['concat'] = concatenate_CV_fmri(y,Xs, method = 'cross_val_ridge',n_folds = 4)[0] 129 | df['max'] = np.max(result[0][0:2,:],axis=0) 130 | df['r2_X0'] = result[0][0] 131 | 132 | df['varpar_X0_concat'] = df['concat'] - concatenate_CV_fmri(y,Xs[1:], method = 'cross_val_ridge',n_folds = 4)[0] 133 | 134 | df['varpar_X0_stacked'] = df['stacked'] - result2[1] 135 | 136 | df['varpar_X0_real'] = var_X[0] 137 | 138 | df['weight_0'] = result[5][:,0] 139 | df['alpha_0'] = alpha[0] 140 | 141 | print('time for iteration: {}'.format(time.time()-time_begin)) 142 | 143 | 144 | return df 145 | 146 | -------------------------------------------------------------------------------- /visualize_sim.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | 3 | font = {'family' : 'sans-serif', 4 | 'weight' : 'normal', 5 | 'size' : 18} 6 | import matplotlib 7 | matplotlib.rc('font', **font) 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | 12 | def box_plot(ax, data, edge_color, fill_color, positions=None): 13 | bp = ax.boxplot(data, patch_artist=True) 14 | 15 | for element in ['boxes', 'whiskers', 'fliers', 'means', 'medians', 'caps']: 16 | plt.setp(bp[element], color=edge_color) 17 | 18 | for patch in bp['boxes']: 19 | patch.set(facecolor=fill_color) 20 | 21 | return bp 22 | 23 | 24 | def sim_plots(Results,name,settings,filename='',var_dict={}, ylim0 = [-0.1, 1], 25 | ylim1=[0,0.5],ylim2=[0,0.5]): # shift=0.2,: 26 | 27 | import math 28 | 29 | var = sorted(list(set(Results[name].tolist()))) 30 | print(var) 31 | 32 | stack_avg = [] 33 | concat_avg = [] 34 | max_avg = [] 35 | varpar_X0_concat = [] 36 | varpar_X0_stacked = [] 37 | varpar_X0_real = [] 38 | 39 | 40 | for px in var: 41 | stack_avg.append(Results[Results[name]==px]['stacked']) 42 | concat_avg.append(Results[Results[name]==px]['concat']) 43 | max_avg.append(Results[Results[name]==px]['max']) 44 | varpar_X0_real.append(Results[Results[name]==px]['varpar_X0_real']) 45 | varpar_X0_concat.append(Results[Results[name]==px]['varpar_X0_concat']) 46 | varpar_X0_stacked.append(Results[Results[name]==px]['varpar_X0_stacked']) 47 | 48 | c_fill = 'white' 49 | pos = None 50 | 51 | fig, axs = plt.subplots(1, 2,figsize=(20,5)) 52 | 53 | 54 | varpar_X0_concat = np.array(varpar_X0_concat) 55 | varpar_X0_stacked = np.array(varpar_X0_stacked) 56 | varpar_X0_real = np.array(varpar_X0_real) 57 | 58 | import pandas as pd 59 | vec = [[np.abs(varpar_X0_stacked[k,i] - varpar_X0_real[k,i]),k,'stack'] for i in range(varpar_X0_stacked.shape[1]) 60 | for k in range(varpar_X0_stacked.shape[0])] 61 | vec_concat = [[np.abs(varpar_X0_concat[k,i] - varpar_X0_real[k,i]),k,'concat'] for i in range(varpar_X0_concat.shape[1]) 62 | for k in range(varpar_X0_concat.shape[0])] 63 | r2_results = pd.concat([pd.DataFrame(vec,columns = ['Err','setting','method']), 64 | pd.DataFrame(vec_concat,columns = ['Err','setting','method'])]) 65 | 66 | sns.boxplot(ax=axs[1], x="setting", y="Err", hue="method",data=r2_results, palette="Set2") 67 | sns.move_legend(axs[1], "upper left", bbox_to_anchor=(1, 1)) 68 | 69 | 70 | axs[1].set_ylim(ylim1) 71 | if name=='noise': 72 | axs[1].set_xlabel(r'$\sigma$') 73 | if name=='Feat_Dim_ratio': 74 | axs[1].set_xlabel(r'$d_1$') 75 | if name=='Cond': 76 | axs[1].set_xlabel(r'$\alpha_1$') 77 | if name=='Sample_Dim_ratio': 78 | axs[1].set_xlabel(r'$n$') 79 | 80 | 81 | axs[1].set_ylabel(r'Var. par. error for ${\bf x}_1$') 82 | axs[1].set_ylim(ylim2) 83 | axs[1].set_xticks(range(len(settings))) 84 | axs[1].set_xticklabels(settings,rotation=45) 85 | 86 | from scipy.stats import ttest_rel 87 | 88 | for idx,s in enumerate(settings): 89 | setting_res = r2_results[r2_results['setting']==idx] 90 | stack_res = [e for e in setting_res[setting_res['method']=='stack']['Err']] 91 | concat_res = [e for e in setting_res[setting_res['method']=='concat']['Err']] 92 | for i in np.arange(len(stack_res)): 93 | axs[1].plot([idx-0.22,idx+0.22], [stack_res[i],concat_res[i]],'-',linewidth=0.25,color='black') 94 | tstat,pval = ttest_rel( np.array(concat_res), np.array(stack_res),alternative='two-sided') 95 | if pval<0.05/20: 96 | if tstat>0: 97 | axs[1].plot(idx,0.45,'*',color = 'black',markersize=10) 98 | else: 99 | axs[1].plot(idx,0.45,'o',color = 'black',markersize=10) 100 | # blabla 101 | 102 | stack_avg = np.array(stack_avg) 103 | max_avg = np.array(max_avg) 104 | concat_avg = np.array(concat_avg) 105 | 106 | import pandas as pd 107 | vec = [[stack_avg[k,i] ,k,'stack'] for i in range(stack_avg.shape[1]) 108 | for k in range(stack_avg.shape[0])] 109 | vec_concat = [[concat_avg[k,i] ,k,'concat'] for i in range(concat_avg.shape[1]) 110 | for k in range(concat_avg.shape[0])] 111 | vec_max = [[max_avg[k,i] ,k,'max'] for i in range(max_avg.shape[1]) 112 | for k in range(max_avg.shape[0])] 113 | r2_results = pd.concat([pd.DataFrame(vec,columns = ['R2','setting','method']), 114 | pd.DataFrame(vec_concat,columns = ['R2','setting','method']), 115 | pd.DataFrame(vec_max,columns = ['R2','setting','method'])]) 116 | 117 | sns.boxplot(ax=axs[0], x="setting", y="R2", hue="method",data=r2_results, palette="hot_r") 118 | sns.move_legend(axs[0], "upper right", bbox_to_anchor=(-0.18, 1)) 119 | 120 | mean_stack = np.mean(stack_avg) 121 | if name=='noise': 122 | axs[0].set_xlabel(r'$\sigma$') 123 | if name=='Feat_Dim_ratio': 124 | axs[0].set_xlabel(r'$d_1$') 125 | if name=='Cond': 126 | axs[0].set_xlabel(r'$\alpha_1$') 127 | if name=='Sample_Dim_ratio': 128 | axs[0].set_xlabel(r'$n$') 129 | if name=='correl': 130 | axs[0].set_xlabel(r'$\rho$') 131 | 132 | for idx,s in enumerate(settings): 133 | for i in np.arange(stack_avg.shape[1]): 134 | axs[0].plot([idx-0.25,idx,idx+0.25], [stack_avg[idx,i],concat_avg[idx,i],max_avg[idx,i]], 135 | '-',linewidth=0.15,color='black') 136 | # bla 137 | 138 | axs[0].set_ylabel(r'$R^2$') 139 | axs[0].set_ylim(ylim0) 140 | axs[0].set_xticks(range(len(settings))) 141 | axs[0].set_xticklabels(settings,rotation=45) 142 | 143 | 144 | 145 | 146 | if name=='noise': 147 | plt.suptitle(r'Vary $\sigma:~\alpha = {},~d={},~n={}$'.format(var_dict['alphas'],var_dict['ds'],var_dict['n'])) 148 | if name=='Feat_Dim_ratio': 149 | plt.suptitle(r'Vary $d_1:~\sigma = {},~\alpha = {},~d_2+d_3+d_4={},~n={}$'.format(var_dict['sigma'], var_dict['alphas'], var_dict['d_sum'], var_dict['n'])) 150 | if name=='Cond': 151 | plt.suptitle(r'Vary $\alpha_1:~\sigma = {},~d={},~n={}$'.format(var_dict['sigma'], var_dict['ds'],var_dict['n'])) 152 | if name=='Sample_Dim_ratio': 153 | plt.suptitle(r'Vary $n:~\sigma = {},~\alpha = {},~d={}$'.format(var_dict['sigma'],var_dict['alpha'],var_dict['ds'])) 154 | if name=='correl': 155 | plt.suptitle(r'Vary $\rho:~\sigma = {},~\alpha = {},~d={}, ~n={}$'.format(var_dict['sigma'],var_dict['alpha'],var_dict['ds'], 156 | var_dict['n'])) 157 | 158 | 159 | 160 | import time 161 | timestr = time.strftime("%Y%m%d-%H%M%S") 162 | print(timestr) 163 | plt.tight_layout() 164 | plt.savefig(filename+timestr+'.jpg') 165 | plt.show() 166 | 167 | -------------------------------------------------------------------------------- /ridge_tools.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import time 4 | import numpy as np 5 | from scipy.stats import zscore 6 | from numpy.linalg import inv, svd 7 | from sklearn.model_selection import KFold 8 | from sklearn.linear_model import Ridge, RidgeCV 9 | 10 | 11 | def corr(X, Y, axis=0): 12 | """Compute correlation coefficient.""" 13 | return np.mean(zscore(X) * zscore(Y), axis) 14 | 15 | 16 | def R2(Pred, Real): 17 | """Compute coefficient of determination (R^2).""" 18 | SSres = np.mean((Real - Pred) ** 2, 0) 19 | SStot = np.var(Real, 0) 20 | return np.nan_to_num(1 - SSres / SStot) 21 | 22 | 23 | def fit_predict(data, features, method="plain", n_folds=10): 24 | """ 25 | Fit and predict using cross-validated Ridge regression. 26 | 27 | Args: 28 | data (numpy.ndarray): The data array. 29 | features (numpy.ndarray): The features array. 30 | method (str): The Ridge regression method. Defaults to 'plain'. 31 | n_folds (int): The number of folds for cross-validation. Defaults to 10. 32 | 33 | Returns: 34 | tuple: Tuple containing the correlation and R^2 values. 35 | """ 36 | n, v = data.shape 37 | p = features.shape[1] 38 | corrs = np.zeros((n_folds, v)) 39 | R2s = np.zeros((n_folds, v)) 40 | ind = CV_ind(n, n_folds) 41 | preds_all = np.zeros_like(data) 42 | 43 | for i in range(n_folds): 44 | train_data = np.nan_to_num(zscore(data[ind != i])) 45 | train_features = np.nan_to_num(zscore(features[ind != i])) 46 | test_data = np.nan_to_num(zscore(data[ind == i])) 47 | test_features = np.nan_to_num(zscore(features[ind == i])) 48 | weights, __ = cross_val_ridge(train_features, train_data, method=method) 49 | preds = np.dot(test_features, weights) 50 | preds_all[ind == i] = preds 51 | 52 | corrs = corr(preds_all, data) 53 | R2s = R2(preds_all, data) 54 | 55 | return corrs, R2s 56 | 57 | 58 | def CV_ind(n, n_folds): 59 | """Generate cross-validation indices.""" 60 | ind = np.zeros((n)) 61 | n_items = int(np.floor(n / n_folds)) 62 | 63 | for i in range(0, n_folds - 1): 64 | ind[i * n_items : (i + 1) * n_items] = i 65 | 66 | ind[(n_folds - 1) * n_items :] = n_folds - 1 67 | 68 | return ind 69 | 70 | 71 | def R2r(Pred, Real): 72 | """Compute square root of R^2.""" 73 | R2rs = R2(Pred, Real) 74 | ind_neg = R2rs < 0 75 | R2rs = np.abs(R2rs) 76 | R2rs = np.sqrt(R2rs) 77 | R2rs[ind_neg] *= -1 78 | 79 | return R2rs 80 | 81 | 82 | def ridge(X, Y, lmbda): 83 | """Compute ridge regression weights.""" 84 | return np.dot(inv(X.T.dot(X) + lmbda * np.eye(X.shape[1])), X.T.dot(Y)) 85 | 86 | 87 | def ridge_by_lambda(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])): 88 | """Compute validation errors for ridge regression with different lambda values.""" 89 | error = np.zeros((lambdas.shape[0], Y.shape[1])) 90 | for idx, lmbda in enumerate(lambdas): 91 | weights = ridge(X, Y, lmbda) 92 | error[idx] = 1 - R2(np.dot(Xval, weights), Yval) 93 | return error 94 | 95 | def ridge_sk(X, Y, lmbda): 96 | """Compute ridge regression weights using scikit-learn.""" 97 | rd = Ridge(alpha=lmbda) 98 | rd.fit(X, Y) 99 | return rd.coef_.T 100 | 101 | 102 | def ridgeCV_sk(X, Y, lmbdas): 103 | """Compute ridge regression weights using scikit-learn with cross-validation.""" 104 | rd = RidgeCV(alphas=lmbdas, solver="svd") 105 | rd.fit(X, Y) 106 | return rd.coef_.T 107 | 108 | 109 | def ridge_by_lambda_sk(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])): 110 | """Compute validation errors for ridge regression with different lambda values using scikit-learn.""" 111 | error = np.zeros((lambdas.shape[0], Y.shape[1])) 112 | for idx, lmbda in enumerate(lambdas): 113 | weights = ridge_sk(X, Y, lmbda) 114 | error[idx] = 1 - R2(np.dot(Xval, weights), Yval) 115 | return error 116 | 117 | 118 | def ridge_svd(X, Y, lmbda): 119 | """ 120 | Ridge regression using singular value decomposition (SVD). 121 | """ 122 | U, s, Vt = svd(X, full_matrices=False) 123 | d = s / (s**2 + lmbda) 124 | return np.dot(Vt, np.diag(d).dot(U.T.dot(Y))) 125 | 126 | 127 | def ridge_by_lambda_svd(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])): 128 | """ 129 | Calculate the validation error of ridge regression using SVD for different lambdas. 130 | """ 131 | error = np.zeros((lambdas.shape[0], Y.shape[1])) 132 | U, s, Vt = svd(X, full_matrices=False) 133 | for idx, lmbda in enumerate(lambdas): 134 | d = s / (s**2 + lmbda) 135 | weights = np.dot(Vt, np.diag(d).dot(U.T.dot(Y))) 136 | error[idx] = 1 - R2(np.dot(Xval, weights), Yval) 137 | return error 138 | 139 | 140 | def cross_val_ridge( 141 | train_features, 142 | train_data, 143 | n_splits=10, 144 | lambdas=np.array([10**i for i in range(-6, 10)]), 145 | method="plain", 146 | do_plot=False, 147 | ): 148 | """ 149 | Cross validation for ridge regression. 150 | 151 | Args: 152 | train_features (array): Array of training features. 153 | train_data (array): Array of training data. 154 | lambdas (array): Array of lambda values for Ridge regression. 155 | Default is [10^i for i in range(-6, 10)]. 156 | 157 | Returns: 158 | weightMatrix (array): Array of weights for the Ridge regression. 159 | r (array): Array of regularization parameters. 160 | 161 | """ 162 | 163 | ridge_1 = { 164 | "plain": ridge_by_lambda, 165 | "svd": ridge_by_lambda_svd, 166 | "ridge_sk": ridge_by_lambda_sk, 167 | }[ 168 | method 169 | ] # loss of the regressor 170 | 171 | ridge_2 = {"plain": ridge, "svd": ridge_svd, "ridge_sk": ridge_sk,}[ 172 | method 173 | ] # solver for the weights 174 | 175 | n_voxels = train_data.shape[1] # get number of voxels from data 176 | nL = lambdas.shape[0] # get number of hyperparameter (lambdas) from setting 177 | r_cv = np.zeros((nL, train_data.shape[1])) # loss matrix 178 | 179 | kf = KFold(n_splits=n_splits) # set up dataset for cross validation 180 | start_t = time.time() # record start time 181 | for icv, (trn, val) in enumerate(kf.split(train_data)): 182 | cost = ridge_1( 183 | zscore(train_features[trn]), 184 | zscore(train_data[trn]), 185 | zscore(train_features[val]), 186 | zscore(train_data[val]), 187 | lambdas=lambdas, 188 | ) # loss of regressor 1 189 | 190 | if do_plot: 191 | import matplotlib.pyplot as plt 192 | 193 | plt.figure() 194 | plt.imshow(cost, aspect="auto") 195 | 196 | r_cv += cost 197 | 198 | if do_plot: # show loss 199 | plt.figure() 200 | plt.imshow(r_cv, aspect="auto", cmap="RdBu_r") 201 | 202 | argmin_lambda = np.argmin(r_cv, axis=0) # pick the best lambda 203 | weights = np.zeros( 204 | (train_features.shape[1], train_data.shape[1]) 205 | ) # initialize the weight 206 | for idx_lambda in range( 207 | lambdas.shape[0] 208 | ): # this is much faster than iterating over voxels! 209 | idx_vox = argmin_lambda == idx_lambda 210 | if np.sum(idx_vox*1.0)>0: 211 | weights[:, idx_vox] = ridge_2( 212 | train_features, train_data[:, idx_vox], lambdas[idx_lambda] 213 | ) 214 | 215 | if do_plot: # show the weights 216 | plt.figure() 217 | plt.imshow(weights, aspect="auto", cmap="RdBu_r", vmin=-0.5, vmax=0.5) 218 | 219 | return weights, np.array([lambdas[i] for i in argmin_lambda]) 220 | -------------------------------------------------------------------------------- /variance_partitioning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "2ed69d89", 6 | "metadata": {}, 7 | "source": [ 8 | "### This notebook demonstrates the variance partitioning process for stacking or concatenation. The experiment is conducted in either a forward or backward style, followed by loading the results and applying the \"95% criteria\" to determine feature attribution." 9 | ] 10 | }, 11 | { 12 | "cell_type": "markdown", 13 | "id": "de68d00f", 14 | "metadata": {}, 15 | "source": [ 16 | "## Forward" 17 | ] 18 | }, 19 | { 20 | "cell_type": "code", 21 | "execution_count": null, 22 | "id": "b25746b2", 23 | "metadata": {}, 24 | "outputs": [], 25 | "source": [ 26 | "import pickle\n", 27 | "subj = 1\n", 28 | "r2s = []\n", 29 | "r2sr = []\n", 30 | "\n", 31 | "\n", 32 | "# Get R2 for using only the first layer (Conv1 in AlexNet) and the last layer(FC7 in AlexNet)\n", 33 | "r2 = []\n", 34 | "for i in range(1,23):\n", 35 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_r2s_{}.npy'.format(0,i)).squeeze())\n", 36 | "print(np.hstack(r2).shape)\n", 37 | "r2_1 = np.hstack(r2)[0,:]\n", 38 | "r2_7 = np.hstack(r2)[-1,:]\n", 39 | "\n", 40 | "\n", 41 | "# Loading forward experiments results\n", 42 | "for k in range(6):\n", 43 | " r2 = []\n", 44 | " for i in range(1,23):\n", 45 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_stacked_r2s_{}.npy'.format(k,i)).squeeze())\n", 46 | " r2 = np.hstack(r2)\n", 47 | " r2s.append(r2)\n", 48 | "r2s.append(r2_7) \n", 49 | "r2s = np.array(r2s)\n", 50 | "print(r2s.shape)\n", 51 | " \n", 52 | "\n", 53 | "# Calculate variance by difference of R2\n", 54 | "vp = np.zeros((r2s.shape[0],r2s.shape[1]))\n", 55 | "vp_square = np.zeros((r2s.shape[0],r2s.shape[1]))\n", 56 | "for j in range(vp.shape[0]-1):\n", 57 | " vp[j,:] = np.sqrt(r2s[j,:] - r2s[j+1,:])\n", 58 | " vp_square[j,:] = (r2s[j,:] - r2s[j+1,:])\n", 59 | " \n", 60 | "vp[-1] = np.sqrt(r2_7)\n", 61 | "vp_square[-1,:] = r2_7\n", 62 | "\n", 63 | "\n", 64 | "print(vp.shape)\n", 65 | "print(vp_square.shape)\n" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "f71d0411", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "# Apply our 95% Criteria in forward style: \n", 76 | "# If removing feature from one specific layer will cause the encoding performance lower to less than 95% of the performance using features from all layers,\n", 77 | "# We will assign that voxel to this layer as our feature attribution result\n", 78 | "\n", 79 | "\n", 80 | "vp_sel_layer = np.zeros(r2s.shape[1])\n", 81 | "\n", 82 | "\n", 83 | "for i in range(r2s.shape[1]):\n", 84 | " if r2s[0,i]<=0:\n", 85 | " vp_sel_layer[i] = float('nan')\n", 86 | " continue\n", 87 | " if r2s[1,i]<0.95*r2s[0,i]:\n", 88 | " vp_sel_layer[i] = 1\n", 89 | " continue\n", 90 | " if r2s[2,i]<0.95*r2s[0,i]:\n", 91 | " vp_sel_layer[i] = 2\n", 92 | " continue\n", 93 | " if r2s[3,i]<0.95*r2s[0,i]:\n", 94 | " vp_sel_layer[i] = 3\n", 95 | " continue\n", 96 | " if r2s[4,i]<0.95*r2s[0,i]:\n", 97 | " vp_sel_layer[i] = 4\n", 98 | " continue\n", 99 | " if r2s[5,i]<0.95*r2s[0,i]:\n", 100 | " vp_sel_layer[i] = 5\n", 101 | " continue\n", 102 | " if r2s[6,i]<0.95*r2s[0,i]:\n", 103 | " vp_sel_layer[i] = 6\n", 104 | " continue\n", 105 | " vp_sel_layer[i] = 7\n", 106 | " \n", 107 | " \n", 108 | "print(vp_sel_layer)\n" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": null, 114 | "id": "a6396b02", 115 | "metadata": {}, 116 | "outputs": [], 117 | "source": [] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "id": "c91cc1cb", 123 | "metadata": {}, 124 | "outputs": [], 125 | "source": [] 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "id": "ff7877ef", 130 | "metadata": {}, 131 | "source": [ 132 | "## Backward" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": null, 138 | "id": "ee6bf576", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "import pickle\n", 143 | "subj = 1\n", 144 | "r2sr = []\n", 145 | "\n", 146 | "\n", 147 | "# Get R2 for using only the first layer (Conv1 in AlexNet) and the last layer(FC7 in AlexNet)\n", 148 | "r2 = []\n", 149 | "for i in range(1,23):\n", 150 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_r2s_{}.npy'.format(0,i)).squeeze())\n", 151 | "print(np.hstack(r2).shape)\n", 152 | "r2_1 = np.hstack(r2)[0,:]\n", 153 | "r2_7 = np.hstack(r2)[-1,:]\n", 154 | "\n", 155 | "\n", 156 | "\n", 157 | "# Loading backward experiments results\n", 158 | "for k in range(6):\n", 159 | " r2 = []\n", 160 | " for i in range(1,23):\n", 161 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_r_stacked_r2s_{}.npy'.format(k,i)).squeeze())\n", 162 | " r2 = np.hstack(r2)\n", 163 | " r2sr.append(r2)\n", 164 | "r2sr.append(r2_1) \n", 165 | "r2sr = np.array(r2sr)\n", 166 | "\n", 167 | "print(r2sr.shape)\n", 168 | " \n", 169 | "\n", 170 | "# Calculate variance by difference of R2\n", 171 | "vpr = np.zeros((r2sr.shape[0],r2sr.shape[1]))\n", 172 | "vpr_square = np.zeros((r2sr.shape[0],r2sr.shape[1]))\n", 173 | "for j in range(vpr.shape[0]-1):\n", 174 | " vpr[j,:] = np.sqrt(r2sr[j,:] - r2sr[j+1,:])\n", 175 | " vpr_square[j,:] = (r2sr[j,:] - r2sr[j+1,:])\n", 176 | "\n", 177 | "vpr[-1] = np.sqrt(r2_1)\n", 178 | "vpr_square[-1,:] = r2_1\n", 179 | "\n", 180 | "print(vpr.shape)\n", 181 | "print(vpr_square.shape)\n" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "760923a9", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "# Apply our 95% Criteria in backward style: \n", 192 | "# If removing feature from one specific layer will cause the encoding performance lower to less than 95% of the performance using features from all layers,\n", 193 | "# We will assign that voxel to this layer as our feature attribution result\n", 194 | "\n", 195 | "\n", 196 | "vpr_sel_layer = np.zeros(r2sr.shape[1])\n", 197 | "\n", 198 | "for i in range(r2sr.shape[1]):\n", 199 | " if r2sr[0,i]<=0:\n", 200 | " vpr_sel_layer[i] = float('nan')\n", 201 | " continue\n", 202 | " if r2sr[1,i]<0.95*r2sr[0,i]:\n", 203 | " vpr_sel_layer[i] = 7\n", 204 | " continue\n", 205 | " if r2sr[2,i]<0.95*r2sr[0,i]:\n", 206 | " vpr_sel_layer[i] = 6\n", 207 | " continue\n", 208 | " if r2sr[3,i]<0.95*r2sr[0,i]:\n", 209 | " vpr_sel_layer[i] = 5\n", 210 | " continue\n", 211 | " if r2sr[4,i]<0.95*r2sr[0,i]:\n", 212 | " vpr_sel_layer[i] = 4\n", 213 | " continue\n", 214 | " if r2sr[5,i]<0.95*r2sr[0,i]:\n", 215 | " vpr_sel_layer[i] = 3\n", 216 | " continue\n", 217 | " if r2sr[6,i]<0.95*r2sr[0,i]:\n", 218 | " vpr_sel_layer[i] = 2\n", 219 | " continue\n", 220 | " vpr_sel_layer[i] = 1\n", 221 | " \n", 222 | "print(vpr_sel_layer)\n" 223 | ] 224 | } 225 | ], 226 | "metadata": { 227 | "kernelspec": { 228 | "display_name": "Python 3 (ipykernel)", 229 | "language": "python", 230 | "name": "python3" 231 | }, 232 | "language_info": { 233 | "codemirror_mode": { 234 | "name": "ipython", 235 | "version": 3 236 | }, 237 | "file_extension": ".py", 238 | "mimetype": "text/x-python", 239 | "name": "python", 240 | "nbconvert_exporter": "python", 241 | "pygments_lexer": "ipython3", 242 | "version": "3.8.13" 243 | } 244 | }, 245 | "nbformat": 4, 246 | "nbformat_minor": 5 247 | } 248 | -------------------------------------------------------------------------------- /simulate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "7e439802", 6 | "metadata": {}, 7 | "source": [ 8 | "# simulate stacking experiment" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "f2850c2e-1883-4126-8140-901d5898eb7e", 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import numpy as np\n", 19 | "from simulation_experiment import synexp\n", 20 | "from visualize_sim import sim_plots\n", 21 | "import pandas as pd\n", 22 | "\n", 23 | "n_runs = 50 # how many simulation experiments\n", 24 | "run_sim = True # run experiment?\n", 25 | "plot_sim = True # plot results?\n", 26 | "\n", 27 | "type_sim = 1 # what parameter to vary?\n", 28 | "# 1 - Vary the dimensionality of X1 with respect to other feature spaces\n", 29 | "# 2 - Vary the weight of X1 with respect to other feature spaces\n", 30 | "# 3 - Vary the noise level\n", 31 | "# 4 - Vary the number of samples\n", 32 | "# 5 - Vary feature space correlation\n", 33 | "\n", 34 | "# Note, the experiments below are from the paper, you can change the parameters as you want\n", 35 | "# It might be more feasible to make a script and call synexp non-interactively, especially for large experiments" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "b8faefb0-2b21-46b2-9d2b-6b03b03b5251", 42 | "metadata": {}, 43 | "outputs": [ 44 | { 45 | "name": "stdout", 46 | "output_type": "stream", 47 | "text": [ 48 | "iteration number 1\n", 49 | "data sampled\n", 50 | "time for stacking: 5.576988935470581\n", 51 | "time for iteration: 16.36051297187805\n", 52 | "data sampled\n", 53 | "time for stacking: 5.68248987197876\n", 54 | "time for iteration: 17.94617486000061\n", 55 | "data sampled\n", 56 | "time for stacking: 10.574278831481934\n", 57 | "time for iteration: 24.7062509059906\n", 58 | "data sampled\n", 59 | "time for stacking: 7.50465989112854\n", 60 | "time for iteration: 21.93230676651001\n", 61 | "data sampled\n", 62 | "time for stacking: 8.925495147705078\n", 63 | "time for iteration: 29.840813159942627\n", 64 | "first iteration time: 111, total 5556\n", 65 | "iteration number 2\n", 66 | "data sampled\n", 67 | "time for stacking: 13.184641122817993\n", 68 | "time for iteration: 35.82826232910156\n", 69 | "data sampled\n", 70 | "time for stacking: 15.410903215408325\n", 71 | "time for iteration: 36.77866816520691\n", 72 | "data sampled\n", 73 | "time for stacking: 12.556435108184814\n", 74 | "time for iteration: 30.479616165161133\n", 75 | "data sampled\n", 76 | "time for stacking: 12.968762159347534\n", 77 | "time for iteration: 31.550554275512695\n", 78 | "data sampled\n", 79 | "time for stacking: 9.648847818374634\n", 80 | "time for iteration: 26.950864791870117\n", 81 | "iteration number 3\n", 82 | "data sampled\n", 83 | "time for stacking: 9.796719074249268\n", 84 | "time for iteration: 26.401182889938354\n", 85 | "data sampled\n", 86 | "time for stacking: 8.374092102050781\n", 87 | "time for iteration: 23.136788845062256\n", 88 | "data sampled\n", 89 | "time for stacking: 8.078734874725342\n", 90 | "time for iteration: 30.22462296485901\n", 91 | "data sampled\n", 92 | "time for stacking: 13.18555498123169\n", 93 | "time for iteration: 34.938925981521606\n", 94 | "data sampled\n", 95 | "time for stacking: 13.350465059280396\n", 96 | "time for iteration: 33.98402714729309\n", 97 | "iteration number 4\n", 98 | "data sampled\n", 99 | "time for stacking: 11.148561239242554\n", 100 | "time for iteration: 29.419289112091064\n", 101 | "data sampled\n", 102 | "time for stacking: 10.736790180206299\n", 103 | "time for iteration: 30.648571968078613\n", 104 | "data sampled\n", 105 | "time for stacking: 12.025328159332275\n", 106 | "time for iteration: 28.129180908203125\n", 107 | "data sampled\n", 108 | "time for stacking: 9.363476276397705\n", 109 | "time for iteration: 22.97734808921814\n", 110 | "data sampled\n", 111 | "time for stacking: 7.306324005126953\n", 112 | "time for iteration: 20.82240104675293\n", 113 | "iteration number 5\n", 114 | "data sampled\n", 115 | "time for stacking: 7.735720157623291\n", 116 | "time for iteration: 24.28347420692444\n", 117 | "data sampled\n", 118 | "time for stacking: 8.583776950836182\n", 119 | "time for iteration: 23.722854137420654\n", 120 | "data sampled\n", 121 | "time for stacking: 8.153543710708618\n", 122 | "time for iteration: 24.09500002861023\n", 123 | "data sampled\n", 124 | "time for stacking: 12.154590129852295\n" 125 | ] 126 | } 127 | ], 128 | "source": [ 129 | "import random\n", 130 | "random.seed(10)\n", 131 | "\n", 132 | "type_sim = 5 # what parameter to vary?\n", 133 | "\n", 134 | "\n", 135 | "version = 'v5'\n", 136 | "\n", 137 | "# ds_settings - dimentionality of each feature space\n", 138 | "# alpha - stacking weights for simulated data\n", 139 | "# n - number of samples \n", 140 | "# sigma - noise evel\n", 141 | "\n", 142 | "if type_sim ==1: # Vary the dimensionality of X1 with respect to other feature spaces\n", 143 | "\n", 144 | " ds_settings = [[2,10,10,10],[5,10,10,10],[10,10,10,10],[20,10,10,10],[40,10,10,10]]\n", 145 | " n = 100\n", 146 | " sigma = 0.2\n", 147 | " correl = 0.2\n", 148 | " alpha = [0.5,0.2,0.2,0.1] \n", 149 | "\n", 150 | " print('setup')\n", 151 | "\n", 152 | " if run_sim:\n", 153 | " print('running')\n", 154 | " R_d3 = synexp(runs = n_runs,sim_type = 'Feat_Dim_ratio',samples_settings=n,ds_settings=ds_settings,y_dim=2,\n", 155 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n", 156 | " \n", 157 | " R_d3.to_pickle('sweep_d_cluster_{}.npy'.format(version))\n", 158 | " \n", 159 | " if plot_sim:\n", 160 | " R_d3 = pd.read_pickle('sweep_d_cluster_{}.npy'.format(version))\n", 161 | " var = dict(sigma = sigma, d_sum = sum(ds_settings[0][1:]), n = n,alphas = alpha)\n", 162 | " ratio = [d[0] for d in ds_settings]\n", 163 | " sim_plots(R_d3,'Feat_Dim_ratio', ratio, filename = 'sweep_d',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n", 164 | " ylim2=[0,0.5])\n", 165 | "\n", 166 | "elif type_sim ==2: # Vary the weight of X1 with respect to other feature spaces\n", 167 | "\n", 168 | " ds = [10,10,10,10]\n", 169 | " n = 100\n", 170 | " correl = 0.2\n", 171 | " sigma = 0.2\n", 172 | " alpha = [[0.1,0.2,0.5,0.2],[0.3,0.2,0.3,0.2],[0.5,0.2,0.2,0.1],[0.7,0.1,0,0.2],[0.9,0.1,0.,0.]] \n", 173 | " for a in alpha:\n", 174 | " assert np.round(sum(a),2)==1\n", 175 | "\n", 176 | " if run_sim:\n", 177 | " R_C3 = synexp(runs = n_runs,sim_type = 'Cond',samples_settings=n,ds_settings=ds,y_dim=2,\n", 178 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n", 179 | " \n", 180 | " R_C3.to_pickle('sweep_c_cluster_{}.npy'.format(version))\n", 181 | " \n", 182 | " if plot_sim:\n", 183 | " R_C3 = pd.read_pickle('sweep_c_cluster_{}.npy'.format(version))\n", 184 | " var = dict(sigma = sigma, ds = ds, n = n,alphas = alpha)\n", 185 | " \n", 186 | " ratio = [a[0] for a in alpha]\n", 187 | " sim_plots(R_C3,'Cond', ratio, filename = 'sweep_alpha',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n", 188 | " ylim2=[0,0.5])\n", 189 | "\n", 190 | "\n", 191 | "elif type_sim ==3: # Vary the noise level\n", 192 | "\n", 193 | " ds = [10,10,10,10]\n", 194 | " n = 100\n", 195 | " sigma = [0,0.2,0.5,1,1.5]\n", 196 | " alpha = [0.5,0.2,0.2,0.1]\n", 197 | " correl = 0.2\n", 198 | " assert np.round(sum(alpha),2)==1\n", 199 | "\n", 200 | " if run_sim:\n", 201 | " R_sigma3 =synexp(runs = n_runs,sim_type = 'noise',samples_settings=n,ds_settings=ds,y_dim=2,\n", 202 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n", 203 | " \n", 204 | " R_sigma3.to_pickle('sweep_sigma_cluster_{}.npy'.format(version))\n", 205 | " \n", 206 | " if plot_sim:\n", 207 | " R_sigma3 = pd.read_pickle('sweep_sigma_cluster_{}.npy'.format(version))\n", 208 | " var = dict( ds = ds, alphas = alpha,n = n)\n", 209 | " sim_plots(R_sigma3,'noise', sigma, filename = 'sweep_sigma',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n", 210 | " ylim2=[0,0.5])\n", 211 | "\n", 212 | "elif type_sim ==4: # Vary the number of samples\n", 213 | " \n", 214 | " # ds = [10,10,10,10]\n", 215 | " ds = [100,100,100,100]\n", 216 | " # n = [30,40,60,100,200]\n", 217 | " n = [100,200,400, 800]\n", 218 | " sigma = 0.2\n", 219 | " alpha = [0.5,0.2,0.2,0.1]\n", 220 | " correl = 0.2\n", 221 | " assert np.round(sum(alpha),2)==1\n", 222 | "\n", 223 | " if run_sim:\n", 224 | " R_n3 = synexp(runs = n_runs,sim_type = 'Sample_Dim_ratio',samples_settings=n,ds_settings=ds,y_dim=2,\n", 225 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n", 226 | " \n", 227 | " R_n3.to_pickle('sweep_n_cluster_{}.npy'.format(version))\n", 228 | " \n", 229 | " if plot_sim:\n", 230 | " R_n3 = pd.read_pickle('sweep_n_cluster_{}.npy'.format(version))\n", 231 | " \n", 232 | " var = dict(sigma = sigma, ds = ds, alpha = alpha)\n", 233 | " sim_plots(R_n3,'Sample_Dim_ratio', n, filename = 'sweep_n',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n", 234 | " ylim2=[0,0.5])\n", 235 | "\n", 236 | " \n", 237 | "elif type_sim ==5: # Vary the correlation\n", 238 | " \n", 239 | " # ds = [10,10,10,10] \n", 240 | " ds = [100,100,100,100] \n", 241 | " # n = 100\n", 242 | " n = 400\n", 243 | " sigma = 0.2\n", 244 | " alpha = [0.5,0.2,0.2,0.1]\n", 245 | " correl = [0,0.1,0.2, 0.5, 0.8]\n", 246 | " assert np.round(sum(alpha),2)==1\n", 247 | "\n", 248 | " if run_sim:\n", 249 | " R_rho = synexp(runs = n_runs,sim_type = 'correl',samples_settings=n,ds_settings=ds,y_dim=2,\n", 250 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n", 251 | " \n", 252 | " R_rho.to_pickle('sweep_rho_cluster_{}.npy'.format(version))\n", 253 | " \n", 254 | " if plot_sim:\n", 255 | " R_rho = pd.read_pickle('sweep_rho_cluster_{}.npy'.format(version))\n", 256 | " \n", 257 | " var = dict(sigma = sigma, ds = ds, alpha = alpha, n = n)\n", 258 | " sim_plots(R_rho,'correl', correl, filename = 'sweep_rho',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n", 259 | " ylim2=[0,0.5])" 260 | ] 261 | }, 262 | { 263 | "cell_type": "code", 264 | "execution_count": null, 265 | "id": "a58dd81b-8f9d-443e-8302-4c6c1bad5f04", 266 | "metadata": {}, 267 | "outputs": [], 268 | "source": [] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": null, 273 | "id": "8dba0b9e-57f0-47fc-bc8e-cf741a92a2a2", 274 | "metadata": {}, 275 | "outputs": [], 276 | "source": [] 277 | } 278 | ], 279 | "metadata": { 280 | "kernelspec": { 281 | "display_name": "Python 3 (ipykernel)", 282 | "language": "python", 283 | "name": "python3" 284 | }, 285 | "language_info": { 286 | "codemirror_mode": { 287 | "name": "ipython", 288 | "version": 3 289 | }, 290 | "file_extension": ".py", 291 | "mimetype": "text/x-python", 292 | "name": "python", 293 | "nbconvert_exporter": "python", 294 | "pygments_lexer": "ipython3", 295 | "version": "3.9.21" 296 | }, 297 | "toc": { 298 | "base_numbering": 1, 299 | "nav_menu": {}, 300 | "number_sections": false, 301 | "sideBar": true, 302 | "skip_h1_title": true, 303 | "title_cell": "Table of Contents", 304 | "title_sidebar": "Contents", 305 | "toc_cell": false, 306 | "toc_position": {}, 307 | "toc_section_display": true, 308 | "toc_window_display": false 309 | } 310 | }, 311 | "nbformat": 4, 312 | "nbformat_minor": 5 313 | } 314 | -------------------------------------------------------------------------------- /stacking.py: -------------------------------------------------------------------------------- 1 | from cvxopt import matrix, solvers 2 | import numpy as np 3 | from scipy.stats import zscore 4 | from ridge_tools import cross_val_ridge, R2, ridge 5 | 6 | # Set option to not show progress in CVXOPT solver 7 | solvers.options["show_progress"] = False 8 | 9 | 10 | def get_cv_indices(n_samples, n_folds): 11 | """Generate cross-validation indices. 12 | 13 | Args: 14 | n_samples (int): Number of samples to generate indices for. 15 | n_folds (int): Number of folds to use in cross-validation. 16 | 17 | Returns: 18 | numpy.ndarray: Array of cross-validation indices with shape (n_samples,). 19 | """ 20 | cv_indices = np.zeros((n_samples)) 21 | n_items = int(np.floor(n_samples / n_folds)) # number of items in one fold 22 | for i in range(0, n_folds - 1): 23 | cv_indices[i * n_items : (i + 1) * n_items] = i 24 | cv_indices[(n_folds - 1) * n_items :] = n_folds - 1 25 | return cv_indices 26 | 27 | 28 | def feat_ridge_CV( 29 | train_features, 30 | train_targets, 31 | test_features, 32 | method="cross_val_ridge", 33 | n_folds=5, 34 | score_function=R2, 35 | ): 36 | """Train a ridge regression model with cross-validation and predict on test_features. 37 | 38 | Args: 39 | train_features (numpy.ndarray): Array of shape (n_samples, n_features) containing the training features. 40 | train_targets (numpy.ndarray): Array of shape (n_samples, n_targets) containing the training targets. 41 | test_features (numpy.ndarray): Array of shape (n_test_samples, n_features) containing the test features. 42 | method (str): Method to use for ridge regression. Options are "simple_ridge" and "cross_val_ridge". 43 | Defaults to "cross_val_ridge". 44 | n_folds (int): Number of folds to use in cross-validation. Defaults to 5. 45 | score_function (callable): Scoring function to use for cross-validation. Defaults to R2. 46 | 47 | Returns: 48 | tuple: Tuple containing: 49 | - preds_train (numpy.ndarray): Array of shape (n_samples, n_targets) containing the training set predictions. 50 | - err (numpy.ndarray): Array of shape (n_samples, n_targets) containing the training set errors. 51 | - preds_test (numpy.ndarray): Array of shape (n_test_samples, n_targets) containing the test set predictions. 52 | - r2s_train_fold (numpy.ndarray): Array of shape (n_folds,) containing the cross-validation scores. 53 | - var_train_fold (numpy.ndarray): Array of shape (n_targets,) containing the variances of the training set predictions. 54 | """ 55 | 56 | if np.all(train_features == 0): 57 | # If there are no predictors, return zero weights and zero predictions 58 | weights = np.zeros((train_features.shape[1], train_targets.shape[1])) 59 | train_preds = np.zeros_like(train_targets) 60 | else: 61 | # Use cross-validation to train the model 62 | cv_indices = get_cv_indices(train_targets.shape[0], n_folds=n_folds) 63 | train_preds = np.zeros_like(train_targets) 64 | 65 | for i_cv in range(n_folds): 66 | train_targets_cv = np.nan_to_num(zscore(train_targets[cv_indices != i_cv])) 67 | train_features_cv = np.nan_to_num( 68 | zscore(train_features[cv_indices != i_cv]) 69 | ) 70 | test_features_cv = np.nan_to_num(zscore(train_features[cv_indices == i_cv])) 71 | 72 | if method == "simple_ridge": 73 | # Use a fixed regularization parameter to train the model 74 | weights = ridge(train_features, train_targets, 100) 75 | elif method == "cross_val_ridge": 76 | # Use cross-validation to select the best regularization parameter 77 | lambdas = np.array([10**i for i in range(-6, 10)]) 78 | if train_features.shape[1] > train_features.shape[0]: 79 | weights, __ = cross_val_ridge( 80 | train_features_cv, 81 | train_targets_cv, 82 | n_splits=5, 83 | lambdas=lambdas, 84 | do_plot=False, 85 | method="plain", 86 | ) 87 | else: 88 | weights, __ = cross_val_ridge( 89 | train_features_cv, 90 | train_targets_cv, 91 | n_splits=5, 92 | lambdas=lambdas, 93 | do_plot=False, 94 | method="plain", 95 | ) 96 | 97 | # Make predictions on the current fold of the data 98 | train_preds[cv_indices == i_cv] = test_features_cv.dot(weights) 99 | 100 | # Calculate prediction error on the training set 101 | train_err = train_targets - train_preds 102 | 103 | # Retrain the model on all of the training data 104 | lambdas = np.array([10**i for i in range(-6, 10)]) 105 | weights, __ = cross_val_ridge( 106 | train_features, 107 | train_targets, 108 | n_splits=5, 109 | lambdas=lambdas, 110 | do_plot=False, 111 | method="plain", 112 | ) 113 | 114 | # Make predictions on the test set using the retrained model 115 | test_preds = np.dot(test_features, weights) 116 | 117 | # Calculate the score on the training set 118 | train_scores = score_function(train_preds, train_targets) 119 | train_variances = np.var(train_preds, axis=0) 120 | 121 | return train_preds, train_err, test_preds, train_scores, train_variances 122 | 123 | 124 | import numpy as np 125 | from cvxopt import matrix, solvers 126 | 127 | 128 | def stacking_fmri( 129 | train_data, 130 | test_data, 131 | train_features, 132 | test_features, 133 | method="cross_val_ridge", 134 | score_f=R2, 135 | ): 136 | """ 137 | Stacks predictions from different feature spaces and uses them to make final predictions. 138 | 139 | Args: 140 | train_data (ndarray): Training data of shape (n_time_train, n_voxels) 141 | test_data (ndarray): Testing data of shape (n_time_test, n_voxels) 142 | train_features (list): List of training feature spaces, each of shape (n_time_train, n_dims) 143 | test_features (list): List of testing feature spaces, each of shape (n_time_test, n_dims) 144 | method (str): Name of the method used for training. Default is 'cross_val_ridge'. 145 | score_f (callable): Scikit-learn scoring function to use for evaluation. Default is mean_squared_error. 146 | 147 | Returns: 148 | Tuple of ndarrays: 149 | - r2s: Array of shape (n_features, n_voxels) containing unweighted R2 scores for each feature space and voxel 150 | - stacked_r2s: Array of shape (n_voxels,) containing R2 scores for the stacked predictions of each voxel 151 | - r2s_weighted: Array of shape (n_features, n_voxels) containing R2 scores for each feature space weighted by stacking weights 152 | - r2s_train: Array of shape (n_features, n_voxels) containing R2 scores for each feature space and voxel in the training set 153 | - stacked_train_r2s: Array of shape (n_voxels,) containing R2 scores for the stacked predictions of each voxel in the training set 154 | - S: Array of shape (n_voxels, n_features) containing the stacking weights for each voxel 155 | """ 156 | 157 | # Number of time points in the test set 158 | n_time_test = test_data.shape[0] 159 | 160 | # Check that the number of voxels is the same in the training and test sets 161 | assert train_data.shape[1] == test_data.shape[1] 162 | n_voxels = train_data.shape[1] 163 | 164 | # Check that the number of feature spaces is the same in the training and test sets 165 | assert len(train_features) == len(test_features) 166 | n_features = len(train_features) 167 | 168 | # Array to store R2 scores for each feature space and voxel 169 | r2s = np.zeros((n_features, n_voxels)) 170 | # Array to store R2 scores for each feature space and voxel in the training set 171 | r2s_train = np.zeros((n_features, n_voxels)) 172 | # Array to store variance explained by the model for each feature space and voxel in the training set 173 | var_train = np.zeros((n_features, n_voxels)) 174 | # Array to store R2 scores for each feature space weighted by stacking weights 175 | r2s_weighted = np.zeros((n_features, n_voxels)) 176 | 177 | # Array to store stacked predictions for each voxel 178 | stacked_pred = np.zeros((n_time_test, n_voxels)) 179 | # Dictionary to store predictions for each feature space and voxel in the training set 180 | preds_train = {} 181 | # Dictionary to store predictions for each feature space and voxel in the test set 182 | preds_test = np.zeros((n_features, n_time_test, n_voxels)) 183 | # Array to store weighted predictions for each feature space and voxel in the test set 184 | weighted_pred = np.zeros((n_features, n_time_test, n_voxels)) 185 | 186 | # normalize data by TRAIN/TEST 187 | train_data = np.nan_to_num(zscore(train_data)) 188 | test_data = np.nan_to_num(zscore(test_data)) 189 | 190 | train_features = [np.nan_to_num(zscore(F)) for F in train_features] 191 | test_features = [np.nan_to_num(zscore(F)) for F in test_features] 192 | 193 | # initialize an error dictionary to store errors for each feature 194 | err = dict() 195 | preds_train = dict() 196 | 197 | # iterate over each feature and train a model using feature ridge regression 198 | for FEATURE in range(n_features): 199 | ( 200 | preds_train[FEATURE], 201 | error, 202 | preds_test[FEATURE, :, :], 203 | r2s_train[FEATURE, :], 204 | var_train[FEATURE, :], 205 | ) = feat_ridge_CV( 206 | train_features[FEATURE], train_data, test_features[FEATURE], method=method 207 | ) 208 | err[FEATURE] = error 209 | 210 | # calculate error matrix for stacking 211 | P = np.zeros((n_voxels, n_features, n_features)) 212 | for i in range(n_features): 213 | for j in range(n_features): 214 | P[:, i, j] = np.mean(err[i] * err[j], 0) 215 | 216 | # solve the quadratic programming problem to obtain the weights for stacking 217 | q = matrix(np.zeros((n_features))) 218 | G = matrix(-np.eye(n_features, n_features)) 219 | h = matrix(np.zeros(n_features)) 220 | A = matrix(np.ones((1, n_features))) 221 | b = matrix(np.ones(1)) 222 | 223 | S = np.zeros((n_voxels, n_features)) 224 | stacked_pred_train = np.zeros_like(train_data) 225 | 226 | for i in range(0, n_voxels): 227 | PP = matrix(P[i]) 228 | # solve for stacking weights for every voxel 229 | S[i, :] = np.array(solvers.qp(PP, q, G, h, A, b)["x"]).reshape(n_features) 230 | 231 | # combine the predictions from the individual feature spaces for voxel i 232 | z_test = np.array( 233 | [preds_test[feature_j, :, i] for feature_j in range(n_features)] 234 | ) 235 | z_train = np.array( 236 | [preds_train[feature_j][:, i] for feature_j in range(n_features)] 237 | ) 238 | # multiply the predictions by S[i,:] 239 | stacked_pred[:, i] = np.dot(S[i, :], z_test) 240 | # combine the training predictions from the individual feature spaces for voxel i 241 | stacked_pred_train[:, i] = np.dot(S[i, :], z_train) 242 | 243 | # compute the R2 score for the stacked predictions on the training data 244 | stacked_train_r2s = score_f(stacked_pred_train, train_data) 245 | 246 | # compute the R2 scores for each individual feature and the weighted feature predictions 247 | for FEATURE in range(n_features): 248 | # weight the predictions according to S: 249 | # weighted single feature space predictions, computed over a fold 250 | weighted_pred[FEATURE, :] = preds_test[FEATURE, :] * S[:, FEATURE] 251 | 252 | for FEATURE in range(n_features): 253 | r2s[FEATURE, :] = score_f(preds_test[FEATURE], test_data) 254 | r2s_weighted[FEATURE, :] = score_f(weighted_pred[FEATURE], test_data) 255 | 256 | # compute the R2 score for the stacked predictions on the test data 257 | stacked_r2s = score_f(stacked_pred, test_data) 258 | 259 | # return the results 260 | return ( 261 | r2s, 262 | stacked_r2s, 263 | r2s_weighted, 264 | r2s_train, 265 | stacked_train_r2s, 266 | S, 267 | ) 268 | 269 | 270 | def stacking_CV_fmri(data, features, method="cross_val_ridge", n_folds=5, score_f=R2): 271 | """ 272 | A function that performs cross-validated feature stacking to predict fMRI 273 | signal from a set of predictors. 274 | 275 | Args: 276 | - data (ndarray): A matrix of fMRI signal data with dimensions n_time x n_voxels. 277 | - features (list): A list of length n_features containing arrays of predictors 278 | with dimensions n_time x n_dim. 279 | - method (str): A string indicating the method to use to train the model. Default is "cross_val_ridge". 280 | - n_folds (int): An integer indicating the number of cross-validation folds to use. Default is 5. 281 | - score_f (function): A function to use for scoring the model. Default is R2. 282 | 283 | Returns: 284 | - A tuple containing the following elements: 285 | - r2s (ndarray): An array of shape (n_features, n_voxels) containing the R2 scores 286 | for each feature and voxel. 287 | - r2s_weighted (ndarray): An array of shape (n_features, n_voxels) containing the R2 scores 288 | for each feature and voxel, weighted by stacking weights. 289 | - stacked_r2s (float): The R2 score for the stacked predictions. 290 | - r2s_train (ndarray): An array of shape (n_features, n_voxels) containing the R2 scores 291 | for each feature and voxel for the training set. 292 | - stacked_train (float): The R2 score for the stacked predictions for the training set. 293 | - S_average (ndarray): An array of shape (n_features, n_voxels) containing the stacking weights 294 | for each feature and voxel. 295 | 296 | """ 297 | 298 | n_time, n_voxels = data.shape 299 | n_features = len(features) 300 | 301 | ind = get_cv_indices(n_time, n_folds=n_folds) 302 | 303 | # create arrays to store results 304 | r2s = np.zeros((n_features, n_voxels)) 305 | r2s_train_folds = np.zeros((n_folds, n_features, n_voxels)) 306 | var_train_folds = np.zeros((n_folds, n_features, n_voxels)) 307 | r2s_weighted = np.zeros((n_features, n_voxels)) 308 | stacked_train_r2s_fold = np.zeros((n_folds, n_voxels)) 309 | stacked_pred = np.zeros((n_time, n_voxels)) 310 | preds_test = np.zeros((n_features, n_time, n_voxels)) 311 | weighted_pred = np.zeros((n_features, n_time, n_voxels)) 312 | S_average = np.zeros((n_voxels, n_features)) 313 | 314 | # perform cross-validation by fold 315 | for ind_num in range(n_folds): 316 | # split data into training and testing sets 317 | train_ind = ind != ind_num 318 | test_ind = ind == ind_num 319 | train_data = data[train_ind] 320 | train_features = [F[train_ind] for F in features] 321 | test_data = data[test_ind] 322 | test_features = [F[test_ind] for F in features] 323 | 324 | # normalize data 325 | train_data = np.nan_to_num(zscore(train_data)) 326 | test_data = np.nan_to_num(zscore(test_data)) 327 | 328 | train_features = [np.nan_to_num(zscore(F)) for F in train_features] 329 | test_features = [np.nan_to_num(zscore(F)) for F in test_features] 330 | 331 | # Store prediction errors and training predictions for each feature 332 | err = dict() 333 | preds_train = dict() 334 | for FEATURE in range(n_features): 335 | ( 336 | preds_train[FEATURE], 337 | error, 338 | preds_test[FEATURE, test_ind], 339 | r2s_train_folds[ind_num, FEATURE, :], 340 | var_train_folds[ind_num, FEATURE, :], 341 | ) = feat_ridge_CV( 342 | train_features[FEATURE], 343 | train_data, 344 | test_features[FEATURE], 345 | method=method, 346 | ) 347 | err[FEATURE] = error 348 | 349 | # calculate error matrix for stacking 350 | P = np.zeros((n_voxels, n_features, n_features)) 351 | for i in range(n_features): 352 | for j in range(n_features): 353 | P[:, i, j] = np.mean(err[i] * err[j], axis=0) 354 | 355 | # Set optimization parameters for computing stacking weights 356 | q = matrix(np.zeros((n_features))) 357 | G = matrix(-np.eye(n_features, n_features)) 358 | h = matrix(np.zeros(n_features)) 359 | A = matrix(np.ones((1, n_features))) 360 | b = matrix(np.ones(1)) 361 | 362 | S = np.zeros((n_voxels, n_features)) 363 | stacked_pred_train = np.zeros_like(train_data) 364 | 365 | # Compute stacking weights and combined predictions for each voxel 366 | for i in range(n_voxels): 367 | PP = matrix(P[i]) 368 | # solve for stacking weights for every voxel 369 | S[i, :] = np.array(solvers.qp(PP, q, G, h, A, b)["x"]).reshape( 370 | n_features, 371 | ) 372 | # combine the predictions from the individual feature spaces for voxel i 373 | z = np.array( 374 | [preds_test[feature_j, test_ind, i] for feature_j in range(n_features)] 375 | ) 376 | # multiply the predictions by S[i,:] 377 | stacked_pred[test_ind, i] = np.dot(S[i, :], z) 378 | # combine the training predictions from the individual feature spaces for voxel i 379 | z = np.array( 380 | [preds_train[feature_j][:, i] for feature_j in range(n_features)] 381 | ) 382 | stacked_pred_train[:, i] = np.dot(S[i, :], z) 383 | 384 | S_average += S 385 | stacked_train_r2s_fold[ind_num, :] = score_f(stacked_pred_train, train_data) 386 | 387 | # Compute weighted single feature space predictions, computed over a fold 388 | for FEATURE in range(n_features): 389 | weighted_pred[FEATURE, test_ind] = ( 390 | preds_test[FEATURE, test_ind] * S[:, FEATURE] 391 | ) 392 | 393 | # Compute overall performance metrics 394 | data_zscored = zscore(data) 395 | for FEATURE in range(n_features): 396 | r2s[FEATURE, :] = score_f(preds_test[FEATURE], data_zscored) 397 | r2s_weighted[FEATURE, :] = score_f(weighted_pred[FEATURE], data_zscored) 398 | 399 | stacked_r2s = score_f(stacked_pred, data_zscored) 400 | 401 | r2s_train = r2s_train_folds.mean(0) 402 | stacked_train = stacked_train_r2s_fold.mean(0) 403 | S_average = S_average / n_folds 404 | 405 | # return the results 406 | return ( 407 | r2s, 408 | stacked_r2s, 409 | r2s_weighted, 410 | r2s_train, 411 | stacked_train, 412 | S_average, 413 | ) 414 | -------------------------------------------------------------------------------- /stacking_tutorial.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 19, 6 | "id": "extra-bearing", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import numpy as np\n", 11 | "from stacking_fmri import stacking_CV_fmri, stacking_fmri\n", 12 | "from ridge_tools import R2\n", 13 | "import matplotlib.pyplot as plt\n", 14 | "import seaborn as sns" 15 | ] 16 | }, 17 | { 18 | "cell_type": "code", 19 | "execution_count": 2, 20 | "id": "confidential-punch", 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "### Set up or load features\n", 25 | "\n", 26 | "N_sample = 1000\n", 27 | "dim_X1 = 50\n", 28 | "dim_X2 = 100\n", 29 | "dim_X3 = 25\n", 30 | "\n", 31 | "X1 = np.random.randn(N_sample, dim_X1)\n", 32 | "X2 = np.random.randn(N_sample, dim_X2)\n", 33 | "X3 = np.random.randn(N_sample, dim_X3)\n", 34 | "\n", 35 | "#X1 = np.load('....')" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 32, 41 | "id": "boring-occurrence", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "### Set up or load brain data (fMRI, EEG, ....)\n", 46 | "\n", 47 | "dim_Y = 10\n", 48 | "\n", 49 | "# Y = np.random.randn(N_sample, dim_Y)\n", 50 | "Y = 0.3 * X1.dot(np.random.randn(dim_X1, dim_Y)) + \\\n", 51 | " 0.3 * X2.dot(np.random.randn(dim_X2, dim_Y)) + \\\n", 52 | " 0.4 * X3.dot(np.random.randn(dim_X3, dim_Y))\n", 53 | "\n", 54 | "#Y = np.load('....')" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 33, 60 | "id": "buried-charger", 61 | "metadata": {}, 62 | "outputs": [ 63 | { 64 | "name": "stdout", 65 | "output_type": "stream", 66 | "text": [ 67 | "(3, 250)\n", 68 | "(3, 250)\n", 69 | "(3, 250)\n", 70 | "(3, 250)\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "### Run stacking using multiple features (Xs) and Y\n", 76 | "\n", 77 | "\n", 78 | "## with the outermost cross-validation\n", 79 | "r2s, stacked_r2s, _, _, _, S_average = stacking_CV_fmri(Y, [X1,X2,X3], method = 'cross_val_ridge',n_folds = 4,score_f=R2)", 80 | "\n", 81 | "\n", 82 | "## simple train-test setting (without the outermost cross-validation)\n", 83 | "# r2s, stacked_r2s, _, _, _, S_average = stacking_fmri(Y[0:700], Y[700:], [X1[0:700],X2[0:700],X3[0:700]], [X1[700:],X2[700:],X3[700:]], method = 'cross_val_ridge',score_f=R2)" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 34, 89 | "id": "chinese-sussex", 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "shape of r2s is (number of features, dim_Y), that is (3, 10)\n", 97 | "shape of stacked_r2s is (dim_Y, ), that is (10,)\n", 98 | "shape of S_average is (dim_Y, num of features), that is (10, 3)\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "### Results\n", 104 | "\n", 105 | "## r2s: voxelwise R2(predictions using only one feature, data)\n", 106 | "print('shape of r2s is (number of features, dim_Y), that is', r2s.shape)\n", 107 | "\n", 108 | "## stacked_r2s: voxelwise R2(stacking predictions using all features, data)\n", 109 | "print('shape of stacked_r2s is (dim_Y, ), that is', stacked_r2s.shape)\n", 110 | "\n", 111 | "## S_average: optimzed voxelwise stacking weights showing how different features are combined\n", 112 | "print('shape of S_average is (dim_Y, num of features), that is', S_average.shape)\n", 113 | "\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 37, 119 | "id": "curious-blanket", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "data": { 124 | "text/plain": [ 125 | "Text(0.5, 1.0, 'Prediction Performance')" 126 | ] 127 | }, 128 | "execution_count": 37, 129 | "metadata": {}, 130 | "output_type": "execute_result" 131 | }, 132 | { 133 | "data": { 134 | "image/png": "\n", 135 | "text/plain": [ 136 | "
" 137 | ] 138 | }, 139 | "metadata": { 140 | "needs_background": "light" 141 | }, 142 | "output_type": "display_data" 143 | } 144 | ], 145 | "source": [ 146 | "plt.figure(figsize=(10,6))\n", 147 | "\n", 148 | "bar_width = 0.2\n", 149 | "index_0 = np.arange(dim_Y)\n", 150 | "index_1 = index_0 + bar_width\n", 151 | "index_2 = index_1 + bar_width\n", 152 | "index_3 = index_2 + bar_width\n", 153 | "\n", 154 | "\n", 155 | "plt.bar(index_0, stacked_r2s, width=bar_width, label='Stacking of X1,2,3')\n", 156 | "plt.bar(index_1, r2s[0,:], width=bar_width, label='X1')\n", 157 | "plt.bar(index_2, r2s[1,:], width=bar_width, label='X2')\n", 158 | "plt.bar(index_3, r2s[2,:], width=bar_width, label='X3')\n", 159 | "plt.legend()\n", 160 | "plt.xlabel('Voxel ID')\n", 161 | "plt.ylabel('R2')\n", 162 | "plt.title('Prediction Performance')" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 45, 168 | "id": "smoking-merchant", 169 | "metadata": {}, 170 | "outputs": [ 171 | { 172 | "data": { 173 | "text/plain": [ 174 | "([,\n", 175 | " ,\n", 176 | " ],\n", 177 | " [Text(0, 0.5, 'X1'), Text(0, 1.5, 'X2'), Text(0, 2.5, 'X3')])" 178 | ] 179 | }, 180 | "execution_count": 45, 181 | "metadata": {}, 182 | "output_type": "execute_result" 183 | }, 184 | { 185 | "data": { 186 | "image/png": "\n", 187 | "text/plain": [ 188 | "
" 189 | ] 190 | }, 191 | "metadata": { 192 | "needs_background": "light" 193 | }, 194 | "output_type": "display_data" 195 | } 196 | ], 197 | "source": [ 198 | "sns.heatmap(S_average.T, vmin=0, vmax=1)\n", 199 | "plt.xlabel('Voxel ID')\n", 200 | "plt.ylabel('Feature')\n", 201 | "plt.yticks([0.5,1.5,2.5],['X1', 'X2', 'X3'])" 202 | ] 203 | }, 204 | { 205 | "cell_type": "code", 206 | "execution_count": null, 207 | "id": "answering-fireplace", 208 | "metadata": {}, 209 | "outputs": [], 210 | "source": [] 211 | } 212 | ], 213 | "metadata": { 214 | "kernelspec": { 215 | "display_name": "Python 3", 216 | "language": "python", 217 | "name": "python3" 218 | }, 219 | "language_info": { 220 | "codemirror_mode": { 221 | "name": "ipython", 222 | "version": 3 223 | }, 224 | "file_extension": ".py", 225 | "mimetype": "text/x-python", 226 | "name": "python", 227 | "nbconvert_exporter": "python", 228 | "pygments_lexer": "ipython3", 229 | "version": "3.7.9" 230 | } 231 | }, 232 | "nbformat": 4, 233 | "nbformat_minor": 5 234 | } 235 | --------------------------------------------------------------------------------