├── src
├── __init__.py
├── main.py
├── data.py
├── options.py
├── recon.py
├── eval.py
├── utils.py
├── models.py
├── nsd_access.py
└── trainer.py
├── assets
├── galaxy_brain.gif
├── MindBridge_method.png
└── MindBridge_teaser.png
├── requirements.txt
├── .gitignore
├── scripts
├── train_single.sh
├── train_bridge.sh
├── inference.sh
└── adapt_bridge.sh
└── README.md
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/galaxy_brain.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlepure2333/MindBridge/HEAD/assets/galaxy_brain.gif
--------------------------------------------------------------------------------
/assets/MindBridge_method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlepure2333/MindBridge/HEAD/assets/MindBridge_method.png
--------------------------------------------------------------------------------
/assets/MindBridge_teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/littlepure2333/MindBridge/HEAD/assets/MindBridge_teaser.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==2.0.1
2 | torchvision==0.15.2
3 | diffusers==0.13.0
4 | kornia
5 | tqdm
6 | pandas
7 | scipy
8 | accelerate
9 | deepspeed
10 | torchsnooper
11 | matplotlib
12 | pycocotools
13 | h5py
14 | nibabel
15 | urllib3
16 | numpy
17 | wandb
18 | pillow
19 | scikit-image
20 | clip
21 | clip-retrieval
22 | transformers
23 | gpustat
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Project specific files
2 | data/
3 | train_logs/
4 | weights/
5 | wandb
6 |
7 | # Byte-compiled / optimized / DLL files
8 | __pycache__
9 | *.py[cod]
10 | *$py.class
11 |
12 | # Distribution / packaging
13 | dist/
14 | build/
15 | *.egg-info/
16 | *.egg
17 |
18 | # Virtual environments
19 | venv/
20 | env/
21 | .env
22 |
23 | # IDE specific files
24 | .idea/
25 | .vscode/
26 |
27 | # Compiled Python files
28 | *.pyc
29 |
30 | # Logs and databases
31 | *.log
32 | *.sqlite3
33 |
34 | # OS generated files
35 | .DS_Store
36 | Thumbs.db
--------------------------------------------------------------------------------
/scripts/train_single.sh:
--------------------------------------------------------------------------------
1 | batch_size=50
2 | val_batch_size=50
3 | num_epochs=600
4 | mse_mult=10000
5 | subj=1
6 | model_name="mindbridge_subj"$subj
7 |
8 | cd src/
9 |
10 | # CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \
11 | accelerate launch --multi_gpu --num_processes 2 --gpu_ids 0,1 --main_process_port 29501 \
12 | main.py \
13 | --model_name $model_name --subj_list $subj \
14 | --num_epochs $num_epochs --batch_size $batch_size --val_batch_size $val_batch_size \
15 | --h_size 2048 --n_blocks 4 --pool_type max --pool_num 8192 \
16 | --mse_mult $mse_mult \
17 | --eval_interval 10 --ckpt_interval 10 \
18 | --max_lr 3e-4 --num_workers 4
--------------------------------------------------------------------------------
/scripts/train_bridge.sh:
--------------------------------------------------------------------------------
1 | batch_size=50
2 | val_batch_size=50
3 | num_epochs=600
4 | mse_mult=10000
5 | rec_mult=1
6 | cyc_mul=1
7 | model_name="mindbridge_subj1257"
8 |
9 | cd src/
10 |
11 | # CUDA_VISIBLE_DEVICES=4 python -W ignore \
12 | accelerate launch --multi_gpu --num_processes 8 --gpu_ids 0,1,2,3,4,5,6,7 --main_process_port 29502 \
13 | main.py \
14 | --model_name $model_name --subj_list 1 2 5 7 \
15 | --num_epochs $num_epochs --batch_size $batch_size --val_batch_size $val_batch_size \
16 | --h_size 2048 --n_blocks 4 --pool_type max --pool_num 8192 \
17 | --mse_mult $mse_mult --rec_mult $rec_mult --cyc_mul $cyc_mul \
18 | --eval_interval 10 --ckpt_interval 10 \
19 | --max_lr 1.5e-4 --num_workers 8
20 |
--------------------------------------------------------------------------------
/scripts/inference.sh:
--------------------------------------------------------------------------------
1 | subj_load=1
2 | subj_test=1
3 | model_name="mindbridge_subj125"
4 | ckpt_from="last"
5 | text_image_ratio=0.5
6 | guidance=5
7 | gpu_id=0
8 |
9 | cd src/
10 |
11 | CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \
12 | recon.py \
13 | --model_name $model_name --ckpt_from $ckpt_from \
14 | --h_size 2048 --n_blocks 4 --pool_type max \
15 | --subj_load $subj_load --subj_test $subj_test \
16 | --text_image_ratio $text_image_ratio --guidance $guidance \
17 | --recons_per_sample 8
18 | # --test_end 2 \
19 | # --test_start 0 \
20 |
21 |
22 | results_path="../train_logs/"$model_name"/recon_on_subj"$subj_test
23 |
24 | CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \
25 | eval.py --results_path $results_path
26 |
--------------------------------------------------------------------------------
/scripts/adapt_bridge.sh:
--------------------------------------------------------------------------------
1 | batch_size=50
2 | val_batch_size=50
3 | mse_mult=10000
4 | rec_mult=1
5 | cyc_mult=1
6 | num_epochs=200
7 | subj_source="1 2 5"
8 | subj_target=7
9 | length=4000
10 | model_name="mindbridge_subj125_a"$subj_target"_l"$length
11 | load_from="../train_logs/mindbridge_subj125/last.pth"
12 | gpu_id=0
13 |
14 | cd src/
15 |
16 | # accelerate launch --multi_gpu --num_processes 4 --gpu_ids 0,1,2,3 --main_process_port 29503 \
17 | CUDA_VISIBLE_DEVICES=$gpu_id python -W ignore \
18 | main.py \
19 | --model_name $model_name --subj_target $subj_target --subj_source $subj_source \
20 | --num_epochs $num_epochs --batch_size $batch_size --val_batch_size $val_batch_size \
21 | --h_size 2048 --n_blocks 4 --pool_type max --pool_num 8192 \
22 | --mse_mult $mse_mult --rec_mult $rec_mult --cyc_mult $cyc_mult \
23 | --eval_interval 10 --ckpt_interval 10 \
24 | --load_from $load_from --num_workers 8 \
25 | --max_lr 1.5e-4 --adapting --length $length
26 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | from accelerate import Accelerator, DeepSpeedPlugin
5 |
6 | # tf32 data type is faster than standard float32
7 | torch.backends.cuda.matmul.allow_tf32 = True
8 |
9 | # Custom models and functions #
10 | from models import Clipper, MindBridge, MindSingle
11 | from nsd_access import NSDAccess
12 | from trainer import *
13 | from options import args
14 | import utils
15 |
16 | def config_multi_gpu():
17 | # Multi-GPU config
18 | deepspeed_plugin = DeepSpeedPlugin(zero_stage=2, gradient_clipping=1.0)
19 | accelerator = Accelerator(split_batches=False, mixed_precision='no', deepspeed_plugin=deepspeed_plugin)
20 | accelerator.print("PID of this process =",os.getpid())
21 | device = accelerator.device
22 | accelerator.print("device:",device)
23 | num_devices = torch.cuda.device_count()
24 | if num_devices==0: num_devices = 1
25 | accelerator.print(accelerator.state)
26 | local_rank = accelerator.state.local_process_index
27 | world_size = accelerator.state.num_processes
28 | distributed = not accelerator.state.distributed_type == 'NO'
29 | accelerator.print("distributed =",distributed, "num_devices =", num_devices, "local rank =", local_rank, "world size =", world_size)
30 |
31 | return accelerator, device, local_rank
32 |
33 | def prepare_CLIP(args, device):
34 | # Prepare CLIP
35 | clip_sizes = {"RN50": 1024, "ViT-L/14": 768, "ViT-B/32": 512, "ViT-H-14": 1024}
36 | clip_size = clip_sizes[args.clip_variant]
37 |
38 | print("Using hidden layer CLIP space (Versatile Diffusion)")
39 | if not args.norm_embs:
40 | print("WARNING: YOU WANT NORMED EMBEDDINGS FOR VERSATILE DIFFUSION!")
41 | clip_extractor = Clipper(args.clip_variant, device=device, hidden_state=True, norm_embs=args.norm_embs)
42 |
43 | out_dim_image = 257 * clip_size # 257*768 = 197376
44 | out_dim_text = 77 * clip_size # 77*768 = 59136
45 |
46 | print("clip_extractor loaded.")
47 | print("out_dim_image:",out_dim_image)
48 | print("out_dim_text:", out_dim_text)
49 |
50 | return clip_extractor, out_dim_image, out_dim_text
51 |
52 | def prepare_voxel2clip(args, out_dim_image, out_dim_text, device):
53 | # Prepare voxel2clip
54 | if args.adapting:
55 | args.subj_list = args.subj_source + [args.subj_target]
56 |
57 | voxel2clip_kwargs = dict(
58 | in_dim=args.pool_num, out_dim_image=out_dim_image, out_dim_text=out_dim_text,
59 | h=args.h_size, n_blocks=args.n_blocks, subj_list=args.subj_list, adapting=args.adapting)
60 | if len(args.subj_list) == 1: # Single subject does not need "brain builder"
61 | voxel2clip_kwargs.pop("adapting") # Single subject does not need "adapting"
62 | voxel2clip = MindSingle(**voxel2clip_kwargs).to(device)
63 | else:
64 | voxel2clip = MindBridge(**voxel2clip_kwargs).to(device)
65 |
66 | if args.adapting: # reset-tuning
67 | # Only let the parameters of embedder and builder in the voxel2clip trainable, keeping other parameters frozen
68 | voxel2clip.requires_grad_(False)
69 | voxel2clip.embedder[str(args.subj_target)].requires_grad_(True)
70 | voxel2clip.builder[str(args.subj_target)].requires_grad_(True)
71 |
72 | print("voxel2clip loaded.")
73 | print("params of voxel2clip:")
74 | utils.count_params(voxel2clip)
75 |
76 | return voxel2clip
77 |
78 | def prepare_coco(args):
79 | # Preload coco captions
80 | nsda = NSDAccess(args.data_path)
81 | coco_73k = list(range(0, 73000))
82 | prompts_list = nsda.read_image_coco_info(coco_73k,info_type='captions')
83 |
84 | print("coco captions loaded.")
85 |
86 | return prompts_list
87 |
88 | def prepare_trainer(args, accelerator, voxel2clip, clip_extractor, prompts_list, device):
89 | if args.adapting:
90 | trainer = Trainer_adapt(args, accelerator, voxel2clip, clip_extractor, prompts_list, device)
91 | elif len(args.subj_list) == 1:
92 | trainer = Trainer_single(args, accelerator, voxel2clip, clip_extractor, prompts_list, device)
93 | else:
94 | trainer = Trainer_bridge(args, accelerator, voxel2clip, clip_extractor, prompts_list, device)
95 |
96 | return trainer
97 |
98 |
99 | def main():
100 | accelerator, device, local_rank = config_multi_gpu()
101 | if local_rank != 0: # suppress print for non-local_rank=0
102 | sys.stdout = open(os.devnull, 'w')
103 |
104 | # need non-deterministic CuDNN for conv3D to work
105 | utils.seed_everything(args.seed, cudnn_deterministic=False)
106 |
107 | # learning rate will be changed by "acclerate" based on number of processes(GPUs)
108 | args.max_lr *= accelerator.num_processes
109 |
110 | # Prepare CLIP
111 | clip_extractor, out_dim_image, out_dim_text = prepare_CLIP(args, device)
112 |
113 | # Prepare voxel2clip
114 | voxel2clip = prepare_voxel2clip(args, out_dim_image, out_dim_text, device)
115 |
116 | # Prepare coco captions
117 | prompts_list = prepare_coco(args)
118 |
119 | # Init Trainer
120 | trainer = prepare_trainer(args, accelerator, voxel2clip, clip_extractor, prompts_list, device)
121 | trainer.prepare_wandb(local_rank, args)
122 | # trainer.prepare_multi_gpu()
123 |
124 | # Train or Adapt
125 | trainer.train(local_rank)
126 |
127 | print("\n===Finished!===\n")
128 |
129 | if __name__ == '__main__':
130 | main()
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | from torch import nn
6 | import torch.nn.functional as F
7 | from torch.utils.data import Dataset, DataLoader
8 | import utils
9 | import kornia
10 | from kornia.augmentation.container import AugmentationSequential
11 |
12 |
13 | img_augment = AugmentationSequential(
14 | kornia.augmentation.RandomResizedCrop((224,224), (0.8,1), p=0.3),
15 | kornia.augmentation.Resize((224, 224)),
16 | kornia.augmentation.RandomBrightness(brightness=(0.8, 1.2), clip_output=True, p=0.2),
17 | kornia.augmentation.RandomContrast(contrast=(0.8, 1.2), clip_output=True, p=0.2),
18 | kornia.augmentation.RandomGamma((0.8, 1.2), (1.0, 1.3), p=0.2),
19 | kornia.augmentation.RandomSaturation((0.8,1.2), p=0.2),
20 | kornia.augmentation.RandomHue((-0.1,0.1), p=0.2),
21 | kornia.augmentation.RandomSharpness((0.8, 1.2), p=0.2),
22 | kornia.augmentation.RandomGrayscale(p=0.2),
23 | data_keys=["input"],
24 | )
25 |
26 | class NSDDataset(Dataset):
27 | def __init__(self, root_dir, extensions=None, pool_num=8192, pool_type="max", length=None):
28 | self.root_dir = root_dir
29 | self.extensions = extensions if extensions else []
30 | self.pool_num = pool_num
31 | self.pool_type = pool_type
32 | self.samples = self._load_samples()
33 | self.samples_keys = sorted(self.samples.keys())
34 | self.length = length
35 | if length is not None:
36 | if length > len(self.samples_keys):
37 | pass # enlarge the dataset
38 | elif length > 0:
39 | self.samples_keys = self.samples_keys[:length]
40 | elif length < 0:
41 | self.samples_keys = self.samples_keys[length:]
42 | elif length == 0:
43 | raise ValueError("length must be a non-zero value!")
44 | else:
45 | self.length = len(self.samples_keys)
46 |
47 | def _load_samples(self):
48 | files = os.listdir(self.root_dir)
49 | samples = {}
50 | for file in files:
51 | file_path = os.path.join(self.root_dir, file)
52 | sample_id, ext = file.split(".",maxsplit=1)
53 | if ext in self.extensions:
54 | if sample_id in samples.keys():
55 | samples[sample_id][ext] = file_path
56 | else:
57 | samples[sample_id]={"subj": file_path}
58 | samples[sample_id][ext] = file_path
59 | # print(samples)
60 | return samples
61 |
62 | def _load_image(self, image_path):
63 | image = Image.open(image_path).convert('RGB')
64 | image = np.array(image).astype(np.float32) / 255.0
65 | image = torch.from_numpy(image.transpose(2, 0, 1))
66 | return image
67 |
68 | def _load_npy(self, npy_path):
69 | array = np.load(npy_path)
70 | array = torch.from_numpy(array)
71 | return array
72 |
73 | def vox_process(self, x):
74 | if self.pool_num is not None:
75 | x = pool_voxels(x, self.pool_num, self.pool_type)
76 | return x
77 |
78 | def subj_process(self, key):
79 | id = int(key.split("/")[-2].split("subj")[-1])
80 | return id
81 |
82 | def aug_process(self, brain3d):
83 | return brain3d
84 |
85 | def __len__(self):
86 | # return len(self.samples_keys)
87 | return self.length
88 |
89 | def __getitem__(self, idx):
90 | idx = idx % len(self.samples_keys)
91 | sample_key = self.samples_keys[idx]
92 | sample = self.samples[sample_key]
93 | items = []
94 | for ext in self.extensions:
95 | if ext == "jpg":
96 | items.append(self._load_image(sample[ext]))
97 | elif ext == "nsdgeneral.npy":
98 | voxel = self._load_npy(sample[ext])
99 | items.append(self.vox_process(voxel))
100 | elif ext == "coco73k.npy":
101 | items.append(self._load_npy(sample[ext]))
102 | elif ext == "subj":
103 | items.append(self.subj_process(sample[ext]))
104 | elif ext == "wholebrain_3d.npy":
105 | brain3d = self._load_npy(sample[ext])
106 | items.append(self.aug_process(brain3d, ))
107 |
108 | return items
109 |
110 | def pool_voxels(voxels, pool_num, pool_type):
111 | voxels = voxels.float()
112 | if pool_type == 'avg':
113 | voxels = nn.AdaptiveAvgPool1d(pool_num)(voxels)
114 | elif pool_type == 'max':
115 | voxels = nn.AdaptiveMaxPool1d(pool_num)(voxels)
116 | elif pool_type == "resize":
117 | voxels = voxels.unsqueeze(1) # Add a dimension to make it (B, 1, L)
118 | voxels = F.interpolate(voxels, size=pool_num, mode='linear', align_corners=False)
119 | voxels = voxels.squeeze(1)
120 |
121 | return voxels
122 |
123 | def get_dataloader(
124 | root_dir,
125 | batch_size,
126 | num_workers=1,
127 | seed=42,
128 | is_shuffle=True,
129 | extensions=['nsdgeneral.npy', "jpg", 'coco73k.npy', "subj"],
130 | pool_type=None,
131 | pool_num=None,
132 | length=None,
133 | ):
134 | utils.seed_everything(seed)
135 | dataset = NSDDataset(root_dir=root_dir, extensions=extensions, pool_num=pool_num, pool_type=pool_type, length=length)
136 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=is_shuffle)
137 |
138 | return dataloader
139 |
140 | def get_dls(subject, data_path, batch_size, val_batch_size, num_workers, pool_type, pool_num, length, seed):
141 | train_path = "{}/webdataset_avg_split/train/subj0{}".format(data_path, subject)
142 | val_path = "{}/webdataset_avg_split/val/subj0{}".format(data_path, subject)
143 | extensions = ['nsdgeneral.npy', "jpg", 'coco73k.npy', "subj"]
144 |
145 | train_dl = get_dataloader(
146 | train_path,
147 | batch_size=batch_size,
148 | num_workers=num_workers,
149 | seed=seed,
150 | extensions=extensions,
151 | pool_type=pool_type,
152 | pool_num=pool_num,
153 | is_shuffle=True,
154 | length=length,
155 | )
156 |
157 | val_dl = get_dataloader(
158 | val_path,
159 | batch_size=val_batch_size,
160 | num_workers=num_workers,
161 | seed=seed,
162 | extensions=extensions,
163 | pool_type=pool_type,
164 | pool_num=pool_num,
165 | is_shuffle=False,
166 | )
167 |
168 | num_train=len(train_dl.dataset)
169 | num_val=len(val_dl.dataset)
170 | print(train_path,"\n",val_path)
171 | print("number of train data:", num_train)
172 | print("batch_size", batch_size)
173 | print("number of val data:", num_val)
174 | print("val_batch_size", val_batch_size)
175 |
176 | return train_dl, val_dl
177 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # MindBridge: A Cross-Subject Brain Decoding Framework
2 |
3 | 
4 |
5 | [Shizun Wang](https://littlepure2333.github.io/home/), [Songhua Liu](http://121.37.94.87/), [Zhenxiong Tan](https://github.com/Yuanshi9815), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
6 | National University of Singapore
7 |
8 | **CVPR 2024 Highlight**
9 | [Project](https://littlepure2333.github.io/MindBridge/) | [Arxiv](https://arxiv.org/abs/2404.07850)
10 |
11 | ## News
12 | **[2024.04.12]** MindBridge's paper, project and code are released.
13 | **[2024.04.05]** MindBridge is selected as **CVPR 2024 Highlight** paper!
14 | **[2024.02.27]** MindBridge is accepted by **CVPR 2024**!
15 |
16 | ## Overview
17 | 
18 |
19 | > We present a novel approach, MindBridge, that achieves **cross-subject brain decoding by employing only one model**. Our proposed framework establishes a generic paradigm capable of addressing these challenges: **1) the inherent variability** in input dimensions across subjects due to differences in brain size; **2) the unique intrinsic neural patterns**, influencing how different individuals perceive and process sensory information; **3) limited data availability for new subjects** in real-world scenarios hampers the performance of decoding models.
20 | Notably, by cycle reconstruction, MindBridge can enable **novel brain signals synthesis**, which also can serve as pseudo data augmentation. Within the framework, we can **adapt** a pretrained MindBridge to a **new subject** using less data.
21 |
22 | ## Installation
23 |
24 | 1. Agree to the Natural Scenes Dataset's [Terms and Conditions](https://cvnlab.slite.page/p/IB6BSeW_7o/Terms-and-Conditions) and fill out the [NSD Data Access form](https://forms.gle/xue2bCdM9LaFNMeb7)
25 |
26 | 2. Download this repository: ``git clone https://github.com/littlepure2333/MindBridge.git``
27 |
28 | 3. Create a conda environment and install the packages necessary to run the code.
29 |
30 | ```bash
31 | conda create -n mindbridge python=3.10.8 -y
32 | conda activate mindbridge
33 | pip install -r requirements.txt
34 | ```
35 |
36 | ## Preparation
37 |
38 | ### Data
39 |
40 | Download the essential files we used from [NSD dataset](https://natural-scenes-dataset.s3.amazonaws.com/index.html), which contains `nsd_stim_info_merged.csv`. Also download COCO captions from [this link](http://images.cocodataset.org/annotations/annotations_trainval2017.zip) which contains `captions_train2017.json` and `captions_val2017.json`.
41 | We use the same preprocessed data as [MindEye's](https://github.com/MedARC-AI/fMRI-reconstruction-NSD), which can be downloaded from [Hugging Face](https://huggingface.co/datasets/pscotti/naturalscenesdataset/tree/main/webdataset_avg_split), and extract all files from the compressed tar files.
42 | Then organize the data as following:
43 |
44 |
45 |
46 | Data Organization
47 |
48 | ```
49 | data/natural-scenes-dataset
50 | ├── nsddata
51 | │ └── experiments
52 | │ └── nsd
53 | │ └── nsd_stim_info_merged.csv
54 | ├── nsddata_stimuli
55 | │ └── stimuli
56 | │ └── nsd
57 | │ └── annotations
58 | │ ├── captions_train2017.json
59 | │ └── captions_val2017.json
60 | └── webdataset_avg_split
61 | ├── test
62 | │ ├── subj01
63 | │ │ ├── sample000000349.coco73k.npy
64 | │ │ ├── sample000000349.jpg
65 | │ │ ├── sample000000349.nsdgeneral.npy
66 | │ │ └── ...
67 | │ └── ...
68 | ├── train
69 | │ ├── subj01
70 | │ │ ├── sample000000300.coco73k.npy
71 | │ │ ├── sample000000300.jpg
72 | │ │ ├── sample000000300.nsdgeneral.npy
73 | │ │ └── ...
74 | │ └── ...
75 | └── val
76 | ├── subj01
77 | │ ├── sample000000000.coco73k.npy
78 | │ ├── sample000000000.jpg
79 | │ ├── sample000000000.nsdgeneral.npy
80 | │ └── ...
81 | └── ...
82 | ```
83 |
84 |
85 |
86 | ### Checkpoints
87 | You can download our pretrained MindBridge checkpoints for "subject01, 02, 05, 07" from [Hugging Face](https://huggingface.co/littlepure2333/MindBridge/tree/main). And place the folders containing checkpoints under the directory `./train_logs/`.
88 |
89 | ## Training
90 |
91 | The training commands are described in the `./scripts` folder. You can check the command options in the `./src/options.py` file. For example, you can resume training by adding the `--resume` option to the command. The training progress can be monitored through [wandb](https://wandb.ai/).
92 |
93 | ### Training on single subject
94 | This script contains training the per-subject-per-model version of MindBridge (which refers to "Vanilla" in the paper) on one subject (e.g. subj01). You can also indicate which subject in the script.
95 |
96 | ```bash
97 | bash scripts/train_single.sh
98 | ```
99 |
100 | ### Training on multi-subjects
101 | This script contains training MindBridge on multi-subjects (e.g. subj01, 02, 05, 07). You can also indicate which subjects in the script.
102 |
103 | ```bash
104 | bash scripts/train_bridge.sh
105 | ```
106 |
107 |
108 | ### Adapting to a new subject
109 | Once the MindBridge is trained on some known "source subjects" (e.g. subj01, 02, 05), you can adapt the MindBridge to a new "target subject" (e.g. subj07) based on limited data volume (e.g. 4000 data points). You can also indicate which source subjects, which target subject, or data volume (length) in the script.
110 |
111 | ```bash
112 | bash scripts/adapt_bridge.sh
113 | ```
114 |
115 | ## Reconstructing and evaluating
116 | This script will reconstruct one subject's images (e.g. subj01) on the test set from a MindBridge model (e.g. subj01, 02, 05, 07), then calculate all the metrics. The evaluated metrics will be saved in a csv file. You can indicate which MindBridge model and which subject in the script.
117 |
118 | ```bash
119 | bash scripts/inference.sh
120 | ```
121 |
122 |
123 | ## TODO List
124 | - [x] Release pretrained checkpoints.
125 | - [ ] Training MindBridge on all 8 subjects in NSD dataset.
126 |
127 | ## Citation
128 | ```
129 | @inproceedings{wang2024mindbridge,
130 | title={Mindbridge: A cross-subject brain decoding framework},
131 | author={Wang, Shizun and Liu, Songhua and Tan, Zhenxiong and Wang, Xinchao},
132 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
133 | pages={11333--11342},
134 | year={2024}
135 | }
136 | ```
137 |
138 | ## Acknowledgement
139 | We extend our gratitude to [MindEye](https://github.com/MedARC-AI/fMRI-reconstruction-NSD) and [nsd_access](https://github.com/tknapen/nsd_access) for generously sharing their codebase, upon which ours is built. We are indebted to the [NSD dataset](https://natural-scenes-dataset.s3.amazonaws.com/index.html) for providing access to high-quality, publicly available data.
140 | Our appreciation also extends to the [Accelerate](https://huggingface.co/docs/accelerate/index) and [DeepSpeed](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for simplifying the process of efficient multi-GPU training, enabling us to train on the 24GB vRAM GPU, NVIDIA A5000.
141 | Special thanks to [Xingyi Yang](https://adamdad.github.io/) and [Yifan Zhang](https://sites.google.com/view/yifan-zhang) for their invaluable discussions.
142 |
143 |
144 |

145 |
146 |
147 |
148 |
--------------------------------------------------------------------------------
/src/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description="MindBridge Configuration")
4 | parser.add_argument(
5 | "--model_name", type=str, default="testing",
6 | help="name of model, used for ckpt saving and wandb logging (if enabled)",
7 | )
8 | parser.add_argument(
9 | "--data_path", type=str, default="../data/natural-scenes-dataset",
10 | help="Path to where NSD data is stored / where to download it to",
11 | )
12 | parser.add_argument(
13 | "--subj_list",type=int, default=[1], choices=[1,2,3,4,5,6,7,8], nargs='+',
14 | help="Subject index to train on",
15 | )
16 | parser.add_argument(
17 | "--subj_source",type=int, default=[1], choices=[1,2,3,4,5,6,7,8], nargs='+',
18 | help="Source subject index to be adapted from (Can be multiple subjects)",
19 | )
20 | parser.add_argument(
21 | "--subj_target",type=int, default=[1], choices=[1,2,3,4,5,6,7,8],
22 | help="Target subject index to be adapted to (Only one subject)",
23 | )
24 | parser.add_argument(
25 | "--adapting",action=argparse.BooleanOptionalAction,default=False,
26 | help="Whether to adapt from source to target subject",
27 | )
28 | parser.add_argument(
29 | "--batch_size", type=int, default=50,
30 | help="Batch size per GPU",
31 | )
32 | parser.add_argument(
33 | "--val_batch_size", type=int, default=50,
34 | help="Validation batch size per GPU",
35 | )
36 | parser.add_argument(
37 | "--clip_variant",type=str,default="ViT-L/14",choices=["RN50", "ViT-L/14", "ViT-B/32", "RN50x64"],
38 | help='OpenAI clip variant',
39 | )
40 | parser.add_argument(
41 | "--wandb_log",action=argparse.BooleanOptionalAction,default=True,
42 | help="Whether to log to wandb",
43 | )
44 | parser.add_argument(
45 | "--wandb_project",type=str,default="MindBridge",
46 | help="Wandb project name",
47 | )
48 | parser.add_argument(
49 | "--resume",action=argparse.BooleanOptionalAction,default=False,
50 | help="Resume training from latest checkpoint, can't do it with --load_from at the same time",
51 | )
52 | parser.add_argument(
53 | "--resume_id",type=str,default=None,
54 | help="Run id for wandb resume",
55 | )
56 | parser.add_argument(
57 | "--load_from",type=str,default=None,
58 | help="load model and restart, can't do it with --resume at the same time",
59 | )
60 | parser.add_argument(
61 | "--norm_embs",action=argparse.BooleanOptionalAction,default=True,
62 | help="Do l2-norming of CLIP embeddings",
63 | )
64 | parser.add_argument(
65 | "--use_image_aug",action=argparse.BooleanOptionalAction,default=True,
66 | help="whether to use image augmentation",
67 | )
68 | parser.add_argument(
69 | "--num_epochs",type=int,default=2000,
70 | help="number of epochs of training",
71 | )
72 | parser.add_argument(
73 | "--lr_scheduler_type",type=str,default='cycle',choices=['cycle','linear'],
74 | help="Type of learning rate scheduler",
75 | )
76 | parser.add_argument(
77 | "--ckpt_interval",type=int,default=10,
78 | help="Save backup ckpt and reconstruct every x epochs",
79 | )
80 | parser.add_argument(
81 | "--eval_interval",type=int,default=10,
82 | help="Evaluate the model every x epochs",
83 | )
84 | parser.add_argument(
85 | "--h_size",type=int,default=2048,
86 | help="Hidden size of MLP",
87 | )
88 | parser.add_argument(
89 | "--n_blocks",type=int,default=2,
90 | help="Number of Hidden layers in MLP",
91 | )
92 | parser.add_argument(
93 | "--seed",type=int,default=42,
94 | help="Seed for reproducibility",
95 | )
96 | parser.add_argument(
97 | "--num_workers",type=int,default=5,
98 | help="Number of workers in dataloader"
99 | )
100 | parser.add_argument(
101 | "--max_lr",type=float,default=3e-4,
102 | help="Max learning rate",
103 | )
104 | parser.add_argument(
105 | "--pool_num", type=int, default=8192,
106 | help="Number of pooling",
107 | )
108 | parser.add_argument(
109 | "--pool_type", type=str, default='max',
110 | help="Type of pooling: avg, max",
111 | )
112 | parser.add_argument(
113 | "--mse_mult", type=float, default=1e4,
114 | help="The weight of mse loss",
115 | )
116 | parser.add_argument(
117 | "--rec_mult", type=float, default=0,
118 | help="The weight of brain reconstruction loss",
119 | )
120 | parser.add_argument(
121 | "--cyc_mult", type=float, default=0,
122 | help="The weight of cycle loss",
123 | )
124 | parser.add_argument(
125 | # "--length", type=int, default=8559,
126 | "--length", type=int, default=None,
127 | help="Indicate dataset length",
128 | )
129 | parser.add_argument(
130 | "--autoencoder_name", type=str, default=None,
131 | help="name of trained autoencoder model",
132 | )
133 | parser.add_argument(
134 | "--subj_load",type=int, default=None, choices=[1,2,3,4,5,6,7,8], nargs='+',
135 | help="subj want to be load in the model",
136 | )
137 | parser.add_argument(
138 | "--subj_test",type=int, default=1, choices=[1,2,3,4,5,6,7,8],
139 | help="subj to test",
140 | )
141 | parser.add_argument(
142 | "--samples",type=int, default=None, nargs='+',
143 | help="Specify sample indice to reconstruction"
144 | )
145 | parser.add_argument(
146 | "--img2img_strength",type=float, default=.85,
147 | help="How much img2img (1=no img2img; 0=outputting the low-level image itself)",
148 | )
149 | parser.add_argument(
150 | "--guidance_scale",type=float, default=3.5,
151 | help="Guidance scale for diffusion model.",
152 | )
153 | parser.add_argument(
154 | "-num_inference_steps",type=int, default=20,
155 | help="Number of inference steps for diffusion model.",
156 | )
157 | parser.add_argument(
158 | "--recons_per_sample", type=int, default=16,
159 | help="How many recons to output, to then automatically pick the best one (MindEye uses 16)",
160 | )
161 | parser.add_argument(
162 | "--plotting", action=argparse.BooleanOptionalAction, default=True,
163 | help="plotting all the results",
164 | )
165 | parser.add_argument(
166 | "--vd_cache_dir", type=str, default='../weights',
167 | help="Where is cached Versatile Diffusion model; if not cached will download to this path",
168 | )
169 | parser.add_argument(
170 | "--gpu_id", type=int, default=0,
171 | help="ID of the GPU to be used",
172 | )
173 | parser.add_argument(
174 | "--ckpt_from", type=str, default='last',
175 | help="ckpt_from ['last', 'best']",
176 | )
177 | parser.add_argument(
178 | "--text_image_ratio", type=float, default=0.5,
179 | help="text_image_ratio in Versatile Diffusion. Only valid when use_text=True. 0.5 means equally weight text and image, 0 means use only image",
180 | )
181 | parser.add_argument(
182 | "--test_start", type=int, default=0,
183 | help="test range start index",
184 | )
185 | parser.add_argument(
186 | "--test_end", type=int, default=None,
187 | help="test range end index, the total length of test data is 982, so max index is 981",
188 | )
189 | parser.add_argument(
190 | "--only_embeddings", action=argparse.BooleanOptionalAction, default=False,
191 | help="only return semantic embeddings of networks",
192 | )
193 | parser.add_argument(
194 | "--synthesis", action=argparse.BooleanOptionalAction, default=False,
195 | help="synthesize new fMRI signals",
196 | )
197 | parser.add_argument(
198 | "--verbose", action=argparse.BooleanOptionalAction, default=True,
199 | help="print more information",
200 | )
201 | parser.add_argument(
202 | "--results_path", type=str, default=None,
203 | help="path to reconstructed outputs",
204 | )
205 |
206 | args = parser.parse_args()
--------------------------------------------------------------------------------
/src/recon.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import matplotlib.pyplot as plt
4 | import torch
5 | from tqdm import tqdm
6 | from datetime import datetime
7 |
8 | import utils
9 | from models import Clipper, MindBridge, MindSingle, Voxel2StableDiffusionModel
10 | import data
11 | from options import args
12 | from eval import cal_metrics
13 |
14 |
15 | ## Load autoencoder
16 | def prepare_voxel2sd(args, ckpt_path, device):
17 | from models import Voxel2StableDiffusionModel
18 | checkpoint = torch.load(ckpt_path, map_location=device)
19 | state_dict = checkpoint['model_state_dict']
20 |
21 | voxel2sd = Voxel2StableDiffusionModel(in_dim=args.num_voxels)
22 |
23 | voxel2sd.load_state_dict(state_dict,strict=False)
24 | voxel2sd.to(device)
25 | voxel2sd.eval()
26 | print("Loaded low-level model!")
27 |
28 | return voxel2sd
29 |
30 | def prepare_data(args):
31 | ## Load data
32 | subj_num_voxels = {
33 | 1: 15724,
34 | 2: 14278,
35 | 3: 15226,
36 | 4: 13153,
37 | 5: 13039,
38 | 6: 17907,
39 | 7: 12682,
40 | 8: 14386
41 | }
42 | args.num_voxels = subj_num_voxels[args.subj_test]
43 |
44 | test_path = "{}/webdataset_avg_split/test/subj0{}".format(args.data_path, args.subj_test)
45 | test_dl = data.get_dataloader(
46 | test_path,
47 | batch_size=args.batch_size,
48 | num_workers=args.num_workers,
49 | seed=args.seed,
50 | is_shuffle=False,
51 | extensions=['nsdgeneral.npy', "jpg", 'coco73k.npy', "subj"],
52 | pool_type=args.pool_type,
53 | pool_num=args.pool_num,
54 | )
55 |
56 | return test_dl
57 |
58 | def prepare_VD(args, device):
59 | print('Creating versatile diffusion reconstruction pipeline...')
60 | from diffusers import VersatileDiffusionDualGuidedPipeline, UniPCMultistepScheduler
61 | from diffusers.models import DualTransformer2DModel
62 | try:
63 | vd_pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(args.vd_cache_dir)
64 | except:
65 | print("Downloading Versatile Diffusion to", args.vd_cache_dir)
66 | vd_pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained(
67 | "shi-labs/versatile-diffusion",
68 | cache_dir = args.vd_cache_dir)
69 |
70 | vd_pipe.image_unet.eval().to(device)
71 | vd_pipe.vae.eval().to(device)
72 | vd_pipe.image_unet.requires_grad_(False)
73 | vd_pipe.vae.requires_grad_(False)
74 |
75 | vd_pipe.scheduler = UniPCMultistepScheduler.from_pretrained("shi-labs/versatile-diffusion", cache_dir=args.vd_cache_dir, subfolder="scheduler")
76 |
77 | # Set weighting of Dual-Guidance
78 | # text_image_ratio=0.5 means equally weight text and image, 0 means use only image
79 | for name, module in vd_pipe.image_unet.named_modules():
80 | if isinstance(module, DualTransformer2DModel):
81 | module.mix_ratio = args.text_image_ratio
82 | for i, type in enumerate(("text", "image")):
83 | if type == "text":
84 | module.condition_lengths[i] = 77
85 | module.transformer_index_for_condition[i] = 1 # use the second (text) transformer
86 | else:
87 | module.condition_lengths[i] = 257
88 | module.transformer_index_for_condition[i] = 0 # use the first (image) transformer
89 |
90 | return vd_pipe
91 |
92 | def prepare_CLIP(args, device):
93 | clip_sizes = {"RN50": 1024, "ViT-L/14": 768, "ViT-B/32": 512, "ViT-H-14": 1024}
94 | clip_size = clip_sizes[args.clip_variant]
95 | out_dim_image = 257 * clip_size
96 | out_dim_text = 77 * clip_size
97 | clip_extractor = Clipper("ViT-L/14", hidden_state=True, norm_embs=True, device=device)
98 |
99 | return clip_extractor, out_dim_image, out_dim_text
100 |
101 | def prepare_voxel2clip(args, out_dim_image, out_dim_text, device):
102 | voxel2clip_kwargs = dict(
103 | in_dim=args.pool_num, out_dim_image=out_dim_image, out_dim_text=out_dim_text,
104 | h=args.h_size, n_blocks=args.n_blocks, subj_list=args.subj_load)
105 |
106 | # only need to load Single-subject version of MindBridge
107 | voxel2clip = MindSingle(**voxel2clip_kwargs)
108 |
109 | outdir = f'../train_logs/{args.model_name}'
110 | ckpt_path = os.path.join(outdir, f'{args.ckpt_from}.pth')
111 | print("ckpt_path",ckpt_path)
112 | checkpoint = torch.load(ckpt_path, map_location='cpu')
113 | print("EPOCH: ",checkpoint['epoch'])
114 | state_dict = checkpoint['model_state_dict']
115 |
116 | voxel2clip.load_state_dict(state_dict,strict=False)
117 | voxel2clip.requires_grad_(False)
118 | voxel2clip.eval().to(device)
119 |
120 | return voxel2clip
121 |
122 | def main(device):
123 | args.batch_size = 1
124 | if args.subj_load is None:
125 | args.subj_load = [args.subj_test]
126 |
127 | # Load data
128 | test_dl = prepare_data(args)
129 | num_test = len(test_dl)
130 |
131 | # Load autoencoder
132 | outdir_ae = f'../train_logs/{args.autoencoder_name}'
133 | ckpt_path = os.path.join(outdir_ae, f'epoch120.pth')
134 | if os.path.exists(ckpt_path):
135 | voxel2sd = prepare_voxel2sd(args, ckpt_path, device)
136 | # pool later
137 | args.pool_type = None
138 | else:
139 | print("No valid path for low-level model specified; not using img2img!")
140 | args.img2img_strength = 1
141 |
142 | # Load VD pipeline
143 | vd_pipe = prepare_VD(args, device)
144 | unet = vd_pipe.image_unet
145 | vae = vd_pipe.vae
146 | noise_scheduler = vd_pipe.scheduler
147 |
148 | # Load CLIP
149 | clip_extractor, out_dim_image, out_dim_text = prepare_CLIP(args, device)
150 |
151 | # load voxel2clip
152 | voxel2clip = prepare_voxel2clip(args, out_dim_image, out_dim_text, device)
153 |
154 | outdir = f'../train_logs/{args.model_name}'
155 | save_dir = os.path.join(outdir, f"recon_on_subj{args.subj_test}")
156 | os.makedirs(save_dir, exist_ok=True)
157 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
158 |
159 | # define test range
160 | test_range = np.arange(num_test)
161 | if args.test_end is None:
162 | args.test_end = num_test
163 |
164 | # define recon logic
165 | only_lowlevel = False
166 | if args.img2img_strength == 1:
167 | img2img = False
168 | elif args.img2img_strength == 0:
169 | img2img = True
170 | only_lowlevel = True
171 | else:
172 | img2img = True
173 |
174 | # recon loop
175 | for val_i, (voxel, img, coco, subj) in enumerate(tqdm(test_dl,total=len(test_range))):
176 | if val_i < args.test_start:
177 | continue
178 | if val_i >= args.test_end:
179 | break
180 | if (args.samples is not None) and (val_i not in args.samples):
181 | continue
182 |
183 | voxel = torch.mean(voxel,axis=1).float().to(device)
184 | img = img.to(device)
185 |
186 | with torch.no_grad():
187 | if args.only_embeddings:
188 | results = voxel2clip(voxel)
189 | embeddings = results[:2]
190 | torch.save(embeddings, os.path.join(save_dir, f'embeddings_{val_i}.pt'))
191 | continue
192 | if img2img: # will apply low-level and high-level pipeline
193 | ae_preds = voxel2sd(voxel)
194 | blurry_recons = vd_pipe.vae.decode(ae_preds.to(device)/0.18215).sample / 2 + 0.5
195 |
196 | if val_i==0:
197 | plt.imshow(utils.torch_to_Image(blurry_recons))
198 | plt.show()
199 |
200 | # pooling
201 | voxel = data.pool_voxels(voxel, args.pool_num, args.pool_type)
202 | else: # only high-level pipeline
203 | blurry_recons = None
204 |
205 | if only_lowlevel: # only low-level pipeline
206 | brain_recons = blurry_recons
207 | else:
208 | grid, brain_recons, best_picks, recon_img = utils.reconstruction(
209 | img, voxel, voxel2clip,
210 | clip_extractor, unet, vae, noise_scheduler,
211 | img_lowlevel = blurry_recons,
212 | num_inference_steps = args.num_inference_steps,
213 | n_samples_save = args.batch_size,
214 | recons_per_sample = args.recons_per_sample,
215 | guidance_scale = args.guidance_scale,
216 | img2img_strength = args.img2img_strength, # 0=fully rely on img_lowlevel, 1=not doing img2img
217 | seed = args.seed,
218 | plotting = args.plotting,
219 | verbose = args.verbose,
220 | device=device,
221 | mem_efficient=False,
222 | )
223 |
224 | if args.plotting:
225 | grid.savefig(os.path.join(save_dir, f'{val_i}.png'))
226 |
227 | brain_recons = brain_recons[:,best_picks.astype(np.int8)]
228 |
229 | torch.save(img, os.path.join(save_dir, f'{val_i}_img.pt'))
230 | torch.save(brain_recons, os.path.join(save_dir, f'{val_i}_rec.pt'))
231 |
232 | print(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
233 | print("save path:", save_dir)
234 |
235 | if __name__ == "__main__":
236 | utils.seed_everything(seed=args.seed)
237 |
238 | device = torch.device('cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
239 | print("device:",device)
240 |
241 | main(device)
242 |
243 | # args.results_path = f'../train_logs/{args.model_name}/recon_on_subj{args.subj_test}'
244 | # cal_metrics(args.results_path, device)
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import scipy as sp
4 | import pandas as pd
5 | import torch
6 | from torchvision import transforms
7 | from tqdm import tqdm
8 | from options import args
9 | import utils
10 | from torchvision.models.feature_extraction import create_feature_extractor
11 |
12 | @torch.no_grad()
13 | def two_way_identification(all_brain_recons, all_images, model, preprocess, feature_layer=None, return_avg=True, device='cpu'):
14 | preds = model(torch.stack([preprocess(recon) for recon in all_brain_recons], dim=0).to(device))
15 | reals = model(torch.stack([preprocess(indiv) for indiv in all_images], dim=0).to(device))
16 | if feature_layer is None:
17 | preds = preds.float().flatten(1).cpu().numpy()
18 | reals = reals.float().flatten(1).cpu().numpy()
19 | else:
20 | preds = preds[feature_layer].float().flatten(1).cpu().numpy()
21 | reals = reals[feature_layer].float().flatten(1).cpu().numpy()
22 |
23 | r = np.corrcoef(reals, preds)
24 | r = r[:len(all_images), len(all_images):]
25 | congruents = np.diag(r)
26 |
27 | success = r < congruents
28 | success_cnt = np.sum(success, 0)
29 |
30 | if return_avg:
31 | perf = np.mean(success_cnt) / (len(all_images)-1)
32 | return perf
33 | else:
34 | return success_cnt, len(all_images)-1
35 |
36 | def cal_metrics(results_path, device):
37 | # Load all images and brain recons
38 | all_files = [f for f in os.listdir(results_path)]
39 |
40 | number = 0
41 | all_images, all_brain_recons = None, None
42 | all_images, all_brain_recons = [], []
43 | for file in tqdm(all_files):
44 | if file.endswith("_img.pt"):
45 | if file.replace("_img.pt", "_rec.pt") in all_files:
46 | number += 1
47 | all_images.append(torch.load(os.path.join(results_path, file), map_location=device))
48 | all_brain_recons.append(torch.load(os.path.join(results_path, file.replace("_img.pt", "_rec.pt")), map_location=device))
49 |
50 | all_images = torch.vstack(all_images)
51 | all_brain_recons = torch.vstack(all_brain_recons)
52 | all_images = all_images.to(device)
53 | all_brain_recons = all_brain_recons.to(device).to(all_images.dtype).clamp(0,1).squeeze()
54 |
55 | print("Images shape:", all_images.shape)
56 | print("Recons shape:", all_brain_recons.shape)
57 | print("Number:", number)
58 |
59 | ### PixCorr
60 | preprocess = transforms.Compose([
61 | transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
62 | ])
63 |
64 | # Flatten images while keeping the batch dimension
65 | all_images_flattened = preprocess(all_images).reshape(len(all_images), -1).cpu()
66 | all_brain_recons_flattened = preprocess(all_brain_recons).view(len(all_brain_recons), -1).cpu()
67 |
68 | print(all_images_flattened.shape)
69 | print(all_brain_recons_flattened.shape)
70 |
71 | print("\n------calculating pixcorr------")
72 | corrsum = 0
73 | for i in tqdm(range(number)):
74 | corrsum += np.corrcoef(all_images_flattened[i], all_brain_recons_flattened[i])[0][1]
75 | corrmean = corrsum / number
76 |
77 | pixcorr = corrmean
78 | print(pixcorr)
79 |
80 | del all_images_flattened
81 | del all_brain_recons_flattened
82 | torch.cuda.empty_cache()
83 |
84 | ### SSIM
85 | # see https://github.com/zijin-gu/meshconv-decoding/issues/3
86 | from skimage.color import rgb2gray
87 | from skimage.metrics import structural_similarity as ssim
88 |
89 | preprocess = transforms.Compose([
90 | transforms.Resize(425, interpolation=transforms.InterpolationMode.BILINEAR),
91 | ])
92 |
93 | # convert image to grayscale with rgb2grey
94 | img_gray = rgb2gray(preprocess(all_images).permute((0,2,3,1)).cpu())
95 | recon_gray = rgb2gray(preprocess(all_brain_recons).permute((0,2,3,1)).cpu())
96 | print("converted, now calculating ssim...")
97 |
98 | ssim_score=[]
99 | for im,rec in tqdm(zip(img_gray,recon_gray),total=len(all_images)):
100 | ssim_score.append(ssim(rec, im, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, data_range=1.0))
101 |
102 | ssim = np.mean(ssim_score)
103 | print(ssim)
104 |
105 |
106 | #### AlexNet
107 | from torchvision.models import alexnet, AlexNet_Weights
108 | alex_weights = AlexNet_Weights.IMAGENET1K_V1
109 |
110 | alex_model = create_feature_extractor(alexnet(weights=alex_weights), return_nodes=['features.4','features.11']).to(device)
111 | alex_model.eval().requires_grad_(False)
112 |
113 | # see alex_weights.transforms()
114 | preprocess = transforms.Compose([
115 | transforms.Resize(256, interpolation=transforms.InterpolationMode.BILINEAR),
116 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
117 | std=[0.229, 0.224, 0.225]),
118 | ])
119 |
120 | layer = 'early, AlexNet(2)'
121 | print(f"\n---{layer}---")
122 | all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images,
123 | alex_model, preprocess, 'features.4', device=device)
124 | alexnet2 = np.mean(all_per_correct)
125 | print(f"2-way Percent Correct: {alexnet2:.4f}")
126 |
127 | layer = 'mid, AlexNet(5)'
128 | print(f"\n---{layer}---")
129 | all_per_correct = two_way_identification(all_brain_recons.to(device).float(), all_images,
130 | alex_model, preprocess, 'features.11', device=device)
131 | alexnet5 = np.mean(all_per_correct)
132 | print(f"2-way Percent Correct: {alexnet5:.4f}")
133 |
134 | del alex_model
135 | torch.cuda.empty_cache()
136 |
137 | #### InceptionV3
138 | from torchvision.models import inception_v3, Inception_V3_Weights
139 | weights = Inception_V3_Weights.DEFAULT
140 | inception_model = create_feature_extractor(inception_v3(weights=weights),
141 | return_nodes=['avgpool']).to(device)
142 | inception_model.eval().requires_grad_(False)
143 |
144 | # see weights.transforms()
145 | preprocess = transforms.Compose([
146 | transforms.Resize(342, interpolation=transforms.InterpolationMode.BILINEAR),
147 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
148 | std=[0.229, 0.224, 0.225]),
149 | ])
150 |
151 | all_per_correct = two_way_identification(all_brain_recons, all_images,
152 | inception_model, preprocess, 'avgpool', device=device)
153 |
154 | inception = np.mean(all_per_correct)
155 | print(f"2-way Percent Correct: {inception:.4f}")
156 |
157 | del inception_model
158 | torch.cuda.empty_cache()
159 |
160 |
161 | #### CLIP
162 | import clip
163 | clip_model, preprocess = clip.load("ViT-L/14", device=device)
164 |
165 | preprocess = transforms.Compose([
166 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
167 | transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
168 | std=[0.26862954, 0.26130258, 0.27577711]),
169 | ])
170 |
171 | all_per_correct = two_way_identification(all_brain_recons, all_images,
172 | clip_model.encode_image, preprocess, None, device=device) # final layer
173 | clip_ = np.mean(all_per_correct)
174 | print(f"2-way Percent Correct: {clip_:.4f}")
175 |
176 |
177 | del clip_model
178 | torch.cuda.empty_cache()
179 |
180 |
181 | #### Efficient Net
182 | from torchvision.models import efficientnet_b1, EfficientNet_B1_Weights
183 | weights = EfficientNet_B1_Weights.DEFAULT
184 | eff_model = create_feature_extractor(efficientnet_b1(weights=weights),
185 | return_nodes=['avgpool']).to(device)
186 | eff_model.eval().requires_grad_(False)
187 |
188 | # see weights.transforms()
189 | preprocess = transforms.Compose([
190 | transforms.Resize(255, interpolation=transforms.InterpolationMode.BILINEAR),
191 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
192 | std=[0.229, 0.224, 0.225]),
193 | ])
194 |
195 | gt = eff_model(preprocess(all_images))['avgpool']
196 | gt = gt.reshape(len(gt),-1).cpu().numpy()
197 | fake = eff_model(preprocess(all_brain_recons))['avgpool']
198 | fake = fake.reshape(len(fake),-1).cpu().numpy()
199 |
200 | effnet = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()
201 | print("Distance:",effnet)
202 |
203 |
204 | del eff_model
205 | torch.cuda.empty_cache()
206 |
207 |
208 | #### SwAV
209 | swav_model = torch.hub.load('facebookresearch/swav:main', 'resnet50')
210 | swav_model = create_feature_extractor(swav_model,
211 | return_nodes=['avgpool']).to(device)
212 | swav_model.eval().requires_grad_(False)
213 |
214 | preprocess = transforms.Compose([
215 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BILINEAR),
216 | transforms.Normalize(mean=[0.485, 0.456, 0.406],
217 | std=[0.229, 0.224, 0.225]),
218 | ])
219 |
220 | gt = swav_model(preprocess(all_images))['avgpool']
221 | gt = gt.reshape(len(gt),-1).cpu().numpy()
222 | fake = swav_model(preprocess(all_brain_recons))['avgpool']
223 | fake = fake.reshape(len(fake),-1).cpu().numpy()
224 |
225 | swav = np.array([sp.spatial.distance.correlation(gt[i],fake[i]) for i in range(len(gt))]).mean()
226 | print("Distance:",swav,"\n")
227 |
228 |
229 | del swav_model
230 | torch.cuda.empty_cache()
231 |
232 |
233 | # # Display in table
234 | # Create a dictionary to store variable names and their corresponding values
235 | data = {
236 | "Metric": ["PixCorr", "SSIM", "AlexNet(2)", "AlexNet(5)", "InceptionV3", "CLIP", "EffNet-B", "SwAV"],
237 | "Value": [pixcorr, ssim, alexnet2, alexnet5, inception, clip_, effnet, swav],
238 | }
239 | print(results_path)
240 | df = pd.DataFrame(data)
241 | print(df.to_string(index=False))
242 |
243 |
244 | # save table to txt file
245 | df.to_csv(os.path.join(results_path, f'_metrics_on_{number}samples.csv'), sep='\t', index=False)
246 |
247 | if __name__ == "__main__":
248 | utils.seed_everything(seed=args.seed)
249 |
250 | device = torch.device('cuda:{}'.format(args.gpu_id) if torch.cuda.is_available() else 'cpu')
251 | print("device:", device)
252 |
253 | assert args.results_path is not None, "Please specify the path to the results folder"
254 | cal_metrics(args.results_path, device)
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torchvision import transforms
3 | import torch
4 | import torch.nn as nn
5 | import PIL
6 | import random
7 | import os
8 | import matplotlib.pyplot as plt
9 | import torchsnooper
10 |
11 |
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 | def seed_everything(seed=0, cudnn_deterministic=True):
15 | random.seed(seed)
16 | os.environ['PYTHONHASHSEED'] = str(seed)
17 | np.random.seed(seed)
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | if cudnn_deterministic:
22 | torch.backends.cudnn.deterministic = True
23 | else:
24 | ## needs to be False to use conv3D
25 | print('Note: not using cudnn.deterministic')
26 |
27 | def np_to_Image(x):
28 | if x.ndim==4:
29 | x=x[0]
30 | return PIL.Image.fromarray((x.transpose(1, 2, 0)*127.5+128).clip(0,255).astype('uint8'))
31 |
32 | def torch_to_Image(x):
33 | if x.ndim==4:
34 | x=x[0]
35 | return transforms.ToPILImage()(x)
36 |
37 | def Image_to_torch(x):
38 | try:
39 | x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5
40 | except:
41 | x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5
42 | return x
43 |
44 | def torch_to_matplotlib(x,device=device):
45 | if torch.mean(x)>10:
46 | x = (x.permute(0, 2, 3, 1)).clamp(0, 255).to(torch.uint8)
47 | else:
48 | x = (x.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8)
49 | if device=='cpu':
50 | return x[0]
51 | else:
52 | return x.cpu().numpy()[0]
53 |
54 | def pairwise_cosine_similarity(A, B, dim=1, eps=1e-8):
55 | #https://stackoverflow.com/questions/67199317/pytorch-cosine-similarity-nxn-elements
56 | numerator = A @ B.T
57 | A_l2 = torch.mul(A, A).sum(axis=dim)
58 | B_l2 = torch.mul(B, B).sum(axis=dim)
59 | denominator = torch.max(torch.sqrt(torch.outer(A_l2, B_l2)), torch.tensor(eps))
60 | return torch.div(numerator, denominator)
61 |
62 | def batchwise_cosine_similarity(Z,B):
63 | # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc
64 | B = B.T
65 | Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1).
66 | B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b).
67 | cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T
68 | return cosine_similarity
69 |
70 | def topk(similarities,labels,k=5):
71 | if k > similarities.shape[0]:
72 | k = similarities.shape[0]
73 | topsum=0
74 | for i in range(k):
75 | topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels)
76 | return topsum
77 |
78 | def soft_clip_loss(preds, targs, temp=0.005, eps=1e-10):
79 | clip_clip = (targs @ targs.T)/temp + eps
80 | check_loss(clip_clip, "clip_clip")
81 | brain_clip = (preds @ targs.T)/temp + eps
82 | check_loss(brain_clip, "brain_clip")
83 |
84 | loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
85 | check_loss(loss1, "loss1")
86 | loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean()
87 | check_loss(loss2, "loss2")
88 |
89 | loss = (loss1 + loss2)/2
90 | return loss
91 |
92 | def count_params(model):
93 | total = sum(p.numel() for p in model.parameters())
94 | trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
95 | print('param counts:\n{:,} total\n{:,} trainable'.format(total, trainable))
96 |
97 | def image_grid(imgs, rows, cols):
98 | w, h = imgs[0].size
99 | grid = PIL.Image.new('RGB', size=(cols*w, rows*h))
100 | for i, img in enumerate(imgs):
101 | grid.paste(img, box=(i%cols*w, i//cols*h))
102 | return grid
103 |
104 | def check_loss(loss, message="loss"):
105 | if loss.isnan().any():
106 | raise ValueError(f'NaN loss in {message}')
107 |
108 | def cosine_anneal(start, end, steps):
109 | return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1)))
110 |
111 | def decode_latents(latents,vae):
112 | latents = 1 / 0.18215 * latents
113 | image = vae.decode(latents).sample
114 | image = (image / 2 + 0.5).clamp(0, 1)
115 | return image
116 |
117 | @torch.no_grad()
118 | def reconstruction(
119 | image, voxel, voxel2clip,
120 | clip_extractor,
121 | unet, vae, noise_scheduler,
122 | img_lowlevel = None,
123 | num_inference_steps = 50,
124 | recons_per_sample = 1,
125 | guidance_scale = 7.5,
126 | img2img_strength = .85,
127 | seed = 42,
128 | plotting=True,
129 | verbose=False,
130 | n_samples_save=1,
131 | device = None,
132 | mem_efficient = True,
133 |
134 | ):
135 | assert n_samples_save==1, "n_samples_save must = 1. Function must be called one image at a time"
136 | assert recons_per_sample>0, "recons_per_sample must > 0"
137 |
138 | brain_recons = None
139 |
140 | voxel=voxel[:n_samples_save]
141 | image=image[:n_samples_save]
142 | B = voxel.shape[0]
143 |
144 | if mem_efficient:
145 | clip_extractor.to("cpu")
146 | unet.to("cpu")
147 | vae.to("cpu")
148 | else:
149 | clip_extractor.to(device)
150 | unet.to(device)
151 | vae.to(device)
152 |
153 | if unet is not None:
154 | do_classifier_free_guidance = guidance_scale > 1.0
155 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
156 | height = unet.config.sample_size * vae_scale_factor
157 | width = unet.config.sample_size * vae_scale_factor
158 | generator = torch.Generator(device=device)
159 | generator.manual_seed(seed)
160 |
161 | if voxel2clip is not None:
162 | clip_results = voxel2clip(voxel)
163 |
164 | if mem_efficient:
165 | voxel2clip.to('cpu')
166 |
167 | brain_clip_image_embeddings, brain_clip_text_embeddings = clip_results[:2]
168 | brain_clip_image_embeddings = brain_clip_image_embeddings.reshape(B,-1,768)
169 | brain_clip_text_embeddings = brain_clip_text_embeddings.reshape(B,-1,768)
170 |
171 | brain_clip_image_embeddings = brain_clip_image_embeddings.repeat(recons_per_sample, 1, 1)
172 | brain_clip_text_embeddings = brain_clip_text_embeddings.repeat(recons_per_sample, 1, 1)
173 |
174 | if recons_per_sample > 0:
175 | for samp in range(len(brain_clip_image_embeddings)):
176 | brain_clip_image_embeddings[samp] = brain_clip_image_embeddings[samp]/(brain_clip_image_embeddings[samp,0].norm(dim=-1).reshape(-1, 1, 1) + 1e-6)
177 | brain_clip_text_embeddings[samp] = brain_clip_text_embeddings[samp]/(brain_clip_text_embeddings[samp,0].norm(dim=-1).reshape(-1, 1, 1) + 1e-6)
178 | input_embedding = brain_clip_image_embeddings#.repeat(recons_per_sample, 1, 1)
179 | if verbose: print("input_embedding",input_embedding.shape)
180 |
181 | prompt_embeds = brain_clip_text_embeddings
182 | if verbose: print("prompt_embedding",prompt_embeds.shape)
183 |
184 | if do_classifier_free_guidance:
185 | input_embedding = torch.cat([torch.zeros_like(input_embedding), input_embedding]).to(device).to(unet.dtype)
186 | prompt_embeds = torch.cat([torch.zeros_like(prompt_embeds), prompt_embeds]).to(device).to(unet.dtype)
187 |
188 | # 3. dual_prompt_embeddings
189 | input_embedding = torch.cat([prompt_embeds, input_embedding], dim=1)
190 |
191 | # 4. Prepare timesteps
192 | noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device)
193 |
194 | # 5b. Prepare latent variables
195 | batch_size = input_embedding.shape[0] // 2 # divide by 2 bc we doubled it for classifier-free guidance
196 | shape = (batch_size, unet.in_channels, height // vae_scale_factor, width // vae_scale_factor)
197 | if img_lowlevel is not None: # use img_lowlevel for img2img initialization
198 | init_timestep = min(int(num_inference_steps * img2img_strength), num_inference_steps)
199 | t_start = max(num_inference_steps - init_timestep, 0)
200 | timesteps = noise_scheduler.timesteps[t_start:]
201 | latent_timestep = timesteps[:1].repeat(batch_size)
202 |
203 | if verbose: print("img_lowlevel", img_lowlevel.shape)
204 | img_lowlevel_embeddings = clip_extractor.normalize(img_lowlevel)
205 | if verbose: print("img_lowlevel_embeddings", img_lowlevel_embeddings.shape)
206 | if mem_efficient:
207 | vae.to(device)
208 | init_latents = vae.encode(img_lowlevel_embeddings.to(device).to(vae.dtype)).latent_dist.sample(generator)
209 | init_latents = vae.config.scaling_factor * init_latents
210 | init_latents = init_latents.repeat(recons_per_sample, 1, 1, 1)
211 |
212 | noise = torch.randn([recons_per_sample, 4, 64, 64], device=device,
213 | generator=generator, dtype=input_embedding.dtype)
214 | init_latents = noise_scheduler.add_noise(init_latents, noise, latent_timestep)
215 | latents = init_latents
216 | else:
217 | timesteps = noise_scheduler.timesteps
218 | latents = torch.randn([recons_per_sample, 4, 64, 64], device=device,
219 | generator=generator, dtype=input_embedding.dtype)
220 | latents = latents * noise_scheduler.init_noise_sigma
221 |
222 | # 7. Denoising loop
223 | if mem_efficient:
224 | unet.to(device)
225 | for i, t in enumerate(timesteps):
226 | # expand the latents if we are doing classifier free guidance
227 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
228 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t)
229 | if verbose: print("timesteps: {}, latent_model_input: {}, input_embedding: {}".format(i, latent_model_input.shape, input_embedding.shape))
230 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_embedding).sample
231 |
232 | # perform guidance
233 | if do_classifier_free_guidance:
234 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
235 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
236 |
237 | # compute the previous noisy sample x_t -> x_t-1
238 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
239 |
240 | if mem_efficient:
241 | unet.to("cpu")
242 |
243 | recons = decode_latents(latents.to(device),vae.to(device)).detach().cpu()
244 |
245 | brain_recons = recons.unsqueeze(0)
246 |
247 | if verbose: print("brain_recons",brain_recons.shape)
248 |
249 | # pick best reconstruction out of several
250 | best_picks = np.zeros(n_samples_save).astype(np.int16)
251 |
252 | if mem_efficient:
253 | vae.to("cpu")
254 | unet.to("cpu")
255 | clip_extractor.to(device)
256 |
257 | clip_image_target = clip_extractor.embed_image(image)
258 | clip_image_target_norm = nn.functional.normalize(clip_image_target.flatten(1), dim=-1)
259 | sims=[]
260 | for im in range(recons_per_sample):
261 | currecon = clip_extractor.embed_image(brain_recons[0,[im]].float()).to(clip_image_target_norm.device).to(clip_image_target_norm.dtype)
262 | currecon = nn.functional.normalize(currecon.view(len(currecon),-1),dim=-1)
263 | cursim = batchwise_cosine_similarity(clip_image_target_norm,currecon)
264 | sims.append(cursim.item())
265 | if verbose: print(sims)
266 | best_picks[0] = int(np.nanargmax(sims))
267 | if verbose: print(best_picks)
268 | if mem_efficient:
269 | clip_extractor.to("cpu")
270 | voxel2clip.to(device)
271 |
272 | img2img_samples = 0 if img_lowlevel is None else 1
273 | num_xaxis_subplots = 1+img2img_samples+recons_per_sample
274 | if plotting:
275 | fig, ax = plt.subplots(n_samples_save, num_xaxis_subplots,
276 | figsize=(num_xaxis_subplots*5,6*n_samples_save),facecolor=(1, 1, 1))
277 | else:
278 | fig = None
279 | recon_img = None
280 |
281 | im = 0
282 | if plotting:
283 | ax[0].set_title(f"Original Image")
284 | ax[0].imshow(torch_to_Image(image[im]))
285 | if img2img_samples == 1:
286 | ax[1].set_title(f"Img2img ({img2img_strength})")
287 | ax[1].imshow(torch_to_Image(img_lowlevel[im].clamp(0,1)))
288 | for ii,i in enumerate(range(num_xaxis_subplots-recons_per_sample,num_xaxis_subplots)):
289 | recon = brain_recons[im][ii]
290 | if plotting:
291 | if ii == best_picks[im]:
292 | ax[i].set_title(f"Reconstruction",fontweight='bold')
293 | recon_img = recon
294 | else:
295 | ax[i].set_title(f"Recon {ii+1} from brain")
296 | ax[i].imshow(torch_to_Image(recon))
297 | if plotting:
298 | for i in range(num_xaxis_subplots):
299 | ax[i].axis('off')
300 |
301 | return fig, brain_recons, best_picks, recon_img
302 |
--------------------------------------------------------------------------------
/src/models.py:
--------------------------------------------------------------------------------
1 | import clip
2 | import torchsnooper
3 | import math
4 | import random
5 | import torch
6 | import torch.nn as nn
7 | import numpy as np
8 | from torchvision import transforms
9 |
10 | def add_hooks(module, parent_name=''):
11 | module_name = module.__class__.__name__
12 | if parent_name:
13 | module_name = f'{parent_name}.{module_name}'
14 |
15 | module.register_forward_hook(lambda mod, inp, out: forward_hook(mod, inp, out, module_name))
16 | module.register_backward_hook(lambda mod, inp, out: backward_hook(mod, inp, out, module_name))
17 |
18 | for name, child in module.named_children():
19 | add_hooks(child, parent_name=f'{module_name}.{name}')
20 |
21 | def forward_hook(module, input, output, name):
22 | if output.isnan().any():
23 | print(f"NaN detected in forward pass in module: {name}")
24 | print(f"Input: {input}")
25 | print(f"Output: {output}")
26 |
27 | def backward_hook(module, grad_input, grad_output, name):
28 | if any(tensor is not None and torch.isnan(tensor).any() for tensor in [*grad_input, *grad_output]):
29 | print(f"NaN detected in backward pass in module: {name}")
30 | print(f"Grad Input: {grad_input}")
31 | print(f"Grad Output: {grad_output}")
32 |
33 | class Clipper(torch.nn.Module):
34 | def __init__(self, clip_variant, clamp_embs=False, norm_embs=False,
35 | hidden_state=False, device=torch.device('cpu')):
36 | super().__init__()
37 | assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \
38 | "clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64"
39 | print(clip_variant, device)
40 |
41 | if clip_variant=="ViT-L/14" and hidden_state:
42 | from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPTokenizer
43 | image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").eval()
44 | image_encoder = image_encoder.to(device)
45 | for param in image_encoder.parameters():
46 | param.requires_grad = False # dont need to calculate gradients
47 | self.image_encoder = image_encoder
48 |
49 | text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").eval()
50 | text_encoder = text_encoder.to(device)
51 | for param in text_encoder.parameters():
52 | param.requires_grad = False # dont need to calculate gradients
53 | self.text_encoder = text_encoder
54 | self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
55 |
56 | elif hidden_state:
57 | raise Exception("hidden_state embeddings only works with ViT-L/14 right now")
58 |
59 | clip_model, preprocess = clip.load(clip_variant, device=device)
60 | clip_model.eval() # dont want to train model
61 | for param in clip_model.parameters():
62 | param.requires_grad = False # dont need to calculate gradients
63 |
64 | self.clip = clip_model
65 | self.clip_variant = clip_variant
66 | if clip_variant == "RN50x64":
67 | self.clip_size = (448,448)
68 | else:
69 | self.clip_size = (224,224)
70 |
71 | preproc = transforms.Compose([
72 | transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC, antialias=None),
73 | transforms.CenterCrop(size=self.clip_size),
74 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
75 | ])
76 | self.preprocess = preproc
77 | self.hidden_state = hidden_state
78 | self.mean = np.array([0.48145466, 0.4578275, 0.40821073])
79 | self.std = np.array([0.26862954, 0.26130258, 0.27577711])
80 | self.normalize = transforms.Normalize(self.mean, self.std)
81 | self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist())
82 | self.clamp_embs = clamp_embs
83 | self.norm_embs = norm_embs
84 | self.device= device
85 |
86 | def versatile_normalize_embeddings(encoder_output):
87 | embeds = encoder_output.last_hidden_state
88 | embeds = image_encoder.vision_model.post_layernorm(embeds)
89 | embeds = image_encoder.visual_projection(embeds)
90 | return embeds
91 | self.versatile_normalize_embeddings = versatile_normalize_embeddings
92 |
93 | def resize_image(self, image):
94 | # note: antialias should be False if planning to use Pinkney's Image Variation SD model
95 | return transforms.Resize(self.clip_size, antialias=None)(image.to(self.device))
96 |
97 | def embed_image(self, image):
98 | """Expects images in -1 to 1 range"""
99 | if self.hidden_state:
100 | # clip_emb = self.preprocess((image/1.5+.25).to(self.device)) # for some reason the /1.5+.25 prevents oversaturation
101 | clip_emb = self.preprocess((image).to(self.device))
102 | clip_emb = self.image_encoder(clip_emb)
103 | clip_emb = self.versatile_normalize_embeddings(clip_emb)
104 | else:
105 | clip_emb = self.preprocess(image.to(self.device))
106 | clip_emb = self.clip.encode_image(clip_emb)
107 | # input is now in CLIP space, but mind-reader preprint further processes embeddings:
108 | if self.clamp_embs:
109 | clip_emb = torch.clamp(clip_emb, -1.5, 1.5)
110 | if self.norm_embs:
111 | if self.hidden_state:
112 | # normalize all tokens by cls token's norm
113 | clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1)
114 | else:
115 | clip_emb = nn.functional.normalize(clip_emb, dim=-1)
116 | return clip_emb
117 |
118 | def embed_text(self, prompt):
119 | r"""
120 | Encodes the prompt into text encoder hidden states.
121 |
122 | Args:
123 | prompt (`str` or `List[str]`):
124 | prompt to be encoded
125 | device: (`torch.device`):
126 | torch device
127 | num_images_per_prompt (`int`):
128 | number of images that should be generated per prompt
129 | do_classifier_free_guidance (`bool`):
130 | whether to use classifier free guidance or not
131 | """
132 |
133 | def normalize_embeddings(encoder_output):
134 | embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
135 | embeds_pooled = encoder_output.text_embeds
136 | embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
137 | return embeds
138 |
139 | text_inputs = self.tokenizer(
140 | prompt,
141 | padding="max_length",
142 | max_length=self.tokenizer.model_max_length,
143 | truncation=True,
144 | return_tensors="pt",
145 | )
146 | text_input_ids = text_inputs.input_ids
147 | untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids
148 | with torch.no_grad():
149 | prompt_embeds = self.text_encoder(
150 | text_input_ids.to(self.device),
151 | )
152 | prompt_embeds = normalize_embeddings(prompt_embeds)
153 |
154 | # duplicate text embeddings for each generation per prompt, using mps friendly method
155 | # bs_embed, seq_len, _ = prompt_embeds.shape
156 | # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
157 | # prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
158 |
159 | return prompt_embeds
160 |
161 | def embed_curated_annotations(self, annots):
162 | for i,b in enumerate(annots):
163 | t = ''
164 | while t == '':
165 | rand = torch.randint(5,(1,1))[0][0]
166 | t = b[0,rand]
167 | if i==0:
168 | txt = np.array(t)
169 | else:
170 | txt = np.vstack((txt,t))
171 | txt = txt.flatten()
172 | return self.embed_text(txt)
173 |
174 | class Adapter_Layer(nn.Module):
175 | def __init__(self,
176 | in_channels,
177 | bottleneck=32,
178 | out_channels=None,
179 | dropout=0.0,
180 | init_option="lora",
181 | adapter_scalar="1.0",
182 | adapter_layernorm_option=None):
183 | super().__init__()
184 | self.in_channels = in_channels
185 | self.down_size = bottleneck
186 | self.out_channels = out_channels if out_channels is not None else in_channels
187 | # self.non_linearity = args.non_linearity # use ReLU by default
188 |
189 | #_before
190 | self.adapter_layernorm_option = adapter_layernorm_option
191 |
192 | self.adapter_layer_norm = None
193 | if adapter_layernorm_option == "in" or adapter_layernorm_option == "out":
194 | self.adapter_layer_norm = nn.LayerNorm(self.n_embd)
195 |
196 | if adapter_scalar == "learnable_scalar":
197 | self.scale = nn.Parameter(torch.ones(1))
198 | else:
199 | self.scale = float(adapter_scalar)
200 |
201 | self.down_proj = nn.Linear(self.in_channels, self.down_size)
202 | self.non_linear_func = nn.ReLU()
203 | self.up_proj = nn.Linear(self.down_size, self.out_channels)
204 |
205 | self.dropout = dropout
206 |
207 | if init_option == "lora":
208 | with torch.no_grad():
209 | nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5))
210 | nn.init.zeros_(self.up_proj.weight)
211 | nn.init.zeros_(self.down_proj.bias)
212 | nn.init.zeros_(self.up_proj.bias)
213 |
214 | def forward(self, x, add_residual=True, residual=None):
215 | residual = x if residual is None else residual
216 | if self.adapter_layernorm_option == 'in':
217 | x = self.adapter_layer_norm(x)
218 |
219 | down = self.down_proj(x)
220 | down = self.non_linear_func(down)
221 | down = nn.functional.dropout(down, p=self.dropout, training=self.training)
222 | up = self.up_proj(down)
223 |
224 | up = up * self.scale
225 |
226 | if self.adapter_layernorm_option == 'out':
227 | up = self.adapter_layer_norm(up)
228 |
229 | if add_residual:
230 | output = up + residual
231 | else:
232 | output = up
233 |
234 | return output
235 |
236 | class ResMLP(nn.Module):
237 | def __init__(self, h, n_blocks, dropout=0.15):
238 | super().__init__()
239 | self.n_blocks = n_blocks
240 | self.mlp = nn.ModuleList([
241 | nn.Sequential(
242 | nn.Linear(h, h),
243 | nn.LayerNorm(h),
244 | nn.GELU(),
245 | nn.Dropout(dropout)
246 | ) for _ in range(n_blocks)
247 | ])
248 |
249 | def forward(self, x):
250 | residual = x
251 | for res_block in range(self.n_blocks):
252 | x = self.mlp[res_block](x)
253 | x += residual
254 | residual = x
255 | return x
256 |
257 | class MindSingle(nn.Module):
258 | def __init__(self, in_dim=15724, out_dim_image=768, out_dim_text=None,
259 | h=4096, n_blocks=4, subj_list=None,):
260 |
261 | super().__init__()
262 |
263 | self.subj_list = subj_list
264 | self.embedder = nn.ModuleDict({
265 | str(subj): nn.Sequential(
266 | Adapter_Layer(in_dim, 128),
267 | nn.Linear(in_dim, h),
268 | nn.LayerNorm(h),
269 | nn.GELU(),
270 | nn.Dropout(0.5),
271 | ) for subj in subj_list
272 | })
273 |
274 | self.translator = ResMLP(h, n_blocks)
275 | self.head_image = nn.Linear(h, out_dim_image)
276 | self.head_text = nn.Linear(h, out_dim_text)
277 |
278 | # @torchsnooper.snoop()
279 | def forward(self, x):
280 | x = self.embedder[str(self.subj_list[0])](x)
281 | x = self.translator(x)
282 |
283 | x_image = self.head_image(x)
284 | x_image = x_image.reshape(len(x_image), -1)
285 |
286 | x_text = self.head_text(x)
287 | x_text = x_text.reshape(len(x_text), -1)
288 |
289 | return x_image, x_text
290 |
291 | class MindBridge(MindSingle):
292 | def __init__(self, in_dim=15724, out_dim_image=768, out_dim_text=None,
293 | h=4096, n_blocks=4, subj_list=None, adapting=False):
294 |
295 | assert len(subj_list) >= 2, "MindBridge requires at least 2 subjects"
296 |
297 | super().__init__(in_dim=in_dim, out_dim_image=out_dim_image,
298 | out_dim_text=out_dim_text, h=h, n_blocks=n_blocks, subj_list=subj_list
299 | )
300 |
301 | self.builder = nn.ModuleDict({
302 | str(subj): nn.Sequential(
303 | nn.Linear(h, in_dim),
304 | nn.LayerNorm(in_dim),
305 | nn.GELU(),
306 | Adapter_Layer(in_dim, 128),
307 | ) for subj in subj_list
308 | })
309 |
310 | self.adapting = adapting
311 | self.cyc_loss = nn.MSELoss()
312 |
313 | # @torchsnooper.snoop()
314 | def forward(self, x):
315 | if len(x) == 2 and type(x) is tuple:
316 | subj_list = x[1].tolist() # (s,)
317 | x = x[0] # (b,n)
318 | else:
319 | subj_list = self.subj_list
320 |
321 | x = x.squeeze()
322 | x_subj = torch.chunk(x,len(subj_list))
323 | x = []
324 | x_rec = []
325 | if self.adapting: # choose subj_a (source subject) and subj_b (target subject)
326 | subj_a, subj_b = subj_list[0], subj_list[-1]
327 | else: # random sample 2 subjects
328 | subj_a, subj_b = random.sample(subj_list, 2)
329 | for i, subj_i in enumerate(subj_list): # subj is 1-based
330 | x_i = self.embedder[str(subj_i)](x_subj[i]) # subj_i seman embedding
331 | if subj_i == subj_a: x_a = x_i # subj_a seman embedding are choosen
332 | x.append(x_i)
333 | x_i_rec = self.builder[str(subj_i)](x_i) # subj_i recon brain signals
334 | x_rec.append(x_i_rec)
335 |
336 | x = torch.concat(x, dim=0)
337 | x_rec = torch.concat(x_rec, dim=0)
338 | # del x_i, x_subj, x_i_rec
339 |
340 | # forward cycling
341 | x_b = self.builder[str(subj_b)](x_a) # subj_b recon brain signal using subj_a seman embedding
342 | x_b = self.embedder[str(subj_b)](x_b) # subj_b seman embedding (pseudo)
343 | loss_cyc = self.cyc_loss(x_a, x_b)
344 |
345 | x = self.translator(x)
346 | x_image = self.head_image(x)
347 | x_image = x_image.reshape(len(x_image), -1)
348 |
349 | x_text = self.head_text(x)
350 | x_text = x_text.reshape(len(x_text), -1)
351 |
352 | return x_image, x_text, x_rec, loss_cyc
353 |
354 |
355 | from diffusers.models.vae import Decoder
356 | class Voxel2StableDiffusionModel(torch.nn.Module):
357 | def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False, ups_mode='4x'):
358 | super().__init__()
359 | self.lin0 = nn.Sequential(
360 | nn.Linear(in_dim, h, bias=False),
361 | nn.LayerNorm(h),
362 | nn.SiLU(inplace=True),
363 | nn.Dropout(0.5),
364 | )
365 |
366 | self.mlp = nn.ModuleList([
367 | nn.Sequential(
368 | nn.Linear(h, h, bias=False),
369 | nn.LayerNorm(h),
370 | nn.SiLU(inplace=True),
371 | nn.Dropout(0.25)
372 | ) for _ in range(n_blocks)
373 | ])
374 | self.ups_mode = ups_mode
375 | if ups_mode=='4x':
376 | self.lin1 = nn.Linear(h, 16384, bias=False)
377 | self.norm = nn.GroupNorm(1, 64)
378 |
379 | self.upsampler = Decoder(
380 | in_channels=64,
381 | out_channels=4,
382 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
383 | block_out_channels=[64, 128, 256],
384 | layers_per_block=1,
385 | )
386 |
387 | if use_cont:
388 | self.maps_projector = nn.Sequential(
389 | nn.Conv2d(64, 512, 1, bias=False),
390 | nn.GroupNorm(1,512),
391 | nn.ReLU(True),
392 | nn.Conv2d(512, 512, 1, bias=False),
393 | nn.GroupNorm(1,512),
394 | nn.ReLU(True),
395 | nn.Conv2d(512, 512, 1, bias=True),
396 | )
397 | else:
398 | self.maps_projector = nn.Identity()
399 |
400 | if ups_mode=='8x': # prev best
401 | self.lin1 = nn.Linear(h, 16384, bias=False)
402 | self.norm = nn.GroupNorm(1, 256)
403 |
404 | self.upsampler = Decoder(
405 | in_channels=256,
406 | out_channels=4,
407 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"],
408 | block_out_channels=[64, 128, 256, 256],
409 | layers_per_block=1,
410 | )
411 | self.maps_projector = nn.Identity()
412 |
413 | if ups_mode=='16x':
414 | self.lin1 = nn.Linear(h, 8192, bias=False)
415 | self.norm = nn.GroupNorm(1, 512)
416 |
417 | self.upsampler = Decoder(
418 | in_channels=512,
419 | out_channels=4,
420 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D", "UpDecoderBlock2D"],
421 | block_out_channels=[64, 128, 256, 256, 512],
422 | layers_per_block=1,
423 | )
424 | self.maps_projector = nn.Identity()
425 |
426 | if use_cont:
427 | self.maps_projector = nn.Sequential(
428 | nn.Conv2d(64, 512, 1, bias=False),
429 | nn.GroupNorm(1,512),
430 | nn.ReLU(True),
431 | nn.Conv2d(512, 512, 1, bias=False),
432 | nn.GroupNorm(1,512),
433 | nn.ReLU(True),
434 | nn.Conv2d(512, 512, 1, bias=True),
435 | )
436 | else:
437 | self.maps_projector = nn.Identity()
438 |
439 | # @torchsnooper.snoop()
440 | def forward(self, x, return_transformer_feats=False):
441 | x = self.lin0(x)
442 | residual = x
443 | for res_block in self.mlp:
444 | x = res_block(x)
445 | x = x + residual
446 | residual = x
447 | x = x.reshape(len(x), -1)
448 | x = self.lin1(x) # bs, 4096
449 |
450 | if self.ups_mode == '4x':
451 | side = 16
452 | if self.ups_mode == '8x':
453 | side = 8
454 | if self.ups_mode == '16x':
455 | side = 4
456 |
457 | # decoder
458 | x = self.norm(x.reshape(x.shape[0], -1, side, side).contiguous())
459 | if return_transformer_feats:
460 | return self.upsampler(x), self.maps_projector(x).flatten(2).permute(0,2,1)
461 | return self.upsampler(x)
462 |
463 | if __name__ == "__main__":
464 | in_dim=8000
465 | h=2048
466 | head = nn.Sequential(
467 | Adapter_Layer(in_dim, 128),
468 | nn.Linear(in_dim, h),
469 | nn.LayerNorm(h),
470 | nn.GELU(),
471 | nn.Dropout(0.5),
472 | )
473 |
474 | def add_hooks(module, parent_name=''):
475 | module_name = module.__class__.__name__
476 | if parent_name:
477 | module_name = f'{parent_name}.{module_name}'
478 |
479 | for name, child in module.named_children():
480 | add_hooks(child, parent_name=f'{module_name}.{name}')
481 |
482 | print(module_name)
483 |
484 | add_hooks(head)
485 |
--------------------------------------------------------------------------------
/src/nsd_access.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as op
3 | import glob
4 | import nibabel as nb
5 | import numpy as np
6 | import pandas as pd
7 | from pandas import json_normalize
8 | from tqdm import tqdm
9 | import h5py
10 | import matplotlib.pyplot as plt
11 |
12 | import urllib.request
13 | import zipfile
14 | from pycocotools.coco import COCO
15 |
16 |
17 |
18 | class NSDAccess(object):
19 | """
20 | Little class that provides easy access to the NSD data, see [http://naturalscenesdataset.org](their website)
21 | """
22 |
23 | def __init__(self, nsd_folder, *args, **kwargs):
24 | super().__init__(*args, **kwargs)
25 | self.nsd_folder = nsd_folder
26 | self.nsddata_folder = op.join(self.nsd_folder, 'nsddata')
27 | self.ppdata_folder = op.join(self.nsd_folder, 'nsddata', 'ppdata')
28 | self.nsddata_betas_folder = op.join(
29 | self.nsd_folder, 'nsddata_betas', 'ppdata')
30 |
31 | self.behavior_file = op.join(
32 | self.ppdata_folder, '{subject}', 'behav', 'responses.tsv')
33 | self.stimuli_file = op.join(
34 | self.nsd_folder, 'nsddata_stimuli', 'stimuli', 'nsd', 'nsd_stimuli.hdf5')
35 | self.stimuli_description_file = op.join(
36 | self.nsd_folder, 'nsddata', 'experiments', 'nsd', 'nsd_stim_info_merged.csv')
37 |
38 | self.coco_annotation_file = op.join(
39 | self.nsd_folder, 'nsddata_stimuli', 'stimuli', 'nsd', 'annotations', '{}_{}.json')
40 |
41 | def download_coco_annotation_file(self, url='http://images.cocodataset.org/annotations/annotations_trainval2017.zip'):
42 | """download_coco_annotation_file downloads and extracts the relevant annotations files
43 |
44 | Parameters
45 | ----------
46 | url : str, optional
47 | url for zip file containing annotations, by default 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'
48 | """
49 | print('downloading annotations from {}'.format(url))
50 | filehandle, _ = urllib.request.urlretrieve(url)
51 | zip_file_object = zipfile.ZipFile(filehandle, 'r')
52 | zip_file_object.extractall(path=op.split(
53 | op.split(self.coco_annotation_file)[0])[0])
54 |
55 | def affine_header(self, subject, data_format='func1pt8mm'):
56 | """affine_header affine and header, for construction of Nifti image
57 |
58 | Parameters
59 | ----------
60 | subject : str
61 | subject identifier, such as 'subj01'
62 | data_format : str, optional
63 | what type of data format, from ['func1pt8mm', 'func1mm'], by default 'func1pt8mm'
64 |
65 | Returns
66 | -------
67 | tuple
68 | affine and header, for construction of Nifti image
69 | """
70 | full_path = op.join(self.ppdata_folder,
71 | '{subject}', '{data_format}', 'brainmask.nii.gz')
72 | full_path = full_path.format(subject=subject,
73 | data_format=data_format)
74 | nii = nb.load(full_path)
75 |
76 | return nii.affine, nii.header
77 |
78 | def read_vol_ppdata(self, subject, filename='brainmask', data_format='func1pt8mm'):
79 | """load_brainmask, returns boolean brainmask for volumetric data formats
80 |
81 | Parameters
82 | ----------
83 | subject : str
84 | subject identifier, such as 'subj01'
85 | data_format : str, optional
86 | what type of data format, from ['func1pt8mm', 'func1mm'], by default 'func1pt8mm'
87 |
88 | Returns
89 | -------
90 | numpy.ndarray, 4D (bool)
91 | brain mask array
92 | """
93 | full_path = op.join(self.ppdata_folder,
94 | '{subject}', '{data_format}', '{filename}.nii.gz')
95 | full_path = full_path.format(subject=subject,
96 | data_format=data_format,
97 | filename=filename)
98 | return nb.load(full_path).get_data()
99 |
100 | def read_betas(self, subject, session_index, trial_index=[], data_type='betas_fithrf_GLMdenoise_RR', data_format='fsaverage', mask=None):
101 | """read_betas read betas from MRI files
102 |
103 | Parameters
104 | ----------
105 | subject : str
106 | subject identifier, such as 'subj01'
107 | session_index : int
108 | which session, counting from 1
109 | trial_index : list, optional
110 | which trials from this session's file to return, by default [], which returns all trials
111 | data_type : str, optional
112 | which type of beta values to return from ['betas_assumehrf', 'betas_fithrf', 'betas_fithrf_GLMdenoise_RR', 'restingbetas_fithrf'], by default 'betas_fithrf_GLMdenoise_RR'
113 | data_format : str, optional
114 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm'], by default 'fsaverage'
115 | mask : numpy.ndarray, if defined, selects 'mat' data_format, needs volumetric data_format
116 | binary/boolean mask into mat file beta data format.
117 |
118 | Returns
119 | -------
120 | numpy.ndarray, 2D (fsaverage) or 4D (other data formats)
121 | the requested per-trial beta values
122 | """
123 | data_folder = op.join(self.nsddata_betas_folder,
124 | subject, data_format, data_type)
125 | si_str = str(session_index).zfill(2)
126 |
127 | if type(mask) == np.ndarray: # will use the mat file iff exists, otherwise boom!
128 | ipf = op.join(data_folder, f'betas_session{si_str}.mat')
129 | assert op.isfile(ipf), \
130 | 'Error: ' + ipf + ' not available for masking. You may need to download these separately.'
131 | # will do indexing of both space and time in one go for this option,
132 | # so will return results immediately from this
133 | h5 = h5py.File(ipf, 'r')
134 | betas = h5.get('betas')
135 | # embed()
136 | if len(trial_index) == 0:
137 | trial_index = slice(0, betas.shape[0])
138 | # this isn't finished yet - binary masks cannot be used for indexing like this
139 | return betas[trial_index, np.nonzero(mask)]
140 |
141 | if data_format == 'fsaverage':
142 | session_betas = []
143 | for hemi in ['lh', 'rh']:
144 | hdata = nb.load(op.join(
145 | data_folder, f'{hemi}.betas_session{si_str}.mgh')).get_data()
146 | session_betas.append(hdata)
147 | out_data = np.squeeze(np.vstack(session_betas))
148 | else:
149 | # if no mask was specified, we'll use the nifti image
150 | out_data = nb.load(
151 | op.join(data_folder, f'betas_session{si_str}.nii.gz')).get_fdata()
152 |
153 | if len(trial_index) == 0:
154 | trial_index = slice(0, out_data.shape[-1])
155 |
156 | return out_data[..., trial_index]
157 |
158 | def read_mapper_results(self, subject, mapper='prf', data_type='angle', data_format='fsaverage'):
159 | """read_mapper_results [summary]
160 |
161 | Parameters
162 | ----------
163 | subject : str
164 | subject identifier, such as 'subj01'
165 | mapper : str, optional
166 | first part of the mapper filename, by default 'prf'
167 | data_type : str, optional
168 | second part of the mapper filename, by default 'angle'
169 | data_format : str, optional
170 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm'], by default 'fsaverage'
171 |
172 | Returns
173 | -------
174 | numpy.ndarray, 2D (fsaverage) or 4D (other data formats)
175 | the requested mapper values
176 | """
177 | if data_format == 'fsaverage':
178 | # unclear for now where the fsaverage mapper results would be
179 | # as they are still in fsnative format now.
180 | raise NotImplementedError(
181 | 'no mapper results in fsaverage present for now')
182 | else: # is 'func1pt8mm' or 'func1mm'
183 | return self.read_vol_ppdata(subject=subject, filename=f'{mapper}_{data_type}', data_format=data_format)
184 |
185 | def read_atlas_results(self, subject, atlas='HCP_MMP1', data_format='fsaverage'):
186 | """read_atlas_results [summary]
187 |
188 | Parameters
189 | ----------
190 | subject : str
191 | subject identifier, such as 'subj01'
192 | for surface-based data formats, subject should be the same as data_format.
193 | for example, for fsaverage, both subject and data_format should be 'fsaverage'
194 | this requires a little more typing but makes data format explicit
195 | atlas : str, optional
196 | which atlas to read,
197 | for volume formats, any of ['HCP_MMP1', 'Kastner2015', 'nsdgeneral', 'visualsulc'] for volume,
198 | for fsaverage
199 | can be prefixed by 'lh.' or 'rh.' for hemisphere-specific atlases in volume
200 | for surface: takes both hemispheres by default, instead when prefixed by '.rh' or '.lh'.
201 | By default 'HCP_MMP1'.
202 | data_format : str, optional
203 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm', 'MNI'], by default 'fsaverage'
204 |
205 | Returns
206 | -------
207 | numpy.ndarray, 1D/2D (surface) or 3D/4D (volume data formats)
208 | the requested atlas values
209 | dict,
210 | dictionary containing the mapping between ROI names and atlas values
211 | """
212 |
213 | # first, get the mapping.
214 | atlas_name = atlas
215 | if atlas[:3] in ('rh.', 'lh.'):
216 | atlas_name = atlas[3:]
217 |
218 | mapp_df = pd.read_csv(os.path.join(self.nsddata_folder, 'freesurfer', 'fsaverage',
219 | 'label', f'{atlas_name}.mgz.ctab'), delimiter=' ', header=None, index_col=0)
220 | atlas_mapping = mapp_df.to_dict()[1]
221 | # dict((y,x) for x,y in atlas_mapping.iteritems())
222 | atlas_mapping = {y: x for x, y in atlas_mapping.items()}
223 |
224 | if data_format not in ('func1pt8mm', 'func1mm', 'MNI'):
225 | # if surface based results by exclusion
226 | if atlas[:3] in ('rh.', 'lh.'): # check if hemisphere-specific atlas requested
227 | ipf = op.join(self.nsddata_folder, 'freesurfer',
228 | subject, 'label', f'{atlas}.mgz')
229 | return np.squeeze(nb.load(ipf).get_data()), atlas_mapping
230 | else: # more than one hemisphere requested
231 | session_betas = []
232 | for hemi in ['lh', 'rh']:
233 | hdata = nb.load(op.join(
234 | self.nsddata_folder, 'freesurfer', subject, 'label', f'{hemi}.{atlas}.mgz')).get_data()
235 | session_betas.append(hdata)
236 | out_data = np.squeeze(np.vstack(session_betas))
237 | return out_data, atlas_mapping
238 | else: # is 'func1pt8mm', 'MNI', or 'func1mm'
239 | ipf = op.join(self.ppdata_folder, subject,
240 | data_format, 'roi', f'{atlas}.nii.gz')
241 | return nb.load(ipf).get_fdata(), atlas_mapping
242 |
243 | def list_atlases(self, subject, data_format='fsaverage', abs_paths=False):
244 | """list_atlases [summary]
245 |
246 | Parameters
247 | ----------
248 | subject : str
249 | subject identifier, such as 'subj01'
250 | for surface-based data formats, subject should be the same as data_format.
251 | for example, for fsaverage, both subject and data_format should be 'fsaverage'
252 | this requires a little more typing but makes data format explicit
253 | data_format : str, optional
254 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm', 'MNI'], by default 'fsaverage'
255 |
256 | Returns
257 | -------
258 | list
259 | collection of absolute path names to
260 | """
261 | if data_format in ('func1pt8mm', 'func1mm', 'MNI'):
262 | atlas_files = glob.glob(
263 | op.join(self.ppdata_folder, subject, data_format, 'roi', '*.nii.gz'))
264 | else:
265 | atlas_files = glob.glob(
266 | op.join(self.nsddata_folder, 'freesurfer', subject, 'label', '*.mgz'))
267 |
268 | # print this
269 | import pprint
270 | pp = pprint.PrettyPrinter(indent=4)
271 | print('Atlases found in {}:'.format(op.split(atlas_files[0])[0]))
272 | pp.pprint([op.split(f)[1] for f in atlas_files])
273 | if abs_paths:
274 | return atlas_files
275 | else: # this is the format which you can input into other functions, so this is the default
276 | return np.unique([op.split(f)[1].replace('lh.', '').replace('rh.', '').replace('.mgz', '').replace('.nii.gz', '') for f in atlas_files])
277 |
278 | def read_behavior(self, subject, session_index, trial_index=[]):
279 | """read_behavior [summary]
280 |
281 | Parameters
282 | ----------
283 | subject : str
284 | subject identifier, such as 'subj01'
285 | session_index : int
286 | which session, counting from 0
287 | trial_index : list, optional
288 | which trials from this session's behavior to return, by default [], which returns all trials
289 |
290 | Returns
291 | -------
292 | pandas DataFrame
293 | DataFrame containing the behavioral information for the requested trials
294 | """
295 |
296 | behavior = pd.read_csv(self.behavior_file.format(
297 | subject=subject), delimiter='\t')
298 |
299 | # the behavior is encoded per run.
300 | # I'm now setting this function up so that it aligns with the timepoints in the fmri files,
301 | # i.e. using indexing per session, and not using the 'run' information.
302 | session_behavior = behavior[behavior['SESSION'] == session_index]
303 |
304 | if len(trial_index) == 0:
305 | trial_index = slice(0, len(session_behavior))
306 |
307 | return session_behavior.iloc[trial_index]
308 |
309 | def read_images(self, image_index, show=False):
310 | """read_images reads a list of images, and returns their data
311 |
312 | Parameters
313 | ----------
314 | image_index : list of integers
315 | which images indexed in the 73k format to return
316 | show : bool, optional
317 | whether to also show the images, by default False
318 |
319 | Returns
320 | -------
321 | numpy.ndarray, 3D
322 | RGB image data
323 | """
324 |
325 | if not hasattr(self, 'stim_descriptions'):
326 | self.stim_descriptions = pd.read_csv(
327 | self.stimuli_description_file, index_col=0)
328 |
329 | sf = h5py.File(self.stimuli_file, 'r')
330 | sdataset = sf.get('imgBrick')
331 | if show:
332 | f, ss = plt.subplots(1, len(image_index),
333 | figsize=(6*len(image_index), 6))
334 | if len(image_index) == 1:
335 | ss = [ss]
336 | for s, d in zip(ss, sdataset[image_index]):
337 | s.axis('off')
338 | s.imshow(d)
339 | return sdataset[image_index]
340 |
341 | def read_image_coco_info(self, image_index, info_type='captions', show_annot=False, show_img=False):
342 | """image_coco_info returns the coco annotations of a single image or a list of images
343 |
344 | Parameters
345 | ----------
346 | image_index : list of integers
347 | which images indexed in the 73k format to return the captions for
348 | info_type : str, optional
349 | what type of annotation to return, from ['captions', 'person_keypoints', 'instances'], by default 'captions'
350 | show_annot : bool, optional
351 | whether to show the annotation, by default False
352 | show_img : bool, optional
353 | whether to show the image (from the nsd formatted data), by default False
354 |
355 | Returns
356 | -------
357 | coco Annotation
358 | coco annotation, to be used in subsequent analysis steps
359 |
360 | Example
361 | -------
362 | single image:
363 | ci = read_image_coco_info(
364 | [569], info_type='captions', show_annot=False, show_img=False)
365 | list of images:
366 | ci = read_image_coco_info(
367 | [569, 2569], info_type='captions')
368 |
369 | """
370 | if not hasattr(self, 'stim_descriptions'):
371 | self.stim_descriptions = pd.read_csv(
372 | self.stimuli_description_file, index_col=0)
373 | if len(image_index) == 1:
374 | subj_info = self.stim_descriptions.iloc[image_index[0]]
375 |
376 | # checking whether annotation file for this trial exists.
377 | # This may not be the right place to call the download, and
378 | # re-opening the annotations for all images separately may be slowing things down
379 | # however images used in the experiment seem to have come from different sets.
380 | annot_file = self.coco_annotation_file.format(
381 | info_type, subj_info['cocoSplit'])
382 | print('getting annotations from ' + annot_file)
383 | if not os.path.isfile(annot_file):
384 | print('annotations file not found')
385 | self.download_coco_annotation_file()
386 |
387 | coco = COCO(annot_file)
388 | coco_annot_IDs = coco.getAnnIds([subj_info['cocoId']])
389 | coco_annot = coco.loadAnns(coco_annot_IDs)
390 |
391 | if show_img:
392 | self.read_images(image_index, show=True)
393 |
394 | if show_annot:
395 | # still need to convert the annotations (especially person_keypoints and instances) to the right reference frame,
396 | # because the images were cropped. See image information per image to do this.
397 | coco.showAnns(coco_annot)
398 |
399 | elif len(image_index) > 1:
400 |
401 | # we output a list of annots
402 | coco_annot = []
403 |
404 | # load train_2017
405 | annot_file = self.coco_annotation_file.format(
406 | info_type, 'train2017')
407 | coco_train = COCO(annot_file)
408 |
409 | # also load the val 2017
410 | annot_file = self.coco_annotation_file.format(
411 | info_type, 'val2017')
412 | coco_val = COCO(annot_file)
413 |
414 | for image in image_index:
415 | subj_info = self.stim_descriptions.iloc[image]
416 | if subj_info['cocoSplit'] == 'train2017':
417 | coco_annot_IDs = coco_train.getAnnIds(
418 | [subj_info['cocoId']])
419 | coco_ann = coco_train.loadAnns(coco_annot_IDs)
420 | coco_annot.append(coco_ann)
421 |
422 | elif subj_info['cocoSplit'] == 'val2017':
423 | coco_annot_IDs = coco_val.getAnnIds(
424 | [subj_info['cocoId']])
425 | coco_ann = coco_val.loadAnns(coco_annot_IDs)
426 | coco_annot.append(coco_ann)
427 |
428 | return coco_annot
429 |
430 | def read_image_coco_category(self, image_index):
431 | """image_coco_category returns the coco category of a single image or a list of images
432 |
433 | Args:
434 | image_index ([list of integers]): which images indexed in the 73k format to return
435 | the category for
436 |
437 | Returns
438 | -------
439 | coco category
440 | coco category, to be used in subsequent analysis steps
441 |
442 | Example
443 | -------
444 | single image:
445 | ci = read_image_coco_category(
446 | [569])
447 | list of images:
448 | ci = read_image_coco_category(
449 | [569, 2569])
450 | """
451 |
452 | if not hasattr(self, 'stim_descriptions'):
453 | self.stim_descriptions = pd.read_csv(
454 | self.stimuli_description_file, index_col=0)
455 |
456 | if len(image_index) == 1:
457 | subj_info = self.stim_descriptions.iloc[image_index[0]]
458 | coco_id = subj_info['cocoId']
459 |
460 | # checking whether annotation file for this trial exists.
461 | # This may not be the right place to call the download, and
462 | # re-opening the annotations for all images separately may be slowing things down
463 | # however images used in the experiment seem to have come from different sets.
464 | annot_file = self.coco_annotation_file.format(
465 | 'instances', subj_info['cocoSplit'])
466 | print('getting annotations from ' + annot_file)
467 | if not os.path.isfile(annot_file):
468 | print('annotations file not found')
469 | self.download_coco_annotation_file()
470 |
471 | coco = COCO(annot_file)
472 |
473 | cat_ids = coco.getCatIds()
474 | categories = json_normalize(coco.loadCats(cat_ids))
475 |
476 | coco_cats = []
477 | for cat_id in cat_ids:
478 | this_img_list = coco.getImgIds(catIds=[cat_id])
479 | if coco_id in this_img_list:
480 | this_cat = np.asarray(categories[categories['id']==cat_id]['name'])[0]
481 | coco_cats.append(this_cat)
482 |
483 | elif len(image_index) > 1:
484 |
485 | # we output a list of annots
486 | coco_cats = []
487 |
488 | # load train_2017
489 | annot_file = self.coco_annotation_file.format(
490 | 'instances', 'train2017')
491 | coco_train = COCO(annot_file)
492 | cat_ids_train = coco_train.getCatIds()
493 | categories_train = json_normalize(coco_train.loadCats(cat_ids_train))
494 |
495 | # also load the val 2017
496 | annot_file = self.coco_annotation_file.format(
497 | 'instances', 'val2017')
498 | coco_val = COCO(annot_file)
499 | cat_ids_val = coco_val.getCatIds()
500 | categories_val = json_normalize(coco_val.loadCats(cat_ids_val))
501 |
502 | for image in tqdm(image_index, bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}'):
503 | subj_info = self.stim_descriptions.iloc[image]
504 | coco_id = subj_info['cocoId']
505 | image_cat = []
506 | if subj_info['cocoSplit'] == 'train2017':
507 | for cat_id in cat_ids_train:
508 | this_img_list = coco_train.getImgIds(catIds=[cat_id])
509 | if coco_id in this_img_list:
510 | this_cat = np.asarray(categories_train[categories_train['id']==cat_id]['name'])[0]
511 | image_cat.append(this_cat)
512 |
513 | elif subj_info['cocoSplit'] == 'val2017':
514 | for cat_id in cat_ids_val:
515 | this_img_list = coco_val.getImgIds(catIds=[cat_id])
516 | if coco_id in this_img_list:
517 | this_cat = np.asarray(categories_val[categories_val['id']==cat_id]['name'])[0]
518 | image_cat.append(this_cat)
519 | coco_cats.append(image_cat)
520 | return coco_cats
521 |
--------------------------------------------------------------------------------
/src/trainer.py:
--------------------------------------------------------------------------------
1 | from abc import abstractmethod
2 | import os
3 | import math
4 | import numpy as np
5 | import torch
6 | import torch.nn as nn
7 | from tqdm import tqdm
8 | import wandb
9 |
10 | # tf32 data type is faster than standard float32
11 | torch.backends.cuda.matmul.allow_tf32 = True
12 |
13 | # Custom models and functions #
14 | import utils
15 | import data
16 |
17 | class Trainer:
18 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None:
19 | # train logs path
20 | self.outdir = os.path.abspath(f'../train_logs/{args.model_name}')
21 | if not os.path.exists(self.outdir):
22 | os.makedirs(self.outdir,exist_ok=True)
23 |
24 | self.args = args
25 | self.accelerator = accelerator
26 | self.voxel2clip = voxel2clip
27 | self.clip_extractor = clip_extractor
28 | self.prompts_list = prompts_list
29 | self.device = device
30 | self.num_devices = max(torch.cuda.device_count(), 1)
31 | self.epoch_start = 0
32 |
33 | self.prepare_dataloader()
34 | self.prepare_optimizer()
35 | self.prepare_scheduler()
36 | self.prepare_multi_gpu()
37 | self.prepare_weights()
38 |
39 | def prepare_weights(self):
40 | # Resume or load ckpt
41 | if self.args.resume:
42 | self.resume()
43 | elif self.args.load_from:
44 | self.load()
45 | else: # the ckpt folder should not contain any ckpt. If it contains something, which means you forget to change experiment name, and will cause overwrite.
46 | file_count = len(os.listdir(self.outdir))
47 | if file_count > 0:
48 | raise RuntimeError("The folder is not empty, please check to avoid overwriting! \n {}\n".format(self.outdir))
49 |
50 | @abstractmethod
51 | def prepare_dataloader(self):
52 | pass
53 |
54 | def prepare_optimizer(self,):
55 | # Prepare optimizer
56 | no_decay = ['bias', 'Norm', 'temperature']
57 | opt_grouped_parameters = [
58 | {'params': [p for n, p in self.voxel2clip.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2},
59 | {'params': [p for n, p in self.voxel2clip.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
60 | ]
61 | self.optimizer = torch.optim.AdamW(opt_grouped_parameters, lr=self.args.max_lr)
62 |
63 | def prepare_scheduler(self):
64 | # prepare lr scheduler
65 | one_epoch_steps = self.num_batches
66 | if self.accelerator.state.deepspeed_plugin is not None: # Multi GPU
67 | one_epoch_steps = math.ceil(one_epoch_steps / self.num_devices)
68 | total_steps = self.args.num_epochs * one_epoch_steps
69 | print("one_epoch_steps_per_gpu:",one_epoch_steps)
70 | print("total_steps:",total_steps)
71 |
72 | if self.args.lr_scheduler_type == 'linear':
73 | self.lr_scheduler = torch.optim.lr_scheduler.LinearLR(
74 | self.optimizer,
75 | total_iters=total_steps,
76 | last_epoch=-1
77 | )
78 | elif self.args.lr_scheduler_type == 'cycle':
79 | self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
80 | self.optimizer,
81 | max_lr=self.args.max_lr,
82 | total_steps=total_steps,
83 | final_div_factor=100,
84 | last_epoch=-1,
85 | pct_start=2/self.args.num_epochs,
86 | )
87 |
88 | def prepare_wandb(self, local_rank, args):
89 | ## Weights and Biases
90 | if local_rank==0 and args.wandb_log: # only use main process for wandb logging
91 | import wandb
92 | wandb_run = args.model_name
93 | wandb_notes = ''
94 |
95 | print(f"Wandb project {args.wandb_project} run {wandb_run}")
96 | wandb.login(host='https://api.wandb.ai')
97 | wandb_config = vars(args)
98 | print("wandb_config:\n",wandb_config)
99 | if args.resume: # wandb_auto_resume
100 | if args.resume_id is None:
101 | args.resume_id = args.model_name
102 | print("wandb_id:", args.resume_id)
103 | wandb.init(
104 | id = args.resume_id,
105 | project=args.wandb_project,
106 | name=wandb_run,
107 | config=wandb_config,
108 | notes=wandb_notes,
109 | resume="allow",
110 | )
111 | else:
112 | wandb.init(
113 | project=args.wandb_project,
114 | name=wandb_run,
115 | config=wandb_config,
116 | notes=wandb_notes,
117 | )
118 |
119 | @abstractmethod
120 | def prepare_multi_gpu(self):
121 | pass
122 |
123 | def input(self, voxel, subj_id):
124 | return (voxel, subj_id)
125 |
126 | def train(self, local_rank):
127 | epoch = self.epoch_start
128 | self.losses, self.val_losses, self.lrs = [], [], []
129 | self.best_sim = 0
130 | self.best_epoch = 0
131 |
132 | self.val_voxel0 = self.val_image0 = None
133 |
134 | ## Main loop
135 | print(f"{self.args.model_name} starting with epoch {epoch} / {self.args.num_epochs}")
136 | progress_bar = tqdm(range(epoch, self.args.num_epochs), disable=(local_rank!=0))
137 |
138 | for epoch in progress_bar:
139 | self.voxel2clip.train()
140 |
141 | self.sims_image = 0.
142 | self.sims_text = 0.
143 | self.val_sims_image = 0.
144 | self.val_sims_text = 0.
145 | self.fwd_percent_correct = 0.
146 | self.bwd_percent_correct = 0.
147 | self.val_fwd_percent_correct = 0.
148 | self.val_bwd_percent_correct = 0.
149 | self.loss_clip_image_sum = 0.
150 | self.loss_clip_text_sum = 0.
151 | self.loss_mse_image_sum = 0.
152 | self.loss_mse_text_sum = 0.
153 | self.loss_rec_sum = 0.
154 | self.loss_cyc_sum = 0.
155 | self.val_loss_clip_image_sum = 0.
156 | self.val_loss_clip_text_sum = 0.
157 | self.val_loss_mse_image_sum = 0.
158 | self.val_loss_mse_text_sum = 0.
159 | self.val_loss_rec_sum = 0.
160 | self.val_loss_cyc_sum = 0.
161 |
162 | # wandb logging
163 | self.train_epoch(epoch)
164 | self.log_train()
165 |
166 | if epoch % self.args.eval_interval == 0:
167 | self.eval_epoch(epoch)
168 | self.log_val()
169 |
170 | if self.args.wandb_log and local_rank==0:
171 | wandb.log(self.logs)
172 |
173 | progress_dict = {
174 | "epoch": epoch,
175 | "lr": self.logs["train/lr"],
176 | "loss": self.logs["train/loss"],
177 | }
178 |
179 | progress_bar.set_postfix(progress_dict)
180 |
181 | # Main process
182 | if local_rank==0:
183 | # Uploading logs to wandb
184 | if self.args.wandb_log:
185 | wandb.log(self.logs)
186 |
187 | # Save model
188 | if epoch % self.args.ckpt_interval == 0 or epoch == self.args.num_epochs-1:
189 | self.save(epoch)
190 |
191 | # wait for other GPUs to catch up if needed
192 | self.accelerator.wait_for_everyone()
193 |
194 | @abstractmethod
195 | def train_epoch(self, epoch):
196 | pass
197 |
198 | def train_step(self, voxel, image, captions, subj_id):
199 | loss = 0.
200 | self.optimizer.zero_grad()
201 |
202 | if self.args.use_image_aug:
203 | image = data.img_augment(image)
204 | clip_image = self.clip_extractor.embed_image(image).float()
205 | clip_text = self.clip_extractor.embed_text(captions).float()
206 |
207 | # clip_image_pred, clip_text_pred, voxel_rec, loss_cyc = self.voxel2clip((voxel, subj_id))
208 | results = self.voxel2clip(self.input(voxel, subj_id))
209 |
210 | # image clip loss
211 | clip_image_pred = results[0]
212 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1)
213 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1)
214 | loss_clip_image = utils.soft_clip_loss(
215 | clip_image_pred_norm,
216 | clip_image_norm,
217 | )
218 |
219 | utils.check_loss(loss_clip_image, "loss_clip_image")
220 | loss += loss_clip_image
221 | self.loss_clip_image_sum += loss_clip_image.item()
222 |
223 | # image mse loss
224 | if self.args.mse_mult:
225 | loss_mse_image = nn.MSELoss()(clip_image_pred_norm, clip_image_norm)
226 | utils.check_loss(loss_mse_image, "loss_mse_image")
227 | loss += self.args.mse_mult * loss_mse_image
228 | self.loss_mse_image_sum += loss_mse_image.item()
229 |
230 | # text clip loss
231 | clip_text_pred = results[1]
232 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1)
233 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1)
234 | loss_clip_text = utils.soft_clip_loss(
235 | clip_text_pred_norm,
236 | clip_text_norm,
237 | )
238 | utils.check_loss(loss_clip_text, "loss_clip_text")
239 | loss += loss_clip_text
240 | self.loss_clip_text_sum += loss_clip_text.item()
241 |
242 | # text mse loss
243 | if self.args.mse_mult:
244 | loss_mse_text = nn.MSELoss()(clip_text_pred_norm, clip_text_norm)
245 | utils.check_loss(loss_mse_text, "loss_mse_text")
246 | loss += self.args.mse_mult * loss_mse_text
247 | self.loss_mse_text_sum += loss_mse_text.item()
248 |
249 | # brain reconstruction loss
250 | if self.args.rec_mult:
251 | voxel_rec = results[2]
252 | loss_rec = nn.MSELoss()(voxel, voxel_rec)
253 | utils.check_loss(loss_rec, "loss_rec")
254 | loss += self.args.rec_mult * loss_rec
255 | self.loss_rec_sum += loss_rec.item()
256 |
257 | # cycle loss
258 | if self.args.cyc_mult:
259 | loss_cyc = results[3]
260 | utils.check_loss(loss_cyc, "loss_cyc")
261 | loss += self.args.cyc_mult * loss_cyc
262 | self.loss_cyc_sum += loss_cyc.item()
263 |
264 | utils.check_loss(loss)
265 | self.accelerator.backward(loss)
266 | self.optimizer.step()
267 |
268 | self.losses.append(loss.item())
269 | self.lrs.append(self.optimizer.param_groups[0]['lr'])
270 | self.lr_scheduler.step()
271 |
272 | self.sims_image += nn.functional.cosine_similarity(clip_image_norm,clip_image_pred_norm).mean().item()
273 | self.sims_text += nn.functional.cosine_similarity(clip_text_norm,clip_text_pred_norm).mean().item()
274 |
275 | # forward and backward top 1 accuracy
276 | labels = torch.arange(len(clip_image_norm)).to(self.device)
277 | self.fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_pred_norm, clip_image_norm), labels, k=1)
278 | self.bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_norm, clip_image_pred_norm), labels, k=1)
279 |
280 | @abstractmethod
281 | def eval_epoch(self, epoch):
282 | pass
283 |
284 | def eval_step(self, voxel, image, captions, subj_id):
285 | val_loss = 0.
286 | with torch.no_grad():
287 | # used for reconstruction
288 | if self.val_image0 is None:
289 | self.val_image0 = image.detach().clone()
290 | self.val_voxel0 = voxel.detach().clone()
291 |
292 | clip_image = self.clip_extractor.embed_image(image).float()
293 | clip_text = self.clip_extractor.embed_text(captions).float()
294 |
295 | # clip_image_pred, clip_text_pred, voxel_rec, loss_cyc = self.voxel2clip((voxel, subj_id))
296 | results = self.voxel2clip(self.input(voxel, subj_id))
297 |
298 | # image clip loss
299 | clip_image_pred = results[0]
300 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1)
301 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1)
302 | val_loss_clip_image = utils.soft_clip_loss(
303 | clip_image_pred_norm,
304 | clip_image_norm,
305 | )
306 | val_loss += val_loss_clip_image
307 | self.val_loss_clip_image_sum += val_loss_clip_image.item()
308 |
309 | # image mse loss
310 | if self.args.mse_mult:
311 | val_loss_mse_image = nn.MSELoss()(clip_image_pred_norm, clip_image_norm)
312 | val_loss += self.args.mse_mult * val_loss_mse_image
313 | self.val_loss_mse_image_sum += val_loss_mse_image.item()
314 |
315 | # text clip loss
316 | clip_text_pred = results[1]
317 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1)
318 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1)
319 | val_loss_clip_text = utils.soft_clip_loss(
320 | clip_text_pred_norm,
321 | clip_text_norm,
322 | )
323 | val_loss += val_loss_clip_text
324 | self.val_loss_clip_text_sum += val_loss_clip_text.item()
325 |
326 | # text mse loss
327 | if self.args.mse_mult:
328 | val_loss_mse_text = nn.MSELoss()(clip_text_pred_norm, clip_text_norm)
329 | val_loss += self.args.mse_mult * val_loss_mse_text
330 | self.val_loss_mse_text_sum += val_loss_mse_text.item()
331 |
332 | # brain reconstruction loss
333 | if self.args.rec_mult:
334 | voxel_rec = results[2]
335 | val_loss_rec = nn.MSELoss()(voxel, voxel_rec)
336 | val_loss += self.args.rec_mult * val_loss_rec
337 | self.val_loss_rec_sum += val_loss_rec.item()
338 |
339 | # cycle loss
340 | if self.args.cyc_mult:
341 | loss_cyc = results[3]
342 | val_loss_cyc = loss_cyc
343 | val_loss += self.args.cyc_mult * val_loss_cyc
344 | self.val_loss_cyc_sum += val_loss_cyc.item()
345 |
346 | utils.check_loss(val_loss)
347 | self.val_losses.append(val_loss.item())
348 |
349 | self.val_sims_image += nn.functional.cosine_similarity(clip_image_norm,clip_image_pred_norm).mean().item()
350 | self.val_sims_text += nn.functional.cosine_similarity(clip_text_norm,clip_text_pred_norm).mean().item()
351 |
352 | labels = torch.arange(len(clip_image_norm)).to(self.device)
353 | self.val_fwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_pred_norm, clip_image_norm), labels, k=1)
354 | self.val_bwd_percent_correct += utils.topk(utils.batchwise_cosine_similarity(clip_image_norm, clip_image_pred_norm), labels, k=1)
355 |
356 | def vis(self,):
357 | pass
358 |
359 | def save_ckpt(self, tag, epoch):
360 | if self.accelerator.is_main_process:
361 | ckpt_path = self.outdir+f'/{tag}.pth'
362 | print(f'--- saving model: {ckpt_path} ---',flush=True)
363 | unwrapped_model = self.accelerator.unwrap_model(self.voxel2clip)
364 | # unwarpped_optimizer = self.accelerator.unwrap_model(self.optimizer)
365 | # print("unwarpped optimizer keys", unwarpped_optimizer.state_dict().keys())
366 |
367 | try:
368 | torch.save({
369 | 'epoch': epoch,
370 | 'model_state_dict': unwrapped_model.state_dict(),
371 | }, ckpt_path)
372 | except:
373 | print("Couldn't save... moving on to prevent crashing.")
374 | del unwrapped_model
375 |
376 | state_path = os.path.join(self.outdir, tag)
377 | print(f'--- saving state: {state_path} ---',flush=True)
378 | self.accelerator.save_state(state_path)
379 |
380 | def save(self, epoch):
381 | self.save_ckpt(f'last', epoch)
382 | # save best model
383 | current_sim = (self.val_sims_image + self.val_sims_text) / (self.val_i + 1) if hasattr(self, 'val_i') else 0
384 | if current_sim > self.best_sim:
385 | self.best_sim = current_sim
386 | self.best_epoch = epoch
387 | self.save_ckpt(f'best', epoch)
388 | else:
389 | print(f'Not best - current_similarity: {current_sim:.3f} @ epoch {epoch}, best_similarity: {self.best_sim:.3f} @ epoch {self.best_epoch}')
390 |
391 | def load(self,):
392 | print("\n--- load from ckpt: {} ---\n".format(self.args.load_from))
393 | checkpoint = torch.load(self.args.load_from, map_location='cpu')
394 | unwrapped_voxel2clip = self.accelerator.unwrap_model(self.voxel2clip)
395 | unwrapped_voxel2clip.load_state_dict(checkpoint['model_state_dict'], strict=False)
396 | print("loaded keys", checkpoint['model_state_dict'].keys())
397 | del checkpoint
398 |
399 | def resume(self,):
400 | state_path = os.path.join(self.outdir, "last")
401 | print(f"\n--- resuming from {state_path} ---\n")
402 | self.accelerator.load_state(state_path)
403 |
404 | ckpt_path = self.outdir+'/last.pth'
405 | print(f"\n--- Read resume epoch from {ckpt_path} ---\n")
406 | checkpoint = torch.load(ckpt_path, map_location='cpu')
407 | self.epoch_start = checkpoint['epoch']
408 | print(">>> Resume at Epoch", self.epoch_start)
409 | # self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']['base_optimizer_state'])
410 | # self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
411 | # self.voxel2clip.load_state_dict(checkpoint['model_state_dict'])
412 | del checkpoint
413 |
414 | def log_train(self):
415 | self.logs = {
416 | "train/loss": np.mean(self.losses[-(self.train_i+1):]),
417 | "train/lr": self.lrs[-1],
418 | "train/num_steps": len(self.losses),
419 | "train/cosine_sim_image": self.sims_image / (self.train_i + 1),
420 | "train/cosine_sim_text": self.sims_text / (self.train_i + 1),
421 | "train/fwd_pct_correct": self.fwd_percent_correct / (self.train_i + 1),
422 | "train/bwd_pct_correct": self.bwd_percent_correct / (self.train_i + 1),
423 | "train/loss_clip_image": self.loss_clip_image_sum / (self.train_i + 1),
424 | "train/loss_clip_text": self.loss_clip_text_sum / (self.train_i + 1),
425 | "train/loss_mse_image": self.loss_mse_image_sum / (self.train_i + 1),
426 | "train/loss_mse_text": self.loss_mse_text_sum / (self.train_i + 1),
427 | "train/loss_rec": self.loss_rec_sum / (self.train_i + 1),
428 | "train/loss_cyc": self.loss_cyc_sum / (self.train_i + 1),
429 | }
430 |
431 | def log_val(self):
432 | self.logs.update({
433 | "val/loss": np.mean(self.val_losses[-(self.val_i+1):]),
434 | "val/num_steps": len(self.val_losses),
435 | "val/cosine_sim_image": self.val_sims_image / (self.val_i + 1),
436 | "val/cosine_sim_text": self.val_sims_text / (self.val_i + 1),
437 | "val/val_fwd_pct_correct": self.val_fwd_percent_correct / (self.val_i + 1),
438 | "val/val_bwd_pct_correct": self.val_bwd_percent_correct / (self.val_i + 1),
439 | "val/loss_clip_image": self.val_loss_clip_image_sum / (self.val_i + 1),
440 | "val/loss_clip_text": self.val_loss_clip_text_sum / (self.val_i + 1),
441 | "val/loss_mse_image": self.val_loss_mse_image_sum / (self.val_i + 1),
442 | "val/loss_mse_text": self.val_loss_mse_text_sum / (self.val_i + 1),
443 | "val/loss_rec": self.val_loss_rec_sum / (self.val_i + 1),
444 | "val/loss_cyc": self.val_loss_cyc_sum / (self.val_i + 1),
445 | })
446 |
447 | class Trainer_single(Trainer):
448 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None:
449 | super().__init__(args, accelerator, voxel2clip, clip_extractor, prompts_list, device)
450 |
451 | def prepare_dataloader(self):
452 | # Prepare data and dataloader
453 | print("Preparing data and dataloader...")
454 | self.train_dl, self.val_dl = data.get_dls(
455 | subject=self.args.subj_list[0],
456 | data_path=self.args.data_path,
457 | batch_size=self.args.batch_size,
458 | val_batch_size=self.args.val_batch_size,
459 | num_workers=self.args.num_workers,
460 | pool_type=self.args.pool_type,
461 | pool_num=self.args.pool_num,
462 | length=self.args.length,
463 | seed=self.args.seed,
464 | )
465 | self.num_batches = len(self.train_dl)
466 |
467 | def prepare_multi_gpu(self):
468 | self.voxel2clip, self.optimizer, self.train_dl, self.val_dl, self.lr_scheduler = self.accelerator.prepare(
469 | self.voxel2clip, self.optimizer, self.train_dl, self.val_dl, self.lr_scheduler)
470 |
471 | def input(self, voxel, subj_id):
472 | # adapting need to know subj_id
473 | return voxel
474 |
475 | def train_epoch(self, epoch):
476 | # train loop
477 | for train_i, data_i in enumerate(self.train_dl):
478 | self.train_i = train_i
479 | repeat_index = train_i % 3 # randomly choose the one in the repeated three
480 |
481 | voxel, image, coco, subj_id = data_i
482 | voxel = voxel[:,repeat_index,...].float()
483 | subj_id = subj_id[[0],...]
484 |
485 | coco_ids = coco.squeeze().tolist()
486 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids]
487 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list]
488 |
489 | print(">>> Epoch{} | Iter{} | voxel: {}".format(epoch, train_i, voxel.shape), flush=True)
490 | self.train_step(voxel, image, captions, subj_id)
491 |
492 | def eval_epoch(self, epoch):
493 | print("evaluating...")
494 | self.voxel2clip.eval()
495 |
496 | for val_i, data_i in enumerate(self.val_dl):
497 | self.val_i = val_i
498 | repeat_index = val_i % 3 # randomly choose the one in the repeated three
499 | voxel, image, coco, subj_id = data_i
500 | voxel = torch.mean(voxel,axis=1)
501 | subj_id = subj_id[[0],...]
502 |
503 | coco_ids = coco.squeeze().tolist()
504 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids]
505 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list]
506 |
507 | print(">>> Epoch{} | Eval{} | voxel: {}".format(epoch, val_i, voxel.shape), flush=True)
508 | self.eval_step(voxel, image, captions, subj_id)
509 |
510 | class Trainer_bridge(Trainer):
511 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None:
512 | super().__init__(args, accelerator, voxel2clip, clip_extractor, prompts_list, device)
513 |
514 | def prepare_dataloader(self):
515 | # Prepare data and dataloader
516 | print("Preparing data and dataloader...")
517 | self.train_dls = [] # tarin_dls contains all subjects separately
518 | self.val_dls = [] # tarin_dls contains all subjects separately
519 |
520 | for subj in self.args.subj_list:
521 | train_dl, val_dl = data.get_dls(
522 | subject=subj,
523 | data_path=self.args.data_path,
524 | batch_size=self.args.batch_size,
525 | val_batch_size=self.args.val_batch_size,
526 | num_workers=self.args.num_workers,
527 | pool_type=self.args.pool_type,
528 | pool_num=self.args.pool_num,
529 | length=self.args.length,
530 | seed=self.args.seed,
531 | )
532 | self.train_dls.append(train_dl)
533 | self.val_dls.append(val_dl)
534 |
535 | self.num_batches = len(self.train_dls[0])
536 |
537 | def prepare_multi_gpu(self):
538 | self.voxel2clip, self.optimizer, self.lr_scheduler, _ = self.accelerator.prepare(
539 | self.voxel2clip, self.optimizer, self.lr_scheduler, self.train_dls[0])
540 |
541 | for i, dls in enumerate(zip(self.train_dls, self.val_dls)):
542 | train_dl, val_dl = dls
543 | self.train_dls[i] = self.accelerator.prepare(train_dl)
544 | self.val_dls[i] = self.accelerator.prepare(val_dl)
545 |
546 | def train_epoch(self, epoch):
547 | # train loop
548 | for train_i, datas in enumerate(zip(*self.train_dls)):
549 | self.train_i = train_i
550 | repeat_index = train_i % 3 # randomly choose the one in the repeated three
551 |
552 | # ensemble data from multiple subjects
553 | voxel_list, image_list, coco_list, subj_id_list = [], [], [], []
554 | for voxel, image, coco, subj_id in datas:
555 | voxel_list.append(voxel[:,repeat_index,...])
556 | image_list.append(image)
557 | coco_list.append(coco)
558 | subj_id_list.append(subj_id[[0],...])
559 | voxel = torch.cat(voxel_list, dim=0)
560 | image = torch.cat(image_list, dim=0)
561 | coco = torch.cat(coco_list, dim=0)
562 | subj_id = torch.cat(subj_id_list, dim=0)
563 |
564 | coco_ids = coco.squeeze().tolist()
565 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids]
566 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list]
567 |
568 | print(">>> Epoch{} | Iter{} | voxel: {}".format(epoch, train_i, voxel.shape), flush=True)
569 | self.train_step(voxel, image, captions, subj_id)
570 |
571 | def eval_epoch(self, epoch):
572 | print("Evaluating...")
573 | self.voxel2clip.eval()
574 | for val_i, datas in enumerate(zip(*self.val_dls)):
575 | self.val_i = val_i
576 | repeat_index = val_i % 3 # randomly choose the one in the repeated three
577 |
578 | # ensemble data from multiple subjects
579 | voxel_list, image_list, coco_list, subj_id_list = [], [], [], []
580 | for voxel, image, coco, subj_id in datas:
581 | voxel_list.append(torch.mean(voxel,axis=1))
582 | image_list.append(image)
583 | coco_list.append(coco)
584 | subj_id_list.append(subj_id[[0],...])
585 | voxel = torch.cat(voxel_list, dim=0)
586 | image = torch.cat(image_list, dim=0)
587 | coco = torch.cat(coco_list, dim=0)
588 | subj_id = torch.cat(subj_id_list, dim=0)
589 |
590 | coco_ids = coco.squeeze().tolist()
591 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids]
592 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list]
593 |
594 | print(">>> Epoch{} | Eval{} | voxel: {}".format(epoch, val_i, voxel.shape), flush=True)
595 | self.eval_step(voxel, image, captions, subj_id)
596 |
597 | class Trainer_adapt(Trainer):
598 | def __init__(self, args, accelerator, voxel2clip, clip_extractor, prompts_list, device) -> None:
599 | super().__init__(args, accelerator, voxel2clip, clip_extractor, prompts_list, device)
600 |
601 | def prepare_dataloader(self):
602 | # Prepare data and dataloader
603 | print("Preparing data and dataloader...")
604 | self.train_dls_source = [] # tarin_dls contains all subjects separately
605 | self.val_dls_source = [] # tarin_dls contains all subjects separately
606 |
607 | # source subjects
608 | for subj in self.args.subj_source:
609 | train_dl, val_dl = data.get_dls(
610 | subject=subj,
611 | data_path=self.args.data_path,
612 | batch_size=self.args.batch_size,
613 | val_batch_size=self.args.val_batch_size,
614 | num_workers=self.args.num_workers,
615 | pool_type=self.args.pool_type,
616 | pool_num=self.args.pool_num,
617 | length=self.args.length,
618 | seed=self.args.seed,
619 | )
620 | self.train_dls_source.append(train_dl)
621 | self.val_dls_source.append(val_dl)
622 |
623 | # target subjects
624 | self.train_dl_target, self.val_dl_target = data.get_dls(
625 | subject=self.args.subj_target,
626 | data_path=self.args.data_path,
627 | batch_size=self.args.batch_size,
628 | val_batch_size=self.args.val_batch_size,
629 | num_workers=self.args.num_workers,
630 | pool_type=self.args.pool_type,
631 | pool_num=self.args.pool_num,
632 | length=self.args.length,
633 | seed=self.args.seed,
634 | )
635 |
636 | self.num_batches = len(self.train_dl_target)
637 |
638 | def prepare_multi_gpu(self):
639 | self.voxel2clip, self.optimizer, self.lr_scheduler, self.train_dl_target, self.val_dl_target = self.accelerator.prepare(
640 | self.voxel2clip, self.optimizer, self.lr_scheduler, self.train_dl_target, self.val_dl_target)
641 |
642 | for i, dls in enumerate(zip(self.train_dls_source, self.val_dls_source)):
643 | train_dl, val_dl = dls
644 | self.train_dls_source[i] = self.accelerator.prepare(train_dl)
645 | self.val_dls_source[i] = self.accelerator.prepare(val_dl)
646 |
647 | def train_epoch(self, epoch):
648 | # enable iteratable
649 | train_dls_source_iter = []
650 | for train_dl_s in self.train_dls_source:
651 | train_dls_source_iter.append(iter(train_dl_s))
652 |
653 | # train loop
654 | for train_i, datas_target in enumerate(self.train_dl_target):
655 | self.train_i = train_i
656 | repeat_index = train_i % 3 # randomly choose the one in the repeated three
657 | voxel_target, image, coco, subj_id = datas_target
658 | voxel = voxel_target[:,repeat_index,...]
659 |
660 | source_index = train_i % len(train_dls_source_iter) # every time choose one source domain
661 | voxel_source, image_source, coco_source, subj_id_source = next(train_dls_source_iter[source_index])
662 | voxel_source = voxel_source[:,repeat_index,...]
663 | voxel = torch.cat((voxel_source, voxel), dim=0)
664 | image = torch.cat((image_source, image), dim=0)
665 | coco = torch.cat((coco_source, coco), dim=0)
666 | subj_id = torch.cat((subj_id_source[[0],...], subj_id[[0],...]), dim=0)
667 |
668 | coco_ids = coco.squeeze().tolist()
669 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids]
670 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list]
671 |
672 | print(">>> Epoch{} | Iter{} | source{} | voxel: {}".format(epoch, train_i, source_index, voxel.shape), flush=True)
673 | self.train_step(voxel, image, captions, subj_id)
674 |
675 | def eval_epoch(self, epoch):
676 | print("Evaluating...")
677 | self.voxel2clip.eval()
678 |
679 | # enable iteratable
680 | val_dls_source_iter = []
681 | for val_dl_s in self.val_dls_source:
682 | val_dls_source_iter.append(iter(val_dl_s))
683 |
684 | for val_i, datas_target in enumerate(self.val_dl_target):
685 | self.val_i = val_i
686 | repeat_index = val_i % 3 # randomly choose the one in the repeated three
687 | voxel, image, coco, subj_id = datas_target
688 | voxel = torch.mean(voxel,axis=1)
689 |
690 | source_index = val_i % len(val_dls_source_iter) # every time choose one source domain
691 | print("Using source {}".format(source_index))
692 | voxel_source, image_source, coco_source, subj_id_source = next(val_dls_source_iter[source_index])
693 | voxel_source = torch.mean(voxel_source, axis=1)
694 | voxel = torch.cat((voxel_source, voxel), dim=0)
695 | image = torch.cat((image_source, image), dim=0)
696 | coco = torch.cat((coco_source, coco), dim=0)
697 | subj_id = torch.cat((subj_id_source[[0],...], subj_id[[0],...]), dim=0)
698 |
699 | coco_ids = coco.squeeze().tolist()
700 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids]
701 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list]
702 |
703 | print(">>> Epoch{} | Eval{} | voxel: {}".format(epoch, val_i, voxel.shape), flush=True)
704 | self.eval_step(voxel, image, captions, subj_id)
705 |
706 |
--------------------------------------------------------------------------------