├── README.md ├── assets ├── framework.jpg ├── framework.pdf └── framework.png ├── clients └── client.py ├── main.py ├── requirements.txt ├── scripts └── train.sh ├── servers └── server.py ├── trainmodel └── models.py └── utils ├── DFL.py ├── data.py ├── data_utils.py ├── loss_avg.py ├── nsd_access.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # BrainGuard: Privacy-Preserving Multisubject Image Reconstructions from Brain Activities [arXiv](https://arxiv.org/abs/2501.14309) [Project](https://zhibotian.github.io/BrainGuard/) 2 | 3 | [Zhibo Tian](https://scholar.google.com/citations?user=HbKGBGgAAAAJ&hl=en), [Ruijie Quan](https://scholar.google.com/citations?user=WKLRPsAAAAAJ&hl=en), [Fan Ma](https://scholar.google.com/citations?user=FyglsaAAAAAJ&hl=en), [Kun Zhan](https://scholar.google.com/citations?user=sk7TcGAAAAAJ&hl=en), [Yi Yang](https://scholar.google.com/citations?hl=en&user=RMSuNFwAAAAJ) 4 | 5 | ## Overview 6 | ![framework](assets/framework.jpg) 7 | 8 | Reconstructing perceived images from human brain activity forms a crucial link between human and machine learning through Brain-Computer Interfaces. Early methods primarily focused on training separate models for each individual to account for individual variability in brain activity, overlooking valuable cross-subject commonalities. Recent advancements have explored multisubject methods, but these approaches face significant challenges, particularly in data privacy and effectively managing individual variability. To overcome these challenges, we introduce BrainGuard, a privacy-preserving collaborative training framework designed to enhance image reconstruction from multisubject fMRI data while safeguarding individual privacy. BrainGuard employs a collaborative global-local architecture where individual models are trained on each subject’s local data and operate in conjunction with a shared global model that captures and leverages cross-subject patterns. This architecture eliminates the need to aggregate fMRI data across subjects, thereby ensuring privacy preservation. To tackle the complexity of fMRI data, BrainGuard integrates a hybrid synchronization strategy, enabling individual models to dynamically incorporate parameters from the global model. By establishing a secure and collaborative training environment, BrainGuard not only protects sensitive brain data but also improves the image reconstructions accuracy. Extensive experiments demonstrate that BrainGuard sets a new benchmark in both high-level and low-level metrics, advancing the state-of-the-art in brain decoding through its innovative design. 9 | 10 | ## Installation 11 | 12 | 1. Agree to the Natural Scenes Dataset's [Terms and Conditions](https://cvnlab.slite.page/p/IB6BSeW_7o/Terms-and-Conditions) and fill out the [NSD Data Access form](https://forms.gle/xue2bCdM9LaFNMeb7) 13 | 14 | 2. Clone this repository: ``git clone https://github.com/kunzhan/BrainGuard.git`` 15 | 16 | 3. Create a conda environment and install the packages necessary to run the code. 17 | 18 | ```bash 19 | conda create -n brainguard python=3.10.8 -y 20 | conda activate brainguard 21 | pip install -r requirements.txt 22 | ``` 23 | 24 | ## Preparation 25 | 26 | Download the essential files we used from [NSD dataset](https://natural-scenes-dataset.s3.amazonaws.com/index.html), which contains `nsd_stim_info_merged.csv`, `captions_train2017.json` and `captions_val2017.json`. 27 | We use the same preprocessed data as [MindEye's](https://github.com/MedARC-AI/fMRI-reconstruction-NSD), which can be downloaded from [Hugging Face](https://huggingface.co/datasets/pscotti/naturalscenesdataset/tree/main/webdataset_avg_split), and extract all files from the compressed tar files. 28 | Then organize the data as following: 29 | 30 |
31 | 32 | Data Organization 33 | 34 | ``` 35 | data/natural-scenes-dataset 36 | ├── nsddata 37 | │ └── experiments 38 | │ └── nsd 39 | │ └── nsd_stim_info_merged.csv 40 | ├── nsddata_stimuli 41 | │ └── stimuli 42 | │ └── nsd 43 | │ └── annotations 44 | │ ├── captions_train2017.json 45 | │ └── captions_val2017.json 46 | └── webdataset_avg_split 47 | ├── test 48 | │ ├── subj01 49 | │ │ ├── sample000000349.coco73k.npy 50 | │ │ ├── sample000000349.jpg 51 | │ │ ├── sample000000349.nsdgeneral.npy 52 | │ │ └── ... 53 | │ └── ... 54 | ├── train 55 | │ ├── subj01 56 | │ │ ├── sample000000300.coco73k.npy 57 | │ │ ├── sample000000300.jpg 58 | │ │ ├── sample000000300.nsdgeneral.npy 59 | │ │ └── ... 60 | │ └── ... 61 | └── val 62 | ├── subj01 63 | │ ├── sample000000000.coco73k.npy 64 | │ ├── sample000000000.jpg 65 | │ ├── sample000000000.nsdgeneral.npy 66 | │ └── ... 67 | └── ... 68 | ``` 69 |
70 | 71 | ### Checkpoints 72 | You can download our pretrained Brainguard checkpoints for "subject01, 02, 05, 07" from [Hugging Face](https://huggingface.co/Zhibo2333/Brainguard/tree/main). And place the folders containing checkpoints under the directory `./train_logs/`. 73 | 74 | ## Training 75 | 76 | ```bash 77 | bash scripts/train.sh 78 | ``` 79 | 80 | ## Citation 81 | ``` 82 | @InProceedings{tian2025brainguard, 83 | author = {Zhibo Tian and Ruijie Quan and Fan Ma and Kun Zhan and Yi Yang}, 84 | booktitle = {AAAI}, 85 | title = {{BrainGuard}: Privacy-preserving multisubject image reconstructions from brain activities}, 86 | year = {2025}, 87 | volume = {39}, 88 | } 89 | ``` 90 | 91 | ## Acknowledgement 92 | We extend our gratitude to [MindBridge](https://github.com/littlepure2333/MindBridge), [MindEye](https://github.com/MedARC-AI/fMRI-reconstruction-NSD) and [nsd_access](https://github.com/tknapen/nsd_access) for generously sharing their codebase, upon which ours is built. We are indebted to the [NSD dataset](https://natural-scenes-dataset.s3.amazonaws.com/index.html) for providing access to high-quality, publicly available data. 93 | 94 | ## Contact 95 | https://kunzhan.github.io/ 96 | 97 | If you have any questions, feel free to contact me. (Email: `ice.echo#gmail.com`) 98 | -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/BrainGuard/6c767e6653ffca0c2f350c166952a9055de712fd/assets/framework.jpg -------------------------------------------------------------------------------- /assets/framework.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/BrainGuard/6c767e6653ffca0c2f350c166952a9055de712fd/assets/framework.pdf -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kunzhan/BrainGuard/6c767e6653ffca0c2f350c166952a9055de712fd/assets/framework.png -------------------------------------------------------------------------------- /clients/client.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | from utils.data_utils import read_client_data 6 | from utils.DFL import DFL 7 | from utils.loss_avg import AverageMeter 8 | from utils.utils import soft_clip_loss, Clipper 9 | import math 10 | from trainmodel.models import * 11 | import utils.data as data 12 | from utils.utils import prepare_coco 13 | 14 | class client(object): 15 | def __init__(self, args, id, train_samples, cuda_id): 16 | 17 | self.cuda_id = cuda_id 18 | self.model = copy.deepcopy(args.model) 19 | self.model.ridge = RidgeRegression(input_size=args.multi_voxel_dims[id], out_features=2048) 20 | self.model = self.model.to('cuda:{}'.format(self.cuda_id)) 21 | self.model_ema = copy.deepcopy(self.model) 22 | for param in self.model_ema.parameters(): # freeze the ema model parameters 23 | param.detach_() 24 | self.dataset = args.dataset 25 | self.device = 'cuda:{}'.format(self.cuda_id) 26 | self.id = id 27 | self.args = args 28 | self.num_classes = args.num_classes 29 | self.train_samples = train_samples 30 | self.batch_size = args.batch_size 31 | self.global_rounds = args.global_rounds 32 | self.learning_rate = args.local_learning_rate 33 | self.local_steps = args.local_steps 34 | self.set_opt_grouped_parameters(args) 35 | self.loss_mse = nn.MSELoss(reduction='mean').to(f'cuda:{self.cuda_id}') 36 | self.optimizer = torch.optim.AdamW(self.opt_grouped_parameters, betas=(0.9, 0.9999), lr=self.learning_rate, eps=1e-8) 37 | self.train_type = args.train_type 38 | self.prompts_list = prepare_coco(args.data_root) 39 | self.clip_extractor = self.prepare_CLIP(args, self.device) 40 | 41 | self.prepare_dataloader() 42 | self.set_lr_scheduler(args) 43 | 44 | self.eta = args.eta 45 | self.layer_idx = args.layer_idx 46 | self.DFL = DFL(self.id, self.cuda_id, soft_clip_loss, self.train_dl, self.layer_idx, self.eta, self.device) 47 | 48 | self.global_best_val_sim_image= 0. 49 | self.global_model_best_val_sim_image= 0. 50 | self.all_steps = 0 51 | self.best_val_bwd = 0. 52 | self.flag_ala = True 53 | self.before_aggregate_bwd = 0. 54 | 55 | self.total_loss = AverageMeter() 56 | self.mse_image = AverageMeter() 57 | self.mse_text = AverageMeter() 58 | self.nce_image = AverageMeter() 59 | self.nce_text = AverageMeter() 60 | 61 | def train(self, writer, round, logger): 62 | 63 | self.model.to(f'cuda:{self.cuda_id}') 64 | self.model_ema.to(f'cuda:{self.cuda_id}') 65 | self.model.train() 66 | 67 | for step in range(self.local_steps): 68 | 69 | logger.info("Start train Client {}, global_round: {}/{} Local step:{}/{}".format(self.id, round+1, self.global_rounds, step+1, self.local_steps)) 70 | 71 | for train_i, data_i in enumerate(self.train_dl): 72 | self.train_i = train_i 73 | repeat_index = train_i % 3 # randomly choose the one in the repeated three 74 | voxel, image, coco = data_i 75 | voxel = voxel[:,repeat_index,...].float() 76 | 77 | coco_ids = coco.squeeze().tolist() 78 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 79 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 80 | 81 | if self.args.use_image_aug: 82 | image = data.img_augment(image) 83 | 84 | clip_image = self.clip_extractor.embed_image(image).float() 85 | clip_text = self.clip_extractor.embed_text(captions).float() 86 | 87 | voxel = voxel.to(f'cuda:{self.cuda_id}') 88 | clip_image = clip_image.to(f'cuda:{self.cuda_id}') 89 | clip_text = clip_text.to(f'cuda:{self.cuda_id}') 90 | 91 | ridge_out = self.model.ridge(voxel) 92 | results = self.model.backbone(ridge_out) 93 | 94 | clip_image_pred = results[0] 95 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1) 96 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1) 97 | 98 | loss_mse_image = self.loss_mse(clip_image_pred_norm, clip_image_norm) * 10000 99 | 100 | loss_clip_image = soft_clip_loss( 101 | clip_image_pred_norm, 102 | clip_image_norm, 103 | ) 104 | 105 | clip_text_pred = results[1] 106 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1) 107 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1) 108 | 109 | loss_mse_text = self.loss_mse(clip_text_pred_norm, clip_text_norm) * 10000 110 | 111 | loss_clip_text = soft_clip_loss( 112 | clip_text_pred_norm, 113 | clip_text_norm, 114 | ) 115 | 116 | loss = loss_mse_image * 2 + loss_clip_image + loss_clip_text + loss_mse_text * 2 117 | self.update_local_ema(self.model, self.model_ema, self.args.ema_decay) 118 | 119 | self.optimizer.zero_grad() 120 | loss.backward() 121 | self.optimizer.step() 122 | self.lr_scheduler.step() 123 | 124 | current_lr = self.lr_scheduler.get_last_lr()[0] 125 | self.total_loss.update(loss.item()) 126 | self.mse_image.update(loss_mse_image.item()) 127 | self.mse_text.update(loss_mse_text.item()) 128 | self.nce_image.update(loss_clip_image.item()) 129 | self.nce_text.update(loss_clip_text.item()) 130 | 131 | writer.add_scalar(f'Loss/loss_All_train_client_{self.id}', self.total_loss.avg, self.all_steps * len(self.train_dl) + train_i) 132 | writer.add_scalar(f'Loss/loss_Mse_image_train_client_{self.id}', self.mse_image.avg, self.all_steps * len(self.train_dl) + train_i) 133 | writer.add_scalar(f'Loss/loss_Mse_text_train_client_{self.id}', self.mse_text.avg, self.all_steps * len(self.train_dl) + train_i) 134 | writer.add_scalar(f'Loss/loss_SoftCliptrain_image_client_{self.id}', self.nce_image.avg, self.all_steps * len(self.train_dl) + train_i) 135 | writer.add_scalar(f'Loss/loss_SoftCliptrain_text_client_{self.id}', self.nce_text.avg, self.all_steps * len(self.train_dl) + train_i) 136 | writer.add_scalar(f'Learning rate/train_client_{self.id}', current_lr, self.all_steps * len(self.train_dl) + train_i) 137 | 138 | if (train_i % (len(self.train_dl) // 8) == 0): 139 | logger.info(f"client{self.id}: Learning rate: {current_lr:.4f}, Loss softclip image:{self.nce_image.avg:.4f}, Loss softclip text:{self.nce_text.avg:.4f}") 140 | 141 | self.all_steps += 1 142 | self.model.eval() 143 | 144 | with torch.no_grad(): 145 | val_sims_base_image = AverageMeter() 146 | val_sims_base_text = AverageMeter() 147 | val_loss_base_mse_image = AverageMeter() 148 | val_loss_base_mse_text = AverageMeter() 149 | val_loss_base_nce_image = AverageMeter() 150 | val_loss_base_nce_text = AverageMeter() 151 | 152 | for val_i, data_i in enumerate(self.val_dl): 153 | self.val_i = val_i 154 | repeat_index = val_i % 3 155 | voxel, image, coco = data_i 156 | voxel = torch.mean(voxel,axis=1) 157 | voxel = voxel.to(f'cuda:{self.cuda_id}').float() 158 | 159 | coco_ids = coco.squeeze().tolist() 160 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 161 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 162 | 163 | clip_image = self.clip_extractor.embed_image(image).float() 164 | clip_text = self.clip_extractor.embed_text(captions).float() 165 | 166 | clip_image = clip_image.to(f'cuda:{self.cuda_id}') 167 | clip_text = clip_text.to(f'cuda:{self.cuda_id}') 168 | 169 | ridge_out = self.model.ridge(voxel) 170 | results = self.model.backbone(ridge_out) 171 | 172 | clip_image_pred = results[0] 173 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1) 174 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1) 175 | 176 | val_loss_mse_image = self.loss_mse(clip_image_pred_norm, clip_image_norm) * 10000 177 | 178 | loss_clip_image = soft_clip_loss( 179 | clip_image_pred_norm, 180 | clip_image_norm, 181 | ) 182 | 183 | val_sims_image = nn.functional.cosine_similarity(clip_image_norm, clip_image_pred_norm).mean().item() 184 | 185 | clip_text_pred = results[1] 186 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1) 187 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1) 188 | 189 | val_loss_mse_text = self.loss_mse(clip_text_pred_norm, clip_text_norm) * 10000 190 | 191 | loss_clip_text = soft_clip_loss( 192 | clip_text_pred_norm, 193 | clip_text_norm, 194 | ) 195 | 196 | val_sims_text = nn.functional.cosine_similarity(clip_text_norm, clip_text_pred_norm).mean().item() 197 | 198 | val_loss_base_nce_image.update(loss_clip_image.item()) 199 | val_loss_base_nce_text.update(loss_clip_text.item()) 200 | val_loss_base_mse_image.update(val_loss_mse_image.item()) 201 | val_loss_base_mse_text.update(val_loss_mse_text.item()) 202 | val_sims_base_image.update(val_sims_image) 203 | val_sims_base_text.update(val_sims_text) 204 | 205 | writer.add_scalar(f'Val/sim_image_{self.id}', val_sims_base_image.avg, self.all_steps) 206 | writer.add_scalar(f'Val/sim_text_{self.id}', val_sims_base_text.avg, self.all_steps) 207 | writer.add_scalar(f'Val/loss_mse_image{self.id}', val_loss_base_mse_image.avg, self.all_steps) 208 | writer.add_scalar(f'Val/loss_mse_text{self.id}', val_loss_base_mse_text.avg, self.all_steps) 209 | writer.add_scalar(f'Val/loss_SoftClip_image{self.id}', val_loss_base_nce_image.avg, self.all_steps) 210 | writer.add_scalar(f'Val/loss_SoftClip_text{self.id}', val_loss_base_nce_text.avg, self.all_steps) 211 | 212 | logger.info(f'client{self.id} Mean sim image: {val_sims_base_image.avg}, Mean sim text: {val_sims_base_text.avg}') 213 | if val_sims_base_image.avg > self.global_best_val_sim_image: 214 | self.global_best_val_sim_image = val_sims_base_image.avg 215 | torch.save(self.model.state_dict(), './logs/model/client{}_best.pth'.format(self.id)) 216 | 217 | logger.info("Train Client {} done".format(self.id)) 218 | 219 | 220 | def eval_global_model(self, global_model, writer, round, logger): 221 | global_sims_base_image = AverageMeter() 222 | global_sims_base_text = AverageMeter() 223 | global_loss_base_nce_image = AverageMeter() 224 | global_loss_base_nce_text = AverageMeter() 225 | global_loss_base_mse_image = AverageMeter() 226 | global_loss_base_mse_text = AverageMeter() 227 | 228 | self.model.ridge.to(f'cuda:{self.cuda_id}') 229 | global_model.backbone.to(f'cuda:{self.cuda_id}') 230 | 231 | self.model.eval() 232 | global_model.eval() 233 | 234 | with torch.no_grad(): 235 | for val_i, data_i in enumerate(self.val_dl): 236 | self.val_i = val_i 237 | repeat_index = val_i % 3 238 | voxel, image, coco = data_i 239 | voxel = torch.mean(voxel,axis=1) 240 | 241 | coco_ids = coco.squeeze().tolist() 242 | current_prompts_list = [self.prompts_list[coco_id] for coco_id in coco_ids] 243 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 244 | 245 | clip_image = self.clip_extractor.embed_image(image).float() 246 | clip_text = self.clip_extractor.embed_text(captions).float() 247 | 248 | voxel = voxel.to(f'cuda:{self.cuda_id}').float() 249 | clip_text = clip_text.to(f'cuda:{self.cuda_id}').float() 250 | clip_image = clip_image.to(f'cuda:{self.cuda_id}') 251 | 252 | ridge_out = self.model.ridge(voxel) 253 | results = global_model.backbone(ridge_out) 254 | 255 | clip_image_pred = results[0] 256 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1) 257 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1) 258 | 259 | global_sims_image = nn.functional.cosine_similarity(clip_image_norm, clip_image_pred_norm).mean().item() 260 | loss_clip_image = soft_clip_loss( 261 | clip_image_pred_norm, 262 | clip_image_norm, 263 | ) 264 | global_loss_mse_image = self.loss_mse(clip_image_pred_norm, clip_image_norm) * 10000 265 | 266 | clip_text_pred = results[1] 267 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1) 268 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1) 269 | global_sims_text = nn.functional.cosine_similarity(clip_text_pred_norm, clip_text_norm).mean().item() 270 | 271 | global_loss_mse_text = self.loss_mse(clip_text_pred_norm, clip_text_norm) * 10000 272 | 273 | loss_clip_text = soft_clip_loss( 274 | clip_text_pred_norm, 275 | clip_text_norm, 276 | ) 277 | 278 | global_sims_base_image.update(global_sims_image) 279 | global_sims_base_text.update(global_sims_text) 280 | global_loss_base_nce_image.update(loss_clip_image.item()) 281 | global_loss_base_nce_text.update(loss_clip_text.item()) 282 | global_loss_base_mse_image.update(global_loss_mse_image.item()) 283 | global_loss_base_mse_text.update(global_loss_mse_text.item()) 284 | 285 | writer.add_scalar(f'Global_val/sim_image{self.id}', global_sims_base_image.avg, self.all_steps) 286 | writer.add_scalar(f'Global_val/sim_text{self.id}', global_sims_base_text.avg, self.all_steps) 287 | writer.add_scalar(f'Global_val/loss_mse_image{self.id}', global_loss_base_mse_image.avg, self.all_steps) 288 | writer.add_scalar(f'Global_val/loss_mse_text{self.id}', global_loss_base_mse_text.avg, self.all_steps) 289 | writer.add_scalar(f'Global_val/loss_nce_image{self.id}', global_loss_base_nce_image.avg, self.all_steps) 290 | writer.add_scalar(f'Global_val/loss_nce_text{self.id}', global_loss_base_nce_text.avg, self.all_steps) 291 | 292 | logger.info(f'Globel model on client{self.id} data:\n sim_image:{global_sims_base_image.avg:.4f} sim_text:{global_sims_base_text.avg:.4f}') 293 | self.model.ridge.to(f'cpu') 294 | global_model.backbone.to(f'cpu') 295 | 296 | 297 | def local_initialization(self, received_global_model, writer, round): 298 | self.model.to(f'cpu') 299 | temp_global_model = copy.deepcopy(received_global_model) 300 | temp_global_model.to(f'cpu') 301 | if self.flag_dfl: 302 | self.DFL.adaptive_local_aggregation(temp_global_model, self.model, writer, round, self.clip_extractor, self.prompts_list) 303 | 304 | 305 | def load_train_data(self, batch_size=None, is_train=True): 306 | 307 | if batch_size == None: 308 | batch_size = self.batch_size 309 | train_data, test_data = read_client_data(self.id, self.args.train_type) 310 | if is_train: 311 | return DataLoader(train_data, batch_size, drop_last=False, shuffle=True, num_workers=4) 312 | else: 313 | if self.resume: 314 | return DataLoader(test_data, batch_size=1, drop_last=False, shuffle=False, num_workers=4) 315 | else: 316 | return DataLoader(test_data, batch_size=self.batch_size, drop_last=False, shuffle=False, num_workers=4) 317 | 318 | def set_opt_grouped_parameters(self, args): 319 | no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] 320 | self.opt_grouped_parameters = [ 321 | {'params': [p for n, p in self.model.ridge.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, 322 | {'params': [p for n, p in self.model.ridge.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}, 323 | {'params': [p for n, p in self.model.backbone.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 1e-2}, 324 | {'params': [p for n, p in self.model.backbone.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0},] 325 | 326 | 327 | def set_lr_scheduler(self, args): 328 | total_steps=((args.global_rounds * self.local_steps) * math.ceil(8859 / args.batch_size)) 329 | self.lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( 330 | self.optimizer, 331 | max_lr=self.learning_rate, 332 | total_steps=total_steps, 333 | final_div_factor=1000, 334 | last_epoch=-1, pct_start=2/(args.global_rounds * self.local_steps) 335 | ) 336 | 337 | def prepare_CLIP(self, args, device): 338 | # Prepare CLIP 339 | clip_sizes = {"RN50": 1024, "ViT-L/14": 768, "ViT-B/32": 512, "ViT-H-14": 1024} 340 | clip_size = clip_sizes[args.clip_variant] 341 | 342 | print("Using hidden layer CLIP space (Versatile Diffusion)") 343 | if not args.norm_embs: 344 | print("WARNING: YOU WANT NORMED EMBEDDINGS FOR VERSATILE DIFFUSION!") 345 | clip_extractor = Clipper(args.clip_variant, device=device, hidden_state=True, norm_embs=True) 346 | 347 | out_dim_image = 257 * clip_size # 257*768 = 197376 348 | out_dim_text = 77 * clip_size # 77*768 = 59136 349 | 350 | print("clip_extractor loaded.") 351 | print("out_dim_image:",out_dim_image) 352 | print("out_dim_text:", out_dim_text) 353 | 354 | return clip_extractor 355 | 356 | def prepare_dataloader(self): 357 | # Prepare data and dataloader 358 | print("Preparing data and dataloader...") 359 | self.train_dl, self.val_dl = data.get_dls( 360 | subject=self.id, 361 | data_path=self.args.data_root, 362 | batch_size=self.args.batch_size, 363 | val_batch_size=self.args.batch_size, 364 | num_workers=self.args.num_workers, 365 | pool_type='max', 366 | pool_num=8192, 367 | length=8859, 368 | seed=42, 369 | ) 370 | self.num_batches = len(self.train_dl) 371 | 372 | def update_local_ema(self, local_model, ema_model, alpha): 373 | for param, ema_param in zip(local_model.parameters(), ema_model.parameters()): 374 | ema_param.data = alpha * param.data + (1 - alpha) * ema_param.data -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import json 2 | import torch 3 | import argparse 4 | import time 5 | import numpy as np 6 | import pprint 7 | from utils.data_utils import count_params 8 | from servers.server import Server 9 | from trainmodel.models import * 10 | import logging 11 | from utils.utils import init_log 12 | 13 | def run(args): 14 | 15 | args.multi_voxel_dims = {1:15724, 2:14278, 5:13039, 7:12682} 16 | args.client_name = ['subj1', 'subj2', 'subj5', 'subj7'] 17 | 18 | for i in range(args.prev, args.times): 19 | args.logger.info("Creating server and clients ...") 20 | 21 | clip_emb_dim = 768 22 | hidden_dim = 2048 23 | 24 | model = BrainGuardModule() 25 | model.ridge = RidgeRegression(input_size=15724, out_features=hidden_dim) 26 | model.backbone = BrainNetwork(in_dim=hidden_dim, latent_size=clip_emb_dim, out_dim_image=257*768, out_dim_text=77*768, use_projector=True, train_type=args.train_type) 27 | 28 | args.logger.info('Total params: {:.1f}M\n'.format(count_params(model))) 29 | args.model = model.to('cpu' if args.cuda_id["server"]==-1 else f'cuda:{args.cuda_id["server"]}') 30 | args.logger.info(args.model) 31 | 32 | server = Server(args) 33 | 34 | if args.resume: 35 | args.logger.info("Starting test directly!") 36 | server.test(resume=True) 37 | args.logger.info("Resume test done!") 38 | break 39 | 40 | args.logger.info(f'==========Strat train {args.train_type} model==========') 41 | args.logger.info(f"============= Running time: {i}th =============") 42 | 43 | server.train(args) 44 | 45 | args.logger.info("All done!") 46 | 47 | 48 | 49 | if __name__ == "__main__": 50 | total_start = time.time() 51 | 52 | parser = argparse.ArgumentParser() 53 | # general 54 | parser.add_argument('-dev', "--device", type=str, default="cuda", choices=["cpu", "cuda"]) 55 | parser.add_argument('-data', "--dataset", type=str, default="NSD") 56 | parser.add_argument('-lbs', "--batch_size", type=int, default=50) 57 | parser.add_argument('-lr', "--local_learning_rate", type=float, default=3e-4, help="Local learning rate") 58 | parser.add_argument("--lr_scheduler_type",type=str,default='cycle', choices=['cycle','linear'],) 59 | parser.add_argument('-gr', "--global_rounds", type=int, default=600) 60 | parser.add_argument('-ls', "--local_steps", type=int, default=1) 61 | parser.add_argument('-jr', "--join_ratio", type=float, default=1.0, help="Ratio of clients per round") 62 | parser.add_argument('-c', "--clients", type=int, default=[1,2,5,7], help="Train clients") 63 | parser.add_argument('-t', "--times", type=int, default=1, help="Running times") 64 | parser.add_argument('-et', "--eta", type=float, default=1.0) 65 | parser.add_argument('-p', "--layer_idx", type=int, default=24) 66 | parser.add_argument('-ed', "--ema_decay", type=float, default=0.999, help="EMA decay rate") 67 | parser.add_argument("--cuda_id", type=json.loads, default='{"server":-1, "1": 0, "2": 1, "5": 2, "7": 3}', help="") 68 | parser.add_argument('--data_root', type=str, default='datapath') 69 | parser.add_argument("--clip_variant",type=str,default="ViT-L/14",choices=["RN50", "ViT-L/14", "ViT-B/32", "RN50x64"], help='OpenAI clip variant',) 70 | parser.add_argument("--num_workers",type=int, default=5, help="Number of workers in dataloader") 71 | parser.add_argument("--norm_embs",action=argparse.BooleanOptionalAction, default=True, help="Do l2-norming of CLIP embeddings",) 72 | parser.add_argument("--use_image_aug",action=argparse.BooleanOptionalAction, default=True, help="whether to use image augmentation",) 73 | args = parser.parse_args() 74 | 75 | logger = init_log('global', logging.INFO) 76 | logger.propagate = 0 77 | args.logger = logger 78 | if args.device == "cuda" and not torch.cuda.is_available(): 79 | print("\ncuda is not avaiable.\n") 80 | args.device = "cpu" 81 | 82 | all_args = {**vars(args)} 83 | print('{}\n'.format(pprint.pformat(all_args))) 84 | run(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.1 2 | torchvision==0.15.2 3 | diffusers==0.13.0 4 | kornia 5 | tqdm 6 | pandas 7 | scipy 8 | accelerate 9 | deepspeed 10 | torchsnooper 11 | matplotlib 12 | pycocotools 13 | h5py 14 | nibabel 15 | urllib3 16 | numpy 17 | wandb 18 | pillow 19 | scikit-image 20 | clip 21 | clip-retrieval 22 | transformers 23 | gpustat -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | now=$(date +"%Y%m%d_%H%M%S") 3 | 4 | train_type=vision 5 | save_path=./logs/$train_type/ 6 | data_root=/data/natural-scenes-dataset # data path 7 | 8 | python -u main.py \ 9 | -ls 1 \ 10 | -gr 600 \ 11 | --cuda_id '{"server":-1, "1": 0, "2": 1, "5": 2, "7": 3}' \ 12 | -tp $train_type \ 13 | -p 24 \ 14 | -lbs 50 \ 15 | --data_root $data_root \ 16 | 2>&1 | tee $save_path/$now.log 17 | -------------------------------------------------------------------------------- /servers/server.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import numpy as np 3 | import torch 4 | import time 5 | from clients.client import * 6 | from threading import Thread 7 | from torch.utils.tensorboard import SummaryWriter 8 | import os 9 | from utils.utils import Clipper 10 | from utils.nsd_access import NSDAccess 11 | 12 | 13 | class Server(object): 14 | def __init__(self, args): 15 | self.device = args.device 16 | self.dataset = args.dataset 17 | self.global_rounds = args.global_rounds 18 | self.global_model = copy.deepcopy(args.model) 19 | self.num_clients = [1,2,5,7] 20 | self.join_ratio = args.join_ratio 21 | self.random_join_ratio = args.random_join_ratio 22 | self.join_clients = int(len(self.num_clients) * self.join_ratio) 23 | self.cuda_id = args.cuda_id 24 | self.clients = [] 25 | self.selected_clients = [] 26 | 27 | self.uploaded_weights = [] 28 | self.uploaded_ids = [] 29 | self.uploaded_models = [] 30 | 31 | self.rs_test_loss = [] 32 | self.rs_train_loss = [] 33 | # self.clip_extractor = self.prepare_CLIP(args, 'cpu') 34 | # self.times = times 35 | self.eval_gap = args.eval_gap 36 | self.writer = None 37 | self.set_clients(args, client) 38 | if not args.resume: 39 | log_dir = './logs/{}/{}'.format(args.train_type, time.strftime("%b%d_%d-%H-%M", time.localtime())) 40 | self.writer = SummaryWriter(log_dir=log_dir) 41 | os.makedirs(log_dir, exist_ok=True) 42 | print(f"\nJoin ratio / total clients: {self.join_ratio} / {len(self.num_clients)}") 43 | print("Finished creating server and clients.") 44 | self.selected_clients = self.select_clients() 45 | self.prompts_list = self.prepare_coco(args.data_root) 46 | self.Budget = [] 47 | 48 | def train(self, args): 49 | for i in range(self.global_rounds): 50 | args.logger.info(f"============= Round: {i+1}th =============") 51 | s_t = time.time() 52 | if i != 0 and i < self.global_rounds: 53 | self.send_models(i) 54 | # if i % args.eval_interval == 0: 55 | # for client in self.selected_clients: 56 | # client.eval_local_model(self.writer, i, args.logger, self.prompts_list) 57 | 58 | thread_train = [Thread(target=client.train, args=(self.writer, i, args.logger)) 59 | for client in self.selected_clients] 60 | [t.start() for t in thread_train] 61 | [t.join() for t in thread_train] 62 | 63 | # for client in self.selected_clients: 64 | # client.train(self.writer, i, args.logger) 65 | if i < self.global_rounds: 66 | self.receive_models() 67 | self.aggregate_parameters() 68 | 69 | # eval global model after aggregate_parameters 70 | # threads_eval_global = [Thread(target=client.eval_global_model, args=(self.global_model, self.writer, i, args.logger)) 71 | # for client in self.selected_clients] 72 | # [t.start() for t in threads_eval_global] 73 | # [t.join() for t in threads_eval_global] 74 | if i % args.eval_interval == 0: 75 | args.logger.info(f'======Start using clients data eval global model======') 76 | for client in self.selected_clients: 77 | client.eval_global_model(self.global_model, self.writer, i, args.logger) 78 | 79 | def set_clients(self, args, clientObj): 80 | for client in self.num_clients: 81 | # train_data = read_client_data(client, is_train=True) 82 | # test_data = read_client_data(client, is_train=False) 83 | client = clientObj(args, 84 | id=client, 85 | train_samples=8859, 86 | test_samples=982, cuda_id=self.cuda_id[str(client)]) 87 | self.clients.append(client) 88 | 89 | def select_clients(self): 90 | if self.random_join_ratio: 91 | join_clients = np.random.choice(range(self.join_clients, self.num_clients+1), 1, replace=False)[0] 92 | else: 93 | join_clients = self.join_clients 94 | selected_clients = list(np.random.choice(self.clients, join_clients, replace=False)) 95 | 96 | return selected_clients 97 | 98 | def send_models(self, round): 99 | assert (len(self.clients) > 0) 100 | 101 | # for client in self.clients: 102 | # client.local_initialization(self.global_model) 103 | threads = [Thread(target=client.local_initialization, args=(self.global_model, self.writer, round, )) 104 | for client in self.selected_clients] 105 | [t.start() for t in threads] 106 | [t.join() for t in threads] 107 | 108 | def receive_models(self): 109 | assert (len(self.selected_clients) > 0) 110 | 111 | active_train_samples = 0 112 | for client in self.selected_clients: 113 | active_train_samples += client.train_samples 114 | 115 | self.uploaded_weights = [] 116 | self.uploaded_ids = [] 117 | self.uploaded_models = [] 118 | for client in self.selected_clients: 119 | self.uploaded_weights.append(client.train_samples / active_train_samples) 120 | self.uploaded_ids.append(client.id) 121 | self.uploaded_models.append(client.model_ema) 122 | 123 | def add_parameters(self, w, client_model): 124 | client_model = client_model.to('cpu' if self.cuda_id["server"]==-1 else f'cuda:{self.cuda_id["server"]}') 125 | # for server_param, client_param in zip(self.global_model.parameters(), client_model.parameters()): 126 | # server_param.data += client_param.data.clone() * w 127 | for (server_param_name, server_param), (_, client_param) in zip(self.global_model.named_parameters(), client_model.named_parameters()): 128 | if 'ridge' not in server_param_name: 129 | server_param.data += client_param.data.clone() * w 130 | # # server_param.data += client_param.data.clone() 131 | # # server_param.data += client_param.data.clone() * w 132 | 133 | def aggregate_parameters(self): 134 | assert (len(self.uploaded_models) > 0) 135 | 136 | self.global_model = copy.deepcopy(self.uploaded_models[0]) 137 | self.global_model.to('cpu' if self.cuda_id["server"]==-1 else f'cuda:{self.cuda_id["server"]}') 138 | for param in self.global_model.parameters(): 139 | param.data = torch.zeros_like(param.data) 140 | 141 | for w, client_model in zip(self.uploaded_weights, self.uploaded_models): 142 | self.add_parameters(w, client_model) 143 | 144 | 145 | def prepare_CLIP(self, args, device): 146 | # Prepare CLIP 147 | clip_sizes = {"RN50": 1024, "ViT-L/14": 768, "ViT-B/32": 512, "ViT-H-14": 1024} 148 | clip_size = clip_sizes[args.clip_variant] 149 | 150 | print("Using hidden layer CLIP space (Versatile Diffusion)") 151 | if not args.norm_embs: 152 | print("WARNING: YOU WANT NORMED EMBEDDINGS FOR VERSATILE DIFFUSION!") 153 | clip_extractor = Clipper(args.clip_variant, device=device, hidden_state=True, norm_embs=True) 154 | 155 | out_dim_image = 257 * clip_size # 257*768 = 197376 156 | out_dim_text = 77 * clip_size # 77*768 = 59136 157 | 158 | print("clip_extractor loaded.") 159 | print("out_dim_image:",out_dim_image) 160 | print("out_dim_text:", out_dim_text) 161 | 162 | return clip_extractor 163 | 164 | def prepare_coco(self, data_path): 165 | # Preload coco captions 166 | nsda = NSDAccess(data_path) 167 | coco_73k = list(range(0, 73000)) 168 | prompts_list = nsda.read_image_coco_info(coco_73k, info_type='captions') 169 | 170 | print("coco captions loaded.") 171 | 172 | return prompts_list 173 | 174 | -------------------------------------------------------------------------------- /trainmodel/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import math 5 | 6 | 7 | from diffusers.models.vae import Decoder 8 | class Voxel2StableDiffusionModel(torch.nn.Module): 9 | def __init__(self, in_dim=15724, h=4096, n_blocks=4, use_cont=False, ups_mode='4x'): 10 | super().__init__() 11 | self.lin0 = nn.Sequential( 12 | nn.Linear(in_dim, h, bias=False), 13 | nn.LayerNorm(h), 14 | nn.SiLU(inplace=True), 15 | nn.Dropout(0.5), 16 | ) 17 | 18 | self.mlp = nn.ModuleList([ 19 | nn.Sequential( 20 | nn.Linear(h, h, bias=False), 21 | nn.LayerNorm(h), 22 | nn.SiLU(inplace=True), 23 | nn.Dropout(0.25) 24 | ) for _ in range(n_blocks) 25 | ]) 26 | self.ups_mode = ups_mode 27 | if ups_mode=='4x': 28 | self.lin1 = nn.Linear(h, 16384, bias=False) 29 | self.norm = nn.GroupNorm(1, 64) 30 | 31 | self.upsampler = Decoder( 32 | in_channels=64, 33 | out_channels=4, 34 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], 35 | block_out_channels=[64, 128, 256], 36 | layers_per_block=1, 37 | ) 38 | 39 | if use_cont: 40 | self.maps_projector = nn.Sequential( 41 | nn.Conv2d(64, 512, 1, bias=False), 42 | nn.GroupNorm(1,512), 43 | nn.ReLU(True), 44 | nn.Conv2d(512, 512, 1, bias=False), 45 | nn.GroupNorm(1,512), 46 | nn.ReLU(True), 47 | nn.Conv2d(512, 512, 1, bias=True), 48 | ) 49 | else: 50 | self.maps_projector = nn.Identity() 51 | 52 | if ups_mode=='8x': # prev best 53 | self.lin1 = nn.Linear(h, 16384, bias=False) 54 | self.norm = nn.GroupNorm(1, 256) 55 | 56 | self.upsampler = Decoder( 57 | in_channels=256, 58 | out_channels=4, 59 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D"], 60 | block_out_channels=[64, 128, 256, 256], 61 | layers_per_block=1, 62 | ) 63 | self.maps_projector = nn.Identity() 64 | 65 | if ups_mode=='16x': 66 | self.lin1 = nn.Linear(h, 8192, bias=False) 67 | self.norm = nn.GroupNorm(1, 512) 68 | 69 | self.upsampler = Decoder( 70 | in_channels=512, 71 | out_channels=4, 72 | up_block_types=["UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D","UpDecoderBlock2D", "UpDecoderBlock2D"], 73 | block_out_channels=[64, 128, 256, 256, 512], 74 | layers_per_block=1, 75 | ) 76 | self.maps_projector = nn.Identity() 77 | 78 | if use_cont: 79 | self.maps_projector = nn.Sequential( 80 | nn.Conv2d(64, 512, 1, bias=False), 81 | nn.GroupNorm(1,512), 82 | nn.ReLU(True), 83 | nn.Conv2d(512, 512, 1, bias=False), 84 | nn.GroupNorm(1,512), 85 | nn.ReLU(True), 86 | nn.Conv2d(512, 512, 1, bias=True), 87 | ) 88 | else: 89 | self.maps_projector = nn.Identity() 90 | 91 | # @torchsnooper.snoop() 92 | def forward(self, x, return_transformer_feats=False): 93 | x = self.lin0(x) 94 | residual = x 95 | for res_block in self.mlp: 96 | x = res_block(x) 97 | x = x + residual 98 | residual = x 99 | x = x.reshape(len(x), -1) 100 | x = self.lin1(x) # bs, 4096 101 | 102 | if self.ups_mode == '4x': 103 | side = 16 104 | if self.ups_mode == '8x': 105 | side = 8 106 | if self.ups_mode == '16x': 107 | side = 4 108 | 109 | # decoder 110 | x = self.norm(x.reshape(x.shape[0], -1, side, side).contiguous()) 111 | if return_transformer_feats: 112 | return self.upsampler(x), self.maps_projector(x).flatten(2).permute(0,2,1) 113 | return self.upsampler(x) 114 | 115 | 116 | class BrainGuardModule(nn.Module): 117 | def __init__(self): 118 | super(BrainGuardModule, self).__init__() 119 | def forward(self, x): 120 | return x 121 | 122 | class RidgeRegression(torch.nn.Module): 123 | # make sure to add weight_decay when initializing optimizer 124 | def __init__(self, input_size, out_features): 125 | super(RidgeRegression, self).__init__() 126 | self.linear = torch.nn.Linear(input_size, out_features) 127 | def forward(self, x): 128 | x = self.linear(x) 129 | return x 130 | 131 | class BrainNetwork(nn.Module): 132 | def __init__(self, out_dim_image=768, out_dim_text=768, in_dim=15724, latent_size=768, h=2048, n_blocks=4, norm_type='ln', use_projector=True, act_first=False, drop1=.5, drop2=.15, train_type='vision'): 133 | super().__init__() 134 | norm_func = partial(nn.BatchNorm1d, num_features=h) if norm_type == 'bn' else partial(nn.LayerNorm, normalized_shape=h) 135 | act_fn = partial(nn.ReLU, inplace=True) if norm_type == 'bn' else nn.GELU 136 | act_and_norm = (act_fn, norm_func) if act_first else (norm_func, act_fn) 137 | self.mlp = nn.ModuleList([ 138 | nn.Sequential( 139 | nn.Linear(h, h), 140 | *[item() for item in act_and_norm], 141 | nn.Dropout(drop2) 142 | ) for _ in range(n_blocks) 143 | ]) 144 | self.head_image = nn.Linear(h, out_dim_image, bias=True) 145 | self.head_text = nn.Linear(h, out_dim_text, bias=True) 146 | self.n_blocks = n_blocks 147 | self.latent_size = latent_size 148 | self.use_projector = use_projector 149 | self.train_type = train_type 150 | if use_projector: 151 | self.projector_image = nn.Sequential( 152 | nn.LayerNorm(self.latent_size), 153 | nn.GELU(), 154 | nn.Linear(self.latent_size, 2048), 155 | nn.LayerNorm(2048), 156 | nn.GELU(), 157 | nn.Linear(2048, 2048), 158 | nn.LayerNorm(2048), 159 | nn.GELU(), 160 | nn.Linear(2048, self.latent_size) 161 | ) 162 | self.projector_text = nn.Sequential( 163 | nn.LayerNorm(self.latent_size), 164 | nn.GELU(), 165 | nn.Linear(self.latent_size, 2048), 166 | nn.LayerNorm(2048), 167 | nn.GELU(), 168 | nn.Linear(2048, 2048), 169 | nn.LayerNorm(2048), 170 | nn.GELU(), 171 | nn.Linear(2048, self.latent_size) 172 | ) 173 | 174 | def forward(self, x): 175 | residual = x 176 | for res_block in range(self.n_blocks): 177 | x = self.mlp[res_block](x) 178 | x += residual 179 | residual = x 180 | x = x.reshape(len(x), -1) 181 | x_image = self.head_image(x) 182 | x_text = self.head_text(x) 183 | if self.use_projector: 184 | return self.projector_image(x_image.reshape(len(x_image), -1, self.latent_size)), self.projector_text(x_text.reshape(len(x_text), -1, self.latent_size)) 185 | return x 186 | -------------------------------------------------------------------------------- /utils/DFL.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import copy 5 | from utils.loss_avg import AverageMeter 6 | 7 | 8 | class DFL: 9 | def __init__(self, 10 | cid: int, 11 | cuda_id: int, 12 | soft_clip_loss: nn.Module, 13 | logit_scale: float, 14 | train_dl, 15 | layer_idx: int = 0, 16 | eta: float = 1.0, 17 | device: str = 'cpu', 18 | threshold: float = 50, 19 | num_pre_loss: int = 10, 20 | ) -> None: 21 | 22 | self.cid = cid 23 | self.soft_clip_loss = soft_clip_loss 24 | self.logit_scale = logit_scale 25 | self.train_dl = train_dl 26 | self.layer_idx = layer_idx 27 | self.eta = eta 28 | self.threshold = threshold 29 | self.num_pre_loss = num_pre_loss 30 | self.device = device 31 | self.cuda_id = cuda_id 32 | self.weights = None # Learnable local aggregation weights. 33 | self.start_phase = True 34 | self.losses_value = AverageMeter() 35 | self.all_steps = 0 36 | self.loss_mse = nn.MSELoss(reduction='mean').to(f'cuda:{self.cuda_id}') 37 | 38 | def adaptive_local_aggregation(self, 39 | global_model: nn.Module, 40 | local_model: nn.Module, 41 | writer, 42 | round, 43 | clip_extractor, 44 | prompts_list 45 | ) -> None: 46 | 47 | # obtain the references of the parameters 48 | params_g = list(global_model.parameters()) 49 | params = list(local_model.parameters()) 50 | 51 | for param, param_g in zip(params[2:-self.layer_idx], params_g[2:-self.layer_idx]): 52 | param.data.copy_(param_g.data) 53 | 54 | # temp local model only for weight learning 55 | model_t = copy.deepcopy(local_model) 56 | model_t.to(f'cuda:{self.cuda_id}') 57 | params_t = list(model_t.parameters()) 58 | 59 | # only consider higher layers 60 | params_p = params[-self.layer_idx:] 61 | params_gp = params_g[-self.layer_idx:] 62 | params_tp = params_t[-self.layer_idx:] 63 | 64 | # frozen the lower layers to reduce computational cost in Pytorch 65 | for param in params_t[:-self.layer_idx]: 66 | param.requires_grad = False 67 | 68 | # used to obtain the gradient of advanced layers 69 | # no need to use optimizer.step(), so lr=0 70 | optimizer = torch.optim.AdamW(params_tp, lr=0) 71 | 72 | # initialize the weight to all ones in the beginning 73 | if self.weights == None: 74 | self.weights = [torch.ones_like(param.data).to(f'cuda:{self.cuda_id}') for param in params_p] 75 | 76 | # initialize the higher layers in the temp local model 77 | for param_t, param, param_g, weight in zip(params_tp, params_p, params_gp, 78 | self.weights): 79 | param_gpu = param.to(f'cuda:{self.cuda_id}') 80 | param_g_gpu = param_g.to(f'cuda:{self.cuda_id}') 81 | param_t.data = param_gpu + (param_g_gpu - param_gpu) * weight 82 | 83 | # weight learning 84 | losses = [] # record losses 85 | cnt = 0 # weight training iteration counter 86 | while True: 87 | for train_i, data_i in enumerate(self.train_dl): 88 | repeat_index = train_i % 3 # randomly choose the one in the repeated three 89 | voxel, image, coco = data_i 90 | voxel = torch.mean(voxel,axis=1) 91 | voxel = voxel.to(f'cuda:{self.cuda_id}').float() 92 | 93 | coco_ids = coco.squeeze().tolist() 94 | current_prompts_list = [prompts_list[coco_id] for coco_id in coco_ids] 95 | captions = [prompts[repeat_index]['caption'] for prompts in current_prompts_list] 96 | 97 | clip_image = clip_extractor.embed_image(image).float() 98 | clip_text = clip_extractor.embed_text(captions).float() 99 | clip_image = clip_image.to(f'cuda:{self.cuda_id}') 100 | ridge_out = model_t.ridge(voxel) 101 | results = model_t.backbone(ridge_out) 102 | 103 | optimizer.zero_grad() 104 | 105 | clip_image_pred = results[0] 106 | clip_image_pred_norm = nn.functional.normalize(clip_image_pred.flatten(1), dim=-1) 107 | clip_image_norm = nn.functional.normalize(clip_image.flatten(1), dim=-1) 108 | loss_mse_image = self.loss_mse(clip_image_pred_norm, clip_image_norm) * 10000 109 | loss_clip_image = self.soft_clip_loss( 110 | clip_image_pred_norm, 111 | clip_image_norm, 112 | ) 113 | 114 | clip_text_pred = results[1] 115 | clip_text_pred_norm = nn.functional.normalize(clip_text_pred.flatten(1), dim=-1) 116 | clip_text_norm = nn.functional.normalize(clip_text.flatten(1), dim=-1) 117 | loss_mse_text = self.loss_mse(clip_text_pred_norm, clip_text_norm) * 10000 118 | loss_clip_text = self.soft_clip_loss( 119 | clip_text_pred_norm, 120 | clip_text_norm, 121 | ) 122 | 123 | loss = loss_mse_image * 2 + loss_clip_image + loss_clip_text + loss_mse_text * 2 124 | self.losses_value.update(loss.item()) 125 | if (train_i % (len(self.train_dl) // 8) == 0): 126 | print(f"client{self.cid}: loss_DFL: {self.losses_value.avg:.4f}") 127 | loss.backward() 128 | writer.add_scalar(f'Loss_DFL/client_{self.cid}', self.losses_value.avg, self.all_steps * len(self.train_dl) + train_i) 129 | # update weight in this batch 130 | for param_t, param, param_g, weight in zip(params_tp, params_p,params_gp, self.weights): 131 | 132 | param_gpu = param.to(f'cuda:{self.cuda_id}') 133 | param_g_gpu = param_g.to(f'cuda:{self.cuda_id}') 134 | weight_update = self.eta * ((param_t.grad * 1000) * (param_g_gpu - param_gpu)) 135 | weight.data = torch.clamp(weight.data - weight_update, 0, 1) 136 | 137 | # update temp local model in this batch 138 | for param_t, param, param_g, weight in zip(params_tp, params_p, 139 | params_gp, self.weights): 140 | param_gpu = param.to(f'cuda:{self.cuda_id}') 141 | param_g_gpu = param_g.to(f'cuda:{self.cuda_id}') 142 | param_t.data = param_gpu + (param_g_gpu - param_gpu) * weight 143 | self.all_steps += 1 144 | losses.append(loss.item()) 145 | cnt += 1 146 | 147 | # only train one epoch in the subsequent iterations 148 | if not self.start_phase: 149 | break 150 | 151 | # train the weight until convergence 152 | if len(losses) > self.num_pre_loss and np.std(losses[-self.num_pre_loss:]) < self.threshold: 153 | print('Client:', self.cid, '\tStd:', np.std(losses[-self.num_pre_loss:]), 154 | '\tDFL epochs:', cnt) 155 | break 156 | 157 | self.start_phase = False 158 | 159 | for param, param_t in zip(params_p, params_tp): 160 | param.data = param_t.data.clone().to('cpu') 161 | -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image 5 | from torch import nn 6 | import torch.nn.functional as F 7 | from torch.utils.data import Dataset, DataLoader 8 | from utils.utils import seed_everything 9 | import kornia 10 | from kornia.augmentation.container import AugmentationSequential 11 | 12 | 13 | img_augment = AugmentationSequential( 14 | kornia.augmentation.RandomResizedCrop((224,224), (0.8,1), p=0.3), 15 | kornia.augmentation.Resize((224, 224)), 16 | kornia.augmentation.RandomBrightness(brightness=(0.8, 1.2), clip_output=True, p=0.2), 17 | kornia.augmentation.RandomContrast(contrast=(0.8, 1.2), clip_output=True, p=0.2), 18 | kornia.augmentation.RandomGamma((0.8, 1.2), (1.0, 1.3), p=0.2), 19 | kornia.augmentation.RandomSaturation((0.8,1.2), p=0.2), 20 | kornia.augmentation.RandomHue((-0.1,0.1), p=0.2), 21 | kornia.augmentation.RandomSharpness((0.8, 1.2), p=0.2), 22 | kornia.augmentation.RandomGrayscale(p=0.2), 23 | data_keys=["input"], 24 | ) 25 | 26 | class NSDDataset(Dataset): 27 | def __init__(self, root_dir, extensions=None, pool_num=8192, pool_type="max", length=None): 28 | self.root_dir = root_dir 29 | self.extensions = extensions if extensions else [] 30 | self.pool_num = pool_num 31 | self.pool_type = pool_type 32 | self.samples = self._load_samples() 33 | self.samples_keys = sorted(self.samples.keys()) 34 | self.length = length 35 | if length is not None: 36 | if length > len(self.samples_keys): 37 | pass # enlarge the dataset 38 | elif length > 0: 39 | self.samples_keys = self.samples_keys[:length] 40 | elif length < 0: 41 | self.samples_keys = self.samples_keys[length:] 42 | elif length == 0: 43 | raise ValueError("length must be a non-zero value!") 44 | else: 45 | self.length = len(self.samples_keys) 46 | 47 | def _load_samples(self): 48 | files = os.listdir(self.root_dir) 49 | samples = {} 50 | for file in files: 51 | file_path = os.path.join(self.root_dir, file) 52 | sample_id, ext = file.split(".",maxsplit=1) 53 | if ext in self.extensions: 54 | if sample_id in samples.keys(): 55 | samples[sample_id][ext] = file_path 56 | else: 57 | samples[sample_id]={"subj": file_path} 58 | samples[sample_id][ext] = file_path 59 | # print(samples) 60 | return samples 61 | 62 | def _load_image(self, image_path): 63 | image = Image.open(image_path).convert('RGB') 64 | image = np.array(image).astype(np.float32) / 255.0 65 | image = torch.from_numpy(image.transpose(2, 0, 1)) 66 | return image 67 | 68 | def _load_npy(self, npy_path): 69 | array = np.load(npy_path) 70 | array = torch.from_numpy(array) 71 | return array 72 | 73 | def vox_process(self, x): 74 | if self.pool_num is not None: 75 | x = pool_voxels(x, self.pool_num, self.pool_type) 76 | return x 77 | 78 | def subj_process(self, key): 79 | id = int(key.split("/")[-2].split("subj")[-1]) 80 | return id 81 | 82 | def aug_process(self, brain3d): 83 | return brain3d 84 | 85 | def __len__(self): 86 | # return len(self.samples_keys) 87 | return self.length 88 | 89 | def __getitem__(self, idx): 90 | idx = idx % len(self.samples_keys) 91 | sample_key = self.samples_keys[idx] 92 | sample = self.samples[sample_key] 93 | items = [] 94 | for ext in self.extensions: 95 | if ext == "jpg": 96 | items.append(self._load_image(sample[ext])) 97 | elif ext == "nsdgeneral.npy": 98 | voxel = self._load_npy(sample[ext]) 99 | items.append(voxel) 100 | # items.append(self.vox_process(voxel)) 101 | elif ext == "coco73k.npy": 102 | items.append(self._load_npy(sample[ext])) 103 | elif ext == "subj": 104 | items.append(self.subj_process(sample[ext])) 105 | elif ext == "wholebrain_3d.npy": 106 | brain3d = self._load_npy(sample[ext]) 107 | items.append(self.aug_process(brain3d, )) 108 | 109 | return items 110 | 111 | def pool_voxels(voxels, pool_num, pool_type): 112 | voxels = voxels.float() 113 | if pool_type == 'avg': 114 | voxels = nn.AdaptiveAvgPool1d(pool_num)(voxels) 115 | elif pool_type == 'max': 116 | voxels = nn.AdaptiveMaxPool1d(pool_num)(voxels) 117 | elif pool_type == "resize": 118 | voxels = voxels.unsqueeze(1) # Add a dimension to make it (B, 1, L) 119 | voxels = F.interpolate(voxels, size=pool_num, mode='linear', align_corners=False) 120 | voxels = voxels.squeeze(1) 121 | 122 | return voxels 123 | 124 | def get_dataloader( 125 | root_dir, 126 | batch_size, 127 | num_workers=1, 128 | seed=42, 129 | is_shuffle=True, 130 | extensions=['nsdgeneral.npy', "jpg", 'coco73k.npy'], 131 | pool_type=None, 132 | pool_num=None, 133 | length=None, 134 | ): 135 | seed_everything(seed) 136 | dataset = NSDDataset(root_dir=root_dir, extensions=extensions, pool_num=pool_num, pool_type=pool_type, length=length) 137 | dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=is_shuffle) 138 | 139 | return dataloader 140 | 141 | def get_dls(subject, data_path, batch_size, val_batch_size, num_workers, pool_type, pool_num, length, seed): 142 | train_path = "{}/webdataset_avg_split/train/train_subj0{}".format(data_path, subject) 143 | val_path = "{}/webdataset_avg_split/val/val_subj0{}".format(data_path, subject) 144 | # extensions = ['nsdgeneral.npy', "jpg", 'coco73k.npy', "subj"] 145 | extensions = ['nsdgeneral.npy', "jpg", 'coco73k.npy'] 146 | 147 | train_dl = get_dataloader( 148 | train_path, 149 | batch_size=batch_size, 150 | num_workers=num_workers, 151 | seed=seed, 152 | extensions=extensions, 153 | pool_type=pool_type, 154 | pool_num=pool_num, 155 | is_shuffle=True, 156 | length=length, 157 | ) 158 | 159 | val_dl = get_dataloader( 160 | val_path, 161 | batch_size=val_batch_size, 162 | num_workers=num_workers, 163 | seed=seed, 164 | extensions=extensions, 165 | pool_type=pool_type, 166 | pool_num=pool_num, 167 | is_shuffle=False, 168 | ) 169 | 170 | num_train=len(train_dl.dataset) 171 | num_val=len(val_dl.dataset) 172 | print(train_path,"\n",val_path) 173 | print("number of train data:", num_train) 174 | print("batch_size", batch_size) 175 | print("number of val data:", num_val) 176 | print("val_batch_size", val_batch_size) 177 | 178 | return train_dl, val_dl 179 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def get_voxels_num(sub): 5 | if sub ==0: 6 | num_voxels = 15724 7 | elif sub == 1: 8 | num_voxels = 15724 9 | elif sub == 2: 10 | num_voxels = 14278 11 | elif sub == 3: 12 | num_voxels = 15226 13 | elif sub == 4: 14 | num_voxels = 13153 15 | elif sub == 5: 16 | num_voxels = 13039 17 | elif sub == 6: 18 | num_voxels = 17907 19 | elif sub == 7: 20 | num_voxels = 12682 21 | elif sub == 8: 22 | num_voxels = 14386 23 | 24 | return num_voxels 25 | 26 | 27 | def read_data(client, training_type): 28 | 29 | def load_and_normalize_data(fmri_path): 30 | fmri_data = np.load(fmri_path) 31 | norm_mean = np.mean(fmri_data, axis=0) 32 | norm_scale = np.std(fmri_data, axis=0, ddof=1) 33 | return ((fmri_data - norm_mean) / norm_scale) 34 | 35 | def get_clip_path(training_type, client): 36 | if training_type == 'vision': 37 | return '/data/NSD/data/extracted_features/subj{:02d}/nsd_clipvision_noavg_train.npy'.format(client), '/data/NSD/data/extracted_features/subj{:02d}/nsd_clipvision_test.npy'.format(client) 38 | elif training_type == 'text': 39 | return '/data/NSD/data/extracted_features/subj{:02d}/nsd_cliptext_noavg_train.npy'.format(client), '/data/NSD/data/extracted_features/subj{:02d}/nsd_cliptext_test.npy'.format(client) 40 | 41 | train_fmri_path = '/data/NSD/data/processed_data/subj{:02d}/nsd_train_fmrinoavg_nsdgeneral_sub{}.npy'.format(client, client) 42 | test_fmri_path = '/data/NSD/data/processed_data/subj{:02d}/nsd_test_fmriavg_nsdgeneral_sub{}.npy'.format(client, client) 43 | 44 | train_clip_path, test_clip_path = get_clip_path(training_type, client) 45 | clip_data = np.load(train_clip_path) 46 | clip_data_test = np.load(test_clip_path) 47 | 48 | return (load_and_normalize_data(train_fmri_path), clip_data, 49 | load_and_normalize_data(test_fmri_path), clip_data_test) 50 | 51 | def read_client_data(client, training_type): 52 | 53 | frmi_data, latent_data, frmi_data_test, latent_data_test = read_data(client, training_type) 54 | frmi_data = torch.Tensor(frmi_data).type(torch.float32) 55 | latent_data = torch.Tensor(latent_data).type(torch.float32) 56 | 57 | train_data = [(x, y) for x, y in zip(frmi_data, latent_data)] 58 | 59 | frmi_data_test = torch.Tensor(frmi_data_test).type(torch.float32) 60 | latent_data_test = torch.Tensor(latent_data_test).type(torch.float32) 61 | 62 | test_data = [(x, y) for x, y in zip(frmi_data_test, latent_data_test)] 63 | 64 | return train_data, test_data 65 | 66 | 67 | def count_params(model): 68 | param_num = sum(p.numel() for p in model.parameters()) 69 | return param_num / 1e6 70 | 71 | -------------------------------------------------------------------------------- /utils/loss_avg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class AverageMeter(object): 4 | """Computes and stores the average and current value""" 5 | 6 | def __init__(self, length=0): 7 | self.length = length 8 | self.reset() 9 | 10 | def reset(self): 11 | if self.length > 0: 12 | self.history = [] 13 | else: 14 | self.count = 0 15 | self.sum = 0.0 16 | self.val = 0.0 17 | self.avg = 0.0 18 | 19 | def update(self, val, num=1): 20 | if self.length > 0: 21 | # currently assert num==1 to avoid bad usage, refine when there are some explict requirements 22 | assert num == 1 23 | self.history.append(val) 24 | if len(self.history) > self.length: 25 | del self.history[0] 26 | 27 | self.val = self.history[-1] 28 | self.avg = np.mean(self.history) 29 | else: 30 | self.val = val 31 | self.sum += val * num 32 | self.count += num 33 | self.avg = self.sum / self.count 34 | -------------------------------------------------------------------------------- /utils/nsd_access.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as op 3 | import glob 4 | import nibabel as nb 5 | import numpy as np 6 | import pandas as pd 7 | from pandas import json_normalize 8 | from tqdm import tqdm 9 | import h5py 10 | import matplotlib.pyplot as plt 11 | 12 | import urllib.request 13 | import zipfile 14 | from pycocotools.coco import COCO 15 | 16 | 17 | 18 | class NSDAccess(object): 19 | """ 20 | Little class that provides easy access to the NSD data, see [http://naturalscenesdataset.org](their website) 21 | """ 22 | 23 | def __init__(self, nsd_folder, *args, **kwargs): 24 | super().__init__(*args, **kwargs) 25 | self.nsd_folder = nsd_folder 26 | self.nsddata_folder = op.join(self.nsd_folder, 'nsddata') 27 | self.ppdata_folder = op.join(self.nsd_folder, 'nsddata', 'ppdata') 28 | self.nsddata_betas_folder = op.join( 29 | self.nsd_folder, 'nsddata_betas', 'ppdata') 30 | 31 | self.behavior_file = op.join( 32 | self.ppdata_folder, '{subject}', 'behav', 'responses.tsv') 33 | self.stimuli_file = op.join( 34 | self.nsd_folder, 'nsddata_stimuli', 'stimuli', 'nsd', 'nsd_stimuli.hdf5') 35 | self.stimuli_description_file = op.join( 36 | self.nsd_folder, 'nsddata', 'experiments', 'nsd', 'nsd_stim_info_merged.csv') 37 | 38 | self.coco_annotation_file = op.join( 39 | self.nsd_folder, 'nsddata_stimuli', 'stimuli', 'nsd', 'annotations', '{}_{}.json') 40 | 41 | def download_coco_annotation_file(self, url='http://images.cocodataset.org/annotations/annotations_trainval2017.zip'): 42 | """download_coco_annotation_file downloads and extracts the relevant annotations files 43 | 44 | Parameters 45 | ---------- 46 | url : str, optional 47 | url for zip file containing annotations, by default 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip' 48 | """ 49 | print('downloading annotations from {}'.format(url)) 50 | filehandle, _ = urllib.request.urlretrieve(url) 51 | zip_file_object = zipfile.ZipFile(filehandle, 'r') 52 | zip_file_object.extractall(path=op.split( 53 | op.split(self.coco_annotation_file)[0])[0]) 54 | 55 | def affine_header(self, subject, data_format='func1pt8mm'): 56 | """affine_header affine and header, for construction of Nifti image 57 | 58 | Parameters 59 | ---------- 60 | subject : str 61 | subject identifier, such as 'subj01' 62 | data_format : str, optional 63 | what type of data format, from ['func1pt8mm', 'func1mm'], by default 'func1pt8mm' 64 | 65 | Returns 66 | ------- 67 | tuple 68 | affine and header, for construction of Nifti image 69 | """ 70 | full_path = op.join(self.ppdata_folder, 71 | '{subject}', '{data_format}', 'brainmask.nii.gz') 72 | full_path = full_path.format(subject=subject, 73 | data_format=data_format) 74 | nii = nb.load(full_path) 75 | 76 | return nii.affine, nii.header 77 | 78 | def read_vol_ppdata(self, subject, filename='brainmask', data_format='func1pt8mm'): 79 | """load_brainmask, returns boolean brainmask for volumetric data formats 80 | 81 | Parameters 82 | ---------- 83 | subject : str 84 | subject identifier, such as 'subj01' 85 | data_format : str, optional 86 | what type of data format, from ['func1pt8mm', 'func1mm'], by default 'func1pt8mm' 87 | 88 | Returns 89 | ------- 90 | numpy.ndarray, 4D (bool) 91 | brain mask array 92 | """ 93 | full_path = op.join(self.ppdata_folder, 94 | '{subject}', '{data_format}', '{filename}.nii.gz') 95 | full_path = full_path.format(subject=subject, 96 | data_format=data_format, 97 | filename=filename) 98 | return nb.load(full_path).get_data() 99 | 100 | def read_betas(self, subject, session_index, trial_index=[], data_type='betas_fithrf_GLMdenoise_RR', data_format='fsaverage', mask=None): 101 | """read_betas read betas from MRI files 102 | 103 | Parameters 104 | ---------- 105 | subject : str 106 | subject identifier, such as 'subj01' 107 | session_index : int 108 | which session, counting from 1 109 | trial_index : list, optional 110 | which trials from this session's file to return, by default [], which returns all trials 111 | data_type : str, optional 112 | which type of beta values to return from ['betas_assumehrf', 'betas_fithrf', 'betas_fithrf_GLMdenoise_RR', 'restingbetas_fithrf'], by default 'betas_fithrf_GLMdenoise_RR' 113 | data_format : str, optional 114 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm'], by default 'fsaverage' 115 | mask : numpy.ndarray, if defined, selects 'mat' data_format, needs volumetric data_format 116 | binary/boolean mask into mat file beta data format. 117 | 118 | Returns 119 | ------- 120 | numpy.ndarray, 2D (fsaverage) or 4D (other data formats) 121 | the requested per-trial beta values 122 | """ 123 | data_folder = op.join(self.nsddata_betas_folder, 124 | subject, data_format, data_type) 125 | si_str = str(session_index).zfill(2) 126 | 127 | if type(mask) == np.ndarray: # will use the mat file iff exists, otherwise boom! 128 | ipf = op.join(data_folder, f'betas_session{si_str}.mat') 129 | assert op.isfile(ipf), \ 130 | 'Error: ' + ipf + ' not available for masking. You may need to download these separately.' 131 | # will do indexing of both space and time in one go for this option, 132 | # so will return results immediately from this 133 | h5 = h5py.File(ipf, 'r') 134 | betas = h5.get('betas') 135 | # embed() 136 | if len(trial_index) == 0: 137 | trial_index = slice(0, betas.shape[0]) 138 | # this isn't finished yet - binary masks cannot be used for indexing like this 139 | return betas[trial_index, np.nonzero(mask)] 140 | 141 | if data_format == 'fsaverage': 142 | session_betas = [] 143 | for hemi in ['lh', 'rh']: 144 | hdata = nb.load(op.join( 145 | data_folder, f'{hemi}.betas_session{si_str}.mgh')).get_data() 146 | session_betas.append(hdata) 147 | out_data = np.squeeze(np.vstack(session_betas)) 148 | else: 149 | # if no mask was specified, we'll use the nifti image 150 | out_data = nb.load( 151 | op.join(data_folder, f'betas_session{si_str}.nii.gz')).get_fdata() 152 | 153 | if len(trial_index) == 0: 154 | trial_index = slice(0, out_data.shape[-1]) 155 | 156 | return out_data[..., trial_index] 157 | 158 | def read_mapper_results(self, subject, mapper='prf', data_type='angle', data_format='fsaverage'): 159 | """read_mapper_results [summary] 160 | 161 | Parameters 162 | ---------- 163 | subject : str 164 | subject identifier, such as 'subj01' 165 | mapper : str, optional 166 | first part of the mapper filename, by default 'prf' 167 | data_type : str, optional 168 | second part of the mapper filename, by default 'angle' 169 | data_format : str, optional 170 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm'], by default 'fsaverage' 171 | 172 | Returns 173 | ------- 174 | numpy.ndarray, 2D (fsaverage) or 4D (other data formats) 175 | the requested mapper values 176 | """ 177 | if data_format == 'fsaverage': 178 | # unclear for now where the fsaverage mapper results would be 179 | # as they are still in fsnative format now. 180 | raise NotImplementedError( 181 | 'no mapper results in fsaverage present for now') 182 | else: # is 'func1pt8mm' or 'func1mm' 183 | return self.read_vol_ppdata(subject=subject, filename=f'{mapper}_{data_type}', data_format=data_format) 184 | 185 | def read_atlas_results(self, subject, atlas='HCP_MMP1', data_format='fsaverage'): 186 | """read_atlas_results [summary] 187 | 188 | Parameters 189 | ---------- 190 | subject : str 191 | subject identifier, such as 'subj01' 192 | for surface-based data formats, subject should be the same as data_format. 193 | for example, for fsaverage, both subject and data_format should be 'fsaverage' 194 | this requires a little more typing but makes data format explicit 195 | atlas : str, optional 196 | which atlas to read, 197 | for volume formats, any of ['HCP_MMP1', 'Kastner2015', 'nsdgeneral', 'visualsulc'] for volume, 198 | for fsaverage 199 | can be prefixed by 'lh.' or 'rh.' for hemisphere-specific atlases in volume 200 | for surface: takes both hemispheres by default, instead when prefixed by '.rh' or '.lh'. 201 | By default 'HCP_MMP1'. 202 | data_format : str, optional 203 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm', 'MNI'], by default 'fsaverage' 204 | 205 | Returns 206 | ------- 207 | numpy.ndarray, 1D/2D (surface) or 3D/4D (volume data formats) 208 | the requested atlas values 209 | dict, 210 | dictionary containing the mapping between ROI names and atlas values 211 | """ 212 | 213 | # first, get the mapping. 214 | atlas_name = atlas 215 | if atlas[:3] in ('rh.', 'lh.'): 216 | atlas_name = atlas[3:] 217 | 218 | mapp_df = pd.read_csv(os.path.join(self.nsddata_folder, 'freesurfer', 'fsaverage', 219 | 'label', f'{atlas_name}.mgz.ctab'), delimiter=' ', header=None, index_col=0) 220 | atlas_mapping = mapp_df.to_dict()[1] 221 | # dict((y,x) for x,y in atlas_mapping.iteritems()) 222 | atlas_mapping = {y: x for x, y in atlas_mapping.items()} 223 | 224 | if data_format not in ('func1pt8mm', 'func1mm', 'MNI'): 225 | # if surface based results by exclusion 226 | if atlas[:3] in ('rh.', 'lh.'): # check if hemisphere-specific atlas requested 227 | ipf = op.join(self.nsddata_folder, 'freesurfer', 228 | subject, 'label', f'{atlas}.mgz') 229 | return np.squeeze(nb.load(ipf).get_data()), atlas_mapping 230 | else: # more than one hemisphere requested 231 | session_betas = [] 232 | for hemi in ['lh', 'rh']: 233 | hdata = nb.load(op.join( 234 | self.nsddata_folder, 'freesurfer', subject, 'label', f'{hemi}.{atlas}.mgz')).get_data() 235 | session_betas.append(hdata) 236 | out_data = np.squeeze(np.vstack(session_betas)) 237 | return out_data, atlas_mapping 238 | else: # is 'func1pt8mm', 'MNI', or 'func1mm' 239 | ipf = op.join(self.ppdata_folder, subject, 240 | data_format, 'roi', f'{atlas}.nii.gz') 241 | return nb.load(ipf).get_fdata(), atlas_mapping 242 | 243 | def list_atlases(self, subject, data_format='fsaverage', abs_paths=False): 244 | """list_atlases [summary] 245 | 246 | Parameters 247 | ---------- 248 | subject : str 249 | subject identifier, such as 'subj01' 250 | for surface-based data formats, subject should be the same as data_format. 251 | for example, for fsaverage, both subject and data_format should be 'fsaverage' 252 | this requires a little more typing but makes data format explicit 253 | data_format : str, optional 254 | what type of data format, from ['fsaverage', 'func1pt8mm', 'func1mm', 'MNI'], by default 'fsaverage' 255 | 256 | Returns 257 | ------- 258 | list 259 | collection of absolute path names to 260 | """ 261 | if data_format in ('func1pt8mm', 'func1mm', 'MNI'): 262 | atlas_files = glob.glob( 263 | op.join(self.ppdata_folder, subject, data_format, 'roi', '*.nii.gz')) 264 | else: 265 | atlas_files = glob.glob( 266 | op.join(self.nsddata_folder, 'freesurfer', subject, 'label', '*.mgz')) 267 | 268 | # print this 269 | import pprint 270 | pp = pprint.PrettyPrinter(indent=4) 271 | print('Atlases found in {}:'.format(op.split(atlas_files[0])[0])) 272 | pp.pprint([op.split(f)[1] for f in atlas_files]) 273 | if abs_paths: 274 | return atlas_files 275 | else: # this is the format which you can input into other functions, so this is the default 276 | return np.unique([op.split(f)[1].replace('lh.', '').replace('rh.', '').replace('.mgz', '').replace('.nii.gz', '') for f in atlas_files]) 277 | 278 | def read_behavior(self, subject, session_index, trial_index=[]): 279 | """read_behavior [summary] 280 | 281 | Parameters 282 | ---------- 283 | subject : str 284 | subject identifier, such as 'subj01' 285 | session_index : int 286 | which session, counting from 0 287 | trial_index : list, optional 288 | which trials from this session's behavior to return, by default [], which returns all trials 289 | 290 | Returns 291 | ------- 292 | pandas DataFrame 293 | DataFrame containing the behavioral information for the requested trials 294 | """ 295 | 296 | behavior = pd.read_csv(self.behavior_file.format( 297 | subject=subject), delimiter='\t') 298 | 299 | # the behavior is encoded per run. 300 | # I'm now setting this function up so that it aligns with the timepoints in the fmri files, 301 | # i.e. using indexing per session, and not using the 'run' information. 302 | session_behavior = behavior[behavior['SESSION'] == session_index] 303 | 304 | if len(trial_index) == 0: 305 | trial_index = slice(0, len(session_behavior)) 306 | 307 | return session_behavior.iloc[trial_index] 308 | 309 | def read_images(self, image_index, show=False): 310 | """read_images reads a list of images, and returns their data 311 | 312 | Parameters 313 | ---------- 314 | image_index : list of integers 315 | which images indexed in the 73k format to return 316 | show : bool, optional 317 | whether to also show the images, by default False 318 | 319 | Returns 320 | ------- 321 | numpy.ndarray, 3D 322 | RGB image data 323 | """ 324 | 325 | if not hasattr(self, 'stim_descriptions'): 326 | self.stim_descriptions = pd.read_csv( 327 | self.stimuli_description_file, index_col=0) 328 | 329 | sf = h5py.File(self.stimuli_file, 'r') 330 | sdataset = sf.get('imgBrick') 331 | if show: 332 | f, ss = plt.subplots(1, len(image_index), 333 | figsize=(6*len(image_index), 6)) 334 | if len(image_index) == 1: 335 | ss = [ss] 336 | for s, d in zip(ss, sdataset[image_index]): 337 | s.axis('off') 338 | s.imshow(d) 339 | return sdataset[image_index] 340 | 341 | def read_image_coco_info(self, image_index, info_type='captions', show_annot=False, show_img=False): 342 | """image_coco_info returns the coco annotations of a single image or a list of images 343 | 344 | Parameters 345 | ---------- 346 | image_index : list of integers 347 | which images indexed in the 73k format to return the captions for 348 | info_type : str, optional 349 | what type of annotation to return, from ['captions', 'person_keypoints', 'instances'], by default 'captions' 350 | show_annot : bool, optional 351 | whether to show the annotation, by default False 352 | show_img : bool, optional 353 | whether to show the image (from the nsd formatted data), by default False 354 | 355 | Returns 356 | ------- 357 | coco Annotation 358 | coco annotation, to be used in subsequent analysis steps 359 | 360 | Example 361 | ------- 362 | single image: 363 | ci = read_image_coco_info( 364 | [569], info_type='captions', show_annot=False, show_img=False) 365 | list of images: 366 | ci = read_image_coco_info( 367 | [569, 2569], info_type='captions') 368 | 369 | """ 370 | if not hasattr(self, 'stim_descriptions'): 371 | self.stim_descriptions = pd.read_csv( 372 | self.stimuli_description_file, index_col=0) 373 | if len(image_index) == 1: 374 | subj_info = self.stim_descriptions.iloc[image_index[0]] 375 | 376 | # checking whether annotation file for this trial exists. 377 | # This may not be the right place to call the download, and 378 | # re-opening the annotations for all images separately may be slowing things down 379 | # however images used in the experiment seem to have come from different sets. 380 | annot_file = self.coco_annotation_file.format( 381 | info_type, subj_info['cocoSplit']) 382 | print('getting annotations from ' + annot_file) 383 | if not os.path.isfile(annot_file): 384 | print('annotations file not found') 385 | self.download_coco_annotation_file() 386 | 387 | coco = COCO(annot_file) 388 | coco_annot_IDs = coco.getAnnIds([subj_info['cocoId']]) 389 | coco_annot = coco.loadAnns(coco_annot_IDs) 390 | 391 | if show_img: 392 | self.read_images(image_index, show=True) 393 | 394 | if show_annot: 395 | # still need to convert the annotations (especially person_keypoints and instances) to the right reference frame, 396 | # because the images were cropped. See image information per image to do this. 397 | coco.showAnns(coco_annot) 398 | 399 | elif len(image_index) > 1: 400 | 401 | # we output a list of annots 402 | coco_annot = [] 403 | 404 | # load train_2017 405 | annot_file = self.coco_annotation_file.format( 406 | info_type, 'train2017') 407 | coco_train = COCO(annot_file) 408 | 409 | # also load the val 2017 410 | annot_file = self.coco_annotation_file.format( 411 | info_type, 'val2017') 412 | coco_val = COCO(annot_file) 413 | 414 | for image in image_index: 415 | subj_info = self.stim_descriptions.iloc[image] 416 | if subj_info['cocoSplit'] == 'train2017': 417 | coco_annot_IDs = coco_train.getAnnIds( 418 | [subj_info['cocoId']]) 419 | coco_ann = coco_train.loadAnns(coco_annot_IDs) 420 | coco_annot.append(coco_ann) 421 | 422 | elif subj_info['cocoSplit'] == 'val2017': 423 | coco_annot_IDs = coco_val.getAnnIds( 424 | [subj_info['cocoId']]) 425 | coco_ann = coco_val.loadAnns(coco_annot_IDs) 426 | coco_annot.append(coco_ann) 427 | 428 | return coco_annot 429 | 430 | def read_image_coco_category(self, image_index): 431 | """image_coco_category returns the coco category of a single image or a list of images 432 | 433 | Args: 434 | image_index ([list of integers]): which images indexed in the 73k format to return 435 | the category for 436 | 437 | Returns 438 | ------- 439 | coco category 440 | coco category, to be used in subsequent analysis steps 441 | 442 | Example 443 | ------- 444 | single image: 445 | ci = read_image_coco_category( 446 | [569]) 447 | list of images: 448 | ci = read_image_coco_category( 449 | [569, 2569]) 450 | """ 451 | 452 | if not hasattr(self, 'stim_descriptions'): 453 | self.stim_descriptions = pd.read_csv( 454 | self.stimuli_description_file, index_col=0) 455 | 456 | if len(image_index) == 1: 457 | subj_info = self.stim_descriptions.iloc[image_index[0]] 458 | coco_id = subj_info['cocoId'] 459 | 460 | # checking whether annotation file for this trial exists. 461 | # This may not be the right place to call the download, and 462 | # re-opening the annotations for all images separately may be slowing things down 463 | # however images used in the experiment seem to have come from different sets. 464 | annot_file = self.coco_annotation_file.format( 465 | 'instances', subj_info['cocoSplit']) 466 | print('getting annotations from ' + annot_file) 467 | if not os.path.isfile(annot_file): 468 | print('annotations file not found') 469 | self.download_coco_annotation_file() 470 | 471 | coco = COCO(annot_file) 472 | 473 | cat_ids = coco.getCatIds() 474 | categories = json_normalize(coco.loadCats(cat_ids)) 475 | 476 | coco_cats = [] 477 | for cat_id in cat_ids: 478 | this_img_list = coco.getImgIds(catIds=[cat_id]) 479 | if coco_id in this_img_list: 480 | this_cat = np.asarray(categories[categories['id']==cat_id]['name'])[0] 481 | coco_cats.append(this_cat) 482 | 483 | elif len(image_index) > 1: 484 | 485 | # we output a list of annots 486 | coco_cats = [] 487 | 488 | # load train_2017 489 | annot_file = self.coco_annotation_file.format( 490 | 'instances', 'train2017') 491 | coco_train = COCO(annot_file) 492 | cat_ids_train = coco_train.getCatIds() 493 | categories_train = json_normalize(coco_train.loadCats(cat_ids_train)) 494 | 495 | # also load the val 2017 496 | annot_file = self.coco_annotation_file.format( 497 | 'instances', 'val2017') 498 | coco_val = COCO(annot_file) 499 | cat_ids_val = coco_val.getCatIds() 500 | categories_val = json_normalize(coco_val.loadCats(cat_ids_val)) 501 | 502 | for image in tqdm(image_index, bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}'): 503 | subj_info = self.stim_descriptions.iloc[image] 504 | coco_id = subj_info['cocoId'] 505 | image_cat = [] 506 | if subj_info['cocoSplit'] == 'train2017': 507 | for cat_id in cat_ids_train: 508 | this_img_list = coco_train.getImgIds(catIds=[cat_id]) 509 | if coco_id in this_img_list: 510 | this_cat = np.asarray(categories_train[categories_train['id']==cat_id]['name'])[0] 511 | image_cat.append(this_cat) 512 | 513 | elif subj_info['cocoSplit'] == 'val2017': 514 | for cat_id in cat_ids_val: 515 | this_img_list = coco_val.getImgIds(catIds=[cat_id]) 516 | if coco_id in this_img_list: 517 | this_cat = np.asarray(categories_val[categories_val['id']==cat_id]['name'])[0] 518 | image_cat.append(this_cat) 519 | coco_cats.append(image_cat) 520 | return coco_cats 521 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import torch 4 | import torch.nn.functional as F 5 | import matplotlib.pyplot as plt 6 | from utils.nsd_access import NSDAccess 7 | 8 | logs = set() 9 | 10 | def init_log(name, level=logging.INFO): 11 | if (name, level) in logs: 12 | return 13 | logs.add((name, level)) 14 | logger = logging.getLogger(name) 15 | logger.setLevel(level) 16 | ch = logging.StreamHandler() 17 | ch.setLevel(level) 18 | if "SLURM_PROCID" in os.environ: 19 | rank = int(os.environ["SLURM_PROCID"]) 20 | logger.addFilter(lambda record: rank == 0) 21 | else: 22 | rank = 0 23 | format_str = "[%(asctime)s][%(levelname)8s] %(message)s" 24 | formatter = logging.Formatter(format_str) 25 | ch.setFormatter(formatter) 26 | logger.addHandler(ch) 27 | return logger 28 | 29 | 30 | def batchwise_cosine_similarity(Z,B): 31 | # https://www.h4pz.co/blog/2021/4/2/batch-cosine-similarity-in-pytorch-or-numpy-jax-cupy-etc 32 | B = B.T 33 | Z_norm = torch.linalg.norm(Z, dim=1, keepdim=True) # Size (n, 1). 34 | B_norm = torch.linalg.norm(B, dim=0, keepdim=True) # Size (1, b). 35 | cosine_similarity = ((Z @ B) / (Z_norm @ B_norm)).T 36 | return cosine_similarity 37 | 38 | def topk(similarities,labels,k=5): 39 | if k > similarities.shape[0]: 40 | k = similarities.shape[0] 41 | topsum=0 42 | for i in range(k): 43 | topsum += torch.sum(torch.argsort(similarities,axis=1)[:,-(i+1)] == labels)/len(labels) 44 | return topsum 45 | 46 | def mixco(voxels, beta=0.15, s_thresh=0.5): 47 | perm = torch.randperm(voxels.shape[0]) 48 | voxels_shuffle = voxels[perm].to(voxels.device,dtype=voxels.dtype) 49 | betas = torch.distributions.Beta(beta, beta).sample([voxels.shape[0]]).to(voxels.device,dtype=voxels.dtype) 50 | select = (torch.rand(voxels.shape[0]) <= s_thresh).to(voxels.device) 51 | betas_shape = [-1] + [1]*(len(voxels.shape)-1) 52 | voxels[select] = voxels[select] * betas[select].reshape(*betas_shape) + \ 53 | voxels_shuffle[select] * (1 - betas[select]).reshape(*betas_shape) 54 | betas[~select] = 1 55 | return voxels, perm, betas, select 56 | 57 | 58 | def mixco_nce(preds, targs, temp=0.1, perm=None, betas=None, select=None, distributed=False, 59 | accelerator=None, local_rank=None, bidirectional=True): 60 | brain_clip = (preds @ targs.T)/temp 61 | 62 | if perm is not None and betas is not None and select is not None: 63 | probs = torch.diag(betas) 64 | probs[torch.arange(preds.shape[0]).to(preds.device), perm] = 1 - betas 65 | 66 | loss = -(brain_clip.log_softmax(-1) * probs).sum(-1).mean() 67 | if bidirectional: 68 | loss2 = -(brain_clip.T.log_softmax(-1) * probs.T).sum(-1).mean() 69 | loss = (loss + loss2)/2 70 | return loss 71 | else: 72 | loss = F.cross_entropy(brain_clip, torch.arange(brain_clip.shape[0]).to(brain_clip.device)) 73 | if bidirectional: 74 | loss2 = F.cross_entropy(brain_clip.T, torch.arange(brain_clip.shape[0]).to(brain_clip.device)) 75 | loss = (loss + loss2)/2 76 | return loss 77 | 78 | def simple_nce(preds, targs, logit_scale=3.51): 79 | # preds = preds.to(torch.float16) 80 | # targs = targs.to(torch.float16) 81 | brain_clip = logit_scale * preds @ targs.T 82 | clip_brain = logit_scale * targs @ preds.T 83 | 84 | labels = torch.arange(brain_clip.shape[0]).to(brain_clip.device) 85 | loss = F.cross_entropy(brain_clip, labels).to(brain_clip.device) 86 | #if bidirectional: 87 | loss2 = F.cross_entropy(clip_brain, labels).to(clip_brain.device) 88 | loss = (loss + loss2)/2 89 | return loss 90 | 91 | def cosine_anneal(start, end, steps): 92 | return end + (start - end)/2 * (1 + torch.cos(torch.pi*torch.arange(steps)/(steps-1))) 93 | 94 | 95 | def soft_clip_loss(preds, targs, temp=0.05, eps=1e-10): 96 | 97 | clip_clip = (targs @ targs.T)/temp + eps 98 | brain_clip = (preds @ targs.T)/temp + eps 99 | 100 | loss1 = -(brain_clip.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean() 101 | loss2 = -(brain_clip.T.log_softmax(-1) * clip_clip.softmax(-1)).sum(-1).mean() 102 | 103 | loss = (loss1 + loss2)/2 104 | return loss 105 | 106 | def prepare_coco(data_path): 107 | # Preload coco captions 108 | nsda = NSDAccess(data_path) 109 | coco_73k = list(range(0, 73000)) 110 | prompts_list = nsda.read_image_coco_info(coco_73k,info_type='captions') 111 | 112 | print("coco captions loaded.") 113 | 114 | return prompts_list 115 | 116 | 117 | 118 | from torchvision import transforms 119 | import numpy as np 120 | import clip 121 | import torch.nn as nn 122 | import random 123 | 124 | 125 | def seed_everything(seed=0, cudnn_deterministic=True): 126 | random.seed(seed) 127 | os.environ['PYTHONHASHSEED'] = str(seed) 128 | np.random.seed(seed) 129 | torch.manual_seed(seed) 130 | torch.cuda.manual_seed(seed) 131 | torch.cuda.manual_seed_all(seed) 132 | if cudnn_deterministic: 133 | torch.backends.cudnn.deterministic = True 134 | else: 135 | ## needs to be False to use conv3D 136 | print('Note: not using cudnn.deterministic') 137 | 138 | 139 | 140 | class Clipper(torch.nn.Module): 141 | def __init__(self, clip_variant, clamp_embs=False, norm_embs=False, 142 | hidden_state=False, device=torch.device('cpu')): 143 | super().__init__() 144 | assert clip_variant in ("RN50", "ViT-L/14", "ViT-B/32", "RN50x64"), \ 145 | "clip_variant must be one of RN50, ViT-L/14, ViT-B/32, RN50x64" 146 | print(clip_variant, device) 147 | 148 | if clip_variant=="ViT-L/14" and hidden_state: 149 | from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection, CLIPTokenizer 150 | image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").eval() 151 | image_encoder = image_encoder.to(device) 152 | for param in image_encoder.parameters(): 153 | param.requires_grad = False # dont need to calculate gradients 154 | self.image_encoder = image_encoder 155 | 156 | text_encoder = CLIPTextModelWithProjection.from_pretrained("openai/clip-vit-large-patch14").eval() 157 | text_encoder = text_encoder.to(device) 158 | for param in text_encoder.parameters(): 159 | param.requires_grad = False # dont need to calculate gradients 160 | self.text_encoder = text_encoder 161 | self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14") 162 | 163 | elif hidden_state: 164 | raise Exception("hidden_state embeddings only works with ViT-L/14 right now") 165 | 166 | clip_model, preprocess = clip.load(clip_variant, device=device) 167 | clip_model.eval() # dont want to train model 168 | for param in clip_model.parameters(): 169 | param.requires_grad = False # dont need to calculate gradients 170 | 171 | self.clip = clip_model 172 | self.clip_variant = clip_variant 173 | if clip_variant == "RN50x64": 174 | self.clip_size = (448,448) 175 | else: 176 | self.clip_size = (224,224) 177 | 178 | preproc = transforms.Compose([ 179 | transforms.Resize(size=self.clip_size[0], interpolation=transforms.InterpolationMode.BICUBIC, antialias=None), 180 | transforms.CenterCrop(size=self.clip_size), 181 | transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) 182 | ]) 183 | self.preprocess = preproc 184 | self.hidden_state = hidden_state 185 | self.mean = np.array([0.48145466, 0.4578275, 0.40821073]) 186 | self.std = np.array([0.26862954, 0.26130258, 0.27577711]) 187 | self.normalize = transforms.Normalize(self.mean, self.std) 188 | self.denormalize = transforms.Normalize((-self.mean / self.std).tolist(), (1.0 / self.std).tolist()) 189 | self.clamp_embs = clamp_embs 190 | self.norm_embs = norm_embs 191 | self.device= device 192 | 193 | def versatile_normalize_embeddings(encoder_output): 194 | embeds = encoder_output.last_hidden_state 195 | embeds = image_encoder.vision_model.post_layernorm(embeds) 196 | embeds = image_encoder.visual_projection(embeds) 197 | return embeds 198 | self.versatile_normalize_embeddings = versatile_normalize_embeddings 199 | 200 | def resize_image(self, image): 201 | # note: antialias should be False if planning to use Pinkney's Image Variation SD model 202 | return transforms.Resize(self.clip_size, antialias=None)(image.to(self.device)) 203 | 204 | def embed_image(self, image): 205 | """Expects images in -1 to 1 range""" 206 | if self.hidden_state: 207 | # clip_emb = self.preprocess((image/1.5+.25).to(self.device)) # for some reason the /1.5+.25 prevents oversaturation 208 | clip_emb = self.preprocess((image).to(self.device)) 209 | clip_emb = self.image_encoder(clip_emb) 210 | clip_emb = self.versatile_normalize_embeddings(clip_emb) 211 | else: 212 | clip_emb = self.preprocess(image.to(self.device)) 213 | clip_emb = self.clip.encode_image(clip_emb) 214 | # input is now in CLIP space, but mind-reader preprint further processes embeddings: 215 | if self.clamp_embs: 216 | clip_emb = torch.clamp(clip_emb, -1.5, 1.5) 217 | if self.norm_embs: 218 | if self.hidden_state: 219 | # normalize all tokens by cls token's norm 220 | clip_emb = clip_emb / torch.norm(clip_emb[:, 0], dim=-1).reshape(-1, 1, 1) 221 | else: 222 | clip_emb = nn.functional.normalize(clip_emb, dim=-1) 223 | return clip_emb 224 | 225 | def embed_text(self, prompt): 226 | r""" 227 | Encodes the prompt into text encoder hidden states. 228 | 229 | Args: 230 | prompt (`str` or `List[str]`): 231 | prompt to be encoded 232 | device: (`torch.device`): 233 | torch device 234 | num_images_per_prompt (`int`): 235 | number of images that should be generated per prompt 236 | do_classifier_free_guidance (`bool`): 237 | whether to use classifier free guidance or not 238 | """ 239 | 240 | def normalize_embeddings(encoder_output): 241 | embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) 242 | embeds_pooled = encoder_output.text_embeds 243 | embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True) 244 | return embeds 245 | 246 | text_inputs = self.tokenizer( 247 | prompt, 248 | padding="max_length", 249 | max_length=self.tokenizer.model_max_length, 250 | truncation=True, 251 | return_tensors="pt", 252 | ) 253 | text_input_ids = text_inputs.input_ids 254 | untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="pt").input_ids 255 | with torch.no_grad(): 256 | prompt_embeds = self.text_encoder( 257 | text_input_ids.to(self.device), 258 | ) 259 | prompt_embeds = normalize_embeddings(prompt_embeds) 260 | 261 | # duplicate text embeddings for each generation per prompt, using mps friendly method 262 | # bs_embed, seq_len, _ = prompt_embeds.shape 263 | # prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 264 | # prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 265 | 266 | return prompt_embeds 267 | 268 | def embed_curated_annotations(self, annots): 269 | for i,b in enumerate(annots): 270 | t = '' 271 | while t == '': 272 | rand = torch.randint(5,(1,1))[0][0] 273 | t = b[0,rand] 274 | if i==0: 275 | txt = np.array(t) 276 | else: 277 | txt = np.vstack((txt,t)) 278 | txt = txt.flatten() 279 | return self.embed_text(txt) 280 | 281 | def torch_to_Image(x): 282 | if x.ndim==4: 283 | x=x[0] 284 | return transforms.ToPILImage()(x) 285 | 286 | def Image_to_torch(x): 287 | try: 288 | x = (transforms.ToTensor()(x)[:3].unsqueeze(0)-.5)/.5 289 | except: 290 | x = (transforms.ToTensor()(x[0])[:3].unsqueeze(0)-.5)/.5 291 | return x 292 | 293 | 294 | def decode_latents(latents,vae): 295 | latents = 1 / 0.18215 * latents 296 | image = vae.decode(latents).sample 297 | image = (image / 2 + 0.5).clamp(0, 1) 298 | return image 299 | 300 | 301 | def combine_with_memory( 302 | pred_embedding_vision, pred_embedding_text, 303 | clip_vision_train, clip_text_train, 304 | clip_vision_train_norm, clip_text_train_norm, 305 | clip_image_target, clip_text_target, 306 | clip_image_target_norm, clip_text_target_norm 307 | ): 308 | pred_embedding_text = pred_embedding_text.to(clip_text_train_norm.device) 309 | clip_image_target_norm = clip_image_target_norm.to(clip_text_train_norm.device) 310 | pred_embedding_vision = pred_embedding_vision.to(clip_text_train_norm.device) 311 | clip_text_target_norm = clip_text_target_norm.to(clip_text_train_norm.device) 312 | alpha = 0.1 313 | pred_embedding_text_norm = nn.functional.normalize(pred_embedding_text.flatten(1), dim=-1) 314 | pred_embedding_vision_norm = nn.functional.normalize(pred_embedding_vision.flatten(1), dim=-1) 315 | 316 | similarity_text = batchwise_cosine_similarity(pred_embedding_text_norm, clip_text_train_norm) 317 | similarity_vision = batchwise_cosine_similarity(pred_embedding_vision_norm, clip_vision_train_norm) 318 | # similarity_text = clip_text_train_norm @ pred_embedding_text_norm.T 319 | # similarity_vision = clip_vision_train_norm @ pred_embedding_vision_norm.T 320 | 321 | # Target with train data 322 | similarity_text_tar_with_train = batchwise_cosine_similarity(clip_text_target_norm, clip_text_train_norm) 323 | similarity_vision_tar_with_train = batchwise_cosine_similarity(clip_image_target_norm, clip_vision_train_norm) 324 | topk_index_text_tar_with_train = torch.topk(similarity_text_tar_with_train.flatten(), 1).indices 325 | topk_index_vision_tar_with_train = torch.topk(similarity_vision_tar_with_train.flatten(), 1).indices 326 | 327 | topk_index_text = torch.topk(similarity_text.flatten(), 1).indices 328 | topk_index_vision = torch.topk(similarity_vision.flatten(), 1).indices 329 | # similarity_vision_by_text = pred_embedding_vision_norm @ (nn.functional.normalize(clip_vision_train[topk_index_text].flatten(1), dim=-1)).T 330 | similarity_vision_by_text = batchwise_cosine_similarity(pred_embedding_vision_norm, (nn.functional.normalize(clip_vision_train[topk_index_text].flatten(1), dim=-1))) 331 | print('\n Alpha' , alpha) 332 | print('target 跟 train data 的最大相似度') 333 | print('Top indices text tar retrival train data:',topk_index_text_tar_with_train,'Top similarity_text retrival train data:', torch.topk(similarity_text_tar_with_train.flatten(), 1).values) 334 | print('Top indices vision tar retrival train data:',topk_index_vision_tar_with_train,'Top similarity_vision retrival train data', torch.topk(similarity_vision_tar_with_train.flatten(), 1).values) 335 | 336 | print('预测在train data 中 top 10') 337 | print('Top 10 index text:',torch.topk(similarity_text.flatten(), 10).indices,'Top 10 similarity_text:', torch.topk(similarity_text.flatten(), 10).values) 338 | print('Top 10 index vision:',torch.topk(similarity_vision.flatten(), 10).indices,'Top 10 similarity_vision', torch.topk(similarity_vision.flatten(), 10).values) 339 | 340 | print('target 在train data 中 top 10') 341 | print('Top 10 index text:',torch.topk(similarity_text_tar_with_train.flatten(), 10).indices,'Top 10 similarity_text:', torch.topk(similarity_text_tar_with_train.flatten(), 10).values) 342 | print('Top 10 index vision:',torch.topk(similarity_vision_tar_with_train.flatten(), 10).indices,'Top 10 similarity_vision', torch.topk(similarity_vision_tar_with_train.flatten(), 10).values) 343 | 344 | print('预测与train data 中的最大相似度') 345 | print('Top indices text:',topk_index_text,'Top similarity_text:', torch.topk(similarity_text.flatten(), 1).values) 346 | print('Top indices vision:',topk_index_vision,'Top similarity_vision', torch.topk(similarity_vision.flatten(), 1).values) 347 | print('Top similarity_vision_by_text',torch.topk(similarity_vision_by_text.flatten(), 1).values) 348 | 349 | 350 | 351 | # import ipdb;ipdb.set_trace() 352 | combined_brain_clip_text_embeddings = (1-alpha) * clip_text_train[topk_index_text] + alpha * pred_embedding_text 353 | combined_brain_clip_image_embeddings = (1-alpha) * clip_vision_train[topk_index_vision] + alpha * pred_embedding_vision 354 | combined_brain_clip_image_embeddings_by_text = (1-alpha) * clip_vision_train[topk_index_text] + alpha * pred_embedding_vision 355 | 356 | combined_brain_clip_text_embeddings_using_target = (1-alpha) * clip_text_train[topk_index_text_tar_with_train] + alpha * pred_embedding_text 357 | combined_brain_clip_image_embeddings_using_target = (1-alpha) * clip_vision_train[topk_index_vision_tar_with_train] + alpha * pred_embedding_vision 358 | combined_brain_clip_image_embeddings_by_text_using_target = (1-alpha) * clip_vision_train[topk_index_text_tar_with_train] + alpha * pred_embedding_vision 359 | 360 | 361 | retrivaled_embedding_text_norm = nn.functional.normalize(clip_text_train[topk_index_text].flatten(1), dim=-1) 362 | retrivaled_embedding_vision_norm = nn.functional.normalize(clip_vision_train[topk_index_vision].flatten(1), dim=-1) 363 | 364 | similarity_text_retrival = batchwise_cosine_similarity(retrivaled_embedding_text_norm, clip_text_target_norm) 365 | similarity_vision_retrival = batchwise_cosine_similarity(retrivaled_embedding_vision_norm, clip_image_target_norm) 366 | print('检索到的跟tar之间的相似度') 367 | print('Similarity_retrival_text_with_tar', similarity_text_retrival) 368 | print('Similarity_retrival_vision_with_tar', similarity_vision_retrival) 369 | 370 | combined_brain_clip_text_embeddings_norm = nn.functional.normalize(combined_brain_clip_text_embeddings.flatten(1), dim=-1) 371 | combined_brain_clip_image_embeddings_norm = nn.functional.normalize(combined_brain_clip_image_embeddings.flatten(1), dim=-1) 372 | combined_brain_clip_image_embeddings_by_text_norm = nn.functional.normalize(combined_brain_clip_image_embeddings_by_text.flatten(1), dim=-1) 373 | 374 | 375 | combined_brain_clip_text_embeddings_norm_using_tar = nn.functional.normalize(combined_brain_clip_text_embeddings_using_target.flatten(1), dim=-1) 376 | combined_brain_clip_image_embeddings_norm_using_tar = nn.functional.normalize(combined_brain_clip_image_embeddings_using_target.flatten(1), dim=-1) 377 | combined_brain_clip_image_embeddings_by_text_norm_using_tar = nn.functional.normalize(combined_brain_clip_image_embeddings_by_text_using_target.flatten(1), dim=-1) 378 | 379 | similarity_text_after_using_tar = batchwise_cosine_similarity(clip_text_target_norm , combined_brain_clip_text_embeddings_norm_using_tar) 380 | similarity_vision_after_using_tar = batchwise_cosine_similarity(clip_image_target_norm , combined_brain_clip_image_embeddings_norm_using_tar) 381 | similarity_vision_after_by_text_using_tar = batchwise_cosine_similarity(clip_image_target_norm , combined_brain_clip_image_embeddings_by_text_norm_using_tar) 382 | print('使用 target 检索 进行合并后与 target 的相似度:') 383 | print('Similarity_text_combined_with_tar_using_tar:', torch.topk(similarity_text_after_using_tar.flatten(), 1).values) 384 | print('Similarity_vision_combined_with_combined_using_tar', torch.topk(similarity_vision_after_using_tar.flatten(), 1).values) 385 | print('Similarity_vision_combined_with_combined_by_text_using_tar', torch.topk(similarity_vision_after_by_text_using_tar.flatten(), 1).values) 386 | 387 | # print('Similarity_text_combined_with_tar:', torch.topk(similarity_text_after.flatten(), 1).values) 388 | # print('Similarity_vision_combined_with_combined', torch.topk(similarity_vision_after.flatten(), 1).values) 389 | # print('Similarity_vision_combined_with_combined_by_text', torch.topk(similarity_vision_after_by_text.flatten(), 1).values) 390 | # similarity_text_before = pred_embedding_text_norm @ clip_text_target_norm.T 391 | # similarity_vision_before = pred_embedding_vision_norm @ clip_image_target_norm.T 392 | 393 | similarity_text_before = batchwise_cosine_similarity(pred_embedding_text_norm, clip_text_target_norm) 394 | similarity_vision_before = batchwise_cosine_similarity(pred_embedding_vision_norm, clip_image_target_norm) 395 | 396 | # similarity_text_after = clip_text_target_norm @ combined_brain_clip_text_embeddings_norm.T 397 | # similarity_vision_after = clip_image_target_norm @ combined_brain_clip_image_embeddings_norm.T 398 | # similarity_vision_after_by_text = clip_image_target_norm @ combined_brain_clip_image_embeddings_by_text_norm.T 399 | 400 | similarity_text_after = batchwise_cosine_similarity(clip_text_target_norm , combined_brain_clip_text_embeddings_norm) 401 | similarity_vision_after = batchwise_cosine_similarity(clip_image_target_norm , combined_brain_clip_image_embeddings_norm) 402 | similarity_vision_after_by_text = batchwise_cosine_similarity(clip_image_target_norm , combined_brain_clip_image_embeddings_by_text_norm) 403 | print('pred 跟 target 相似度') 404 | print('Similarity_text_pred_with_tar:', torch.topk(similarity_text_before.flatten(), 1).values) 405 | print('Similarity_vision_pred_with_tar', torch.topk(similarity_vision_before.flatten(), 1).values) 406 | print('pred 自己检索合并后 跟 target 相似度') 407 | print('Similarity_text_combined_with_tar:', torch.topk(similarity_text_after.flatten(), 1).values) 408 | print('Similarity_vision_combined_with_tar', torch.topk(similarity_vision_after.flatten(), 1).values) 409 | print('Similarity_vision_combined_with_tar_by_text', torch.topk(similarity_vision_after_by_text.flatten(), 1).values) 410 | 411 | # similarity_text_after_combine = batchwise_cosine_similarity(clip_text_train_norm , combined_brain_clip_text_embeddings_norm) 412 | # topk_index_text_after_combine = torch.topk(similarity_text_after_combine.flatten(), 1).indices 413 | 414 | # print('Top indices text after combine:',topk_index_text_after_combine,'Top similarity_text_after_combine:', torch.topk(similarity_text_after_combine.flatten(), 1).values) 415 | # combined_brain_clip_image_embeddings_by_text_after_combine = (1-alpha) * clip_vision_train[topk_index_text] + alpha * pred_embedding_vision 416 | # combined_brain_clip_image_embeddings_by_text_after_combine_norm = nn.functional.normalize(combined_brain_clip_image_embeddings_by_text_after_combine.flatten(1), dim=-1) 417 | # similarity_vision_after_by_text_after_combine = batchwise_cosine_similarity(clip_image_target_norm , combined_brain_clip_image_embeddings_by_text_after_combine_norm) 418 | # print('Similarity_vision_combined_with_combined_by_text_after_combine', torch.topk(similarity_vision_after_by_text_after_combine.flatten(), 1).values) 419 | 420 | 421 | # return combined_brain_clip_image_embeddings, combined_brain_clip_text_embeddings 422 | return combined_brain_clip_image_embeddings_using_target, combined_brain_clip_text_embeddings_using_target 423 | 424 | @torch.no_grad() 425 | def reconstruction( 426 | args, 427 | image, voxel, captions, 428 | clip_vision_train, clip_text_train, 429 | clip_vision_train_norm, clip_text_train_norm, 430 | voxel2clip, 431 | clip_extractor, 432 | unet, vae, noise_scheduler, 433 | img_lowlevel = None, 434 | num_inference_steps = 50, 435 | recons_per_sample = 1, 436 | guidance_scale = 7.5, 437 | img2img_strength = .85, 438 | seed = 42, 439 | plotting=True, 440 | verbose=False, 441 | n_samples_save=1, 442 | device = None, 443 | mem_efficient = True, 444 | retrival_from_memory = False, 445 | 446 | ): 447 | assert n_samples_save==1, "n_samples_save must = 1. Function must be called one image at a time" 448 | assert recons_per_sample>0, "recons_per_sample must > 0" 449 | 450 | brain_recons = None 451 | 452 | voxel=voxel[:n_samples_save] 453 | image=image[:n_samples_save] 454 | B = voxel.shape[0] 455 | 456 | clip_image_target = clip_extractor.embed_image(image) 457 | clip_text_target = clip_extractor.embed_text(captions) 458 | 459 | clip_image_target_norm = nn.functional.normalize(clip_image_target.flatten(1), dim=-1) 460 | clip_text_target_norm = nn.functional.normalize(clip_text_target.flatten(1), dim=-1) 461 | 462 | if mem_efficient: 463 | clip_extractor.to("cpu") 464 | unet.to("cpu") 465 | vae.to("cpu") 466 | else: 467 | clip_extractor.to(device) 468 | unet.to(device) 469 | vae.to(device) 470 | 471 | if unet is not None: 472 | do_classifier_free_guidance = guidance_scale > 1.0 473 | vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) 474 | height = unet.config.sample_size * vae_scale_factor 475 | width = unet.config.sample_size * vae_scale_factor 476 | generator = torch.Generator(device=device) 477 | generator.manual_seed(seed) 478 | 479 | if voxel2clip is not None: 480 | ridge_out = voxel2clip.ridge(voxel) 481 | clip_results = voxel2clip.backbone(ridge_out) 482 | if mem_efficient: 483 | voxel2clip.to('cpu') 484 | # brain_clip_text_embeddings = clip_extractor.embed_text(captions).float() 485 | brain_clip_image_embeddings, brain_clip_text_embeddings = clip_results[:2] 486 | if retrival_from_memory: 487 | brain_clip_image_embeddings, brain_clip_text_embeddings = combine_with_memory( 488 | brain_clip_image_embeddings, brain_clip_text_embeddings, 489 | clip_vision_train, clip_text_train, 490 | clip_vision_train_norm, clip_text_train_norm, 491 | clip_image_target, clip_text_target, 492 | clip_image_target_norm, clip_text_target_norm) 493 | # import ipdb;ipdb.set_trace() 494 | brain_clip_image_embeddings = brain_clip_image_embeddings.reshape(B,-1,768) 495 | brain_clip_text_embeddings = brain_clip_text_embeddings.reshape(B,-1,768) 496 | 497 | brain_clip_image_embeddings = brain_clip_image_embeddings.repeat(recons_per_sample, 1, 1) 498 | brain_clip_text_embeddings = brain_clip_text_embeddings.repeat(recons_per_sample, 1, 1) 499 | 500 | if recons_per_sample > 0: 501 | for samp in range(len(brain_clip_image_embeddings)): 502 | brain_clip_image_embeddings[samp] = brain_clip_image_embeddings[samp]/(brain_clip_image_embeddings[samp,0].norm(dim=-1).reshape(-1, 1, 1) + 1e-6) 503 | brain_clip_text_embeddings[samp] = brain_clip_text_embeddings[samp]/(brain_clip_text_embeddings[samp,0].norm(dim=-1).reshape(-1, 1, 1) + 1e-6) 504 | input_embedding = brain_clip_image_embeddings#.repeat(recons_per_sample, 1, 1) 505 | if verbose: print("input_embedding",input_embedding.shape) 506 | 507 | prompt_embeds = brain_clip_text_embeddings 508 | if verbose: print("prompt_embedding",prompt_embeds.shape) 509 | 510 | if do_classifier_free_guidance: 511 | input_embedding = torch.cat([torch.zeros_like(input_embedding), input_embedding]).to(device).to(unet.dtype) 512 | prompt_embeds = torch.cat([torch.zeros_like(prompt_embeds), prompt_embeds]).to(device).to(unet.dtype) 513 | 514 | # 3. dual_prompt_embeddings 515 | input_embedding = torch.cat([prompt_embeds, input_embedding], dim=1) 516 | 517 | # 4. Prepare timesteps 518 | noise_scheduler.set_timesteps(num_inference_steps=num_inference_steps, device=device) 519 | 520 | # 5b. Prepare latent variables 521 | batch_size = input_embedding.shape[0] // 2 # divide by 2 bc we doubled it for classifier-free guidance 522 | shape = (batch_size, unet.in_channels, height // vae_scale_factor, width // vae_scale_factor) 523 | if img_lowlevel is not None: # use img_lowlevel for img2img initialization 524 | init_timestep = min(int(num_inference_steps * img2img_strength), num_inference_steps) 525 | t_start = max(num_inference_steps - init_timestep, 0) 526 | timesteps = noise_scheduler.timesteps[t_start:] 527 | latent_timestep = timesteps[:1].repeat(batch_size) 528 | 529 | if verbose: print("img_lowlevel", img_lowlevel.shape) 530 | img_lowlevel_embeddings = clip_extractor.normalize(img_lowlevel) 531 | if verbose: print("img_lowlevel_embeddings", img_lowlevel_embeddings.shape) 532 | if mem_efficient: 533 | vae.to(device) 534 | init_latents = vae.encode(img_lowlevel_embeddings.to(device).to(vae.dtype)).latent_dist.sample(generator) 535 | init_latents = vae.config.scaling_factor * init_latents 536 | init_latents = init_latents.repeat(recons_per_sample, 1, 1, 1) 537 | 538 | noise = torch.randn([recons_per_sample, 4, 64, 64], device=device, 539 | generator=generator, dtype=input_embedding.dtype) 540 | init_latents = noise_scheduler.add_noise(init_latents, noise, latent_timestep) 541 | latents = init_latents 542 | else: 543 | timesteps = noise_scheduler.timesteps 544 | latents = torch.randn([recons_per_sample, 4, 64, 64], device=device, 545 | generator=generator, dtype=input_embedding.dtype) 546 | latents = latents * noise_scheduler.init_noise_sigma 547 | 548 | # 7. Denoising loop 549 | if mem_efficient: 550 | unet.to(device) 551 | for i, t in enumerate(timesteps): 552 | # expand the latents if we are doing classifier free guidance 553 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 554 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 555 | if verbose: print("timesteps: {}, latent_model_input: {}, input_embedding: {}".format(i, latent_model_input.shape, input_embedding.shape)) 556 | noise_pred = unet(latent_model_input, t, encoder_hidden_states=input_embedding).sample 557 | 558 | # perform guidance 559 | if do_classifier_free_guidance: 560 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 561 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 562 | 563 | # compute the previous noisy sample x_t -> x_t-1 564 | latents = noise_scheduler.step(noise_pred, t, latents).prev_sample 565 | 566 | if mem_efficient: 567 | unet.to("cpu") 568 | 569 | recons = decode_latents(latents.to(device),vae.to(device)).detach().cpu() 570 | 571 | brain_recons = recons.unsqueeze(0) 572 | 573 | if verbose: print("brain_recons",brain_recons.shape) 574 | 575 | # pick best reconstruction out of several 576 | best_picks = np.zeros(n_samples_save).astype(np.int16) 577 | 578 | if mem_efficient: 579 | vae.to("cpu") 580 | unet.to("cpu") 581 | clip_extractor.to(device) 582 | 583 | # clip_image_target = clip_extractor.embed_image(image) 584 | # clip_image_target_norm = nn.functional.normalize(clip_image_target.flatten(1), dim=-1) 585 | sims=[] 586 | for im in range(recons_per_sample): 587 | currecon = clip_extractor.embed_image(brain_recons[0,[im]].float()).to(clip_image_target_norm.device).to(clip_image_target_norm.dtype) 588 | currecon = nn.functional.normalize(currecon.view(len(currecon),-1),dim=-1) 589 | # import ipdb;ipdb.set_trace() 590 | cursim = batchwise_cosine_similarity(clip_image_target_norm,currecon) 591 | sims.append(cursim.item()) 592 | if verbose: print(sims) 593 | best_picks[0] = int(np.nanargmax(sims)) 594 | if verbose: print(best_picks) 595 | if mem_efficient: 596 | clip_extractor.to("cpu") 597 | voxel2clip.to(device) 598 | 599 | img2img_samples = 0 if img_lowlevel is None else 1 600 | num_xaxis_subplots = 1+img2img_samples+recons_per_sample 601 | if plotting: 602 | fig, ax = plt.subplots(n_samples_save, num_xaxis_subplots, 603 | figsize=(num_xaxis_subplots*5,6*n_samples_save),facecolor=(1, 1, 1)) 604 | else: 605 | fig = None 606 | recon_img = None 607 | 608 | im = 0 609 | if plotting: 610 | ax[0].set_title(f"Original Image") 611 | ax[0].imshow(torch_to_Image(image[im])) 612 | if img2img_samples == 1: 613 | ax[1].set_title(f"Img2img ({img2img_strength})") 614 | ax[1].imshow(torch_to_Image(img_lowlevel[im].clamp(0,1))) 615 | for ii,i in enumerate(range(num_xaxis_subplots-recons_per_sample,num_xaxis_subplots)): 616 | recon = brain_recons[im][ii] 617 | if plotting: 618 | if ii == best_picks[im]: 619 | ax[i].set_title(f"Reconstruction",fontweight='bold') 620 | recon_img = recon 621 | else: 622 | ax[i].set_title(f"Recon {ii+1} from brain") 623 | ax[i].imshow(torch_to_Image(recon)) 624 | if plotting: 625 | for i in range(num_xaxis_subplots): 626 | ax[i].axis('off') 627 | 628 | return fig, brain_recons, best_picks, recon_img, ridge_out --------------------------------------------------------------------------------