├── figure1.png ├── figure2.png ├── figure3.png ├── figure4.png ├── BraVL_EEG ├── utils │ ├── .text.py.swp │ ├── filehandling.py │ ├── BaseExperiment.py │ ├── utils.py │ ├── TBLogger.py │ ├── BaseFlags.py │ └── BaseMMVae.py ├── brain_image_text │ ├── constants.py │ ├── networks │ │ ├── __pycache__ │ │ │ ├── QNET.cpython-37.pyc │ │ │ ├── MLP_Text.cpython-37.pyc │ │ │ ├── MLP_Brain.cpython-37.pyc │ │ │ ├── MLP_Image.cpython-37.pyc │ │ │ └── VAEtrimodal.cpython-37.pyc │ │ ├── VAEtrimodal.py │ │ ├── QNET.py │ │ ├── MLP_Image.py │ │ ├── MLP_Text.py │ │ └── MLP_Brain.py │ ├── flags.py │ └── experiment.py ├── alphabet.json ├── modalities │ └── Modality.py ├── main_trimodal.py ├── divergence_measures │ ├── mmd.py │ ├── kl_div.py │ └── mm_div.py └── job_trimodal ├── BraVL_fMRI ├── utils │ ├── .text.py.swp │ ├── filehandling.py │ ├── BaseExperiment.py │ ├── utils.py │ ├── BaseFlags.py │ ├── TBLogger.py │ └── BaseMMVae.py ├── brain_image_text │ ├── constants.py │ ├── networks │ │ ├── __pycache__ │ │ │ ├── QNET.cpython-37.pyc │ │ │ ├── MLP_Brain.cpython-37.pyc │ │ │ ├── MLP_Image.cpython-37.pyc │ │ │ ├── MLP_Text.cpython-37.pyc │ │ │ └── VAEtrimodal.cpython-37.pyc │ │ ├── VAEtrimodal.py │ │ ├── QNET.py │ │ ├── MLP_Image.py │ │ ├── MLP_Text.py │ │ └── MLP_Brain.py │ ├── flags.py │ └── experiment.py ├── alphabet.json ├── modalities │ └── Modality.py ├── main_trimodal.py ├── divergence_measures │ ├── mmd.py │ ├── kl_div.py │ └── mm_div.py ├── stability_selection.py ├── job_trimodal ├── extract_fea_with_timm.py ├── data_prepare_with_aug_GOD_Wiki.py └── data_prepare_with_aug_DIR_Wiki.py ├── environment.yml ├── LICENSE └── README.md /figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/figure1.png -------------------------------------------------------------------------------- /figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/figure2.png -------------------------------------------------------------------------------- /figure3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/figure3.png -------------------------------------------------------------------------------- /figure4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/figure4.png -------------------------------------------------------------------------------- /BraVL_EEG/utils/.text.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_EEG/utils/.text.py.swp -------------------------------------------------------------------------------- /BraVL_fMRI/utils/.text.py.swp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_fMRI/utils/.text.py.swp -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/constants.py: -------------------------------------------------------------------------------- 1 | 2 | indices = {'img_mnist': 0, 'img_svhn': 1, 'text': 2}; 3 | modalities = ['img_mnist', 'img_svhn', 'text']; 4 | -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/constants.py: -------------------------------------------------------------------------------- 1 | 2 | indices = {'img_mnist': 0, 'img_svhn': 1, 'text': 2}; 3 | modalities = ['img_mnist', 'img_svhn', 'text']; 4 | -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/__pycache__/QNET.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_EEG/brain_image_text/networks/__pycache__/QNET.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/__pycache__/QNET.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_fMRI/brain_image_text/networks/__pycache__/QNET.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/__pycache__/MLP_Text.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_EEG/brain_image_text/networks/__pycache__/MLP_Text.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/__pycache__/MLP_Brain.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_EEG/brain_image_text/networks/__pycache__/MLP_Brain.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/__pycache__/MLP_Image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_EEG/brain_image_text/networks/__pycache__/MLP_Image.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/__pycache__/MLP_Brain.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_fMRI/brain_image_text/networks/__pycache__/MLP_Brain.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/__pycache__/MLP_Image.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_fMRI/brain_image_text/networks/__pycache__/MLP_Image.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/__pycache__/MLP_Text.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_fMRI/brain_image_text/networks/__pycache__/MLP_Text.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/__pycache__/VAEtrimodal.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_EEG/brain_image_text/networks/__pycache__/VAEtrimodal.cpython-37.pyc -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/__pycache__/VAEtrimodal.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ChangdeDu/BraVL/HEAD/BraVL_fMRI/brain_image_text/networks/__pycache__/VAEtrimodal.cpython-37.pyc -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: BraVL 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | - bioconda 8 | dependencies: 9 | - matplotlib=3.4.2 10 | - seaborn=0.11.2 11 | - numpy=1.19.1 12 | - python=3.7.0 13 | - pytorch=1.9.0 14 | - scikit-learn=0.24.2 15 | - scipy=1.7.1 16 | - tensorboardx=2.4 17 | - torchvision=0.10.0 18 | - transformers=4.9.2 19 | - timm=0.4.12 20 | -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/VAEtrimodal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utils import utils 7 | from utils.BaseMMVae import BaseMMVae 8 | 9 | 10 | class VAEtrimodal(BaseMMVae, nn.Module): 11 | def __init__(self, flags, modalities, subsets): 12 | super().__init__(flags, modalities, subsets) 13 | 14 | class VAEbimodal(BaseMMVae, nn.Module): 15 | def __init__(self, flags, modalities, subsets): 16 | super().__init__(flags, modalities, subsets) 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/VAEtrimodal.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from utils import utils 7 | from utils.BaseMMVae import BaseMMVae 8 | 9 | 10 | class VAEtrimodal(BaseMMVae, nn.Module): 11 | def __init__(self, flags, modalities, subsets): 12 | super().__init__(flags, modalities, subsets) 13 | 14 | class VAEbimodal(BaseMMVae, nn.Module): 15 | def __init__(self, flags, modalities, subsets): 16 | super().__init__(flags, modalities, subsets) 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/QNET.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | class QNet(nn.Module): 5 | def __init__(self, input_dim,latent_dim): 6 | super(QNet, self).__init__() 7 | self.fc1 = nn.Linear(input_dim,512) 8 | self.fc21 = nn.Linear(512, latent_dim) 9 | self.fc22 = nn.Linear(512, latent_dim) 10 | 11 | def forward(self, x): 12 | e = F.relu(self.fc1(x)) 13 | mu = self.fc21(e) 14 | lv = self.fc22(e) 15 | # return mu,lv.mul(0.5).exp_() 16 | return mu,torch.tensor(0.75).cuda() -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/QNET.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torch 4 | class QNet(nn.Module): 5 | def __init__(self, input_dim,latent_dim): 6 | super(QNet, self).__init__() 7 | self.fc1 = nn.Linear(input_dim,512) 8 | self.fc21 = nn.Linear(512, latent_dim) 9 | self.fc22 = nn.Linear(512, latent_dim) 10 | 11 | def forward(self, x): 12 | e = F.relu(self.fc1(x)) 13 | mu = self.fc21(e) 14 | lv = self.fc22(e) 15 | # return mu,lv.mul(0.5).exp_() 16 | return mu,torch.tensor(0.75).cuda() -------------------------------------------------------------------------------- /BraVL_EEG/alphabet.json: -------------------------------------------------------------------------------- 1 | [ 2 | "a", 3 | "b", 4 | "c", 5 | "d", 6 | "e", 7 | "f", 8 | "g", 9 | "h", 10 | "i", 11 | "j", 12 | "k", 13 | "l", 14 | "m", 15 | "n", 16 | "o", 17 | "p", 18 | "q", 19 | "r", 20 | "s", 21 | "t", 22 | "u", 23 | "v", 24 | "w", 25 | "x", 26 | "y", 27 | "z", 28 | 29 | "0", 30 | "1", 31 | "2", 32 | "3", 33 | "4", 34 | "5", 35 | "6", 36 | "7", 37 | "8", 38 | "9", 39 | 40 | "-", 41 | ",", 42 | ";", 43 | ".", 44 | "!", 45 | "?", 46 | ":", 47 | "'", 48 | "\"", 49 | "\\", 50 | "/", 51 | "|", 52 | "_", 53 | "@", 54 | "#", 55 | "$", 56 | "%", 57 | "^", 58 | "&", 59 | "*", 60 | "~", 61 | "`", 62 | "+", 63 | "-", 64 | "=", 65 | "<", 66 | ">", 67 | "(", 68 | ")", 69 | "[", 70 | "]", 71 | "{", 72 | "}", 73 | " ", 74 | "\n" 75 | ] -------------------------------------------------------------------------------- /BraVL_fMRI/alphabet.json: -------------------------------------------------------------------------------- 1 | [ 2 | "a", 3 | "b", 4 | "c", 5 | "d", 6 | "e", 7 | "f", 8 | "g", 9 | "h", 10 | "i", 11 | "j", 12 | "k", 13 | "l", 14 | "m", 15 | "n", 16 | "o", 17 | "p", 18 | "q", 19 | "r", 20 | "s", 21 | "t", 22 | "u", 23 | "v", 24 | "w", 25 | "x", 26 | "y", 27 | "z", 28 | 29 | "0", 30 | "1", 31 | "2", 32 | "3", 33 | "4", 34 | "5", 35 | "6", 36 | "7", 37 | "8", 38 | "9", 39 | 40 | "-", 41 | ",", 42 | ";", 43 | ".", 44 | "!", 45 | "?", 46 | ":", 47 | "'", 48 | "\"", 49 | "\\", 50 | "/", 51 | "|", 52 | "_", 53 | "@", 54 | "#", 55 | "$", 56 | "%", 57 | "^", 58 | "&", 59 | "*", 60 | "~", 61 | "`", 62 | "+", 63 | "-", 64 | "=", 65 | "<", 66 | ">", 67 | "(", 68 | ")", 69 | "[", 70 | "]", 71 | "{", 72 | "}", 73 | " ", 74 | "\n" 75 | ] -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Changde Du 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /BraVL_EEG/utils/filehandling.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from datetime import datetime 4 | 5 | def create_dir(dir_name): 6 | if not os.path.exists(dir_name): 7 | os.makedirs(dir_name) 8 | # else: 9 | # shutil.rmtree(dir_name, ignore_errors=True) 10 | # os.makedirs(dir_name) 11 | 12 | 13 | def get_str_experiments(flags): 14 | dateTimeObj = datetime.now() 15 | dateStr = dateTimeObj.strftime("%Y_%m_%d") 16 | str_experiments = flags.dataset + '_' + dateStr; 17 | return str_experiments 18 | 19 | def create_dir_structure(flags, train=True): 20 | if train: 21 | str_experiments = get_str_experiments(flags) 22 | flags.dir_experiment_run = os.path.join(flags.dir_experiment, str_experiments) 23 | flags.str_experiment = str_experiments; 24 | else: 25 | flags.dir_experiment_run = flags.dir_experiment; 26 | 27 | print(flags.dir_experiment_run) 28 | if train: 29 | create_dir(flags.dir_experiment_run) 30 | 31 | flags.dir_checkpoints = os.path.join(flags.dir_experiment_run, 'checkpoints') 32 | if train: 33 | create_dir(flags.dir_checkpoints) 34 | 35 | flags.dir_logs = os.path.join(flags.dir_experiment_run, 'logs') 36 | if train: 37 | create_dir(flags.dir_logs) 38 | print(flags.dir_logs) 39 | return flags; 40 | -------------------------------------------------------------------------------- /BraVL_fMRI/utils/filehandling.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from datetime import datetime 4 | 5 | def create_dir(dir_name): 6 | if not os.path.exists(dir_name): 7 | os.makedirs(dir_name) 8 | # else: 9 | # shutil.rmtree(dir_name, ignore_errors=True) 10 | # os.makedirs(dir_name) 11 | 12 | 13 | def get_str_experiments(flags): 14 | dateTimeObj = datetime.now() 15 | dateStr = dateTimeObj.strftime("%Y_%m_%d") 16 | str_experiments = flags.dataset + '_' + dateStr; 17 | return str_experiments 18 | 19 | def create_dir_structure(flags, train=True): 20 | if train: 21 | str_experiments = get_str_experiments(flags) 22 | flags.dir_experiment_run = os.path.join(flags.dir_experiment, str_experiments) 23 | flags.str_experiment = str_experiments; 24 | else: 25 | flags.dir_experiment_run = flags.dir_experiment; 26 | 27 | print(flags.dir_experiment_run) 28 | if train: 29 | create_dir(flags.dir_experiment_run) 30 | 31 | flags.dir_checkpoints = os.path.join(flags.dir_experiment_run, 'checkpoints') 32 | if train: 33 | create_dir(flags.dir_checkpoints) 34 | 35 | flags.dir_logs = os.path.join(flags.dir_experiment_run, 'logs') 36 | if train: 37 | create_dir(flags.dir_logs) 38 | print(flags.dir_logs) 39 | return flags; 40 | -------------------------------------------------------------------------------- /BraVL_EEG/modalities/Modality.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABC, abstractmethod 3 | import os 4 | 5 | import torch 6 | import torch.distributions as dist 7 | 8 | class Modality(ABC): 9 | def __init__(self, name, enc, dec, class_dim, style_dim, lhood_name): 10 | self.name = name; 11 | self.encoder = enc; 12 | self.decoder = dec; 13 | self.class_dim = class_dim; 14 | self.style_dim = style_dim; 15 | self.likelihood_name = lhood_name; 16 | self.likelihood = self.get_likelihood(lhood_name); 17 | 18 | 19 | def get_likelihood(self, name): 20 | if name == 'laplace': 21 | pz = dist.Laplace; 22 | elif name == 'bernoulli': 23 | pz = dist.Bernoulli; 24 | elif name == 'normal': 25 | pz = dist.Normal; 26 | elif name == 'categorical': 27 | pz = dist.OneHotCategorical; 28 | else: 29 | print('likelihood not implemented') 30 | pz = None; 31 | return pz; 32 | 33 | 34 | 35 | 36 | 37 | def calc_log_prob(self, out_dist, target, norm_value): 38 | log_prob = out_dist.log_prob(target).sum(); 39 | mean_val_logprob = log_prob/norm_value; 40 | return mean_val_logprob; 41 | 42 | 43 | def save_networks(self, dir_checkpoints): 44 | torch.save(self.encoder.state_dict(), os.path.join(dir_checkpoints, 45 | 'enc_' + self.name)) 46 | torch.save(self.decoder.state_dict(), os.path.join(dir_checkpoints, 47 | 'dec_' + self.name)) 48 | 49 | -------------------------------------------------------------------------------- /BraVL_fMRI/modalities/Modality.py: -------------------------------------------------------------------------------- 1 | 2 | from abc import ABC, abstractmethod 3 | import os 4 | 5 | import torch 6 | import torch.distributions as dist 7 | 8 | class Modality(ABC): 9 | def __init__(self, name, enc, dec, class_dim, style_dim, lhood_name): 10 | self.name = name; 11 | self.encoder = enc; 12 | self.decoder = dec; 13 | self.class_dim = class_dim; 14 | self.style_dim = style_dim; 15 | self.likelihood_name = lhood_name; 16 | self.likelihood = self.get_likelihood(lhood_name); 17 | 18 | 19 | def get_likelihood(self, name): 20 | if name == 'laplace': 21 | pz = dist.Laplace; 22 | elif name == 'bernoulli': 23 | pz = dist.Bernoulli; 24 | elif name == 'normal': 25 | pz = dist.Normal; 26 | elif name == 'categorical': 27 | pz = dist.OneHotCategorical; 28 | else: 29 | print('likelihood not implemented') 30 | pz = None; 31 | return pz; 32 | 33 | 34 | 35 | 36 | 37 | def calc_log_prob(self, out_dist, target, norm_value): 38 | log_prob = out_dist.log_prob(target).sum(); 39 | mean_val_logprob = log_prob/norm_value; 40 | return mean_val_logprob; 41 | 42 | 43 | def save_networks(self, dir_checkpoints): 44 | torch.save(self.encoder.state_dict(), os.path.join(dir_checkpoints, 45 | 'enc_' + self.name)) 46 | torch.save(self.decoder.state_dict(), os.path.join(dir_checkpoints, 47 | 'dec_' + self.name)) 48 | 49 | -------------------------------------------------------------------------------- /BraVL_EEG/main_trimodal.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '5' 4 | import json 5 | import torch 6 | from run_epochs_trimodal import run_epochs_trimodal 7 | from utils.filehandling import create_dir_structure 8 | from brain_image_text.flags import parser 9 | from brain_image_text.experiment import BrainImageText 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | if __name__ == '__main__': 12 | FLAGS = parser.parse_args() 13 | use_cuda = torch.cuda.is_available() 14 | FLAGS.device = torch.device('cuda' if use_cuda else 'cpu') 15 | 16 | if FLAGS.method == 'poe': 17 | FLAGS.modality_poe=True 18 | elif FLAGS.method == 'moe': 19 | FLAGS.modality_moe=True 20 | elif FLAGS.method == 'jsd': 21 | FLAGS.modality_jsd=True 22 | elif FLAGS.method == 'joint_elbo': 23 | FLAGS.joint_elbo=True 24 | else: 25 | print('method implemented...exit!') 26 | sys.exit() 27 | print(FLAGS.modality_poe) 28 | print(FLAGS.modality_moe) 29 | print(FLAGS.modality_jsd) 30 | print(FLAGS.joint_elbo) 31 | 32 | FLAGS.alpha_modalities = [FLAGS.div_weight_uniform_content, FLAGS.div_weight_m1_content, 33 | FLAGS.div_weight_m2_content, FLAGS.div_weight_m3_content] 34 | 35 | FLAGS = create_dir_structure(FLAGS) 36 | alphabet_path = os.path.join(os.getcwd(), 'alphabet.json') 37 | with open(alphabet_path) as alphabet_file: 38 | alphabet = str(''.join(json.load(alphabet_file))) 39 | 40 | mst = BrainImageText(FLAGS, alphabet) 41 | mst.set_optimizer() 42 | total_params = sum(p.numel() for p in mst.mm_vae.parameters()) 43 | print('num parameters model: ' + str(total_params)) 44 | run_epochs_trimodal(mst) 45 | -------------------------------------------------------------------------------- /BraVL_fMRI/main_trimodal.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | os.environ['CUDA_VISIBLE_DEVICES'] = '1' 4 | import json 5 | import torch 6 | from run_epochs_trimodal import run_epochs_trimodal 7 | from utils.filehandling import create_dir_structure 8 | from brain_image_text.flags import parser 9 | from brain_image_text.experiment import BrainImageText 10 | torch.set_default_tensor_type(torch.DoubleTensor) 11 | if __name__ == '__main__': 12 | FLAGS = parser.parse_args() 13 | use_cuda = torch.cuda.is_available() 14 | FLAGS.device = torch.device('cuda' if use_cuda else 'cpu') 15 | 16 | if FLAGS.method == 'poe': 17 | FLAGS.modality_poe=True 18 | elif FLAGS.method == 'moe': 19 | FLAGS.modality_moe=True 20 | elif FLAGS.method == 'jsd': 21 | FLAGS.modality_jsd=True 22 | elif FLAGS.method == 'joint_elbo': 23 | FLAGS.joint_elbo=True 24 | else: 25 | print('method implemented...exit!') 26 | sys.exit() 27 | print(FLAGS.modality_poe) 28 | print(FLAGS.modality_moe) 29 | print(FLAGS.modality_jsd) 30 | print(FLAGS.joint_elbo) 31 | 32 | FLAGS.alpha_modalities = [FLAGS.div_weight_uniform_content, FLAGS.div_weight_m1_content, 33 | FLAGS.div_weight_m2_content, FLAGS.div_weight_m3_content] 34 | 35 | FLAGS = create_dir_structure(FLAGS) 36 | alphabet_path = os.path.join(os.getcwd(), 'alphabet.json') 37 | with open(alphabet_path) as alphabet_file: 38 | alphabet = str(''.join(json.load(alphabet_file))) 39 | 40 | mst = BrainImageText(FLAGS, alphabet) 41 | mst.set_optimizer() 42 | total_params = sum(p.numel() for p in mst.mm_vae.parameters()) 43 | print('num parameters model: ' + str(total_params)) 44 | run_epochs_trimodal(mst) 45 | -------------------------------------------------------------------------------- /BraVL_EEG/divergence_measures/mmd.py: -------------------------------------------------------------------------------- 1 | def mmd_loss(z_tilde, z, z_var): 2 | r"""Calculate maximum mean discrepancy described in the WAE paper. 3 | Args: 4 | z_tilde (Tensor): samples from deterministic non-random encoder Q(Z|X). 5 | 2D Tensor(batch_size x dimension). 6 | z (Tensor): samples from prior distributions. same shape with z_tilde. 7 | z_var (Number): scalar variance of isotropic gaussian prior P(Z). 8 | """ 9 | assert z_tilde.size() == z.size() 10 | assert z.ndimension() == 2 11 | 12 | n = z.size(0) 13 | im_kernel_z_z = im_kernel_sum(z, z, z_var, exclude_diag=True).div(n*(n-1)) 14 | im_kernel_ztilde_ztilde = im_kernel_sum(z_tilde, z_tilde, z_var, exclude_diag=True).div(n*(n-1)) 15 | im_kernel_z_ztilde = im_kernel_sum(z, z_tilde, z_var, exclude_diag=False).div(n*n).mul(2) 16 | out = im_kernel_z_z + im_kernel_ztilde_ztilde - im_kernel_z_ztilde 17 | return out, im_kernel_z_z, im_kernel_ztilde_ztilde, im_kernel_z_ztilde 18 | 19 | 20 | def im_kernel_sum(z1, z2, z_var, exclude_diag=True): 21 | r"""Calculate sum of sample-wise measures of inverse multiquadratics kernel described in the WAE paper. 22 | Args: 23 | z1 (Tensor): batch of samples from a multivariate gaussian distribution \ 24 | with scalar variance of z_var. 25 | z2 (Tensor): batch of samples from another multivariate gaussian distribution \ 26 | with scalar variance of z_var. 27 | exclude_diag (bool): whether to exclude diagonal kernel measures before sum it all. 28 | """ 29 | assert z1.size() == z2.size() 30 | assert z1.ndimension() == 2 31 | 32 | z_dim = z1.size(1) 33 | C = 2*z_dim*z_var 34 | 35 | z11 = z1.unsqueeze(1).repeat(1, z2.size(0), 1) 36 | z22 = z2.unsqueeze(0).repeat(z1.size(0), 1, 1) 37 | 38 | kernel_matrix = C/(1e-9+C+(z11-z22).pow(2).sum(2)) 39 | kernel_sum = kernel_matrix.sum() 40 | # numerically identical to the formulation. but.. 41 | if exclude_diag: 42 | kernel_sum -= kernel_matrix.diag().sum() 43 | return kernel_sum -------------------------------------------------------------------------------- /BraVL_fMRI/divergence_measures/mmd.py: -------------------------------------------------------------------------------- 1 | def mmd_loss(z_tilde, z, z_var): 2 | r"""Calculate maximum mean discrepancy described in the WAE paper. 3 | Args: 4 | z_tilde (Tensor): samples from deterministic non-random encoder Q(Z|X). 5 | 2D Tensor(batch_size x dimension). 6 | z (Tensor): samples from prior distributions. same shape with z_tilde. 7 | z_var (Number): scalar variance of isotropic gaussian prior P(Z). 8 | """ 9 | assert z_tilde.size() == z.size() 10 | assert z.ndimension() == 2 11 | 12 | n = z.size(0) 13 | im_kernel_z_z = im_kernel_sum(z, z, z_var, exclude_diag=True).div(n*(n-1)) 14 | im_kernel_ztilde_ztilde = im_kernel_sum(z_tilde, z_tilde, z_var, exclude_diag=True).div(n*(n-1)) 15 | im_kernel_z_ztilde = im_kernel_sum(z, z_tilde, z_var, exclude_diag=False).div(n*n).mul(2) 16 | out = im_kernel_z_z + im_kernel_ztilde_ztilde - im_kernel_z_ztilde 17 | return out, im_kernel_z_z, im_kernel_ztilde_ztilde, im_kernel_z_ztilde 18 | 19 | 20 | def im_kernel_sum(z1, z2, z_var, exclude_diag=True): 21 | r"""Calculate sum of sample-wise measures of inverse multiquadratics kernel described in the WAE paper. 22 | Args: 23 | z1 (Tensor): batch of samples from a multivariate gaussian distribution \ 24 | with scalar variance of z_var. 25 | z2 (Tensor): batch of samples from another multivariate gaussian distribution \ 26 | with scalar variance of z_var. 27 | exclude_diag (bool): whether to exclude diagonal kernel measures before sum it all. 28 | """ 29 | assert z1.size() == z2.size() 30 | assert z1.ndimension() == 2 31 | 32 | z_dim = z1.size(1) 33 | C = 2*z_dim*z_var 34 | 35 | z11 = z1.unsqueeze(1).repeat(1, z2.size(0), 1) 36 | z22 = z2.unsqueeze(0).repeat(z1.size(0), 1, 1) 37 | 38 | kernel_matrix = C/(1e-9+C+(z11-z22).pow(2).sum(2)) 39 | kernel_sum = kernel_matrix.sum() 40 | # numerically identical to the formulation. but.. 41 | if exclude_diag: 42 | kernel_sum -= kernel_matrix.diag().sum() 43 | return kernel_sum -------------------------------------------------------------------------------- /BraVL_EEG/utils/BaseExperiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from itertools import chain, combinations 4 | 5 | class BaseExperiment(ABC): 6 | def __init__(self, flags): 7 | self.flags = flags 8 | self.name = flags.dataset 9 | 10 | self.modalities = None 11 | self.num_modalities = None 12 | self.subsets = None 13 | self.dataset_train = None 14 | self.dataset_test = None 15 | self.Q1, self.Q2, self.Q3 = None,None,None 16 | self.mm_vae = None 17 | self.clfs = None 18 | self.optimizer = None 19 | self.rec_weights = None 20 | self.style_weights = None 21 | 22 | self.test_samples = None 23 | self.paths_fid = None 24 | 25 | 26 | @abstractmethod 27 | def set_model(self): 28 | pass 29 | 30 | @abstractmethod 31 | def set_Qmodel(self): 32 | pass 33 | 34 | @abstractmethod 35 | def set_modalities(self): 36 | pass 37 | 38 | @abstractmethod 39 | def set_dataset(self): 40 | pass 41 | 42 | 43 | @abstractmethod 44 | def set_optimizer(self): 45 | pass 46 | 47 | @abstractmethod 48 | def set_rec_weights(self): 49 | pass 50 | 51 | @abstractmethod 52 | def set_style_weights(self): 53 | pass 54 | 55 | 56 | def set_subsets(self): 57 | num_mods = len(list(self.modalities.keys())) 58 | 59 | """ 60 | powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) 61 | (1,2,3) 62 | """ 63 | xs = list(self.modalities) 64 | # note we return an iterator rather than a list 65 | subsets_list = chain.from_iterable(combinations(xs, n) for n in 66 | range(len(xs)+1)) 67 | subsets = dict() 68 | for k, mod_names in enumerate(subsets_list): 69 | mods = [] 70 | for l, mod_name in enumerate(sorted(mod_names)): 71 | mods.append(self.modalities[mod_name]) 72 | key = '_'.join(sorted(mod_names)) 73 | subsets[key] = mods 74 | return subsets 75 | -------------------------------------------------------------------------------- /BraVL_fMRI/utils/BaseExperiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | from abc import ABC, abstractmethod 3 | from itertools import chain, combinations 4 | 5 | class BaseExperiment(ABC): 6 | def __init__(self, flags): 7 | self.flags = flags 8 | self.name = flags.dataset 9 | 10 | self.modalities = None 11 | self.num_modalities = None 12 | self.subsets = None 13 | self.dataset_train = None 14 | self.dataset_test = None 15 | self.Q1, self.Q2, self.Q3 = None,None,None 16 | self.mm_vae = None 17 | self.clfs = None 18 | self.optimizer = None 19 | self.rec_weights = None 20 | self.style_weights = None 21 | 22 | self.test_samples = None 23 | self.paths_fid = None 24 | 25 | 26 | @abstractmethod 27 | def set_model(self): 28 | pass 29 | 30 | @abstractmethod 31 | def set_Qmodel(self): 32 | pass 33 | 34 | @abstractmethod 35 | def set_modalities(self): 36 | pass 37 | 38 | @abstractmethod 39 | def set_dataset(self): 40 | pass 41 | 42 | 43 | @abstractmethod 44 | def set_optimizer(self): 45 | pass 46 | 47 | @abstractmethod 48 | def set_rec_weights(self): 49 | pass 50 | 51 | @abstractmethod 52 | def set_style_weights(self): 53 | pass 54 | 55 | 56 | def set_subsets(self): 57 | num_mods = len(list(self.modalities.keys())) 58 | 59 | """ 60 | powerset([1,2,3]) --> () (1,) (2,) (3,) (1,2) (1,3) (2,3) 61 | (1,2,3) 62 | """ 63 | xs = list(self.modalities) 64 | # note we return an iterator rather than a list 65 | subsets_list = chain.from_iterable(combinations(xs, n) for n in 66 | range(len(xs)+1)) 67 | subsets = dict() 68 | for k, mod_names in enumerate(subsets_list): 69 | mods = [] 70 | for l, mod_name in enumerate(sorted(mod_names)): 71 | mods.append(self.modalities[mod_name]) 72 | key = '_'.join(sorted(mod_names)) 73 | subsets[key] = mods 74 | return subsets 75 | -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/MLP_Image.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class EncoderImage(nn.Module): 7 | def __init__(self, flags): 8 | super(EncoderImage, self).__init__() 9 | self.flags = flags; 10 | self.hidden_dim = 256; 11 | 12 | modules = [] 13 | modules.append(nn.Sequential(nn.Linear(flags.m2_dim, self.hidden_dim), nn.ReLU(True))) 14 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 15 | for _ in range(flags.num_hidden_layers - 1)]) 16 | self.enc = nn.Sequential(*modules) 17 | self.relu = nn.ReLU(); 18 | self.hidden_mu = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 19 | self.hidden_logvar = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 20 | 21 | 22 | def forward(self, x): 23 | h = self.enc(x); 24 | h = h.view(h.size(0), -1); 25 | latent_space_mu = self.hidden_mu(h); 26 | latent_space_logvar = self.hidden_logvar(h); 27 | latent_space_mu = latent_space_mu.view(latent_space_mu.size(0), -1); 28 | latent_space_logvar = latent_space_logvar.view(latent_space_logvar.size(0), -1); 29 | return None, None, latent_space_mu, latent_space_logvar; 30 | 31 | 32 | 33 | class DecoderImage(nn.Module): 34 | def __init__(self, flags): 35 | super(DecoderImage, self).__init__(); 36 | self.flags = flags; 37 | self.hidden_dim = 256; 38 | modules = [] 39 | 40 | modules.append(nn.Sequential(nn.Linear(flags.class_dim, self.hidden_dim), nn.ReLU(True))) 41 | 42 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 43 | for _ in range(flags.num_hidden_layers - 1)]) 44 | self.dec = nn.Sequential(*modules) 45 | self.fc3 = nn.Linear(self.hidden_dim, flags.m2_dim, bias=True) 46 | 47 | 48 | def forward(self, style_latent_space, class_latent_space): 49 | z = class_latent_space; 50 | x_hat = self.dec(z); 51 | x_hat = self.fc3(x_hat); 52 | return x_hat, torch.tensor(0.75).to(z.device); -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/MLP_Image.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class EncoderImage(nn.Module): 7 | def __init__(self, flags): 8 | super(EncoderImage, self).__init__() 9 | self.flags = flags; 10 | self.hidden_dim = 2048; 11 | 12 | modules = [] 13 | modules.append(nn.Sequential(nn.Linear(flags.m2_dim, self.hidden_dim), nn.ReLU(True))) 14 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 15 | for _ in range(flags.num_hidden_layers - 1)]) 16 | self.enc = nn.Sequential(*modules) 17 | self.relu = nn.ReLU(); 18 | self.hidden_mu = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 19 | self.hidden_logvar = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 20 | 21 | 22 | def forward(self, x): 23 | h = self.enc(x); 24 | h = h.view(h.size(0), -1); 25 | latent_space_mu = self.hidden_mu(h); 26 | latent_space_logvar = self.hidden_logvar(h); 27 | latent_space_mu = latent_space_mu.view(latent_space_mu.size(0), -1); 28 | latent_space_logvar = latent_space_logvar.view(latent_space_logvar.size(0), -1); 29 | return None, None, latent_space_mu, latent_space_logvar; 30 | 31 | 32 | 33 | class DecoderImage(nn.Module): 34 | def __init__(self, flags): 35 | super(DecoderImage, self).__init__(); 36 | self.flags = flags; 37 | self.hidden_dim = 2048; 38 | modules = [] 39 | 40 | modules.append(nn.Sequential(nn.Linear(flags.class_dim, self.hidden_dim), nn.ReLU(True))) 41 | 42 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 43 | for _ in range(flags.num_hidden_layers - 1)]) 44 | self.dec = nn.Sequential(*modules) 45 | self.fc3 = nn.Linear(self.hidden_dim, flags.m2_dim, bias=True) 46 | 47 | 48 | def forward(self, style_latent_space, class_latent_space): 49 | z = class_latent_space; 50 | x_hat = self.dec(z); 51 | x_hat = self.fc3(x_hat); 52 | return x_hat, torch.tensor(0.75).to(z.device); -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/MLP_Text.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class EncoderText(nn.Module): 7 | def __init__(self, flags): 8 | super(EncoderText, self).__init__() 9 | self.flags = flags; 10 | self.hidden_dim = 256; 11 | 12 | modules = [] 13 | modules.append(nn.Sequential(nn.Linear(flags.m3_dim, self.hidden_dim), nn.ReLU(True))) 14 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 15 | for _ in range(flags.num_hidden_layers - 1)]) 16 | self.enc = nn.Sequential(*modules) 17 | self.relu = nn.ReLU(); 18 | self.hidden_mu = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 19 | self.hidden_logvar = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 20 | 21 | 22 | def forward(self, x): 23 | h = self.enc(x); 24 | h = h.view(h.size(0), -1); 25 | latent_space_mu = self.hidden_mu(h); 26 | latent_space_logvar = self.hidden_logvar(h); 27 | latent_space_mu = latent_space_mu.view(latent_space_mu.size(0), -1); 28 | latent_space_logvar = latent_space_logvar.view(latent_space_logvar.size(0), -1); 29 | return None, None, latent_space_mu, latent_space_logvar; 30 | 31 | 32 | 33 | class DecoderText(nn.Module): 34 | def __init__(self, flags): 35 | super(DecoderText, self).__init__(); 36 | self.flags = flags; 37 | self.hidden_dim = 256; 38 | modules = [] 39 | 40 | modules.append(nn.Sequential(nn.Linear(flags.class_dim, self.hidden_dim), nn.ReLU(True))) 41 | 42 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 43 | for _ in range(flags.num_hidden_layers - 1)]) 44 | self.dec = nn.Sequential(*modules) 45 | self.fc3 = nn.Linear(self.hidden_dim, flags.m3_dim) 46 | self.relu = nn.ReLU(); 47 | 48 | 49 | def forward(self, style_latent_space, class_latent_space): 50 | z = class_latent_space; 51 | x_hat = self.dec(z); 52 | x_hat = self.fc3(x_hat); 53 | return x_hat, torch.tensor(0.75).to(z.device); -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/MLP_Text.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class EncoderText(nn.Module): 7 | def __init__(self, flags): 8 | super(EncoderText, self).__init__() 9 | self.flags = flags; 10 | self.hidden_dim = 512; 11 | 12 | modules = [] 13 | modules.append(nn.Sequential(nn.Linear(flags.m3_dim, self.hidden_dim), nn.ReLU(True))) 14 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 15 | for _ in range(flags.num_hidden_layers - 1)]) 16 | self.enc = nn.Sequential(*modules) 17 | self.relu = nn.ReLU(); 18 | self.hidden_mu = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 19 | self.hidden_logvar = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 20 | 21 | 22 | def forward(self, x): 23 | h = self.enc(x); 24 | h = h.view(h.size(0), -1); 25 | latent_space_mu = self.hidden_mu(h); 26 | latent_space_logvar = self.hidden_logvar(h); 27 | latent_space_mu = latent_space_mu.view(latent_space_mu.size(0), -1); 28 | latent_space_logvar = latent_space_logvar.view(latent_space_logvar.size(0), -1); 29 | return None, None, latent_space_mu, latent_space_logvar; 30 | 31 | 32 | 33 | class DecoderText(nn.Module): 34 | def __init__(self, flags): 35 | super(DecoderText, self).__init__(); 36 | self.flags = flags; 37 | self.hidden_dim = 512; 38 | modules = [] 39 | 40 | modules.append(nn.Sequential(nn.Linear(flags.class_dim, self.hidden_dim), nn.ReLU(True))) 41 | 42 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 43 | for _ in range(flags.num_hidden_layers - 1)]) 44 | self.dec = nn.Sequential(*modules) 45 | self.fc3 = nn.Linear(self.hidden_dim, flags.m3_dim) 46 | self.relu = nn.ReLU(); 47 | 48 | 49 | def forward(self, style_latent_space, class_latent_space): 50 | z = class_latent_space; 51 | x_hat = self.dec(z); 52 | x_hat = self.fc3(x_hat); 53 | return x_hat, torch.tensor(0.75).to(z.device); -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/networks/MLP_Brain.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | class EncoderBrain(nn.Module): 6 | def __init__(self, flags): 7 | super(EncoderBrain, self).__init__() 8 | self.flags = flags; 9 | self.hidden_dim = 256; 10 | 11 | modules = [] 12 | modules.append(nn.Sequential(nn.Linear(flags.m1_dim, self.hidden_dim), nn.ReLU(True))) 13 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 14 | for _ in range(flags.num_hidden_layers - 1)]) 15 | self.enc = nn.Sequential(*modules) 16 | self.relu = nn.ReLU(); 17 | 18 | self.hidden_mu = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 19 | self.hidden_logvar = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 20 | 21 | 22 | def forward(self, x): 23 | h = self.enc(x); 24 | h = h.view(h.size(0), -1); 25 | latent_space_mu = self.hidden_mu(h); 26 | latent_space_logvar = self.hidden_logvar(h); 27 | latent_space_mu = latent_space_mu.view(latent_space_mu.size(0), -1); 28 | latent_space_logvar = latent_space_logvar.view(latent_space_logvar.size(0), -1); 29 | return None, None, latent_space_mu, latent_space_logvar; 30 | 31 | 32 | 33 | class DecoderBrain(nn.Module): 34 | def __init__(self, flags): 35 | super(DecoderBrain, self).__init__(); 36 | self.flags = flags; 37 | self.hidden_dim = 256; 38 | modules = [] 39 | 40 | modules.append(nn.Sequential(nn.Linear(flags.class_dim, self.hidden_dim), nn.ReLU(True))) 41 | 42 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 43 | for _ in range(flags.num_hidden_layers - 1)]) 44 | self.dec = nn.Sequential(*modules) 45 | self.fc3 = nn.Linear(self.hidden_dim, flags.m1_dim) 46 | self.relu = nn.ReLU(); 47 | 48 | 49 | def forward(self, style_latent_space, class_latent_space): 50 | z = class_latent_space; 51 | x_hat = self.dec(z); 52 | x_hat = self.fc3(x_hat); 53 | return x_hat, torch.tensor(0.75).to(z.device); 54 | -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/networks/MLP_Brain.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class EncoderBrain(nn.Module): 7 | def __init__(self, flags): 8 | super(EncoderBrain, self).__init__() 9 | self.flags = flags; 10 | self.hidden_dim = 512; 11 | 12 | modules = [] 13 | modules.append(nn.Sequential(nn.Linear(flags.m1_dim, self.hidden_dim), nn.ReLU(True))) 14 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 15 | for _ in range(flags.num_hidden_layers - 1)]) 16 | self.enc = nn.Sequential(*modules) 17 | self.relu = nn.ReLU(); 18 | 19 | self.hidden_mu = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 20 | self.hidden_logvar = nn.Linear(in_features=self.hidden_dim, out_features=flags.class_dim, bias=True) 21 | 22 | 23 | def forward(self, x): 24 | h = self.enc(x); 25 | h = h.view(h.size(0), -1); 26 | latent_space_mu = self.hidden_mu(h); 27 | latent_space_logvar = self.hidden_logvar(h); 28 | latent_space_mu = latent_space_mu.view(latent_space_mu.size(0), -1); 29 | latent_space_logvar = latent_space_logvar.view(latent_space_logvar.size(0), -1); 30 | return None, None, latent_space_mu, latent_space_logvar; 31 | 32 | 33 | 34 | class DecoderBrain(nn.Module): 35 | def __init__(self, flags): 36 | super(DecoderBrain, self).__init__(); 37 | self.flags = flags; 38 | self.hidden_dim = 512; 39 | modules = [] 40 | 41 | modules.append(nn.Sequential(nn.Linear(flags.class_dim, self.hidden_dim), nn.ReLU(True))) 42 | 43 | modules.extend([nn.Sequential(nn.Linear(self.hidden_dim, self.hidden_dim), nn.ReLU(True)) 44 | for _ in range(flags.num_hidden_layers - 1)]) 45 | self.dec = nn.Sequential(*modules) 46 | self.fc3 = nn.Linear(self.hidden_dim, flags.m1_dim) 47 | self.relu = nn.ReLU(); 48 | 49 | 50 | def forward(self, style_latent_space, class_latent_space): 51 | z = class_latent_space; 52 | x_hat = self.dec(z); 53 | x_hat = self.fc3(x_hat); 54 | return x_hat, torch.tensor(0.75).to(z.device); 55 | -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/flags.py: -------------------------------------------------------------------------------- 1 | from utils.BaseFlags import parser as parser 2 | 3 | # DATASET NAME 4 | parser.add_argument('--dataset', type=str, default='Brain_Image_Text', help="name of the dataset") 5 | # DATA DEPENDENT 6 | # to be set by experiments themselves 7 | parser.add_argument('--style_m1_dim', type=int, default=0, help="dimension of varying factor latent space") 8 | parser.add_argument('--style_m2_dim', type=int, default=0, help="dimension of varying factor latent space") 9 | parser.add_argument('--style_m3_dim', type=int, default=0, help="dimension of varying factor latent space") 10 | 11 | parser.add_argument('--num_hidden_layers', type=int, default=2, help="number of channels in images") 12 | parser.add_argument('--likelihood_m1', type=str, default='laplace', help="output distribution") 13 | parser.add_argument('--likelihood_m2', type=str, default='laplace', help="output distribution") 14 | parser.add_argument('--likelihood_m3', type=str, default='laplace', help="output distribution") 15 | 16 | # LOSS TERM WEIGHTS 17 | parser.add_argument('--beta_m1_style', type=float, default=1.0, help="default weight divergence term style modality 1") 18 | parser.add_argument('--beta_m2_style', type=float, default=1.0, help="default weight divergence term style modality 2") 19 | parser.add_argument('--beta_m3_style', type=float, default=1.0, help="default weight divergence term style modality 3") 20 | parser.add_argument('--beta_m1_rec', type=float, default=1.0, help="default weight reconstruction modality 1") 21 | parser.add_argument('--beta_m2_rec', type=float, default=1.0, help="default weight reconstruction modality 2") 22 | parser.add_argument('--beta_m3_rec', type=float, default=1.0, help="default weight reconstruction modality 3") 23 | parser.add_argument('--div_weight_m1_content', type=float, default=0.25, help="default weight divergence term content modality 1") 24 | parser.add_argument('--div_weight_m2_content', type=float, default=0.25, help="default weight divergence term content modality 2") 25 | parser.add_argument('--div_weight_m3_content', type=float, default=0.25, help="default weight divergence term content modality 2") 26 | parser.add_argument('--div_weight_uniform_content', type=float, default=0.25, help="default weight divergence term prior") 27 | # 28 | -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/flags.py: -------------------------------------------------------------------------------- 1 | from utils.BaseFlags import parser as parser 2 | 3 | # DATASET NAME 4 | parser.add_argument('--dataset', type=str, default='Brain_Image_Text', help="name of the dataset") 5 | # DATA DEPENDENT 6 | # to be set by experiments themselves 7 | parser.add_argument('--style_m1_dim', type=int, default=0, help="dimension of varying factor latent space") 8 | parser.add_argument('--style_m2_dim', type=int, default=0, help="dimension of varying factor latent space") 9 | parser.add_argument('--style_m3_dim', type=int, default=0, help="dimension of varying factor latent space") 10 | 11 | parser.add_argument('--num_hidden_layers', type=int, default=2, help="number of channels in images") 12 | parser.add_argument('--likelihood_m1', type=str, default='laplace', help="output distribution") 13 | parser.add_argument('--likelihood_m2', type=str, default='laplace', help="output distribution") 14 | parser.add_argument('--likelihood_m3', type=str, default='laplace', help="output distribution") 15 | 16 | # LOSS TERM WEIGHTS 17 | parser.add_argument('--beta_m1_style', type=float, default=1.0, help="default weight divergence term style modality 1") 18 | parser.add_argument('--beta_m2_style', type=float, default=1.0, help="default weight divergence term style modality 2") 19 | parser.add_argument('--beta_m3_style', type=float, default=1.0, help="default weight divergence term style modality 3") 20 | parser.add_argument('--beta_m1_rec', type=float, default=1.0, help="default weight reconstruction modality 1") 21 | parser.add_argument('--beta_m2_rec', type=float, default=1.0, help="default weight reconstruction modality 2") 22 | parser.add_argument('--beta_m3_rec', type=float, default=1.0, help="default weight reconstruction modality 3") 23 | parser.add_argument('--div_weight_m1_content', type=float, default=0.25, help="default weight divergence term content modality 1") 24 | parser.add_argument('--div_weight_m2_content', type=float, default=0.25, help="default weight divergence term content modality 2") 25 | parser.add_argument('--div_weight_m3_content', type=float, default=0.25, help="default weight divergence term content modality 2") 26 | parser.add_argument('--div_weight_uniform_content', type=float, default=0.25, help="default weight divergence term prior") 27 | # 28 | -------------------------------------------------------------------------------- /BraVL_fMRI/stability_selection.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from itertools import combinations 3 | 4 | def stability_selection(data, n=None): 5 | """Return the indices of the n voxels with best stability 6 | 7 | Given repeated fMRI measurements on a set of stimuli, return the indices of 8 | the voxels that demonstrate the best stability across the repetitions. This 9 | stability is quantified for each voxel, as the mean Pearson correlation 10 | coefficient across all pairwise combinations of the repetions. 11 | 12 | Parameters 13 | ---------- 14 | data : 3D array (n_repetitions, n_items, n_voxels) 15 | The fMRI images 16 | n : int | None 17 | If specified, the indices of the top N most stable vertices are 18 | returned. Otherwise the indices of all vertices are returned. 19 | 20 | Returns 21 | ------- 22 | top_indices : 1D array (n_voxels) 23 | The indices of the vertices, ordered by stability in decreasing order 24 | (the first index corresponds to the vertex with the highest stability). 25 | If the n parameter is specified, as most N indices are returned. 26 | """ 27 | n_repetitions, n_items, n_voxels = data.shape 28 | 29 | if n is None: 30 | n = n_voxels 31 | elif n > n_voxels: 32 | raise ValueError('n must be a number between 0 and ' + n_voxels) 33 | 34 | # Drop all voxels don't contain NaN's for any items 35 | non_nan_mask = ~np.any(np.any(np.isnan(data), axis=1), axis=0) 36 | non_nan_indices = np.flatnonzero(non_nan_mask) 37 | data_trimmed = data[:, :, non_nan_mask] 38 | 39 | data_means = data_trimmed.mean(axis=1) 40 | data_stds = data_trimmed.std(axis=1) 41 | 42 | # Loop over all pairwise combinations and compute correlations 43 | stability_scores = [] 44 | for x, y in combinations(range(n_repetitions), 2): 45 | x1 = (data_trimmed[x] - data_means[x]) / data_stds[x] 46 | y1 = (data_trimmed[y] - data_means[y]) / data_stds[y] 47 | stability_scores.append(np.sum(x1 * y1, axis=0) / n_items) 48 | 49 | # Compute the N best voxels 50 | best_voxels = np.mean(stability_scores, axis=0).argsort()[-n:] 51 | 52 | # Return the (original) indices of the best voxels in decreasing order 53 | return non_nan_indices[best_voxels][::-1] 54 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BraVL 2 | This is the official code for the paper "Decoding Visual Neural Representations by Multimodal Learning of Brain-Visual-Linguistic Features, IEEE Transactions on Pattern Analysis and Machine Intelligence (TPAMI 2023)" (https://ieeexplore.ieee.org/document/10089190). 3 | > - Authors: Changde Du, Kaicheng Fu, Jinpeng Li, Huiguang He 4 | 5 | ![Dual coding of knowledge](figure1.png) 6 | ![Illustration of the trimodal data](figure2.png) 7 | ![Brain-Image-Text joint representation learning](figure3.png) 8 | ![The voxel stability maps in the visual cortex](figure4.png) 9 | 10 | ## Preliminaries 11 | 12 | This code was developed and tested with: 13 | - Python version 3.7.0 14 | - PyTorch version 1.9.0 15 | - CUDA version 11.2 16 | - The conda environment defined in `environment.yml` 17 | 18 | First, set up the conda enviroment as follows: 19 | ```bash 20 | conda env create -f environment.yml # create conda env 21 | conda activate BraVL # activate conda env 22 | ``` 23 | ## Download data 24 | Second, download the pre-processed trimodal data from https://figshare.com/articles/dataset/BraVL/17024591, unzip them, and put them at "./data" directory: 25 | ```bash 26 | unzip DIR-Wiki.zip -d BraVL_fMRI/data/ 27 | unzip GOD-Wiki.zip -d BraVL_fMRI/data/ 28 | unzip ThingsEEG-Text.zip -d BraVL_EEG/data/ 29 | ``` 30 | Note that, the raw (image and brain fMRI/EEG) data are not included here because they are too large. Raw ImageNet images and brain fMRI data can be downloaded from the corresponding official site. We provide python scripts for feature extraction and data preprocessing. 31 | 32 | ## Experiments 33 | 34 | Experiments can be started by running the `job_trimodal` script. 35 | 36 | 37 | ### running BraVL on the Image-Text-fMRI datasets 38 | ``` 39 | cd BraVL_fMRI 40 | bash job_trimodal 41 | ``` 42 | ### running BraVL on the Image-Text-EEG datasets 43 | ``` 44 | cd BraVL_EEG 45 | bash job_trimodal 46 | ``` 47 | ## Cite 48 | 49 | Please cite our paper if you use this code in your own work: 50 | ``` 51 | @article{du2023decoding, 52 | title={Decoding Visual Neural Representations by Multimodal Learning of Brain-Visual-Linguistic Features}, 53 | author={Du, Changde and Fu, Kaicheng and Li, Jinpeng and He, Huiguang}, 54 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 55 | year={2023}, 56 | publisher={IEEE} 57 | } 58 | ``` 59 | 60 | If you have any questions about the code or the paper, we are happy to help! 61 | -------------------------------------------------------------------------------- /BraVL_fMRI/job_trimodal: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EPOCH=100 4 | 5 | DATASET='DIR-Wiki' 6 | RATIO='' 7 | ####################################### 8 | # Subject 1 9 | ####################################### 10 | SUBJECT='sub-01' 11 | FUSION="joint_elbo" 12 | python main_trimodal.py \ 13 | --end_epoch=$EPOCH \ 14 | --method=$FUSION \ 15 | --aug_type='image_text' \ 16 | --sbj=$SUBJECT\ 17 | --dataname=$DATASET \ 18 | --stability_ratio=$RATIO \ 19 | --roi='LVC_HVC_IT'\ 20 | --image_model='pytorch/repvgg_b3g4'\ 21 | --text_model='GPTNeo' 22 | ####################################### 23 | # Subject 2 24 | ####################################### 25 | SUBJECT='sub-02' 26 | FUSION="joint_elbo" 27 | python main_trimodal.py \ 28 | --end_epoch=$EPOCH \ 29 | --method=$FUSION \ 30 | --aug_type='image_text' \ 31 | --sbj=$SUBJECT\ 32 | --dataname=$DATASET \ 33 | --stability_ratio=$RATIO \ 34 | --roi='LVC_HVC_IT'\ 35 | --image_model='pytorch/repvgg_b3g4'\ 36 | --text_model='GPTNeo' 37 | ####################################### 38 | # Subject 3 39 | ####################################### 40 | SUBJECT='sub-03' 41 | FUSION="joint_elbo" 42 | python main_trimodal.py \ 43 | --end_epoch=$EPOCH \ 44 | --method=$FUSION \ 45 | --aug_type='image_text' \ 46 | --sbj=$SUBJECT\ 47 | --dataname=$DATASET \ 48 | --stability_ratio=$RATIO \ 49 | --roi='LVC_HVC_IT'\ 50 | --image_model='pytorch/repvgg_b3g4'\ 51 | --text_model='GPTNeo' 52 | 53 | 54 | DATASET='GOD-Wiki' 55 | ####################################### 56 | # Subject 1 57 | ####################################### 58 | SUBJECT='sub-01' 59 | FUSION="joint_elbo" 60 | python main_trimodal.py \ 61 | --end_epoch=$EPOCH \ 62 | --method=$FUSION \ 63 | --aug_type='image_text' \ 64 | --sbj=$SUBJECT\ 65 | --dataname=$DATASET \ 66 | --roi='VC'\ 67 | --image_model='pytorch/repvgg_b3g4'\ 68 | --text_model='GPTNeo' 69 | ####################################### 70 | # Subject 2 71 | ####################################### 72 | SUBJECT='sub-02' 73 | FUSION="joint_elbo" 74 | python main_trimodal.py \ 75 | --end_epoch=$EPOCH \ 76 | --method=$FUSION \ 77 | --aug_type='image_text' \ 78 | --sbj=$SUBJECT\ 79 | --dataname=$DATASET \ 80 | --roi='VC'\ 81 | --image_model='pytorch/repvgg_b3g4'\ 82 | --text_model='GPTNeo' 83 | ####################################### 84 | # Subject 3 85 | ####################################### 86 | SUBJECT='sub-03' 87 | FUSION="joint_elbo" 88 | python main_trimodal.py \ 89 | --end_epoch=$EPOCH \ 90 | --method=$FUSION \ 91 | --aug_type='image_text' \ 92 | --sbj=$SUBJECT\ 93 | --dataname=$DATASET \ 94 | --roi='VC'\ 95 | --image_model='pytorch/repvgg_b3g4'\ 96 | --text_model='GPTNeo' 97 | ####################################### 98 | # Subject 4 99 | ####################################### 100 | SUBJECT='sub-04' 101 | FUSION="joint_elbo" 102 | python main_trimodal.py \ 103 | --end_epoch=$EPOCH \ 104 | --method=$FUSION \ 105 | --aug_type='image_text' \ 106 | --sbj=$SUBJECT\ 107 | --dataname=$DATASET \ 108 | --roi='VC'\ 109 | --image_model='pytorch/repvgg_b3g4'\ 110 | --text_model='GPTNeo' 111 | ####################################### 112 | # Subject 5 113 | ####################################### 114 | SUBJECT='sub-05' 115 | FUSION="joint_elbo" 116 | python main_trimodal.py \ 117 | --end_epoch=$EPOCH \ 118 | --method=$FUSION \ 119 | --aug_type='image_text' \ 120 | --sbj=$SUBJECT\ 121 | --dataname=$DATASET \ 122 | --roi='VC'\ 123 | --image_model='pytorch/repvgg_b3g4'\ 124 | --text_model='GPTNeo' -------------------------------------------------------------------------------- /BraVL_EEG/job_trimodal: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | EPOCH=500 4 | 5 | DATASET='ThingsEEG-Wiki' 6 | TEXTMODAL='CLIPText' 7 | IMAGEMODAL='pytorch/cornet_s_fine_tuned' 8 | AUGTYPE='no_aug' 9 | CHANNELS='17channels' 10 | ######################################## 11 | ## Subject 1 12 | ######################################## 13 | #SUBJECT='sub-01' 14 | #FUSION="joint_elbo" 15 | #python main_trimodal.py \ 16 | # --end_epoch=$EPOCH \ 17 | # --method=$FUSION \ 18 | # --aug_type=$AUGTYPE \ 19 | # --sbj=$SUBJECT\ 20 | # --dataname=$DATASET \ 21 | # --roi=$CHANNELS\ 22 | # --image_model=$IMAGEMODAL\ 23 | # --text_model=$TEXTMODAL 24 | # 25 | ######################################## 26 | ## Subject 2 27 | ######################################## 28 | #SUBJECT='sub-02' 29 | #FUSION="joint_elbo" 30 | #python main_trimodal.py \ 31 | # --end_epoch=$EPOCH \ 32 | # --method=$FUSION \ 33 | # --aug_type=$AUGTYPE \ 34 | # --sbj=$SUBJECT\ 35 | # --dataname=$DATASET \ 36 | # --roi=$CHANNELS\ 37 | # --image_model=$IMAGEMODAL\ 38 | # --text_model=$TEXTMODAL 39 | # 40 | ######################################## 41 | ## Subject 3 42 | ######################################## 43 | #SUBJECT='sub-03' 44 | #FUSION="joint_elbo" 45 | #python main_trimodal.py \ 46 | # --end_epoch=$EPOCH \ 47 | # --method=$FUSION \ 48 | # --aug_type=$AUGTYPE \ 49 | # --sbj=$SUBJECT\ 50 | # --dataname=$DATASET \ 51 | # --roi=$CHANNELS\ 52 | # --image_model=$IMAGEMODAL\ 53 | # --text_model=$TEXTMODAL 54 | # 55 | ######################################## 56 | ## Subject 4 57 | ######################################## 58 | #SUBJECT='sub-04' 59 | #FUSION="joint_elbo" 60 | #python main_trimodal.py \ 61 | # --end_epoch=$EPOCH \ 62 | # --method=$FUSION \ 63 | # --aug_type=$AUGTYPE \ 64 | # --sbj=$SUBJECT\ 65 | # --dataname=$DATASET \ 66 | # --roi=$CHANNELS\ 67 | # --image_model=$IMAGEMODAL\ 68 | # --text_model=$TEXTMODAL 69 | # 70 | ######################################## 71 | ## Subject 5 72 | ######################################## 73 | #SUBJECT='sub-05' 74 | #FUSION="joint_elbo" 75 | #python main_trimodal.py \ 76 | # --end_epoch=$EPOCH \ 77 | # --method=$FUSION \ 78 | # --aug_type=$AUGTYPE \ 79 | # --sbj=$SUBJECT\ 80 | # --dataname=$DATASET \ 81 | # --roi=$CHANNELS\ 82 | # --image_model=$IMAGEMODAL\ 83 | # --text_model=$TEXTMODAL 84 | # 85 | ######################################## 86 | ## Subject 6 87 | ######################################## 88 | #SUBJECT='sub-06' 89 | #FUSION="joint_elbo" 90 | #python main_trimodal.py \ 91 | # --end_epoch=$EPOCH \ 92 | # --method=$FUSION \ 93 | # --aug_type=$AUGTYPE \ 94 | # --sbj=$SUBJECT\ 95 | # --dataname=$DATASET \ 96 | # --roi=$CHANNELS\ 97 | # --image_model=$IMAGEMODAL\ 98 | # --text_model=$TEXTMODAL 99 | # 100 | ######################################## 101 | ## Subject 7 102 | ######################################## 103 | #SUBJECT='sub-07' 104 | #FUSION="joint_elbo" 105 | #python main_trimodal.py \ 106 | # --end_epoch=$EPOCH \ 107 | # --method=$FUSION \ 108 | # --aug_type=$AUGTYPE \ 109 | # --sbj=$SUBJECT\ 110 | # --dataname=$DATASET \ 111 | # --roi=$CHANNELS\ 112 | # --image_model=$IMAGEMODAL\ 113 | # --text_model=$TEXTMODAL 114 | 115 | ####################################### 116 | # Subject 8 117 | ####################################### 118 | SUBJECT='sub-08' 119 | FUSION="joint_elbo" 120 | python main_trimodal.py \ 121 | --end_epoch=$EPOCH \ 122 | --method=$FUSION \ 123 | --aug_type=$AUGTYPE \ 124 | --sbj=$SUBJECT\ 125 | --dataname=$DATASET \ 126 | --roi=$CHANNELS\ 127 | --image_model=$IMAGEMODAL\ 128 | --text_model=$TEXTMODAL 129 | 130 | ######################################## 131 | ## Subject 9 132 | ######################################## 133 | #SUBJECT='sub-09' 134 | #FUSION="joint_elbo" 135 | #python main_trimodal.py \ 136 | # --end_epoch=$EPOCH \ 137 | # --method=$FUSION \ 138 | # --aug_type=$AUGTYPE \ 139 | # --sbj=$SUBJECT\ 140 | # --dataname=$DATASET \ 141 | # --roi=$CHANNELS\ 142 | # --image_model=$IMAGEMODAL\ 143 | # --text_model=$TEXTMODAL 144 | 145 | ####################################### 146 | # Subject 10 147 | ####################################### 148 | SUBJECT='sub-10' 149 | FUSION="joint_elbo" 150 | python main_trimodal.py \ 151 | --end_epoch=$EPOCH \ 152 | --method=$FUSION \ 153 | --aug_type=$AUGTYPE \ 154 | --sbj=$SUBJECT\ 155 | --dataname=$DATASET \ 156 | --roi=$CHANNELS\ 157 | --image_model=$IMAGEMODAL\ 158 | --text_model=$TEXTMODAL -------------------------------------------------------------------------------- /BraVL_EEG/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # Print iterations progress 5 | def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'): 6 | """ 7 | Call in a loop to create terminal progress bar 8 | @params: 9 | iteration - Required : current iteration (Int) 10 | total - Required : total iterations (Int) 11 | prefix - Optional : prefix string (Str) 12 | suffix - Optional : suffix string (Str) 13 | decimals - Optional : positive number of decimals in percent complete (Int) 14 | length - Optional : character length of bar (Int) 15 | fill - Optional : bar fill character (Str) 16 | """ 17 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 18 | filledLength = int(length * iteration // total) 19 | bar = fill * filledLength + '-' * (length - filledLength) 20 | print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r') 21 | # Print New Line on Complete 22 | if iteration == total: 23 | print() 24 | 25 | def get_likelihood(str): 26 | if str == 'laplace': 27 | pz = dist.Laplace 28 | elif str == 'bernoulli': 29 | pz = dist.Bernoulli 30 | elif str == 'normal': 31 | pz = dist.Normal 32 | elif str == 'categorical': 33 | pz = dist.OneHotCategorical 34 | else: 35 | print('likelihood not implemented') 36 | pz = None 37 | return pz 38 | 39 | 40 | def reweight_weights(w): 41 | w = w / w.sum() 42 | return w 43 | 44 | 45 | def mixture_component_selection(flags, mus, logvars, w_modalities=None): 46 | #if not defined, take pre-defined weights 47 | num_components = mus.shape[0] 48 | num_samples = mus.shape[1] 49 | if w_modalities is None: 50 | w_modalities = torch.Tensor(flags.alpha_modalities).to(flags.device) 51 | idx_start = [] 52 | idx_end = [] 53 | for k in range(0, num_components): 54 | if k == 0: 55 | i_start = 0 56 | else: 57 | i_start = int(idx_end[k-1]) 58 | if k == w_modalities.shape[0]-1: 59 | i_end = num_samples 60 | else: 61 | i_end = i_start + int(torch.floor(num_samples*w_modalities[k])) 62 | idx_start.append(i_start) 63 | idx_end.append(i_end) 64 | idx_end[-1] = num_samples 65 | mu_sel = torch.cat([mus[k, idx_start[k]:idx_end[k], :] for k in range(w_modalities.shape[0])]) 66 | logvar_sel = torch.cat([logvars[k, idx_start[k]:idx_end[k], :] for k in range(w_modalities.shape[0])]) 67 | return [mu_sel, logvar_sel] 68 | 69 | 70 | def calc_elbo(exp, modality, recs, klds): 71 | flags = exp.flags 72 | mods = exp.modalities 73 | s_weights = exp.style_weights 74 | r_weights = exp.rec_weights 75 | kld_content = klds['content'] 76 | if modality == 'joint': 77 | w_style_kld = 0.0 78 | w_rec = 0.0 79 | klds_style = klds['style'] 80 | for k, m_key in enumerate(mods.keys()): 81 | w_style_kld += s_weights[m_key] * klds_style[m_key] 82 | w_rec += r_weights[m_key] * recs[m_key] 83 | kld_style = w_style_kld 84 | rec_error = w_rec 85 | else: 86 | beta_style_mod = s_weights[modality] 87 | #rec_weight_mod = r_weights[modality] 88 | rec_weight_mod = 1.0 89 | kld_style = beta_style_mod * klds['style'][modality] 90 | rec_error = rec_weight_mod * recs[modality] 91 | div = flags.beta_content * kld_content + flags.beta_style * kld_style 92 | elbo = rec_error + flags.beta * div 93 | return elbo 94 | 95 | 96 | def save_and_log_flags(flags): 97 | #filename_flags = os.path.join(flags.dir_experiment_run, 'flags.json') 98 | #with open(filename_flags, 'w') as f: 99 | # json.dump(flags.__dict__, f, indent=2, sort_keys=True) 100 | 101 | filename_flags_rar = os.path.join(flags.dir_experiment_run, 'flags.rar') 102 | torch.save(flags, filename_flags_rar) 103 | str_args = '' 104 | for k, key in enumerate(sorted(flags.__dict__.keys())): 105 | str_args = str_args + '\n' + key + ': ' + str(flags.__dict__[key]) 106 | return str_args 107 | 108 | 109 | class Flatten(torch.nn.Module): 110 | def forward(self, x): 111 | return x.view(x.size(0), -1) 112 | 113 | 114 | class Unflatten(torch.nn.Module): 115 | def __init__(self, ndims): 116 | super(Unflatten, self).__init__() 117 | self.ndims = ndims 118 | 119 | def forward(self, x): 120 | return x.view(x.size(0), *self.ndims) 121 | -------------------------------------------------------------------------------- /BraVL_fMRI/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | # Print iterations progress 5 | def printProgressBar (iteration, total, prefix = '', suffix = '', decimals = 1, length = 100, fill = '█'): 6 | """ 7 | Call in a loop to create terminal progress bar 8 | @params: 9 | iteration - Required : current iteration (Int) 10 | total - Required : total iterations (Int) 11 | prefix - Optional : prefix string (Str) 12 | suffix - Optional : suffix string (Str) 13 | decimals - Optional : positive number of decimals in percent complete (Int) 14 | length - Optional : character length of bar (Int) 15 | fill - Optional : bar fill character (Str) 16 | """ 17 | percent = ("{0:." + str(decimals) + "f}").format(100 * (iteration / float(total))) 18 | filledLength = int(length * iteration // total) 19 | bar = fill * filledLength + '-' * (length - filledLength) 20 | print('\r%s |%s| %s%% %s' % (prefix, bar, percent, suffix), end = '\r') 21 | # Print New Line on Complete 22 | if iteration == total: 23 | print() 24 | 25 | def get_likelihood(str): 26 | if str == 'laplace': 27 | pz = dist.Laplace 28 | elif str == 'bernoulli': 29 | pz = dist.Bernoulli 30 | elif str == 'normal': 31 | pz = dist.Normal 32 | elif str == 'categorical': 33 | pz = dist.OneHotCategorical 34 | else: 35 | print('likelihood not implemented') 36 | pz = None 37 | return pz 38 | 39 | 40 | def reweight_weights(w): 41 | w = w / w.sum() 42 | return w 43 | 44 | 45 | def mixture_component_selection(flags, mus, logvars, w_modalities=None): 46 | #if not defined, take pre-defined weights 47 | num_components = mus.shape[0] 48 | num_samples = mus.shape[1] 49 | if w_modalities is None: 50 | w_modalities = torch.Tensor(flags.alpha_modalities).to(flags.device) 51 | idx_start = [] 52 | idx_end = [] 53 | for k in range(0, num_components): 54 | if k == 0: 55 | i_start = 0 56 | else: 57 | i_start = int(idx_end[k-1]) 58 | if k == w_modalities.shape[0]-1: 59 | i_end = num_samples 60 | else: 61 | i_end = i_start + int(torch.floor(num_samples*w_modalities[k])) 62 | idx_start.append(i_start) 63 | idx_end.append(i_end) 64 | idx_end[-1] = num_samples 65 | mu_sel = torch.cat([mus[k, idx_start[k]:idx_end[k], :] for k in range(w_modalities.shape[0])]) 66 | logvar_sel = torch.cat([logvars[k, idx_start[k]:idx_end[k], :] for k in range(w_modalities.shape[0])]) 67 | return [mu_sel, logvar_sel] 68 | 69 | 70 | def calc_elbo(exp, modality, recs, klds): 71 | flags = exp.flags 72 | mods = exp.modalities 73 | s_weights = exp.style_weights 74 | r_weights = exp.rec_weights 75 | kld_content = klds['content'] 76 | if modality == 'joint': 77 | w_style_kld = 0.0 78 | w_rec = 0.0 79 | klds_style = klds['style'] 80 | for k, m_key in enumerate(mods.keys()): 81 | w_style_kld += s_weights[m_key] * klds_style[m_key] 82 | w_rec += r_weights[m_key] * recs[m_key] 83 | kld_style = w_style_kld 84 | rec_error = w_rec 85 | else: 86 | beta_style_mod = s_weights[modality] 87 | #rec_weight_mod = r_weights[modality] 88 | rec_weight_mod = 1.0 89 | kld_style = beta_style_mod * klds['style'][modality] 90 | rec_error = rec_weight_mod * recs[modality] 91 | div = flags.beta_content * kld_content + flags.beta_style * kld_style 92 | elbo = rec_error + flags.beta * div 93 | return elbo 94 | 95 | 96 | def save_and_log_flags(flags): 97 | #filename_flags = os.path.join(flags.dir_experiment_run, 'flags.json') 98 | #with open(filename_flags, 'w') as f: 99 | # json.dump(flags.__dict__, f, indent=2, sort_keys=True) 100 | 101 | filename_flags_rar = os.path.join(flags.dir_experiment_run, 'flags.rar') 102 | torch.save(flags, filename_flags_rar) 103 | str_args = '' 104 | for k, key in enumerate(sorted(flags.__dict__.keys())): 105 | str_args = str_args + '\n' + key + ': ' + str(flags.__dict__[key]) 106 | return str_args 107 | 108 | 109 | class Flatten(torch.nn.Module): 110 | def forward(self, x): 111 | return x.view(x.size(0), -1) 112 | 113 | 114 | class Unflatten(torch.nn.Module): 115 | def __init__(self, ndims): 116 | super(Unflatten, self).__init__() 117 | self.ndims = ndims 118 | 119 | def forward(self, x): 120 | return x.view(x.size(0), *self.ndims) 121 | -------------------------------------------------------------------------------- /BraVL_fMRI/utils/BaseFlags.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torch 4 | import scipy.io as sio 5 | parser = argparse.ArgumentParser() 6 | 7 | # TRAINING 8 | parser.add_argument('--batch_size', type=int, default=512, help="batch size for training") 9 | parser.add_argument('--initial_learning_rate', type=float, default=0.0001, help="starting learning rate") 10 | parser.add_argument('--beta_1', type=float, default=0.9, help="default beta_1 val for adam") 11 | parser.add_argument('--beta_2', type=float, default=0.999, help="default beta_2 val for adam") 12 | parser.add_argument('--start_epoch', type=int, default=0, help="flag to set the starting epoch for training") 13 | parser.add_argument('--end_epoch', type=int, default=100, help="flag to indicate the final epoch of training") 14 | 15 | # DATA DEPENDENT 16 | parser.add_argument('--class_dim', type=int, default=32, help="dimension of common factor latent space") 17 | # SAVE and LOAD 18 | parser.add_argument('--mm_vae_save', type=str, default='mm_vae', help="model save for vae_bimodal") 19 | parser.add_argument('--load_saved', type=bool, default=False, help="flag to indicate if a saved model will be loaded") 20 | 21 | # DIRECTORIES 22 | # experiments 23 | parser.add_argument('--dir_experiment', type=str, default='./logs', help="directory to save logs in") 24 | parser.add_argument('--dataname', type=str, default='DIR-Wiki', help="dataset") 25 | parser.add_argument('--sbj', type=str, default='sub-03', help="fmri subject") 26 | parser.add_argument('--roi', type=str, default='LVC_HVC_IT', help="ROI") 27 | parser.add_argument('--text_model', type=str, default='GPTNeo', help="text embedding model") 28 | parser.add_argument('--image_model', type=str, default='pytorch/repvgg_b3g4', help="image embedding model") 29 | parser.add_argument('--stability_ratio', type=str, default='', help="stability_ratio") 30 | parser.add_argument('--test_type', type=str, default='zsl', help='normal or zsl') 31 | parser.add_argument('--aug_type', type=str, default='image_text', help='no_aug, image_text, image_only, text_only') 32 | #multimodal 33 | parser.add_argument('--method', type=str, default='joint_elbo', help='choose method for training the model') 34 | parser.add_argument('--modality_jsd', type=bool, default=False, help="modality_jsd") 35 | parser.add_argument('--modality_poe', type=bool, default=False, help="modality_poe") 36 | parser.add_argument('--modality_moe', type=bool, default=False, help="modality_moe") 37 | parser.add_argument('--joint_elbo', type=bool, default=False, help="modality_moe") 38 | parser.add_argument('--poe_unimodal_elbos', type=bool, default=True, help="unimodal_klds") 39 | parser.add_argument('--factorized_representation', action='store_true', default=False, help="factorized_representation") 40 | 41 | # LOSS TERM WEIGHTS 42 | parser.add_argument('--beta', type=float, default=0.0, help="default initial weight of sum of weighted divergence terms") 43 | parser.add_argument('--beta_style', type=float, default=1.0, help="default weight of sum of weighted style divergence terms") 44 | parser.add_argument('--beta_content', type=float, default=1.0, help="default weight of sum of weighted content divergence terms") 45 | parser.add_argument('--lambda1', type=float, default=0.001, help="default weight of intra_mi terms") 46 | parser.add_argument('--lambda2', type=float, default=0.001, help="default weight of inter_mi terms") 47 | 48 | 49 | FLAGS = parser.parse_args() 50 | data_dir_root = os.path.join('./data', FLAGS.dataname) 51 | brain_dir = os.path.join(data_dir_root, 'brain_feature', FLAGS.roi, FLAGS.sbj) 52 | image_dir_train = os.path.join(data_dir_root, 'visual_feature/ImageNetTraining', FLAGS.image_model+'-PCA', FLAGS.sbj) 53 | text_dir_train = os.path.join(data_dir_root, 'textual_feature/ImageNetTraining/text', FLAGS.text_model, FLAGS.sbj) 54 | 55 | train_brain = sio.loadmat(os.path.join(brain_dir, 'fmri_train_data'+FLAGS.stability_ratio+'.mat'))['data'].astype('double') 56 | train_image = sio.loadmat(os.path.join(image_dir_train, 'feat_pca_train.mat'))['data'].astype('double')#[:,0:3000] 57 | train_text = sio.loadmat(os.path.join(text_dir_train, 'text_feat_train.mat'))['data'].astype('double') 58 | train_brain = torch.from_numpy(train_brain) 59 | train_image = torch.from_numpy(train_image) 60 | train_text = torch.from_numpy(train_text) 61 | dim_brain = train_brain.shape[1] 62 | dim_image = train_image.shape[1] 63 | dim_text = train_text.shape[1] 64 | 65 | parser.add_argument('--m1_dim', type=int, default=dim_brain, help="dimension of modality brain") 66 | parser.add_argument('--m2_dim', type=int, default=dim_image, help="dimension of modality image") 67 | parser.add_argument('--m3_dim', type=int, default=dim_text, help="dimension of modality text") 68 | parser.add_argument('--data_dir_root', type=str, default=data_dir_root, help="data dir") 69 | 70 | FLAGS = parser.parse_args() 71 | print(FLAGS) 72 | -------------------------------------------------------------------------------- /BraVL_EEG/utils/TBLogger.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | class TBLogger(): 8 | def __init__(self, name, writer): 9 | self.name = name; 10 | self.writer = writer; 11 | self.training_prefix = 'train'; 12 | self.testing_prefix = 'test'; 13 | self.step = 0; 14 | 15 | 16 | def write_log_probs(self, name, log_probs): self.writer.add_scalars('%s/LogProb' % name, 17 | log_probs, 18 | self.step) 19 | 20 | 21 | def write_klds(self, name, klds): 22 | self.writer.add_scalars('%s/KLD' % name, 23 | klds, 24 | self.step) 25 | 26 | 27 | def write_group_div(self, name, group_div): 28 | self.writer.add_scalars('%s/group_divergence' % name, 29 | {'group_div': group_div.item()}, 30 | self.step) 31 | 32 | def write_latent_distr(self, name, latents): 33 | l_mods = latents['modalities']; 34 | for k, key in enumerate(l_mods.keys()): 35 | if not l_mods[key][0] is None: 36 | self.writer.add_scalars('%s/mu' % name, 37 | {key: l_mods[key][0].mean().item()}, 38 | self.step) 39 | # if not l_mods[key][1] is None: 40 | # self.writer.add_scalars('%s/logvar' % name, 41 | # {key: l_mods[key][1].mean().item()}, 42 | # self.step) 43 | 44 | 45 | def write_lr_eval(self, lr_eval): 46 | for s, l_key in enumerate(sorted(lr_eval.keys())): 47 | self.writer.add_scalars('Latent Representation/%s'%(l_key), 48 | lr_eval[l_key], 49 | self.step) 50 | 51 | 52 | def write_coherence_logs(self, gen_eval): 53 | for j, l_key in enumerate(sorted(gen_eval['cond'].keys())): 54 | for k, s_key in enumerate(gen_eval['cond'][l_key].keys()): 55 | self.writer.add_scalars('Generation/%s/%s' % 56 | (l_key, s_key), 57 | gen_eval['cond'][l_key][s_key], 58 | self.step) 59 | self.writer.add_scalars('Generation/Random', 60 | gen_eval['random'], 61 | self.step) 62 | 63 | 64 | def write_lhood_logs(self, lhoods): 65 | for k, key in enumerate(sorted(lhoods.keys())): 66 | self.writer.add_scalars('Likelihoods/%s'% 67 | (key), 68 | lhoods[key], 69 | self.step) 70 | 71 | def write_prd_scores(self, prd_scores): 72 | self.writer.add_scalars('PRD', 73 | prd_scores, 74 | self.step) 75 | 76 | 77 | def write_plots(self, plots, epoch): 78 | for k, p_key in enumerate(plots.keys()): 79 | ps = plots[p_key]; 80 | for l, name in enumerate(ps.keys()): 81 | fig = ps[name]; 82 | self.writer.add_image(p_key + '_' + name, 83 | fig, 84 | epoch, 85 | dataformats="HWC"); 86 | 87 | 88 | 89 | def add_basic_logs(self, name, results, loss, log_probs, klds,inter_mi): 90 | self.writer.add_scalars('%s/Loss' % name, 91 | {'loss': loss.data.item()}, 92 | self.step) 93 | # self.writer.add_scalars( '%s/intra_mi' % name, 94 | # {'intra_mi': intra_mi}, 95 | # self.step) 96 | self.writer.add_scalars( '%s/inter_mi' % name, 97 | {'inter_mi': inter_mi}, 98 | self.step) 99 | self.write_log_probs(name, log_probs); 100 | self.write_klds(name, klds); 101 | self.write_group_div(name, results['joint_divergence']); 102 | self.write_latent_distr(name, results['latents']); 103 | 104 | 105 | def write_training_logs(self, results, loss, log_probs, klds,inter_mi): 106 | self.add_basic_logs(self.training_prefix, results, loss, log_probs, klds,inter_mi); 107 | self.step += 1; 108 | 109 | 110 | def write_testing_logs(self, results, loss, log_probs, klds,inter_mi): 111 | self.add_basic_logs(self.testing_prefix, results, loss, log_probs, klds,inter_mi); 112 | self.step += 1; 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /BraVL_fMRI/utils/TBLogger.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | class TBLogger(): 8 | def __init__(self, name, writer): 9 | self.name = name; 10 | self.writer = writer; 11 | self.training_prefix = 'train'; 12 | self.testing_prefix = 'test'; 13 | self.step = 0; 14 | 15 | 16 | def write_log_probs(self, name, log_probs): self.writer.add_scalars('%s/LogProb' % name, 17 | log_probs, 18 | self.step) 19 | 20 | 21 | def write_klds(self, name, klds): 22 | self.writer.add_scalars('%s/KLD' % name, 23 | klds, 24 | self.step) 25 | 26 | 27 | def write_group_div(self, name, group_div): 28 | self.writer.add_scalars('%s/group_divergence' % name, 29 | {'group_div': group_div.item()}, 30 | self.step) 31 | 32 | def write_latent_distr(self, name, latents): 33 | l_mods = latents['modalities']; 34 | for k, key in enumerate(l_mods.keys()): 35 | if not l_mods[key][0] is None: 36 | self.writer.add_scalars('%s/mu' % name, 37 | {key: l_mods[key][0].mean().item()}, 38 | self.step) 39 | # if not l_mods[key][1] is None: 40 | # self.writer.add_scalars('%s/logvar' % name, 41 | # {key: l_mods[key][1].mean().item()}, 42 | # self.step) 43 | 44 | 45 | def write_lr_eval(self, lr_eval): 46 | for s, l_key in enumerate(sorted(lr_eval.keys())): 47 | self.writer.add_scalars('Latent Representation/%s'%(l_key), 48 | lr_eval[l_key], 49 | self.step) 50 | 51 | 52 | def write_coherence_logs(self, gen_eval): 53 | for j, l_key in enumerate(sorted(gen_eval['cond'].keys())): 54 | for k, s_key in enumerate(gen_eval['cond'][l_key].keys()): 55 | self.writer.add_scalars('Generation/%s/%s' % 56 | (l_key, s_key), 57 | gen_eval['cond'][l_key][s_key], 58 | self.step) 59 | self.writer.add_scalars('Generation/Random', 60 | gen_eval['random'], 61 | self.step) 62 | 63 | 64 | def write_lhood_logs(self, lhoods): 65 | for k, key in enumerate(sorted(lhoods.keys())): 66 | self.writer.add_scalars('Likelihoods/%s'% 67 | (key), 68 | lhoods[key], 69 | self.step) 70 | 71 | def write_prd_scores(self, prd_scores): 72 | self.writer.add_scalars('PRD', 73 | prd_scores, 74 | self.step) 75 | 76 | 77 | def write_plots(self, plots, epoch): 78 | for k, p_key in enumerate(plots.keys()): 79 | ps = plots[p_key]; 80 | for l, name in enumerate(ps.keys()): 81 | fig = ps[name]; 82 | self.writer.add_image(p_key + '_' + name, 83 | fig, 84 | epoch, 85 | dataformats="HWC"); 86 | 87 | 88 | 89 | def add_basic_logs(self, name, results, loss, log_probs, klds,inter_mi): 90 | self.writer.add_scalars('%s/Loss' % name, 91 | {'loss': loss.data.item()}, 92 | self.step) 93 | # self.writer.add_scalars( '%s/intra_mi' % name, 94 | # {'intra_mi': intra_mi}, 95 | # self.step) 96 | self.writer.add_scalars( '%s/inter_mi' % name, 97 | {'inter_mi': inter_mi}, 98 | self.step) 99 | self.write_log_probs(name, log_probs); 100 | self.write_klds(name, klds); 101 | self.write_group_div(name, results['joint_divergence']); 102 | self.write_latent_distr(name, results['latents']); 103 | 104 | 105 | def write_training_logs(self, results, loss, log_probs, klds,inter_mi): 106 | self.add_basic_logs(self.training_prefix, results, loss, log_probs, klds,inter_mi); 107 | self.step += 1; 108 | 109 | 110 | def write_testing_logs(self, results, loss, log_probs, klds,inter_mi): 111 | self.add_basic_logs(self.testing_prefix, results, loss, log_probs, klds,inter_mi); 112 | self.step += 1; 113 | 114 | 115 | 116 | 117 | 118 | -------------------------------------------------------------------------------- /BraVL_EEG/utils/BaseFlags.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | import torch 6 | import scipy.io as sio 7 | parser = argparse.ArgumentParser() 8 | 9 | # TRAINING 10 | parser.add_argument('--batch_size', type=int, default=1024, help="batch size for training") 11 | parser.add_argument('--initial_learning_rate', type=float, default=0.0001, help="starting learning rate") 12 | parser.add_argument('--beta_1', type=float, default=0.9, help="default beta_1 val for adam") 13 | parser.add_argument('--beta_2', type=float, default=0.999, help="default beta_2 val for adam") 14 | parser.add_argument('--start_epoch', type=int, default=0, help="flag to set the starting epoch for training") 15 | parser.add_argument('--end_epoch', type=int, default=100, help="flag to indicate the final epoch of training") 16 | 17 | # DATA DEPENDENT 18 | parser.add_argument('--class_dim', type=int, default=32, help="dimension of common factor latent space") 19 | # SAVE and LOAD 20 | parser.add_argument('--mm_vae_save', type=str, default='mm_vae', help="model save for vae_bimodal") 21 | parser.add_argument('--load_saved', type=bool, default=False, help="flag to indicate if a saved model will be loaded") 22 | 23 | # DIRECTORIES 24 | # experiments 25 | parser.add_argument('--dir_experiment', type=str, default='./logs', help="directory to save logs in") 26 | parser.add_argument('--dataname', type=str, default='ThingsEEG-Text', help="dataset") 27 | parser.add_argument('--sbj', type=str, default='sub-01', help="eeg subject") 28 | parser.add_argument('--roi', type=str, default='17channels', help="ROI") 29 | parser.add_argument('--text_model', type=str, default='CLIPText', help="text embedding model") 30 | parser.add_argument('--image_model', type=str, default='pytorch/cornet_s', help="image embedding model") 31 | 32 | parser.add_argument('--test_type', type=str, default='zsl', help='normal or zsl') 33 | parser.add_argument('--aug_type', type=str, default='no_aug', help='no_aug, image_text_ilsvrc2012_val') 34 | parser.add_argument('--unimodal', type=str, default='image', help='image, text') 35 | #multimodal 36 | parser.add_argument('--method', type=str, default='joint_elbo', help='choose method for training the model') 37 | parser.add_argument('--modality_jsd', type=bool, default=False, help="modality_jsd") 38 | parser.add_argument('--modality_poe', type=bool, default=False, help="modality_poe") 39 | parser.add_argument('--modality_moe', type=bool, default=False, help="modality_moe") 40 | parser.add_argument('--joint_elbo', type=bool, default=False, help="modality_moe") 41 | parser.add_argument('--poe_unimodal_elbos', type=bool, default=True, help="unimodal_klds") 42 | parser.add_argument('--factorized_representation', action='store_true', default=False, help="factorized_representation") 43 | 44 | # LOSS TERM WEIGHTS 45 | parser.add_argument('--beta', type=float, default=0.0, help="default initial weight of sum of weighted divergence terms") 46 | parser.add_argument('--beta_style', type=float, default=1.0, help="default weight of sum of weighted style divergence terms") 47 | parser.add_argument('--beta_content', type=float, default=1.0, help="default weight of sum of weighted content divergence terms") 48 | parser.add_argument('--lambda1', type=float, default=0.001, help="default weight of intra_mi terms") 49 | parser.add_argument('--lambda2', type=float, default=0.001, help="default weight of inter_mi terms") 50 | 51 | 52 | FLAGS = parser.parse_args() 53 | data_dir_root = os.path.join('./data', FLAGS.dataname) 54 | brain_dir = os.path.join(data_dir_root, 'brain_feature', FLAGS.roi, FLAGS.sbj) 55 | image_dir_train = os.path.join(data_dir_root, 'visual_feature/ThingsTrain', FLAGS.image_model, FLAGS.sbj) 56 | text_dir_train = os.path.join(data_dir_root, 'textual_feature/ThingsTrain/text', FLAGS.text_model, FLAGS.sbj) 57 | 58 | train_brain = sio.loadmat(os.path.join(brain_dir, 'eeg_train_data_within.mat'))['data'].astype('double') 59 | train_brain = train_brain[:,:,27:60] # 70ms-400ms 60 | train_brain = np.reshape(train_brain,(train_brain.shape[0],-1)) 61 | train_image = sio.loadmat(os.path.join(image_dir_train, 'feat_pca_train.mat'))['data'].astype('double') 62 | train_text = sio.loadmat(os.path.join(text_dir_train, 'text_feat_train.mat'))['data'].astype('double') 63 | train_image = train_image[:,0:100] # top 100 PCs 64 | 65 | 66 | train_brain = torch.from_numpy(train_brain) 67 | train_image = torch.from_numpy(train_image) 68 | train_text = torch.from_numpy(train_text) 69 | dim_brain = train_brain.shape[1] 70 | dim_image = train_image.shape[1] 71 | dim_text = train_text.shape[1] 72 | 73 | parser.add_argument('--m1_dim', type=int, default=dim_brain, help="dimension of modality brain") 74 | parser.add_argument('--m2_dim', type=int, default=dim_image, help="dimension of modality image") 75 | parser.add_argument('--m3_dim', type=int, default=dim_text, help="dimension of modality text") 76 | parser.add_argument('--data_dir_root', type=str, default=data_dir_root, help="data dir") 77 | 78 | FLAGS = parser.parse_args() 79 | print(FLAGS) 80 | -------------------------------------------------------------------------------- /BraVL_EEG/divergence_measures/kl_div.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from utils.utils import reweight_weights 5 | 6 | 7 | def calc_kl_divergence(mu0, logvar0, mu1=None, logvar1=None, norm_value=None): 8 | if mu1 is None or logvar1 is None: 9 | KLD = -0.5 * torch.sum(1 - logvar0.exp() - mu0.pow(2) + logvar0) 10 | else: 11 | KLD = -0.5 * (torch.sum(1 - logvar0.exp()/logvar1.exp() - (mu0-mu1).pow(2)/logvar1.exp() + logvar0 - logvar1)) 12 | if norm_value is not None: 13 | KLD = KLD / float(norm_value); 14 | return KLD 15 | 16 | 17 | def calc_gaussian_scaling_factor(PI, mu1, logvar1, mu2=None, logvar2=None, norm_value=None): 18 | d = mu1.shape[1]; 19 | if mu2 is None or logvar2 is None: 20 | # print('S_11: ' + str(torch.sum(1/((2*PI*(logvar1.exp() + 1)).pow(0.5))))) 21 | # print('S_12: ' + str(torch.sum(torch.exp(-0.5*(mu1.pow(2)/(logvar1.exp()+1)))))) 22 | S_pre = (1/(2*PI).pow(d/2))*torch.sum((logvar1.exp() + 1), dim=1).pow(0.5); 23 | S = S_pre*torch.sum((-0.5*(mu1.pow(2)/(logvar1.exp()+1))).exp(), dim=1); 24 | S = torch.sum(S) 25 | else: 26 | # print('S_21: ' + str(torch.sum(1/((2*PI).pow(d/2)*(logvar1.exp()+logvar2.exp()).pow(0.5))))); 27 | # print('S_22: ' + str(torch.sum(torch.exp(-0.5 * ((mu1 - mu2).pow(2) / (logvar1.exp() + logvar2.exp())))))); 28 | S_pre = torch.sum(1/((2*PI).pow(d/2)*(logvar1.exp()+logvar2.exp())), dim=1).pow(0.5) 29 | S = S_pre*torch.sum(torch.exp(-0.5*((mu1-mu2).pow(2)/(logvar1.exp()+logvar2.exp()))), dim=1); 30 | S = torch.sum(S) 31 | if norm_value is not None: 32 | S = S / float(norm_value); 33 | # print('S: ' + str(S)) 34 | return S 35 | 36 | 37 | def calc_gaussian_scaling_factor_self(PI, logvar1, norm_value=None): 38 | d = logvar1.shape[1]; 39 | S = (1/(2*PI).pow(d/2))*torch.sum(logvar1.exp(), dim=1).pow(0.5); 40 | S = torch.sum(S); 41 | # S = torch.sum(1 / (2*(PI*torch.exp(logvar1)).pow(0.5))); 42 | if norm_value is not None: 43 | S = S / float(norm_value); 44 | # print('S self: ' + str(S)) 45 | return S 46 | 47 | 48 | #def calc_kl_divergence_lb_gauss_mixture(flags, index, mu1, logvar1, mus, logvars, norm_value=None): 49 | # klds = torch.zeros(mus.shape[0]+1) 50 | # if flags.cuda: 51 | # klds = klds.cuda(); 52 | # 53 | # klds[0] = calc_kl_divergence(mu1, logvar1, norm_value=norm_value); 54 | # for k in range(0, mus.shape[0]): 55 | # if k == index: 56 | # kld = 0.0; 57 | # else: 58 | # kld = calc_kl_divergence(mu1, logvar1, mus[k], logvars[k], norm_value=norm_value); 59 | # klds[k+1] = kld; 60 | # kld_mixture = klds.mean(); 61 | # return kld_mixture; 62 | 63 | def calc_kl_divergence_lb_gauss_mixture(flags, index, mu1, logvar1, mus, logvars, norm_value=None): 64 | PI = torch.Tensor([math.pi]); 65 | w_modalities = torch.Tensor(flags.alpha_modalities); 66 | if flags.cuda: 67 | PI = PI.cuda(); 68 | w_modalities = w_modalities.cuda(); 69 | w_modalities = reweight_weights(w_modalities); 70 | 71 | denom = w_modalities[0]*calc_gaussian_scaling_factor(PI, mu1, logvar1, norm_value=norm_value); 72 | for k in range(0, len(mus)): 73 | if index == k: 74 | denom += w_modalities[k+1]*calc_gaussian_scaling_factor_self(PI, logvar1, norm_value=norm_value); 75 | else: 76 | denom += w_modalities[k+1]*calc_gaussian_scaling_factor(PI, mu1, logvar1, mus[k], logvars[k], norm_value=norm_value) 77 | lb = -torch.log(denom); 78 | return lb; 79 | 80 | 81 | def calc_kl_divergence_ub_gauss_mixture(flags, index, mu1, logvar1, mus, logvars, entropy, norm_value=None): 82 | PI = torch.Tensor([math.pi]); 83 | w_modalities = torch.Tensor(flags.alpha_modalities); 84 | if flags.cuda: 85 | PI = PI.cuda(); 86 | w_modalities = w_modalities.cuda(); 87 | w_modalities = reweight_weights(w_modalities); 88 | 89 | nom = calc_gaussian_scaling_factor_self(PI, logvar1, norm_value=norm_value); 90 | kl_div = calc_kl_divergence(mu1, logvar1, norm_value=norm_value); 91 | print('kl div uniform: ' + str(kl_div)) 92 | denom = w_modalities[0]*torch.min(torch.Tensor([kl_div.exp(), 100000])); 93 | for k in range(0, len(mus)): 94 | if index == k: 95 | denom += w_modalities[k+1]; 96 | else: 97 | kl_div = calc_kl_divergence(mu1, logvar1, mus[k], logvars[k], norm_value=norm_value) 98 | print('kl div ' + str(k) + ': ' + str(kl_div)) 99 | denom += w_modalities[k+1]*torch.min(torch.Tensor([kl_div.exp(), 100000])); 100 | ub = torch.log(nom) - torch.log(denom) + entropy; 101 | return ub; 102 | 103 | 104 | def calc_entropy_gauss(flags, logvar, norm_value=None): 105 | PI = torch.Tensor([math.pi]); 106 | if flags.cuda: 107 | PI = PI.cuda(); 108 | ent = 0.5*torch.sum(torch.log(2*PI) + logvar + 1) 109 | if norm_value is not None: 110 | ent = ent / norm_value; 111 | return ent; -------------------------------------------------------------------------------- /BraVL_fMRI/divergence_measures/kl_div.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | from utils.utils import reweight_weights 5 | 6 | 7 | def calc_kl_divergence(mu0, logvar0, mu1=None, logvar1=None, norm_value=None): 8 | if mu1 is None or logvar1 is None: 9 | KLD = -0.5 * torch.sum(1 - logvar0.exp() - mu0.pow(2) + logvar0) 10 | else: 11 | KLD = -0.5 * (torch.sum(1 - logvar0.exp()/logvar1.exp() - (mu0-mu1).pow(2)/logvar1.exp() + logvar0 - logvar1)) 12 | if norm_value is not None: 13 | KLD = KLD / float(norm_value); 14 | return KLD 15 | 16 | 17 | def calc_gaussian_scaling_factor(PI, mu1, logvar1, mu2=None, logvar2=None, norm_value=None): 18 | d = mu1.shape[1]; 19 | if mu2 is None or logvar2 is None: 20 | # print('S_11: ' + str(torch.sum(1/((2*PI*(logvar1.exp() + 1)).pow(0.5))))) 21 | # print('S_12: ' + str(torch.sum(torch.exp(-0.5*(mu1.pow(2)/(logvar1.exp()+1)))))) 22 | S_pre = (1/(2*PI).pow(d/2))*torch.sum((logvar1.exp() + 1), dim=1).pow(0.5); 23 | S = S_pre*torch.sum((-0.5*(mu1.pow(2)/(logvar1.exp()+1))).exp(), dim=1); 24 | S = torch.sum(S) 25 | else: 26 | # print('S_21: ' + str(torch.sum(1/((2*PI).pow(d/2)*(logvar1.exp()+logvar2.exp()).pow(0.5))))); 27 | # print('S_22: ' + str(torch.sum(torch.exp(-0.5 * ((mu1 - mu2).pow(2) / (logvar1.exp() + logvar2.exp())))))); 28 | S_pre = torch.sum(1/((2*PI).pow(d/2)*(logvar1.exp()+logvar2.exp())), dim=1).pow(0.5) 29 | S = S_pre*torch.sum(torch.exp(-0.5*((mu1-mu2).pow(2)/(logvar1.exp()+logvar2.exp()))), dim=1); 30 | S = torch.sum(S) 31 | if norm_value is not None: 32 | S = S / float(norm_value); 33 | # print('S: ' + str(S)) 34 | return S 35 | 36 | 37 | def calc_gaussian_scaling_factor_self(PI, logvar1, norm_value=None): 38 | d = logvar1.shape[1]; 39 | S = (1/(2*PI).pow(d/2))*torch.sum(logvar1.exp(), dim=1).pow(0.5); 40 | S = torch.sum(S); 41 | # S = torch.sum(1 / (2*(PI*torch.exp(logvar1)).pow(0.5))); 42 | if norm_value is not None: 43 | S = S / float(norm_value); 44 | # print('S self: ' + str(S)) 45 | return S 46 | 47 | 48 | #def calc_kl_divergence_lb_gauss_mixture(flags, index, mu1, logvar1, mus, logvars, norm_value=None): 49 | # klds = torch.zeros(mus.shape[0]+1) 50 | # if flags.cuda: 51 | # klds = klds.cuda(); 52 | # 53 | # klds[0] = calc_kl_divergence(mu1, logvar1, norm_value=norm_value); 54 | # for k in range(0, mus.shape[0]): 55 | # if k == index: 56 | # kld = 0.0; 57 | # else: 58 | # kld = calc_kl_divergence(mu1, logvar1, mus[k], logvars[k], norm_value=norm_value); 59 | # klds[k+1] = kld; 60 | # kld_mixture = klds.mean(); 61 | # return kld_mixture; 62 | 63 | def calc_kl_divergence_lb_gauss_mixture(flags, index, mu1, logvar1, mus, logvars, norm_value=None): 64 | PI = torch.Tensor([math.pi]); 65 | w_modalities = torch.Tensor(flags.alpha_modalities); 66 | if flags.cuda: 67 | PI = PI.cuda(); 68 | w_modalities = w_modalities.cuda(); 69 | w_modalities = reweight_weights(w_modalities); 70 | 71 | denom = w_modalities[0]*calc_gaussian_scaling_factor(PI, mu1, logvar1, norm_value=norm_value); 72 | for k in range(0, len(mus)): 73 | if index == k: 74 | denom += w_modalities[k+1]*calc_gaussian_scaling_factor_self(PI, logvar1, norm_value=norm_value); 75 | else: 76 | denom += w_modalities[k+1]*calc_gaussian_scaling_factor(PI, mu1, logvar1, mus[k], logvars[k], norm_value=norm_value) 77 | lb = -torch.log(denom); 78 | return lb; 79 | 80 | 81 | def calc_kl_divergence_ub_gauss_mixture(flags, index, mu1, logvar1, mus, logvars, entropy, norm_value=None): 82 | PI = torch.Tensor([math.pi]); 83 | w_modalities = torch.Tensor(flags.alpha_modalities); 84 | if flags.cuda: 85 | PI = PI.cuda(); 86 | w_modalities = w_modalities.cuda(); 87 | w_modalities = reweight_weights(w_modalities); 88 | 89 | nom = calc_gaussian_scaling_factor_self(PI, logvar1, norm_value=norm_value); 90 | kl_div = calc_kl_divergence(mu1, logvar1, norm_value=norm_value); 91 | print('kl div uniform: ' + str(kl_div)) 92 | denom = w_modalities[0]*torch.min(torch.Tensor([kl_div.exp(), 100000])); 93 | for k in range(0, len(mus)): 94 | if index == k: 95 | denom += w_modalities[k+1]; 96 | else: 97 | kl_div = calc_kl_divergence(mu1, logvar1, mus[k], logvars[k], norm_value=norm_value) 98 | print('kl div ' + str(k) + ': ' + str(kl_div)) 99 | denom += w_modalities[k+1]*torch.min(torch.Tensor([kl_div.exp(), 100000])); 100 | ub = torch.log(nom) - torch.log(denom) + entropy; 101 | return ub; 102 | 103 | 104 | def calc_entropy_gauss(flags, logvar, norm_value=None): 105 | PI = torch.Tensor([math.pi]); 106 | if flags.cuda: 107 | PI = PI.cuda(); 108 | ent = 0.5*torch.sum(torch.log(2*PI) + logvar + 1) 109 | if norm_value is not None: 110 | ent = ent / norm_value; 111 | return ent; -------------------------------------------------------------------------------- /BraVL_EEG/divergence_measures/mm_div.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | from divergence_measures.kl_div import calc_kl_divergence 6 | from divergence_measures.kl_div import calc_kl_divergence_lb_gauss_mixture 7 | from divergence_measures.kl_div import calc_kl_divergence_ub_gauss_mixture 8 | from divergence_measures.kl_div import calc_entropy_gauss 9 | 10 | from utils.utils import reweight_weights 11 | 12 | 13 | def poe(mu, logvar, eps=1e-8): 14 | var = torch.exp(logvar) + eps 15 | # precision of i-th Gaussian expert at point x 16 | T = 1. / var 17 | pd_mu = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0) 18 | pd_var = 1. / torch.sum(T, dim=0) 19 | pd_logvar = torch.log(pd_var) 20 | return pd_mu, pd_logvar 21 | 22 | 23 | def alpha_poe(alpha, mu, logvar, eps=1e-8): 24 | var = torch.exp(logvar) + eps 25 | # precision of i-th Gaussian expert at point x 26 | if var.dim() == 3: 27 | alpha_expanded = alpha.unsqueeze(-1).unsqueeze(-1); 28 | elif var.dim() == 4: 29 | alpha_expanded = alpha.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1); 30 | 31 | T = 1 / var; 32 | pd_var = 1. / torch.sum(alpha_expanded * T, dim=0) 33 | pd_mu = pd_var * torch.sum(alpha_expanded * mu * T, dim=0) 34 | pd_logvar = torch.log(pd_var) 35 | return pd_mu, pd_logvar; 36 | 37 | 38 | def calc_alphaJSD_modalities_mixture(m1_mu, m1_logvar, m2_mu, m2_logvar, flags): 39 | klds = torch.zeros(2); 40 | entropies_mixture = torch.zeros(2); 41 | w_modalities = torch.Tensor(flags.alpha_modalities[1:]); 42 | if flags.cuda: 43 | w_modalities = w_modalities.cuda(); 44 | klds = klds.cuda(); 45 | entropies_mixture = entropies_mixture.cuda(); 46 | w_modalities = reweight_weights(w_modalities); 47 | 48 | mus = [m1_mu, m2_mu] 49 | logvars = [m1_logvar, m2_logvar] 50 | for k in range(0, len(mus)): 51 | ent = calc_entropy_gauss(flags, logvars[k], norm_value=flags.batch_size); 52 | # print('entropy: ' + str(ent)) 53 | # print('lb: ' ) 54 | kld_lb = calc_kl_divergence_lb_gauss_mixture(flags, k, mus[k], logvars[k], mus, logvars, 55 | norm_value=flags.batch_size); 56 | print('kld_lb: ' + str(kld_lb)) 57 | # print('ub: ') 58 | kld_ub = calc_kl_divergence_ub_gauss_mixture(flags, k, mus[k], logvars[k], mus, logvars, ent, 59 | norm_value=flags.batch_size); 60 | print('kld_ub: ' + str(kld_ub)) 61 | # kld_mean = (kld_lb+kld_ub)/2; 62 | entropies_mixture[k] = ent.clone(); 63 | klds[k] = 0.5*(kld_lb + kld_ub); 64 | # klds[k] = kld_ub; 65 | summed_klds = (w_modalities * klds).sum(); 66 | # print('summed klds: ' + str(summed_klds)); 67 | return summed_klds, klds, entropies_mixture; 68 | 69 | def calc_alphaJSD_modalities(flags, mus, logvars, weights, normalization=None): 70 | num_mods = mus.shape[0]; 71 | num_samples = mus.shape[1]; 72 | alpha_mu, alpha_logvar = alpha_poe(weights, mus, logvars) 73 | if normalization is not None: 74 | klds = torch.zeros(num_mods); 75 | else: 76 | klds = torch.zeros(num_mods, num_samples); 77 | klds = klds.to(flags.device); 78 | 79 | for k in range(0, num_mods): 80 | kld = calc_kl_divergence(mus[k,:,:], logvars[k,:,:], alpha_mu, 81 | alpha_logvar, norm_value=normalization); 82 | if normalization is not None: 83 | klds[k] = kld; 84 | else: 85 | klds[k,:] = kld; 86 | if normalization is None: 87 | weights = weights.unsqueeze(1).repeat(1, num_samples); 88 | group_div = (weights * klds).sum(dim=0); 89 | return group_div, klds, [alpha_mu, alpha_logvar]; 90 | 91 | 92 | def calc_group_divergence_moe(flags, mus, logvars, weights, normalization=None): 93 | num_mods = mus.shape[0]; 94 | num_samples = mus.shape[1]; 95 | if normalization is not None: 96 | klds = torch.zeros(num_mods); 97 | else: 98 | klds = torch.zeros(num_mods, num_samples); 99 | klds = klds.to(flags.device); 100 | weights = weights.to(flags.device); 101 | for k in range(0, num_mods): 102 | kld_ind = calc_kl_divergence(mus[k,:,:], logvars[k,:,:], 103 | norm_value=normalization); 104 | if normalization is not None: 105 | klds[k] = kld_ind; 106 | else: 107 | klds[k,:] = kld_ind; 108 | if normalization is None: 109 | weights = weights.unsqueeze(1).repeat(1, num_samples); 110 | group_div = (weights*klds).sum(dim=0); 111 | return group_div, klds; 112 | 113 | 114 | def calc_group_divergence_poe(flags, mus, logvars, norm=None): 115 | num_mods = mus.shape[0]; 116 | poe_mu, poe_logvar = poe(mus, logvars) 117 | kld_poe = calc_kl_divergence(poe_mu, poe_logvar, norm_value=norm); 118 | klds = torch.zeros(num_mods).to(flags.device); 119 | for k in range(0, num_mods): 120 | kld_ind = calc_kl_divergence(mus[k,:,:], logvars[k,:,:], 121 | norm_value=norm); 122 | klds[k] = kld_ind; 123 | return kld_poe, klds, [poe_mu, poe_logvar]; 124 | 125 | 126 | def calc_modality_divergence(m1_mu, m1_logvar, m2_mu, m2_logvar, flags): 127 | if flags.modality_poe: 128 | kld_batch = calc_kl_divergence(m1_mu, m1_logvar, m2_mu, m2_logvar, norm_value=flags.batch_size).sum(); 129 | return kld_batch; 130 | else: 131 | uniform_mu = torch.zeros(m1_mu.shape) 132 | uniform_logvar = torch.zeros(m1_logvar.shape) 133 | klds = torch.zeros(3,3) 134 | klds_modonly = torch.zeros(2,2) 135 | if flags.cuda: 136 | klds = klds.cuda(); 137 | klds_modonly = klds_modonly.cuda(); 138 | uniform_mu = uniform_mu.cuda(); 139 | uniform_logvar = uniform_logvar.cuda(); 140 | 141 | mus = [uniform_mu, m1_mu, m2_mu] 142 | logvars = [uniform_logvar, m1_logvar, m2_logvar] 143 | for i in range(1, len(mus)): # CAREFUL: index starts from one, not zero 144 | for j in range(0, len(mus)): 145 | kld = calc_kl_divergence(mus[i], logvars[i], mus[j], logvars[j], norm_value=flags.batch_size); 146 | klds[i,j] = kld; 147 | if i >= 1 and j >= 1: 148 | klds_modonly[i-1,j-1] = kld; 149 | klds = klds.sum()/(len(mus)*(len(mus)-1)) 150 | klds_modonly = klds_modonly.sum()/((len(mus)-1)*(len(mus)-1)); 151 | return [klds, klds_modonly]; 152 | -------------------------------------------------------------------------------- /BraVL_fMRI/divergence_measures/mm_div.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | 5 | from divergence_measures.kl_div import calc_kl_divergence 6 | from divergence_measures.kl_div import calc_kl_divergence_lb_gauss_mixture 7 | from divergence_measures.kl_div import calc_kl_divergence_ub_gauss_mixture 8 | from divergence_measures.kl_div import calc_entropy_gauss 9 | 10 | from utils.utils import reweight_weights 11 | 12 | 13 | def poe(mu, logvar, eps=1e-8): 14 | var = torch.exp(logvar) + eps 15 | # precision of i-th Gaussian expert at point x 16 | T = 1. / var 17 | pd_mu = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0) 18 | pd_var = 1. / torch.sum(T, dim=0) 19 | pd_logvar = torch.log(pd_var) 20 | return pd_mu, pd_logvar 21 | 22 | 23 | def alpha_poe(alpha, mu, logvar, eps=1e-8): 24 | var = torch.exp(logvar) + eps 25 | # precision of i-th Gaussian expert at point x 26 | if var.dim() == 3: 27 | alpha_expanded = alpha.unsqueeze(-1).unsqueeze(-1); 28 | elif var.dim() == 4: 29 | alpha_expanded = alpha.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1); 30 | 31 | T = 1 / var; 32 | pd_var = 1. / torch.sum(alpha_expanded * T, dim=0) 33 | pd_mu = pd_var * torch.sum(alpha_expanded * mu * T, dim=0) 34 | pd_logvar = torch.log(pd_var) 35 | return pd_mu, pd_logvar; 36 | 37 | 38 | def calc_alphaJSD_modalities_mixture(m1_mu, m1_logvar, m2_mu, m2_logvar, flags): 39 | klds = torch.zeros(2); 40 | entropies_mixture = torch.zeros(2); 41 | w_modalities = torch.Tensor(flags.alpha_modalities[1:]); 42 | if flags.cuda: 43 | w_modalities = w_modalities.cuda(); 44 | klds = klds.cuda(); 45 | entropies_mixture = entropies_mixture.cuda(); 46 | w_modalities = reweight_weights(w_modalities); 47 | 48 | mus = [m1_mu, m2_mu] 49 | logvars = [m1_logvar, m2_logvar] 50 | for k in range(0, len(mus)): 51 | ent = calc_entropy_gauss(flags, logvars[k], norm_value=flags.batch_size); 52 | # print('entropy: ' + str(ent)) 53 | # print('lb: ' ) 54 | kld_lb = calc_kl_divergence_lb_gauss_mixture(flags, k, mus[k], logvars[k], mus, logvars, 55 | norm_value=flags.batch_size); 56 | print('kld_lb: ' + str(kld_lb)) 57 | # print('ub: ') 58 | kld_ub = calc_kl_divergence_ub_gauss_mixture(flags, k, mus[k], logvars[k], mus, logvars, ent, 59 | norm_value=flags.batch_size); 60 | print('kld_ub: ' + str(kld_ub)) 61 | # kld_mean = (kld_lb+kld_ub)/2; 62 | entropies_mixture[k] = ent.clone(); 63 | klds[k] = 0.5*(kld_lb + kld_ub); 64 | # klds[k] = kld_ub; 65 | summed_klds = (w_modalities * klds).sum(); 66 | # print('summed klds: ' + str(summed_klds)); 67 | return summed_klds, klds, entropies_mixture; 68 | 69 | def calc_alphaJSD_modalities(flags, mus, logvars, weights, normalization=None): 70 | num_mods = mus.shape[0]; 71 | num_samples = mus.shape[1]; 72 | alpha_mu, alpha_logvar = alpha_poe(weights, mus, logvars) 73 | if normalization is not None: 74 | klds = torch.zeros(num_mods); 75 | else: 76 | klds = torch.zeros(num_mods, num_samples); 77 | klds = klds.to(flags.device); 78 | 79 | for k in range(0, num_mods): 80 | kld = calc_kl_divergence(mus[k,:,:], logvars[k,:,:], alpha_mu, 81 | alpha_logvar, norm_value=normalization); 82 | if normalization is not None: 83 | klds[k] = kld; 84 | else: 85 | klds[k,:] = kld; 86 | if normalization is None: 87 | weights = weights.unsqueeze(1).repeat(1, num_samples); 88 | group_div = (weights * klds).sum(dim=0); 89 | return group_div, klds, [alpha_mu, alpha_logvar]; 90 | 91 | 92 | def calc_group_divergence_moe(flags, mus, logvars, weights, normalization=None): 93 | num_mods = mus.shape[0]; 94 | num_samples = mus.shape[1]; 95 | if normalization is not None: 96 | klds = torch.zeros(num_mods); 97 | else: 98 | klds = torch.zeros(num_mods, num_samples); 99 | klds = klds.to(flags.device); 100 | weights = weights.to(flags.device); 101 | for k in range(0, num_mods): 102 | kld_ind = calc_kl_divergence(mus[k,:,:], logvars[k,:,:], 103 | norm_value=normalization); 104 | if normalization is not None: 105 | klds[k] = kld_ind; 106 | else: 107 | klds[k,:] = kld_ind; 108 | if normalization is None: 109 | weights = weights.unsqueeze(1).repeat(1, num_samples); 110 | group_div = (weights*klds).sum(dim=0); 111 | return group_div, klds; 112 | 113 | 114 | def calc_group_divergence_poe(flags, mus, logvars, norm=None): 115 | num_mods = mus.shape[0]; 116 | poe_mu, poe_logvar = poe(mus, logvars) 117 | kld_poe = calc_kl_divergence(poe_mu, poe_logvar, norm_value=norm); 118 | klds = torch.zeros(num_mods).to(flags.device); 119 | for k in range(0, num_mods): 120 | kld_ind = calc_kl_divergence(mus[k,:,:], logvars[k,:,:], 121 | norm_value=norm); 122 | klds[k] = kld_ind; 123 | return kld_poe, klds, [poe_mu, poe_logvar]; 124 | 125 | 126 | def calc_modality_divergence(m1_mu, m1_logvar, m2_mu, m2_logvar, flags): 127 | if flags.modality_poe: 128 | kld_batch = calc_kl_divergence(m1_mu, m1_logvar, m2_mu, m2_logvar, norm_value=flags.batch_size).sum(); 129 | return kld_batch; 130 | else: 131 | uniform_mu = torch.zeros(m1_mu.shape) 132 | uniform_logvar = torch.zeros(m1_logvar.shape) 133 | klds = torch.zeros(3,3) 134 | klds_modonly = torch.zeros(2,2) 135 | if flags.cuda: 136 | klds = klds.cuda(); 137 | klds_modonly = klds_modonly.cuda(); 138 | uniform_mu = uniform_mu.cuda(); 139 | uniform_logvar = uniform_logvar.cuda(); 140 | 141 | mus = [uniform_mu, m1_mu, m2_mu] 142 | logvars = [uniform_logvar, m1_logvar, m2_logvar] 143 | for i in range(1, len(mus)): # CAREFUL: index starts from one, not zero 144 | for j in range(0, len(mus)): 145 | kld = calc_kl_divergence(mus[i], logvars[i], mus[j], logvars[j], norm_value=flags.batch_size); 146 | klds[i,j] = kld; 147 | if i >= 1 and j >= 1: 148 | klds_modonly[i-1,j-1] = kld; 149 | klds = klds.sum()/(len(mus)*(len(mus)-1)) 150 | klds_modonly = klds_modonly.sum()/((len(mus)-1)*(len(mus)-1)); 151 | return [klds, klds_modonly]; 152 | -------------------------------------------------------------------------------- /BraVL_fMRI/extract_fea_with_timm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from scipy import io 4 | os.environ['CUDA_VISIBLE_DEVICES'] = '3' 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim 8 | import torch.utils.data 9 | import torch.utils.data.distributed 10 | import torchvision.transforms as transforms 11 | import torchvision.datasets as datasets 12 | import PIL 13 | import torch 14 | import timm 15 | 16 | # python extract_fea_with_timm.py --data ./data/GenericObjectDecoding-v2/images/training --save_dir ./data/GOD-Wiki/visual_feature/ImageNetTraining --model repvgg_b3g4 --resolution 224 17 | 18 | parser = argparse.ArgumentParser(description='PyTorch ImageNet Test') 19 | parser.add_argument('-i', '--data', metavar='./data/GenericObjectDecoding-v2/images/training', 20 | help='path to dataset') 21 | parser.add_argument('-o', '--save_dir', metavar='./data/GOD-Wiki/visual_feature/ImageNetTraining', 22 | help='path to save') 23 | parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', 24 | help='number of data loading workers (default: 4)') 25 | parser.add_argument('-b', '--batch-size', default=1, type=int, 26 | metavar='N', 27 | help='mini-batch size (default: 100) for test') 28 | parser.add_argument('-r', '--resolution', default=224, type=int, 29 | metavar='R', help='resolution (default: 224) for test') 30 | parser.add_argument('-m', '--model', default='resnet50', type=str, 31 | metavar='M', help='pretrained model for test') 32 | 33 | args = parser.parse_args() 34 | root_dir = args.save_dir+'/pytorch/'+ args.model +'/' 35 | if not os.path.exists(root_dir): 36 | os.makedirs(root_dir) 37 | 38 | def get_default_val_trans(args): 39 | if (not hasattr(args, 'resolution')) or args.resolution == 224: 40 | trans = transforms.Compose([ 41 | transforms.Resize(256), 42 | transforms.CenterCrop(224), 43 | transforms.ToTensor(), 44 | ]) 45 | else: 46 | trans = transforms.Compose([ 47 | transforms.Resize(args.resolution, interpolation=PIL.Image.BILINEAR), 48 | transforms.CenterCrop(args.resolution), 49 | transforms.ToTensor(), 50 | ]) 51 | return trans 52 | 53 | def get_ImageNet_val_dataset(args, trans): 54 | val_dataset = datasets.ImageFolder(args.data, trans) 55 | return val_dataset 56 | 57 | def get_default_ImageNet_val_loader_withpath(args): 58 | val_trans = get_default_val_trans(args) 59 | val_dataset = get_ImageNet_val_dataset(args, val_trans) 60 | val_loader = torch.utils.data.DataLoader( 61 | val_dataset, 62 | batch_size=args.batch_size, shuffle=False, 63 | num_workers=args.workers, pin_memory=True) 64 | return val_loader, val_dataset 65 | 66 | 67 | def extract(val_loader, val_dataset, model_final, model_linear, model_multiscale, use_gpu): 68 | def save_feature(feat, flag): 69 | feature_name = feature 70 | l = feature_name.split('_') 71 | if 'out' in l: 72 | l.remove('out') 73 | if 'list' in l: 74 | l.remove('list') 75 | feature_name = '_'.join(l) 76 | 77 | feat = feat.cpu().numpy() 78 | if flag == 'list': 79 | dir1 = '{}/{}_{}'.format(root_dir, feature_name, i) 80 | else: 81 | dir1 = '{}/{}'.format(root_dir, feature_name) 82 | if not os.path.exists(dir1): 83 | os.makedirs(dir1) 84 | filename = '{}.mat'.format(imid) 85 | io.savemat(dir1 + '/' + filename, {'feat': feat}) 86 | 87 | # switch to evaluate mode 88 | model_final.eval() 89 | model_linear.eval() 90 | model_multiscale.eval() 91 | # 对应文件夹的label 92 | # print(val_dataset.class_to_idx) 93 | with torch.no_grad(): 94 | for i, images in enumerate(val_loader): 95 | if use_gpu: 96 | images = images[0].cuda(non_blocking=True) 97 | final = model_final(images) 98 | print(f'Original shape: {final.shape}') 99 | linear = model_linear(images) 100 | print(f'Pooled shape: {linear.shape}') 101 | Conv = model_multiscale(images) 102 | # Conv = [Conv[-4],Conv[-3],Conv[-2],Conv[-1]] 103 | for x in Conv: 104 | print(x.shape) 105 | 106 | wnid = val_dataset.imgs[i][0].split("/")[-2] 107 | imid = val_dataset.imgs[i][0].split("/")[-1].split('.')[0] 108 | print(wnid, imid) 109 | feature_list = ['final','linear','Conv'] 110 | 111 | for feature in feature_list: 112 | feat = eval(feature) 113 | if type(feat) == list: 114 | for i in range(len(feat)): 115 | save_feature(feat[i], 'list') 116 | else: 117 | save_feature(feat, 'single') 118 | 119 | def inference(): 120 | model_final = timm.create_model(args.model, pretrained=True) 121 | model_linear = timm.create_model(args.model, pretrained=True, num_classes=0) 122 | model_multiscale = timm.create_model(args.model, pretrained=True, features_only=True) 123 | if not torch.cuda.is_available(): 124 | print('using CPU, this will be slow') 125 | use_gpu = False 126 | else: 127 | model_final = model_final.cuda() 128 | model_linear = model_linear.cuda() 129 | model_multiscale = model_multiscale.cuda() 130 | use_gpu = True 131 | 132 | cudnn.benchmark = True 133 | 134 | val_loader, val_dataset = get_default_ImageNet_val_loader_withpath(args) 135 | 136 | extract(val_loader, val_dataset, model_final, model_linear, model_multiscale, use_gpu) 137 | 138 | def extract_no_conv(val_loader, val_dataset, model_final, model_linear, use_gpu): 139 | def save_feature(feat, flag): 140 | feature_name = feature 141 | l = feature_name.split('_') 142 | if 'out' in l: 143 | l.remove('out') 144 | if 'list' in l: 145 | l.remove('list') 146 | feature_name = '_'.join(l) 147 | 148 | feat = feat.cpu().numpy() 149 | if flag == 'list': 150 | dir1 = '{}/{}_{}'.format(root_dir, feature_name, i) 151 | else: 152 | dir1 = '{}/{}'.format(root_dir, feature_name) 153 | if not os.path.exists(dir1): 154 | os.makedirs(dir1) 155 | filename = '{}.mat'.format(imid) 156 | io.savemat(dir1 + '/' + filename, {'feat': feat}) 157 | 158 | # switch to evaluate mode 159 | model_final.eval() 160 | model_linear.eval() 161 | with torch.no_grad(): 162 | for i, images in enumerate(val_loader): 163 | if use_gpu: 164 | images = images[0].cuda(non_blocking=True) 165 | final = model_final(images) 166 | print(f'Original shape: {final.shape}') 167 | linear = model_linear(images) 168 | print(f'Pooled shape: {linear.shape}') 169 | 170 | wnid = val_dataset.imgs[i][0].split("/")[-2] 171 | imid = val_dataset.imgs[i][0].split("/")[-1].split('.')[0] 172 | print(wnid, imid) 173 | feature_list = ['final','linear'] 174 | 175 | for feature in feature_list: 176 | feat = eval(feature) 177 | if type(feat) == list: 178 | for i in range(len(feat)): 179 | save_feature(feat[i], 'list') 180 | else: 181 | save_feature(feat, 'single') 182 | 183 | def inference_no_conv(): 184 | model_final = timm.create_model(args.model, pretrained=True) 185 | model_linear = timm.create_model(args.model, pretrained=True, num_classes=0) 186 | if not torch.cuda.is_available(): 187 | print('using CPU, this will be slow') 188 | use_gpu = False 189 | else: 190 | model_final = model_final.cuda() 191 | model_linear = model_linear.cuda() 192 | use_gpu = True 193 | 194 | cudnn.benchmark = True 195 | 196 | val_loader, val_dataset = get_default_ImageNet_val_loader_withpath(args) 197 | 198 | extract_no_conv(val_loader, val_dataset, model_final, model_linear, use_gpu) 199 | 200 | if __name__ == '__main__': 201 | inference() 202 | # inference_no_conv() 203 | 204 | -------------------------------------------------------------------------------- /BraVL_EEG/brain_image_text/experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import itertools 4 | import scipy.io as sio 5 | import torch 6 | import torch.optim as optim 7 | from sklearn.metrics import accuracy_score 8 | from sklearn.model_selection import train_test_split 9 | from torch.utils.data import TensorDataset 10 | from modalities.Modality import Modality 11 | from brain_image_text.networks.VAEtrimodal import VAEtrimodal,VAEbimodal 12 | from brain_image_text.networks.QNET import QNet 13 | from brain_image_text.networks.MLP_Brain import EncoderBrain, DecoderBrain 14 | from brain_image_text.networks.MLP_Image import EncoderImage, DecoderImage 15 | from brain_image_text.networks.MLP_Text import EncoderText, DecoderText 16 | from utils.BaseExperiment import BaseExperiment 17 | 18 | 19 | class BrainImageText(BaseExperiment): 20 | def __init__(self, flags, alphabet): 21 | super().__init__(flags) 22 | 23 | self.modalities = self.set_modalities() 24 | self.num_modalities = len(self.modalities.keys()) 25 | self.subsets = self.set_subsets() 26 | self.dataset_train = None 27 | self.dataset_test = None 28 | 29 | self.set_dataset() 30 | self.mm_vae = self.set_model() 31 | self.optimizer = None 32 | self.rec_weights = self.set_rec_weights() 33 | self.style_weights = self.set_style_weights() 34 | self.Q1,self.Q2,self.Q3 = self.set_Qmodel() 35 | self.eval_metric = accuracy_score 36 | 37 | self.labels = ['digit'] 38 | 39 | 40 | def set_model(self): 41 | model = VAEtrimodal(self.flags, self.modalities, self.subsets) 42 | model = model.to(self.flags.device) 43 | return model 44 | 45 | def set_modalities(self): 46 | mod1 = Modality('brain', EncoderBrain(self.flags), DecoderBrain(self.flags), 47 | self.flags.class_dim, self.flags.style_m1_dim, 'normal') 48 | mod2 = Modality('image', EncoderImage(self.flags), DecoderImage(self.flags), 49 | self.flags.class_dim, self.flags.style_m2_dim, 'normal') 50 | mod3 = Modality('text', EncoderText(self.flags), DecoderText(self.flags), 51 | self.flags.class_dim, self.flags.style_m3_dim, 'normal') 52 | mods = {mod1.name: mod1, mod2.name: mod2, mod3.name: mod3} 53 | return mods 54 | 55 | def set_dataset(self): 56 | # load data 57 | data_dir_root = self.flags.data_dir_root 58 | sbj = self.flags.sbj 59 | image_model = self.flags.image_model 60 | text_model = self.flags.text_model 61 | roi = self.flags.roi 62 | brain_dir = os.path.join(data_dir_root, 'brain_feature', roi, sbj) 63 | image_dir_train = os.path.join(data_dir_root, 'visual_feature/ThingsTrain', image_model, sbj) 64 | image_dir_test = os.path.join(data_dir_root, 'visual_feature/ThingsTest', image_model, sbj) 65 | text_dir_train = os.path.join(data_dir_root, 'textual_feature/ThingsTrain/text', text_model, sbj) 66 | text_dir_test = os.path.join(data_dir_root, 'textual_feature/ThingsTest/text', text_model, sbj) 67 | 68 | train_brain = sio.loadmat(os.path.join(brain_dir, 'eeg_train_data_within.mat'))['data'].astype('double') * 2.0 69 | # train_brain = sio.loadmat(os.path.join(brain_dir, 'eeg_train_data_between.mat'))['data'].astype('double')*2.0 70 | train_brain = train_brain[:,:,27:60] # 70ms-400ms 71 | train_brain = np.reshape(train_brain, (train_brain.shape[0], -1)) 72 | train_image = sio.loadmat(os.path.join(image_dir_train, 'feat_pca_train.mat'))['data'].astype('double')*50.0 73 | train_text = sio.loadmat(os.path.join(text_dir_train, 'text_feat_train.mat'))['data'].astype('double')*2.0 74 | train_label = sio.loadmat(os.path.join(brain_dir, 'eeg_train_data_within.mat'))['class_idx'].T.astype('int') 75 | train_image = train_image[:,0:100] 76 | 77 | # test_brain = sio.loadmat(os.path.join(brain_dir, 'eeg_test_data_unique.mat'))['data'].astype('double')*2.0 78 | # test_brain = test_brain[:, :, 27:60] 79 | # test_brain = np.reshape(test_brain, (test_brain.shape[0], -1)) 80 | # test_image = sio.loadmat(os.path.join(image_dir_test, 'feat_pca_test_unique.mat'))['data'].astype('double')*50.0 81 | # test_text = sio.loadmat(os.path.join(text_dir_test, 'text_feat_test_unique.mat'))['data'].astype('double')*2.0 82 | # test_label = sio.loadmat(os.path.join(brain_dir, 'eeg_test_data_unique.mat'))['class_idx'].T.astype('int') 83 | # train_image = train_image[:, 0:100] 84 | 85 | 86 | test_brain = sio.loadmat(os.path.join(brain_dir, 'eeg_test_data.mat'))['data'].astype('double')*2.0 87 | test_brain = test_brain[:, :, 27:60] 88 | test_brain = np.reshape(test_brain, (test_brain.shape[0], -1)) 89 | test_image = sio.loadmat(os.path.join(image_dir_test, 'feat_pca_test.mat'))['data'].astype('double')*50.0 90 | test_text = sio.loadmat(os.path.join(text_dir_test, 'text_feat_test.mat'))['data'].astype('double')*2.0 91 | test_label = sio.loadmat(os.path.join(brain_dir, 'eeg_test_data.mat'))['class_idx'].T.astype('int') 92 | test_image = test_image[:, 0:100] 93 | 94 | if self.flags.aug_type == 'image_text_ilsvrc2012_val': 95 | image_dir_aug = os.path.join(data_dir_root, 'visual_feature/Aug_ILSVRC2012_val', image_model, sbj) 96 | text_dir_aug = os.path.join(data_dir_root, 'textual_feature/Aug_ILSVRC2012_val/text', text_model, sbj) 97 | aug_image = sio.loadmat(os.path.join(image_dir_aug, 'feat_pca_aug_ilsvrc2012_val.mat'))['data'].astype('double') 98 | aug_image = aug_image[:, 0:100] 99 | aug_text = sio.loadmat(os.path.join(text_dir_aug, 'text_feat_aug_ilsvrc2012_val.mat'))['data'].astype('double') 100 | aug_image = torch.from_numpy(aug_image) 101 | aug_text = torch.from_numpy(aug_text) 102 | print('aug_image=', aug_image.shape) 103 | print('aug_text=', aug_text.shape) 104 | elif self.flags.aug_type == 'no_aug': 105 | print('no augmentation') 106 | 107 | if self.flags.test_type=='normal': 108 | train_label_stratify = train_label 109 | train_brain, val_brain, train_label, val_label = train_test_split(train_brain, train_label_stratify, test_size=0.2, stratify=train_label_stratify) 110 | train_image, val_image, train_label, val_label = train_test_split(train_image, train_label_stratify, test_size=0.2, stratify=train_label_stratify) 111 | train_text, val_text, train_label, val_label = train_test_split(train_text, train_label_stratify, test_size=0.2, stratify=train_label_stratify) 112 | 113 | val_brain = torch.from_numpy(val_brain) 114 | val_image = torch.from_numpy(val_image) 115 | val_text = torch.from_numpy(val_text) 116 | val_label = torch.from_numpy(val_label) 117 | print('val_brain=', val_brain.shape) 118 | print('val_image=', val_image.shape) 119 | print('val_text=', val_text.shape) 120 | 121 | train_brain = torch.from_numpy(train_brain) 122 | test_brain = torch.from_numpy(test_brain) 123 | train_image = torch.from_numpy(train_image) 124 | test_image = torch.from_numpy(test_image) 125 | train_text = torch.from_numpy(train_text) 126 | test_text = torch.from_numpy(test_text) 127 | train_label = torch.from_numpy(train_label) 128 | test_label = torch.from_numpy(test_label) 129 | 130 | 131 | print('train_brain=', train_brain.shape) 132 | print('train_image=', train_image.shape) 133 | print('train_text=', train_text.shape) 134 | print('test_brain=', test_brain.shape) 135 | print('test_image=', test_image.shape) 136 | print('test_text=', test_text.shape) 137 | 138 | self.m1_dim = train_brain.shape[1] 139 | self.m2_dim = train_image.shape[1] 140 | self.m3_dim = train_text.shape[1] 141 | 142 | train_dataset = torch.utils.data.TensorDataset(train_brain, train_image, train_text, train_label) 143 | test_dataset = torch.utils.data.TensorDataset(test_brain, test_image, test_text,test_label) 144 | 145 | self.dataset_train = train_dataset 146 | self.dataset_test = test_dataset 147 | 148 | if self.flags.test_type == 'normal': 149 | val_dataset = torch.utils.data.TensorDataset(val_brain, val_image, val_text, val_label) 150 | self.dataset_val = val_dataset 151 | 152 | if 'image_text' in self.flags.aug_type: 153 | aug_dataset = torch.utils.data.TensorDataset(aug_image, aug_text) 154 | self.dataset_aug = aug_dataset 155 | elif self.flags.aug_type == 'no_aug': 156 | print('no augmentation') 157 | 158 | 159 | def set_optimizer(self): 160 | optimizer = optim.Adam( 161 | itertools.chain(self.mm_vae.parameters(),self.Q1.parameters(),self.Q2.parameters(),self.Q3.parameters()), 162 | lr=self.flags.initial_learning_rate, 163 | betas=(self.flags.beta_1, self.flags.beta_2)) 164 | optimizer_mvae = optim.Adam( 165 | list(self.mm_vae.parameters()), 166 | lr=self.flags.initial_learning_rate, 167 | betas=(self.flags.beta_1, self.flags.beta_2)) 168 | optimizer_Qnet = optim.Adam( 169 | itertools.chain(self.Q1.parameters(),self.Q2.parameters(),self.Q3.parameters()), 170 | lr=self.flags.initial_learning_rate, 171 | betas=(self.flags.beta_1, self.flags.beta_2)) 172 | self.optimizer = {'mvae':optimizer_mvae,'Qnet':optimizer_Qnet,'all':optimizer} 173 | scheduler_mvae = optim.lr_scheduler.StepLR(optimizer_mvae, step_size=20, gamma=1.0) 174 | scheduler_Qnet = optim.lr_scheduler.StepLR(optimizer_Qnet, step_size=20, gamma=1.0) 175 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=1.0) 176 | self.scheduler = {'mvae': scheduler_mvae, 'Qnet': scheduler_Qnet, 'all': scheduler} 177 | 178 | 179 | def set_Qmodel(self): 180 | Q1 = QNet(input_dim=self.flags.m1_dim, latent_dim=self.flags.class_dim).cuda() 181 | Q2 = QNet(input_dim=self.flags.m2_dim, latent_dim=self.flags.class_dim).cuda() 182 | Q3 = QNet(input_dim=self.flags.m3_dim, latent_dim=self.flags.class_dim).cuda() 183 | return Q1, Q2 ,Q3 184 | 185 | def set_rec_weights(self): 186 | weights = dict() 187 | weights['brain'] = self.flags.beta_m1_rec 188 | weights['image'] = self.flags.beta_m2_rec 189 | weights['text'] = self.flags.beta_m3_rec 190 | return weights 191 | 192 | def set_style_weights(self): 193 | weights = dict() 194 | weights['brain'] = self.flags.beta_m1_style 195 | weights['image'] = self.flags.beta_m2_style 196 | weights['text'] = self.flags.beta_m3_style 197 | return weights 198 | -------------------------------------------------------------------------------- /BraVL_fMRI/brain_image_text/experiment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import itertools 4 | import scipy.io as sio 5 | import torch 6 | import torch.optim as optim 7 | from sklearn.metrics import accuracy_score 8 | from sklearn.model_selection import train_test_split 9 | from torch.utils.data import TensorDataset 10 | from modalities.Modality import Modality 11 | from brain_image_text.networks.VAEtrimodal import VAEtrimodal,VAEbimodal 12 | from brain_image_text.networks.QNET import QNet 13 | from brain_image_text.networks.MLP_Brain import EncoderBrain, DecoderBrain 14 | from brain_image_text.networks.MLP_Image import EncoderImage, DecoderImage 15 | from brain_image_text.networks.MLP_Text import EncoderText, DecoderText 16 | from utils.BaseExperiment import BaseExperiment 17 | 18 | 19 | class BrainImageText(BaseExperiment): 20 | def __init__(self, flags, alphabet): 21 | super().__init__(flags) 22 | 23 | self.modalities = self.set_modalities() 24 | self.num_modalities = len(self.modalities.keys()) 25 | self.subsets = self.set_subsets() 26 | self.dataset_train = None 27 | self.dataset_test = None 28 | 29 | self.set_dataset() 30 | self.mm_vae = self.set_model() 31 | self.optimizer = None 32 | self.rec_weights = self.set_rec_weights() 33 | self.style_weights = self.set_style_weights() 34 | self.Q1,self.Q2,self.Q3 = self.set_Qmodel() 35 | self.eval_metric = accuracy_score 36 | 37 | self.labels = ['digit'] 38 | 39 | 40 | def set_model(self): 41 | model = VAEtrimodal(self.flags, self.modalities, self.subsets) 42 | model = model.to(self.flags.device) 43 | return model 44 | 45 | def set_modalities(self): 46 | mod1 = Modality('brain', EncoderBrain(self.flags), DecoderBrain(self.flags), 47 | self.flags.class_dim, self.flags.style_m1_dim, 'normal') 48 | mod2 = Modality('image', EncoderImage(self.flags), DecoderImage(self.flags), 49 | self.flags.class_dim, self.flags.style_m2_dim, 'normal') 50 | mod3 = Modality('text', EncoderText(self.flags), DecoderText(self.flags), 51 | self.flags.class_dim, self.flags.style_m3_dim, 'normal') 52 | mods = {mod1.name: mod1, mod2.name: mod2, mod3.name: mod3} 53 | return mods 54 | 55 | def set_dataset(self): 56 | # load data 57 | data_dir_root = self.flags.data_dir_root 58 | sbj = self.flags.sbj 59 | stability_ratio = self.flags.stability_ratio 60 | image_model = self.flags.image_model 61 | text_model = self.flags.text_model 62 | roi = self.flags.roi 63 | brain_dir = os.path.join(data_dir_root, 'brain_feature', roi, sbj) 64 | image_dir_train = os.path.join(data_dir_root, 'visual_feature/ImageNetTraining', image_model + '-PCA', sbj) 65 | image_dir_test = os.path.join(data_dir_root, 'visual_feature/ImageNetTest', image_model + '-PCA', sbj) 66 | text_dir_train = os.path.join(data_dir_root, 'textual_feature/ImageNetTraining/text', text_model, sbj) 67 | text_dir_test = os.path.join(data_dir_root, 'textual_feature/ImageNetTest/text', text_model, sbj) 68 | 69 | train_brain = sio.loadmat(os.path.join(brain_dir, 'fmri_train_data'+stability_ratio+'.mat'))['data'].astype('double') 70 | train_image = sio.loadmat(os.path.join(image_dir_train, 'feat_pca_train.mat'))['data'].astype('double') 71 | train_text = sio.loadmat(os.path.join(text_dir_train, 'text_feat_train.mat'))['data'].astype('double') 72 | train_label = sio.loadmat(os.path.join(brain_dir, 'fmri_train_data'+stability_ratio+'.mat'))['class_idx'].T.astype('int') 73 | 74 | # test_brain = sio.loadmat(os.path.join(brain_dir, 'fmri_test_data_unique.mat'))['data'].astype('double') 75 | # test_image = sio.loadmat(os.path.join(image_dir_test, 'feat_pca_test_unique.mat'))['data'].astype('double') 76 | # test_text = sio.loadmat(os.path.join(text_dir_test, 'text_feat_test_unique.mat'))['data'].astype('double') 77 | # test_label = sio.loadmat(os.path.join(brain_dir, 'fmri_test_data_unique.mat'))['class_idx'].T.astype('int') 78 | 79 | test_brain = sio.loadmat(os.path.join(brain_dir, 'fmri_test_data'+stability_ratio+'.mat'))['data'].astype('double') 80 | test_image = sio.loadmat(os.path.join(image_dir_test, 'feat_pca_test.mat'))['data'].astype('double') 81 | test_text = sio.loadmat(os.path.join(text_dir_test, 'text_feat_test.mat'))['data'].astype('double') 82 | test_label = sio.loadmat(os.path.join(brain_dir, 'fmri_test_data'+stability_ratio+'.mat'))['class_idx'].T.astype('int') 83 | 84 | if self.flags.aug_type == 'image_text': 85 | image_dir_aug = os.path.join(data_dir_root, 'visual_feature/Aug_1000', image_model + '-PCA', sbj) 86 | text_dir_aug = os.path.join(data_dir_root, 'textual_feature/Aug_1000/text', text_model, sbj) 87 | aug_image = sio.loadmat(os.path.join(image_dir_aug, 'feat_pca_aug.mat'))['data'].astype('double') 88 | aug_text = sio.loadmat(os.path.join(text_dir_aug, 'text_feat_aug.mat'))['data'].astype('double') 89 | aug_image = torch.from_numpy(aug_image) 90 | aug_text = torch.from_numpy(aug_text) 91 | print('aug_image=', aug_image.shape) 92 | print('aug_text=', aug_text.shape) 93 | elif self.flags.aug_type == 'text_only': 94 | text_dir_aug = os.path.join(data_dir_root, 'textual_feature/Aug_1000/text', text_model, sbj) 95 | aug_text = sio.loadmat(os.path.join(text_dir_aug, 'text_feat_aug.mat'))['data'].astype('double') 96 | aug_text = aug_text 97 | aug_text = torch.from_numpy(aug_text) 98 | print('aug_text=', aug_text.shape) 99 | 100 | elif self.flags.aug_type == 'image_only': 101 | image_dir_aug = os.path.join(data_dir_root, 'visual_feature/Aug_1000', image_model + '-PCA', sbj) 102 | aug_image = sio.loadmat(os.path.join(image_dir_aug, 'feat_pca_aug.mat'))['data'].astype('double') 103 | aug_image = torch.from_numpy(aug_image) 104 | print('aug_image=', aug_image.shape) 105 | elif self.flags.aug_type == 'no_aug': 106 | print('no augmentation') 107 | 108 | if self.flags.test_type=='normal': 109 | train_label_stratify = train_label 110 | train_brain, val_brain, train_label, val_label = train_test_split(train_brain, train_label_stratify, test_size=0.2, stratify=train_label_stratify) 111 | train_image, val_image, train_label, val_label = train_test_split(train_image, train_label_stratify, test_size=0.2, stratify=train_label_stratify) 112 | train_text, val_text, train_label, val_label = train_test_split(train_text, train_label_stratify, test_size=0.2, stratify=train_label_stratify) 113 | 114 | val_brain = torch.from_numpy(val_brain) 115 | val_image = torch.from_numpy(val_image) 116 | val_text = torch.from_numpy(val_text) 117 | val_label = torch.from_numpy(val_label) 118 | print('val_brain=', val_brain.shape) 119 | print('val_image=', val_image.shape) 120 | print('val_text=', val_text.shape) 121 | 122 | train_brain = torch.from_numpy(train_brain) 123 | test_brain = torch.from_numpy(test_brain) 124 | train_image = torch.from_numpy(train_image) 125 | test_image = torch.from_numpy(test_image) 126 | train_text = torch.from_numpy(train_text) 127 | test_text = torch.from_numpy(test_text) 128 | train_label = torch.from_numpy(train_label) 129 | test_label = torch.from_numpy(test_label) 130 | 131 | 132 | 133 | print('train_brain=', train_brain.shape) 134 | print('train_image=', train_image.shape) 135 | print('train_text=', train_text.shape) 136 | print('test_brain=', test_brain.shape) 137 | print('test_image=', test_image.shape) 138 | print('test_text=', test_text.shape) 139 | 140 | self.m1_dim = train_brain.shape[1] 141 | self.m2_dim = train_image.shape[1] 142 | self.m3_dim = train_text.shape[1] 143 | 144 | train_dataset = torch.utils.data.TensorDataset(train_brain, train_image, train_text, train_label) 145 | test_dataset = torch.utils.data.TensorDataset(test_brain, test_image, test_text,test_label) 146 | 147 | self.dataset_train = train_dataset 148 | self.dataset_test = test_dataset 149 | 150 | if self.flags.test_type == 'normal': 151 | val_dataset = torch.utils.data.TensorDataset(val_brain, val_image, val_text, val_label) 152 | self.dataset_val = val_dataset 153 | 154 | if self.flags.aug_type == 'image_text': 155 | aug_dataset = torch.utils.data.TensorDataset(aug_image, aug_text) 156 | self.dataset_aug = aug_dataset 157 | elif self.flags.aug_type == 'text_only': 158 | aug_dataset = torch.utils.data.TensorDataset(aug_text) 159 | self.dataset_aug = aug_dataset 160 | elif self.flags.aug_type == 'image_only': 161 | aug_image = torch.utils.data.TensorDataset(aug_image) 162 | self.dataset_aug = aug_image 163 | elif self.flags.aug_type == 'no_aug': 164 | print('no augmentation') 165 | 166 | 167 | def set_optimizer(self): 168 | optimizer = optim.Adam( 169 | itertools.chain(self.mm_vae.parameters(),self.Q1.parameters(),self.Q2.parameters(),self.Q3.parameters()), 170 | lr=self.flags.initial_learning_rate, 171 | betas=(self.flags.beta_1, self.flags.beta_2)) 172 | optimizer_mvae = optim.Adam( 173 | list(self.mm_vae.parameters()), 174 | lr=self.flags.initial_learning_rate, 175 | betas=(self.flags.beta_1, self.flags.beta_2)) 176 | optimizer_Qnet = optim.Adam( 177 | itertools.chain(self.Q1.parameters(),self.Q2.parameters(),self.Q3.parameters()), 178 | lr=self.flags.initial_learning_rate, 179 | betas=(self.flags.beta_1, self.flags.beta_2)) 180 | self.optimizer = {'mvae':optimizer_mvae,'Qnet':optimizer_Qnet,'all':optimizer} 181 | scheduler_mvae = optim.lr_scheduler.StepLR(optimizer_mvae, step_size=20, gamma=1.0) 182 | scheduler_Qnet = optim.lr_scheduler.StepLR(optimizer_Qnet, step_size=20, gamma=1.0) 183 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=1.0) 184 | self.scheduler = {'mvae': scheduler_mvae, 'Qnet': scheduler_Qnet, 'all': scheduler} 185 | 186 | 187 | def set_Qmodel(self): 188 | Q1 = QNet(input_dim=self.flags.m1_dim, latent_dim=self.flags.class_dim).cuda() 189 | Q2 = QNet(input_dim=self.flags.m2_dim, latent_dim=self.flags.class_dim).cuda() 190 | Q3 = QNet(input_dim=self.flags.m3_dim, latent_dim=self.flags.class_dim).cuda() 191 | return Q1, Q2 ,Q3 192 | 193 | def set_rec_weights(self): 194 | weights = dict() 195 | weights['brain'] = self.flags.beta_m1_rec 196 | weights['image'] = self.flags.beta_m2_rec 197 | weights['text'] = self.flags.beta_m3_rec 198 | return weights 199 | 200 | def set_style_weights(self): 201 | weights = dict() 202 | weights['brain'] = self.flags.beta_m1_style 203 | weights['image'] = self.flags.beta_m2_style 204 | weights['text'] = self.flags.beta_m3_style 205 | return weights 206 | -------------------------------------------------------------------------------- /BraVL_EEG/utils/BaseMMVae.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.distributions as dist 9 | from divergence_measures.mm_div import calc_alphaJSD_modalities 10 | from divergence_measures.mm_div import calc_group_divergence_moe 11 | from divergence_measures.mm_div import poe 12 | 13 | from utils import utils 14 | 15 | 16 | class BaseMMVae(ABC, nn.Module): 17 | def __init__(self, flags, modalities, subsets): 18 | super(BaseMMVae, self).__init__() 19 | self.num_modalities = len(modalities.keys()); 20 | self.flags = flags; 21 | self.modalities = modalities; 22 | self.subsets = subsets; 23 | self.set_fusion_functions(); 24 | 25 | encoders = nn.ModuleDict(); 26 | decoders = nn.ModuleDict(); 27 | lhoods = dict(); 28 | for m, m_key in enumerate(sorted(modalities.keys())): 29 | encoders[m_key] = modalities[m_key].encoder; 30 | decoders[m_key] = modalities[m_key].decoder; 31 | lhoods[m_key] = modalities[m_key].likelihood; 32 | self.encoders = encoders; 33 | self.decoders = decoders; 34 | self.lhoods = lhoods; 35 | 36 | 37 | def reparameterize(self, mu, logvar): 38 | std = logvar.mul(0.5).exp_() 39 | eps = Variable(std.data.new(std.size()).normal_()) 40 | return eps.mul(std).add_(mu) 41 | 42 | 43 | def set_fusion_functions(self): 44 | weights = utils.reweight_weights(torch.Tensor(self.flags.alpha_modalities)); 45 | self.weights = weights.to(self.flags.device); 46 | if self.flags.modality_moe: 47 | self.modality_fusion = self.moe_fusion; 48 | self.fusion_condition = self.fusion_condition_moe; self.calc_joint_divergence = self.divergence_static_prior; 49 | elif self.flags.modality_jsd: 50 | self.modality_fusion = self.moe_fusion; 51 | self.fusion_condition = self.fusion_condition_moe; 52 | self.calc_joint_divergence = self.divergence_dynamic_prior; 53 | elif self.flags.modality_poe: 54 | self.modality_fusion = self.poe_fusion; 55 | self.fusion_condition = self.fusion_condition_poe; 56 | self.calc_joint_divergence = self.divergence_static_prior; 57 | elif self.flags.joint_elbo: 58 | self.modality_fusion = self.poe_fusion; 59 | self.fusion_condition = self.fusion_condition_joint; 60 | self.calc_joint_divergence = self.divergence_static_prior; 61 | 62 | 63 | def divergence_static_prior(self, mus, logvars, weights=None): 64 | if weights is None: 65 | weights=self.weights; 66 | weights = weights.clone(); 67 | weights = utils.reweight_weights(weights); 68 | div_measures = calc_group_divergence_moe(self.flags, 69 | mus, 70 | logvars, 71 | weights, 72 | normalization=self.flags.batch_size); 73 | divs = dict(); 74 | divs['joint_divergence'] = div_measures[0]; divs['individual_divs'] = div_measures[1]; divs['dyn_prior'] = None; 75 | return divs; 76 | 77 | 78 | def divergence_dynamic_prior(self, mus, logvars, weights=None): 79 | if weights is None: 80 | weights = self.weights; 81 | div_measures = calc_alphaJSD_modalities(self.flags, 82 | mus, 83 | logvars, 84 | weights, 85 | normalization=self.flags.batch_size); 86 | divs = dict(); 87 | divs['joint_divergence'] = div_measures[0]; 88 | divs['individual_divs'] = div_measures[1]; 89 | divs['dyn_prior'] = div_measures[2]; 90 | return divs; 91 | 92 | 93 | def moe_fusion(self, mus, logvars, weights=None): 94 | if weights is None: 95 | weights = self.weights; 96 | weights = utils.reweight_weights(weights); 97 | #mus = torch.cat(mus, dim=0); 98 | #logvars = torch.cat(logvars, dim=0); 99 | mu_moe, logvar_moe = utils.mixture_component_selection(self.flags, 100 | mus, 101 | logvars, 102 | weights); 103 | return [mu_moe, logvar_moe]; 104 | 105 | 106 | def poe_fusion(self, mus, logvars, weights=None): 107 | if (self.flags.modality_poe or mus.shape[0] == 108 | len(self.modalities.keys())): 109 | num_samples = mus[0].shape[0]; 110 | mus = torch.cat((mus, torch.zeros(1, num_samples, 111 | self.flags.class_dim).to(self.flags.device)), 112 | dim=0); 113 | logvars = torch.cat((logvars, torch.zeros(1, num_samples, 114 | self.flags.class_dim).to(self.flags.device)), 115 | dim=0); 116 | #mus = torch.cat(mus, dim=0); 117 | #logvars = torch.cat(logvars, dim=0); 118 | mu_poe, logvar_poe = poe(mus, logvars); 119 | return [mu_poe, logvar_poe]; 120 | 121 | 122 | def fusion_condition_moe(self, subset, input_batch=None): 123 | if len(subset) == 1: 124 | return True; 125 | else: 126 | return False; 127 | 128 | 129 | def fusion_condition_poe(self, subset, input_batch=None): 130 | if len(subset) == len(input_batch.keys()): 131 | return True; 132 | else: 133 | return False; 134 | 135 | 136 | def fusion_condition_joint(self, subset, input_batch=None): 137 | return True; 138 | 139 | 140 | def forward(self, input_batch,K=1): 141 | latents = self.inference(input_batch); 142 | results = dict(); 143 | results['latents'] = latents; 144 | results['group_distr'] = latents['joint']; 145 | class_embeddings = self.reparameterize(latents['joint'][0], 146 | latents['joint'][1]); 147 | #### For CUBO #### 148 | qz_x = dist.Normal(latents['joint'][0],latents['joint'][1].mul(0.5).exp_()) 149 | zss = qz_x.rsample(torch.Size([K])) 150 | 151 | div = self.calc_joint_divergence(latents['mus'], 152 | latents['logvars'], 153 | latents['weights']); 154 | for k, key in enumerate(div.keys()): 155 | results[key] = div[key]; 156 | 157 | results_rec = dict(); 158 | px_zs = dict(); 159 | enc_mods = latents['modalities']; 160 | for m, m_key in enumerate(self.modalities.keys()): 161 | if m_key in input_batch.keys(): 162 | m_s_mu, m_s_logvar = enc_mods[m_key + '_style']; 163 | if self.flags.factorized_representation: 164 | m_s_embeddings = self.reparameterize(mu=m_s_mu, logvar=m_s_logvar); 165 | else: 166 | m_s_embeddings = None; 167 | m_rec = self.lhoods[m_key](*self.decoders[m_key](m_s_embeddings, class_embeddings)); 168 | px_z = self.lhoods[m_key](*self.decoders[m_key](m_s_embeddings, zss)); 169 | results_rec[m_key] = m_rec; 170 | px_zs[m_key] = px_z 171 | results['rec'] = results_rec; 172 | results['class_embeddings'] = class_embeddings 173 | results['qz_x'] = qz_x 174 | results['zss'] = zss 175 | results['px_zs'] = px_zs 176 | return results; 177 | 178 | def encode(self, input_batch): 179 | latents = dict(); 180 | for m, m_key in enumerate(self.modalities.keys()): 181 | if m_key in input_batch.keys(): 182 | i_m = input_batch[m_key]; 183 | l = self.encoders[m_key](i_m) 184 | latents[m_key + '_style'] = l[:2] 185 | latents[m_key] = l[2:] 186 | else: 187 | latents[m_key + '_style'] = [None, None]; 188 | latents[m_key] = [None, None]; 189 | return latents; 190 | 191 | 192 | def inference(self, input_batch, num_samples=None): 193 | if num_samples is None: 194 | num_samples = self.flags.batch_size; 195 | latents = dict(); 196 | enc_mods = self.encode(input_batch); 197 | latents['modalities'] = enc_mods; 198 | mus = torch.Tensor().to(self.flags.device); 199 | logvars = torch.Tensor().to(self.flags.device); 200 | distr_subsets = dict(); 201 | for k, s_key in enumerate(self.subsets.keys()): 202 | if s_key != '': 203 | mods = self.subsets[s_key]; 204 | mus_subset = torch.Tensor().to(self.flags.device); 205 | logvars_subset = torch.Tensor().to(self.flags.device); 206 | mods_avail = True 207 | for m, mod in enumerate(mods): 208 | if mod.name in input_batch.keys(): 209 | mus_subset = torch.cat((mus_subset, 210 | enc_mods[mod.name][0].unsqueeze(0)), 211 | dim=0); 212 | logvars_subset = torch.cat((logvars_subset, 213 | enc_mods[mod.name][1].unsqueeze(0)), 214 | dim=0); 215 | else: 216 | mods_avail = False; 217 | if mods_avail: 218 | weights_subset = ((1/float(len(mus_subset)))* 219 | torch.ones(len(mus_subset)).to(self.flags.device)); 220 | s_mu, s_logvar = self.modality_fusion(mus_subset, 221 | logvars_subset, 222 | weights_subset); #子集内部POE# 223 | distr_subsets[s_key] = [s_mu, s_logvar]; 224 | if self.fusion_condition(mods, input_batch): 225 | mus = torch.cat((mus, s_mu.unsqueeze(0)), dim=0); 226 | logvars = torch.cat((logvars, s_logvar.unsqueeze(0)), 227 | dim=0); 228 | if self.flags.modality_jsd: 229 | num_samples = mus[0].shape[0] 230 | mus = torch.cat((mus, torch.zeros(1, num_samples, 231 | self.flags.class_dim).to(self.flags.device)), 232 | dim=0); 233 | logvars = torch.cat((logvars, torch.zeros(1, num_samples, 234 | self.flags.class_dim).to(self.flags.device)), 235 | dim=0); 236 | #weights = (1/float(len(mus)))*torch.ones(len(mus)).to(self.flags.device); 237 | weights = (1/float(mus.shape[0]))*torch.ones(mus.shape[0]).to(self.flags.device); 238 | joint_mu, joint_logvar = self.moe_fusion(mus, logvars, weights); #子集之间MOE# 239 | #mus = torch.cat(mus, dim=0); 240 | #logvars = torch.cat(logvars, dim=0); 241 | latents['mus'] = mus; 242 | latents['logvars'] = logvars; 243 | latents['weights'] = weights; 244 | latents['joint'] = [joint_mu, joint_logvar]; 245 | latents['subsets'] = distr_subsets; 246 | return latents; 247 | 248 | 249 | def generate(self, num_samples=None): 250 | if num_samples is None: 251 | num_samples = self.flags.batch_size; 252 | 253 | mu = torch.zeros(num_samples, 254 | self.flags.class_dim).to(self.flags.device); 255 | logvar = torch.zeros(num_samples, 256 | self.flags.class_dim).to(self.flags.device); 257 | z_class = self.reparameterize(mu, logvar); 258 | z_styles = self.get_random_styles(num_samples); 259 | random_latents = {'content': z_class, 'style': z_styles}; 260 | random_samples = self.generate_from_latents(random_latents); 261 | return random_samples; 262 | 263 | 264 | def generate_sufficient_statistics_from_latents(self, latents): 265 | suff_stats = dict(); 266 | content = latents['content'] 267 | for m, m_key in enumerate(self.modalities.keys()): 268 | s = latents['style'][m_key]; 269 | cg = self.lhoods[m_key](*self.decoders[m_key](s, content)); 270 | suff_stats[m_key] = cg; 271 | return suff_stats; 272 | 273 | 274 | def generate_from_latents(self, latents): 275 | suff_stats = self.generate_sufficient_statistics_from_latents(latents); 276 | cond_gen = dict(); 277 | for m, m_key in enumerate(latents['style'].keys()): 278 | cond_gen_m = suff_stats[m_key].mean; 279 | cond_gen[m_key] = cond_gen_m; 280 | return cond_gen; 281 | 282 | 283 | def cond_generation(self, latent_distributions, num_samples=None): 284 | if num_samples is None: 285 | num_samples = self.flags.batch_size; 286 | 287 | style_latents = self.get_random_styles(num_samples); 288 | cond_gen_samples = dict(); 289 | for k, key in enumerate(latent_distributions.keys()): 290 | [mu, logvar] = latent_distributions[key]; 291 | content_rep = self.reparameterize(mu=mu, logvar=logvar); 292 | latents = {'content': content_rep, 'style': style_latents} 293 | cond_gen_samples[key] = self.generate_from_latents(latents); 294 | return cond_gen_samples; 295 | 296 | 297 | def get_random_style_dists(self, num_samples): 298 | styles = dict(); 299 | for k, m_key in enumerate(self.modalities.keys()): 300 | mod = self.modalities[m_key]; 301 | s_mu = torch.zeros(num_samples, 302 | mod.style_dim).to(self.flags.device) 303 | s_logvar = torch.zeros(num_samples, 304 | mod.style_dim).to(self.flags.device); 305 | styles[m_key] = [s_mu, s_logvar]; 306 | return styles; 307 | 308 | 309 | def get_random_styles(self, num_samples): 310 | styles = dict(); 311 | for k, m_key in enumerate(self.modalities.keys()): 312 | if self.flags.factorized_representation: 313 | mod = self.modalities[m_key]; 314 | z_style = torch.randn(num_samples, mod.style_dim); 315 | z_style = z_style.to(self.flags.device); 316 | else: 317 | z_style = None; 318 | styles[m_key] = z_style; 319 | return styles; 320 | 321 | 322 | def save_networks(self): 323 | for k, m_key in enumerate(self.modalities.keys()): 324 | torch.save(self.encoders[m_key].state_dict(), 325 | os.path.join(self.flags.dir_checkpoints, 'enc_' + 326 | self.modalities[m_key].name)) 327 | torch.save(self.decoders[m_key].state_dict(), 328 | os.path.join(self.flags.dir_checkpoints, 'dec_' + 329 | self.modalities[m_key].name)) 330 | 331 | 332 | 333 | 334 | -------------------------------------------------------------------------------- /BraVL_fMRI/utils/BaseMMVae.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.autograd import Variable 8 | import torch.distributions as dist 9 | from divergence_measures.mm_div import calc_alphaJSD_modalities 10 | from divergence_measures.mm_div import calc_group_divergence_moe 11 | from divergence_measures.mm_div import poe 12 | 13 | from utils import utils 14 | 15 | 16 | class BaseMMVae(ABC, nn.Module): 17 | def __init__(self, flags, modalities, subsets): 18 | super(BaseMMVae, self).__init__() 19 | self.num_modalities = len(modalities.keys()); 20 | self.flags = flags; 21 | self.modalities = modalities; 22 | self.subsets = subsets; 23 | self.set_fusion_functions(); 24 | 25 | encoders = nn.ModuleDict(); 26 | decoders = nn.ModuleDict(); 27 | lhoods = dict(); 28 | for m, m_key in enumerate(sorted(modalities.keys())): 29 | encoders[m_key] = modalities[m_key].encoder; 30 | decoders[m_key] = modalities[m_key].decoder; 31 | lhoods[m_key] = modalities[m_key].likelihood; 32 | self.encoders = encoders; 33 | self.decoders = decoders; 34 | self.lhoods = lhoods; 35 | 36 | 37 | def reparameterize(self, mu, logvar): 38 | std = logvar.mul(0.5).exp_() 39 | eps = Variable(std.data.new(std.size()).normal_()) 40 | return eps.mul(std).add_(mu) 41 | 42 | 43 | def set_fusion_functions(self): 44 | weights = utils.reweight_weights(torch.Tensor(self.flags.alpha_modalities)); 45 | self.weights = weights.to(self.flags.device); 46 | if self.flags.modality_moe: 47 | self.modality_fusion = self.moe_fusion; 48 | self.fusion_condition = self.fusion_condition_moe; self.calc_joint_divergence = self.divergence_static_prior; 49 | elif self.flags.modality_jsd: 50 | self.modality_fusion = self.moe_fusion; 51 | self.fusion_condition = self.fusion_condition_moe; 52 | self.calc_joint_divergence = self.divergence_dynamic_prior; 53 | elif self.flags.modality_poe: 54 | self.modality_fusion = self.poe_fusion; 55 | self.fusion_condition = self.fusion_condition_poe; 56 | self.calc_joint_divergence = self.divergence_static_prior; 57 | elif self.flags.joint_elbo: 58 | self.modality_fusion = self.poe_fusion; 59 | self.fusion_condition = self.fusion_condition_joint; 60 | self.calc_joint_divergence = self.divergence_static_prior; 61 | 62 | 63 | def divergence_static_prior(self, mus, logvars, weights=None): 64 | if weights is None: 65 | weights=self.weights; 66 | weights = weights.clone(); 67 | weights = utils.reweight_weights(weights); 68 | div_measures = calc_group_divergence_moe(self.flags, 69 | mus, 70 | logvars, 71 | weights, 72 | normalization=self.flags.batch_size); 73 | divs = dict(); 74 | divs['joint_divergence'] = div_measures[0]; divs['individual_divs'] = div_measures[1]; divs['dyn_prior'] = None; 75 | return divs; 76 | 77 | 78 | def divergence_dynamic_prior(self, mus, logvars, weights=None): 79 | if weights is None: 80 | weights = self.weights; 81 | div_measures = calc_alphaJSD_modalities(self.flags, 82 | mus, 83 | logvars, 84 | weights, 85 | normalization=self.flags.batch_size); 86 | divs = dict(); 87 | divs['joint_divergence'] = div_measures[0]; 88 | divs['individual_divs'] = div_measures[1]; 89 | divs['dyn_prior'] = div_measures[2]; 90 | return divs; 91 | 92 | 93 | def moe_fusion(self, mus, logvars, weights=None): 94 | if weights is None: 95 | weights = self.weights; 96 | weights = utils.reweight_weights(weights); 97 | #mus = torch.cat(mus, dim=0); 98 | #logvars = torch.cat(logvars, dim=0); 99 | mu_moe, logvar_moe = utils.mixture_component_selection(self.flags, 100 | mus, 101 | logvars, 102 | weights); 103 | return [mu_moe, logvar_moe]; 104 | 105 | 106 | def poe_fusion(self, mus, logvars, weights=None): 107 | if (self.flags.modality_poe or mus.shape[0] == 108 | len(self.modalities.keys())): 109 | num_samples = mus[0].shape[0]; 110 | mus = torch.cat((mus, torch.zeros(1, num_samples, 111 | self.flags.class_dim).to(self.flags.device)), 112 | dim=0); 113 | logvars = torch.cat((logvars, torch.zeros(1, num_samples, 114 | self.flags.class_dim).to(self.flags.device)), 115 | dim=0); 116 | #mus = torch.cat(mus, dim=0); 117 | #logvars = torch.cat(logvars, dim=0); 118 | mu_poe, logvar_poe = poe(mus, logvars); 119 | return [mu_poe, logvar_poe]; 120 | 121 | 122 | def fusion_condition_moe(self, subset, input_batch=None): 123 | if len(subset) == 1: 124 | return True; 125 | else: 126 | return False; 127 | 128 | 129 | def fusion_condition_poe(self, subset, input_batch=None): 130 | if len(subset) == len(input_batch.keys()): 131 | return True; 132 | else: 133 | return False; 134 | 135 | 136 | def fusion_condition_joint(self, subset, input_batch=None): 137 | return True; 138 | 139 | 140 | def forward(self, input_batch,K=1): 141 | latents = self.inference(input_batch); 142 | results = dict(); 143 | results['latents'] = latents; 144 | results['group_distr'] = latents['joint']; 145 | class_embeddings = self.reparameterize(latents['joint'][0], 146 | latents['joint'][1]); 147 | #### For CUBO #### 148 | qz_x = dist.Normal(latents['joint'][0],latents['joint'][1].mul(0.5).exp_()) 149 | zss = qz_x.rsample(torch.Size([K])) 150 | 151 | div = self.calc_joint_divergence(latents['mus'], 152 | latents['logvars'], 153 | latents['weights']); 154 | for k, key in enumerate(div.keys()): 155 | results[key] = div[key]; 156 | 157 | results_rec = dict(); 158 | px_zs = dict(); 159 | enc_mods = latents['modalities']; 160 | for m, m_key in enumerate(self.modalities.keys()): 161 | if m_key in input_batch.keys(): 162 | m_s_mu, m_s_logvar = enc_mods[m_key + '_style']; 163 | if self.flags.factorized_representation: 164 | m_s_embeddings = self.reparameterize(mu=m_s_mu, logvar=m_s_logvar); 165 | else: 166 | m_s_embeddings = None; 167 | m_rec = self.lhoods[m_key](*self.decoders[m_key](m_s_embeddings, class_embeddings)); 168 | px_z = self.lhoods[m_key](*self.decoders[m_key](m_s_embeddings, zss)); 169 | results_rec[m_key] = m_rec; 170 | px_zs[m_key] = px_z 171 | results['rec'] = results_rec; 172 | results['class_embeddings'] = class_embeddings 173 | results['qz_x'] = qz_x 174 | results['zss'] = zss 175 | results['px_zs'] = px_zs 176 | return results; 177 | 178 | def encode(self, input_batch): 179 | latents = dict(); 180 | for m, m_key in enumerate(self.modalities.keys()): 181 | if m_key in input_batch.keys(): 182 | i_m = input_batch[m_key]; 183 | l = self.encoders[m_key](i_m) 184 | latents[m_key + '_style'] = l[:2] 185 | latents[m_key] = l[2:] 186 | else: 187 | latents[m_key + '_style'] = [None, None]; 188 | latents[m_key] = [None, None]; 189 | return latents; 190 | 191 | 192 | def inference(self, input_batch, num_samples=None): 193 | if num_samples is None: 194 | num_samples = self.flags.batch_size; 195 | latents = dict(); 196 | enc_mods = self.encode(input_batch); 197 | latents['modalities'] = enc_mods; 198 | mus = torch.Tensor().to(self.flags.device); 199 | logvars = torch.Tensor().to(self.flags.device); 200 | distr_subsets = dict(); 201 | for k, s_key in enumerate(self.subsets.keys()): 202 | if s_key != '': 203 | mods = self.subsets[s_key]; 204 | mus_subset = torch.Tensor().to(self.flags.device); 205 | logvars_subset = torch.Tensor().to(self.flags.device); 206 | mods_avail = True 207 | for m, mod in enumerate(mods): 208 | if mod.name in input_batch.keys(): 209 | mus_subset = torch.cat((mus_subset, 210 | enc_mods[mod.name][0].unsqueeze(0)), 211 | dim=0); 212 | logvars_subset = torch.cat((logvars_subset, 213 | enc_mods[mod.name][1].unsqueeze(0)), 214 | dim=0); 215 | else: 216 | mods_avail = False; 217 | if mods_avail: 218 | weights_subset = ((1/float(len(mus_subset)))* 219 | torch.ones(len(mus_subset)).to(self.flags.device)); 220 | s_mu, s_logvar = self.modality_fusion(mus_subset, 221 | logvars_subset, 222 | weights_subset); #子集内部POE# 223 | distr_subsets[s_key] = [s_mu, s_logvar]; 224 | if self.fusion_condition(mods, input_batch): 225 | mus = torch.cat((mus, s_mu.unsqueeze(0)), dim=0); 226 | logvars = torch.cat((logvars, s_logvar.unsqueeze(0)), 227 | dim=0); 228 | if self.flags.modality_jsd: 229 | num_samples = mus[0].shape[0] 230 | mus = torch.cat((mus, torch.zeros(1, num_samples, 231 | self.flags.class_dim).to(self.flags.device)), 232 | dim=0); 233 | logvars = torch.cat((logvars, torch.zeros(1, num_samples, 234 | self.flags.class_dim).to(self.flags.device)), 235 | dim=0); 236 | #weights = (1/float(len(mus)))*torch.ones(len(mus)).to(self.flags.device); 237 | weights = (1/float(mus.shape[0]))*torch.ones(mus.shape[0]).to(self.flags.device); 238 | joint_mu, joint_logvar = self.moe_fusion(mus, logvars, weights); #子集之间MOE# 239 | #mus = torch.cat(mus, dim=0); 240 | #logvars = torch.cat(logvars, dim=0); 241 | latents['mus'] = mus; 242 | latents['logvars'] = logvars; 243 | latents['weights'] = weights; 244 | latents['joint'] = [joint_mu, joint_logvar]; 245 | latents['subsets'] = distr_subsets; 246 | return latents; 247 | 248 | 249 | def generate(self, num_samples=None): 250 | if num_samples is None: 251 | num_samples = self.flags.batch_size; 252 | 253 | mu = torch.zeros(num_samples, 254 | self.flags.class_dim).to(self.flags.device); 255 | logvar = torch.zeros(num_samples, 256 | self.flags.class_dim).to(self.flags.device); 257 | z_class = self.reparameterize(mu, logvar); 258 | z_styles = self.get_random_styles(num_samples); 259 | random_latents = {'content': z_class, 'style': z_styles}; 260 | random_samples = self.generate_from_latents(random_latents); 261 | return random_samples; 262 | 263 | 264 | def generate_sufficient_statistics_from_latents(self, latents): 265 | suff_stats = dict(); 266 | content = latents['content'] 267 | for m, m_key in enumerate(self.modalities.keys()): 268 | s = latents['style'][m_key]; 269 | cg = self.lhoods[m_key](*self.decoders[m_key](s, content)); 270 | suff_stats[m_key] = cg; 271 | return suff_stats; 272 | 273 | 274 | def generate_from_latents(self, latents): 275 | suff_stats = self.generate_sufficient_statistics_from_latents(latents); 276 | cond_gen = dict(); 277 | for m, m_key in enumerate(latents['style'].keys()): 278 | cond_gen_m = suff_stats[m_key].mean; 279 | cond_gen[m_key] = cond_gen_m; 280 | return cond_gen; 281 | 282 | 283 | def cond_generation(self, latent_distributions, num_samples=None): 284 | if num_samples is None: 285 | num_samples = self.flags.batch_size; 286 | 287 | style_latents = self.get_random_styles(num_samples); 288 | cond_gen_samples = dict(); 289 | for k, key in enumerate(latent_distributions.keys()): 290 | [mu, logvar] = latent_distributions[key]; 291 | content_rep = self.reparameterize(mu=mu, logvar=logvar); 292 | latents = {'content': content_rep, 'style': style_latents} 293 | cond_gen_samples[key] = self.generate_from_latents(latents); 294 | return cond_gen_samples; 295 | 296 | 297 | def get_random_style_dists(self, num_samples): 298 | styles = dict(); 299 | for k, m_key in enumerate(self.modalities.keys()): 300 | mod = self.modalities[m_key]; 301 | s_mu = torch.zeros(num_samples, 302 | mod.style_dim).to(self.flags.device) 303 | s_logvar = torch.zeros(num_samples, 304 | mod.style_dim).to(self.flags.device); 305 | styles[m_key] = [s_mu, s_logvar]; 306 | return styles; 307 | 308 | 309 | def get_random_styles(self, num_samples): 310 | styles = dict(); 311 | for k, m_key in enumerate(self.modalities.keys()): 312 | if self.flags.factorized_representation: 313 | mod = self.modalities[m_key]; 314 | z_style = torch.randn(num_samples, mod.style_dim); 315 | z_style = z_style.to(self.flags.device); 316 | else: 317 | z_style = None; 318 | styles[m_key] = z_style; 319 | return styles; 320 | 321 | 322 | def save_networks(self): 323 | for k, m_key in enumerate(self.modalities.keys()): 324 | torch.save(self.encoders[m_key].state_dict(), 325 | os.path.join(self.flags.dir_checkpoints, 'enc_' + 326 | self.modalities[m_key].name)) 327 | torch.save(self.decoders[m_key].state_dict(), 328 | os.path.join(self.flags.dir_checkpoints, 'dec_' + 329 | self.modalities[m_key].name)) 330 | 331 | 332 | 333 | 334 | -------------------------------------------------------------------------------- /BraVL_fMRI/data_prepare_with_aug_GOD_Wiki.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from itertools import product 3 | import os 4 | import pickle 5 | import bdpy 6 | from bdpy.dataform import Features 7 | from bdpy.util import dump_info, makedir_ifnot 8 | import numpy as np 9 | from sklearn.decomposition import PCA 10 | from scipy import io 11 | 12 | # Settings ################################################################### 13 | seed = 42 14 | TINY = 1e-8 15 | # Python RNG 16 | np.random.seed(seed) 17 | 18 | subject_set=['subject1','subject2','subject3','subject4','subject5'] 19 | for subject in subject_set: 20 | if subject == 'subject1': 21 | subjects_list = { 22 | 'sub-01': 'Subject1.h5', 23 | } 24 | 25 | subjects_list_test = { 26 | 'sub-01': 'Subject1.h5', 27 | } 28 | elif subject == 'subject2': 29 | subjects_list = { 30 | 'sub-02': 'Subject2.h5', 31 | } 32 | 33 | subjects_list_test = { 34 | 'sub-02': 'Subject2.h5', 35 | } 36 | elif subject == 'subject3': 37 | subjects_list = { 38 | 'sub-03': 'Subject3.h5', 39 | } 40 | 41 | subjects_list_test = { 42 | 'sub-03': 'Subject3.h5', 43 | } 44 | elif subject == 'subject4': 45 | subjects_list = { 46 | 'sub-04': 'Subject4.h5', 47 | } 48 | 49 | subjects_list_test = { 50 | 'sub-04': 'Subject4.h5', 51 | } 52 | elif subject == 'subject5': 53 | subjects_list = { 54 | 'sub-05': 'Subject5.h5', 55 | } 56 | 57 | subjects_list_test = { 58 | 'sub-05': 'Subject5.h5', 59 | } 60 | 61 | text_embedding_list = [ 62 | 'GPTNeo', 63 | 'ALBERT', 64 | # 'GPTNeo_phrases', 65 | # 'ALBERT_phrases' 66 | ] 67 | 68 | 69 | rois_list = { 70 | # 'VC': 'ROI_VC = 1', 71 | # 'LVC': 'ROI_LVC = 1', 72 | # 'HVC': 'ROI_HVC = 1', 73 | 'V1': 'ROI_V1 = 1', 74 | 'V2': 'ROI_V2 = 1', 75 | 'V3': 'ROI_V3 = 1', 76 | 'V4': 'ROI_V4 = 1', 77 | 'LOC': 'ROI_LOC = 1', 78 | 'FFA': 'ROI_FFA = 1', 79 | 'PPA': 'ROI_PPA = 1', 80 | } 81 | 82 | network = 'pytorch/repvgg_b3g4' 83 | features_list = [ # 'Conv_0', 84 | # 'Conv_1', 85 | 'Conv_2', 86 | 'Conv_3', 87 | 'Conv_4', 88 | 'linear', 89 | 'final'] 90 | 91 | features_list = features_list[::-1] # Start training from deep layers 92 | 93 | # Brain data 94 | brain_dir = './data/GenericObjectDecoding-v2' 95 | # Image features 96 | timm_extracted_visual_features = './data/GOD-Wiki/visual_feature/ImageNetTraining/' + network 97 | timm_extracted_visual_features_test = './data/GOD-Wiki/visual_feature/ImageNetTest/' + network 98 | timm_extracted_visual_features_aug = './data/GOD-Wiki/visual_feature/Aug_1000/' + network 99 | print('DNN feature') 100 | print(timm_extracted_visual_features) 101 | # Text features 102 | model_extracted_textual_features = './data/Wiki_articles_features' 103 | 104 | # Results directory 105 | results_dir_root = './data/GOD-Wiki/visual_feature/ImageNetTraining/' + network + '-PCA' 106 | results_dir_root_test = './data/GOD-Wiki/visual_feature/ImageNetTest/' + network + '-PCA' 107 | results_dir_root_aug = './data/GOD-Wiki/visual_feature/Aug_1000/' + network + '-PCA' 108 | results_fmri_root = './data/GOD-Wiki/brain_feature/LVC_HVC_IT' 109 | results_text_root = './data/GOD-Wiki/textual_feature/ImageNetTraining/text' 110 | results_text_root_test = './data/GOD-Wiki/textual_feature/ImageNetTest/text' 111 | results_text_root_aug = './data/GOD-Wiki/textual_feature/Aug_1000/text' 112 | 113 | # Main ####################################################################### 114 | analysis_basename = os.path.splitext(os.path.basename(__file__))[0] 115 | # Print info ----------------------------------------------------------------- 116 | print('Subjects: %s' % subjects_list.keys()) 117 | print('ROIs: %s' % rois_list.keys()) 118 | print('Target features: %s' % network.split('/')[-1]) 119 | print('Layers: %s' % features_list) 120 | print('') 121 | 122 | # Load data ------------------------------------------------------------------ 123 | print('----------------------------------------') 124 | print('Loading data') 125 | 126 | data_brain = {sbj: bdpy.BData(os.path.join(brain_dir, dat_file)) 127 | for sbj, dat_file in subjects_list.items()} 128 | data_features = Features(os.path.join(timm_extracted_visual_features, network)) 129 | 130 | data_brain_test = {sbj: bdpy.BData(os.path.join(brain_dir, dat_file)) 131 | for sbj, dat_file in subjects_list_test.items()} 132 | data_features_test = Features(os.path.join(timm_extracted_visual_features_test, network)) 133 | data_features_aug = Features(os.path.join(timm_extracted_visual_features_aug, network)) 134 | 135 | # Initialize directories ----------------------------------------------------- 136 | makedir_ifnot(results_dir_root) 137 | makedir_ifnot(results_dir_root_test) 138 | makedir_ifnot(results_dir_root_aug) 139 | makedir_ifnot(results_text_root) 140 | makedir_ifnot(results_text_root_test) 141 | makedir_ifnot(results_text_root_aug) 142 | 143 | # Save runtime information --------------------------------------------------- 144 | info_dir = results_dir_root 145 | runtime_params = { 146 | 'fMRI data': [os.path.abspath(os.path.join(brain_dir, v)) for v in subjects_list.values()], 147 | 'ROIs': rois_list.keys(), 148 | 'target DNN': network.split('/')[-1], 149 | 'target DNN features': os.path.abspath(timm_extracted_visual_features), 150 | 'target DNN layers': features_list, 151 | } 152 | dump_info(info_dir, script=__file__, parameters=runtime_params) 153 | 154 | 155 | ####################################### 156 | # Original 157 | ####################################### 158 | first = 1 159 | for sbj, roi in product(subjects_list ,rois_list): 160 | print('--------------------') 161 | print('VC ROI: %s' % roi) 162 | # Brain data 163 | # data_brain[sbj].show_metadata() 164 | x= data_brain[sbj].select(rois_list[roi]) 165 | x_labels = data_brain[sbj].select('image_index').flatten() # Label (image index) 166 | print('roi_shape=', x.shape) 167 | 168 | x_test= data_brain_test[sbj].select(rois_list[roi]) # Brain data 169 | x_labels_test = data_brain_test[sbj].select('image_index').flatten() # Label (image index) 170 | 171 | if first: 172 | best_roi_sel = x 173 | best_roi_sel_test = x_test 174 | first = 0 175 | else: 176 | best_roi_sel = np.append(best_roi_sel, x, axis=1) 177 | best_roi_sel_test = np.append(best_roi_sel_test, x_test, axis=1) 178 | 179 | 180 | ####################################### 181 | # Save brain and image feature data 182 | ####################################### 183 | # Analysis loop -------------------------------------------------------------- 184 | print('----------------------------------------') 185 | print('Analysis loop') 186 | first = 1 187 | for feat, sbj in product(features_list, subjects_list): 188 | print('--------------------') 189 | print('Feature: %s' % feat) 190 | print('Subject: %s' % sbj) 191 | 192 | results_dir_alllayer_pca = os.path.join(results_dir_root, sbj) 193 | results_dir_alllayer_pca_test = os.path.join(results_dir_root_test, sbj) 194 | results_dir_alllayer_pca_aug = os.path.join(results_dir_root_aug, sbj) 195 | 196 | results_fmri_dir = os.path.join(results_fmri_root, sbj) 197 | # Preparing data 198 | print('Preparing data') 199 | 200 | # Brain data 201 | x = best_roi_sel[0:1200] # Brain data 202 | x_labels = data_brain[sbj].select('image_index').flatten() # Label (image index) 203 | x_labels = x_labels[0:1200] # Label (image index) 204 | x_class = data_brain[sbj].select('Label') # Label (class index) 205 | WordNetID = data_brain[sbj].select('stimulus_id') # Label (class index) 206 | WordNetID = WordNetID[0:1200,0] 207 | class_idx = x_class[0:1200, 1] 208 | 209 | x_test = best_roi_sel_test[1200:2950] # Brain data 210 | x_labels_test = data_brain_test[sbj].select('image_index').flatten() # Label (image index) 211 | x_labels_test = x_labels_test[1200:2950] # Label (image index) 212 | x_class_test = data_brain_test[sbj].select('Label') # Label (class index) 213 | WordNetID_test = data_brain_test[sbj].select('stimulus_id') # Label (class index) 214 | WordNetID_test = WordNetID_test[1200:2950,0] 215 | class_idx_test = x_class_test[1200:2950, 1] 216 | 217 | # Averaging test brain data 218 | x_labels_test_unique, indices = np.unique(x_labels_test, return_index=True) 219 | x_test_unique = np.vstack([np.mean(x_test[(np.array(x_labels_test) == lb).flatten(), :], axis=0) for lb in x_labels_test_unique]) 220 | WordNetID_test_unique = WordNetID_test[indices] 221 | class_idx_test_unique = class_idx_test[indices] 222 | 223 | # Target features and image labels (file names) 224 | y = data_features.get_features(feat) # Target DNN features 225 | y_labels = data_features.index # Label (image index) 226 | y = np.reshape(y,(y.shape[0],-1)) 227 | 228 | y_test = data_features_test.get_features(feat) # Target DNN features 229 | y_labels_test = data_features_test.index # Label (image index) 230 | y_test = np.reshape(y_test,(y_test.shape[0],-1)) 231 | 232 | y_aug = data_features_aug.get_features(feat) # Target DNN features 233 | y_labels_aug_temp = data_features_aug.labels # Label (image index) 234 | y_labels_aug = [] 235 | for it in y_labels_aug_temp: 236 | y_labels_aug.append(int(it.split('_')[0][1:])) 237 | y_labels_aug = np.array(y_labels_aug) 238 | y_aug = np.reshape(y_aug,(y_aug.shape[0],-1)) 239 | 240 | # Calculate normalization parameters 241 | # Normalize X (fMRI data) 242 | x_mean = np.mean(x, axis=0)[np.newaxis, :] # np.newaxis was added to match Matlab outputs 243 | x_norm = np.std(x, axis=0, ddof=1)[np.newaxis, :] 244 | 245 | # Normalize Y (DNN features) 246 | y_mean = np.mean(y, axis=0)[np.newaxis, :] 247 | y_norm = np.std(y, axis=0, ddof=1)[np.newaxis, :] 248 | 249 | # Y index to sort Y by X (matching samples) 250 | y_index = np.array([np.where(np.array(y_labels) == xl) for xl in x_labels]).flatten() 251 | y_index_test = np.array([np.where(np.array(y_labels_test) == xl) for xl in x_labels_test]).flatten() 252 | y_index_test_unique = np.array([np.where(np.array(y_labels_test) == xl) for xl in x_labels_test_unique]).flatten() 253 | 254 | # X preprocessing 255 | print('Normalizing X') 256 | x = (x - x_mean) / (x_norm+TINY) 257 | x[np.isinf(x)] = 0 258 | 259 | x_test = (x_test - x_mean) / (x_norm+TINY) 260 | x_test[np.isinf(x_test)] = 0 261 | x_test_unique = (x_test_unique - x_mean) / (x_norm+TINY) 262 | x_test_unique[np.isinf(x_test_unique)] = 0 263 | 264 | print('Doing PCA') 265 | ipca = PCA(n_components=0.99, random_state=seed) 266 | ipca.fit(x) 267 | x = ipca.transform(x) 268 | x_test = ipca.transform(x_test) 269 | x_test_unique = ipca.transform(x_test_unique) 270 | print(x.shape) 271 | 272 | # Y preprocessing 273 | print('Normalizing Y') 274 | y = (y - y_mean) / (y_norm+TINY) 275 | y[np.isinf(y)] = 0 276 | y_test = (y_test - y_mean) / (y_norm+TINY) 277 | y_test[np.isinf(y_test)] = 0 278 | y_aug = (y_aug - y_mean) / (y_norm+TINY) 279 | y_aug[np.isinf(y_aug)] = 0 280 | 281 | print('Doing PCA') 282 | ipca = PCA(n_components=0.99, random_state=seed) 283 | ipca.fit(y) 284 | # ipca.fit(y_aug) 285 | y = ipca.transform(y) 286 | y_test = ipca.transform(y_test) 287 | y_aug = ipca.transform(y_aug) 288 | print(y.shape) 289 | 290 | print('Sorting Y') 291 | y = y[y_index, :] 292 | y_test = y_test[y_index_test, :] 293 | y_test_unique = y_test[y_index_test_unique, :] 294 | 295 | if first: 296 | feat_pca_train = y 297 | feat_pca_test = y_test 298 | feat_pca_aug = y_aug 299 | feat_pca_test_unique = y_test_unique 300 | first = 0 301 | else: 302 | feat_pca_train = np.concatenate((feat_pca_train, y), axis=1) 303 | feat_pca_test = np.concatenate((feat_pca_test, y_test), axis=1) 304 | feat_pca_aug = np.concatenate((feat_pca_aug, y_aug), axis=1) 305 | feat_pca_test_unique = np.concatenate((feat_pca_test_unique, y_test_unique), axis=1) 306 | print(feat_pca_test_unique.shape) 307 | 308 | 309 | makedir_ifnot(results_dir_alllayer_pca) 310 | makedir_ifnot(results_dir_alllayer_pca_test) 311 | makedir_ifnot(results_dir_alllayer_pca_aug) 312 | results_dir_alllayer_pca_path = os.path.join(results_dir_alllayer_pca, "feat_pca_train.mat") 313 | io.savemat(results_dir_alllayer_pca_path, {"data":feat_pca_train}) 314 | results_dir_alllayer_pca_test_path = os.path.join(results_dir_alllayer_pca_test, "feat_pca_test.mat") 315 | io.savemat(results_dir_alllayer_pca_test_path, {"data":feat_pca_test}) 316 | results_dir_alllayer_pca_aug_path = os.path.join(results_dir_alllayer_pca_aug, "feat_pca_aug.mat") 317 | io.savemat(results_dir_alllayer_pca_aug_path, {"data":feat_pca_aug}) 318 | results_dir_alllayer_pca_test_path = os.path.join(results_dir_alllayer_pca_test, "feat_pca_test_unique.mat") 319 | io.savemat(results_dir_alllayer_pca_test_path, {"data":feat_pca_test_unique}) 320 | 321 | 322 | makedir_ifnot(results_fmri_dir) 323 | results_fmri_dir_path = os.path.join(results_fmri_dir, "fmri_train_data.mat") 324 | io.savemat(results_fmri_dir_path, {"data":x, "image_idx":x_labels, "WordNetID":WordNetID, "class_idx":class_idx}) 325 | results_fmri_dir_path = os.path.join(results_fmri_dir, "fmri_test_data.mat") 326 | io.savemat(results_fmri_dir_path, {"data":x_test, "image_idx":x_labels_test, "WordNetID":WordNetID_test, "class_idx":class_idx_test}) 327 | results_fmri_dir_path = os.path.join(results_fmri_dir, "fmri_test_data_unique.mat") 328 | io.savemat(results_fmri_dir_path, {"data":x_test_unique, "image_idx":x_labels_test_unique, "WordNetID":WordNetID_test_unique, "class_idx":class_idx_test_unique}) 329 | 330 | ####################################### 331 | # Save text feature data 332 | ####################################### 333 | 334 | for feat, sbj in product(text_embedding_list, subjects_list): 335 | print('--------------------') 336 | print('Feature: %s' % feat) 337 | print('Subject: %s' % sbj) 338 | 339 | results_dir_text_fea = os.path.join(results_text_root, feat, sbj) 340 | results_dir_text_fea_test = os.path.join(results_text_root_test, feat, sbj) 341 | results_dir_text_fea_aug = os.path.join(results_text_root_aug, feat, sbj) 342 | # Preparing data 343 | print('Preparing data') 344 | 345 | # Brain data 346 | x_class = data_brain[sbj].select('Label')[0:1200] # Label (class index) 347 | WordNetID = data_brain[sbj].select('stimulus_id') # Label (class index) 348 | WordNetID = WordNetID[0:1200,0] 349 | class_idx = x_class[0:1200, 1] 350 | 351 | x_labels_test = data_brain_test[sbj].select('image_index').flatten() # Label (image index) 352 | x_labels_test = x_labels_test[1200:2950] # Label (image index) 353 | x_class_test = data_brain_test[sbj].select('Label') # Label (class index) 354 | WordNetID_test = data_brain_test[sbj].select('stimulus_id') # Label (class index) 355 | WordNetID_test = WordNetID_test[1200:2950,0] 356 | class_idx_test = x_class_test[1200:2950, 1] 357 | 358 | # Averaging test brain data 359 | x_labels_test_unique, indices = np.unique(x_labels_test, return_index=True) 360 | WordNetID_test_unique = WordNetID_test[indices] 361 | class_idx_test_unique = class_idx_test[indices] 362 | 363 | # Target text features and wnid 364 | name = 'ImageNet_class200_' + feat + '.pkl' 365 | full = os.path.join(model_extracted_textual_features, name) 366 | dictionary = pickle.load(open(full, 'rb')) 367 | 368 | firstfeat = 1 369 | firstlabel = 1 370 | for key, value in dictionary.items(): 371 | for k, v in value.items(): 372 | # print(k, v) 373 | if k == 'wnid': 374 | # print(v) 375 | v = int(v[1:]) 376 | if firstlabel: 377 | text_label = np.array([v]) 378 | firstlabel = 0 379 | else: 380 | text_label = np.concatenate((text_label, np.array([v])), axis=0) 381 | 382 | elif k == 'feats': 383 | v = np.expand_dims(v, axis=0) 384 | if firstfeat: 385 | text_feat = v 386 | firstfeat = 0 387 | else: 388 | text_feat = np.concatenate((text_feat, v), axis=0) 389 | 390 | # Target text features and wnid 391 | name = 'ImageNet_trainval_classes_' + feat + '.pkl' 392 | full = os.path.join(model_extracted_textual_features, name) 393 | dictionary = pickle.load(open(full, 'rb')) 394 | 395 | firstfeat = 1 396 | firstlabel = 1 397 | for key, value in dictionary.items(): 398 | for k, v in value.items(): 399 | # print(k, v) 400 | if k == 'wnid': 401 | # print(v) 402 | v = int(v[1:]) 403 | if firstlabel: 404 | text_label_aug = np.array([v]) 405 | firstlabel = 0 406 | else: 407 | text_label_aug = np.concatenate((text_label_aug, np.array([v])), axis=0) 408 | 409 | elif k == 'feats': 410 | v = np.expand_dims(v, axis=0) 411 | if firstfeat: 412 | text_feat_aug = v 413 | firstfeat = 0 414 | else: 415 | text_feat_aug = np.concatenate((text_feat_aug, v), axis=0) 416 | 417 | 418 | # t index to sort t by X (matching samples) 419 | t_index = np.array([np.where(np.array(text_label) == xl) for xl in WordNetID.astype(int)]).flatten() 420 | t_index_test = np.array([np.where(np.array(text_label) == xl) for xl in WordNetID_test.astype(int)]).flatten() 421 | t_index_test_unique = np.array([np.where(np.array(text_label) == xl) for xl in WordNetID_test_unique.astype(int)]).flatten() 422 | t_index_aug = np.array([np.where(np.array(text_label_aug) == xl) for xl in y_labels_aug]).flatten() 423 | 424 | print('Sorting text') 425 | t = text_feat[t_index, :] 426 | t_test = text_feat[t_index_test, :] 427 | t_aug = text_feat_aug[t_index_aug, :] 428 | t_test_unique = text_feat[t_index_test_unique, :] 429 | 430 | print(t.shape) 431 | print(t_test.shape) 432 | print(t_aug.shape) 433 | print(t_test_unique.shape) 434 | 435 | makedir_ifnot(results_dir_text_fea) 436 | makedir_ifnot(results_dir_text_fea_test) 437 | makedir_ifnot(results_dir_text_fea_aug) 438 | 439 | results_dir_text_fea_path = os.path.join(results_dir_text_fea, "text_feat_train.mat") 440 | io.savemat(results_dir_text_fea_path, {"data": t}) 441 | 442 | results_dir_text_fea_test_path = os.path.join(results_dir_text_fea_test, "text_feat_test.mat") 443 | io.savemat(results_dir_text_fea_test_path, {"data": t_test}) 444 | 445 | results_dir_text_fea_aug_path = os.path.join(results_dir_text_fea_aug, "text_feat_aug.mat") 446 | io.savemat(results_dir_text_fea_aug_path, {"data": t_aug}) 447 | 448 | results_dir_text_fea_test_path = os.path.join(results_dir_text_fea_test, "text_feat_test_unique.mat") 449 | io.savemat(results_dir_text_fea_test_path, {"data": t_test_unique}) 450 | 451 | print('%s finished.' % analysis_basename) -------------------------------------------------------------------------------- /BraVL_fMRI/data_prepare_with_aug_DIR_Wiki.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | from itertools import product 3 | import os 4 | import pickle 5 | import bdpy 6 | from bdpy.dataform import Features 7 | from bdpy.util import dump_info, makedir_ifnot 8 | import numpy as np 9 | from stability_selection import stability_selection 10 | from sklearn.decomposition import PCA 11 | from scipy import io 12 | 13 | # Settings ################################################################### 14 | seed = 42 15 | TINY = 1e-8 16 | # Python RNG 17 | np.random.seed(seed) 18 | 19 | subject_set=['subject1','subject2','subject3'] 20 | for subject in subject_set: 21 | if subject == 'subject1': 22 | subjects_list = { 23 | 'sub-01': 'sub-01_perceptionNaturalImageTraining_VC_v2.h5', 24 | } 25 | 26 | subjects_list_test = { 27 | 'sub-01': 'sub-01_perceptionNaturalImageTest_VC_v2.h5', 28 | } 29 | elif subject == 'subject2': 30 | subjects_list = { 31 | 'sub-02': 'sub-02_perceptionNaturalImageTraining_VC_v2.h5', 32 | } 33 | 34 | subjects_list_test = { 35 | 'sub-02': 'sub-02_perceptionNaturalImageTest_VC_v2.h5', 36 | } 37 | elif subject == 'subject3': 38 | subjects_list = { 39 | 'sub-03': 'sub-03_perceptionNaturalImageTraining_VC_v2.h5', 40 | } 41 | 42 | subjects_list_test = { 43 | 'sub-03': 'sub-03_perceptionNaturalImageTest_VC_v2.h5', 44 | } 45 | 46 | 47 | text_model_list = [ 48 | 'GPTNeo', 49 | 'ALBERT', 50 | # 'GPTNeo_phrases', 51 | # 'ALBERT_phrases' 52 | ] 53 | 54 | rois_list = { 55 | # 'VC': 'ROI_VC = 1', 56 | 'LVC': 'ROI_LVC = 1', 57 | 'HVC': 'ROI_HVC = 1', 58 | # 'V1': 'ROI_V1 = 1', 59 | # 'V2': 'ROI_V2 = 1', 60 | # 'V3': 'ROI_V3 = 1', 61 | # 'V4': 'ROI_V4 = 1', 62 | # 'LOC': 'ROI_LOC = 1', 63 | # 'FFA': 'ROI_FFA = 1', 64 | # 'PPA': 'ROI_PPA = 1', 65 | 'IT': 'ROI_IT = 1', 66 | } 67 | 68 | network = 'pytorch/repvgg_b3g4' 69 | features_list = [#'Conv_0', 70 | # 'Conv_1', 71 | 'Conv_2', 72 | 'Conv_3', 73 | 'Conv_4', 74 | 'linear', 75 | 'final'] 76 | 77 | features_list = features_list[::-1] # Start training from deep layers 78 | 79 | # Brain data 80 | brain_dir = './data/DeepImageReconstruction/data/fmri' 81 | # Image features 82 | timm_extracted_visual_features = './data/DIR-Wiki/visual_feature/ImageNetTraining/'+network 83 | timm_extracted_visual_features_test = './data/DIR-Wiki/visual_feature/ImageNetTest/'+network 84 | timm_extracted_visual_features_aug = './data/DIR-Wiki/visual_feature/Aug_1000/'+network 85 | print('DNN feature') 86 | print(timm_extracted_visual_features) 87 | # Text features 88 | model_extracted_textual_features = './data/Wiki_articles_features' 89 | 90 | # Results directory 91 | results_dir_root = './data/DIR-Wiki/visual_feature/ImageNetTraining/'+network+'-PCA' 92 | results_dir_root_test = './data/DIR-Wiki/visual_feature/ImageNetTest/'+network+'-PCA' 93 | results_dir_root_aug = './data/DIR-Wiki/visual_feature/Aug_1000/'+network+'-PCA' 94 | results_fmri_root = './data/DIR-Wiki/brain_feature/LVC_HVC_IT' 95 | results_text_root = './data/DIR-Wiki/textual_feature/ImageNetTraining/text' 96 | results_text_root_test = './data/DIR-Wiki/textual_feature/ImageNetTest/text' 97 | results_text_root_aug = './data/DIR-Wiki/textual_feature/Aug_1000/text' 98 | 99 | # Main ####################################################################### 100 | analysis_basename = os.path.splitext(os.path.basename(__file__))[0] 101 | # Print info ----------------------------------------------------------------- 102 | print('Subjects: %s' % subjects_list.keys()) 103 | print('ROIs: %s' % rois_list.keys()) 104 | print('Target features: %s' % network.split('/')[-1]) 105 | print('Layers: %s' % features_list) 106 | print('') 107 | 108 | # Load data ------------------------------------------------------------------ 109 | print('----------------------------------------') 110 | print('Loading data') 111 | 112 | data_brain = {sbj: bdpy.BData(os.path.join(brain_dir, dat_file)) 113 | for sbj, dat_file in subjects_list.items()} 114 | data_features = Features(os.path.join(timm_extracted_visual_features, network)) 115 | 116 | data_brain_test = {sbj: bdpy.BData(os.path.join(brain_dir, dat_file)) 117 | for sbj, dat_file in subjects_list_test.items()} 118 | data_features_test = Features(os.path.join(timm_extracted_visual_features_test, network)) 119 | data_features_aug = Features(os.path.join(timm_extracted_visual_features_aug, network)) 120 | 121 | # Initialize directories ----------------------------------------------------- 122 | makedir_ifnot(results_dir_root) 123 | makedir_ifnot(results_dir_root_test) 124 | makedir_ifnot(results_dir_root_aug) 125 | makedir_ifnot(results_text_root) 126 | makedir_ifnot(results_text_root_test) 127 | makedir_ifnot(results_text_root_aug) 128 | makedir_ifnot('tmp') 129 | 130 | # Save runtime information --------------------------------------------------- 131 | info_dir = results_dir_root 132 | runtime_params = { 133 | 'fMRI data': [os.path.abspath(os.path.join(brain_dir, v)) for v in subjects_list.values()], 134 | 'ROIs': rois_list.keys(), 135 | 'target DNN': network.split('/')[-1], 136 | 'target DNN features': os.path.abspath(timm_extracted_visual_features), 137 | 'target DNN layers': features_list, 138 | } 139 | dump_info(info_dir, script=__file__, parameters=runtime_params) 140 | 141 | ####################################### 142 | # Stability selection 143 | ####################################### 144 | select_ratio = 0.15 145 | totalnum = 0 146 | first = 1 147 | best_roi_sel = [] 148 | num_voxel = dict() 149 | for sbj, roi in product(subjects_list ,rois_list): 150 | print('--------------------') 151 | print('VC ROI: %s' % roi) 152 | trial1 = [] 153 | l1 = [] 154 | trial2 = [] 155 | l2 = [] 156 | trial3 = [] 157 | l3 = [] 158 | trial4 = [] 159 | l4 = [] 160 | trial5 = [] 161 | l5 = [] 162 | # Brain data 163 | x = data_brain[sbj].select(rois_list[roi]) # Brain data 164 | x_labels = data_brain[sbj].select('image_index').flatten() # Label (image index) 165 | 166 | x_test = data_brain_test[sbj].select(rois_list[roi]) # Brain data 167 | x_labels_test = data_brain_test[sbj].select('image_index').flatten() # Label (image index) 168 | 169 | for l in range(1,int(len(x_labels)/5)+1): 170 | n = np.where(x_labels==l) 171 | #trial1 172 | l1.append(l) 173 | trial1.append(x[n[0][0]]) 174 | #trial2 175 | l2.append(l) 176 | trial2.append(x[n[0][1]]) 177 | #trial3 178 | l3.append(l) 179 | trial3.append(x[n[0][2]]) 180 | #trial4 181 | l4.append(l) 182 | trial4.append(x[n[0][3]]) 183 | #trial5 184 | l5.append(l) 185 | trial5.append(x[n[0][4]]) 186 | #reshape to select 187 | sel_input = np.array([trial1]) 188 | sel_input = np.append(sel_input, np.array([trial2]), axis=0) 189 | sel_input = np.append(sel_input, np.array([trial3]), axis=0) 190 | sel_input = np.append(sel_input, np.array([trial4]), axis=0) 191 | sel_input = np.append(sel_input, np.array([trial5]), axis=0) 192 | select_num = int(select_ratio * (x.shape)[1]) 193 | num_voxel.update({roi:select_num}) 194 | 195 | print('roi_shape=',x.shape) 196 | sel_idx = stability_selection(sel_input, select_num) 197 | #save as best_roi_sel mat 198 | if first: 199 | best_roi_sel = np.array(x[:,sel_idx]) 200 | best_roi_sel_test = np.array(x_test[:, sel_idx]) 201 | first = 0 202 | else: 203 | best_roi_sel = np.append(best_roi_sel, x[:,sel_idx], axis=1) 204 | best_roi_sel_test = np.append(best_roi_sel_test, x_test[:,sel_idx], axis=1) 205 | 206 | totalnum_voxel = (best_roi_sel.shape)[1] 207 | print('total_selected_voxel=', totalnum_voxel) 208 | print(num_voxel) 209 | 210 | print('best_roi_sel_shape=',best_roi_sel.shape) 211 | print('x_labels_shape=',x_labels.shape) 212 | 213 | print('best_roi_sel_test_shape=',best_roi_sel_test.shape) 214 | print('x_labels_test_shape=',x_labels_test.shape) 215 | 216 | 217 | ####################################### 218 | # Save brain and image feature data 219 | ####################################### 220 | # Analysis loop -------------------------------------------------------------- 221 | print('----------------------------------------') 222 | print('Analysis loop') 223 | first = 1 224 | for feat, sbj in product(features_list, subjects_list): 225 | print('--------------------') 226 | print('Feature: %s' % feat) 227 | print('Subject: %s' % sbj) 228 | 229 | results_dir_alllayer_pca = os.path.join(results_dir_root, sbj) 230 | results_dir_alllayer_pca_test = os.path.join(results_dir_root_test, sbj) 231 | results_dir_alllayer_pca_aug = os.path.join(results_dir_root_aug, sbj) 232 | 233 | results_fmri_dir = os.path.join(results_fmri_root, sbj) 234 | # Preparing data 235 | # -------------- 236 | print('Preparing data') 237 | 238 | # Brain data 239 | x = best_roi_sel # Brain data 240 | x_labels = x_labels # Label (image index) 241 | x_class = data_brain[sbj].select('Label') # Label (class index) 242 | WordNetID = x_class[:, 2] 243 | if sbj == 'sub-03': 244 | class_idx = data_brain[sbj].select('image_index').flatten() 245 | else: 246 | class_idx = x_class[:, 1] 247 | 248 | x_test = best_roi_sel_test # Brain data 249 | x_labels_test = x_labels_test # Label (image index) 250 | x_class_test = data_brain_test[sbj].select('Label') # Label (class index) 251 | WordNetID_test = x_class_test[:, 2] 252 | if sbj == 'sub-03': 253 | class_idx_test = data_brain_test[sbj].select('image_index').flatten() 254 | else: 255 | class_idx_test = x_class_test[:, 1] 256 | 257 | # Averaging test brain data 258 | x_labels_test_unique, indices = np.unique(x_labels_test, return_index=True) 259 | x_test_unique = np.vstack([np.mean(x_test[(np.array(x_labels_test) == lb).flatten(), :], axis=0) for lb in x_labels_test_unique]) 260 | WordNetID_test_unique = WordNetID_test[indices] 261 | class_idx_test_unique = class_idx_test[indices] 262 | 263 | # Target features and image labels (file names) 264 | y = data_features.get_features(feat) # Target DNN features 265 | y_labels = data_features.index # Label (image index) 266 | y = np.reshape(y,(y.shape[0],-1)) 267 | 268 | y_test = data_features_test.get_features(feat) # Target DNN features 269 | y_labels_test = data_features_test.index # Label (image index) 270 | y_test = np.reshape(y_test,(y_test.shape[0],-1)) 271 | 272 | y_aug = data_features_aug.get_features(feat) # Target DNN features 273 | y_labels_aug_temp = data_features_aug.labels # Label (image index) 274 | y_labels_aug = [] 275 | for it in y_labels_aug_temp: 276 | y_labels_aug.append(int(it.split('_')[0][1:])) 277 | y_labels_aug = np.array(y_labels_aug) 278 | y_aug = np.reshape(y_aug,(y_aug.shape[0],-1)) 279 | 280 | # Calculate normalization parameters 281 | # Normalize X (fMRI data) 282 | x_mean = np.mean(x, axis=0)[np.newaxis, :] # np.newaxis was added to match Matlab outputs 283 | x_norm = np.std(x, axis=0, ddof=1)[np.newaxis, :] 284 | 285 | # Normalize Y (DNN features) 286 | y_mean = np.mean(y, axis=0)[np.newaxis, :] 287 | y_norm = np.std(y, axis=0, ddof=1)[np.newaxis, :] 288 | 289 | # Y index to sort Y by X (matching samples) 290 | y_index = np.array([np.where(np.array(y_labels) == xl) for xl in x_labels]).flatten() 291 | y_index_test = np.array([np.where(np.array(y_labels_test) == xl) for xl in x_labels_test]).flatten() 292 | y_index_test_unique = np.array([np.where(np.array(y_labels_test) == xl) for xl in x_labels_test_unique]).flatten() 293 | 294 | # X preprocessing 295 | print('Normalizing X') 296 | x = (x - x_mean) / (x_norm+TINY) 297 | x[np.isinf(x)] = 0 298 | 299 | x_test = (x_test - x_mean) / (x_norm+TINY) 300 | x_test[np.isinf(x_test)] = 0 301 | x_test_unique = (x_test_unique - x_mean) / (x_norm+TINY) 302 | x_test_unique[np.isinf(x_test_unique)] = 0 303 | 304 | print('Doing PCA') 305 | ipca = PCA(n_components=0.99, random_state=seed) 306 | ipca.fit(x) 307 | x = ipca.transform(x) 308 | x_test = ipca.transform(x_test) 309 | x_test_unique = ipca.transform(x_test_unique) 310 | print(x.shape) 311 | 312 | # Y preprocessing 313 | print('Normalizing Y') 314 | y = (y - y_mean) / (y_norm+TINY) 315 | y[np.isinf(y)] = 0 316 | y_test = (y_test - y_mean) / (y_norm+TINY) 317 | y_test[np.isinf(y_test)] = 0 318 | y_aug = (y_aug - y_mean) / (y_norm+TINY) 319 | y_aug[np.isinf(y_aug)] = 0 320 | 321 | print('Doing PCA') 322 | ipca = PCA(n_components=0.99, random_state=seed) 323 | ipca.fit(y) 324 | # ipca.fit(y_aug) 325 | y = ipca.transform(y) 326 | y_test = ipca.transform(y_test) 327 | y_aug = ipca.transform(y_aug) 328 | print(y.shape) 329 | 330 | print('Sorting Y') 331 | y = y[y_index, :] 332 | y_test = y_test[y_index_test, :] 333 | y_test_unique = y_test[y_index_test_unique, :] 334 | 335 | if first: 336 | feat_pca_train = y 337 | feat_pca_test = y_test 338 | feat_pca_aug = y_aug 339 | feat_pca_test_unique = y_test_unique 340 | first = 0 341 | else: 342 | feat_pca_train = np.concatenate((feat_pca_train, y), axis=1) 343 | feat_pca_test = np.concatenate((feat_pca_test, y_test), axis=1) 344 | feat_pca_aug = np.concatenate((feat_pca_aug, y_aug), axis=1) 345 | feat_pca_test_unique = np.concatenate((feat_pca_test_unique, y_test_unique), axis=1) 346 | print(feat_pca_test_unique.shape) 347 | 348 | 349 | makedir_ifnot(results_dir_alllayer_pca) 350 | makedir_ifnot(results_dir_alllayer_pca_test) 351 | makedir_ifnot(results_dir_alllayer_pca_aug) 352 | results_dir_alllayer_pca_path = os.path.join(results_dir_alllayer_pca, "feat_pca_train.mat") 353 | io.savemat(results_dir_alllayer_pca_path, {"data":feat_pca_train}) 354 | results_dir_alllayer_pca_test_path = os.path.join(results_dir_alllayer_pca_test, "feat_pca_test.mat") 355 | io.savemat(results_dir_alllayer_pca_test_path, {"data":feat_pca_test}) 356 | results_dir_alllayer_pca_aug_path = os.path.join(results_dir_alllayer_pca_aug, "feat_pca_aug.mat") 357 | io.savemat(results_dir_alllayer_pca_aug_path, {"data":feat_pca_aug}) 358 | results_dir_alllayer_pca_test_path = os.path.join(results_dir_alllayer_pca_test, "feat_pca_test_unique.mat") 359 | io.savemat(results_dir_alllayer_pca_test_path, {"data":feat_pca_test_unique}) 360 | 361 | 362 | makedir_ifnot(results_fmri_dir) 363 | results_fmri_dir_path = os.path.join(results_fmri_dir, "fmri_train_data.mat") 364 | io.savemat(results_fmri_dir_path, {"data":x, "image_idx":x_labels, "WordNetID":WordNetID, "class_idx":class_idx}) 365 | results_fmri_dir_path = os.path.join(results_fmri_dir, "fmri_test_data.mat") 366 | io.savemat(results_fmri_dir_path, {"data":x_test, "image_idx":x_labels_test, "WordNetID":WordNetID_test, "class_idx":class_idx_test}) 367 | results_fmri_dir_path = os.path.join(results_fmri_dir, "fmri_test_data_unique.mat") 368 | io.savemat(results_fmri_dir_path, {"data":x_test_unique, "image_idx":x_labels_test_unique, "WordNetID":WordNetID_test_unique, "class_idx":class_idx_test_unique}) 369 | 370 | ####################################### 371 | # Save text feature data 372 | ####################################### 373 | 374 | for feat, sbj in product(text_model_list, subjects_list): 375 | print('--------------------') 376 | print('Feature: %s' % feat) 377 | print('Subject: %s' % sbj) 378 | 379 | results_dir_text_fea = os.path.join(results_text_root, feat, sbj) 380 | results_dir_text_fea_test = os.path.join(results_text_root_test, feat, sbj) 381 | results_dir_text_fea_aug = os.path.join(results_text_root_aug, feat, sbj) 382 | # Preparing data 383 | # -------------- 384 | print('Preparing data') 385 | 386 | # Brain data 387 | x_class = data_brain[sbj].select('Label') # Label (class index) 388 | WordNetID = x_class[:, 2] 389 | class_idx = x_class[:, 1] 390 | 391 | x_labels_test = x_labels_test # Label (image index) 392 | x_class_test = data_brain_test[sbj].select('Label') # Label (class index) 393 | WordNetID_test = x_class_test[:, 2] 394 | class_idx_test = x_class_test[:, 1] 395 | 396 | # Averaging test brain data 397 | x_labels_test_unique, indices = np.unique(x_labels_test, return_index=True) 398 | WordNetID_test_unique = WordNetID_test[indices] 399 | class_idx_test_unique = class_idx_test[indices] 400 | 401 | # Target text features and wnid 402 | name = 'ImageNet_class200_' + feat + '.pkl' 403 | full = os.path.join(model_extracted_textual_features, name) 404 | dictionary = pickle.load(open(full, 'rb')) 405 | 406 | firstfeat = 1 407 | firstlabel = 1 408 | for key, value in dictionary.items(): 409 | for k, v in value.items(): 410 | # print(k, v) 411 | if k == 'wnid': 412 | # print(v) 413 | v = int(v[1:]) 414 | if firstlabel: 415 | text_label = np.array([v]) 416 | firstlabel = 0 417 | else: 418 | text_label = np.concatenate((text_label, np.array([v])), axis=0) 419 | 420 | elif k == 'feats': 421 | v = np.expand_dims(v, axis=0) 422 | if firstfeat: 423 | text_feat = v 424 | firstfeat = 0 425 | else: 426 | text_feat = np.concatenate((text_feat, v), axis=0) 427 | 428 | # Extra text features and wnid 429 | name = 'ImageNet_trainval_classes_' + feat + '.pkl' 430 | full = os.path.join(model_extracted_textual_features, name) 431 | dictionary = pickle.load(open(full, 'rb')) 432 | 433 | firstfeat = 1 434 | firstlabel = 1 435 | for key, value in dictionary.items(): 436 | for k, v in value.items(): 437 | # print(k, v) 438 | if k == 'wnid': 439 | # print(v) 440 | v = int(v[1:]) 441 | if firstlabel: 442 | text_label_aug = np.array([v]) 443 | firstlabel = 0 444 | else: 445 | text_label_aug = np.concatenate((text_label_aug, np.array([v])), axis=0) 446 | 447 | elif k == 'feats': 448 | v = np.expand_dims(v, axis=0) 449 | if firstfeat: 450 | text_feat_aug = v 451 | firstfeat = 0 452 | else: 453 | text_feat_aug = np.concatenate((text_feat_aug, v), axis=0) 454 | 455 | # t index to sort t by X (matching samples) 456 | t_index = np.array([np.where(np.array(text_label) == xl) for xl in WordNetID.astype(int)]).flatten() 457 | t_index_test = np.array([np.where(np.array(text_label) == xl) for xl in WordNetID_test.astype(int)]).flatten() 458 | t_index_test_unique = np.array([np.where(np.array(text_label) == xl) for xl in WordNetID_test_unique.astype(int)]).flatten() 459 | t_index_aug = np.array([np.where(np.array(text_label_aug) == xl) for xl in y_labels_aug]).flatten() 460 | 461 | print('Sorting text') 462 | t = text_feat[t_index, :] 463 | t_test = text_feat[t_index_test, :] 464 | t_aug = text_feat_aug[t_index_aug, :] 465 | t_test_unique = text_feat[t_index_test_unique, :] 466 | 467 | print(t.shape) 468 | print(t_test.shape) 469 | print(t_aug.shape) 470 | print(t_test_unique.shape) 471 | 472 | makedir_ifnot(results_dir_text_fea) 473 | makedir_ifnot(results_dir_text_fea_test) 474 | makedir_ifnot(results_dir_text_fea_aug) 475 | 476 | results_dir_text_fea_path = os.path.join(results_dir_text_fea, "text_feat_train.mat") 477 | io.savemat(results_dir_text_fea_path, {"data": t}) 478 | 479 | results_dir_text_fea_test_path = os.path.join(results_dir_text_fea_test, "text_feat_test.mat") 480 | io.savemat(results_dir_text_fea_test_path, {"data": t_test}) 481 | 482 | results_dir_text_fea_aug_path = os.path.join(results_dir_text_fea_aug, "text_feat_aug.mat") 483 | io.savemat(results_dir_text_fea_aug_path, {"data": t_aug}) 484 | 485 | results_dir_text_fea_test_path = os.path.join(results_dir_text_fea_test, "text_feat_test_unique.mat") 486 | io.savemat(results_dir_text_fea_test_path, {"data": t_test_unique}) 487 | 488 | print('%s finished.' % analysis_basename) --------------------------------------------------------------------------------