├── src ├── __init__.py ├── main.py ├── data.py ├── options.py ├── recon.py ├── eval.py ├── utils.py ├── models.py ├── nsd_access.py └── trainer.py ├── assets ├── galaxy_brain.gif ├── MindBridge_method.png └── MindBridge_teaser.png ├── requirements.txt ├── .gitignore ├── scripts ├── train_single.sh ├── train_bridge.sh ├── inference.sh └── adapt_bridge.sh └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/galaxy_brain.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlepure2333/MindBridge/HEAD/assets/galaxy_brain.gif -------------------------------------------------------------------------------- /assets/MindBridge_method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlepure2333/MindBridge/HEAD/assets/MindBridge_method.png -------------------------------------------------------------------------------- /assets/MindBridge_teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/littlepure2333/MindBridge/HEAD/assets/MindBridge_teaser.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision==0.15.2 3 | diffusers==0.13.0 4 | kornia 5 | tqdm 6 | pandas 7 | scipy 8 | accelerate 9 | deepspeed 10 | torchsnooper 11 | matplotlib 12 | pycocotools 13 | h5py 14 | nibabel 15 | urllib3 16 | numpy 17 | wandb 18 | pillow 19 | scikit-image 20 | clip 21 | clip-retrieval 22 | transformers 23 | gpustat -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Project specific files 2 | data/ 3 | train_logs/ 4 | weights/ 5 | wandb 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__ 9 | *.py[cod] 10 | *$py.class 11 | 12 | # Distribution / packaging 13 | dist/ 14 | build/ 15 | *.egg-info/ 16 | *.egg 17 | 18 | # Virtual environments 19 | venv/ 20 | env/ 21 | .env 22 | 23 | # IDE specific files 24 | .idea/ 25 | .vscode/ 26 | 27 | # Compiled Python files 28 | *.pyc 29 | 30 | # Logs and databases 31 | *.log 32 | *.sqlite3 33 | 34 | # OS generated files 35 | .DS_Store 36 | Thumbs.db -------------------------------------------------------------------------------- /scripts/train_single.sh: -------------------------------------------------------------------------------- 1 | batch_size=50 2 | val_batch_size=50 3 | num_epochs=600 4 | mse_mult=10000 5 | subj=1 6 | model_name="mindbridge_subj"$subj 7 | 8 | cd src/ 9 | 10 | # CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \ 11 | accelerate launch --multi_gpu --num_processes 2 --gpu_ids 0,1 --main_process_port 29501 \ 12 | main.py \ 13 | --model_name $model_name --subj_list $subj \ 14 | --num_epochs $num_epochs --batch_size $batch_size --val_batch_size $val_batch_size \ 15 | --h_size 2048 --n_blocks 4 --pool_type max --pool_num 8192 \ 16 | --mse_mult $mse_mult \ 17 | --eval_interval 10 --ckpt_interval 10 \ 18 | --max_lr 3e-4 --num_workers 4 -------------------------------------------------------------------------------- /scripts/train_bridge.sh: -------------------------------------------------------------------------------- 1 | batch_size=50 2 | val_batch_size=50 3 | num_epochs=600 4 | mse_mult=10000 5 | rec_mult=1 6 | cyc_mul=1 7 | model_name="mindbridge_subj1257" 8 | 9 | cd src/ 10 | 11 | # CUDA_VISIBLE_DEVICES=4 python -W ignore \ 12 | accelerate launch --multi_gpu --num_processes 8 --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 29502 \ 13 | main.py \ 14 | --model_name $model_name --subj_list 1 2 5 7 \ 15 | --num_epochs $num_epochs --batch_size $batch_size --val_batch_size $val_batch_size \ 16 | --h_size 2048 --n_blocks 4 --pool_type max --pool_num 8192 \ 17 | --mse_mult $mse_mult --rec_mult $rec_mult --cyc_mul $cyc_mul \ 18 | --eval_interval 10 --ckpt_interval 10 \ 19 | --max_lr 1.5e-4 --num_workers 8 20 | -------------------------------------------------------------------------------- /scripts/inference.sh: -------------------------------------------------------------------------------- 1 | subj_load=1 2 | subj_test=1 3 | model_name="mindbridge_subj125" 4 | ckpt_from="last" 5 | text_image_ratio=0.5 6 | guidance=5 7 | gpu_id=0 8 | 9 | cd src/ 10 | 11 | CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \ 12 | recon.py \ 13 | --model_name $model_name --ckpt_from $ckpt_from \ 14 | --h_size 2048 --n_blocks 4 --pool_type max \ 15 | --subj_load $subj_load --subj_test $subj_test \ 16 | --text_image_ratio $text_image_ratio --guidance $guidance \ 17 | --recons_per_sample 8 18 | # --test_end 2 \ 19 | # --test_start 0 \ 20 | 21 | 22 | results_path="../train_logs/"$model_name"/recon_on_subj"$subj_test 23 | 24 | CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \ 25 | eval.py --results_path $results_path 26 | -------------------------------------------------------------------------------- /scripts/adapt_bridge.sh: -------------------------------------------------------------------------------- 1 | batch_size=50 2 | val_batch_size=50 3 | mse_mult=10000 4 | rec_mult=1 5 | cyc_mult=1 6 | num_epochs=200 7 | subj_source="1 2 5" 8 | subj_target=7 9 | length=4000 10 | model_name="mindbridge_subj125_a"$subj_target"_l"$length 11 | load_from="../train_logs/mindbridge_subj125/last.pth" 12 | gpu_id=0 13 | 14 | cd src/ 15 | 16 | # accelerate launch --multi_gpu --num_processes 4 --gpu_ids 0,1,2,3 --main_process_port 29503 \ 17 | CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \ 18 | main.py \ 19 | --model_name $model_name --subj_target $subj_target --subj_source $subj_source \ 20 | --num_epochs $num_epochs --batch_size $batch_size --val_batch_size $val_batch_size \ 21 | --h_size 2048 --n_blocks 4 --pool_type max --pool_num 8192 \ 22 | --mse_mult $mse_mult --rec_mult $rec_mult --cyc_mult $cyc_mult \ 23 | --eval_interval 10 --ckpt_interval 10 \ 24 | --load_from $load_from --num_workers 8 \ 25 | --max_lr 1.5e-4 --adapting --length $length 26 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from accelerate import Accelerator, DeepSpeedPlugin 5 | 6 | # tf32 data type is faster than standard float32 7 | torch.backends.cuda.matmul.allow_tf32 = True 8 | 9 | # Custom models and functions # 10 | from models import Clipper, MindBridge, MindSingle 11 | from nsd_access import NSDAccess 12 | from trainer import * 13 | from options import args 14 | import utils 15 | 16 | def config_multi_gpu(): 17 | # Multi-GPU config 18 | deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_clipping=1.0) 19 | accelerator = Accelerator(split_batches=False, mixed_precision='no', deepspeed_plugin=deepspeed_plugin) 20 | accelerator.print("PID of this process =",os.getpid()) 21 | device = accelerator.device 22 | accelerator.print("device:",device) 23 | num_devices = torch.cuda.device_count() 24 | if num_devices==0: num_devices = 1 25 | accelerator.print(accelerator.state) 26 | local_rank = accelerator.state.local_process_index 27 | world_size = accelerator.state.num_processes 28 | distributed = not accelerator.state.distributed_type == 'NO' 29 | accelerator.print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size) 30 | 31 | return accelerator, device, local_rank 32 | 33 | def prepare_CLIP(args, device): 34 | # Prepare CLIP 35 | clip_sizes = {"RN50": 1024, "ViT-L/14": 768, "ViT-B/32": 512, "ViT-H-14": 1024} 36 | clip_size = clip_sizes[args.clip_variant] 37 | 38 | print("Using hidden layer CLIP space (Versatile Diffusion)") 39 | if not args.norm_embs: 40 | print("WARNING: YOU WANT NORMED EMBEDDINGS FOR VERSATILE DIFFUSION!") 41 | clip_extractor = Clipper(args.clip_variant, device=device, hidden_state=True, norm_embs=args.norm_embs) 42 | 43 | out_dim_image = 257 * clip_size # 257*768 = 197376 44 | out_dim_text = 77 * clip_size # 77*768 = 59136 45 | 46 | print("clip_extractor loaded.") 47 | print("out_dim_image:",out_dim_image) 48 | print("out_dim_text:", out_dim_text) 49 | 50 | return clip_extractor, out_dim_image, out_dim_text 51 | 52 | def prepare_voxel2clip(args, out_dim_image, out_dim_text, device): 53 | # Prepare voxel2clip 54 | if args.adapting: 55 | args.subj_list = args.subj_source + [args.subj_target] 56 | 57 | voxel2clip_kwargs = dict( 58 | in_dim=args.pool_num, out_dim_image=out_dim_image, out_dim_text=out_dim_text, 59 | h=args.h_size, n_blocks=args.n_blocks, subj_list=args.subj_list, adapting=args.adapting) 60 | if len(args.subj_list) == 1: # Single subject does not need "brain builder" 61 | voxel2clip_kwargs.pop("adapting") # Single subject does not need "adapting" 62 | voxel2clip = MindSingle(**voxel2clip_kwargs).to(device) 63 | else: 64 | voxel2clip = MindBridge(**voxel2clip_kwargs).to(device) 65 | 66 | if args.adapting: # reset-tuning 67 | # Only let the parameters of embedder and builder in the voxel2clip trainable, keeping other parameters frozen 68 | voxel2clip.requires_grad_(False) 69 | voxel2clip.embedder[str(args.subj_target)].requires_grad_(True) 70 | voxel2clip.builder[str(args.subj_target)].requires_grad_(True) 71 | 72 | print("voxel2clip loaded.") 73 | print("params of voxel2clip:") 74 | utils.count_params(voxel2clip) 75 | 76 | return voxel2clip 77 | 78 | def prepare_coco(args): 79 | # Preload coco captions 80 | nsda = NSDAccess(args.data_path) 81 | coco_73k = list(range(0, 73000)) 82 | prompts_list = nsda.read_image_coco_info(coco_73k,info_type='captions') 83 | 84 | print("coco captions loaded.") 85 | 86 | return prompts_list 87 | 88 | def prepare_trainer(args, accelerator, voxel2clip, clip_extractor, prompts_list, device): 89 | if args.adapting: 90 | trainer = Trainer_adapt(args, accelerator, voxel2clip, clip_extractor, prompts_list, device) 91 | elif len(args.subj_list) == 1: 92 | trainer = Trainer_single(args, accelerator, voxel2clip, clip_extractor, prompts_list, device) 93 | else: 94 | trainer = Trainer_bridge(args, accelerator, voxel2clip, clip_extractor, prompts_list, device) 95 | 96 | return trainer 97 | 98 | 99 | def main(): 100 | accelerator, device, local_rank = config_multi_gpu() 101 | if local_rank != 0: # suppress print for non-local_rank=0 102 | sys.stdout = open(os.devnull, 'w') 103 | 104 | # need non-deterministic CuDNN for conv3D to work 105 | utils.seed_everything(args.seed, cudnn_deterministic=False) 106 | 107 | # learning rate will be changed by "acclerate" based on number of processes(GPUs) 108 | args.max_lr *= accelerator.num_processes 109 | 110 | # Prepare CLIP 111 | clip_extractor, out_dim_image, out_dim_text = prepare_CLIP(args, device) 112 | 113 | # Prepare voxel2clip 114 | voxel2clip = prepare_voxel2clip(args, out_dim_image, out_dim_text, device) 115 | 116 | # Prepare coco captions 117 | prompts_list = prepare_coco(args) 118 | 119 | # Init Trainer 120 | trainer = prepare_trainer(args, accelerator, voxel2clip, clip_extractor, prompts_list, device) 121 | trainer.prepare_wandb(local_rank, args) 122 | # trainer.prepare_multi_gpu() 123 | 124 | # Train or Adapt 125 | trainer.train(local_rank) 126 | 127 | print("\n===Finished!===\n") 128 | 129 | if __name__ == '__main__': 130 | main() -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset, DataLoader 8 | import utils 9 | import kornia 10 | from kornia.augmentation.container import AugmentationSequential 11 | 12 | 13 | img_augment = AugmentationSequential( 14 | kornia.augmentation.RandomResizedCrop((224,224), (0.8,1), p=0.3), 15 | kornia.augmentation.Resize((224, 224)), 16 | kornia.augmentation.RandomBrightness(brightness=(0.8, 1.2), clip_output=True, p=0.2), 17 | kornia.augmentation.RandomContrast(contrast=(0.8, 1.2), clip_output=True, p=0.2), 18 | kornia.augmentation.RandomGamma((0.8, 1.2), (1.0, 1.3), p=0.2), 19 | kornia.augmentation.RandomSaturation((0.8,1.2), p=0.2), 20 | kornia.augmentation.RandomHue((-0.1,0.1), p=0.2), 21 | kornia.augmentation.RandomSharpness((0.8, 1.2), p=0.2), 22 | kornia.augmentation.RandomGrayscale(p=0.2), 23 | data_keys=["input"], 24 | ) 25 | 26 | class NSDDataset(Dataset): 27 | def __init__(self, root_dir, extensions=None, pool_num=8192, pool_type="max", length=None): 28 | self.root_dir = root_dir 29 | self.extensions = extensions if extensions else [] 30 | self.pool_num = pool_num 31 | self.pool_type = pool_type 32 | self.samples = self._load_samples() 33 | self.samples_keys = sorted(self.samples.keys()) 34 | self.length = length 35 | if length is not None: 36 | if length > len(self.samples_keys): 37 | pass # enlarge the dataset 38 | elif length > 0: 39 | self.samples_keys = self.samples_keys[:length] 40 | elif length < 0: 41 | self.samples_keys = self.samples_keys[length:] 42 | elif length == 0: 43 | raise ValueError("length must be a non-zero value!") 44 | else: 45 | self.length = len(self.samples_keys) 46 | 47 | def _load_samples(self): 48 | files = os.listdir(self.root_dir) 49 | samples = {} 50 | for file in files: 51 | file_path = os.path.join(self.root_dir, file) 52 | sample_id, ext = file.split(".",maxsplit=1) 53 | if ext in self.extensions: 54 | if sample_id in samples.keys(): 55 | samples[sample_id][ext] = file_path 56 | else: 57 | samples[sample_id]={"subj": file_path} 58 | samples[sample_id][ext] = file_path 59 | # print(samples) 60 | return samples 61 | 62 | def _load_image(self, image_path): 63 | image = Image.open(image_path).convert('RGB') 64 | image = np.array(image).astype(np.float32) / 255.0 65 | image = torch.from_numpy(image.transpose(2, 0, 1)) 66 | return image 67 | 68 | def _load_npy(self, npy_path): 69 | array = np.load(npy_path) 70 | array = torch.from_numpy(array) 71 | return array 72 | 73 | def vox_process(self, x): 74 | if self.pool_num is not None: 75 | x = pool_voxels(x, self.pool_num, self.pool_type) 76 | return x 77 | 78 | def subj_process(self, key): 79 | id = int(key.split("/")[-2].split("subj")[-1]) 80 | return id 81 | 82 | def aug_process(self, brain3d): 83 | return brain3d 84 | 85 | def __len__(self): 86 | # return len(self.samples_keys) 87 | return self.length 88 | 89 | def __getitem__(self, idx): 90 | idx = idx % len(self.samples_keys) 91 | sample_key = self.samples_keys[idx] 92 | sample = self.samples[sample_key] 93 | items = [] 94 | for ext in self.extensions: 95 | if ext == "jpg": 96 | items.append(self._load_image(sample[ext])) 97 | elif ext == "nsdgeneral.npy": 98 | voxel = self._load_npy(sample[ext]) 99 | items.append(self.vox_process(voxel)) 100 | elif ext == "coco73k.npy": 101 | items.append(self._load_npy(sample[ext])) 102 | elif ext == "subj": 103 | items.append(self.subj_process(sample[ext])) 104 | elif ext == "wholebrain_3d.npy": 105 | brain3d = self._load_npy(sample[ext]) 106 | items.append(self.aug_process(brain3d, )) 107 | 108 | return items 109 | 110 | def pool_voxels(voxels, pool_num, pool_type): 111 | voxels = voxels.float() 112 | if pool_type == 'avg': 113 | voxels = nn.AdaptiveAvgPool1d(pool_num)(voxels) 114 | elif pool_type == 'max': 115 | voxels = nn.AdaptiveMaxPool1d(pool_num)(voxels) 116 | elif pool_type == "resize": 117 | voxels = voxels.unsqueeze(1) # Add a dimension to make it (B, 1, L) 118 | voxels = F.interpolate(voxels, size=pool_num, mode='linear', align_corners=False) 119 | voxels = voxels.squeeze(1) 120 | 121 | return voxels 122 | 123 | def get_dataloader( 124 | root_dir, 125 | batch_size, 126 | num_workers=1, 127 | seed=42, 128 | is_shuffle=True, 129 | extensions=['nsdgeneral.npy', "jpg", 'coco73k.npy', "subj"], 130 | pool_type=None, 131 | pool_num=None, 132 | length=None, 133 | ): 134 | utils.seed_everything(seed) 135 | dataset = NSDDataset(root_dir=root_dir, extensions=extensions, pool_num=pool_num, pool_type=pool_type, length=length) 136 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=is_shuffle) 137 | 138 | return dataloader 139 | 140 | def get_dls(subject, data_path, batch_size, val_batch_size, num_workers, pool_type, pool_num, length, seed): 141 | train_path = "{}/webdataset_avg_split/train/subj0{}".format(data_path, subject) 142 | val_path = "{}/webdataset_avg_split/val/subj0{}".format(data_path, subject) 143 | extensions = ['nsdgeneral.npy', "jpg", 'coco73k.npy', "subj"] 144 | 145 | train_dl = get_dataloader( 146 | train_path, 147 | batch_size=batch_size, 148 | num_workers=num_workers, 149 | seed=seed, 150 | extensions=extensions, 151 | pool_type=pool_type, 152 | pool_num=pool_num, 153 | is_shuffle=True, 154 | length=length, 155 | ) 156 | 157 | val_dl = get_dataloader( 158 | val_path, 159 | batch_size=val_batch_size, 160 | num_workers=num_workers, 161 | seed=seed, 162 | extensions=extensions, 163 | pool_type=pool_type, 164 | pool_num=pool_num, 165 | is_shuffle=False, 166 | ) 167 | 168 | num_train=len(train_dl.dataset) 169 | num_val=len(val_dl.dataset) 170 | print(train_path,"\n",val_path) 171 | print("number of train data:", num_train) 172 | print("batch_size", batch_size) 173 | print("number of val data:", num_val) 174 | print("val_batch_size", val_batch_size) 175 | 176 | return train_dl, val_dl 177 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MindBridge: A Cross-Subject Brain Decoding Framework 2 | 3 | ![teasor](assets/MindBridge_teaser.png) 4 | 5 | [Shizun Wang](https://littlepure2333.github.io/home/), [Songhua Liu](http://121.37.94.87/), [Zhenxiong Tan](https://github.com/Yuanshi9815), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/) 6 | National University of Singapore 7 | 8 | **CVPR 2024 Highlight** 9 | [Project](https://littlepure2333.github.io/MindBridge/) | [Arxiv](https://arxiv.org/abs/2404.07850) 10 | 11 | ## News 12 | **[2024.04.12]** MindBridge's paper, project and code are released. 13 | **[2024.04.05]** MindBridge is selected as **CVPR 2024 Highlight** paper! 14 | **[2024.02.27]** MindBridge is accepted by **CVPR 2024**! 15 | 16 | ## Overview 17 | ![method](assets/MindBridge_method.png) 18 | 19 | > We present a novel approach, MindBridge, that achieves **cross-subject brain decoding by employing only one model**. Our proposed framework establishes a generic paradigm capable of addressing these challenges: **1) the inherent variability** in input dimensions across subjects due to differences in brain size; **2) the unique intrinsic neural patterns**, influencing how different individuals perceive and process sensory information; **3) limited data availability for new subjects** in real-world scenarios hampers the performance of decoding models. 20 | Notably, by cycle reconstruction, MindBridge can enable **novel brain signals synthesis**, which also can serve as pseudo data augmentation. Within the framework, we can **adapt** a pretrained MindBridge to a **new subject** using less data. 21 | 22 | ## Installation 23 | 24 | 1. Agree to the Natural Scenes Dataset's [Terms and Conditions](https://cvnlab.slite.page/p/IB6BSeW_7o/Terms-and-Conditions) and fill out the [NSD Data Access form](https://forms.gle/xue2bCdM9LaFNMeb7) 25 | 26 | 2. Download this repository: ``git clone https://github.com/littlepure2333/MindBridge.git`` 27 | 28 | 3. Create a conda environment and install the packages necessary to run the code. 29 | 30 | ```bash 31 | conda create -n mindbridge python=3.10.8 -y 32 | conda activate mindbridge 33 | pip install -r requirements.txt 34 | ``` 35 | 36 | ## Preparation 37 | 38 | ### Data 39 | 40 | Download the essential files we used from [NSD dataset](https://natural-scenes-dataset.s3.amazonaws.com/index.html), which contains `nsd_stim_info_merged.csv`. Also download COCO captions from [this link](http://images.cocodataset.org/annotations/annotations_trainval2017.zip) which contains `captions_train2017.json` and `captions_val2017.json`. 41 | We use the same preprocessed data as [MindEye's](https://github.com/MedARC-AI/fMRI-reconstruction-NSD), which can be downloaded from [Hugging Face](https://huggingface.co/datasets/pscotti/naturalscenesdataset/tree/main/webdataset_avg_split), and extract all files from the compressed tar files. 42 | Then organize the data as following: 43 | 44 |
45 | 46 | Data Organization 47 | 48 | ``` 49 | data/natural-scenes-dataset 50 | ├── nsddata 51 | │ └── experiments 52 | │ └── nsd 53 | │ └── nsd_stim_info_merged.csv 54 | ├── nsddata_stimuli 55 | │ └── stimuli 56 | │ └── nsd 57 | │ └── annotations 58 | │ ├── captions_train2017.json 59 | │ └── captions_val2017.json 60 | └── webdataset_avg_split 61 | ├── test 62 | │ ├── subj01 63 | │ │ ├── sample000000349.coco73k.npy 64 | │ │ ├── sample000000349.jpg 65 | │ │ ├── sample000000349.nsdgeneral.npy 66 | │ │ └── ... 67 | │ └── ... 68 | ├── train 69 | │ ├── subj01 70 | │ │ ├── sample000000300.coco73k.npy 71 | │ │ ├── sample000000300.jpg 72 | │ │ ├── sample000000300.nsdgeneral.npy 73 | │ │ └── ... 74 | │ └── ... 75 | └── val 76 | ├── subj01 77 | │ ├── sample000000000.coco73k.npy 78 | │ ├── sample000000000.jpg 79 | │ ├── sample000000000.nsdgeneral.npy 80 | │ └── ... 81 | └── ... 82 | ``` 83 | 84 |
85 | 86 | ### Checkpoints 87 | You can download our pretrained MindBridge checkpoints for "subject01, 02, 05, 07" from [Hugging Face](https://huggingface.co/littlepure2333/MindBridge/tree/main). And place the folders containing checkpoints under the directory `./train_logs/`. 88 | 89 | ## Training 90 | 91 | The training commands are described in the `./scripts` folder. You can check the command options in the `./src/options.py` file. For example, you can resume training by adding the `--resume` option to the command. The training progress can be monitored through [wandb](https://wandb.ai/). 92 | 93 | ### Training on single subject 94 | This script contains training the per-subject-per-model version of MindBridge (which refers to "Vanilla" in the paper) on one subject (e.g. subj01). You can also indicate which subject in the script. 95 | 96 | ```bash 97 | bash scripts/train_single.sh 98 | ``` 99 | 100 | ### Training on multi-subjects 101 | This script contains training MindBridge on multi-subjects (e.g. subj01, 02, 05, 07). You can also indicate which subjects in the script. 102 | 103 | ```bash 104 | bash scripts/train_bridge.sh 105 | ``` 106 | 107 | 108 | ### Adapting to a new subject 109 | Once the MindBridge is trained on some known "source subjects" (e.g. subj01, 02, 05), you can adapt the MindBridge to a new "target subject" (e.g. subj07) based on limited data volume (e.g. 4000 data points). You can also indicate which source subjects, which target subject, or data volume (length) in the script. 110 | 111 | ```bash 112 | bash scripts/adapt_bridge.sh 113 | ``` 114 | 115 | ## Reconstructing and evaluating 116 | This script will reconstruct one subject's images (e.g. subj01) on the test set from a MindBridge model (e.g. subj01, 02, 05, 07), then calculate all the metrics. The evaluated metrics will be saved in a csv file. You can indicate which MindBridge model and which subject in the script. 117 | 118 | ```bash 119 | bash scripts/inference.sh 120 | ``` 121 | 122 | 123 | ## TODO List 124 | - [x] Release pretrained checkpoints. 125 | - [ ] Training MindBridge on all 8 subjects in NSD dataset. 126 | 127 | ## Citation 128 | ``` 129 | @inproceedings{wang2024mindbridge, 130 | title={Mindbridge: A cross-subject brain decoding framework}, 131 | author={Wang, Shizun and Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao}, 132 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 133 | pages={11333--11342}, 134 | year={2024} 135 | } 136 | ``` 137 | 138 | ## Acknowledgement 139 | We extend our gratitude to [MindEye](https://github.com/MedARC-AI/fMRI-reconstruction-NSD) and [nsd_access](https://github.com/tknapen/nsd_access) for generously sharing their codebase, upon which ours is built. We are indebted to the [NSD dataset](https://natural-scenes-dataset.s3.amazonaws.com/index.html) for providing access to high-quality, publicly available data. 140 | Our appreciation also extends to the [Accelerate](https://huggingface.co/docs/accelerate/index) and [DeepSpeed](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for simplifying the process of efficient multi-GPU training, enabling us to train on the 24GB vRAM GPU, NVIDIA A5000. 141 | Special thanks to [Xingyi Yang](https://adamdad.github.io/) and [Yifan Zhang](https://sites.google.com/view/yifan-zhang) for their invaluable discussions. 142 | 143 |
144 | galaxy brain 145 |
146 | 147 | 148 | -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description="MindBridge Configuration") 4 | parser.add_argument( 5 | "--model_name", type=str, default="testing", 6 | help="name of model, used for ckpt saving and wandb logging (if enabled)", 7 | ) 8 | parser.add_argument( 9 | "--data_path", type=str, default="../data/natural-scenes-dataset", 10 | help="Path to where NSD data is stored / where to download it to", 11 | ) 12 | parser.add_argument( 13 | "--subj_list",type=int, default=[1], choices=[1,2,3,4,5,6,7,8], nargs='+', 14 | help="Subject index to train on", 15 | ) 16 | parser.add_argument( 17 | "--subj_source",type=int, default=[1], choices=[1,2,3,4,5,6,7,8], nargs='+', 18 | help="Source subject index to be adapted from (Can be multiple subjects)", 19 | ) 20 | parser.add_argument( 21 | "--subj_target",type=int, default=[1], choices=[1,2,3,4,5,6,7,8], 22 | help="Target subject index to be adapted to (Only one subject)", 23 | ) 24 | parser.add_argument( 25 | "--adapting",action=argparse.BooleanOptionalAction,default=False, 26 | help="Whether to adapt from source to target subject", 27 | ) 28 | parser.add_argument( 29 | "--batch_size", type=int, default=50, 30 | help="Batch size per GPU", 31 | ) 32 | parser.add_argument( 33 | "--val_batch_size", type=int, default=50, 34 | help="Validation batch size per GPU", 35 | ) 36 | parser.add_argument( 37 | "--clip_variant",type=str,default="ViT-L/14",choices=["RN50", "ViT-L/14", "ViT-B/32", "RN50x64"], 38 | help='OpenAI clip variant', 39 | ) 40 | parser.add_argument( 41 | "--wandb_log",action=argparse.BooleanOptionalAction,default=True, 42 | help="Whether to log to wandb", 43 | ) 44 | parser.add_argument( 45 | "--wandb_project",type=str,default="MindBridge", 46 | help="Wandb project name", 47 | ) 48 | parser.add_argument( 49 | "--resume",action=argparse.BooleanOptionalAction,default=False, 50 | help="Resume training from latest checkpoint, can't do it with --load_from at the same time", 51 | ) 52 | parser.add_argument( 53 | "--resume_id",type=str,default=None, 54 | help="Run id for wandb resume", 55 | ) 56 | parser.add_argument( 57 | "--load_from",type=str,default=None, 58 | help="load model and restart, can't do it with --resume at the same time", 59 | ) 60 | parser.add_argument( 61 | "--norm_embs",action=argparse.BooleanOptionalAction,default=True, 62 | help="Do l2-norming of CLIP embeddings", 63 | ) 64 | parser.add_argument( 65 | "--use_image_aug",action=argparse.BooleanOptionalAction,default=True, 66 | help="whether to use image augmentation", 67 | ) 68 | parser.add_argument( 69 | "--num_epochs",type=int,default=2000, 70 | help="number of epochs of training", 71 | ) 72 | parser.add_argument( 73 | "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'], 74 | help="Type of learning rate scheduler", 75 | ) 76 | parser.add_argument( 77 | "--ckpt_interval",type=int,default=10, 78 | help="Save backup ckpt and reconstruct every x epochs", 79 | ) 80 | parser.add_argument( 81 | "--eval_interval",type=int,default=10, 82 | help="Evaluate the model every x epochs", 83 | ) 84 | parser.add_argument( 85 | "--h_size",type=int,default=2048, 86 | help="Hidden size of MLP", 87 | ) 88 | parser.add_argument( 89 | "--n_blocks",type=int,default=2, 90 | help="Number of Hidden layers in MLP", 91 | ) 92 | parser.add_argument( 93 | "--seed",type=int,default=42, 94 | help="Seed for reproducibility", 95 | ) 96 | parser.add_argument( 97 | "--num_workers",type=int,default=5, 98 | help="Number of workers in dataloader" 99 | ) 100 | parser.add_argument( 101 | "--max_lr",type=float,default=3e-4, 102 | help="Max learning rate", 103 | ) 104 | parser.add_argument( 105 | "--pool_num", type=int, default=8192, 106 | help="Number of pooling", 107 | ) 108 | parser.add_argument( 109 | "--pool_type", type=str, default='max', 110 | help="Type of pooling: avg, max", 111 | ) 112 | parser.add_argument( 113 | "--mse_mult", type=float, default=1e4, 114 | help="The weight of mse loss", 115 | ) 116 | parser.add_argument( 117 | "--rec_mult", type=float, default=0, 118 | help="The weight of brain reconstruction loss", 119 | ) 120 | parser.add_argument( 121 | "--cyc_mult", type=float, default=0, 122 | help="The weight of cycle loss", 123 | ) 124 | parser.add_argument( 125 | # "--length", type=int, default=8559, 126 | "--length", type=int, default=None, 127 | help="Indicate dataset length", 128 | ) 129 | parser.add_argument( 130 | "--autoencoder_name", type=str, default=None, 131 | help="name of trained autoencoder model", 132 | ) 133 | parser.add_argument( 134 | "--subj_load",type=int, default=None, choices=[1,2,3,4,5,6,7,8], nargs='+', 135 | help="subj want to be load in the model", 136 | ) 137 | parser.add_argument( 138 | "--subj_test",type=int, default=1, choices=[1,2,3,4,5,6,7,8], 139 | help="subj to test", 140 | ) 141 | parser.add_argument( 142 | "--samples",type=int, default=None, nargs='+', 143 | help="Specify sample indice to reconstruction" 144 | ) 145 | parser.add_argument( 146 | "--img2img_strength",type=float, default=.85, 147 | help="How much img2img (1=no img2img; 0=outputting the low-level image itself)", 148 | ) 149 | parser.add_argument( 150 | "--guidance_scale",type=float, default=3.5, 151 | help="Guidance scale for diffusion model.", 152 | ) 153 | parser.add_argument( 154 | "-num_inference_steps",type=int, default=20, 155 | help="Number of inference steps for diffusion model.", 156 | ) 157 | parser.add_argument( 158 | "--recons_per_sample", type=int, default=16, 159 | help="How many recons to output, to then automatically pick the best one (MindEye uses 16)", 160 | ) 161 | parser.add_argument( 162 | "--plotting", action=argparse.BooleanOptionalAction, default=True, 163 | help="plotting all the results", 164 | ) 165 | parser.add_argument( 166 | "--vd_cache_dir", type=str, default='../weights', 167 | help="Where is cached Versatile Diffusion model; if not cached will download to this path", 168 | ) 169 | parser.add_argument( 170 | "--gpu_id", type=int, default=0, 171 | help="ID of the GPU to be used", 172 | ) 173 | parser.add_argument( 174 | "--ckpt_from", type=str, default='last', 175 | help="ckpt_from ['last', 'best']", 176 | ) 177 | parser.add_argument( 178 | "--text_image_ratio", type=float, default=0.5, 179 | help="text_image_ratio in Versatile Diffusion. Only valid when use_text=True. 0.5 means equally weight text and image, 0 means use only image", 180 | ) 181 | parser.add_argument( 182 | "--test_start", type=int, default=0, 183 | help="test range start index", 184 | ) 185 | parser.add_argument( 186 | "--test_end", type=int, default=None, 187 | help="test range end index, the total length of test data is 982, so max index is 981", 188 | ) 189 | parser.add_argument( 190 | "--only_embeddings", action=argparse.BooleanOptionalAction, default=False, 191 | help="only return semantic embeddings of networks", 192 | ) 193 | parser.add_argument( 194 | "--synthesis", action=argparse.BooleanOptionalAction, default=False, 195 | help="synthesize new fMRI signals", 196 | ) 197 | parser.add_argument( 198 | "--verbose", action=argparse.BooleanOptionalAction, default=True, 199 | help="print more information", 200 | ) 201 | parser.add_argument( 202 | "--results_path", type=str, default=None, 203 | help="path to reconstructed outputs", 204 | ) 205 | 206 | args = parser.parse_args() -------------------------------------------------------------------------------- /src/recon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import torch 5 | from tqdm import tqdm 6 | from datetime import datetime 7 | 8 | import utils 9 | from models import Clipper, MindBridge, MindSingle, Voxel2StableDiffusionModel 10 | import data 11 | from options import args 12 | from eval import cal_metrics 13 | 14 | 15 | ## Load autoencoder 16 | def prepare_voxel2sd(args, ckpt_path, device): 17 | from models import Voxel2StableDiffusionModel 18 | checkpoint = torch.load(ckpt_path, map_location=device) 19 | state_dict = checkpoint['model_state_dict'] 20 | 21 | voxel2sd = Voxel2StableDiffusionModel(in_dim=args.num_voxels) 22 | 23 | voxel2sd.load_state_dict(state_dict,strict=False) 24 | voxel2sd.to(device) 25 | voxel2sd.eval() 26 | print("Loaded low-level model!") 27 | 28 | return voxel2sd 29 | 30 | def prepare_data(args): 31 | ## Load data 32 | subj_num_voxels = { 33 | 1: 15724, 34 | 2: 14278, 35 | 3: 15226, 36 | 4: 13153, 37 | 5: 13039, 38 | 6: 17907, 39 | 7: 12682, 40 | 8: 14386 41 | } 42 | args.num_voxels = subj_num_voxels[args.subj_test] 43 | 44 | test_path = "{}/webdataset_avg_split/test/subj0{}".format(args.data_path, args.subj_test) 45 | test_dl = data.get_dataloader( 46 | test_path, 47 | batch_size=args.batch_size, 48 | num_workers=args.num_workers, 49 | seed=args.seed, 50 | is_shuffle=False, 51 | extensions=['nsdgeneral.npy', "jpg", 'coco73k.npy', "subj"], 52 | pool_type=args.pool_type, 53 | pool_num=args.pool_num, 54 | ) 55 | 56 | return test_dl 57 | 58 | def prepare_VD(args, device): 59 | print('Creating versatile diffusion reconstruction pipeline...') 60 | from diffusers import VersatileDiffusionDualGuidedPipeline, UniPCMultistepScheduler 61 | from diffusers.models import DualTransformer2DModel 62 | try: 63 | vd_pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(args.vd_cache_dir) 64 | except: 65 | print("Downloading Versatile Diffusion to", args.vd_cache_dir) 66 | vd_pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( 67 | "shi-labs/versatile-diffusion", 68 | cache_dir = args.vd_cache_dir) 69 | 70 | vd_pipe.image_unet.eval().to(device) 71 | vd_pipe.vae.eval().to(device) 72 | vd_pipe.image_unet.requires_grad_(False) 73 | vd_pipe.vae.requires_grad_(False) 74 | 75 | vd_pipe.scheduler = UniPCMultistepScheduler.from_pretrained("shi-labs/versatile-diffusion", cache_dir=args.vd_cache_dir, subfolder="scheduler") 76 | 77 | # Set weighting of Dual-Guidance 78 | # text_image_ratio=0.5 means equally weight text and image, 0 means use only image 79 | for name, module in vd_pipe.image_unet.named_modules(): 80 | if isinstance(module, DualTransformer2DModel): 81 | module.mix_ratio = args.text_image_ratio 82 | for i, type in enumerate(("text", "image")): 83 | if type == "text": 84 | module.condition_lengths[i] = 77 85 | module.transformer_index_for_condition[i] = 1 # use the second (text) transformer 86 | else: 87 | module.condition_lengths[i] = 257 88 | module.transformer_index_for_condition[i] = 0 # use the first (image) transformer 89 | 90 | return vd_pipe 91 | 92 | def prepare_CLIP(args, device): 93 | clip_sizes = {"RN50": 1024, "ViT-L/14": 768, "ViT-B/32": 512, "ViT-H-14": 1024} 94 | clip_size = clip_sizes[args.clip_variant] 95 | out_dim_image = 257 * clip_size 96 | out_dim_text = 77 * clip_size 97 | clip_extractor = Clipper("ViT-L/14", hidden_state=True, norm_embs=True, device=device) 98 | 99 | return clip_extractor, out_dim_image, out_dim_text 100 | 101 | def prepare_voxel2clip(args, out_dim_image, out_dim_text, device): 102 | voxel2clip_kwargs = dict( 103 | in_dim=args.pool_num, out_dim_image=out_dim_image, out_dim_text=out_dim_text, 104 | h=args.h_size, n_blocks=args.n_blocks, subj_list=args.subj_load) 105 | 106 | # only need to load Single-subject version of MindBridge 107 | voxel2clip = MindSingle(**voxel2clip_kwargs) 108 | 109 | outdir = f'../train_logs/{args.model_name}' 110 | ckpt_path = os.path.join(outdir, f'{args.ckpt_from}.pth') 111 | print("ckpt_path",ckpt_path) 112 | checkpoint = torch.load(ckpt_path, map_location='cpu') 113 | print("EPOCH: ",checkpoint['epoch']) 114 | state_dict = checkpoint['model_state_dict'] 115 | 116 | voxel2clip.load_state_dict(state_dict,strict=False) 117 | voxel2clip.requires_grad_(False) 118 | voxel2clip.eval().to(device) 119 | 120 | return voxel2clip 121 | 122 | def main(device): 123 | args.batch_size = 1 124 | if args.subj_load is None: 125 | args.subj_load = [args.subj_test] 126 | 127 | # Load data 128 | test_dl = prepare_data(args) 129 | num_test = len(test_dl) 130 | 131 | # Load autoencoder 132 | outdir_ae = f'../train_logs/{args.autoencoder_name}' 133 | ckpt_path = os.path.join(outdir_ae, f'epoch120.pth') 134 | if os.path.exists(ckpt_path): 135 | voxel2sd = prepare_voxel2sd(args, ckpt_path, device) 136 | # pool later 137 | args.pool_type = None 138 | else: 139 | print("No valid path for low-level model specified; not using img2img!") 140 | args.img2img_strength = 1 141 | 142 | # Load VD pipeline 143 | vd_pipe = prepare_VD(args, device) 144 | unet = vd_pipe.image_unet 145 | vae = vd_pipe.vae 146 | noise_scheduler = vd_pipe.scheduler 147 | 148 | # Load CLIP 149 | clip_extractor, out_dim_image, out_dim_text = prepare_CLIP(args, device) 150 | 151 | # load voxel2clip 152 | voxel2clip = prepare_voxel2clip(args, out_dim_image, out_dim_text, device) 153 | 154 | outdir = f'../train_logs/{args.model_name}' 155 | save_dir = os.path.join(outdir, f"recon_on_subj{args.subj_test}") 156 | os.makedirs(save_dir, exist_ok=True) 157 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 158 | 159 | # define test range 160 | test_range = np.arange(num_test) 161 | if args.test_end is None: 162 | args.test_end = num_test 163 | 164 | # define recon logic 165 | only_lowlevel = False 166 | if args.img2img_strength == 1: 167 | img2img = False 168 | elif args.img2img_strength == 0: 169 | img2img = True 170 | only_lowlevel = True 171 | else: 172 | img2img = True 173 | 174 | # recon loop 175 | for val_i, (voxel, img, coco, subj) in enumerate(tqdm(test_dl,total=len(test_range))): 176 | if val_i < args.test_start: 177 | continue 178 | if val_i >= args.test_end: 179 | break 180 | if (args.samples is not None) and (val_i not in args.samples): 181 | continue 182 | 183 | voxel = torch.mean(voxel,axis=1).float().to(device) 184 | img = img.to(device) 185 | 186 | with torch.no_grad(): 187 | if args.only_embeddings: 188 | results = voxel2clip(voxel) 189 | embeddings = results[:2] 190 | torch.save(embeddings, os.path.join(save_dir, f'embeddings_{val_i}.pt')) 191 | continue 192 | if img2img: # will apply low-level and high-level pipeline 193 | ae_preds = voxel2sd(voxel) 194 | blurry_recons = vd_pipe.vae.decode(ae_preds.to(device)/0.18215).sample / 2 + 0.5 195 | 196 | if val_i==0: 197 | plt.imshow(utils.torch_to_Image(blurry_recons)) 198 | plt.show() 199 | 200 | # pooling 201 | voxel = data.pool_voxels(voxel, args.pool_num, args.pool_type) 202 | else: # only high-level pipeline 203 | blurry_recons = None 204 | 205 | if only_lowlevel: # only low-level pipeline 206 | brain_recons = blurry_recons 207 | else: 208 | grid, brain_recons, best_picks, recon_img = utils.reconstruction( 209 | img, voxel, voxel2clip, 210 | clip_extractor, unet, vae, noise_scheduler, 211 | img_lowlevel = blurry_recons, 212 | num_inference_steps = args.num_inference_steps, 213 | n_samples_save = args.batch_size, 214 | recons_per_sample = args.recons_per_sample, 215 | guidance_scale = args.guidance_scale, 216 | img2img_strength = args.img2img_strength, # 0=fully rely on img_lowlevel, 1=not doing img2img 217 | seed = args.seed, 218 | plotting = args.plotting, 219 | verbose = args.verbose, 220 | device=device, 221 | mem_efficient=False, 222 | ) 223 | 224 | if args.plotting: 225 | grid.savefig(os.path.join(save_dir, f'{val_i}.png')) 226 | 227 | brain_recons = brain_recons[:,best_picks.astype(np.int8)] 228 | 229 | torch.save(img, os.path.join(save_dir, f'{val_i}_img.pt')) 230 | torch.save(brain_recons, os.path.join(save_dir, f'{val_i}_rec.pt')) 231 | 232 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S')) 233 | print("save path:", save_dir) 234 | 235 | if __name__ == "__main__": 236 | utils.seed_everything(seed=args.seed) 237 | 238 | device = torch.device('cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu') 239 | print("device:",device) 240 | 241 | main(device) 242 | 243 | # args.results_path = f'../train_logs/{args.model_name}/recon_on_subj{args.subj_test}' 244 | # cal_metrics(args.results_path, device) -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import scipy as sp 4 | import pandas as pd 5 | import torch 6 | from torchvision import transforms 7 | from tqdm import tqdm 8 | from options import args 9 | import utils 10 | from torchvision.models.feature_extraction import create_feature_extractor 11 | 12 | @torch.no_grad() 13 | def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True, device='cpu'): 14 | preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device)) 15 | reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device)) 16 | if feature_layer is None: 17 | preds = preds.float().flatten(1).cpu().numpy() 18 | reals = reals.float().flatten(1).cpu().numpy() 19 | else: 20 | preds = preds[feature_layer].float().flatten(1).cpu().numpy() 21 | reals = reals[feature_layer].float().flatten(1).cpu().numpy() 22 | 23 | r = np.corrcoef(reals, preds) 24 | r = r[:len(all_images), len(all_images):] 25 | congruents = np.diag(r) 26 | 27 | success = r < congruents 28 | success_cnt = np.sum(success, 0) 29 | 30 | if return_avg: 31 | perf = np.mean(success_cnt) / (len(all_images)-1) 32 | return perf 33 | else: 34 | return success_cnt, len(all_images)-1 35 | 36 | def cal_metrics(results_path, device): 37 | # Load all images and brain recons 38 | all_files = [f for f in os.listdir(results_path)] 39 | 40 | number = 0 41 | all_images, all_brain_recons = None, None 42 | all_images, all_brain_recons = [], [] 43 | for file in tqdm(all_files): 44 | if file.endswith("_img.pt"): 45 | if file.replace("_img.pt", "_rec.pt") in all_files: 46 | number += 1 47 | all_images.append(torch.load(os.path.join(results_path, file), map_location=device)) 48 | all_brain_recons.append(torch.load(os.path.join(results_path, file.replace("_img.pt", "_rec.pt")), map_location=device)) 49 | 50 | all_images = torch.vstack(all_images) 51 | all_brain_recons = torch.vstack(all_brain_recons) 52 | all_images = all_images.to(device) 53 | all_brain_recons = all_brain_recons.to(device).to(all_images.dtype).clamp(0,1).squeeze() 54 | 55 | print("Images shape:", all_images.shape) 56 | print("Recons shape:", all_brain_recons.shape) 57 | print("Number:", number) 58 | 59 | ### PixCorr 60 | preprocess = transforms.Compose([ 61 | transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), 62 | ]) 63 | 64 | # Flatten images while keeping the batch dimension 65 | all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu() 66 | all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu() 67 | 68 | print(all_images_flattened.shape) 69 | print(all_brain_recons_flattened.shape) 70 | 71 | print("\n------calculating pixcorr------") 72 | corrsum = 0 73 | for i in tqdm(range(number)): 74 | corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1] 75 | corrmean = corrsum / number 76 | 77 | pixcorr = corrmean 78 | print(pixcorr) 79 | 80 | del all_images_flattened 81 | del all_brain_recons_flattened 82 | torch.cuda.empty_cache() 83 | 84 | ### SSIM 85 | # see https://github.com/zijin-gu/meshconv-decoding/issues/3 86 | from skimage.color import rgb2gray 87 | from skimage.metrics import structural_similarity as ssim 88 | 89 | preprocess = transforms.Compose([ 90 | transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR), 91 | ]) 92 | 93 | # convert image to grayscale with rgb2grey 94 | img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu()) 95 | recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu()) 96 | print("converted, now calculating ssim...") 97 | 98 | ssim_score=[] 99 | for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)): 100 | ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0)) 101 | 102 | ssim = np.mean(ssim_score) 103 | print(ssim) 104 | 105 | 106 | #### AlexNet 107 | from torchvision.models import alexnet, AlexNet_Weights 108 | alex_weights = AlexNet_Weights.IMAGENET1K_V1 109 | 110 | alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device) 111 | alex_model.eval().requires_grad_(False) 112 | 113 | # see alex_weights.transforms() 114 | preprocess = transforms.Compose([ 115 | transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR), 116 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 117 | std=[0.229, 0.224, 0.225]), 118 | ]) 119 | 120 | layer = 'early, AlexNet(2)' 121 | print(f"\n---{layer}---") 122 | all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 123 | alex_model, preprocess, 'features.4', device=device) 124 | alexnet2 = np.mean(all_per_correct) 125 | print(f"2-way Percent Correct: {alexnet2:.4f}") 126 | 127 | layer = 'mid, AlexNet(5)' 128 | print(f"\n---{layer}---") 129 | all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images, 130 | alex_model, preprocess, 'features.11', device=device) 131 | alexnet5 = np.mean(all_per_correct) 132 | print(f"2-way Percent Correct: {alexnet5:.4f}") 133 | 134 | del alex_model 135 | torch.cuda.empty_cache() 136 | 137 | #### InceptionV3 138 | from torchvision.models import inception_v3, Inception_V3_Weights 139 | weights = Inception_V3_Weights.DEFAULT 140 | inception_model = create_feature_extractor(inception_v3(weights=weights), 141 | return_nodes=['avgpool']).to(device) 142 | inception_model.eval().requires_grad_(False) 143 | 144 | # see weights.transforms() 145 | preprocess = transforms.Compose([ 146 | transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR), 147 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 148 | std=[0.229, 0.224, 0.225]), 149 | ]) 150 | 151 | all_per_correct = two_way_identification(all_brain_recons, all_images, 152 | inception_model, preprocess, 'avgpool', device=device) 153 | 154 | inception = np.mean(all_per_correct) 155 | print(f"2-way Percent Correct: {inception:.4f}") 156 | 157 | del inception_model 158 | torch.cuda.empty_cache() 159 | 160 | 161 | #### CLIP 162 | import clip 163 | clip_model, preprocess = clip.load("ViT-L/14", device=device) 164 | 165 | preprocess = transforms.Compose([ 166 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR), 167 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 168 | std=[0.26862954, 0.26130258, 0.27577711]), 169 | ]) 170 | 171 | all_per_correct = two_way_identification(all_brain_recons, all_images, 172 | clip_model.encode_image, preprocess, None, device=device) # final layer 173 | clip_ = np.mean(all_per_correct) 174 | print(f"2-way Percent Correct: {clip_:.4f}") 175 | 176 | 177 | del clip_model 178 | torch.cuda.empty_cache() 179 | 180 | 181 | #### Efficient Net 182 | from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights 183 | weights = EfficientNet_B1_Weights.DEFAULT 184 | eff_model = create_feature_extractor(efficientnet_b1(weights=weights), 185 | return_nodes=['avgpool']).to(device) 186 | eff_model.eval().requires_grad_(False) 187 | 188 | # see weights.transforms() 189 | preprocess = transforms.Compose([ 190 | transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR), 191 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 192 | std=[0.229, 0.224, 0.225]), 193 | ]) 194 | 195 | gt = eff_model(preprocess(all_images))['avgpool'] 196 | gt = gt.reshape(len(gt),-1).cpu().numpy() 197 | fake = eff_model(preprocess(all_brain_recons))['avgpool'] 198 | fake = fake.reshape(len(fake),-1).cpu().numpy() 199 | 200 | effnet = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean() 201 | print("Distance:",effnet) 202 | 203 | 204 | del eff_model 205 | torch.cuda.empty_cache() 206 | 207 | 208 | #### SwAV 209 | swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50') 210 | swav_model = create_feature_extractor(swav_model, 211 | return_nodes=['avgpool']).to(device) 212 | swav_model.eval().requires_grad_(False) 213 | 214 | preprocess = transforms.Compose([ 215 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR), 216 | transforms.Normalize(mean=[0.485, 0.456, 0.406], 217 | std=[0.229, 0.224, 0.225]), 218 | ]) 219 | 220 | gt = swav_model(preprocess(all_images))['avgpool'] 221 | gt = gt.reshape(len(gt),-1).cpu().numpy() 222 | fake = swav_model(preprocess(all_brain_recons))['avgpool'] 223 | fake = fake.reshape(len(fake),-1).cpu().numpy() 224 | 225 | swav = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean() 226 | print("Distance:",swav,"\n") 227 | 228 | 229 | del swav_model 230 | torch.cuda.empty_cache() 231 | 232 | 233 | # # Display in table 234 | # Create a dictionary to store variable names and their corresponding values 235 | data = { 236 | "Metric": ["PixCorr", "SSIM", "AlexNet(2)", "AlexNet(5)", "InceptionV3", "CLIP", "EffNet-B", "SwAV"], 237 | "Value": [pixcorr, ssim, alexnet2, alexnet5, inception, clip_, effnet, swav], 238 | } 239 | print(results_path) 240 | df = pd.DataFrame(data) 241 | print(df.to_string(index=False)) 242 | 243 | 244 | # save table to txt file 245 | df.to_csv(os.path.join(results_path, f'_metrics_on_{number}samples.csv'), sep='\t', index=False) 246 | 247 | if __name__ == "__main__": 248 | utils.seed_everything(seed=args.seed) 249 | 250 | device = torch.device('cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu') 251 | print("device:", device) 252 | 253 | assert args.results_path is not None, "Please specify the path to the results folder" 254 | cal_metrics(args.results_path, device) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torchvision import transforms 3 | import torch 4 | import torch.nn as nn 5 | import PIL 6 | import random 7 | import os 8 | import matplotlib.pyplot as plt 9 | import torchsnooper 10 | 11 | 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | def seed_everything(seed=0, cudnn_deterministic=True): 15 | random.seed(seed) 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | if cudnn_deterministic: 22 | torch.backends.cudnn.deterministic = True 23 | else: 24 | ## needs to be False to use conv3D 25 | print('Note: not using cudnn.deterministic') 26 | 27 | def np_to_Image(x): 28 | if x.ndim==4: 29 | x=x[0] 30 | return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8')) 31 | 32 | def torch_to_Image(x): 33 | if x.ndim==4: 34 | x=x[0] 35 | return transforms.ToPILImage()(x) 36 | 37 | def Image_to_torch(x): 38 | try: 39 | x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5 40 | except: 41 | x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5 42 | return x 43 | 44 | def torch_to_matplotlib(x,device=device): 45 | if torch.mean(x)>10: 46 | x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8) 47 | else: 48 | x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8) 49 | if device=='cpu': 50 | return x[0] 51 | else: 52 | return x.cpu().numpy()[0] 53 | 54 | def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8): 55 | #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements 56 | numerator = A @ B.T 57 | A_l2 = torch.mul(A, A).sum(axis=dim) 58 | B_l2 = torch.mul(B, B).sum(axis=dim) 59 | denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps)) 60 | return torch.div(numerator, denominator) 61 | 62 | def batchwise_cosine_similarity(Z,B): 63 | # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc 64 | B = B.T 65 | Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1). 66 | B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b). 67 | cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T 68 | return cosine_similarity 69 | 70 | def topk(similarities,labels,k=5): 71 | if k > similarities.shape[0]: 72 | k = similarities.shape[0] 73 | topsum=0 74 | for i in range(k): 75 | topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels) 76 | return topsum 77 | 78 | def soft_clip_loss(preds, targs, temp=0.005, eps=1e-10): 79 | clip_clip = (targs @ targs.T)/temp + eps 80 | check_loss(clip_clip, "clip_clip") 81 | brain_clip = (preds @ targs.T)/temp + eps 82 | check_loss(brain_clip, "brain_clip") 83 | 84 | loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean() 85 | check_loss(loss1, "loss1") 86 | loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean() 87 | check_loss(loss2, "loss2") 88 | 89 | loss = (loss1 + loss2)/2 90 | return loss 91 | 92 | def count_params(model): 93 | total = sum(p.numel() for p in model.parameters()) 94 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) 95 | print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable)) 96 | 97 | def image_grid(imgs, rows, cols): 98 | w, h = imgs[0].size 99 | grid = PIL.Image.new('RGB', size=(cols*w, rows*h)) 100 | for i, img in enumerate(imgs): 101 | grid.paste(img, box=(i%cols*w, i//cols*h)) 102 | return grid 103 | 104 | def check_loss(loss, message="loss"): 105 | if loss.isnan().any(): 106 | raise ValueError(f'NaN loss in {message}') 107 | 108 | def cosine_anneal(start, end, steps): 109 | return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1))) 110 | 111 | def decode_latents(latents,vae): 112 | latents = 1 / 0.18215 * latents 113 | image = vae.decode(latents).sample 114 | image = (image / 2 + 0.5).clamp(0, 1) 115 | return image 116 | 117 | @torch.no_grad() 118 | def reconstruction( 119 | image, voxel, voxel2clip, 120 | clip_extractor, 121 | unet, vae, noise_scheduler, 122 | img_lowlevel = None, 123 | num_inference_steps = 50, 124 | recons_per_sample = 1, 125 | guidance_scale = 7.5, 126 | img2img_strength = .85, 127 | seed = 42, 128 | plotting=True, 129 | verbose=False, 130 | n_samples_save=1, 131 | device = None, 132 | mem_efficient = True, 133 | 134 | ): 135 | assert n_samples_save==1, "n_samples_save must = 1. Function must be called one image at a time" 136 | assert recons_per_sample>0, "recons_per_sample must > 0" 137 | 138 | brain_recons = None 139 | 140 | voxel=voxel[:n_samples_save] 141 | image=image[:n_samples_save] 142 | B = voxel.shape[0] 143 | 144 | if mem_efficient: 145 | clip_extractor.to("cpu") 146 | unet.to("cpu") 147 | vae.to("cpu") 148 | else: 149 | clip_extractor.to(device) 150 | unet.to(device) 151 | vae.to(device) 152 | 153 | if unet is not None: 154 | do_classifier_free_guidance = guidance_scale > 1.0 155 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 156 | height = unet.config.sample_size * vae_scale_factor 157 | width = unet.config.sample_size * vae_scale_factor 158 | generator = torch.Generator(device=device) 159 | generator.manual_seed(seed) 160 | 161 | if voxel2clip is not None: 162 | clip_results = voxel2clip(voxel) 163 | 164 | if mem_efficient: 165 | voxel2clip.to('cpu') 166 | 167 | brain_clip_image_embeddings, brain_clip_text_embeddings = clip_results[:2] 168 | brain_clip_image_embeddings = brain_clip_image_embeddings.reshape(B,-1,768) 169 | brain_clip_text_embeddings = brain_clip_text_embeddings.reshape(B,-1,768) 170 | 171 | brain_clip_image_embeddings = brain_clip_image_embeddings.repeat(recons_per_sample, 1, 1) 172 | brain_clip_text_embeddings = brain_clip_text_embeddings.repeat(recons_per_sample, 1, 1) 173 | 174 | if recons_per_sample > 0: 175 | for samp in range(len(brain_clip_image_embeddings)): 176 | brain_clip_image_embeddings[samp] = brain_clip_image_embeddings[samp]/(brain_clip_image_embeddings[samp,0].norm(dim=-1).reshape(-1, 1, 1) + 1e-6) 177 | brain_clip_text_embeddings[samp] = brain_clip_text_embeddings[samp]/(brain_clip_text_embeddings[samp,0].norm(dim=-1).reshape(-1, 1, 1) + 1e-6) 178 | input_embedding = brain_clip_image_embeddings#.repeat(recons_per_sample, 1, 1) 179 | if verbose: print("input_embedding",input_embedding.shape) 180 | 181 | prompt_embeds = brain_clip_text_embeddings 182 | if verbose: print("prompt_embedding",prompt_embeds.shape) 183 | 184 | if do_classifier_free_guidance: 185 | input_embedding = torch.cat([torch.zeros_like(input_embedding), input_embedding]).to(device).to(unet.dtype) 186 | prompt_embeds = torch.cat([torch.zeros_like(prompt_embeds), prompt_embeds]).to(device).to(unet.dtype) 187 | 188 | # 3. dual_prompt_embeddings 189 | input_embedding = torch.cat([prompt_embeds, input_embedding], dim=1) 190 | 191 | # 4. Prepare timesteps 192 | noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device) 193 | 194 | # 5b. Prepare latent variables 195 | batch_size = input_embedding.shape[0] // 2 # divide by 2 bc we doubled it for classifier-free guidance 196 | shape = (batch_size, unet.in_channels, height // vae_scale_factor, width // vae_scale_factor) 197 | if img_lowlevel is not None: # use img_lowlevel for img2img initialization 198 | init_timestep = min(int(num_inference_steps * img2img_strength), num_inference_steps) 199 | t_start = max(num_inference_steps - init_timestep, 0) 200 | timesteps = noise_scheduler.timesteps[t_start:] 201 | latent_timestep = timesteps[:1].repeat(batch_size) 202 | 203 | if verbose: print("img_lowlevel", img_lowlevel.shape) 204 | img_lowlevel_embeddings = clip_extractor.normalize(img_lowlevel) 205 | if verbose: print("img_lowlevel_embeddings", img_lowlevel_embeddings.shape) 206 | if mem_efficient: 207 | vae.to(device) 208 | init_latents = vae.encode(img_lowlevel_embeddings.to(device).to(vae.dtype)).latent_dist.sample(generator) 209 | init_latents = vae.config.scaling_factor * init_latents 210 | init_latents = init_latents.repeat(recons_per_sample, 1, 1, 1) 211 | 212 | noise = torch.randn([recons_per_sample, 4, 64, 64], device=device, 213 | generator=generator, dtype=input_embedding.dtype) 214 | init_latents = noise_scheduler.add_noise(init_latents, noise, latent_timestep) 215 | latents = init_latents 216 | else: 217 | timesteps = noise_scheduler.timesteps 218 | latents = torch.randn([recons_per_sample, 4, 64, 64], device=device, 219 | generator=generator, dtype=input_embedding.dtype) 220 | latents = latents * noise_scheduler.init_noise_sigma 221 | 222 | # 7. Denoising loop 223 | if mem_efficient: 224 | unet.to(device) 225 | for i, t in enumerate(timesteps): 226 | # expand the latents if we are doing classifier free guidance 227 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 228 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 229 | if verbose: print("timesteps: {}, latent_model_input: {}, input_embedding: {}".format(i, latent_model_input.shape, input_embedding.shape)) 230 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_embedding).sample 231 | 232 | # perform guidance 233 | if do_classifier_free_guidance: 234 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 235 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 236 | 237 | # compute the previous noisy sample x_t -> x_t-1 238 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample 239 | 240 | if mem_efficient: 241 | unet.to("cpu") 242 | 243 | recons = decode_latents(latents.to(device),vae.to(device)).detach().cpu() 244 | 245 | brain_recons = recons.unsqueeze(0) 246 | 247 | if verbose: print("brain_recons",brain_recons.shape) 248 | 249 | # pick best reconstruction out of several 250 | best_picks = np.zeros(n_samples_save).astype(np.int16) 251 | 252 | if mem_efficient: 253 | vae.to("cpu") 254 | unet.to("cpu") 255 | clip_extractor.to(device) 256 | 257 | clip_image_target = clip_extractor.embed_image(image) 258 | clip_image_target_norm = nn.functional.normalize(clip_image_target.flatten(1), dim=-1) 259 | sims=[] 260 | for im in range(recons_per_sample): 261 | currecon = clip_extractor.embed_image(brain_recons[0,[im]].float()).to(clip_image_target_norm.device).to(clip_image_target_norm.dtype) 262 | currecon = nn.functional.normalize(currecon.view(len(currecon),-1),dim=-1) 263 | cursim = batchwise_cosine_similarity(clip_image_target_norm,currecon) 264 | sims.append(cursim.item()) 265 | if verbose: print(sims) 266 | best_picks[0] = int(np.nanargmax(sims)) 267 | if verbose: print(best_picks) 268 | if mem_efficient: 269 | clip_extractor.to("cpu") 270 | voxel2clip.to(device) 271 | 272 | img2img_samples = 0 if img_lowlevel is None else 1 273 | num_xaxis_subplots = 1+img2img_samples+recons_per_sample 274 | if plotting: 275 | fig, ax = plt.subplots(n_samples_save, num_xaxis_subplots, 276 | figsize=(num_xaxis_subplots*5,6*n_samples_save),facecolor=(1, 1, 1)) 277 | else: 278 | fig = None 279 | recon_img = None 280 | 281 | im = 0 282 | if plotting: 283 | ax[0].set_title(f"Original Image") 284 | ax[0].imshow(torch_to_Image(image[im])) 285 | if img2img_samples == 1: 286 | ax[1].set_title(f"Img2img ({img2img_strength})") 287 | ax[1].imshow(torch_to_Image(img_lowlevel[im].clamp(0,1))) 288 | for ii,i in enumerate(range(num_xaxis_subplots-recons_per_sample,num_xaxis_subplots)): 289 | recon = brain_recons[im][ii] 290 | if plotting: 291 | if ii == best_picks[im]: 292 | ax[i].set_title(f"Reconstruction",fontweight='bold') 293 | recon_img = recon 294 | else: 295 | ax[i].set_title(f"Recon {ii+1} from brain") 296 | ax[i].imshow(torch_to_Image(recon)) 297 | if plotting: 298 | for i in range(num_xaxis_subplots): 299 | ax[i].axis('off') 300 | 301 | return fig, brain_recons, best_picks, recon_img 302 | -------------------------------------------------------------------------------- /src/models.py: -------------------------------------------------------------------------------- 1 | import clip 2 | import torchsnooper 3 | import math 4 | import random 5 | import torch 6 | import torch.nn as nn 7 | import numpy as np 8 | from torchvision import transforms 9 | 10 | def add_hooks(module, parent_name=''): 11 | module_name = module.__class__.__name__ 12 | if parent_name: 13 | module_name = f'{parent_name}.{module_name}' 14 | 15 | module.register_forward_hook(lambda mod, inp, out: forward_hook(mod, inp, out, module_name)) 16 | module.register_backward_hook(lambda mod, inp, out: backward_hook(mod, inp, out, module_name)) 17 | 18 | for name, child in module.named_children(): 19 | add_hooks(child, parent_name=f'{module_name}.{name}') 20 | 21 | def forward_hook(module, input, output, name): 22 | if output.isnan().any(): 23 | print(f"NaN detected in forward pass in module: {name}") 24 | print(f"Input: {input}") 25 | print(f"Output: {output}") 26 | 27 | def backward_hook(module, grad_input, grad_output, name): 28 | if any(tensor is not None and torch.isnan(tensor).any() for tensor in [*grad_input, *grad_output]): 29 | print(f"NaN detected in backward pass in module: {name}") 30 | print(f"Grad Input: {grad_input}") 31 | print(f"Grad Output: {grad_output}") 32 | 33 | class Clipper(torch.nn.Module): 34 | def __init__(self, clip_variant, clamp_embs=False, norm_embs=False, 35 | hidden_state=False, device=torch.device('cpu')): 36 | super().__init__() 37 | assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \ 38 | "clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64" 39 | print(clip_variant, device) 40 | 41 | if clip_variant=="ViT-L/14" and hidden_state: 42 | from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPTokenizer 43 | image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").eval() 44 | image_encoder = image_encoder.to(device) 45 | for param in image_encoder.parameters(): 46 | param.requires_grad = False # dont need to calculate gradients 47 | self.image_encoder = image_encoder 48 | 49 | text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").eval() 50 | text_encoder = text_encoder.to(device) 51 | for param in text_encoder.parameters(): 52 | param.requires_grad = False # dont need to calculate gradients 53 | self.text_encoder = text_encoder 54 | self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 55 | 56 | elif hidden_state: 57 | raise Exception("hidden_state embeddings only works with ViT-L/14 right now") 58 | 59 | clip_model, preprocess = clip.load(clip_variant, device=device) 60 | clip_model.eval() # dont want to train model 61 | for param in clip_model.parameters(): 62 | param.requires_grad = False # dont need to calculate gradients 63 | 64 | self.clip = clip_model 65 | self.clip_variant = clip_variant 66 | if clip_variant == "RN50x64": 67 | self.clip_size = (448,448) 68 | else: 69 | self.clip_size = (224,224) 70 | 71 | preproc = transforms.Compose([ 72 | transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC, antialias=None), 73 | transforms.CenterCrop(size=self.clip_size), 74 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 75 | ]) 76 | self.preprocess = preproc 77 | self.hidden_state = hidden_state 78 | self.mean = np.array([0.48145466, 0.4578275, 0.40821073]) 79 | self.std = np.array([0.26862954, 0.26130258, 0.27577711]) 80 | self.normalize = transforms.Normalize(self.mean, self.std) 81 | self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist()) 82 | self.clamp_embs = clamp_embs 83 | self.norm_embs = norm_embs 84 | self.device= device 85 | 86 | def versatile_normalize_embeddings(encoder_output): 87 | embeds = encoder_output.last_hidden_state 88 | embeds = image_encoder.vision_model.post_layernorm(embeds) 89 | embeds = image_encoder.visual_projection(embeds) 90 | return embeds 91 | self.versatile_normalize_embeddings = versatile_normalize_embeddings 92 | 93 | def resize_image(self, image): 94 | # note: antialias should be False if planning to use Pinkney's Image Variation SD model 95 | return transforms.Resize(self.clip_size, antialias=None)(image.to(self.device)) 96 | 97 | def embed_image(self, image): 98 | """Expects images in -1 to 1 range""" 99 | if self.hidden_state: 100 | # clip_emb = self.preprocess((image/1.5+.25).to(self.device)) # for some reason the /1.5+.25 prevents oversaturation 101 | clip_emb = self.preprocess((image).to(self.device)) 102 | clip_emb = self.image_encoder(clip_emb) 103 | clip_emb = self.versatile_normalize_embeddings(clip_emb) 104 | else: 105 | clip_emb = self.preprocess(image.to(self.device)) 106 | clip_emb = self.clip.encode_image(clip_emb) 107 | # input is now in CLIP space, but mind-reader preprint further processes embeddings: 108 | if self.clamp_embs: 109 | clip_emb = torch.clamp(clip_emb, -1.5, 1.5) 110 | if self.norm_embs: 111 | if self.hidden_state: 112 | # normalize all tokens by cls token's norm 113 | clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1) 114 | else: 115 | clip_emb = nn.functional.normalize(clip_emb, dim=-1) 116 | return clip_emb 117 | 118 | def embed_text(self, prompt): 119 | r""" 120 | Encodes the prompt into text encoder hidden states. 121 | 122 | Args: 123 | prompt (`str` or `List[str]`): 124 | prompt to be encoded 125 | device: (`torch.device`): 126 | torch device 127 | num_images_per_prompt (`int`): 128 | number of images that should be generated per prompt 129 | do_classifier_free_guidance (`bool`): 130 | whether to use classifier free guidance or not 131 | """ 132 | 133 | def normalize_embeddings(encoder_output): 134 | embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) 135 | embeds_pooled = encoder_output.text_embeds 136 | embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) 137 | return embeds 138 | 139 | text_inputs = self.tokenizer( 140 | prompt, 141 | padding="max_length", 142 | max_length=self.tokenizer.model_max_length, 143 | truncation=True, 144 | return_tensors="pt", 145 | ) 146 | text_input_ids = text_inputs.input_ids 147 | untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids 148 | with torch.no_grad(): 149 | prompt_embeds = self.text_encoder( 150 | text_input_ids.to(self.device), 151 | ) 152 | prompt_embeds = normalize_embeddings(prompt_embeds) 153 | 154 | # duplicate text embeddings for each generation per prompt, using mps friendly method 155 | # bs_embed, seq_len, _ = prompt_embeds.shape 156 | # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 157 | # prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 158 | 159 | return prompt_embeds 160 | 161 | def embed_curated_annotations(self, annots): 162 | for i,b in enumerate(annots): 163 | t = '' 164 | while t == '': 165 | rand = torch.randint(5,(1,1))[0][0] 166 | t = b[0,rand] 167 | if i==0: 168 | txt = np.array(t) 169 | else: 170 | txt = np.vstack((txt,t)) 171 | txt = txt.flatten() 172 | return self.embed_text(txt) 173 | 174 | class Adapter_Layer(nn.Module): 175 | def __init__(self, 176 | in_channels, 177 | bottleneck=32, 178 | out_channels=None, 179 | dropout=0.0, 180 | init_option="lora", 181 | adapter_scalar="1.0", 182 | adapter_layernorm_option=None): 183 | super().__init__() 184 | self.in_channels = in_channels 185 | self.down_size = bottleneck 186 | self.out_channels = out_channels if out_channels is not None else in_channels 187 | # self.non_linearity = args.non_linearity # use ReLU by default 188 | 189 | #_before 190 | self.adapter_layernorm_option = adapter_layernorm_option 191 | 192 | self.adapter_layer_norm = None 193 | if adapter_layernorm_option == "in" or adapter_layernorm_option == "out": 194 | self.adapter_layer_norm = nn.LayerNorm(self.n_embd) 195 | 196 | if adapter_scalar == "learnable_scalar": 197 | self.scale = nn.Parameter(torch.ones(1)) 198 | else: 199 | self.scale = float(adapter_scalar) 200 | 201 | self.down_proj = nn.Linear(self.in_channels, self.down_size) 202 | self.non_linear_func = nn.ReLU() 203 | self.up_proj = nn.Linear(self.down_size, self.out_channels) 204 | 205 | self.dropout = dropout 206 | 207 | if init_option == "lora": 208 | with torch.no_grad(): 209 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) 210 | nn.init.zeros_(self.up_proj.weight) 211 | nn.init.zeros_(self.down_proj.bias) 212 | nn.init.zeros_(self.up_proj.bias) 213 | 214 | def forward(self, x, add_residual=True, residual=None): 215 | residual = x if residual is None else residual 216 | if self.adapter_layernorm_option == 'in': 217 | x = self.adapter_layer_norm(x) 218 | 219 | down = self.down_proj(x) 220 | down = self.non_linear_func(down) 221 | down = nn.functional.dropout(down, p=self.dropout, training=self.training) 222 | up = self.up_proj(down) 223 | 224 | up = up * self.scale 225 | 226 | if self.adapter_layernorm_option == 'out': 227 | up = self.adapter_layer_norm(up) 228 | 229 | if add_residual: 230 | output = up + residual 231 | else: 232 | output = up 233 | 234 | return output 235 | 236 | class ResMLP(nn.Module): 237 | def __init__(self, h, n_blocks, dropout=0.15): 238 | super().__init__() 239 | self.n_blocks = n_blocks 240 | self.mlp = nn.ModuleList([ 241 | nn.Sequential( 242 | nn.Linear(h, h), 243 | nn.LayerNorm(h), 244 | nn.GELU(), 245 | nn.Dropout(dropout) 246 | ) for _ in range(n_blocks) 247 | ]) 248 | 249 | def forward(self, x): 250 | residual = x 251 | for res_block in range(self.n_blocks): 252 | x = self.mlp[res_block](x) 253 | x += residual 254 | residual = x 255 | return x 256 | 257 | class MindSingle(nn.Module): 258 | def __init__(self, in_dim=15724, out_dim_image=768, out_dim_text=None, 259 | h=4096, n_blocks=4, subj_list=None,): 260 | 261 | super().__init__() 262 | 263 | self.subj_list = subj_list 264 | self.embedder = nn.ModuleDict({ 265 | str(subj): nn.Sequential( 266 | Adapter_Layer(in_dim, 128), 267 | nn.Linear(in_dim, h), 268 | nn.LayerNorm(h), 269 | nn.GELU(), 270 | nn.Dropout(0.5), 271 | ) for subj in subj_list 272 | }) 273 | 274 | self.translator = ResMLP(h, n_blocks) 275 | self.head_image = nn.Linear(h, out_dim_image) 276 | self.head_text = nn.Linear(h, out_dim_text) 277 | 278 | # @torchsnooper.snoop() 279 | def forward(self, x): 280 | x = self.embedder[str(self.subj_list[0])](x) 281 | x = self.translator(x) 282 | 283 | x_image = self.head_image(x) 284 | x_image = x_image.reshape(len(x_image), -1) 285 | 286 | x_text = self.head_text(x) 287 | x_text = x_text.reshape(len(x_text), -1) 288 | 289 | return x_image, x_text 290 | 291 | class MindBridge(MindSingle): 292 | def __init__(self, in_dim=15724, out_dim_image=768, out_dim_text=None, 293 | h=4096, n_blocks=4, subj_list=None, adapting=False): 294 | 295 | assert len(subj_list) >= 2, "MindBridge requires at least 2 subjects" 296 | 297 | super().__init__(in_dim=in_dim, out_dim_image=out_dim_image, 298 | out_dim_text=out_dim_text, h=h, n_blocks=n_blocks, subj_list=subj_list 299 | ) 300 | 301 | self.builder = nn.ModuleDict({ 302 | str(subj): nn.Sequential( 303 | nn.Linear(h, in_dim), 304 | nn.LayerNorm(in_dim), 305 | nn.GELU(), 306 | Adapter_Layer(in_dim, 128), 307 | ) for subj in subj_list 308 | }) 309 | 310 | self.adapting = adapting 311 | self.cyc_loss = nn.MSELoss() 312 | 313 | # @torchsnooper.snoop() 314 | def forward(self, x): 315 | if len(x) == 2 and type(x) is tuple: 316 | subj_list = x[1].tolist() # (s,) 317 | x = x[0] # (b,n) 318 | else: 319 | subj_list = self.subj_list 320 | 321 | x = x.squeeze() 322 | x_subj = torch.chunk(x,len(subj_list)) 323 | x = [] 324 | x_rec = [] 325 | if self.adapting: # choose subj_a (source subject) and subj_b (target subject) 326 | subj_a, subj_b = subj_list[0], subj_list[-1] 327 | else: # random sample 2 subjects 328 | subj_a, subj_b = random.sample(subj_list, 2) 329 | for i, subj_i in enumerate(subj_list): # subj is 1-based 330 | x_i = self.embedder[str(subj_i)](x_subj[i]) # subj_i seman embedding 331 | if subj_i == subj_a: x_a = x_i # subj_a seman embedding are choosen 332 | x.append(x_i) 333 | x_i_rec = self.builder[str(subj_i)](x_i) # subj_i recon brain signals 334 | x_rec.append(x_i_rec) 335 | 336 | x = torch.concat(x, dim=0) 337 | x_rec = torch.concat(x_rec, dim=0) 338 | # del x_i, x_subj, x_i_rec 339 | 340 | # forward cycling 341 | x_b = self.builder[str(subj_b)](x_a) # subj_b recon brain signal using subj_a seman embedding 342 | x_b = self.embedder[str(subj_b)](x_b) # subj_b seman embedding (pseudo) 343 | loss_cyc = self.cyc_loss(x_a, x_b) 344 | 345 | x = self.translator(x) 346 | x_image = self.head_image(x) 347 | x_image = x_image.reshape(len(x_image), -1) 348 | 349 | x_text = self.head_text(x) 350 | x_text = x_text.reshape(len(x_text), -1) 351 | 352 | return x_image, x_text, x_rec, loss_cyc 353 | 354 | 355 | from diffusers.models.vae import Decoder 356 | class Voxel2StableDiffusionModel(torch.nn.Module): 357 | def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False, ups_mode='4x'): 358 | super().__init__() 359 | self.lin0 = nn.Sequential( 360 | nn.Linear(in_dim, h, bias=False), 361 | nn.LayerNorm(h), 362 | nn.SiLU(inplace=True), 363 | nn.Dropout(0.5), 364 | ) 365 | 366 | self.mlp = nn.ModuleList([ 367 | nn.Sequential( 368 | nn.Linear(h, h, bias=False), 369 | nn.LayerNorm(h), 370 | nn.SiLU(inplace=True), 371 | nn.Dropout(0.25) 372 | ) for _ in range(n_blocks) 373 | ]) 374 | self.ups_mode = ups_mode 375 | if ups_mode=='4x': 376 | self.lin1 = nn.Linear(h, 16384, bias=False) 377 | self.norm = nn.GroupNorm(1, 64) 378 | 379 | self.upsampler = Decoder( 380 | in_channels=64, 381 | out_channels=4, 382 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], 383 | block_out_channels=[64, 128, 256], 384 | layers_per_block=1, 385 | ) 386 | 387 | if use_cont: 388 | self.maps_projector = nn.Sequential( 389 | nn.Conv2d(64, 512, 1, bias=False), 390 | nn.GroupNorm(1,512), 391 | nn.ReLU(True), 392 | nn.Conv2d(512, 512, 1, bias=False), 393 | nn.GroupNorm(1,512), 394 | nn.ReLU(True), 395 | nn.Conv2d(512, 512, 1, bias=True), 396 | ) 397 | else: 398 | self.maps_projector = nn.Identity() 399 | 400 | if ups_mode=='8x': # prev best 401 | self.lin1 = nn.Linear(h, 16384, bias=False) 402 | self.norm = nn.GroupNorm(1, 256) 403 | 404 | self.upsampler = Decoder( 405 | in_channels=256, 406 | out_channels=4, 407 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], 408 | block_out_channels=[64, 128, 256, 256], 409 | layers_per_block=1, 410 | ) 411 | self.maps_projector = nn.Identity() 412 | 413 | if ups_mode=='16x': 414 | self.lin1 = nn.Linear(h, 8192, bias=False) 415 | self.norm = nn.GroupNorm(1, 512) 416 | 417 | self.upsampler = Decoder( 418 | in_channels=512, 419 | out_channels=4, 420 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D", "UpDecoderBlock2D"], 421 | block_out_channels=[64, 128, 256, 256, 512], 422 | layers_per_block=1, 423 | ) 424 | self.maps_projector = nn.Identity() 425 | 426 | if use_cont: 427 | self.maps_projector = nn.Sequential( 428 | nn.Conv2d(64, 512, 1, bias=False), 429 | nn.GroupNorm(1,512), 430 | nn.ReLU(True), 431 | nn.Conv2d(512, 512, 1, bias=False), 432 | nn.GroupNorm(1,512), 433 | nn.ReLU(True), 434 | nn.Conv2d(512, 512, 1, bias=True), 435 | ) 436 | else: 437 | self.maps_projector = nn.Identity() 438 | 439 | # @torchsnooper.snoop() 440 | def forward(self, x, return_transformer_feats=False): 441 | x = self.lin0(x) 442 | residual = x 443 | for res_block in self.mlp: 444 | x = res_block(x) 445 | x = x + residual 446 | residual = x 447 | x = x.reshape(len(x), -1) 448 | x = self.lin1(x) # bs, 4096 449 | 450 | if self.ups_mode == '4x': 451 | side = 16 452 | if self.ups_mode == '8x': 453 | side = 8 454 | if self.ups_mode == '16x': 455 | side = 4 456 | 457 | # decoder 458 | x = self.norm(x.reshape(x.shape[0], -1, side, side).contiguous()) 459 | if return_transformer_feats: 460 | return self.upsampler(x), self.maps_projector(x).flatten(2).permute(0,2,1) 461 | return self.upsampler(x) 462 | 463 | if __name__ == "__main__": 464 | in_dim=8000 465 | h=2048 466 | head = nn.Sequential( 467 | Adapter_Layer(in_dim, 128), 468 | nn.Linear(in_dim, h), 469 | nn.LayerNorm(h), 470 | nn.GELU(), 471 | nn.Dropout(0.5), 472 | ) 473 | 474 | def add_hooks(module, parent_name=''): 475 | module_name = module.__class__.__name__ 476 | if parent_name: 477 | module_name = f'{parent_name}.{module_name}' 478 | 479 | for name, child in module.named_children(): 480 | add_hooks(child, parent_name=f'{module_name}.{name}') 481 | 482 | print(module_name) 483 | 484 | add_hooks(head) 485 | -------------------------------------------------------------------------------- /src/nsd_access.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import glob 4 | import nibabel as nb 5 | import numpy as np 6 | import pandas as pd 7 | from pandas import json_normalize 8 | from tqdm import tqdm 9 | import h5py 10 | import matplotlib.pyplot as plt 11 | 12 | import urllib.request 13 | import zipfile 14 | from pycocotools.coco import COCO 15 | 16 | 17 | 18 | class NSDAccess(object): 19 | """ 20 | Little class that provides easy access to the NSD data, see [http://naturalscenesdataset.org](their website) 21 | """ 22 | 23 | def __init__(self, nsd_folder, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | self.nsd_folder = nsd_folder 26 | self.nsddata_folder = op.join(self.nsd_folder, 'nsddata') 27 | self.ppdata_folder = op.join(self.nsd_folder, 'nsddata', 'ppdata') 28 | self.nsddata_betas_folder = op.join( 29 | self.nsd_folder, 'nsddata_betas', 'ppdata') 30 | 31 | self.behavior_file = op.join( 32 | self.ppdata_folder, '{subject}', 'behav', 'responses.tsv') 33 | self.stimuli_file = op.join( 34 | self.nsd_folder, 'nsddata_stimuli', 'stimuli', 'nsd', 'nsd_stimuli.hdf5') 35 | self.stimuli_description_file = op.join( 36 | self.nsd_folder, 'nsddata', 'experiments', 'nsd', 'nsd_stim_info_merged.csv') 37 | 38 | self.coco_annotation_file = op.join( 39 | self.nsd_folder, 'nsddata_stimuli', 'stimuli', 'nsd', 'annotations', '{}_{}.json') 40 | 41 | def download_coco_annotation_file(self, url='http://images.cocodataset.org/annotations/annotations_trainval2017.zip'): 42 | """download_coco_annotation_file downloads and extracts the relevant annotations files 43 | 44 | Parameters 45 | ---------- 46 | url : str, optional 47 | url for zip file containing annotations, by default 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip' 48 | """ 49 | print('downloading annotations from {}'.format(url)) 50 | filehandle, _ = urllib.request.urlretrieve(url) 51 | zip_file_object = zipfile.ZipFile(filehandle, 'r') 52 | zip_file_object.extractall(path=op.split( 53 | op.split(self.coco_annotation_file)[0])[0]) 54 | 55 | def affine_header(self, subject, data_format='func1pt8mm'): 56 | """affine_header affine and header, for construction of Nifti image 57 | 58 | Parameters 59 | ---------- 60 | subject : str 61 | subject identifier, such as 'subj01' 62 | data_format : str, optional 63 | what type of data format, from ['func1pt8mm', 'func1mm'], by default 'func1pt8mm' 64 | 65 | Returns 66 | ------- 67 | tuple 68 | affine and header, for construction of Nifti image 69 | """ 70 | full_path = op.join(self.ppdata_folder, 71 | '{subject}', '{data_format}', 'brainmask.nii.gz') 72 | full_path = full_path.format(subject=subject, 73 | data_format=data_format) 74 | nii = nb.load(full_path) 75 | 76 | return nii.affine, nii.header 77 | 78 | def read_vol_ppdata(self, subject, filename='brainmask', data_format='func1pt8mm'): 79 | """load_brainmask, returns boolean brainmask for volumetric data formats 80 | 81 | Parameters 82 | ---------- 83 | subject : str 84 | subject identifier, such as 'subj01' 85 | data_format : str, optional 86 | what type of data format, from ['func1pt8mm', 'func1mm'], by default 'func1pt8mm' 87 | 88 | Returns 89 | ------- 90 | numpy.ndarray, 4D (bool) 91 | brain mask array 92 | """ 93 | full_path = op.join(self.ppdata_folder, 94 | '{subject}', '{data_format}', '{filename}.nii.gz') 95 | full_path = full_path.format(subject=subject, 96 | data_format=data_format, 97 | filename=filename) 98 | return nb.load(full_path).get_data() 99 | 100 | def read_betas(self, subject, session_index, trial_index=[], data_type='betas_fithrf_GLMdenoise_RR', data_format='fsaverage', mask=None): 101 | """read_betas read betas from MRI files 102 | 103 | Parameters 104 | ---------- 105 | subject : str 106 | subject identifier, such as 'subj01' 107 | session_index : int 108 | which session, counting from 1 109 | trial_index : list, optional 110 | which trials from this session's file to return, by default [], which returns all trials 111 | data_type : str, optional 112 | which type of beta values to return from ['betas_assumehrf', 'betas_fithrf', 'betas_fithrf_GLMdenoise_RR', 'restingbetas_fithrf'], by default 'betas_fithrf_GLMdenoise_RR' 113 | data_format : str, optional 114 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm'], by default 'fsaverage' 115 | mask : numpy.ndarray, if defined, selects 'mat' data_format, needs volumetric data_format 116 | binary/boolean mask into mat file beta data format. 117 | 118 | Returns 119 | ------- 120 | numpy.ndarray, 2D (fsaverage) or 4D (other data formats) 121 | the requested per-trial beta values 122 | """ 123 | data_folder = op.join(self.nsddata_betas_folder, 124 | subject, data_format, data_type) 125 | si_str = str(session_index).zfill(2) 126 | 127 | if type(mask) == np.ndarray: # will use the mat file iff exists, otherwise boom! 128 | ipf = op.join(data_folder, f'betas_session{si_str}.mat') 129 | assert op.isfile(ipf), \ 130 | 'Error: ' + ipf + ' not available for masking. You may need to download these separately.' 131 | # will do indexing of both space and time in one go for this option, 132 | # so will return results immediately from this 133 | h5 = h5py.File(ipf, 'r') 134 | betas = h5.get('betas') 135 | # embed() 136 | if len(trial_index) == 0: 137 | trial_index = slice(0, betas.shape[0]) 138 | # this isn't finished yet - binary masks cannot be used for indexing like this 139 | return betas[trial_index, np.nonzero(mask)] 140 | 141 | if data_format == 'fsaverage': 142 | session_betas = [] 143 | for hemi in ['lh', 'rh']: 144 | hdata = nb.load(op.join( 145 | data_folder, f'{hemi}.betas_session{si_str}.mgh')).get_data() 146 | session_betas.append(hdata) 147 | out_data = np.squeeze(np.vstack(session_betas)) 148 | else: 149 | # if no mask was specified, we'll use the nifti image 150 | out_data = nb.load( 151 | op.join(data_folder, f'betas_session{si_str}.nii.gz')).get_fdata() 152 | 153 | if len(trial_index) == 0: 154 | trial_index = slice(0, out_data.shape[-1]) 155 | 156 | return out_data[..., trial_index] 157 | 158 | def read_mapper_results(self, subject, mapper='prf', data_type='angle', data_format='fsaverage'): 159 | """read_mapper_results [summary] 160 | 161 | Parameters 162 | ---------- 163 | subject : str 164 | subject identifier, such as 'subj01' 165 | mapper : str, optional 166 | first part of the mapper filename, by default 'prf' 167 | data_type : str, optional 168 | second part of the mapper filename, by default 'angle' 169 | data_format : str, optional 170 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm'], by default 'fsaverage' 171 | 172 | Returns 173 | ------- 174 | numpy.ndarray, 2D (fsaverage) or 4D (other data formats) 175 | the requested mapper values 176 | """ 177 | if data_format == 'fsaverage': 178 | # unclear for now where the fsaverage mapper results would be 179 | # as they are still in fsnative format now. 180 | raise NotImplementedError( 181 | 'no mapper results in fsaverage present for now') 182 | else: # is 'func1pt8mm' or 'func1mm' 183 | return self.read_vol_ppdata(subject=subject, filename=f'{mapper}_{data_type}', data_format=data_format) 184 | 185 | def read_atlas_results(self, subject, atlas='HCP_MMP1', data_format='fsaverage'): 186 | """read_atlas_results [summary] 187 | 188 | Parameters 189 | ---------- 190 | subject : str 191 | subject identifier, such as 'subj01' 192 | for surface-based data formats, subject should be the same as data_format. 193 | for example, for fsaverage, both subject and data_format should be 'fsaverage' 194 | this requires a little more typing but makes data format explicit 195 | atlas : str, optional 196 | which atlas to read, 197 | for volume formats, any of ['HCP_MMP1', 'Kastner2015', 'nsdgeneral', 'visualsulc'] for volume, 198 | for fsaverage 199 | can be prefixed by 'lh.' or 'rh.' for hemisphere-specific atlases in volume 200 | for surface: takes both hemispheres by default, instead when prefixed by '.rh' or '.lh'. 201 | By default 'HCP_MMP1'. 202 | data_format : str, optional 203 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm', 'MNI'], by default 'fsaverage' 204 | 205 | Returns 206 | ------- 207 | numpy.ndarray, 1D/2D (surface) or 3D/4D (volume data formats) 208 | the requested atlas values 209 | dict, 210 | dictionary containing the mapping between ROI names and atlas values 211 | """ 212 | 213 | # first, get the mapping. 214 | atlas_name = atlas 215 | if atlas[:3] in ('rh.', 'lh.'): 216 | atlas_name = atlas[3:] 217 | 218 | mapp_df = pd.read_csv(os.path.join(self.nsddata_folder, 'freesurfer', 'fsaverage', 219 | 'label', f'{atlas_name}.mgz.ctab'), delimiter=' ', header=None, index_col=0) 220 | atlas_mapping = mapp_df.to_dict()[1] 221 | # dict((y,x) for x,y in atlas_mapping.iteritems()) 222 | atlas_mapping = {y: x for x, y in atlas_mapping.items()} 223 | 224 | if data_format not in ('func1pt8mm', 'func1mm', 'MNI'): 225 | # if surface based results by exclusion 226 | if atlas[:3] in ('rh.', 'lh.'): # check if hemisphere-specific atlas requested 227 | ipf = op.join(self.nsddata_folder, 'freesurfer', 228 | subject, 'label', f'{atlas}.mgz') 229 | return np.squeeze(nb.load(ipf).get_data()), atlas_mapping 230 | else: # more than one hemisphere requested 231 | session_betas = [] 232 | for hemi in ['lh', 'rh']: 233 | hdata = nb.load(op.join( 234 | self.nsddata_folder, 'freesurfer', subject, 'label', f'{hemi}.{atlas}.mgz')).get_data() 235 | session_betas.append(hdata) 236 | out_data = np.squeeze(np.vstack(session_betas)) 237 | return out_data, atlas_mapping 238 | else: # is 'func1pt8mm', 'MNI', or 'func1mm' 239 | ipf = op.join(self.ppdata_folder, subject, 240 | data_format, 'roi', f'{atlas}.nii.gz') 241 | return nb.load(ipf).get_fdata(), atlas_mapping 242 | 243 | def list_atlases(self, subject, data_format='fsaverage', abs_paths=False): 244 | """list_atlases [summary] 245 | 246 | Parameters 247 | ---------- 248 | subject : str 249 | subject identifier, such as 'subj01' 250 | for surface-based data formats, subject should be the same as data_format. 251 | for example, for fsaverage, both subject and data_format should be 'fsaverage' 252 | this requires a little more typing but makes data format explicit 253 | data_format : str, optional 254 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm', 'MNI'], by default 'fsaverage' 255 | 256 | Returns 257 | ------- 258 | list 259 | collection of absolute path names to 260 | """ 261 | if data_format in ('func1pt8mm', 'func1mm', 'MNI'): 262 | atlas_files = glob.glob( 263 | op.join(self.ppdata_folder, subject, data_format, 'roi', '*.nii.gz')) 264 | else: 265 | atlas_files = glob.glob( 266 | op.join(self.nsddata_folder, 'freesurfer', subject, 'label', '*.mgz')) 267 | 268 | # print this 269 | import pprint 270 | pp = pprint.PrettyPrinter(indent=4) 271 | print('Atlases found in {}:'.format(op.split(atlas_files[0])[0])) 272 | pp.pprint([op.split(f)[1] for f in atlas_files]) 273 | if abs_paths: 274 | return atlas_files 275 | else: # this is the format which you can input into other functions, so this is the default 276 | return np.unique([op.split(f)[1].replace('lh.', '').replace('rh.', '').replace('.mgz', '').replace('.nii.gz', '') for f in atlas_files]) 277 | 278 | def read_behavior(self, subject, session_index, trial_index=[]): 279 | """read_behavior [summary] 280 | 281 | Parameters 282 | ---------- 283 | subject : str 284 | subject identifier, such as 'subj01' 285 | session_index : int 286 | which session, counting from 0 287 | trial_index : list, optional 288 | which trials from this session's behavior to return, by default [], which returns all trials 289 | 290 | Returns 291 | ------- 292 | pandas DataFrame 293 | DataFrame containing the behavioral information for the requested trials 294 | """ 295 | 296 | behavior = pd.read_csv(self.behavior_file.format( 297 | subject=subject), delimiter='\t') 298 | 299 | # the behavior is encoded per run. 300 | # I'm now setting this function up so that it aligns with the timepoints in the fmri files, 301 | # i.e. using indexing per session, and not using the 'run' information. 302 | session_behavior = behavior[behavior['SESSION'] == session_index] 303 | 304 | if len(trial_index) == 0: 305 | trial_index = slice(0, len(session_behavior)) 306 | 307 | return session_behavior.iloc[trial_index] 308 | 309 | def read_images(self, image_index, show=False): 310 | """read_images reads a list of images, and returns their data 311 | 312 | Parameters 313 | ---------- 314 | image_index : list of integers 315 | which images indexed in the 73k format to return 316 | show : bool, optional 317 | whether to also show the images, by default False 318 | 319 | Returns 320 | ------- 321 | numpy.ndarray, 3D 322 | RGB image data 323 | """ 324 | 325 | if not hasattr(self, 'stim_descriptions'): 326 | self.stim_descriptions = pd.read_csv( 327 | self.stimuli_description_file, index_col=0) 328 | 329 | sf = h5py.File(self.stimuli_file, 'r') 330 | sdataset = sf.get('imgBrick') 331 | if show: 332 | f, ss = plt.subplots(1, len(image_index), 333 | figsize=(6*len(image_index), 6)) 334 | if len(image_index) == 1: 335 | ss = [ss] 336 | for s, d in zip(ss, sdataset[image_index]): 337 | s.axis('off') 338 | s.imshow(d) 339 | return sdataset[image_index] 340 | 341 | def read_image_coco_info(self, image_index, info_type='captions', show_annot=False, show_img=False): 342 | """image_coco_info returns the coco annotations of a single image or a list of images 343 | 344 | Parameters 345 | ---------- 346 | image_index : list of integers 347 | which images indexed in the 73k format to return the captions for 348 | info_type : str, optional 349 | what type of annotation to return, from ['captions', 'person_keypoints', 'instances'], by default 'captions' 350 | show_annot : bool, optional 351 | whether to show the annotation, by default False 352 | show_img : bool, optional 353 | whether to show the image (from the nsd formatted data), by default False 354 | 355 | Returns 356 | ------- 357 | coco Annotation 358 | coco annotation, to be used in subsequent analysis steps 359 | 360 | Example 361 | ------- 362 | single image: 363 | ci = read_image_coco_info( 364 | [569], info_type='captions', show_annot=False, show_img=False) 365 | list of images: 366 | ci = read_image_coco_info( 367 | [569, 2569], info_type='captions') 368 | 369 | """ 370 | if not hasattr(self, 'stim_descriptions'): 371 | self.stim_descriptions = pd.read_csv( 372 | self.stimuli_description_file, index_col=0) 373 | if len(image_index) == 1: 374 | subj_info = self.stim_descriptions.iloc[image_index[0]] 375 | 376 | # checking whether annotation file for this trial exists. 377 | # This may not be the right place to call the download, and 378 | # re-opening the annotations for all images separately may be slowing things down 379 | # however images used in the experiment seem to have come from different sets. 380 | annot_file = self.coco_annotation_file.format( 381 | info_type, subj_info['cocoSplit']) 382 | print('getting annotations from ' + annot_file) 383 | if not os.path.isfile(annot_file): 384 | print('annotations file not found') 385 | self.download_coco_annotation_file() 386 | 387 | coco = COCO(annot_file) 388 | coco_annot_IDs = coco.getAnnIds([subj_info['cocoId']]) 389 | coco_annot = coco.loadAnns(coco_annot_IDs) 390 | 391 | if show_img: 392 | self.read_images(image_index, show=True) 393 | 394 | if show_annot: 395 | # still need to convert the annotations (especially person_keypoints and instances) to the right reference frame, 396 | # because the images were cropped. See image information per image to do this. 397 | coco.showAnns(coco_annot) 398 | 399 | elif len(image_index) > 1: 400 | 401 | # we output a list of annots 402 | coco_annot = [] 403 | 404 | # load train_2017 405 | annot_file = self.coco_annotation_file.format( 406 | info_type, 'train2017') 407 | coco_train = COCO(annot_file) 408 | 409 | # also load the val 2017 410 | annot_file = self.coco_annotation_file.format( 411 | info_type, 'val2017') 412 | coco_val = COCO(annot_file) 413 | 414 | for image in image_index: 415 | subj_info = self.stim_descriptions.iloc[image] 416 | if subj_info['cocoSplit'] == 'train2017': 417 | coco_annot_IDs = coco_train.getAnnIds( 418 | [subj_info['cocoId']]) 419 | coco_ann = coco_train.loadAnns(coco_annot_IDs) 420 | coco_annot.append(coco_ann) 421 | 422 | elif subj_info['cocoSplit'] == 'val2017': 423 | coco_annot_IDs = coco_val.getAnnIds( 424 | [subj_info['cocoId']]) 425 | coco_ann = coco_val.loadAnns(coco_annot_IDs) 426 | coco_annot.append(coco_ann) 427 | 428 | return coco_annot 429 | 430 | def read_image_coco_category(self, image_index): 431 | """image_coco_category returns the coco category of a single image or a list of images 432 | 433 | Args: 434 | image_index ([list of integers]): which images indexed in the 73k format to return 435 | the category for 436 | 437 | Returns 438 | ------- 439 | coco category 440 | coco category, to be used in subsequent analysis steps 441 | 442 | Example 443 | ------- 444 | single image: 445 | ci = read_image_coco_category( 446 | [569]) 447 | list of images: 448 | ci = read_image_coco_category( 449 | [569, 2569]) 450 | """ 451 | 452 | if not hasattr(self, 'stim_descriptions'): 453 | self.stim_descriptions = pd.read_csv( 454 | self.stimuli_description_file, index_col=0) 455 | 456 | if len(image_index) == 1: 457 | subj_info = self.stim_descriptions.iloc[image_index[0]] 458 | coco_id = subj_info['cocoId'] 459 | 460 | # checking whether annotation file for this trial exists. 461 | # This may not be the right place to call the download, and 462 | # re-opening the annotations for all images separately may be slowing things down 463 | # however images used in the experiment seem to have come from different sets. 464 | annot_file = self.coco_annotation_file.format( 465 | 'instances', subj_info['cocoSplit']) 466 | print('getting annotations from ' + annot_file) 467 | if not os.path.isfile(annot_file): 468 | print('annotations file not found') 469 | self.download_coco_annotation_file() 470 | 471 | coco = COCO(annot_file) 472 | 473 | cat_ids = coco.getCatIds() 474 | categories = json_normalize(coco.loadCats(cat_ids)) 475 | 476 | coco_cats = [] 477 | for cat_id in cat_ids: 478 | this_img_list = coco.getImgIds(catIds=[cat_id]) 479 | if coco_id in this_img_list: 480 | this_cat = np.asarray(categories[categories['id']==cat_id]['name'])[0] 481 | coco_cats.append(this_cat) 482 | 483 | elif len(image_index) > 1: 484 | 485 | # we output a list of annots 486 | coco_cats = [] 487 | 488 | # load train_2017 489 | annot_file = self.coco_annotation_file.format( 490 | 'instances', 'train2017') 491 | coco_train = COCO(annot_file) 492 | cat_ids_train = coco_train.getCatIds() 493 | categories_train = json_normalize(coco_train.loadCats(cat_ids_train)) 494 | 495 | # also load the val 2017 496 | annot_file = self.coco_annotation_file.format( 497 | 'instances', 'val2017') 498 | coco_val = COCO(annot_file) 499 | cat_ids_val = coco_val.getCatIds() 500 | categories_val = json_normalize(coco_val.loadCats(cat_ids_val)) 501 | 502 | for image in tqdm(image_index, bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}'): 503 | subj_info = self.stim_descriptions.iloc[image] 504 | coco_id = subj_info['cocoId'] 505 | image_cat = [] 506 | if subj_info['cocoSplit'] == 'train2017': 507 | for cat_id in cat_ids_train: 508 | this_img_list = coco_train.getImgIds(catIds=[cat_id]) 509 | if coco_id in this_img_list: 510 | this_cat = np.asarray(categories_train[categories_train['id']==cat_id]['name'])[0] 511 | image_cat.append(this_cat) 512 | 513 | elif subj_info['cocoSplit'] == 'val2017': 514 | for cat_id in cat_ids_val: 515 | this_img_list = coco_val.getImgIds(catIds=[cat_id]) 516 | if coco_id in this_img_list: 517 | this_cat = np.asarray(categories_val[categories_val['id']==cat_id]['name'])[0] 518 | image_cat.append(this_cat) 519 | coco_cats.append(image_cat) 520 | return coco_cats 521 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | import os 3 | import math 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | from tqdm import tqdm 8 | import wandb 9 | 10 | # tf32 data type is faster than standard float32 11 | torch.backends.cuda.matmul.allow_tf32 = True 12 | 13 | # Custom models and functions # 14 | import utils 15 | import data 16 | 17 | class Trainer: 18 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None: 19 | # train logs path 20 | self.outdir = os.path.abspath(f'../train_logs/{args.model_name}') 21 | if not os.path.exists(self.outdir): 22 | os.makedirs(self.outdir,exist_ok=True) 23 | 24 | self.args = args 25 | self.accelerator = accelerator 26 | self.voxel2clip = voxel2clip 27 | self.clip_extractor = clip_extractor 28 | self.prompts_list = prompts_list 29 | self.device = device 30 | self.num_devices = max(torch.cuda.device_count(), 1) 31 | self.epoch_start = 0 32 | 33 | self.prepare_dataloader() 34 | self.prepare_optimizer() 35 | self.prepare_scheduler() 36 | self.prepare_multi_gpu() 37 | self.prepare_weights() 38 | 39 | def prepare_weights(self): 40 | # Resume or load ckpt 41 | if self.args.resume: 42 | self.resume() 43 | elif self.args.load_from: 44 | self.load() 45 | else: # the ckpt folder should not contain any ckpt. If it contains something, which means you forget to change experiment name, and will cause overwrite. 46 | file_count = len(os.listdir(self.outdir)) 47 | if file_count > 0: 48 | raise RuntimeError("The folder is not empty, please check to avoid overwriting! \n {}\n".format(self.outdir)) 49 | 50 | @abstractmethod 51 | def prepare_dataloader(self): 52 | pass 53 | 54 | def prepare_optimizer(self,): 55 | # Prepare optimizer 56 | no_decay = ['bias', 'Norm', 'temperature'] 57 | opt_grouped_parameters = [ 58 | {'params': [p for n, p in self.voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, 59 | {'params': [p for n, p in self.voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 60 | ] 61 | self.optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=self.args.max_lr) 62 | 63 | def prepare_scheduler(self): 64 | # prepare lr scheduler 65 | one_epoch_steps = self.num_batches 66 | if self.accelerator.state.deepspeed_plugin is not None: # Multi GPU 67 | one_epoch_steps = math.ceil(one_epoch_steps / self.num_devices) 68 | total_steps = self.args.num_epochs * one_epoch_steps 69 | print("one_epoch_steps_per_gpu:",one_epoch_steps) 70 | print("total_steps:",total_steps) 71 | 72 | if self.args.lr_scheduler_type == 'linear': 73 | self.lr_scheduler = torch.optim.lr_scheduler.LinearLR( 74 | self.optimizer, 75 | total_iters=total_steps, 76 | last_epoch=-1 77 | ) 78 | elif self.args.lr_scheduler_type == 'cycle': 79 | self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( 80 | self.optimizer, 81 | max_lr=self.args.max_lr, 82 | total_steps=total_steps, 83 | final_div_factor=100, 84 | last_epoch=-1, 85 | pct_start=2/self.args.num_epochs, 86 | ) 87 | 88 | def prepare_wandb(self, local_rank, args): 89 | ## Weights and Biases 90 | if local_rank==0 and args.wandb_log: # only use main process for wandb logging 91 | import wandb 92 | wandb_run = args.model_name 93 | wandb_notes = '' 94 | 95 | print(f"Wandb project {args.wandb_project} run {wandb_run}") 96 | wandb.login(host='https://api.wandb.ai') 97 | wandb_config = vars(args) 98 | print("wandb_config:\n",wandb_config) 99 | if args.resume: # wandb_auto_resume 100 | if args.resume_id is None: 101 | args.resume_id = args.model_name 102 | print("wandb_id:", args.resume_id) 103 | wandb.init( 104 | id = args.resume_id, 105 | project=args.wandb_project, 106 | name=wandb_run, 107 | config=wandb_config, 108 | notes=wandb_notes, 109 | resume="allow", 110 | ) 111 | else: 112 | wandb.init( 113 | project=args.wandb_project, 114 | name=wandb_run, 115 | config=wandb_config, 116 | notes=wandb_notes, 117 | ) 118 | 119 | @abstractmethod 120 | def prepare_multi_gpu(self): 121 | pass 122 | 123 | def input(self, voxel, subj_id): 124 | return (voxel, subj_id) 125 | 126 | def train(self, local_rank): 127 | epoch = self.epoch_start 128 | self.losses, self.val_losses, self.lrs = [], [], [] 129 | self.best_sim = 0 130 | self.best_epoch = 0 131 | 132 | self.val_voxel0 = self.val_image0 = None 133 | 134 | ## Main loop 135 | print(f"{self.args.model_name} starting with epoch {epoch} / {self.args.num_epochs}") 136 | progress_bar = tqdm(range(epoch, self.args.num_epochs), disable=(local_rank!=0)) 137 | 138 | for epoch in progress_bar: 139 | self.voxel2clip.train() 140 | 141 | self.sims_image = 0. 142 | self.sims_text = 0. 143 | self.val_sims_image = 0. 144 | self.val_sims_text = 0. 145 | self.fwd_percent_correct = 0. 146 | self.bwd_percent_correct = 0. 147 | self.val_fwd_percent_correct = 0. 148 | self.val_bwd_percent_correct = 0. 149 | self.loss_clip_image_sum = 0. 150 | self.loss_clip_text_sum = 0. 151 | self.loss_mse_image_sum = 0. 152 | self.loss_mse_text_sum = 0. 153 | self.loss_rec_sum = 0. 154 | self.loss_cyc_sum = 0. 155 | self.val_loss_clip_image_sum = 0. 156 | self.val_loss_clip_text_sum = 0. 157 | self.val_loss_mse_image_sum = 0. 158 | self.val_loss_mse_text_sum = 0. 159 | self.val_loss_rec_sum = 0. 160 | self.val_loss_cyc_sum = 0. 161 | 162 | # wandb logging 163 | self.train_epoch(epoch) 164 | self.log_train() 165 | 166 | if epoch % self.args.eval_interval == 0: 167 | self.eval_epoch(epoch) 168 | self.log_val() 169 | 170 | if self.args.wandb_log and local_rank==0: 171 | wandb.log(self.logs) 172 | 173 | progress_dict = { 174 | "epoch": epoch, 175 | "lr": self.logs["train/lr"], 176 | "loss": self.logs["train/loss"], 177 | } 178 | 179 | progress_bar.set_postfix(progress_dict) 180 | 181 | # Main process 182 | if local_rank==0: 183 | # Uploading logs to wandb 184 | if self.args.wandb_log: 185 | wandb.log(self.logs) 186 | 187 | # Save model 188 | if epoch % self.args.ckpt_interval == 0 or epoch == self.args.num_epochs-1: 189 | self.save(epoch) 190 | 191 | # wait for other GPUs to catch up if needed 192 | self.accelerator.wait_for_everyone() 193 | 194 | @abstractmethod 195 | def train_epoch(self, epoch): 196 | pass 197 | 198 | def train_step(self, voxel, image, captions, subj_id): 199 | loss = 0. 200 | self.optimizer.zero_grad() 201 | 202 | if self.args.use_image_aug: 203 | image = data.img_augment(image) 204 | clip_image = self.clip_extractor.embed_image(image).float() 205 | clip_text = self.clip_extractor.embed_text(captions).float() 206 | 207 | # clip_image_pred, clip_text_pred, voxel_rec, loss_cyc = self.voxel2clip((voxel, subj_id)) 208 | results = self.voxel2clip(self.input(voxel, subj_id)) 209 | 210 | # image clip loss 211 | clip_image_pred = results[0] 212 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1) 213 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1) 214 | loss_clip_image = utils.soft_clip_loss( 215 | clip_image_pred_norm, 216 | clip_image_norm, 217 | ) 218 | 219 | utils.check_loss(loss_clip_image, "loss_clip_image") 220 | loss += loss_clip_image 221 | self.loss_clip_image_sum += loss_clip_image.item() 222 | 223 | # image mse loss 224 | if self.args.mse_mult: 225 | loss_mse_image = nn.MSELoss()(clip_image_pred_norm, clip_image_norm) 226 | utils.check_loss(loss_mse_image, "loss_mse_image") 227 | loss += self.args.mse_mult * loss_mse_image 228 | self.loss_mse_image_sum += loss_mse_image.item() 229 | 230 | # text clip loss 231 | clip_text_pred = results[1] 232 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1) 233 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1) 234 | loss_clip_text = utils.soft_clip_loss( 235 | clip_text_pred_norm, 236 | clip_text_norm, 237 | ) 238 | utils.check_loss(loss_clip_text, "loss_clip_text") 239 | loss += loss_clip_text 240 | self.loss_clip_text_sum += loss_clip_text.item() 241 | 242 | # text mse loss 243 | if self.args.mse_mult: 244 | loss_mse_text = nn.MSELoss()(clip_text_pred_norm, clip_text_norm) 245 | utils.check_loss(loss_mse_text, "loss_mse_text") 246 | loss += self.args.mse_mult * loss_mse_text 247 | self.loss_mse_text_sum += loss_mse_text.item() 248 | 249 | # brain reconstruction loss 250 | if self.args.rec_mult: 251 | voxel_rec = results[2] 252 | loss_rec = nn.MSELoss()(voxel, voxel_rec) 253 | utils.check_loss(loss_rec, "loss_rec") 254 | loss += self.args.rec_mult * loss_rec 255 | self.loss_rec_sum += loss_rec.item() 256 | 257 | # cycle loss 258 | if self.args.cyc_mult: 259 | loss_cyc = results[3] 260 | utils.check_loss(loss_cyc, "loss_cyc") 261 | loss += self.args.cyc_mult * loss_cyc 262 | self.loss_cyc_sum += loss_cyc.item() 263 | 264 | utils.check_loss(loss) 265 | self.accelerator.backward(loss) 266 | self.optimizer.step() 267 | 268 | self.losses.append(loss.item()) 269 | self.lrs.append(self.optimizer.param_groups[0]['lr']) 270 | self.lr_scheduler.step() 271 | 272 | self.sims_image += nn.functional.cosine_similarity(clip_image_norm,clip_image_pred_norm).mean().item() 273 | self.sims_text += nn.functional.cosine_similarity(clip_text_norm,clip_text_pred_norm).mean().item() 274 | 275 | # forward and backward top 1 accuracy 276 | labels = torch.arange(len(clip_image_norm)).to(self.device) 277 | self.fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_pred_norm, clip_image_norm), labels, k=1) 278 | self.bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_norm, clip_image_pred_norm), labels, k=1) 279 | 280 | @abstractmethod 281 | def eval_epoch(self, epoch): 282 | pass 283 | 284 | def eval_step(self, voxel, image, captions, subj_id): 285 | val_loss = 0. 286 | with torch.no_grad(): 287 | # used for reconstruction 288 | if self.val_image0 is None: 289 | self.val_image0 = image.detach().clone() 290 | self.val_voxel0 = voxel.detach().clone() 291 | 292 | clip_image = self.clip_extractor.embed_image(image).float() 293 | clip_text = self.clip_extractor.embed_text(captions).float() 294 | 295 | # clip_image_pred, clip_text_pred, voxel_rec, loss_cyc = self.voxel2clip((voxel, subj_id)) 296 | results = self.voxel2clip(self.input(voxel, subj_id)) 297 | 298 | # image clip loss 299 | clip_image_pred = results[0] 300 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1) 301 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1) 302 | val_loss_clip_image = utils.soft_clip_loss( 303 | clip_image_pred_norm, 304 | clip_image_norm, 305 | ) 306 | val_loss += val_loss_clip_image 307 | self.val_loss_clip_image_sum += val_loss_clip_image.item() 308 | 309 | # image mse loss 310 | if self.args.mse_mult: 311 | val_loss_mse_image = nn.MSELoss()(clip_image_pred_norm, clip_image_norm) 312 | val_loss += self.args.mse_mult * val_loss_mse_image 313 | self.val_loss_mse_image_sum += val_loss_mse_image.item() 314 | 315 | # text clip loss 316 | clip_text_pred = results[1] 317 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1) 318 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1) 319 | val_loss_clip_text = utils.soft_clip_loss( 320 | clip_text_pred_norm, 321 | clip_text_norm, 322 | ) 323 | val_loss += val_loss_clip_text 324 | self.val_loss_clip_text_sum += val_loss_clip_text.item() 325 | 326 | # text mse loss 327 | if self.args.mse_mult: 328 | val_loss_mse_text = nn.MSELoss()(clip_text_pred_norm, clip_text_norm) 329 | val_loss += self.args.mse_mult * val_loss_mse_text 330 | self.val_loss_mse_text_sum += val_loss_mse_text.item() 331 | 332 | # brain reconstruction loss 333 | if self.args.rec_mult: 334 | voxel_rec = results[2] 335 | val_loss_rec = nn.MSELoss()(voxel, voxel_rec) 336 | val_loss += self.args.rec_mult * val_loss_rec 337 | self.val_loss_rec_sum += val_loss_rec.item() 338 | 339 | # cycle loss 340 | if self.args.cyc_mult: 341 | loss_cyc = results[3] 342 | val_loss_cyc = loss_cyc 343 | val_loss += self.args.cyc_mult * val_loss_cyc 344 | self.val_loss_cyc_sum += val_loss_cyc.item() 345 | 346 | utils.check_loss(val_loss) 347 | self.val_losses.append(val_loss.item()) 348 | 349 | self.val_sims_image += nn.functional.cosine_similarity(clip_image_norm,clip_image_pred_norm).mean().item() 350 | self.val_sims_text += nn.functional.cosine_similarity(clip_text_norm,clip_text_pred_norm).mean().item() 351 | 352 | labels = torch.arange(len(clip_image_norm)).to(self.device) 353 | self.val_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_pred_norm, clip_image_norm), labels, k=1) 354 | self.val_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_norm, clip_image_pred_norm), labels, k=1) 355 | 356 | def vis(self,): 357 | pass 358 | 359 | def save_ckpt(self, tag, epoch): 360 | if self.accelerator.is_main_process: 361 | ckpt_path = self.outdir+f'/{tag}.pth' 362 | print(f'--- saving model: {ckpt_path} ---',flush=True) 363 | unwrapped_model = self.accelerator.unwrap_model(self.voxel2clip) 364 | # unwarpped_optimizer = self.accelerator.unwrap_model(self.optimizer) 365 | # print("unwarpped optimizer keys", unwarpped_optimizer.state_dict().keys()) 366 | 367 | try: 368 | torch.save({ 369 | 'epoch': epoch, 370 | 'model_state_dict': unwrapped_model.state_dict(), 371 | }, ckpt_path) 372 | except: 373 | print("Couldn't save... moving on to prevent crashing.") 374 | del unwrapped_model 375 | 376 | state_path = os.path.join(self.outdir, tag) 377 | print(f'--- saving state: {state_path} ---',flush=True) 378 | self.accelerator.save_state(state_path) 379 | 380 | def save(self, epoch): 381 | self.save_ckpt(f'last', epoch) 382 | # save best model 383 | current_sim = (self.val_sims_image + self.val_sims_text) / (self.val_i + 1) if hasattr(self, 'val_i') else 0 384 | if current_sim > self.best_sim: 385 | self.best_sim = current_sim 386 | self.best_epoch = epoch 387 | self.save_ckpt(f'best', epoch) 388 | else: 389 | print(f'Not best - current_similarity: {current_sim:.3f} @ epoch {epoch}, best_similarity: {self.best_sim:.3f} @ epoch {self.best_epoch}') 390 | 391 | def load(self,): 392 | print("\n--- load from ckpt: {} ---\n".format(self.args.load_from)) 393 | checkpoint = torch.load(self.args.load_from, map_location='cpu') 394 | unwrapped_voxel2clip = self.accelerator.unwrap_model(self.voxel2clip) 395 | unwrapped_voxel2clip.load_state_dict(checkpoint['model_state_dict'], strict=False) 396 | print("loaded keys", checkpoint['model_state_dict'].keys()) 397 | del checkpoint 398 | 399 | def resume(self,): 400 | state_path = os.path.join(self.outdir, "last") 401 | print(f"\n--- resuming from {state_path} ---\n") 402 | self.accelerator.load_state(state_path) 403 | 404 | ckpt_path = self.outdir+'/last.pth' 405 | print(f"\n--- Read resume epoch from {ckpt_path} ---\n") 406 | checkpoint = torch.load(ckpt_path, map_location='cpu') 407 | self.epoch_start = checkpoint['epoch'] 408 | print(">>> Resume at Epoch", self.epoch_start) 409 | # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']['base_optimizer_state']) 410 | # self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 411 | # self.voxel2clip.load_state_dict(checkpoint['model_state_dict']) 412 | del checkpoint 413 | 414 | def log_train(self): 415 | self.logs = { 416 | "train/loss": np.mean(self.losses[-(self.train_i+1):]), 417 | "train/lr": self.lrs[-1], 418 | "train/num_steps": len(self.losses), 419 | "train/cosine_sim_image": self.sims_image / (self.train_i + 1), 420 | "train/cosine_sim_text": self.sims_text / (self.train_i + 1), 421 | "train/fwd_pct_correct": self.fwd_percent_correct / (self.train_i + 1), 422 | "train/bwd_pct_correct": self.bwd_percent_correct / (self.train_i + 1), 423 | "train/loss_clip_image": self.loss_clip_image_sum / (self.train_i + 1), 424 | "train/loss_clip_text": self.loss_clip_text_sum / (self.train_i + 1), 425 | "train/loss_mse_image": self.loss_mse_image_sum / (self.train_i + 1), 426 | "train/loss_mse_text": self.loss_mse_text_sum / (self.train_i + 1), 427 | "train/loss_rec": self.loss_rec_sum / (self.train_i + 1), 428 | "train/loss_cyc": self.loss_cyc_sum / (self.train_i + 1), 429 | } 430 | 431 | def log_val(self): 432 | self.logs.update({ 433 | "val/loss": np.mean(self.val_losses[-(self.val_i+1):]), 434 | "val/num_steps": len(self.val_losses), 435 | "val/cosine_sim_image": self.val_sims_image / (self.val_i + 1), 436 | "val/cosine_sim_text": self.val_sims_text / (self.val_i + 1), 437 | "val/val_fwd_pct_correct": self.val_fwd_percent_correct / (self.val_i + 1), 438 | "val/val_bwd_pct_correct": self.val_bwd_percent_correct / (self.val_i + 1), 439 | "val/loss_clip_image": self.val_loss_clip_image_sum / (self.val_i + 1), 440 | "val/loss_clip_text": self.val_loss_clip_text_sum / (self.val_i + 1), 441 | "val/loss_mse_image": self.val_loss_mse_image_sum / (self.val_i + 1), 442 | "val/loss_mse_text": self.val_loss_mse_text_sum / (self.val_i + 1), 443 | "val/loss_rec": self.val_loss_rec_sum / (self.val_i + 1), 444 | "val/loss_cyc": self.val_loss_cyc_sum / (self.val_i + 1), 445 | }) 446 | 447 | class Trainer_single(Trainer): 448 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None: 449 | super().__init__(args, accelerator, voxel2clip, clip_extractor, prompts_list, device) 450 | 451 | def prepare_dataloader(self): 452 | # Prepare data and dataloader 453 | print("Preparing data and dataloader...") 454 | self.train_dl, self.val_dl = data.get_dls( 455 | subject=self.args.subj_list[0], 456 | data_path=self.args.data_path, 457 | batch_size=self.args.batch_size, 458 | val_batch_size=self.args.val_batch_size, 459 | num_workers=self.args.num_workers, 460 | pool_type=self.args.pool_type, 461 | pool_num=self.args.pool_num, 462 | length=self.args.length, 463 | seed=self.args.seed, 464 | ) 465 | self.num_batches = len(self.train_dl) 466 | 467 | def prepare_multi_gpu(self): 468 | self.voxel2clip, self.optimizer, self.train_dl, self.val_dl, self.lr_scheduler = self.accelerator.prepare( 469 | self.voxel2clip, self.optimizer, self.train_dl, self.val_dl, self.lr_scheduler) 470 | 471 | def input(self, voxel, subj_id): 472 | # adapting need to know subj_id 473 | return voxel 474 | 475 | def train_epoch(self, epoch): 476 | # train loop 477 | for train_i, data_i in enumerate(self.train_dl): 478 | self.train_i = train_i 479 | repeat_index = train_i % 3 # randomly choose the one in the repeated three 480 | 481 | voxel, image, coco, subj_id = data_i 482 | voxel = voxel[:,repeat_index,...].float() 483 | subj_id = subj_id[[0],...] 484 | 485 | coco_ids = coco.squeeze().tolist() 486 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 487 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 488 | 489 | print(">>> Epoch{} | Iter{} | voxel: {}".format(epoch, train_i, voxel.shape), flush=True) 490 | self.train_step(voxel, image, captions, subj_id) 491 | 492 | def eval_epoch(self, epoch): 493 | print("evaluating...") 494 | self.voxel2clip.eval() 495 | 496 | for val_i, data_i in enumerate(self.val_dl): 497 | self.val_i = val_i 498 | repeat_index = val_i % 3 # randomly choose the one in the repeated three 499 | voxel, image, coco, subj_id = data_i 500 | voxel = torch.mean(voxel,axis=1) 501 | subj_id = subj_id[[0],...] 502 | 503 | coco_ids = coco.squeeze().tolist() 504 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 505 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 506 | 507 | print(">>> Epoch{} | Eval{} | voxel: {}".format(epoch, val_i, voxel.shape), flush=True) 508 | self.eval_step(voxel, image, captions, subj_id) 509 | 510 | class Trainer_bridge(Trainer): 511 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None: 512 | super().__init__(args, accelerator, voxel2clip, clip_extractor, prompts_list, device) 513 | 514 | def prepare_dataloader(self): 515 | # Prepare data and dataloader 516 | print("Preparing data and dataloader...") 517 | self.train_dls = [] # tarin_dls contains all subjects separately 518 | self.val_dls = [] # tarin_dls contains all subjects separately 519 | 520 | for subj in self.args.subj_list: 521 | train_dl, val_dl = data.get_dls( 522 | subject=subj, 523 | data_path=self.args.data_path, 524 | batch_size=self.args.batch_size, 525 | val_batch_size=self.args.val_batch_size, 526 | num_workers=self.args.num_workers, 527 | pool_type=self.args.pool_type, 528 | pool_num=self.args.pool_num, 529 | length=self.args.length, 530 | seed=self.args.seed, 531 | ) 532 | self.train_dls.append(train_dl) 533 | self.val_dls.append(val_dl) 534 | 535 | self.num_batches = len(self.train_dls[0]) 536 | 537 | def prepare_multi_gpu(self): 538 | self.voxel2clip, self.optimizer, self.lr_scheduler, _ = self.accelerator.prepare( 539 | self.voxel2clip, self.optimizer, self.lr_scheduler, self.train_dls[0]) 540 | 541 | for i, dls in enumerate(zip(self.train_dls, self.val_dls)): 542 | train_dl, val_dl = dls 543 | self.train_dls[i] = self.accelerator.prepare(train_dl) 544 | self.val_dls[i] = self.accelerator.prepare(val_dl) 545 | 546 | def train_epoch(self, epoch): 547 | # train loop 548 | for train_i, datas in enumerate(zip(*self.train_dls)): 549 | self.train_i = train_i 550 | repeat_index = train_i % 3 # randomly choose the one in the repeated three 551 | 552 | # ensemble data from multiple subjects 553 | voxel_list, image_list, coco_list, subj_id_list = [], [], [], [] 554 | for voxel, image, coco, subj_id in datas: 555 | voxel_list.append(voxel[:,repeat_index,...]) 556 | image_list.append(image) 557 | coco_list.append(coco) 558 | subj_id_list.append(subj_id[[0],...]) 559 | voxel = torch.cat(voxel_list, dim=0) 560 | image = torch.cat(image_list, dim=0) 561 | coco = torch.cat(coco_list, dim=0) 562 | subj_id = torch.cat(subj_id_list, dim=0) 563 | 564 | coco_ids = coco.squeeze().tolist() 565 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 566 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 567 | 568 | print(">>> Epoch{} | Iter{} | voxel: {}".format(epoch, train_i, voxel.shape), flush=True) 569 | self.train_step(voxel, image, captions, subj_id) 570 | 571 | def eval_epoch(self, epoch): 572 | print("Evaluating...") 573 | self.voxel2clip.eval() 574 | for val_i, datas in enumerate(zip(*self.val_dls)): 575 | self.val_i = val_i 576 | repeat_index = val_i % 3 # randomly choose the one in the repeated three 577 | 578 | # ensemble data from multiple subjects 579 | voxel_list, image_list, coco_list, subj_id_list = [], [], [], [] 580 | for voxel, image, coco, subj_id in datas: 581 | voxel_list.append(torch.mean(voxel,axis=1)) 582 | image_list.append(image) 583 | coco_list.append(coco) 584 | subj_id_list.append(subj_id[[0],...]) 585 | voxel = torch.cat(voxel_list, dim=0) 586 | image = torch.cat(image_list, dim=0) 587 | coco = torch.cat(coco_list, dim=0) 588 | subj_id = torch.cat(subj_id_list, dim=0) 589 | 590 | coco_ids = coco.squeeze().tolist() 591 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 592 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 593 | 594 | print(">>> Epoch{} | Eval{} | voxel: {}".format(epoch, val_i, voxel.shape), flush=True) 595 | self.eval_step(voxel, image, captions, subj_id) 596 | 597 | class Trainer_adapt(Trainer): 598 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None: 599 | super().__init__(args, accelerator, voxel2clip, clip_extractor, prompts_list, device) 600 | 601 | def prepare_dataloader(self): 602 | # Prepare data and dataloader 603 | print("Preparing data and dataloader...") 604 | self.train_dls_source = [] # tarin_dls contains all subjects separately 605 | self.val_dls_source = [] # tarin_dls contains all subjects separately 606 | 607 | # source subjects 608 | for subj in self.args.subj_source: 609 | train_dl, val_dl = data.get_dls( 610 | subject=subj, 611 | data_path=self.args.data_path, 612 | batch_size=self.args.batch_size, 613 | val_batch_size=self.args.val_batch_size, 614 | num_workers=self.args.num_workers, 615 | pool_type=self.args.pool_type, 616 | pool_num=self.args.pool_num, 617 | length=self.args.length, 618 | seed=self.args.seed, 619 | ) 620 | self.train_dls_source.append(train_dl) 621 | self.val_dls_source.append(val_dl) 622 | 623 | # target subjects 624 | self.train_dl_target, self.val_dl_target = data.get_dls( 625 | subject=self.args.subj_target, 626 | data_path=self.args.data_path, 627 | batch_size=self.args.batch_size, 628 | val_batch_size=self.args.val_batch_size, 629 | num_workers=self.args.num_workers, 630 | pool_type=self.args.pool_type, 631 | pool_num=self.args.pool_num, 632 | length=self.args.length, 633 | seed=self.args.seed, 634 | ) 635 | 636 | self.num_batches = len(self.train_dl_target) 637 | 638 | def prepare_multi_gpu(self): 639 | self.voxel2clip, self.optimizer, self.lr_scheduler, self.train_dl_target, self.val_dl_target = self.accelerator.prepare( 640 | self.voxel2clip, self.optimizer, self.lr_scheduler, self.train_dl_target, self.val_dl_target) 641 | 642 | for i, dls in enumerate(zip(self.train_dls_source, self.val_dls_source)): 643 | train_dl, val_dl = dls 644 | self.train_dls_source[i] = self.accelerator.prepare(train_dl) 645 | self.val_dls_source[i] = self.accelerator.prepare(val_dl) 646 | 647 | def train_epoch(self, epoch): 648 | # enable iteratable 649 | train_dls_source_iter = [] 650 | for train_dl_s in self.train_dls_source: 651 | train_dls_source_iter.append(iter(train_dl_s)) 652 | 653 | # train loop 654 | for train_i, datas_target in enumerate(self.train_dl_target): 655 | self.train_i = train_i 656 | repeat_index = train_i % 3 # randomly choose the one in the repeated three 657 | voxel_target, image, coco, subj_id = datas_target 658 | voxel = voxel_target[:,repeat_index,...] 659 | 660 | source_index = train_i % len(train_dls_source_iter) # every time choose one source domain 661 | voxel_source, image_source, coco_source, subj_id_source = next(train_dls_source_iter[source_index]) 662 | voxel_source = voxel_source[:,repeat_index,...] 663 | voxel = torch.cat((voxel_source, voxel), dim=0) 664 | image = torch.cat((image_source, image), dim=0) 665 | coco = torch.cat((coco_source, coco), dim=0) 666 | subj_id = torch.cat((subj_id_source[[0],...], subj_id[[0],...]), dim=0) 667 | 668 | coco_ids = coco.squeeze().tolist() 669 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 670 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 671 | 672 | print(">>> Epoch{} | Iter{} | source{} | voxel: {}".format(epoch, train_i, source_index, voxel.shape), flush=True) 673 | self.train_step(voxel, image, captions, subj_id) 674 | 675 | def eval_epoch(self, epoch): 676 | print("Evaluating...") 677 | self.voxel2clip.eval() 678 | 679 | # enable iteratable 680 | val_dls_source_iter = [] 681 | for val_dl_s in self.val_dls_source: 682 | val_dls_source_iter.append(iter(val_dl_s)) 683 | 684 | for val_i, datas_target in enumerate(self.val_dl_target): 685 | self.val_i = val_i 686 | repeat_index = val_i % 3 # randomly choose the one in the repeated three 687 | voxel, image, coco, subj_id = datas_target 688 | voxel = torch.mean(voxel,axis=1) 689 | 690 | source_index = val_i % len(val_dls_source_iter) # every time choose one source domain 691 | print("Using source {}".format(source_index)) 692 | voxel_source, image_source, coco_source, subj_id_source = next(val_dls_source_iter[source_index]) 693 | voxel_source = torch.mean(voxel_source, axis=1) 694 | voxel = torch.cat((voxel_source, voxel), dim=0) 695 | image = torch.cat((image_source, image), dim=0) 696 | coco = torch.cat((coco_source, coco), dim=0) 697 | subj_id = torch.cat((subj_id_source[[0],...], subj_id[[0],...]), dim=0) 698 | 699 | coco_ids = coco.squeeze().tolist() 700 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 701 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 702 | 703 | print(">>> Epoch{} | Eval{} | voxel: {}".format(epoch, val_i, voxel.shape), flush=True) 704 | self.eval_step(voxel, image, captions, subj_id) 705 | 706 | --------------------------------------------------------------------------------