├── extra_images
├── readme.md
├── sim_1.png
├── sim_2.png
├── sim_3.png
└── sim_4.png
├── concatenate.py
├── README.md
├── extended_simulations.md
├── run_simulation_sweep.py
├── simulation_experiment.py
├── visualize_sim.py
├── ridge_tools.py
├── variance_partitioning.ipynb
├── simulate.ipynb
├── stacking.py
└── stacking_tutorial.ipynb
/extra_images/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/extra_images/sim_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_1.png
--------------------------------------------------------------------------------
/extra_images/sim_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_2.png
--------------------------------------------------------------------------------
/extra_images/sim_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_3.png
--------------------------------------------------------------------------------
/extra_images/sim_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/brainML/Stacking/HEAD/extra_images/sim_4.png
--------------------------------------------------------------------------------
/concatenate.py:
--------------------------------------------------------------------------------
1 | # from cvxopt import matrix, solvers
2 | import numpy as np
3 | from scipy.stats import zscore
4 | from ridge_tools import cross_val_ridge, R2, ridge
5 | from stacking import feat_ridge_CV, get_cv_indices
6 |
7 |
8 | def concatenate_CV_fmri(data, features, method="cross_val_ridge", n_folds=5, score_f=R2):
9 | """
10 | A function that concatenates feature spaces to predict fMRI signal.
11 |
12 | Args:
13 | - data (ndarray): A matrix of fMRI signal data with dimensions n_time x n_voxels.
14 | - features (list): A list of length n_features containing arrays of predictors
15 | with dimensions n_time x n_dim.
16 | - method (str): A string indicating the method to use to train the model. Default is "cross_val_ridge".
17 | - n_folds (int): An integer indicating the number of cross-validation folds to use. Default is 5.
18 | - score_f (function): A function to use for scoring the model. Default is R2.
19 |
20 | Returns:
21 | - A tuple containing the following element:
22 | - concat_r2s (float): The R2 score for the concatenated model predictions.
23 |
24 | """
25 |
26 | n_time, n_voxels = data.shape
27 | n_features = len(features)
28 |
29 | ind = get_cv_indices(n_time, n_folds=n_folds)
30 |
31 | # create arrays to store predictions
32 | concat_pred = np.zeros((n_time, n_voxels))
33 |
34 | # perform cross-validation by fold
35 | for ind_num in range(n_folds):
36 | # split data into training and testing sets
37 | train_ind = ind != ind_num
38 | test_ind = ind == ind_num
39 | train_data = data[train_ind]
40 | train_features = [F[train_ind] for F in features]
41 | test_data = data[test_ind]
42 | test_features = [F[test_ind] for F in features]
43 |
44 | # normalize data
45 | train_data = np.nan_to_num(zscore(train_data))
46 | test_data = np.nan_to_num(zscore(test_data))
47 |
48 | train_features = [np.nan_to_num(zscore(F)) for F in train_features]
49 | test_features = [np.nan_to_num(zscore(F)) for F in test_features]
50 |
51 | # Store predictions
52 | __,__, concat_pred[test_ind], __,__ = feat_ridge_CV(np.hstack(train_features), train_data, np.hstack(test_features),
53 | method=method)
54 |
55 |
56 | # Compute overall performance metrics
57 | data_zscored = zscore(data)
58 |
59 | concat_r2s = score_f(concat_pred, data_zscored)
60 |
61 | # return the results
62 | return (
63 | concat_r2s,
64 | )
65 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # 🧠 Stacked regressions and structured variance partitioning for interpretable brain maps
2 |
3 | ## Overview
4 |
5 | This is a Python package that provides an implementation of stacked regression for functional MRI (fMRI) data. The package uses ridge regression to train models on multiple feature spaces and combines the predictions from these models using a weighted linear combination. The weights are learned using quadratic programming.
6 |
7 | > Here we presents an approach for brain mapping based on two proposed methods: stacking different encoding models and structured variance partitioning. This package is useful for researchers interested in aligning brain activity with different layers of a neural network, or with other types of correlated feature spaces.
8 |
9 | > Relating brain activity associated with a complex stimulus to different attributes of that stimulus is a powerful approach for constructing functional brain maps. However, when stimuli are naturalistic, their attributes are often correlated. These different attributes can act as confounders for each other and complicate the interpretability of brain maps. Correlations between attributes also impact the robustness of statistical estimators.
10 |
11 | > Each encoding model uses as input a feature space that describes a different stimulus attribute. The algorithm learns to predict the activity of a voxel as a linear combination of the individual encoding models. We show that the resulting unified model can predict held-out brain activity better or at least as well as the individual encoding models. Further, the weights of the linear combination are readily interpretable; they show the importance of each feature space for predicting a voxel.
12 |
13 | > We build on our stacking models to introduce a new variant of variance partitioning in which we rely on the known relationships between features during hypothesis testing. This approach, which we term structured variance partitioning, constraints the size of the hypothesis space and allows us to ask targeted questions about the similarity between feature spaces and brain regions even in the presence of correlations between the feature spaces.
14 |
15 | > We validate our approach in simulation, showcase its brain mapping potential on fMRI data, and release a Python package.
16 |
17 | ## Installation
18 | To use this code, simply install with pip:
19 |
20 |
21 | ```bash
22 | pip install stacking_fmri
23 | ```
24 |
25 |
26 | ## Usage
27 | Here is an example of how to use the `stacking_fmri` function:
28 | ```python
29 | from stacking_fmri import stacking_fmri
30 | from sklearn.datasets import make_regression
31 |
32 | # Generate synthetic data
33 | X_train, y_train = make_regression(n_samples=50, n_features=1000, random_state=42)
34 | X_test, y_test = make_regression(n_samples=50, n_features=1000, random_state=43)
35 |
36 | # Generate random feature spaces
37 | n_features = 5
38 | train_features = [np.random.randn(X_train.shape[0], 10) for _ in range(n_features)]
39 | test_features = [np.random.randn(X_test.shape[0], 10) for _ in range(n_features)]
40 |
41 | # Train and test the model
42 | (
43 | r2s,
44 | stacked_r2s,
45 | r2s_weighted,
46 | r2s_train,
47 | stacked_train_r2s,
48 | S,
49 | ) = stacking_fmri(
50 | X_train,
51 | X_test,
52 | train_features,
53 | test_features,
54 | method="cross_val_ridge",
55 | score_f=np.mean_squared_error,
56 | )
57 |
58 | print("R2 scores for each feature and voxel:")
59 | print(r2s)
60 | print("\nWeighted R2 scores for each feature and voxel:")
61 | print(r2s_weighted)
62 | print("\nUnweighted R2 scores for the stacked predictions:")
63 | print(stacked_r2s)
64 | print("\nStacking weights:")
65 | print(S)
66 | ```
67 |
68 | We also provide examples of how to use the package in jupyter notebooks:
69 |
70 | - stacking_tutorial.ipynb
71 |
72 | - variance_partitioning.ipynb
73 |
74 |
75 |
77 |
78 |
79 | ## Contributions
80 | Contributions are welcome! Please feel free to submit a pull request with your changes or open an issue to report a bug or suggest a new feature.
81 |
82 |
83 | ## Contact
84 | Created by [@lrg1213] - feel free to contact me!
85 |
86 |
87 | ## References
88 | [1]
89 | Ruogu Lin, Thomas Naselaris, Kendrick Kay, and Leila Wehbe (2023).
90 | Stacked regressions and structured variance partitioning for interpretable brain maps.
91 |
92 |
93 |
94 |
--------------------------------------------------------------------------------
/extended_simulations.md:
--------------------------------------------------------------------------------
1 | ## Correction and Extended Simulation Study
2 |
3 | ### Correction to the Original Paper
4 |
5 | In the original paper, we stated that the simulation experiment used a **correlation of 0.2 between the feature spaces**. However, due to a bug in our simulation code (which is now fixed), the results in the published article actually correspond to a **correlation of 0**.
6 |
7 | We now provide **two corrected examples** below generated using different correlations levels.
8 |
9 | ---
10 |
11 | ### New Examples with Correlated Feature Spaces
12 |
13 | #### **Example 1**
14 | 
15 |
16 | #### **Example 2**
17 |
18 | 
19 |
20 |
21 | As can be seen, in some parameter settings **stacking outperforms concatenation**, while in others **concatenation performs better**.
22 |
23 | As we emphasized in the paper, *these simulations cannot be taken as proof of either model's superiority*: there are infinitely many simulation configurations and the behaviors of stacking vs. concatenation depend heavily on the structure of the data.
24 |
25 | To deepen our understanding of this specific simulation setup, we performed a substantially more **comprehensive sweep** of parameters.
26 |
27 | ---
28 |
29 | ## Extended Simulation Sweep
30 |
31 | Below we report results from a broad grid of simulation settings (576 in total).
32 | For each configuration, we generate **100 independent datasets** using the same simulation framework as in the original study.
33 |
34 | This sweep varies the following parameters:
35 |
36 | - **Number of samples (n)** 50, 100, 200, 400
37 | - **Dimensions of feature spaces (d1,d2,d3,d4)**: [10, 10, 10, 10], [10, 100, 100, 100], [100, 100, 100, 100]
38 | - **Feature importance (alpha)**: [0.3, 0.3, 0.3, 0.1], [0.5, 0.2, 0.2, 0.1], [0.7, 0.1, 0.1, 0.1]
39 | - **Noise (sigma)**: 0, 0.5, 1, 1.5
40 | - **Feature correlation (correl)**: = [0, 0.1, 0.2, 0.3]
41 |
42 | Together, this produces a total of 576 total simulation settings. Each setting is run with 100 simulated datasets.
43 |
44 | ### Results
45 | Below we group the simulation results by parameter families to understand which factors most strongly influence whether stacking or concatenation performs better.
46 |
47 | In the figure below, we focus on one parameter at a time, and count the number of simulation settings (over all other parameters) for which stacking has reliably lower estimation error than concatenation, and the ones for which concatenation has reliably lower estimation error than stacking (similar to the right part of the two plots above, we count the number of stars and filled circles).
48 |
49 | 
50 |
51 | We see that smaller sample size n, higher noise sigma, low correlation, higher values of alpha1 and lower values of d1/d2 lead to better performance of stacking across the different settings of the other parameters.
52 |
53 | Instead of counting the number of experiment where one method is reliably better than the other, we compute the average error across experiments and settings, picking pairs of parameters at at time:
54 |
55 | 
56 |
57 | This visualization of the difference in average error highlights a relationship between n, sigma and alpha1. Small n and low sigma lead to concatenation performing better. Large alpha1 leads to stacking being better, but not at high n.
58 |
59 | ### Conclusion
60 | Across this **specific** simulation grid, we observe the following patterns:
61 | - Stacking can have lower error in lower sample sizes (n).
62 | - Stacking can have lower error in higher noise settings (sigma).
63 | - Stacking can have lower error when d1 is smaller than the other feature sizes.
64 | - Stacking can have lower error when alpha1 is large.
65 |
66 | Importantly, when we compare these findings to the real-data experiment in our paper:
67 | - Our dataset has 7 × 1024 = 7168 total features, which is on the same order as the ∼8000 training samples.
68 | - fMRI data has considerable noise.
69 | - Potentially, there could be a preference for a feature space (one of the alphas is large).
70 |
71 | Thus, the empirical conditions of the real dataset (and of data from experiments similar to the one we use) appear to correspond to simulation settings under which stacking tends to have an advantage, suggesting that stacking could be useful for building fMRI encoding models in naturalistic experiments. As we mentioned in the paper, even in the event where the experimenter thinks that concatenation might be better, they can include an additional feature space with the concatenated features.
72 |
--------------------------------------------------------------------------------
/run_simulation_sweep.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | import argparse
4 | import itertools
5 | import numpy as np
6 | import pandas as pd
7 |
8 | from simulation_experiment import run_one_simulation
9 |
10 | # Number of samples (n)
11 | N_LIST = [
12 | 50,100,200,400
13 | ]
14 |
15 | # Feature dimensions for each of the 4 feature spaces
16 | DS_LIST = [
17 | [10, 10, 10, 10],
18 | [10, 100, 100, 100],
19 | [100,100,100,100]
20 | ]
21 |
22 | # Alpha for each of the 4 feature spaces
23 | ALPHA_LIST = [
24 | [0.3, 0.3, 0.3, 0.1],
25 | [0.5, 0.2, 0.2, 0.1],
26 | [0.7, 0.1, 0.1, 0.1],
27 | ]
28 |
29 | # Output noise (y_noise / sigma)
30 | NOISE_LIST = [
31 | 0, 0.5, 1, 1.5
32 | ]
33 |
34 | # Correlation between feature spaces
35 | CORREL_LIST = [
36 | 0, 0.1, 0.2, 0.3
37 | ]
38 |
39 | # Fixed parameters (you can change these too if you want)
40 | SCALE = 0.5
41 | Y_DIM = 2
42 |
43 | # How many Monte Carlo runs per parameter setting?
44 | N_RUNS_PER_SETTING = 50
45 |
46 |
47 | def build_tasks():
48 | """
49 | Flatten all (parameter combination × run) into a list of tasks.
50 | Each task gets a unique global_index so you can shard across machines.
51 | """
52 | tasks = []
53 | idx = 0
54 |
55 | for n, ds, alpha, noise, correl in itertools.product(
56 | N_LIST, DS_LIST, ALPHA_LIST, NOISE_LIST, CORREL_LIST
57 | ):
58 | for run_id in range(N_RUNS_PER_SETTING):
59 | tasks.append(
60 | dict(
61 | global_index=idx,
62 | n=int(n),
63 | ds=list(ds),
64 | alpha=list(alpha),
65 | noise=float(noise),
66 | correl=float(correl),
67 | run=run_id,
68 | )
69 | )
70 | idx += 1
71 |
72 | return tasks
73 |
74 |
75 | def main():
76 | parser = argparse.ArgumentParser(
77 | description="Run a slice of the parameter sweep for the stacking vs concatenation simulations."
78 | )
79 | parser.add_argument(
80 | "--start-index",
81 | type=int,
82 | required=True,
83 | help="First task index (inclusive) to run.",
84 | )
85 | parser.add_argument(
86 | "--end-index",
87 | type=int,
88 | required=True,
89 | help="Last task index (inclusive) to run.",
90 | )
91 | parser.add_argument(
92 | "--output",
93 | type=str,
94 | required=True,
95 | help="Output pickle file name, e.g. results_0_99.pkl",
96 | )
97 | parser.add_argument(
98 | "--print-task-count",
99 | action="store_true",
100 | help="If set, only print total number of tasks and exit.",
101 | )
102 |
103 | args = parser.parse_args()
104 |
105 | tasks = build_tasks()
106 | total_tasks = len(tasks)
107 |
108 | if args.print_task_count:
109 | print(f"Total tasks: {total_tasks}")
110 | return
111 |
112 | if total_tasks == 0:
113 | raise ValueError("No tasks were generated. Did you fill in the parameter lists?")
114 |
115 | start = max(0, args.start_index)
116 | end = min(total_tasks - 1, args.end_index)
117 |
118 | if start > end:
119 | raise ValueError(
120 | f"Invalid range: start_index={args.start_index}, end_index={args.end_index}, total_tasks={total_tasks}"
121 | )
122 |
123 | selected_tasks = [t for t in tasks if start <= t["global_index"] <= end]
124 |
125 | print(
126 | f"Total tasks: {total_tasks}. "
127 | f"Running tasks {start}..{end} ({len(selected_tasks)} tasks) "
128 | f"into {args.output}"
129 | )
130 |
131 | all_dfs = []
132 |
133 | for t in selected_tasks:
134 | idx = t["global_index"]
135 | n = t["n"]
136 | ds = t["ds"]
137 | alpha = t["alpha"]
138 | noise = t["noise"]
139 | correl = t["correl"]
140 | run_id = t["run"]
141 |
142 | print(
143 | f"\n=== Task {idx} ===\n"
144 | f"n={n}, ds={ds}, alpha={alpha}, noise={noise}, correl={correl}, run={run_id}"
145 | )
146 |
147 | # For reproducibility, seed based on global task index
148 | np.random.seed(idx)
149 |
150 | df = run_one_simulation(
151 | samples=n,
152 | ds=ds,
153 | scale=SCALE,
154 | correl=correl,
155 | alpha=alpha,
156 | y_dim=Y_DIM,
157 | y_noise=noise,
158 | )
159 |
160 | # Attach parameter metadata to each row in this simulation
161 | df["n"] = n
162 | df["d1"], df["d2"], df["d3"], df["d4"] = ds
163 | df["alpha1"], df["alpha2"], df["alpha3"], df["alpha4"] = alpha
164 | df["sigma"] = noise
165 | df["correl"] = correl
166 | df["run"] = run_id
167 | df["task_index"] = idx
168 |
169 | all_dfs.append(df)
170 |
171 | if not all_dfs:
172 | print("No tasks were selected; nothing to save.")
173 | return
174 |
175 | result = pd.concat(all_dfs, ignore_index=True)
176 | result.to_pickle(args.output)
177 | print(f"\nSaved {len(result)} rows to {args.output}")
178 |
179 |
180 | if __name__ == "__main__":
181 | main()
182 |
--------------------------------------------------------------------------------
/simulation_experiment.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib.pyplot as plt
3 | import pandas as pd
4 | from ridge_tools import cross_val_ridge, fit_predict, R2
5 | import sys
6 | from stacking import stacking_CV_fmri
7 | from concatenate import concatenate_CV_fmri
8 | from scipy.stats import zscore, multivariate_normal #, wishart
9 | from scipy.linalg import toeplitz
10 | import time
11 |
12 |
13 | def toeplitz_cov(n, scale=1):
14 | return toeplitz(np.exp(-(np.arange(n*1.0))**2/(n*scale)))
15 |
16 | def feat_sample(n,ds,scale):
17 | Xs = []
18 | for di in ds:
19 | X = multivariate_normal.rvs(np.zeros(di),cov=toeplitz_cov(di,scale),size=n)#reshape([n,di])
20 | Xs.append(X)
21 | return Xs
22 |
23 | def data_sample(Xs,correl,ds,scale,alpha,data_dim,noise=0):
24 | assert len(Xs) == len(alpha)
25 | ts = []
26 | y = 0
27 | d = sum(ds)
28 | cnt = 0
29 | wtot = multivariate_normal.rvs(mean=np.zeros(d),cov=toeplitz_cov(d,scale),
30 | size=data_dim).T#reshape([d,data_dim])
31 | for iX, X in enumerate(Xs):
32 | w = wtot[cnt:cnt+ds[iX],:]
33 | cnt += ds[iX]
34 | t = zscore(X.dot(w))
35 | ts.append(t)
36 | y += alpha[iX]*t
37 | # add correlated component
38 | if correl>0:
39 | if iX < len(Xs)-1: # add a correlated component from iX+1 to every iX
40 | X_tmp = Xs[iX+1].dot(multivariate_normal.rvs(np.zeros(ds[iX]),cov=toeplitz_cov(ds[iX],scale),size=ds[iX+1]))
41 | t_tmp = correl*zscore(X_tmp.dot(w))
42 | Xs[iX] += correl*X_tmp
43 | y += alpha[iX]*t_tmp
44 | else: # increase contribution of last feature space to keep alphas meaningful
45 | y += alpha[iX]*t*correl
46 | Xs[iX] *= (1+ correl)
47 | y = zscore(y)
48 | ns = noise*np.random.randn(y.shape[0], y.shape[1])
49 | y_orig = y
50 | y+= ns
51 | var_X = [R2(alpha[i]*ts[i],y) for i in range(len(Xs))]
52 | return y, var_X, Xs
53 |
54 | def sample_all_at_once(n,ds,scale,correl, alpha,data_dim,y_noise=0):
55 | Xs = feat_sample(n,ds,scale)
56 | y,var_X, Xs = data_sample(Xs,correl,ds,scale,alpha,data_dim,y_noise)
57 | return Xs, y, var_X
58 |
59 | # Experiment Functions
60 |
61 | import time
62 |
63 | def synexp(runs,sim_type, samples_settings,ds_settings,y_dim,alpha_settings,correl = 0,scale = 1,
64 | y_noise_settings=0):
65 |
66 | Results = pd.DataFrame()
67 |
68 | start = time.time()
69 | for run in range(runs):
70 | print('iteration number {}'.format(run+1))
71 | if sim_type == 'Feat_Dim_ratio': # Vary the dimensionality of X1 with respect to other feature spaces
72 | for ds in ds_settings:
73 | df = run_one_simulation(samples_settings,ds,scale,correl,alpha_settings,y_dim,
74 | y_noise_settings)
75 | df['Feat_Dim_ratio'] = ds[0]
76 | Results = pd.concat([Results,df],ignore_index=True)
77 | elif sim_type == 'Cond': # Vary the weight of X1 with respect to other feature spaces
78 | for alpha in alpha_settings:
79 | df = run_one_simulation(samples_settings,ds_settings,scale,correl,alpha,y_dim,
80 | y_noise_settings)
81 | df['Cond'] = alpha[0]
82 | Results = pd.concat([Results,df],ignore_index=True)
83 | elif sim_type == 'Sample_Dim_ratio': # Vary the number of samples
84 | for samples in samples_settings:
85 | df = run_one_simulation(samples,ds_settings,scale,correl,alpha_settings,y_dim,
86 | y_noise_settings)
87 | df['Sample_Dim_ratio'] = samples
88 | Results = pd.concat([Results,df],ignore_index=True)
89 | elif sim_type == 'noise': # Vary the noise level
90 | for y_noise in y_noise_settings:
91 | df = run_one_simulation(samples_settings,ds_settings,scale,correl,alpha_settings,y_dim,
92 | y_noise)
93 | df['noise'] = y_noise
94 | Results = pd.concat([Results,df],ignore_index=True)
95 | elif sim_type == 'correl': # Vary the feature space correlation level
96 | for correl_v in correl:
97 | df = run_one_simulation(samples_settings,ds_settings,scale,correl_v,alpha_settings,y_dim,
98 | y_noise_settings)
99 | df['correl'] = correl_v
100 | Results = pd.concat([Results,df],ignore_index=True)
101 |
102 | if run==0:
103 | time_int = (time.time() - start)
104 | print("first iteration time: {}, total {}".format(int(time_int), int(time_int*runs)))
105 |
106 |
107 |
108 | return Results
109 |
110 |
111 |
112 | def run_one_simulation(samples,ds,scale,correl,alpha,y_dim,y_noise):
113 | Xs, y, var_X = sample_all_at_once(samples,ds,scale,correl,alpha,y_dim,y_noise)
114 | print('data sampled')
115 | time_begin = time.time()
116 | y = zscore(y)
117 |
118 | concat_X = np.hstack(Xs)
119 |
120 | result = stacking_CV_fmri(y,Xs, method = 'cross_val_ridge',n_folds = 4)
121 | result2 = stacking_CV_fmri(y,Xs[1:], method = 'cross_val_ridge',n_folds = 4)
122 |
123 | print('time for stacking: {}'.format(time.time()-time_begin))
124 |
125 | df = pd.DataFrame()
126 |
127 | df['stacked'] = result[1]
128 | df['concat'] = concatenate_CV_fmri(y,Xs, method = 'cross_val_ridge',n_folds = 4)[0]
129 | df['max'] = np.max(result[0][0:2,:],axis=0)
130 | df['r2_X0'] = result[0][0]
131 |
132 | df['varpar_X0_concat'] = df['concat'] - concatenate_CV_fmri(y,Xs[1:], method = 'cross_val_ridge',n_folds = 4)[0]
133 |
134 | df['varpar_X0_stacked'] = df['stacked'] - result2[1]
135 |
136 | df['varpar_X0_real'] = var_X[0]
137 |
138 | df['weight_0'] = result[5][:,0]
139 | df['alpha_0'] = alpha[0]
140 |
141 | print('time for iteration: {}'.format(time.time()-time_begin))
142 |
143 |
144 | return df
145 |
146 |
--------------------------------------------------------------------------------
/visualize_sim.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns
2 |
3 | font = {'family' : 'sans-serif',
4 | 'weight' : 'normal',
5 | 'size' : 18}
6 | import matplotlib
7 | matplotlib.rc('font', **font)
8 | import matplotlib.pyplot as plt
9 | import numpy as np
10 |
11 |
12 | def box_plot(ax, data, edge_color, fill_color, positions=None):
13 | bp = ax.boxplot(data, patch_artist=True)
14 |
15 | for element in ['boxes', 'whiskers', 'fliers', 'means', 'medians', 'caps']:
16 | plt.setp(bp[element], color=edge_color)
17 |
18 | for patch in bp['boxes']:
19 | patch.set(facecolor=fill_color)
20 |
21 | return bp
22 |
23 |
24 | def sim_plots(Results,name,settings,filename='',var_dict={}, ylim0 = [-0.1, 1],
25 | ylim1=[0,0.5],ylim2=[0,0.5]): # shift=0.2,:
26 |
27 | import math
28 |
29 | var = sorted(list(set(Results[name].tolist())))
30 | print(var)
31 |
32 | stack_avg = []
33 | concat_avg = []
34 | max_avg = []
35 | varpar_X0_concat = []
36 | varpar_X0_stacked = []
37 | varpar_X0_real = []
38 |
39 |
40 | for px in var:
41 | stack_avg.append(Results[Results[name]==px]['stacked'])
42 | concat_avg.append(Results[Results[name]==px]['concat'])
43 | max_avg.append(Results[Results[name]==px]['max'])
44 | varpar_X0_real.append(Results[Results[name]==px]['varpar_X0_real'])
45 | varpar_X0_concat.append(Results[Results[name]==px]['varpar_X0_concat'])
46 | varpar_X0_stacked.append(Results[Results[name]==px]['varpar_X0_stacked'])
47 |
48 | c_fill = 'white'
49 | pos = None
50 |
51 | fig, axs = plt.subplots(1, 2,figsize=(20,5))
52 |
53 |
54 | varpar_X0_concat = np.array(varpar_X0_concat)
55 | varpar_X0_stacked = np.array(varpar_X0_stacked)
56 | varpar_X0_real = np.array(varpar_X0_real)
57 |
58 | import pandas as pd
59 | vec = [[np.abs(varpar_X0_stacked[k,i] - varpar_X0_real[k,i]),k,'stack'] for i in range(varpar_X0_stacked.shape[1])
60 | for k in range(varpar_X0_stacked.shape[0])]
61 | vec_concat = [[np.abs(varpar_X0_concat[k,i] - varpar_X0_real[k,i]),k,'concat'] for i in range(varpar_X0_concat.shape[1])
62 | for k in range(varpar_X0_concat.shape[0])]
63 | r2_results = pd.concat([pd.DataFrame(vec,columns = ['Err','setting','method']),
64 | pd.DataFrame(vec_concat,columns = ['Err','setting','method'])])
65 |
66 | sns.boxplot(ax=axs[1], x="setting", y="Err", hue="method",data=r2_results, palette="Set2")
67 | sns.move_legend(axs[1], "upper left", bbox_to_anchor=(1, 1))
68 |
69 |
70 | axs[1].set_ylim(ylim1)
71 | if name=='noise':
72 | axs[1].set_xlabel(r'$\sigma$')
73 | if name=='Feat_Dim_ratio':
74 | axs[1].set_xlabel(r'$d_1$')
75 | if name=='Cond':
76 | axs[1].set_xlabel(r'$\alpha_1$')
77 | if name=='Sample_Dim_ratio':
78 | axs[1].set_xlabel(r'$n$')
79 |
80 |
81 | axs[1].set_ylabel(r'Var. par. error for ${\bf x}_1$')
82 | axs[1].set_ylim(ylim2)
83 | axs[1].set_xticks(range(len(settings)))
84 | axs[1].set_xticklabels(settings,rotation=45)
85 |
86 | from scipy.stats import ttest_rel
87 |
88 | for idx,s in enumerate(settings):
89 | setting_res = r2_results[r2_results['setting']==idx]
90 | stack_res = [e for e in setting_res[setting_res['method']=='stack']['Err']]
91 | concat_res = [e for e in setting_res[setting_res['method']=='concat']['Err']]
92 | for i in np.arange(len(stack_res)):
93 | axs[1].plot([idx-0.22,idx+0.22], [stack_res[i],concat_res[i]],'-',linewidth=0.25,color='black')
94 | tstat,pval = ttest_rel( np.array(concat_res), np.array(stack_res),alternative='two-sided')
95 | if pval<0.05/20:
96 | if tstat>0:
97 | axs[1].plot(idx,0.45,'*',color = 'black',markersize=10)
98 | else:
99 | axs[1].plot(idx,0.45,'o',color = 'black',markersize=10)
100 | # blabla
101 |
102 | stack_avg = np.array(stack_avg)
103 | max_avg = np.array(max_avg)
104 | concat_avg = np.array(concat_avg)
105 |
106 | import pandas as pd
107 | vec = [[stack_avg[k,i] ,k,'stack'] for i in range(stack_avg.shape[1])
108 | for k in range(stack_avg.shape[0])]
109 | vec_concat = [[concat_avg[k,i] ,k,'concat'] for i in range(concat_avg.shape[1])
110 | for k in range(concat_avg.shape[0])]
111 | vec_max = [[max_avg[k,i] ,k,'max'] for i in range(max_avg.shape[1])
112 | for k in range(max_avg.shape[0])]
113 | r2_results = pd.concat([pd.DataFrame(vec,columns = ['R2','setting','method']),
114 | pd.DataFrame(vec_concat,columns = ['R2','setting','method']),
115 | pd.DataFrame(vec_max,columns = ['R2','setting','method'])])
116 |
117 | sns.boxplot(ax=axs[0], x="setting", y="R2", hue="method",data=r2_results, palette="hot_r")
118 | sns.move_legend(axs[0], "upper right", bbox_to_anchor=(-0.18, 1))
119 |
120 | mean_stack = np.mean(stack_avg)
121 | if name=='noise':
122 | axs[0].set_xlabel(r'$\sigma$')
123 | if name=='Feat_Dim_ratio':
124 | axs[0].set_xlabel(r'$d_1$')
125 | if name=='Cond':
126 | axs[0].set_xlabel(r'$\alpha_1$')
127 | if name=='Sample_Dim_ratio':
128 | axs[0].set_xlabel(r'$n$')
129 | if name=='correl':
130 | axs[0].set_xlabel(r'$\rho$')
131 |
132 | for idx,s in enumerate(settings):
133 | for i in np.arange(stack_avg.shape[1]):
134 | axs[0].plot([idx-0.25,idx,idx+0.25], [stack_avg[idx,i],concat_avg[idx,i],max_avg[idx,i]],
135 | '-',linewidth=0.15,color='black')
136 | # bla
137 |
138 | axs[0].set_ylabel(r'$R^2$')
139 | axs[0].set_ylim(ylim0)
140 | axs[0].set_xticks(range(len(settings)))
141 | axs[0].set_xticklabels(settings,rotation=45)
142 |
143 |
144 |
145 |
146 | if name=='noise':
147 | plt.suptitle(r'Vary $\sigma:~\alpha = {},~d={},~n={}$'.format(var_dict['alphas'],var_dict['ds'],var_dict['n']))
148 | if name=='Feat_Dim_ratio':
149 | plt.suptitle(r'Vary $d_1:~\sigma = {},~\alpha = {},~d_2+d_3+d_4={},~n={}$'.format(var_dict['sigma'], var_dict['alphas'], var_dict['d_sum'], var_dict['n']))
150 | if name=='Cond':
151 | plt.suptitle(r'Vary $\alpha_1:~\sigma = {},~d={},~n={}$'.format(var_dict['sigma'], var_dict['ds'],var_dict['n']))
152 | if name=='Sample_Dim_ratio':
153 | plt.suptitle(r'Vary $n:~\sigma = {},~\alpha = {},~d={}$'.format(var_dict['sigma'],var_dict['alpha'],var_dict['ds']))
154 | if name=='correl':
155 | plt.suptitle(r'Vary $\rho:~\sigma = {},~\alpha = {},~d={}, ~n={}$'.format(var_dict['sigma'],var_dict['alpha'],var_dict['ds'],
156 | var_dict['n']))
157 |
158 |
159 |
160 | import time
161 | timestr = time.strftime("%Y%m%d-%H%M%S")
162 | print(timestr)
163 | plt.tight_layout()
164 | plt.savefig(filename+timestr+'.jpg')
165 | plt.show()
166 |
167 |
--------------------------------------------------------------------------------
/ridge_tools.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 | import time
4 | import numpy as np
5 | from scipy.stats import zscore
6 | from numpy.linalg import inv, svd
7 | from sklearn.model_selection import KFold
8 | from sklearn.linear_model import Ridge, RidgeCV
9 |
10 |
11 | def corr(X, Y, axis=0):
12 | """Compute correlation coefficient."""
13 | return np.mean(zscore(X) * zscore(Y), axis)
14 |
15 |
16 | def R2(Pred, Real):
17 | """Compute coefficient of determination (R^2)."""
18 | SSres = np.mean((Real - Pred) ** 2, 0)
19 | SStot = np.var(Real, 0)
20 | return np.nan_to_num(1 - SSres / SStot)
21 |
22 |
23 | def fit_predict(data, features, method="plain", n_folds=10):
24 | """
25 | Fit and predict using cross-validated Ridge regression.
26 |
27 | Args:
28 | data (numpy.ndarray): The data array.
29 | features (numpy.ndarray): The features array.
30 | method (str): The Ridge regression method. Defaults to 'plain'.
31 | n_folds (int): The number of folds for cross-validation. Defaults to 10.
32 |
33 | Returns:
34 | tuple: Tuple containing the correlation and R^2 values.
35 | """
36 | n, v = data.shape
37 | p = features.shape[1]
38 | corrs = np.zeros((n_folds, v))
39 | R2s = np.zeros((n_folds, v))
40 | ind = CV_ind(n, n_folds)
41 | preds_all = np.zeros_like(data)
42 |
43 | for i in range(n_folds):
44 | train_data = np.nan_to_num(zscore(data[ind != i]))
45 | train_features = np.nan_to_num(zscore(features[ind != i]))
46 | test_data = np.nan_to_num(zscore(data[ind == i]))
47 | test_features = np.nan_to_num(zscore(features[ind == i]))
48 | weights, __ = cross_val_ridge(train_features, train_data, method=method)
49 | preds = np.dot(test_features, weights)
50 | preds_all[ind == i] = preds
51 |
52 | corrs = corr(preds_all, data)
53 | R2s = R2(preds_all, data)
54 |
55 | return corrs, R2s
56 |
57 |
58 | def CV_ind(n, n_folds):
59 | """Generate cross-validation indices."""
60 | ind = np.zeros((n))
61 | n_items = int(np.floor(n / n_folds))
62 |
63 | for i in range(0, n_folds - 1):
64 | ind[i * n_items : (i + 1) * n_items] = i
65 |
66 | ind[(n_folds - 1) * n_items :] = n_folds - 1
67 |
68 | return ind
69 |
70 |
71 | def R2r(Pred, Real):
72 | """Compute square root of R^2."""
73 | R2rs = R2(Pred, Real)
74 | ind_neg = R2rs < 0
75 | R2rs = np.abs(R2rs)
76 | R2rs = np.sqrt(R2rs)
77 | R2rs[ind_neg] *= -1
78 |
79 | return R2rs
80 |
81 |
82 | def ridge(X, Y, lmbda):
83 | """Compute ridge regression weights."""
84 | return np.dot(inv(X.T.dot(X) + lmbda * np.eye(X.shape[1])), X.T.dot(Y))
85 |
86 |
87 | def ridge_by_lambda(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])):
88 | """Compute validation errors for ridge regression with different lambda values."""
89 | error = np.zeros((lambdas.shape[0], Y.shape[1]))
90 | for idx, lmbda in enumerate(lambdas):
91 | weights = ridge(X, Y, lmbda)
92 | error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
93 | return error
94 |
95 | def ridge_sk(X, Y, lmbda):
96 | """Compute ridge regression weights using scikit-learn."""
97 | rd = Ridge(alpha=lmbda)
98 | rd.fit(X, Y)
99 | return rd.coef_.T
100 |
101 |
102 | def ridgeCV_sk(X, Y, lmbdas):
103 | """Compute ridge regression weights using scikit-learn with cross-validation."""
104 | rd = RidgeCV(alphas=lmbdas, solver="svd")
105 | rd.fit(X, Y)
106 | return rd.coef_.T
107 |
108 |
109 | def ridge_by_lambda_sk(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])):
110 | """Compute validation errors for ridge regression with different lambda values using scikit-learn."""
111 | error = np.zeros((lambdas.shape[0], Y.shape[1]))
112 | for idx, lmbda in enumerate(lambdas):
113 | weights = ridge_sk(X, Y, lmbda)
114 | error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
115 | return error
116 |
117 |
118 | def ridge_svd(X, Y, lmbda):
119 | """
120 | Ridge regression using singular value decomposition (SVD).
121 | """
122 | U, s, Vt = svd(X, full_matrices=False)
123 | d = s / (s**2 + lmbda)
124 | return np.dot(Vt, np.diag(d).dot(U.T.dot(Y)))
125 |
126 |
127 | def ridge_by_lambda_svd(X, Y, Xval, Yval, lambdas=np.array([0.1, 1, 10, 100, 1000])):
128 | """
129 | Calculate the validation error of ridge regression using SVD for different lambdas.
130 | """
131 | error = np.zeros((lambdas.shape[0], Y.shape[1]))
132 | U, s, Vt = svd(X, full_matrices=False)
133 | for idx, lmbda in enumerate(lambdas):
134 | d = s / (s**2 + lmbda)
135 | weights = np.dot(Vt, np.diag(d).dot(U.T.dot(Y)))
136 | error[idx] = 1 - R2(np.dot(Xval, weights), Yval)
137 | return error
138 |
139 |
140 | def cross_val_ridge(
141 | train_features,
142 | train_data,
143 | n_splits=10,
144 | lambdas=np.array([10**i for i in range(-6, 10)]),
145 | method="plain",
146 | do_plot=False,
147 | ):
148 | """
149 | Cross validation for ridge regression.
150 |
151 | Args:
152 | train_features (array): Array of training features.
153 | train_data (array): Array of training data.
154 | lambdas (array): Array of lambda values for Ridge regression.
155 | Default is [10^i for i in range(-6, 10)].
156 |
157 | Returns:
158 | weightMatrix (array): Array of weights for the Ridge regression.
159 | r (array): Array of regularization parameters.
160 |
161 | """
162 |
163 | ridge_1 = {
164 | "plain": ridge_by_lambda,
165 | "svd": ridge_by_lambda_svd,
166 | "ridge_sk": ridge_by_lambda_sk,
167 | }[
168 | method
169 | ] # loss of the regressor
170 |
171 | ridge_2 = {"plain": ridge, "svd": ridge_svd, "ridge_sk": ridge_sk,}[
172 | method
173 | ] # solver for the weights
174 |
175 | n_voxels = train_data.shape[1] # get number of voxels from data
176 | nL = lambdas.shape[0] # get number of hyperparameter (lambdas) from setting
177 | r_cv = np.zeros((nL, train_data.shape[1])) # loss matrix
178 |
179 | kf = KFold(n_splits=n_splits) # set up dataset for cross validation
180 | start_t = time.time() # record start time
181 | for icv, (trn, val) in enumerate(kf.split(train_data)):
182 | cost = ridge_1(
183 | zscore(train_features[trn]),
184 | zscore(train_data[trn]),
185 | zscore(train_features[val]),
186 | zscore(train_data[val]),
187 | lambdas=lambdas,
188 | ) # loss of regressor 1
189 |
190 | if do_plot:
191 | import matplotlib.pyplot as plt
192 |
193 | plt.figure()
194 | plt.imshow(cost, aspect="auto")
195 |
196 | r_cv += cost
197 |
198 | if do_plot: # show loss
199 | plt.figure()
200 | plt.imshow(r_cv, aspect="auto", cmap="RdBu_r")
201 |
202 | argmin_lambda = np.argmin(r_cv, axis=0) # pick the best lambda
203 | weights = np.zeros(
204 | (train_features.shape[1], train_data.shape[1])
205 | ) # initialize the weight
206 | for idx_lambda in range(
207 | lambdas.shape[0]
208 | ): # this is much faster than iterating over voxels!
209 | idx_vox = argmin_lambda == idx_lambda
210 | if np.sum(idx_vox*1.0)>0:
211 | weights[:, idx_vox] = ridge_2(
212 | train_features, train_data[:, idx_vox], lambdas[idx_lambda]
213 | )
214 |
215 | if do_plot: # show the weights
216 | plt.figure()
217 | plt.imshow(weights, aspect="auto", cmap="RdBu_r", vmin=-0.5, vmax=0.5)
218 |
219 | return weights, np.array([lambdas[i] for i in argmin_lambda])
220 |
--------------------------------------------------------------------------------
/variance_partitioning.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "2ed69d89",
6 | "metadata": {},
7 | "source": [
8 | "### This notebook demonstrates the variance partitioning process for stacking or concatenation. The experiment is conducted in either a forward or backward style, followed by loading the results and applying the \"95% criteria\" to determine feature attribution."
9 | ]
10 | },
11 | {
12 | "cell_type": "markdown",
13 | "id": "de68d00f",
14 | "metadata": {},
15 | "source": [
16 | "## Forward"
17 | ]
18 | },
19 | {
20 | "cell_type": "code",
21 | "execution_count": null,
22 | "id": "b25746b2",
23 | "metadata": {},
24 | "outputs": [],
25 | "source": [
26 | "import pickle\n",
27 | "subj = 1\n",
28 | "r2s = []\n",
29 | "r2sr = []\n",
30 | "\n",
31 | "\n",
32 | "# Get R2 for using only the first layer (Conv1 in AlexNet) and the last layer(FC7 in AlexNet)\n",
33 | "r2 = []\n",
34 | "for i in range(1,23):\n",
35 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_r2s_{}.npy'.format(0,i)).squeeze())\n",
36 | "print(np.hstack(r2).shape)\n",
37 | "r2_1 = np.hstack(r2)[0,:]\n",
38 | "r2_7 = np.hstack(r2)[-1,:]\n",
39 | "\n",
40 | "\n",
41 | "# Loading forward experiments results\n",
42 | "for k in range(6):\n",
43 | " r2 = []\n",
44 | " for i in range(1,23):\n",
45 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_stacked_r2s_{}.npy'.format(k,i)).squeeze())\n",
46 | " r2 = np.hstack(r2)\n",
47 | " r2s.append(r2)\n",
48 | "r2s.append(r2_7) \n",
49 | "r2s = np.array(r2s)\n",
50 | "print(r2s.shape)\n",
51 | " \n",
52 | "\n",
53 | "# Calculate variance by difference of R2\n",
54 | "vp = np.zeros((r2s.shape[0],r2s.shape[1]))\n",
55 | "vp_square = np.zeros((r2s.shape[0],r2s.shape[1]))\n",
56 | "for j in range(vp.shape[0]-1):\n",
57 | " vp[j,:] = np.sqrt(r2s[j,:] - r2s[j+1,:])\n",
58 | " vp_square[j,:] = (r2s[j,:] - r2s[j+1,:])\n",
59 | " \n",
60 | "vp[-1] = np.sqrt(r2_7)\n",
61 | "vp_square[-1,:] = r2_7\n",
62 | "\n",
63 | "\n",
64 | "print(vp.shape)\n",
65 | "print(vp_square.shape)\n"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "id": "f71d0411",
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "# Apply our 95% Criteria in forward style: \n",
76 | "# If removing feature from one specific layer will cause the encoding performance lower to less than 95% of the performance using features from all layers,\n",
77 | "# We will assign that voxel to this layer as our feature attribution result\n",
78 | "\n",
79 | "\n",
80 | "vp_sel_layer = np.zeros(r2s.shape[1])\n",
81 | "\n",
82 | "\n",
83 | "for i in range(r2s.shape[1]):\n",
84 | " if r2s[0,i]<=0:\n",
85 | " vp_sel_layer[i] = float('nan')\n",
86 | " continue\n",
87 | " if r2s[1,i]<0.95*r2s[0,i]:\n",
88 | " vp_sel_layer[i] = 1\n",
89 | " continue\n",
90 | " if r2s[2,i]<0.95*r2s[0,i]:\n",
91 | " vp_sel_layer[i] = 2\n",
92 | " continue\n",
93 | " if r2s[3,i]<0.95*r2s[0,i]:\n",
94 | " vp_sel_layer[i] = 3\n",
95 | " continue\n",
96 | " if r2s[4,i]<0.95*r2s[0,i]:\n",
97 | " vp_sel_layer[i] = 4\n",
98 | " continue\n",
99 | " if r2s[5,i]<0.95*r2s[0,i]:\n",
100 | " vp_sel_layer[i] = 5\n",
101 | " continue\n",
102 | " if r2s[6,i]<0.95*r2s[0,i]:\n",
103 | " vp_sel_layer[i] = 6\n",
104 | " continue\n",
105 | " vp_sel_layer[i] = 7\n",
106 | " \n",
107 | " \n",
108 | "print(vp_sel_layer)\n"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": null,
114 | "id": "a6396b02",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": []
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": null,
122 | "id": "c91cc1cb",
123 | "metadata": {},
124 | "outputs": [],
125 | "source": []
126 | },
127 | {
128 | "cell_type": "markdown",
129 | "id": "ff7877ef",
130 | "metadata": {},
131 | "source": [
132 | "## Backward"
133 | ]
134 | },
135 | {
136 | "cell_type": "code",
137 | "execution_count": null,
138 | "id": "ee6bf576",
139 | "metadata": {},
140 | "outputs": [],
141 | "source": [
142 | "import pickle\n",
143 | "subj = 1\n",
144 | "r2sr = []\n",
145 | "\n",
146 | "\n",
147 | "# Get R2 for using only the first layer (Conv1 in AlexNet) and the last layer(FC7 in AlexNet)\n",
148 | "r2 = []\n",
149 | "for i in range(1,23):\n",
150 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_r2s_{}.npy'.format(0,i)).squeeze())\n",
151 | "print(np.hstack(r2).shape)\n",
152 | "r2_1 = np.hstack(r2)[0,:]\n",
153 | "r2_7 = np.hstack(r2)[-1,:]\n",
154 | "\n",
155 | "\n",
156 | "\n",
157 | "# Loading backward experiments results\n",
158 | "for k in range(6):\n",
159 | " r2 = []\n",
160 | " for i in range(1,23):\n",
161 | " r2.append(np.load('/home/lrg1213/DATA1/subj1_stack_alexnet_vp_1024/{}_r_stacked_r2s_{}.npy'.format(k,i)).squeeze())\n",
162 | " r2 = np.hstack(r2)\n",
163 | " r2sr.append(r2)\n",
164 | "r2sr.append(r2_1) \n",
165 | "r2sr = np.array(r2sr)\n",
166 | "\n",
167 | "print(r2sr.shape)\n",
168 | " \n",
169 | "\n",
170 | "# Calculate variance by difference of R2\n",
171 | "vpr = np.zeros((r2sr.shape[0],r2sr.shape[1]))\n",
172 | "vpr_square = np.zeros((r2sr.shape[0],r2sr.shape[1]))\n",
173 | "for j in range(vpr.shape[0]-1):\n",
174 | " vpr[j,:] = np.sqrt(r2sr[j,:] - r2sr[j+1,:])\n",
175 | " vpr_square[j,:] = (r2sr[j,:] - r2sr[j+1,:])\n",
176 | "\n",
177 | "vpr[-1] = np.sqrt(r2_1)\n",
178 | "vpr_square[-1,:] = r2_1\n",
179 | "\n",
180 | "print(vpr.shape)\n",
181 | "print(vpr_square.shape)\n"
182 | ]
183 | },
184 | {
185 | "cell_type": "code",
186 | "execution_count": null,
187 | "id": "760923a9",
188 | "metadata": {},
189 | "outputs": [],
190 | "source": [
191 | "# Apply our 95% Criteria in backward style: \n",
192 | "# If removing feature from one specific layer will cause the encoding performance lower to less than 95% of the performance using features from all layers,\n",
193 | "# We will assign that voxel to this layer as our feature attribution result\n",
194 | "\n",
195 | "\n",
196 | "vpr_sel_layer = np.zeros(r2sr.shape[1])\n",
197 | "\n",
198 | "for i in range(r2sr.shape[1]):\n",
199 | " if r2sr[0,i]<=0:\n",
200 | " vpr_sel_layer[i] = float('nan')\n",
201 | " continue\n",
202 | " if r2sr[1,i]<0.95*r2sr[0,i]:\n",
203 | " vpr_sel_layer[i] = 7\n",
204 | " continue\n",
205 | " if r2sr[2,i]<0.95*r2sr[0,i]:\n",
206 | " vpr_sel_layer[i] = 6\n",
207 | " continue\n",
208 | " if r2sr[3,i]<0.95*r2sr[0,i]:\n",
209 | " vpr_sel_layer[i] = 5\n",
210 | " continue\n",
211 | " if r2sr[4,i]<0.95*r2sr[0,i]:\n",
212 | " vpr_sel_layer[i] = 4\n",
213 | " continue\n",
214 | " if r2sr[5,i]<0.95*r2sr[0,i]:\n",
215 | " vpr_sel_layer[i] = 3\n",
216 | " continue\n",
217 | " if r2sr[6,i]<0.95*r2sr[0,i]:\n",
218 | " vpr_sel_layer[i] = 2\n",
219 | " continue\n",
220 | " vpr_sel_layer[i] = 1\n",
221 | " \n",
222 | "print(vpr_sel_layer)\n"
223 | ]
224 | }
225 | ],
226 | "metadata": {
227 | "kernelspec": {
228 | "display_name": "Python 3 (ipykernel)",
229 | "language": "python",
230 | "name": "python3"
231 | },
232 | "language_info": {
233 | "codemirror_mode": {
234 | "name": "ipython",
235 | "version": 3
236 | },
237 | "file_extension": ".py",
238 | "mimetype": "text/x-python",
239 | "name": "python",
240 | "nbconvert_exporter": "python",
241 | "pygments_lexer": "ipython3",
242 | "version": "3.8.13"
243 | }
244 | },
245 | "nbformat": 4,
246 | "nbformat_minor": 5
247 | }
248 |
--------------------------------------------------------------------------------
/simulate.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "7e439802",
6 | "metadata": {},
7 | "source": [
8 | "# simulate stacking experiment"
9 | ]
10 | },
11 | {
12 | "cell_type": "code",
13 | "execution_count": 1,
14 | "id": "f2850c2e-1883-4126-8140-901d5898eb7e",
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import numpy as np\n",
19 | "from simulation_experiment import synexp\n",
20 | "from visualize_sim import sim_plots\n",
21 | "import pandas as pd\n",
22 | "\n",
23 | "n_runs = 50 # how many simulation experiments\n",
24 | "run_sim = True # run experiment?\n",
25 | "plot_sim = True # plot results?\n",
26 | "\n",
27 | "type_sim = 1 # what parameter to vary?\n",
28 | "# 1 - Vary the dimensionality of X1 with respect to other feature spaces\n",
29 | "# 2 - Vary the weight of X1 with respect to other feature spaces\n",
30 | "# 3 - Vary the noise level\n",
31 | "# 4 - Vary the number of samples\n",
32 | "# 5 - Vary feature space correlation\n",
33 | "\n",
34 | "# Note, the experiments below are from the paper, you can change the parameters as you want\n",
35 | "# It might be more feasible to make a script and call synexp non-interactively, especially for large experiments"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": null,
41 | "id": "b8faefb0-2b21-46b2-9d2b-6b03b03b5251",
42 | "metadata": {},
43 | "outputs": [
44 | {
45 | "name": "stdout",
46 | "output_type": "stream",
47 | "text": [
48 | "iteration number 1\n",
49 | "data sampled\n",
50 | "time for stacking: 5.576988935470581\n",
51 | "time for iteration: 16.36051297187805\n",
52 | "data sampled\n",
53 | "time for stacking: 5.68248987197876\n",
54 | "time for iteration: 17.94617486000061\n",
55 | "data sampled\n",
56 | "time for stacking: 10.574278831481934\n",
57 | "time for iteration: 24.7062509059906\n",
58 | "data sampled\n",
59 | "time for stacking: 7.50465989112854\n",
60 | "time for iteration: 21.93230676651001\n",
61 | "data sampled\n",
62 | "time for stacking: 8.925495147705078\n",
63 | "time for iteration: 29.840813159942627\n",
64 | "first iteration time: 111, total 5556\n",
65 | "iteration number 2\n",
66 | "data sampled\n",
67 | "time for stacking: 13.184641122817993\n",
68 | "time for iteration: 35.82826232910156\n",
69 | "data sampled\n",
70 | "time for stacking: 15.410903215408325\n",
71 | "time for iteration: 36.77866816520691\n",
72 | "data sampled\n",
73 | "time for stacking: 12.556435108184814\n",
74 | "time for iteration: 30.479616165161133\n",
75 | "data sampled\n",
76 | "time for stacking: 12.968762159347534\n",
77 | "time for iteration: 31.550554275512695\n",
78 | "data sampled\n",
79 | "time for stacking: 9.648847818374634\n",
80 | "time for iteration: 26.950864791870117\n",
81 | "iteration number 3\n",
82 | "data sampled\n",
83 | "time for stacking: 9.796719074249268\n",
84 | "time for iteration: 26.401182889938354\n",
85 | "data sampled\n",
86 | "time for stacking: 8.374092102050781\n",
87 | "time for iteration: 23.136788845062256\n",
88 | "data sampled\n",
89 | "time for stacking: 8.078734874725342\n",
90 | "time for iteration: 30.22462296485901\n",
91 | "data sampled\n",
92 | "time for stacking: 13.18555498123169\n",
93 | "time for iteration: 34.938925981521606\n",
94 | "data sampled\n",
95 | "time for stacking: 13.350465059280396\n",
96 | "time for iteration: 33.98402714729309\n",
97 | "iteration number 4\n",
98 | "data sampled\n",
99 | "time for stacking: 11.148561239242554\n",
100 | "time for iteration: 29.419289112091064\n",
101 | "data sampled\n",
102 | "time for stacking: 10.736790180206299\n",
103 | "time for iteration: 30.648571968078613\n",
104 | "data sampled\n",
105 | "time for stacking: 12.025328159332275\n",
106 | "time for iteration: 28.129180908203125\n",
107 | "data sampled\n",
108 | "time for stacking: 9.363476276397705\n",
109 | "time for iteration: 22.97734808921814\n",
110 | "data sampled\n",
111 | "time for stacking: 7.306324005126953\n",
112 | "time for iteration: 20.82240104675293\n",
113 | "iteration number 5\n",
114 | "data sampled\n",
115 | "time for stacking: 7.735720157623291\n",
116 | "time for iteration: 24.28347420692444\n",
117 | "data sampled\n",
118 | "time for stacking: 8.583776950836182\n",
119 | "time for iteration: 23.722854137420654\n",
120 | "data sampled\n",
121 | "time for stacking: 8.153543710708618\n",
122 | "time for iteration: 24.09500002861023\n",
123 | "data sampled\n",
124 | "time for stacking: 12.154590129852295\n"
125 | ]
126 | }
127 | ],
128 | "source": [
129 | "import random\n",
130 | "random.seed(10)\n",
131 | "\n",
132 | "type_sim = 5 # what parameter to vary?\n",
133 | "\n",
134 | "\n",
135 | "version = 'v5'\n",
136 | "\n",
137 | "# ds_settings - dimentionality of each feature space\n",
138 | "# alpha - stacking weights for simulated data\n",
139 | "# n - number of samples \n",
140 | "# sigma - noise evel\n",
141 | "\n",
142 | "if type_sim ==1: # Vary the dimensionality of X1 with respect to other feature spaces\n",
143 | "\n",
144 | " ds_settings = [[2,10,10,10],[5,10,10,10],[10,10,10,10],[20,10,10,10],[40,10,10,10]]\n",
145 | " n = 100\n",
146 | " sigma = 0.2\n",
147 | " correl = 0.2\n",
148 | " alpha = [0.5,0.2,0.2,0.1] \n",
149 | "\n",
150 | " print('setup')\n",
151 | "\n",
152 | " if run_sim:\n",
153 | " print('running')\n",
154 | " R_d3 = synexp(runs = n_runs,sim_type = 'Feat_Dim_ratio',samples_settings=n,ds_settings=ds_settings,y_dim=2,\n",
155 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n",
156 | " \n",
157 | " R_d3.to_pickle('sweep_d_cluster_{}.npy'.format(version))\n",
158 | " \n",
159 | " if plot_sim:\n",
160 | " R_d3 = pd.read_pickle('sweep_d_cluster_{}.npy'.format(version))\n",
161 | " var = dict(sigma = sigma, d_sum = sum(ds_settings[0][1:]), n = n,alphas = alpha)\n",
162 | " ratio = [d[0] for d in ds_settings]\n",
163 | " sim_plots(R_d3,'Feat_Dim_ratio', ratio, filename = 'sweep_d',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n",
164 | " ylim2=[0,0.5])\n",
165 | "\n",
166 | "elif type_sim ==2: # Vary the weight of X1 with respect to other feature spaces\n",
167 | "\n",
168 | " ds = [10,10,10,10]\n",
169 | " n = 100\n",
170 | " correl = 0.2\n",
171 | " sigma = 0.2\n",
172 | " alpha = [[0.1,0.2,0.5,0.2],[0.3,0.2,0.3,0.2],[0.5,0.2,0.2,0.1],[0.7,0.1,0,0.2],[0.9,0.1,0.,0.]] \n",
173 | " for a in alpha:\n",
174 | " assert np.round(sum(a),2)==1\n",
175 | "\n",
176 | " if run_sim:\n",
177 | " R_C3 = synexp(runs = n_runs,sim_type = 'Cond',samples_settings=n,ds_settings=ds,y_dim=2,\n",
178 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n",
179 | " \n",
180 | " R_C3.to_pickle('sweep_c_cluster_{}.npy'.format(version))\n",
181 | " \n",
182 | " if plot_sim:\n",
183 | " R_C3 = pd.read_pickle('sweep_c_cluster_{}.npy'.format(version))\n",
184 | " var = dict(sigma = sigma, ds = ds, n = n,alphas = alpha)\n",
185 | " \n",
186 | " ratio = [a[0] for a in alpha]\n",
187 | " sim_plots(R_C3,'Cond', ratio, filename = 'sweep_alpha',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n",
188 | " ylim2=[0,0.5])\n",
189 | "\n",
190 | "\n",
191 | "elif type_sim ==3: # Vary the noise level\n",
192 | "\n",
193 | " ds = [10,10,10,10]\n",
194 | " n = 100\n",
195 | " sigma = [0,0.2,0.5,1,1.5]\n",
196 | " alpha = [0.5,0.2,0.2,0.1]\n",
197 | " correl = 0.2\n",
198 | " assert np.round(sum(alpha),2)==1\n",
199 | "\n",
200 | " if run_sim:\n",
201 | " R_sigma3 =synexp(runs = n_runs,sim_type = 'noise',samples_settings=n,ds_settings=ds,y_dim=2,\n",
202 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n",
203 | " \n",
204 | " R_sigma3.to_pickle('sweep_sigma_cluster_{}.npy'.format(version))\n",
205 | " \n",
206 | " if plot_sim:\n",
207 | " R_sigma3 = pd.read_pickle('sweep_sigma_cluster_{}.npy'.format(version))\n",
208 | " var = dict( ds = ds, alphas = alpha,n = n)\n",
209 | " sim_plots(R_sigma3,'noise', sigma, filename = 'sweep_sigma',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n",
210 | " ylim2=[0,0.5])\n",
211 | "\n",
212 | "elif type_sim ==4: # Vary the number of samples\n",
213 | " \n",
214 | " # ds = [10,10,10,10]\n",
215 | " ds = [100,100,100,100]\n",
216 | " # n = [30,40,60,100,200]\n",
217 | " n = [100,200,400, 800]\n",
218 | " sigma = 0.2\n",
219 | " alpha = [0.5,0.2,0.2,0.1]\n",
220 | " correl = 0.2\n",
221 | " assert np.round(sum(alpha),2)==1\n",
222 | "\n",
223 | " if run_sim:\n",
224 | " R_n3 = synexp(runs = n_runs,sim_type = 'Sample_Dim_ratio',samples_settings=n,ds_settings=ds,y_dim=2,\n",
225 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n",
226 | " \n",
227 | " R_n3.to_pickle('sweep_n_cluster_{}.npy'.format(version))\n",
228 | " \n",
229 | " if plot_sim:\n",
230 | " R_n3 = pd.read_pickle('sweep_n_cluster_{}.npy'.format(version))\n",
231 | " \n",
232 | " var = dict(sigma = sigma, ds = ds, alpha = alpha)\n",
233 | " sim_plots(R_n3,'Sample_Dim_ratio', n, filename = 'sweep_n',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n",
234 | " ylim2=[0,0.5])\n",
235 | "\n",
236 | " \n",
237 | "elif type_sim ==5: # Vary the correlation\n",
238 | " \n",
239 | " # ds = [10,10,10,10] \n",
240 | " ds = [100,100,100,100] \n",
241 | " # n = 100\n",
242 | " n = 400\n",
243 | " sigma = 0.2\n",
244 | " alpha = [0.5,0.2,0.2,0.1]\n",
245 | " correl = [0,0.1,0.2, 0.5, 0.8]\n",
246 | " assert np.round(sum(alpha),2)==1\n",
247 | "\n",
248 | " if run_sim:\n",
249 | " R_rho = synexp(runs = n_runs,sim_type = 'correl',samples_settings=n,ds_settings=ds,y_dim=2,\n",
250 | " correl = correl, alpha_settings = alpha,scale = 0.5,y_noise_settings=sigma)\n",
251 | " \n",
252 | " R_rho.to_pickle('sweep_rho_cluster_{}.npy'.format(version))\n",
253 | " \n",
254 | " if plot_sim:\n",
255 | " R_rho = pd.read_pickle('sweep_rho_cluster_{}.npy'.format(version))\n",
256 | " \n",
257 | " var = dict(sigma = sigma, ds = ds, alpha = alpha, n = n)\n",
258 | " sim_plots(R_rho,'correl', correl, filename = 'sweep_rho',var_dict = var,ylim0=[-0.1,1],ylim1=[-.2,0.6],\n",
259 | " ylim2=[0,0.5])"
260 | ]
261 | },
262 | {
263 | "cell_type": "code",
264 | "execution_count": null,
265 | "id": "a58dd81b-8f9d-443e-8302-4c6c1bad5f04",
266 | "metadata": {},
267 | "outputs": [],
268 | "source": []
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": null,
273 | "id": "8dba0b9e-57f0-47fc-bc8e-cf741a92a2a2",
274 | "metadata": {},
275 | "outputs": [],
276 | "source": []
277 | }
278 | ],
279 | "metadata": {
280 | "kernelspec": {
281 | "display_name": "Python 3 (ipykernel)",
282 | "language": "python",
283 | "name": "python3"
284 | },
285 | "language_info": {
286 | "codemirror_mode": {
287 | "name": "ipython",
288 | "version": 3
289 | },
290 | "file_extension": ".py",
291 | "mimetype": "text/x-python",
292 | "name": "python",
293 | "nbconvert_exporter": "python",
294 | "pygments_lexer": "ipython3",
295 | "version": "3.9.21"
296 | },
297 | "toc": {
298 | "base_numbering": 1,
299 | "nav_menu": {},
300 | "number_sections": false,
301 | "sideBar": true,
302 | "skip_h1_title": true,
303 | "title_cell": "Table of Contents",
304 | "title_sidebar": "Contents",
305 | "toc_cell": false,
306 | "toc_position": {},
307 | "toc_section_display": true,
308 | "toc_window_display": false
309 | }
310 | },
311 | "nbformat": 4,
312 | "nbformat_minor": 5
313 | }
314 |
--------------------------------------------------------------------------------
/stacking.py:
--------------------------------------------------------------------------------
1 | from cvxopt import matrix, solvers
2 | import numpy as np
3 | from scipy.stats import zscore
4 | from ridge_tools import cross_val_ridge, R2, ridge
5 |
6 | # Set option to not show progress in CVXOPT solver
7 | solvers.options["show_progress"] = False
8 |
9 |
10 | def get_cv_indices(n_samples, n_folds):
11 | """Generate cross-validation indices.
12 |
13 | Args:
14 | n_samples (int): Number of samples to generate indices for.
15 | n_folds (int): Number of folds to use in cross-validation.
16 |
17 | Returns:
18 | numpy.ndarray: Array of cross-validation indices with shape (n_samples,).
19 | """
20 | cv_indices = np.zeros((n_samples))
21 | n_items = int(np.floor(n_samples / n_folds)) # number of items in one fold
22 | for i in range(0, n_folds - 1):
23 | cv_indices[i * n_items : (i + 1) * n_items] = i
24 | cv_indices[(n_folds - 1) * n_items :] = n_folds - 1
25 | return cv_indices
26 |
27 |
28 | def feat_ridge_CV(
29 | train_features,
30 | train_targets,
31 | test_features,
32 | method="cross_val_ridge",
33 | n_folds=5,
34 | score_function=R2,
35 | ):
36 | """Train a ridge regression model with cross-validation and predict on test_features.
37 |
38 | Args:
39 | train_features (numpy.ndarray): Array of shape (n_samples, n_features) containing the training features.
40 | train_targets (numpy.ndarray): Array of shape (n_samples, n_targets) containing the training targets.
41 | test_features (numpy.ndarray): Array of shape (n_test_samples, n_features) containing the test features.
42 | method (str): Method to use for ridge regression. Options are "simple_ridge" and "cross_val_ridge".
43 | Defaults to "cross_val_ridge".
44 | n_folds (int): Number of folds to use in cross-validation. Defaults to 5.
45 | score_function (callable): Scoring function to use for cross-validation. Defaults to R2.
46 |
47 | Returns:
48 | tuple: Tuple containing:
49 | - preds_train (numpy.ndarray): Array of shape (n_samples, n_targets) containing the training set predictions.
50 | - err (numpy.ndarray): Array of shape (n_samples, n_targets) containing the training set errors.
51 | - preds_test (numpy.ndarray): Array of shape (n_test_samples, n_targets) containing the test set predictions.
52 | - r2s_train_fold (numpy.ndarray): Array of shape (n_folds,) containing the cross-validation scores.
53 | - var_train_fold (numpy.ndarray): Array of shape (n_targets,) containing the variances of the training set predictions.
54 | """
55 |
56 | if np.all(train_features == 0):
57 | # If there are no predictors, return zero weights and zero predictions
58 | weights = np.zeros((train_features.shape[1], train_targets.shape[1]))
59 | train_preds = np.zeros_like(train_targets)
60 | else:
61 | # Use cross-validation to train the model
62 | cv_indices = get_cv_indices(train_targets.shape[0], n_folds=n_folds)
63 | train_preds = np.zeros_like(train_targets)
64 |
65 | for i_cv in range(n_folds):
66 | train_targets_cv = np.nan_to_num(zscore(train_targets[cv_indices != i_cv]))
67 | train_features_cv = np.nan_to_num(
68 | zscore(train_features[cv_indices != i_cv])
69 | )
70 | test_features_cv = np.nan_to_num(zscore(train_features[cv_indices == i_cv]))
71 |
72 | if method == "simple_ridge":
73 | # Use a fixed regularization parameter to train the model
74 | weights = ridge(train_features, train_targets, 100)
75 | elif method == "cross_val_ridge":
76 | # Use cross-validation to select the best regularization parameter
77 | lambdas = np.array([10**i for i in range(-6, 10)])
78 | if train_features.shape[1] > train_features.shape[0]:
79 | weights, __ = cross_val_ridge(
80 | train_features_cv,
81 | train_targets_cv,
82 | n_splits=5,
83 | lambdas=lambdas,
84 | do_plot=False,
85 | method="plain",
86 | )
87 | else:
88 | weights, __ = cross_val_ridge(
89 | train_features_cv,
90 | train_targets_cv,
91 | n_splits=5,
92 | lambdas=lambdas,
93 | do_plot=False,
94 | method="plain",
95 | )
96 |
97 | # Make predictions on the current fold of the data
98 | train_preds[cv_indices == i_cv] = test_features_cv.dot(weights)
99 |
100 | # Calculate prediction error on the training set
101 | train_err = train_targets - train_preds
102 |
103 | # Retrain the model on all of the training data
104 | lambdas = np.array([10**i for i in range(-6, 10)])
105 | weights, __ = cross_val_ridge(
106 | train_features,
107 | train_targets,
108 | n_splits=5,
109 | lambdas=lambdas,
110 | do_plot=False,
111 | method="plain",
112 | )
113 |
114 | # Make predictions on the test set using the retrained model
115 | test_preds = np.dot(test_features, weights)
116 |
117 | # Calculate the score on the training set
118 | train_scores = score_function(train_preds, train_targets)
119 | train_variances = np.var(train_preds, axis=0)
120 |
121 | return train_preds, train_err, test_preds, train_scores, train_variances
122 |
123 |
124 | import numpy as np
125 | from cvxopt import matrix, solvers
126 |
127 |
128 | def stacking_fmri(
129 | train_data,
130 | test_data,
131 | train_features,
132 | test_features,
133 | method="cross_val_ridge",
134 | score_f=R2,
135 | ):
136 | """
137 | Stacks predictions from different feature spaces and uses them to make final predictions.
138 |
139 | Args:
140 | train_data (ndarray): Training data of shape (n_time_train, n_voxels)
141 | test_data (ndarray): Testing data of shape (n_time_test, n_voxels)
142 | train_features (list): List of training feature spaces, each of shape (n_time_train, n_dims)
143 | test_features (list): List of testing feature spaces, each of shape (n_time_test, n_dims)
144 | method (str): Name of the method used for training. Default is 'cross_val_ridge'.
145 | score_f (callable): Scikit-learn scoring function to use for evaluation. Default is mean_squared_error.
146 |
147 | Returns:
148 | Tuple of ndarrays:
149 | - r2s: Array of shape (n_features, n_voxels) containing unweighted R2 scores for each feature space and voxel
150 | - stacked_r2s: Array of shape (n_voxels,) containing R2 scores for the stacked predictions of each voxel
151 | - r2s_weighted: Array of shape (n_features, n_voxels) containing R2 scores for each feature space weighted by stacking weights
152 | - r2s_train: Array of shape (n_features, n_voxels) containing R2 scores for each feature space and voxel in the training set
153 | - stacked_train_r2s: Array of shape (n_voxels,) containing R2 scores for the stacked predictions of each voxel in the training set
154 | - S: Array of shape (n_voxels, n_features) containing the stacking weights for each voxel
155 | """
156 |
157 | # Number of time points in the test set
158 | n_time_test = test_data.shape[0]
159 |
160 | # Check that the number of voxels is the same in the training and test sets
161 | assert train_data.shape[1] == test_data.shape[1]
162 | n_voxels = train_data.shape[1]
163 |
164 | # Check that the number of feature spaces is the same in the training and test sets
165 | assert len(train_features) == len(test_features)
166 | n_features = len(train_features)
167 |
168 | # Array to store R2 scores for each feature space and voxel
169 | r2s = np.zeros((n_features, n_voxels))
170 | # Array to store R2 scores for each feature space and voxel in the training set
171 | r2s_train = np.zeros((n_features, n_voxels))
172 | # Array to store variance explained by the model for each feature space and voxel in the training set
173 | var_train = np.zeros((n_features, n_voxels))
174 | # Array to store R2 scores for each feature space weighted by stacking weights
175 | r2s_weighted = np.zeros((n_features, n_voxels))
176 |
177 | # Array to store stacked predictions for each voxel
178 | stacked_pred = np.zeros((n_time_test, n_voxels))
179 | # Dictionary to store predictions for each feature space and voxel in the training set
180 | preds_train = {}
181 | # Dictionary to store predictions for each feature space and voxel in the test set
182 | preds_test = np.zeros((n_features, n_time_test, n_voxels))
183 | # Array to store weighted predictions for each feature space and voxel in the test set
184 | weighted_pred = np.zeros((n_features, n_time_test, n_voxels))
185 |
186 | # normalize data by TRAIN/TEST
187 | train_data = np.nan_to_num(zscore(train_data))
188 | test_data = np.nan_to_num(zscore(test_data))
189 |
190 | train_features = [np.nan_to_num(zscore(F)) for F in train_features]
191 | test_features = [np.nan_to_num(zscore(F)) for F in test_features]
192 |
193 | # initialize an error dictionary to store errors for each feature
194 | err = dict()
195 | preds_train = dict()
196 |
197 | # iterate over each feature and train a model using feature ridge regression
198 | for FEATURE in range(n_features):
199 | (
200 | preds_train[FEATURE],
201 | error,
202 | preds_test[FEATURE, :, :],
203 | r2s_train[FEATURE, :],
204 | var_train[FEATURE, :],
205 | ) = feat_ridge_CV(
206 | train_features[FEATURE], train_data, test_features[FEATURE], method=method
207 | )
208 | err[FEATURE] = error
209 |
210 | # calculate error matrix for stacking
211 | P = np.zeros((n_voxels, n_features, n_features))
212 | for i in range(n_features):
213 | for j in range(n_features):
214 | P[:, i, j] = np.mean(err[i] * err[j], 0)
215 |
216 | # solve the quadratic programming problem to obtain the weights for stacking
217 | q = matrix(np.zeros((n_features)))
218 | G = matrix(-np.eye(n_features, n_features))
219 | h = matrix(np.zeros(n_features))
220 | A = matrix(np.ones((1, n_features)))
221 | b = matrix(np.ones(1))
222 |
223 | S = np.zeros((n_voxels, n_features))
224 | stacked_pred_train = np.zeros_like(train_data)
225 |
226 | for i in range(0, n_voxels):
227 | PP = matrix(P[i])
228 | # solve for stacking weights for every voxel
229 | S[i, :] = np.array(solvers.qp(PP, q, G, h, A, b)["x"]).reshape(n_features)
230 |
231 | # combine the predictions from the individual feature spaces for voxel i
232 | z_test = np.array(
233 | [preds_test[feature_j, :, i] for feature_j in range(n_features)]
234 | )
235 | z_train = np.array(
236 | [preds_train[feature_j][:, i] for feature_j in range(n_features)]
237 | )
238 | # multiply the predictions by S[i,:]
239 | stacked_pred[:, i] = np.dot(S[i, :], z_test)
240 | # combine the training predictions from the individual feature spaces for voxel i
241 | stacked_pred_train[:, i] = np.dot(S[i, :], z_train)
242 |
243 | # compute the R2 score for the stacked predictions on the training data
244 | stacked_train_r2s = score_f(stacked_pred_train, train_data)
245 |
246 | # compute the R2 scores for each individual feature and the weighted feature predictions
247 | for FEATURE in range(n_features):
248 | # weight the predictions according to S:
249 | # weighted single feature space predictions, computed over a fold
250 | weighted_pred[FEATURE, :] = preds_test[FEATURE, :] * S[:, FEATURE]
251 |
252 | for FEATURE in range(n_features):
253 | r2s[FEATURE, :] = score_f(preds_test[FEATURE], test_data)
254 | r2s_weighted[FEATURE, :] = score_f(weighted_pred[FEATURE], test_data)
255 |
256 | # compute the R2 score for the stacked predictions on the test data
257 | stacked_r2s = score_f(stacked_pred, test_data)
258 |
259 | # return the results
260 | return (
261 | r2s,
262 | stacked_r2s,
263 | r2s_weighted,
264 | r2s_train,
265 | stacked_train_r2s,
266 | S,
267 | )
268 |
269 |
270 | def stacking_CV_fmri(data, features, method="cross_val_ridge", n_folds=5, score_f=R2):
271 | """
272 | A function that performs cross-validated feature stacking to predict fMRI
273 | signal from a set of predictors.
274 |
275 | Args:
276 | - data (ndarray): A matrix of fMRI signal data with dimensions n_time x n_voxels.
277 | - features (list): A list of length n_features containing arrays of predictors
278 | with dimensions n_time x n_dim.
279 | - method (str): A string indicating the method to use to train the model. Default is "cross_val_ridge".
280 | - n_folds (int): An integer indicating the number of cross-validation folds to use. Default is 5.
281 | - score_f (function): A function to use for scoring the model. Default is R2.
282 |
283 | Returns:
284 | - A tuple containing the following elements:
285 | - r2s (ndarray): An array of shape (n_features, n_voxels) containing the R2 scores
286 | for each feature and voxel.
287 | - r2s_weighted (ndarray): An array of shape (n_features, n_voxels) containing the R2 scores
288 | for each feature and voxel, weighted by stacking weights.
289 | - stacked_r2s (float): The R2 score for the stacked predictions.
290 | - r2s_train (ndarray): An array of shape (n_features, n_voxels) containing the R2 scores
291 | for each feature and voxel for the training set.
292 | - stacked_train (float): The R2 score for the stacked predictions for the training set.
293 | - S_average (ndarray): An array of shape (n_features, n_voxels) containing the stacking weights
294 | for each feature and voxel.
295 |
296 | """
297 |
298 | n_time, n_voxels = data.shape
299 | n_features = len(features)
300 |
301 | ind = get_cv_indices(n_time, n_folds=n_folds)
302 |
303 | # create arrays to store results
304 | r2s = np.zeros((n_features, n_voxels))
305 | r2s_train_folds = np.zeros((n_folds, n_features, n_voxels))
306 | var_train_folds = np.zeros((n_folds, n_features, n_voxels))
307 | r2s_weighted = np.zeros((n_features, n_voxels))
308 | stacked_train_r2s_fold = np.zeros((n_folds, n_voxels))
309 | stacked_pred = np.zeros((n_time, n_voxels))
310 | preds_test = np.zeros((n_features, n_time, n_voxels))
311 | weighted_pred = np.zeros((n_features, n_time, n_voxels))
312 | S_average = np.zeros((n_voxels, n_features))
313 |
314 | # perform cross-validation by fold
315 | for ind_num in range(n_folds):
316 | # split data into training and testing sets
317 | train_ind = ind != ind_num
318 | test_ind = ind == ind_num
319 | train_data = data[train_ind]
320 | train_features = [F[train_ind] for F in features]
321 | test_data = data[test_ind]
322 | test_features = [F[test_ind] for F in features]
323 |
324 | # normalize data
325 | train_data = np.nan_to_num(zscore(train_data))
326 | test_data = np.nan_to_num(zscore(test_data))
327 |
328 | train_features = [np.nan_to_num(zscore(F)) for F in train_features]
329 | test_features = [np.nan_to_num(zscore(F)) for F in test_features]
330 |
331 | # Store prediction errors and training predictions for each feature
332 | err = dict()
333 | preds_train = dict()
334 | for FEATURE in range(n_features):
335 | (
336 | preds_train[FEATURE],
337 | error,
338 | preds_test[FEATURE, test_ind],
339 | r2s_train_folds[ind_num, FEATURE, :],
340 | var_train_folds[ind_num, FEATURE, :],
341 | ) = feat_ridge_CV(
342 | train_features[FEATURE],
343 | train_data,
344 | test_features[FEATURE],
345 | method=method,
346 | )
347 | err[FEATURE] = error
348 |
349 | # calculate error matrix for stacking
350 | P = np.zeros((n_voxels, n_features, n_features))
351 | for i in range(n_features):
352 | for j in range(n_features):
353 | P[:, i, j] = np.mean(err[i] * err[j], axis=0)
354 |
355 | # Set optimization parameters for computing stacking weights
356 | q = matrix(np.zeros((n_features)))
357 | G = matrix(-np.eye(n_features, n_features))
358 | h = matrix(np.zeros(n_features))
359 | A = matrix(np.ones((1, n_features)))
360 | b = matrix(np.ones(1))
361 |
362 | S = np.zeros((n_voxels, n_features))
363 | stacked_pred_train = np.zeros_like(train_data)
364 |
365 | # Compute stacking weights and combined predictions for each voxel
366 | for i in range(n_voxels):
367 | PP = matrix(P[i])
368 | # solve for stacking weights for every voxel
369 | S[i, :] = np.array(solvers.qp(PP, q, G, h, A, b)["x"]).reshape(
370 | n_features,
371 | )
372 | # combine the predictions from the individual feature spaces for voxel i
373 | z = np.array(
374 | [preds_test[feature_j, test_ind, i] for feature_j in range(n_features)]
375 | )
376 | # multiply the predictions by S[i,:]
377 | stacked_pred[test_ind, i] = np.dot(S[i, :], z)
378 | # combine the training predictions from the individual feature spaces for voxel i
379 | z = np.array(
380 | [preds_train[feature_j][:, i] for feature_j in range(n_features)]
381 | )
382 | stacked_pred_train[:, i] = np.dot(S[i, :], z)
383 |
384 | S_average += S
385 | stacked_train_r2s_fold[ind_num, :] = score_f(stacked_pred_train, train_data)
386 |
387 | # Compute weighted single feature space predictions, computed over a fold
388 | for FEATURE in range(n_features):
389 | weighted_pred[FEATURE, test_ind] = (
390 | preds_test[FEATURE, test_ind] * S[:, FEATURE]
391 | )
392 |
393 | # Compute overall performance metrics
394 | data_zscored = zscore(data)
395 | for FEATURE in range(n_features):
396 | r2s[FEATURE, :] = score_f(preds_test[FEATURE], data_zscored)
397 | r2s_weighted[FEATURE, :] = score_f(weighted_pred[FEATURE], data_zscored)
398 |
399 | stacked_r2s = score_f(stacked_pred, data_zscored)
400 |
401 | r2s_train = r2s_train_folds.mean(0)
402 | stacked_train = stacked_train_r2s_fold.mean(0)
403 | S_average = S_average / n_folds
404 |
405 | # return the results
406 | return (
407 | r2s,
408 | stacked_r2s,
409 | r2s_weighted,
410 | r2s_train,
411 | stacked_train,
412 | S_average,
413 | )
414 |
--------------------------------------------------------------------------------
/stacking_tutorial.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 19,
6 | "id": "extra-bearing",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import numpy as np\n",
11 | "from stacking_fmri import stacking_CV_fmri, stacking_fmri\n",
12 | "from ridge_tools import R2\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import seaborn as sns"
15 | ]
16 | },
17 | {
18 | "cell_type": "code",
19 | "execution_count": 2,
20 | "id": "confidential-punch",
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "### Set up or load features\n",
25 | "\n",
26 | "N_sample = 1000\n",
27 | "dim_X1 = 50\n",
28 | "dim_X2 = 100\n",
29 | "dim_X3 = 25\n",
30 | "\n",
31 | "X1 = np.random.randn(N_sample, dim_X1)\n",
32 | "X2 = np.random.randn(N_sample, dim_X2)\n",
33 | "X3 = np.random.randn(N_sample, dim_X3)\n",
34 | "\n",
35 | "#X1 = np.load('....')"
36 | ]
37 | },
38 | {
39 | "cell_type": "code",
40 | "execution_count": 32,
41 | "id": "boring-occurrence",
42 | "metadata": {},
43 | "outputs": [],
44 | "source": [
45 | "### Set up or load brain data (fMRI, EEG, ....)\n",
46 | "\n",
47 | "dim_Y = 10\n",
48 | "\n",
49 | "# Y = np.random.randn(N_sample, dim_Y)\n",
50 | "Y = 0.3 * X1.dot(np.random.randn(dim_X1, dim_Y)) + \\\n",
51 | " 0.3 * X2.dot(np.random.randn(dim_X2, dim_Y)) + \\\n",
52 | " 0.4 * X3.dot(np.random.randn(dim_X3, dim_Y))\n",
53 | "\n",
54 | "#Y = np.load('....')"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": 33,
60 | "id": "buried-charger",
61 | "metadata": {},
62 | "outputs": [
63 | {
64 | "name": "stdout",
65 | "output_type": "stream",
66 | "text": [
67 | "(3, 250)\n",
68 | "(3, 250)\n",
69 | "(3, 250)\n",
70 | "(3, 250)\n"
71 | ]
72 | }
73 | ],
74 | "source": [
75 | "### Run stacking using multiple features (Xs) and Y\n",
76 | "\n",
77 | "\n",
78 | "## with the outermost cross-validation\n",
79 | "r2s, stacked_r2s, _, _, _, S_average = stacking_CV_fmri(Y, [X1,X2,X3], method = 'cross_val_ridge',n_folds = 4,score_f=R2)",
80 | "\n",
81 | "\n",
82 | "## simple train-test setting (without the outermost cross-validation)\n",
83 | "# r2s, stacked_r2s, _, _, _, S_average = stacking_fmri(Y[0:700], Y[700:], [X1[0:700],X2[0:700],X3[0:700]], [X1[700:],X2[700:],X3[700:]], method = 'cross_val_ridge',score_f=R2)"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 34,
89 | "id": "chinese-sussex",
90 | "metadata": {},
91 | "outputs": [
92 | {
93 | "name": "stdout",
94 | "output_type": "stream",
95 | "text": [
96 | "shape of r2s is (number of features, dim_Y), that is (3, 10)\n",
97 | "shape of stacked_r2s is (dim_Y, ), that is (10,)\n",
98 | "shape of S_average is (dim_Y, num of features), that is (10, 3)\n"
99 | ]
100 | }
101 | ],
102 | "source": [
103 | "### Results\n",
104 | "\n",
105 | "## r2s: voxelwise R2(predictions using only one feature, data)\n",
106 | "print('shape of r2s is (number of features, dim_Y), that is', r2s.shape)\n",
107 | "\n",
108 | "## stacked_r2s: voxelwise R2(stacking predictions using all features, data)\n",
109 | "print('shape of stacked_r2s is (dim_Y, ), that is', stacked_r2s.shape)\n",
110 | "\n",
111 | "## S_average: optimzed voxelwise stacking weights showing how different features are combined\n",
112 | "print('shape of S_average is (dim_Y, num of features), that is', S_average.shape)\n",
113 | "\n"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": 37,
119 | "id": "curious-blanket",
120 | "metadata": {},
121 | "outputs": [
122 | {
123 | "data": {
124 | "text/plain": [
125 | "Text(0.5, 1.0, 'Prediction Performance')"
126 | ]
127 | },
128 | "execution_count": 37,
129 | "metadata": {},
130 | "output_type": "execute_result"
131 | },
132 | {
133 | "data": {
134 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAmEAAAGDCAYAAABjkcdfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAn4UlEQVR4nO3dfZhV1Znn/e+dAtQYfBkxHRBKKsZXFFHwLYmttq1ia2tMj0o6iZp0pkZszcQRRjLPk9Z0x1Ev8rSVqAnBl0gyTtRJNG0UosY2KjEoYBN8AQwxmiohioJRggrE+/njHCplUVAF1GFVFd/PdXF59t7rrHPvcy7051pr7x2ZiSRJkrau95UuQJIkaVtkCJMkSSrAECZJklSAIUySJKkAQ5gkSVIBhjBJkqQCDGGSulVE3BIRX6u+PjoiFm1mP1Mi4ivdW133i4ivRcSrEfH70rVI6l0MYdI2KCJeiIi3ImJlRLwcEd+NiA909+dk5qOZuW8X6jkvIma2e+/5mfkv3V1TRFweEWuq5/56RDwWEUdtZl/DgEuAAzLzQ91bqaS+zhAmbbv+NjM/ABwKHAb8v+0bRES/rV7V1nF79dx3B2YCd0ZEbEoH1e9mT+C1zHxlUwvow9+tpC4yhEnbuMx8CZgBHAgQERkR/xgRvwZ+Xd13akTMazNyNHLd+yPikIh4MiLejIjbge3bHDs2IlrabA+LiDsjYllEvBYR10XE/sAU4Kh1o1PVtq3TmtXt/xIRiyNieUTcHRFD2hzLiDg/In4dESsi4vquhKrMXANMAz4E7BYRO0fETRGxNCJeqk411lU/47yI+EVEXBMRy4GfAw8AQ6p131Jtd1pEPFP9rn5ePb91db4QEZdGxHzgjxHxkWrtn4uI5mrt50fEYRExv9rHdW3ev1dE/Hv1u3s1Im6NiF3a9T+h+t4/RMTtEdH29zi9+ju+ERG/iYix1f0bPG9JtWMIk7Zx1Sm1vwH+o83uTwBHAAdExKHAzcB/BXYDvgPcHRHbRcQA4MfA94H/BPxf4O828Dl1wD3Ai8BwYA/gtsxcAJwP/DIzP5CZu3Tw3r8CrgTOAgZX+7itXbNTqYzoHVxtd1IXzn074DygJTNfpRLI1gIfAQ4BTgS+0OYtRwDPAx8ETgBOBpZU6z4vIvYBfgB8icoo23TgJ9XvaZ1PAacAu1Q/a12/ewNnA03A/wP8NTACOCsijllXcvV7GALsDwwDLm93WmcBY4EGYGT1/IiIw4HvAROrn/2XwAvV93R23pJqwBAmbbt+XB11mgk8DPyvNseuzMzlmfkW8F+A72Tm45n5p8ycBrwDHFn90x9oysw1mflDYPYGPu9wKuFhYmb+MTPfzsyZG2jb3qeBmzPzycx8B/gylZGz4W3aXJWZr2fm74CHgFEb6e+s6rk3A6OBT0TEX1AJVV+q1vcKcA0wrs37lmTmtZm5tvrdtHc2cG9mPlAdZfs6sAPw0TZtvpmZze3e/y/V7+N+4I/ADzLzleoo5aNUghGZubja9zuZuQz4V+AY3uubmbkkM5cDP2nzPfxD9Tt8IDPfzcyXMnNhF89bUg24JkHadn0iM3+2gWPNbV7vCZwbERe12TeASqBK4KXMzDbHXtxAn8OAFzNz7QaOb8wQ4Ml1G5m5MiJeozKa9kJ1d9urE1cBG7vQ4I7M/EzbHdWRov7A0jYzme/jvd9F29cbqrP1/DPz3Yhorta5sT5ebvP6rQ62P1Ct8YPAN4GjgYHV+la066v997Bu2nYYlZG59vak8/OWVAOGMEkdaRuqmoErMvOK9o2q02R7RES0CWL1wG866LMZqI+Ifh0EseygfVtLqISFdZ+7I5Wp0Zc6ed+maKYywjdoI0GxK3UetG6jui5tGO+ts7M+NubK6vtHZuZrEfEJ4LqNv6VVM7DXBvZ3dt6SasDpSEmduQE4PyKOiIodI+KUiBgI/JLKWqIvRkS/iPgklWnHjjwBLAWuqvaxfUR8rHrsZWBou7VTbf0f4HMRMaq6jut/AY9n5gvddI5k5lLgfuD/i4idIuJ91YXw7af7NuYO4JSIOD4i+lO5fcU7wGPdVOZAYCXwekTsQWV9V1fdROU7PL56bntExH7ddN6SNoMhTNJGZeYcKuvCrqMy9bWY6mLvzFwNfLK6vYLKmqg7N9DPn4C/pbL4+3dAS7U9wL8DzwC/j4hXO3jvg8BXgB9RCXJ7UZs1S+dQmWp9lsr5/JDKhQBdkpmLgM8A1wKvUjnfv61+T93hq1RuKfIH4F428F1voLYngM9RWe/1ByrrANeNLm7ReUvaPPHepRySJEnaGhwJkyRJKsAQJkmSVIAhTJIkqQBDmCRJUgGGMEmSpAJ63c1aBw0alMOHDy9dhiRJUqfmzp37ambu3tGxXhfChg8fzpw5c0qXIUmS1KmI2NCj3JyOlCRJKsEQJkmSVIAhTJIkqYBetyasI2vWrKGlpYW33367dCnairbffnuGDh1K//79S5ciSdImq2kIi4ixwDeAOuDGzLyqgzbHAk1Af+DVzDxmUz+npaWFgQMHMnz4cCJii2pW75CZvPbaa7S0tNDQ0FC6HEmSNlnNpiMjog64HjgZOAD4VEQc0K7NLsC3gNMycwRw5uZ81ttvv81uu+1mANuGRAS77babo5+SpF6rlmvCDgcWZ+bzmbkauA04vV2bvwfuzMzfAWTmK5v7YQawbY+/uSSpN6tlCNsDaG6z3VLd19Y+wK4R8fOImBsR53TUUUQ0RsSciJizbNmyGpW7Za644gpGjBjByJEjGTVqFI8//jgATU1NrFq1arP6vOWWW7jwwgvX2z9lyhS+973vbVG9nXn00UcZMWIEo0aN4q233mrd39zcTENDA8uXLwdgxYoVNDQ08OKLldugjB07ll122YVTTz11g31PnDiR/fbbj5EjR3LGGWfw+uuvr9fmxRdfZPTo0YwaNYoRI0YwZcqU7j1BSZIKq+WasI6GKbKDzx8NHA/sAPwyImZl5nPveVPmVGAqwJgxY9r3sZ7hk+7drII35IWrTtno8V/+8pfcc889PPnkk2y33Xa8+uqrrF69GqiEsM985jO8//3v77Z6zj///G7ra0NuvfVWJkyYwOc+97n37B82bBjjx49n0qRJTJ06lUmTJtHY2Miee+4JVALWqlWr+M53vrPBvk844QSuvPJK+vXrx6WXXsqVV17J1Vdf/Z42gwcP5rHHHmO77bZj5cqVHHjggZx22mkMGTKk+09WkqQCajkS1gIMa7M9FFjSQZufZuYfM/NV4BHg4BrWVBNLly5l0KBBbLfddgAMGjSIIUOG8M1vfpMlS5Zw3HHHcdxxxwEwfvx4xowZw4gRI7jsssta+5g9ezYf/ehHOfjggzn88MN588033/MZ9957L0cddRSvvvoql19+OV//+tcBOPbYY7n00ks5/PDD2WeffXj00UcBWLVqFWeddRYjR47k7LPP5ogjjujwSQMPPvgghxxyCAcddBCf//zneeedd7jxxhu54447+Od//mc+/elPr/eeiy++mFmzZtHU1MTMmTO55JJLWo8df/zxDBw4cKPf14knnki/fpX8f+SRR9LS0rJemwEDBrR+n++88w7vvvvuRvuUJKm3qWUImw3sHRENETEAGAfc3a7NvwFHR0S/iHg/cASwoIY11cSJJ55Ic3Mz++yzDxdccAEPP/wwAF/84hcZMmQIDz30EA899BBQmbacM2cO8+fP5+GHH2b+/PmsXr2as88+m2984xv86le/4mc/+xk77LBDa/933XUXV111FdOnT2fQoEHrff7atWt54oknaGpq4qtf/SoA3/rWt9h1112ZP38+X/nKV5g7d+5673v77bc577zzuP3223nqqadYu3Yt3/72t/nCF77AaaedxuTJk7n11lvXe1///v2ZPHkyF198MU1NTQwYMGCzv7ubb76Zk08+ucNjzc3NjBw5kmHDhnHppZc6CiZJ6lNqFsIycy1wIXAflWB1R2Y+ExHnR8T51TYLgJ8C84EnqNzG4ula1VQrH/jAB5g7dy5Tp05l99135+yzz+aWW27psO0dd9zBoYceyiGHHMIzzzzDs88+y6JFixg8eDCHHXYYADvttFPrSNFDDz3E1Vdfzb333suuu+7aYZ+f/OQnARg9ejQvvPACADNnzmTcuHEAHHjggYwcOXK99y1atIiGhgb22WcfAM4991weeeSRLp3zjBkzGDx4ME8/vfk/1xVXXEG/fv06HG2DytTn/PnzWbx4MdOmTePll1/e7M+SJKmnqekd8zNzembuk5l7ZeYV1X1TMnNKmzaTM/OAzDwwM5tqWU8t1dXVceyxx/LVr36V6667jh/96Efrtfntb3/L17/+dR588EHmz5/PKaecwttvv01mbvBKvw9/+MO8+eabPPfccx0eB1qn7erq6li7di1QuY9WZ7rSpiPz5s3jgQceYNasWVxzzTUsXbp0k/uYNm0a99xzD7feemunVzkOGTKEESNGtE61SpLUF/jYom6waNEifv3rX7duz5s3r3Wh+sCBA1vXd73xxhvsuOOO7Lzzzrz88svMmDEDgP32248lS5Ywe/ZsAN58883WMLXnnnty5513cs455/DMM890uaaPf/zj3HHHHQA8++yzPPXUU+u12W+//XjhhRdYvHgxAN///vc55piN3ys3Mxk/fjxNTU3U19czceJEJkyY0Gk9X/7yl7nrrrsA+OlPf8rVV1/N3Xff/Z4LFl566SWOP/54oHID3nVXZa5YsYJf/OIX7Lvvvl04c0kbMnzSvV36I2nr6BOPLSpt5cqVXHTRRbz++uv069ePj3zkI0ydOhWAxsZGTj75ZAYPHsxDDz3EIYccwogRI/jwhz/Mxz72MaCyCP3222/noosu4q233mKHHXbgZz/7WWv/++67L7feeitnnnkmP/nJT7pU0wUXXMC5557LyJEjOeSQQxg5ciQ777zze9psv/32fPe73+XMM89k7dq1HHbYYZ1eeXnDDTdQX1/PCSec0Po5t9xyCw8//DDHHHMMRx99NAsXLmTlypUMHTqUm266iZNOOomnnnqK0047DYALL7yQd955p7WPI488kilTprB06dLWadgFCxZwySWXEBFkJhMmTOCggw7q0rlLktQbxOZOSZUyZsyYbH+V34IFC9h///0LVdQz/elPf2LNmjVsv/32/OY3v+H444/nueee26JF9FvipJNO4r777ttom+uuu476+vrWsNYV/vZS13V1lKuz2/JI6rqImJuZYzo65khYH7Vq1SqOO+441qxZQ2by7W9/u1gAAzoNYECHN6aVStqUqTmDi6RNZQjrowYOHNjhfcEkSVLP4MJ8SZKkAgxhkiRJBRjCJEmSCjCESZIkFWAI6wbNzc00NDSwfPlyoHJz0YaGBl588UXGjh3LLrvswqmnnlq4SkmS1JP0zasjL9+58zab1N8fNnp42LBhjB8/nkmTJjF16lQmTZpEY2Mje+65JxMnTmTVqlV85zvf6d6aJElSr+ZIWDe5+OKLmTVrFk1NTcycOZNLLrkEgOOPP56BAwcWrk6SJPU0fXMkrID+/fszefJkxo4dy/3331/0xqiSJKnncySsG82YMYPBgwfz9NNPly5FkiT1cI6EdZN58+bxwAMPMGvWLD7+8Y8zbtw4Bg8eXLqsPml+y+utr19e8RYnb+DRMj5GRpLUkzkS1g0yk/Hjx9PU1ER9fT0TJ05kwoQJpcuSJEk9mCGsG9xwww3U19dzwgknAHDBBRewcOFCHn74YY4++mjOPPNMHnzwQYYOHdqlB1lLkqS+r29OR3ZyS4nu1tjYSGNjY+t2XV0dc+fOBeDRRx/dqrVI0rZi+AaWIrTn0gT1VI6ESZIkFWAIkyRJKsAQJkmSVIAhTJIkqYC+uTBfkiQV09WLJmDbvnDCkTBJkqQCDGHdoLm5mYaGBpYvXw7AihUraGho4OGHH+aoo45ixIgRjBw5kttvv71wpZIkqafok9ORB007qFv7e+rcpzZ6fNiwYYwfP55JkyYxdepUJk2aRGNjI4MHD+Z73/see++9N0uWLGH06NGcdNJJ7LLLLt1anyRJ6n36ZAgr4eKLL2b06NE0NTUxc+ZMrr32WgYMGNB6fMiQIXzwgx9k2bJlhjBJkmQI6y79+/dn8uTJjB07lvvvv/89AQzgiSeeYPXq1ey1116FKlRP5V2/JWnbZAjrRjNmzGDw4ME8/fTTrc+RBFi6dCmf/exnmTZtGu97n8vwJPUdXV3+0dmyDmlbZCLoJvPmzeOBBx5g1qxZXHPNNSxduhSAN954g1NOOYWvfe1rHHnkkYWrlCRJPYUhrBtkJuPHj6epqYn6+nomTpzIhAkTWL16NWeccQbnnHMOZ555ZukyJUlSD2II6wY33HAD9fX1rVOQF1xwAQsXLuTKK6/kkUce4ZZbbmHUqFGMGjWKefPmlS1WkiT1CH1yTdjWXnvQ2NhIY2Nj63ZdXR1z584F4LLLLtuqtUiSpN7BkTBJkqQC+uRImN5rfsvrXW47cuguNatDklcTasO8Xc22x5EwSZKkAgxhkiRJBTgdKeEUkSRp63MkTJIkqQBDWDdobm6moaGB5cuXA7BixQoaGhp48cUXGT16NKNGjWLEiBFMmTKlcKWSJKmn6JPTkQv2279b+9t/4YKNHh82bBjjx49n0qRJTJ06lUmTJtHY2MjgwYN57LHH2G677Vi5ciUHHnggp512GkOGDOnW+iRJUu/TJ0NYCRdffDGjR4+mqamJmTNncu211zJgwIDW4++88w7vvvtuwQolSVJPYgjrJv3792fy5MmMHTuW+++/vzWANTc3c8opp7B48WImT57sKJgkSQJqvCYsIsZGxKKIWBwRkzo4fmxE/CEi5lX//FMt66m1GTNmMHjwYJ5++unWfcOGDWP+/PksXryYadOm8fLLLxesUJIk9RQ1C2ERUQdcD5wMHAB8KiIO6KDpo5k5qvrnn2tVT63NmzePBx54gFmzZnHNNdewdOnS9xwfMmQII0aM4NFHHy1UoSRJ6klqORJ2OLA4M5/PzNXAbcDpNfy8YjKT8ePH09TURH19PRMnTmTChAm0tLTw1ltvAZUrJn/xi1+w7777Fq5WkiT1BLUMYXsAzW22W6r72jsqIn4VETMiYkQN66mZG264gfr6ek444QQALrjgAhYuXMhNN93EEUccwcEHH8wxxxzDhAkTOOigrt0UVJIk9W21XJgfHezLdttPAntm5sqI+Bvgx8De63UU0Qg0AtTX13f6wZ3dUqK7NTY20tjY2LpdV1fH3LlzAbjsssu2ai2SJKl3qGUIawGGtdkeCixp2yAz32jzenpEfCsiBmXmq+3aTQWmAowZM6Z9kJPUgwyfdG+X2r1w1Sk1rkTadD7CTFtTLacjZwN7R0RDRAwAxgF3t20QER+KiKi+Prxaz2s1rEmSJKlHqNlIWGaujYgLgfuAOuDmzHwmIs6vHp8C/GdgfESsBd4CxmWmI12SJKnPq+nNWjNzOjC93b4pbV5fB1zXTZ9FdVBN24jMJNdbZihJ6ov64lRxn3iA9/bbb89rr72Gg2jbjsxk7ao3ePH1NaVLkSRps/SJxxYNHTqUlpYWli1b1m19tqx4q2ufvesO3faZtfJyF88FYMGbveN8kuTF19dw7eMrSpcjSdJm6RMhrH///jQ0NHRrnyf3oSu8unou0PfOR5KknqpPTEdKkiT1NoYwSZKkAgxhkiRJBRjCJEmSCugTC/O19fXF+7VIkrQ1ORImSZJUgCFMkiSpAEOYJElSAYYwSZKkAgxhkiRJBRjCJEmSCjCESZIkFeB9wiRJ6oO6ej9H8J6OpTgSJkmSVIAhTJIkqQCnI6U+yMdKSVLP50iYJElSAYYwSZKkAgxhkiRJBRjCJEmSCjCESZIkFWAIkyRJKsAQJkmSVIAhTJIkqQBDmCRJUgGGMEmSpAIMYZIkSQUYwiRJkgowhEmSJBVgCJMkSSrAECZJklSAIUySJKkAQ5gkSVIBhjBJkqQCDGGSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwBAmSZJUgCFMkiSpAEOYJElSATUNYRExNiIWRcTiiJi0kXaHRcSfIuI/17IeSZKknqJmISwi6oDrgZOBA4BPRcQBG2h3NXBfrWqRJEnqaWo5EnY4sDgzn8/M1cBtwOkdtLsI+BHwSg1rkSRJ6lFqGcL2AJrbbLdU97WKiD2AM4ApNaxDkiSpx6llCIsO9mW77Sbg0sz800Y7imiMiDkRMWfZsmXdVZ8kSVIx/WrYdwswrM32UGBJuzZjgNsiAmAQ8DcRsTYzf9y2UWZOBaYCjBkzpn2QkyRJ6nVqGcJmA3tHRAPwEjAO+Pu2DTKzYd3riLgFuKd9AJMkSeqLahbCMnNtRFxI5arHOuDmzHwmIs6vHncdmCRJ2mbVciSMzJwOTG+3r8PwlZnn1bIWSZKknsQ75kuSJBVgCJMkSSrAECZJklSAIUySJKkAQ5gkSVIBhjBJkqQCDGGSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwBAmSZJUgCFMkiSpAEOYJElSAYYwSZKkAvqVLmBbcdC0g7rc9qlzn6phJZIkqSdwJEySJKkAQ5gkSVIBhjBJkqQCDGGSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwPuESerRvMeepL7KkTBJkqQCDGGSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwBAmSZJUgCFMkiSpAEOYJElSAYYwSZKkAgxhkiRJBRjCJEmSCjCESZIkFWAIkyRJKsAQJkmSVIAhTJIkqQBDmCRJUgGdhrCI2Cki9upg/8jalCRJktT3bTSERcRZwELgRxHxTEQc1ubwLbUsTJIkqS/rbCTsfwKjM3MU8Dng+xHxyeqxqGVhkiRJfVm/To7XZeZSgMx8IiKOA+6JiKFA1rw6SZKkPqqzkbA3264HqwayY4HTgRE1rEuSJKlP6yyEjW/fJjPfBMYCn++s84gYGxGLImJxREzq4PjpETE/IuZFxJyI+PimFC9JktRbbXQ6MjN/tYFD73bWcUTUAdcDJwAtwOyIuDszn23T7EHg7szM6tWWdwD7dalySZKkXqyzqyN3iogvR8R1EXFiVFwEPA+c1UnfhwOLM/P5zFwN3EZlGrNVZq7MzHVry3bEdWaSJGkb0dnC/O8DK4BfAl8AJgIDgNMzc14n790DaG6z3QIc0b5RRJwBXAl8EDilo44iohFoBKivr+/kYyVJknq+zkLYhzPzIICIuBF4FaivrgvrTEe3sFhvpCsz7wLuioi/BP4F+OsO2kwFpgKMGTPG0TJJktTrdbYwf826F5n5J+C3XQxgUBn5GtZmeyiwZEONM/MRYK+IGNTF/iVJknqtzkbCDo6IN6qvA9ihuh1AZuZOG3nvbGDviGgAXgLGAX/ftkFEfAT4TXVh/qFUpjpf24zzkCRJ6lU6uzqybnM7zsy1EXEhcB9QB9ycmc9ExPnV41OAvwPOiYg1wFvA2W0W6kuSJPVZnY2EbZHMnA5Mb7dvSpvXVwNX17IGSZKknqizNWGSJEmqAUOYJElSAYYwSZKkAmq6JkySNujynbvWrsEbNEvqmxwJkyRJKsAQJkmSVIAhTJIkqQBDmCRJUgGGMEmSpAIMYZIkSQUYwiRJkgrwPmFbynsdSZKkzeBImCRJUgGGMEmSpAIMYZIkSQUYwiRJkgowhEmSJBVgCJMkSSrAECZJklSAIUySJKkAQ5gkSVIBhjBJkqQCfGyRpD5jwX77d6nd/gsX1LgSSeqcI2GSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwBAmSZJUgCFMkiSpAEOYJElSAYYwSZKkArxZq9RbXL5z19s21NeuDklSt3AkTJIkqQBHwiRJ2kQ+IkvdwZEwSZKkAhwJ64H8PyxJkvo+Q5gkSeozujqQAeUHM5yOlCRJKsAQJkmSVIAhTJIkqQDXhElSd+jqzXS9ka6kKkfCJEmSCjCESZIkFWAIkyRJKqCmISwixkbEoohYHBGTOjj+6YiYX/3zWEQcXMt6JEmSeoqaLcyPiDrgeuAEoAWYHRF3Z+azbZr9FjgmM1dExMnAVOCIWtUkbSmfZiBJ6i61vDrycGBxZj4PEBG3AacDrSEsMx9r034WMLSG9UiSuqKrV3qCV3tKW6CW05F7AM1ttluq+zbkH4AZHR2IiMaImBMRc5YtW9aNJUqSJJVRyxAWHezLDhtGHEclhF3a0fHMnJqZYzJzzO67796NJUqSJJVRy+nIFmBYm+2hwJL2jSJiJHAjcHJmvlbDeiRJknqMWo6EzQb2joiGiBgAjAPubtsgIuqBO4HPZuZzNaxFkiSpR6nZSFhmro2IC4H7gDrg5sx8JiLOrx6fAvwTsBvwrYgAWJuZY2pVkyRJUk9R02dHZuZ0YHq7fVPavP4C8IVa1iBJktQTecd8SZKkAmo6EiZJkno+b0RdhiNhkiRJBRjCJEmSCnA6UpKk3qSrj5XykVI9niNhkiRJBRjCJEmSCnA6UjXV1StuwKtuJEnbFkfCJEmSCjCESZIkFeB0pCT1QE7lS32fI2GSJEkFGMIkSZIKcDpS7+VNACVJ2iocCZMkSSrAECZJklSAIUySJKkAQ5gkSVIBhjBJkqQCDGGSJEkFeIsKSZJUzjZ8ayRHwiRJkgowhEmSJBXgdKT6rq4OcUOfHOaWJPVsjoRJkiQVYAiTJEkqwBAmSZJUgGvCpG3Ygv3271K7/RcuqHElkrTtMYRJkvo2L9JRD+V0pCRJUgGGMEmSpAIMYZIkSQUYwiRJkgowhEmSJBVgCJMkSSrAW1RIkmquq/ekA+9Lp22HI2GSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwBAmSZJUgCFMkiSpAEOYJElSAYYwSZKkAmoawiJibEQsiojFETGpg+P7RcQvI+KdiJhQy1okSZJ6kprdMT8i6oDrgROAFmB2RNydmc+2abYc+CLwiVrVIUmS1BPVciTscGBxZj6fmauB24DT2zbIzFcyczawpoZ1SJIk9Ti1DGF7AM1ttluq+zZZRDRGxJyImLNs2bJuKU6SJKmkWoaw6GBfbk5HmTk1M8dk5pjdd999C8uSJEkqr5YhrAUY1mZ7KLCkhp8nSZLUa9QyhM0G9o6IhogYAIwD7q7h50mSJPUaNbs6MjPXRsSFwH1AHXBzZj4TEedXj0+JiA8Bc4CdgHcj4kvAAZn5Rq3qkiRJ6glqFsIAMnM6ML3dviltXv+eyjSlJEnSNsU75kuSJBVgCJMkSSrAECZJklSAIUySJKkAQ5gkSVIBhjBJkqQCDGGSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwBAmSZJUgCFMkiSpAEOYJElSAYYwSZKkAgxhkiRJBRjCJEmSCjCESZIkFWAIkyRJKsAQJkmSVIAhTJIkqQBDmCRJUgGGMEmSpAIMYZIkSQUYwiRJkgowhEmSJBVgCJMkSSrAECZJklSAIUySJKkAQ5gkSVIBhjBJkqQCDGGSJEkFGMIkSZIKMIRJkiQVYAiTJEkqwBAmSZJUgCFMkiSpAEOYJElSAYYwSZKkAgxhkiRJBRjCJEmSCjCESZIkFWAIkyRJKsAQJkmSVEBNQ1hEjI2IRRGxOCImdXA8IuKb1ePzI+LQWtYjSZLUU9QshEVEHXA9cDJwAPCpiDigXbOTgb2rfxqBb9eqHkmSpJ6kliNhhwOLM/P5zFwN3Aac3q7N6cD3smIWsEtEDK5hTZIkST1CLUPYHkBzm+2W6r5NbSNJktTn9Kth39HBvtyMNkREI5XpSoCVEbFoC2vbXIOAV9vu6OgEOvZ0l1u2n7PdoOj6p3e5yy637Nr5dPlcYEvPZwt+G+j28yn620APPB//7rTy706nuvn38e9O5/y7s1nW+202YM8NHahlCGsBhrXZHgos2Yw2ZOZUYGp3F7ipImJOZo4pXYfW52/Ts/n79Fz+Nj2bv0/P1R2/TS2nI2cDe0dEQ0QMAMYBd7drczdwTvUqySOBP2Tm0hrWJEmS1CPUbCQsM9dGxIXAfUAdcHNmPhMR51ePTwGmA38DLAZWAZ+rVT2SJEk9SS2nI8nM6VSCVtt9U9q8TuAfa1lDNys+JaoN8rfp2fx9ei5/m57N36fn2uLfJio5SJIkSVuTjy2SJEkqwBDWBZ09fknlRMSwiHgoIhZExDMR8d9K16T3ioi6iPiPiLindC16r4jYJSJ+GBELq3+Hjipdkyoi4uLqv9OejogfRMT2pWvalkXEzRHxSkQ83Wbff4qIByLi19V/7rqp/RrCOtHFxy+pnLXAJZm5P3Ak8I/+Pj3OfwMWlC5CHfoG8NPM3A84GH+nHiEi9gC+CIzJzAOpXNw2rmxV27xbgLHt9k0CHszMvYEHq9ubxBDWua48fkmFZObSzHyy+vpNKv8R8akLPUREDAVOAW4sXYveKyJ2Av4SuAkgM1dn5utFi1Jb/YAdIqIf8H46uIemtp7MfARY3m736cC06utpwCc2tV9DWOd8tFIvERHDgUOAxwuXoj9rAv4H8G7hOrS+DwPLgO9Wp4tvjIgdSxclyMyXgK8DvwOWUrmH5v1lq1IH/mLdvU2r//zgpnZgCOtclx6tpLIi4gPAj4AvZeYbpesRRMSpwCuZObd0LepQP+BQ4NuZeQjwRzZjOkXdr7q26HSgARgC7BgRnylblWrBENa5Lj1aSeVERH8qAezWzLyzdD1q9THgtIh4gco0/l9FxP8uW5LaaAFaMnPdyPEPqYQylffXwG8zc1lmrgHuBD5auCat7+WIGAxQ/ecrm9qBIaxzXXn8kgqJiKCypmVBZv5r6Xr0Z5n55cwcmpnDqfy9+ffM9P/me4jM/D3QHBH7VncdDzxbsCT92e+AIyPi/dV/xx2PF030RHcD51Zfnwv826Z2UNM75vcFG3r8UuGy9GcfAz4LPBUR86r7/mf1aQ2SNu4i4Nbq/2A+j4+O6xEy8/GI+CHwJJUrwP8D75xfVET8ADgWGBQRLcBlwFXAHRHxD1SC85mb3K93zJckSdr6nI6UJEkqwBAmSZJUgCFMkiSpAEOYJElSAYYwSZKkAgxhknqliPh5RJzUbt+XIuJb3dT/eRFx3cb2R8TlEfFSRMyLiF9HxJ0+QF5SVxnCJPVWP6ByE9i2xlX3b03XZOaozNwbuB3494jYfSvXIKkXMoRJ6q1+CJwaEdtB6wPchwAzI+JTEfFURDwdEVdXj58RET+LisER8VxEfCgido+IH0XE7Oqfj21uQZl5O3A/8PdbfnqS+jpDmKReKTNfA54AxlZ3jaMyEjUYuBr4K2AUcFhEfCIz7wJ+D/wjcANwWfXRPd+gMpp1GPB3wI1bWNqTwH5b2IekbYCPLZLUm62bkvy36j8/DxwG/DwzlwFExK3AXwI/pvKYnqeBWZm5btryr4EDKo/oA2CniBi4BTVF500kyRAmqXf7MfCvEXEosENmPhkR9RtpvwfwLvAXEfG+zHyXyozAUZn5VtuGbULZpjoEmLO5b5a07XA6UlKvlZkrgZ8DN/PnBfmPA8dExKCIqAM+BTwcEf2A71JZr7UA+O/V9vcDF67rMyJGbW49EfF3wIls/YsDJPVCjoRJ6u1+ANxJ9UrJzFwaEV8GHqIyNTg9M/8tIv4JeDQzH42IecDsiLgX+CJwfUTMp/LvxEeA8zfh8y+OiM8AO1KZ6vyrdVOhkrQxkZmla5AkSdrmOB0pSZJUgCFMkiSpAEOYJElSAYYwSZKkAgxhkiRJBRjCJEmSCjCESZIkFWAIkyRJKuD/B/x6F5iSPgcDAAAAAElFTkSuQmCC\n",
135 | "text/plain": [
136 | ""
137 | ]
138 | },
139 | "metadata": {
140 | "needs_background": "light"
141 | },
142 | "output_type": "display_data"
143 | }
144 | ],
145 | "source": [
146 | "plt.figure(figsize=(10,6))\n",
147 | "\n",
148 | "bar_width = 0.2\n",
149 | "index_0 = np.arange(dim_Y)\n",
150 | "index_1 = index_0 + bar_width\n",
151 | "index_2 = index_1 + bar_width\n",
152 | "index_3 = index_2 + bar_width\n",
153 | "\n",
154 | "\n",
155 | "plt.bar(index_0, stacked_r2s, width=bar_width, label='Stacking of X1,2,3')\n",
156 | "plt.bar(index_1, r2s[0,:], width=bar_width, label='X1')\n",
157 | "plt.bar(index_2, r2s[1,:], width=bar_width, label='X2')\n",
158 | "plt.bar(index_3, r2s[2,:], width=bar_width, label='X3')\n",
159 | "plt.legend()\n",
160 | "plt.xlabel('Voxel ID')\n",
161 | "plt.ylabel('R2')\n",
162 | "plt.title('Prediction Performance')"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 45,
168 | "id": "smoking-merchant",
169 | "metadata": {},
170 | "outputs": [
171 | {
172 | "data": {
173 | "text/plain": [
174 | "([,\n",
175 | " ,\n",
176 | " ],\n",
177 | " [Text(0, 0.5, 'X1'), Text(0, 1.5, 'X2'), Text(0, 2.5, 'X3')])"
178 | ]
179 | },
180 | "execution_count": 45,
181 | "metadata": {},
182 | "output_type": "execute_result"
183 | },
184 | {
185 | "data": {
186 | "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWsAAAEKCAYAAADU7nSHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAXiUlEQVR4nO3dfZRlVXnn8e+PRgZQsQ0gIo0DYkcHnYjImxoRUWJjjMQ4icBEl0RtmQUKutZETdZo5uUfV0bHyQqIrRA1UdBB1I5pBFkKOCMIiIg0CDagULyILwgGolBVz/xxbpObSr3c6rrn1j3V3w/rrL7nnnP3sy8NT+16zj77pKqQJI23HZa7A5KkhZmsJakDTNaS1AEma0nqAJO1JHWAyVqSOsBkLUlDluScJPcluWGO40nyV0m2JLk+ycELtWmylqTh+wSwbp7jxwJre9t64CMLNWiylqQhq6rLgZ/Pc8pxwKeqcSWwOsne87W54zA7OEyn7vf6kdxaec0jPx5FGAB+Pf3oyGIdufO+I4v1lFo1sljX89DIYo3K5Q/cMrJYh+32jJHF+taDt44s1r2/uClLbePRn942cM7Zac8D3kYzIt5qQ1VtWES4fYA7+/Yneu/dM9cHxjZZS9K46iXmxSTnmWb74TLvDwuTtSQBTE+NMtoE0P/r7xrg7vk+YM1akgCmJgfflm4j8MberJAjgAeqas4SCDiyliQAqqaH1laSc4GjgD2STADvBx7XxKmzgE3Aq4AtwMPASQu1abKWJIDp4SXrqjphgeMFnLKYNk3WkgQwxJF1G0zWkgSjvsC4aCZrSQJH1pLUBTWcWR6tMVlLEgz1AmMbTNaSBJZBJKkTvMAoSR3gyFqSOsALjJLUAV5glKTxV2XNWpLGnzVrSeoAyyCS1AGOrCWpA6ZG94zUbWGyliSwDCJJnWAZRJI6wJG1JHWAyVqSxl95gVGSOsCatSR1gGUQSeoAR9aS1AGOrCWpAxxZS1IHTPrwAUkaf46sJakDrFlLUgc4spakDhjzkfUOowyW5JhRxpOkgdX04NsyGGmyBs6e72CS9UmuSXLN5l/eOqo+SVIzG2TQbRkMvQySZONch4Dd5/tsVW0ANgCcut/ra8hdk6S51XinnDZq1i8B/hj4xxnvBzishXiStHRjXrNuI1lfCTxcVZfNPJDk5hbiSdLSjXmybqNmvb6qvj7HsT9vIZ4kLd0QLzAmWZfk5iRbkrxnluNPSvL3Sb6bZHOSkxZqs41kfVmSP03y2Kg9yV5J/g74UAvxJGnppqYG3+aRZBVwBnAscCBwQpIDZ5x2CnBjVT0POAr4YJKd5mu3jWT9AuAA4DtJjk5yGnAVcAVweAvxJGnppqcH3+Z3GLClqm6rqkeA84DjZpxTwBOTBHgC8HNg3mkmQ69ZV9X9wNt6SfoS4G7giKqaGHYsSRqaRdSsk6wH1ve9taE3mw1gH+DOvmMT/OuB6l8DG2ny4xOB11fNX19pY+reauADvc6tA14FXJjktKr62rDjSdJQLOJml/5pxrPIbB+Zsf9K4DrgaJpKxFeTfKOqHpwrZhtlkGuBHwCHVNXFVXU68AbgfyQ5t4V4krRkNV0DbwuYAPbt219DM4LudxJwQTW2ALcDz56v0Tam7h05s+RRVdcBL0ry1hbiSdLSDW/q3tXA2iT7A3cBxwMnzjjnDuDlwDeS7AU8C7htvkbbqFnPWZuuqo8NO54kDcUCszwGVVWTSU4FLgJWAedU1eYkJ/eOnwX8d+ATSb5HUzZ5d1X9dL52XXVPkmCoN8VU1SZg04z3zup7fTfwO4tp02QtSTD2dzCarCUJtsuFnCSpexxZS1IHLDwlb1mZrCUJhjYbpC0ma0kCyjKIJHWAZRBJ6oBlehDuoEzWkgSOrCWpEya9wChJ488yiCR1gGUQSRp/Tt2TpC5wZC1JHWCylqQO8HZzSRp/AzxbcVmZrCUJLINIUic4G0SSOsCRtSR1gMlaksZfTVkG2SbvX3vvyGI9/vf//WgCTU6OJg4wfe9PRhbrpxf+bGSx3nnSAaMJtMMOo4kD8Os9RhaqfvWrkcXKrr89slhD4ch6vI0sUUsaa07dk6QuMFlLUgeMd8naZC1JADU53tnaZC1J4MhakrrAC4yS1AWOrCVp/DmylqQucGQtSeOvRneD8TYxWUsSUGM+sh7hAgiSNMamF7EtIMm6JDcn2ZLkPXOcc1SS65JsTnLZQm06spYkhjeyTrIKOAM4BpgArk6ysapu7DtnNXAmsK6q7kjylIXadWQtSTTJetBtAYcBW6rqtqp6BDgPOG7GOScCF1TVHQBVdd9CjZqsJQmoqQy8JVmf5Jq+bX1fU/sAd/btT/Te6/ebwJOTXJrk20neuFD/LINIEosrg1TVBmDDHIcz20dm7O8IvAB4ObALcEWSK6vqlrlimqwlCajp2XLsNpkA9u3bXwPcPcs5P62qh4CHklwOPA+YM1lbBpEkhlqzvhpYm2T/JDsBxwMbZ5zzJeAlSXZMsitwOHDTfI06spYkoGo4I+uqmkxyKnARsAo4p6o2Jzm5d/ysqropyVeA62kmA368qm6Yr12TtSQx3JtiqmoTsGnGe2fN2P9L4C8HbXPgZJ1kF+DpVXXzoJ+RpK6YnhpazboVA9Wsk/wecB3wld7+QUlm1mAkqbNqOgNvy2HQC4x/QTPR+xcAVXUdsF8bHZKk5TDuyXrQMshkVT2QjPevCZK0rWq8l7MeOFnfkOREYFWStcA7gG+21y1JGq3lGjEPatAyyNuB5wC/Bj4DPACc3lKfJGnkqjLwthwWHFn3VpDaWFWvAP68/S5J0uhNjflskAWTdVVNJXk4yZOq6oFRdEqSRm25RsyDGrRm/Svge0m+Cjy09c2qekcrvZKkERv3mvWgyfofepskrUgrYjZIVX2y7Y5I0nJaESPrJLfzr9djpaqeMfQeSdIymJoe70VIBy2DHNL3emfgD4HfGH53JGl5jHsZZKAfJVX1s77trqr6MHB0u12TpNGZrgy8LYdByyAH9+3uQDPSfmIrPZKkZbBSpu59sO/1JHA78EdznZxkN2DPqrp1xvu/VVXXL7qXktSyFVEGAd5cVS/rbcdU1XrgkdlOTPJHwPeBzyfZnOTQvsOfmC9I/xODPzVxz4Bdk6SlG/cyyKDJ+vwB3wP4M+AFVXUQcBLwt0n+oHds3m9ZVRuq6pCqOuSNa/YesGuStHRT0zsMvC2HecsgSZ5Ns4DTk/oSLsBuNLNCZm2zqu4BqKqrkrwM+HKSNcwy/U+SxsG4J6eFatbPAl4NrAZ+r+/9XwJvneMzDyY5YGu9uqruSXIU8EWaxC9JY2e5yhuDmjdZV9WXgC8leWFVXTFgm+9mRrmjqn6ZZB3w3m3rpiS1a6XMBvlOklNoRsaPlT+q6k9mOfeTwEeTfLCqJgGS7EUzo+RZwH9bWpclafiG+HDzVgxaKf9b4KnAK4HLgDU0pZDZvADYnybBH53kNOAq4Arg8KV1V5LaUWTgbTkMOrJ+ZlX9YZLjquqTST4DXDTbiVV1P3ByL0lfAtwNHFFVE8PpsiQN3+SYl0EGHVk/2vvzF0meCzyJOZ5unmR1ko/STNtbRzPF78Ik3p4uaWytlJH1hiRPBv4LsBF4AvC+Oc69FjgTOKVXs744yUHAmUl+VFUnLLHPkjR0416zHnQ964/3Xl4GLLQs6pEzSx5VdR3woiRzTfeTpGW1XCPmQQ1UBkmyV5Kzk1zY2z8wyZtnO3e+2nRVfWzbuilJ7ZpexLYcBq1Zf4LmguLTevu3AKe30B9JWhZTZOBtOQyarPeoqs/R+6HSq0VPtdYrSRqx6Qy+LYdBLzA+lGR3erfPJzkCeKC1XknSiE2Pec160GT9LppZIAck+X/AnsB/aK1XkjRinV7IKcnTq+qOqro2yUtpbhcPcHNVPTrfZyWpS7o+de+LwNZHen22ql7XbnckaXlMp9tlkP7eLzS/WpI6a9xnTCw0G6TmeC1JK8owZ4MkWZfk5iRbkrxnnvMOTTKVZMFrgAuNrJ+X5EGaEfYuvdf09quqdlu425I0/oY1GyTJKuAM4BhgArg6ycaqunGW8z7AHIvizbTQwwdWbVt3Jalbhlg6OAzYUlW3ASQ5DzgOuHHGeW8HPg8cygCW58mPkjRmFlMGSbI+yTV92/q+pvYB7uzbn+i995gk+wCvBc4atH+DzrOWpBVtMVP3qmoDsGGOw7PVU2YO3D8MvLuqpjLgLBSTtSQBU8ObuTcB7Nu3v4bmISz9DgHO6yXqPYBXJZmsqi/O1ajJWpIY6k0xVwNrk+wP3AUcD5zYf0JV7b/1dZJPAF+eL1GDyVqSgOEl66qaTHIqzSyPVcA5VbU5ycm94wPXqfuZrCUJGOYjGKtqE7BpxnuzJumqetMgbZqsJYnurw0iSduFcb/d3GQtSSzfQwUGZbKWJCyDSFInmKwlqQPGfVlRk7UkYc1akjrB2SDb6KjrfzWSOA98+9KRxAHYddXOI4t12wP3jCzWM1c/bWSxbv3Pfz+SOKP8lfilT3nOyGJd/8sfjSzW83bbb2SxLnnX0tuYHvNCyNgma0kaJS8wSlIHjPe42mQtSYAja0nqhMmM99jaZC1JWAaRpE6wDCJJHeDUPUnqgPFO1SZrSQIsg0hSJ0yN+djaZC1JOLKWpE4oR9aSNP4cWUtSBzh1T5I6YLxTtclakgCYHPN0bbKWJLzAKEmd4AVGSeoAR9aS1AGOrCWpA6bKkbUkjT3nWUtSB1izlqQOsGYtSR0w7mWQHZa7A5I0DmoR/ywkybokNyfZkuQ9sxz/j0mu723fTPK8hdp0ZC1JDG82SJJVwBnAMcAEcHWSjVV1Y99ptwMvrar7kxwLbAAOn69dk7UkMdQyyGHAlqq6DSDJecBxwGPJuqq+2Xf+lcCahRq1DCJJNBcYB92SrE9yTd+2vq+pfYA7+/Yneu/N5c3AhQv1z5G1JLG4qXtVtYGmdDGbzNr8bCcmL6NJ1r+9UEyTtSQx1DLIBLBv3/4a4O6ZJyX5LeDjwLFV9bOFGrUMIklAVQ28LeBqYG2S/ZPsBBwPbOw/IcnTgQuAN1TVLYP0z5G1JAFTQxpZV9VkklOBi4BVwDlVtTnJyb3jZwHvA3YHzkwCMFlVh8zXrslakhjuTTFVtQnYNOO9s/pevwV4y2LaNFlLEgxS3lhWrSTrJE8FqKp7k+wJvAS4uao2txFPkpZqu7vdPMnbgCuAK5P8J+DLwKuBC5K8eYHPPjZ38f5/um/YXZOkOQ3zdvM2tDGyPhV4DrAL8CPgmb0R9pOBrwNnz/XB/rmLz9nr8PH+MSdpRdkeHz7waFU9DDyc5Naquhegdw/8eP/bkLTdGvcySBvJejrJ46rqUeB3t76ZZGec1y1pTG2Pyfq19G6trKqJvvd3B85vIZ4kLdm4zwZpY6R7GfCuJI/9IEiyF/AB4DUtxJOkJZumBt6WQxvJ+gXAAcB3khyd5DTgKpoZIvOu1ypJy2W7mw1SVfcDb+sl6UtoFjA5YkZJRJLGylSN91MY25hnvTrJR4GTgHU0deoLkxw97FiSNCxDXMipFW1cYLwWOBM4paomgYuTHESzYMmPquqEFmJK0pJsj7NBjpxZ8qiq64AXJXlrC/EkacmWqxY9qDZq1nPWpqvqY8OOJ0nDMD3mU/dcdU+S2A5H1pLUReM+G8RkLUlYBpGkTrAMIkkd4MhakjrAkbUkdcBUTS13F+ZlspYkxn+JVJO1JLF93m4uSZ3jyFqSOsDZIJLUAc4GkaQO8HZzSeoAa9aS1AHWrCWpAxxZS1IHOM9akjrAkbUkdYCzQSSpA7zAKEkdMO5lkB2WuwOSNA5qEf8sJMm6JDcn2ZLkPbMcT5K/6h2/PsnBC7VpspYkmpH1oNt8kqwCzgCOBQ4ETkhy4IzTjgXW9rb1wEcW6p/JWpJoataDbgs4DNhSVbdV1SPAecBxM845DvhUNa4EVifZe75Gx7ZmvfnH38q2fC7J+qraMOz+LFccY3Ur1kr8Tis5Vr/JR+4aOOckWU8zIt5qQ1+f9wHu7Ds2ARw+o4nZztkHuGeumCtxZL1+4VM6FcdY3Yq1Er/TSo61TapqQ1Ud0rf1/3CZLenPHI4Pcs6/sBKTtSQtpwlg3779NcDd23DOv2CylqThuhpYm2T/JDsBxwMbZ5yzEXhjb1bIEcADVTVnCQTGuGa9BKOqdY2ypmas7sRaid9pJccauqqaTHIqcBGwCjinqjYnObl3/CxgE/AqYAvwMHDSQu1m3CeCS5Isg0hSJ5isJakDVkyyXuj2ziHGOSfJfUluaCtGX6x9k3w9yU1JNic5raU4Oye5Ksl3e3H+axtxZsRcleQ7Sb7ccpwfJvlekuuSXNNyrNVJzk/y/d7f2QtbivOs3vfZuj2Y5PSWYr2z99/EDUnOTbJzG3F6sU7rxdnc1vfptMXcYjmuG00R/1bgGcBOwHeBA1uKdSRwMHDDCL7X3sDBvddPBG5p43vRzPl8Qu/144BvAUe0/N3eBXwG+HLLcX4I7NH231Uv1ieBt/Re7wSsHkHMVcC9wL9toe19gNuBXXr7nwPe1NL3eC5wA7ArzcSHS4C1o/h768q2UkbWg9zeORRVdTnw8zbaniXWPVV1be/1L4GbaP4HGnacqqp/7O0+rre1duU5yRrgd4GPtxVj1JLsRvOD/GyAqnqkqn4xgtAvB26tqh+11P6OwC5JdqRJpPPOBV6CfwdcWVUPV9UkcBnw2pZiddJKSdZz3bq5YiTZD3g+zai3jfZXJbkOuA/4alW1Eqfnw8CfAqNY7b2Ai5N8u3eLcFueAfwE+JteeefjSR7fYrytjgfObaPhqroL+J/AHTS3QT9QVRe3EYtmVH1kkt2T7EozrW3fBT6zXVkpyXrRt252SZInAJ8HTq+qB9uIUVVTVXUQzZ1UhyV5bhtxkrwauK+qvt1G+7N4cVUdTLPK2SlJjmwpzo405bGPVNXzgYeA1q6dAPRuuHgN8H9aav/JNL+h7g88DXh8kj9uI1ZV3QR8APgq8BWaUuZkG7G6aqUk60XfutkVSR5Hk6g/XVUXtB2v96v7pcC6lkK8GHhNkh/SlKuOTvJ3LcWiqu7u/Xkf8AWaklkbJoCJvt9IzqdJ3m06Fri2qn7cUvuvAG6vqp9U1aPABcCLWopFVZ1dVQdX1ZE0pcYftBWri1ZKsh7k9s7OSRKaGuhNVfWhFuPsmWR17/UuNP+Tfr+NWFX13qpaU1X70fw9fa2qWhmtJXl8kidufQ38Ds2v20NXVfcCdyZ5Vu+tlwM3thGrzwm0VALpuQM4Ismuvf8WX05z3aQVSZ7S+/PpwB/Q7nfrnBVxu3nNcXtnG7GSnAscBeyRZAJ4f1Wd3UYsmlHoG4Dv9erJAH9WVZuGHGdv4JO9RdN3AD5XVa1OqRuRvYAvNHmGHYHPVNVXWoz3duDTvQHDbQxwC/G26tV1jwHe1laMqvpWkvOBa2lKEt+h3VvBP59kd+BR4JSqur/FWJ3j7eaS1AErpQwiSSuayVqSOsBkLUkdYLKWpA4wWUtSB5is1aoklyZ55Yz3Tk9y5pDaf1OSv57v/SR/keSu3gp1P0hyQZIDhxFfGhWTtdp2Ls3NL/1aW89iHv+rqg6qqrXAZ4GvJdlzxH2QtpnJWm07H3h1kn8Djy1I9TTg/yY5obfW9A1JPtA7/tokl/QeJLp3kluSPLV3l+Xnk1zd2168rR2qqs8CFwMnLv3rSaNhslarqupnwFX881ojx9OMbPemWbjnaOAg4NAkv19VX6BZn/kU4GM0d4jeC/xvmtHxocDrWPryqtcCz15iG9LIrIjbzTX2tpZCvtT780+AQ4FLq+onAEk+TbMe9Bdpbtu+gWZ9463lklcAB/ZuHQfYbeu6H9totpUapbFlstYofBH4UJKDaZ46cm1vsZ657EOz1vVeSXaoqmma3wJfWFX/1H9iX/JerOcDrT7mSxomyyBqXe8pNJcC5/DPFxa/Bbw0yR69BaROAC7rPZHkb2jqyTfRPP4LmhrzqVvbTHLQtvYnyetoVuBzVTd1hiNrjcq5NOshHw/NI8uSvBf4Ok1JYlNVfSnJ+4BvVNU3eisNXp3kH4B3AGckuZ7mv9vLgZMXEf+dvYXzH09TYjl6awlG6gJX3ZOkDrAMIkkdYLKWpA4wWUtSB5isJakDTNaS1AEma0nqAJO1JHXA/wfKPu2b5DndvAAAAABJRU5ErkJggg==\n",
187 | "text/plain": [
188 | ""
189 | ]
190 | },
191 | "metadata": {
192 | "needs_background": "light"
193 | },
194 | "output_type": "display_data"
195 | }
196 | ],
197 | "source": [
198 | "sns.heatmap(S_average.T, vmin=0, vmax=1)\n",
199 | "plt.xlabel('Voxel ID')\n",
200 | "plt.ylabel('Feature')\n",
201 | "plt.yticks([0.5,1.5,2.5],['X1', 'X2', 'X3'])"
202 | ]
203 | },
204 | {
205 | "cell_type": "code",
206 | "execution_count": null,
207 | "id": "answering-fireplace",
208 | "metadata": {},
209 | "outputs": [],
210 | "source": []
211 | }
212 | ],
213 | "metadata": {
214 | "kernelspec": {
215 | "display_name": "Python 3",
216 | "language": "python",
217 | "name": "python3"
218 | },
219 | "language_info": {
220 | "codemirror_mode": {
221 | "name": "ipython",
222 | "version": 3
223 | },
224 | "file_extension": ".py",
225 | "mimetype": "text/x-python",
226 | "name": "python",
227 | "nbconvert_exporter": "python",
228 | "pygments_lexer": "ipython3",
229 | "version": "3.7.9"
230 | }
231 | },
232 | "nbformat": 4,
233 | "nbformat_minor": 5
234 | }
235 |
--------------------------------------------------------------------------------