├── metrics ├── mIOU │ ├── __init__.py │ ├── segment_vitadapter.sh │ ├── evaluate_img_gen.py │ ├── eval_miou.py │ └── segment_vitadapter.py └── image_metrics.py ├── custom_infra.yaml ├── setup.sh ├── LICENSE ├── CONTRIBUTING.md ├── setup_miou.sh ├── blur.py ├── config ├── config.yaml └── cfg.py ├── eval.py ├── CODE_OF_CONDUCT.md ├── readme.md ├── model ├── peft_utils.py ├── fmri_mlp.py └── models.py ├── prepare_data.py ├── dynadiff.py └── data.py /metrics/mIOU/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /metrics/mIOU/segment_vitadapter.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | eval "$(conda shell.bash hook)" 10 | conda activate miou 11 | echo "$(which pip)" 12 | echo "$2" "$1" 13 | 14 | CONFIG="ViT-Adapter/segmentation/configs/coco_stuff164k/mask2former_beitv2_adapter_large_896_80k_cocostuff164k_ss.py" 15 | CHECKPOINT="ViT-Adapter/segmentation/models/mask2former_beitv2_adapter_large_896_80k_cocostuff164k.pth" 16 | 17 | 18 | python metrics/mIOU/segment_vitadapter.py $CONFIG $CHECKPOINT --work-dir "$1" -------------------------------------------------------------------------------- /custom_infra.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | task_infra_data: 8 | cluster: auto 9 | folder: ./cache 10 | cpus_per_task: 10 11 | gpus_per_node: 0 12 | max_jobs: 128 13 | min_samples_per_job: 8 14 | slurm_partition: null 15 | slurm_constraint: null 16 | timeout_min: 30 17 | 18 | map_infra_image_generation: 19 | cluster: auto 20 | folder: ./cache 21 | cpus_per_task: 8 22 | min_samples_per_job: 10 23 | max_jobs: 128 24 | gpus_per_node: 1 25 | slurm_partition: null 26 | slurm_constraint: null 27 | timeout_min: 25 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | pip install torch==2.4.0 torchvision==0.19 torchaudio==2.4.0 scikit-image==0.24.0 10 | 11 | git clone git@github.com:huggingface/diffusers.git 12 | pushd diffusers 13 | git checkout v0.32.0-release 14 | patch -p1 -d . < ../vd_patch.diff 15 | pip install -e . 16 | popd 17 | 18 | pip install lightning==2.4.0 deepspeed==0.15.1 x_transformers 19 | pip install xformers==0.0.28.post1 hydra-core==1.3.2 h5py==3.11.0 peft==0.13.0 20 | pip install git+https://github.com/openai/CLIP.git 21 | pip install retina-face 22 | 23 | # Downgrade tensorflow 24 | pip install tensorflow==2.13.0 25 | 26 | # Compatible dreamsim version 27 | pip install dreamsim==0.2.0 28 | 29 | pip install typing-extensions==4.13.2 exca nibabel nilearn -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. and affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to dynadiff 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to dynadiff, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /setup_miou.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # All rights reserved. 5 | # 6 | # This source code is licensed under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | 9 | git clone git@github.com:czczup/ViT-Adapter.git 10 | 11 | # # Installing MMSegmentation v0.20.2. as per https://github.com/czczup/ViT-Adapter/tree/main/segmentation#usage 12 | # # recommended environment: torch1.9 + cuda11.1 13 | pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html 14 | pip install mmcv-full==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html 15 | pip install timm==0.4.12 16 | # for Mask2Former 17 | pip install mmdet==2.22.0 18 | pip install mmsegmentation==0.20.2 19 | pushd ViT-Adapter/segmentation 20 | ln -s ../detection/ops ./ 21 | cd ops 22 | # compile deformable attention 23 | sh make.sh 24 | popd 25 | 26 | # Additional dependencies 27 | pip install tqdm scipy==1.7.3 28 | 29 | # Download checkpoint 30 | wget -O mask2former_beitv2_adapter_large_896_80k_cocostuff164k.zip https://github.com/czczup/ViT-Adapter/releases/download/v0.3.1/mask2former_beitv2_adapter_large_896_80k_cocostuff164k.zip && unzip mask2former_beitv2_adapter_large_896_80k_cocostuff164k.zip -d ViT-Adapter/segmentation/models/ && rm mask2former_beitv2_adapter_large_896_80k_cocostuff164k.zip -------------------------------------------------------------------------------- /blur.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import numpy as np 8 | from PIL import Image, ImageDraw 9 | from retinaface import RetinaFace 10 | 11 | 12 | def get_mask(image, faces, erosion=0.1): 13 | mask = Image.new(mode="L", size=image.size, color="black") 14 | width, height = image.size 15 | bboxes = [] 16 | for face in faces: 17 | x0, y0, x1, y1 = face 18 | if x0 > x1: 19 | x0, x1 = x1, x0 20 | if y0 > y1: 21 | y0, y1 = y1, y0 22 | x0 = min(x0, width) 23 | x1 = min(x1, width) 24 | y0 = min(y0, height) 25 | y1 = min(y1, height) 26 | bbox = [x0, y0, x1, y1] 27 | 28 | diagonal = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) 29 | bbox = [ 30 | bbox[0] - erosion * diagonal, 31 | bbox[1] - erosion * diagonal, 32 | bbox[2] + erosion * diagonal, 33 | bbox[3] + erosion * diagonal, 34 | ] 35 | draw = ImageDraw.Draw(mask) 36 | draw.rectangle(bbox, fill="white") 37 | bboxes.append(bbox) 38 | return mask, bboxes 39 | 40 | def blur_faces( 41 | img: Image.Image 42 | ) -> Image.Image: 43 | resp = RetinaFace.detect_faces(np.array(img)) 44 | faces = [face["facial_area"] for face in resp.values()] 45 | if len(faces) == 0: 46 | return img 47 | 48 | mask, bboxes = get_mask(img, faces) 49 | image_width = img.width 50 | image_height = img.height 51 | background = Image.new('RGB', (image_width, image_height)) 52 | for bbox in bboxes: 53 | img_face_crop = img.crop((bbox[0], bbox[1], bbox[2], bbox[3])) 54 | width, height = img_face_crop.size 55 | small_image = img_face_crop.resize((4, 4), Image.BILINEAR) 56 | pixelated_image = small_image.resize((width, height), Image.NEAREST) 57 | background.paste(pixelated_image, (int(bbox[0]), int(bbox[1]))) 58 | 59 | 60 | return Image.composite(background, img, mask) 61 | 62 | 63 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | batch_size: 1 3 | nsd_dataset_config: 4 | nsddata_path: ./nsddata 5 | subject_id: 1 6 | seed: 42 7 | averaged: true 8 | offset: 4.6 9 | duration: 8.0 10 | infra: 11 | cluster: slurm 12 | conda_env: null 13 | cpus_per_task: 10 14 | folder: ./cache 15 | forbid_single_item_computation: false 16 | gpus_per_node: 0 17 | job_name: dynadiff-prepare 18 | keep_in_ram: true 19 | logs: '{folder}/logs/{user}/%j' 20 | mem_gb: null 21 | min_samples_per_job: 8 22 | mode: cached 23 | nodes: 1 24 | permissions: 511 25 | slurm_additional_parameters: null 26 | slurm_use_srun: false 27 | tasks_per_node: 1 28 | timeout_min: 120 29 | version: '1' 30 | workdir: null 31 | name: NeuroImagesDataModuleConfig 32 | pin_memory: true 33 | # test_groupbyimg: averaged 34 | workers: 10 35 | infra: 36 | cluster: null 37 | conda_env: null 38 | cpus_per_task: 10 39 | folder: ./cache 40 | gpus_per_node: null 41 | job_name: dynadiff 42 | keep_in_ram: false 43 | logs: '{folder}/logs/{user}/%j' 44 | mem_gb: null 45 | mode: force 46 | permissions: 511 47 | slurm_additional_parameters: null 48 | slurm_constraint: null 49 | slurm_partition: null 50 | slurm_use_srun: false 51 | timeout_min: 4320 52 | version: '1' 53 | workdir: null 54 | seed: 33 55 | strategy: deepspeed 56 | versatilediffusion_config: 57 | brain_modules_config.clip_image: 58 | act_first: false 59 | blurry_recon: false 60 | deep_subject_layers: false 61 | hidden: 1552 62 | n_blocks: 0 63 | n_repetition_times: 6 64 | n_subjects: 1 65 | name: FmriMLP 66 | native_fmri_space: false 67 | norm_type: ln 68 | out_dim: null 69 | subject_layers: true 70 | subject_layers_dim: hidden 71 | subject_layers_id: false 72 | time_agg: out_linear 73 | tr_embed_dim: 16 74 | use_tr_embeds: false 75 | use_tr_layer: true 76 | diffusion_noise_offset: true 77 | drop_rate_clsfree: 0.1 78 | in_dim: 15724 79 | name: VersatileDiffusion 80 | noise_cubic_sampling: true 81 | num_inference_steps: 20 82 | prediction_type: epsilon 83 | trainable_unet_layers: lora 84 | training_strategy: w/_difloss 85 | vd_cache_dir: null 86 | 87 | -------------------------------------------------------------------------------- /config/cfg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from exca import ConfDict 8 | 9 | 10 | def get_cfg( 11 | subject: int, 12 | averaged_trial: bool, 13 | # save_recons_to: str, 14 | cache: str, 15 | seed: int, 16 | vd_cache_dir: str, 17 | custom_infra: dict | None = None, 18 | ) -> ConfDict: 19 | """ 20 | Get the configuration for the evaluation task. 21 | 22 | Args: 23 | subject (int): Subject number. 24 | averaged_trial (bool): Flag to indicate if averaged trial is used. 25 | cache (str): Directory for caching data. 26 | seed (int): Seed for RNG. 27 | vd_cache_dir (str): Directory for caching Versatile Diffusion. 28 | custom_infra (dict, optional): Custom TaskInfra/MapInfra configuration. Defaults to None (compute locally). 29 | 30 | Returns: 31 | ConfDict: Configuration dictionary for the evaluation task. 32 | """ 33 | 34 | with open("config/config.yaml", "r") as f: 35 | config = ConfDict.from_yaml(f) 36 | 37 | config["versatilediffusion_config.vd_cache_dir"] = vd_cache_dir 38 | config["seed"] = seed 39 | config["data.nsd_dataset_config.seed"] = seed 40 | config["data.nsd_dataset_config.averaged"] = averaged_trial 41 | config["data.nsd_dataset_config.subject_id"] = subject 42 | 43 | local_infra = { 44 | "cluster": None, 45 | "folder": cache, 46 | } 47 | 48 | config["infra"] = local_infra 49 | 50 | if custom_infra is not None: 51 | assert all( 52 | [ 53 | key 54 | in [ 55 | "task_infra_data", 56 | "map_infra_image_generation", 57 | ] 58 | for key in custom_infra 59 | ] 60 | ), "Infra can be specified only for 'task_infra_data' preparation and 'map_infra_image_generation'" 61 | 62 | 63 | 64 | config["data.nsd_dataset_config.infra"] = custom_infra["task_infra_data"] if custom_infra is not None else local_infra 65 | config["image_generation_infra"] = custom_infra["map_infra_image_generation"] if custom_infra is not None else local_infra 66 | 67 | 68 | 69 | return config 70 | -------------------------------------------------------------------------------- /metrics/mIOU/evaluate_img_gen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import subprocess 10 | import tempfile 11 | import typing as tp 12 | from pathlib import Path 13 | 14 | import numpy as np 15 | import torch 16 | from PIL import Image 17 | from torchvision import transforms as TR 18 | from torchvision.transforms.functional import InterpolationMode 19 | from tqdm import tqdm, trange 20 | 21 | from .eval_miou import SegmentationMetric 22 | 23 | 24 | def compute_miou( 25 | preds: tp.List[str | Path | Image.Image], 26 | trues: tp.List[str | Path | Image.Image], 27 | eval_res: int = 512, 28 | ): 29 | 30 | def openLbl(path): 31 | lbl = Image.open(str(path)) 32 | lbl = ( 33 | TR.Resize((eval_res, eval_res), interpolation=InterpolationMode.NEAREST)( 34 | TR.functional.to_tensor(lbl) 35 | ) 36 | * 255 37 | ) 38 | 39 | return lbl 40 | 41 | def process_img(img): 42 | if isinstance(img, (str, Path)): 43 | img = Image.open(img).convert("RGB") 44 | img = img.resize((eval_res, eval_res)) 45 | return img 46 | 47 | preds = [process_img(x.copy()) for x in preds] 48 | trues = [process_img(x.copy()) for x in trues] 49 | with tempfile.TemporaryDirectory() as folder: 50 | folder = Path(folder) 51 | folder_gt = folder / "gts" / "images" 52 | folder_pred = folder / "preds" / "images" 53 | folder_gt.mkdir(exist_ok=True, parents=True) 54 | folder_pred.mkdir(exist_ok=True, parents=True) 55 | pred_labels_gt = folder_gt.parent.joinpath("pred_label") 56 | pred_labels_gt.mkdir(exist_ok=True) 57 | pred_labels_pred = folder_pred.parent.joinpath("pred_label") 58 | pred_labels_pred.mkdir(exist_ok=True) 59 | 60 | images_dir = os.path.join(folder, "images") 61 | os.makedirs(images_dir) 62 | 63 | for idx, (pred, true) in tqdm(enumerate(zip(preds, trues)), total=len(trues)): 64 | pred.save(folder_pred / f"pred_{idx}.png") 65 | true.save(folder_gt / f"gt_{idx}.png") 66 | 67 | bashCommand = f"bash metrics/mIOU/segment_vitadapter.sh {str(folder_gt.parent)}" 68 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) 69 | output, error = process.communicate() 70 | print(output, error) 71 | bashCommand = f"bash metrics/mIOU/segment_vitadapter.sh {str(folder_pred.parent)}" 72 | process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE) 73 | output, error = process.communicate() 74 | print(output, error) 75 | 76 | seg_metric = SegmentationMetric(172) 77 | 78 | tot_miou = [] 79 | for idx in trange(len(trues)): 80 | pred_lbl_path = pred_labels_pred / f"pred_{idx}.png" 81 | gt_lbl_path = pred_labels_gt / f"gt_{idx}.png" 82 | pred_lbl = openLbl(pred_lbl_path) 83 | gt_lbl = openLbl(gt_lbl_path) 84 | 85 | 86 | bool_mask = torch.ones_like(gt_lbl, dtype=torch.bool) 87 | 88 | seg_metric.update(pred_lbl, gt_lbl, bool_mask) 89 | miou = seg_metric.get() 90 | seg_metric.reset() 91 | tot_miou.append(miou) 92 | return np.mean(tot_miou) 93 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | from pathlib import Path 9 | 10 | import yaml 11 | from config.cfg import get_cfg 12 | from dynadiff import DynaDiffEval 13 | 14 | 15 | def evaluate(): 16 | parser = argparse.ArgumentParser(description="Evaluation script for Dynadiff") 17 | parser.add_argument( 18 | "--subject", 19 | type=int, 20 | choices=[1, 2, 5, 7], 21 | help="Subject identifier (must be 1, 2, 5, or 7)", 22 | ) 23 | 24 | parser.add_argument( 25 | "--averaged-trial", 26 | action="store_true", 27 | help="Reconstruct an image from each (1000) averaged test trial instead", 28 | ) 29 | 30 | parser.add_argument( 31 | "--cache", 32 | type=str, 33 | default="./cache", 34 | help="Folder used to prepare and store fMRI data. Defaults to ./cache.", 35 | ) 36 | 37 | parser.add_argument( 38 | "--seed", 39 | type=int, 40 | default=3, 41 | help="Seed for RNG (default: 3)", 42 | ) 43 | 44 | parser.add_argument( 45 | "--vd_cache_dir", 46 | type=str, 47 | default="./versatile_diffusion", 48 | help="Folder to cache Versatile Diffusion. Defaults to ./versatile_diffusion.", 49 | ) 50 | 51 | parser.add_argument( 52 | "--infra-yaml", 53 | type=str, 54 | default=None, 55 | help="Path to infra.yaml config file for data preparation and image generation." 56 | "Defaults to None, i.e. using local compute only", 57 | ) 58 | 59 | parser.add_argument( 60 | "--compute-miou", 61 | action="store_true", 62 | help="Compute mIoU for the stimulus / reconstruction pairs.", 63 | ) 64 | 65 | args = parser.parse_args() 66 | 67 | custom_infra = None 68 | if args.infra_yaml is not None: 69 | print(f"Using custom infra config at: {args.infra_yaml}") 70 | with open(args.infra_yaml, "r") as f: 71 | custom_infra = yaml.safe_load(f) 72 | 73 | print(f"Evaluating subject: sub{args.subject}") 74 | print(f"Using averaged trials: {args.averaged_trial}") 75 | print(f"Preparing data in: {args.cache}") 76 | print(f"Seed: {args.seed}") 77 | print(f"Caching Versatile Diffusion model in : {args.vd_cache_dir}") 78 | 79 | cfg = get_cfg( 80 | args.subject, 81 | args.averaged_trial, 82 | args.cache, 83 | args.seed, 84 | args.vd_cache_dir, 85 | custom_infra, 86 | ) 87 | 88 | average_id = 'averaged' if cfg['data']['nsd_dataset_config']['averaged'] else 'unaveraged' 89 | # subject_id = f"subj{cfg.data.nsd_dataset_config.subject_id:02d}" 90 | subject_id = f"subj{cfg['data']['nsd_dataset_config']['subject_id']:02d}" 91 | 92 | folder = Path(cfg['image_generation_infra']['folder']) / f"reconstructions_{subject_id}_{average_id}" 93 | print( 94 | f"Saving reconstructions to: {folder}" 95 | ) 96 | 97 | task = DynaDiffEval(**cfg) 98 | task.prepare() 99 | task.run() 100 | 101 | if args.compute_miou: 102 | print("Computing mIoU") 103 | task.compute_miou() 104 | 105 | 106 | if __name__ == "__main__": 107 | evaluate() 108 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | # Dynadiff: Single-stage Decoding of Images from Continuously Evolving fMRI 8 | 9 | This repository contains the official code for evaluating the **Dynadiff** model: **Dy**namic **N**eural **A**ctivity **Diff**usion for Image Reconstruction. 10 | 11 | ## Create the environment 12 | 13 | Create a conda environment for running and evaluating the **Dynadiff** model using the following command. 14 | The repository was tested with CUDA 12.2.0, cuDNN 8.8.1.3 (for CUDA 12.0), and GCC 12.2.0. We strongly recommend using this configuration. 15 | 16 | 17 | ```bash 18 | conda create --name dynadiff python==3.10 -y 19 | conda activate dynadiff 20 | chmod +x setup.sh 21 | ./setup.sh 22 | ``` 23 | 24 | 25 | ## Download Model Weights 26 | 27 | The **Dynadiff** model weights are available on [HuggingFace](https://huggingface.co/facebook/dynadiff). 28 | The `subj*` folders should be stored inside `./checkpoints`. 29 | 30 | ## Prepare the NSD Data 31 | 32 | Access to the [Natural Scenes Dataset (NSD)](https://naturalscenesdataset.org/) requires filling out this [form](https://docs.google.com/forms/d/e/1FAIpQLSduTPeZo54uEMKD-ihXmRhx0hBDdLHNsVyeo_kCb8qbyAkXuQ/viewform) and agreeing to the dataset's Terms & Conditions. 33 | After obtaining access, the data for evaluating **Dynadiff** can be downloaded by running the following command: 34 | 35 | ``` 36 | python prepare_data.py \ 37 | --nsd_bucket "s3://nsd_bucket_name" \ 38 | --path "./nsddata" 39 | ``` 40 | 41 | * By default, the **Dynadiff** evaluation script expects the NSD data to be downloaded and prepared in the folder `./nsddata`, as done with this command. This can be changed to any path `my/nsddata/path` by setting the config key: `data.nsd_dataset_config.nsddata_path: my/nsddata/path`. 42 | * Additional `s3` arguments can be specified using the string argument `--aws_args`. 43 | 44 | 45 | 46 | ## Evaluate the Model 47 | ### Locally 48 | Data preparation does not use GPU acceleration. Image reconstruction requires 8Gb of VRAM. 49 | 50 | The NSD data prepared for the **Dynadiff** evaluation will be cached by `exca` inside the folder `./cache` by default. This can be set to another folder using the config key `data.nsd_dataset_config.infra.folder`. 51 | 52 | To reconstruct images from fMRI timeseries recorded for subject `$SUBJECT_ID` (e.g., `SUBJECT_ID=1`), run: 53 | ```bash 54 | python eval.py --subject $SUBJECT_ID 55 | ``` 56 | 57 | By default, the reconstructed images (`{index}.png`) are stored along the stimuli (`{index}_gt.png`) in the `./cache/reconstructions_{subj_id}_{average_mode}` folder. This can be modified using the key `infra.folder` in `config/config.yaml`. 58 | 59 | ### On a SLURM cluster 60 | Both data preparation and image reconstruction can be distributed and accelerated on a SLURM cluster via the [`exca`](https://github.com/facebookresearch/exca) library. 61 | The SLURM resources are specified in the YAML file `custom_infra.yaml` using: 62 | 63 | * `task_infra_data` for data preparation (handled by an `exca.TaskInfra` instance). 64 | * `map_infra_image_generation` for image generation (an`exca.MapInfra` instance). 65 | 66 | Evaluation using these resources is launched with: 67 | 68 | ```bash 69 | python eval.py --subject $SUBJECT_ID --infra-yaml custom_infra.yaml 70 | ``` 71 | 72 | 73 | ## Computing mIoU for stimulus / reconstruction pairs 74 | To accommodate incompatibilities between the `dynadiff` environment and requirements for computing the mIoU segmentation metric using the [`ViT-Adapter`](https://github.com/czczup/ViT-Adapter/tree/main/segmentation) repository, we use an additional conda environment `miou`. The `eval.py` script should still be run in the `dynadiff` environment and will automatically switch to the `miou` environment to compute the mIoU metric locally: 75 | 76 | ```bash 77 | conda create --name miou python==3.7 -y 78 | conda activate miou 79 | chmod +x setup_miou.sh 80 | ./setup_miou.sh 81 | ``` 82 | Pass the `--compute-miou` flag to compute mIoU for stimulus / reconstruction pairs: 83 | ```bash 84 | conda activate dynadiff 85 | python eval.py --subject $SUBJECT_ID --compute-miou 86 | ``` 87 | 88 | ## Contributing 89 | See the CONTRIBUTING file for how to help out. 90 | 91 | ## License 92 | `dynadiff` is MIT licensed, as found in the LICENSE file. Also check-out Meta Open Source [Terms of Use](https://opensource.fb.com/legal/terms/) and [Privacy Policy](https://opensource.fb.com/legal/privacy/). 93 | -------------------------------------------------------------------------------- /model/peft_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Copyright 2024 The HuggingFace Inc. team. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | """Code taken from diffusers library (src/diffusers/loaders/peft.py with checkout v0.32.0-release) and modified to ensure compatibility with versatile diffusion model.""" 22 | 23 | from typing import List, Union 24 | 25 | from peft import PeftConfig, inject_adapter_in_model 26 | from peft.tuners.tuners_utils import BaseTunerLayer 27 | 28 | 29 | def set_adapter_layers(model, enabled=True): 30 | from peft.tuners.tuners_utils import BaseTunerLayer 31 | 32 | for module in model.modules(): 33 | if isinstance(module, BaseTunerLayer): 34 | # The recent version of PEFT needs to call `enable_adapters` instead 35 | if hasattr(module, "enable_adapters"): 36 | module.enable_adapters(enabled=enabled) 37 | else: 38 | module.disable_adapters = not enabled 39 | 40 | 41 | def add_adapter(model, adapter_config, adapter_name: str = "default") -> None: 42 | r""" 43 | Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned 44 | to the adapter to follow the convention of the PEFT library. 45 | 46 | If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT 47 | [documentation](https://huggingface.co/docs/peft). 48 | 49 | Args: 50 | adapter_config (`[~peft.PeftConfig]`): 51 | The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt 52 | methods. 53 | adapter_name (`str`, *optional*, defaults to `"default"`): 54 | The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. 55 | """ 56 | # check_peft_version(min_version=MIN_PEFT_VERSION) 57 | 58 | # if not is_peft_available(): 59 | # raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.") 60 | 61 | if not model._hf_peft_config_loaded: 62 | model._hf_peft_config_loaded = True 63 | elif adapter_name in model.peft_config: 64 | raise ValueError( 65 | f"Adapter with name {adapter_name} already exists. Please use a different name." 66 | ) 67 | 68 | if not isinstance(adapter_config, PeftConfig): 69 | raise ValueError( 70 | f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." 71 | ) 72 | 73 | # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is 74 | # handled by the `load_lora_layers` or `StableDiffusionLoraLoaderMixin`. Therefore we set it to `None` here. 75 | adapter_config.base_model_name_or_path = None 76 | inject_adapter_in_model(adapter_config, model, adapter_name) 77 | set_adapter(model, adapter_name) 78 | 79 | 80 | def set_adapter(model, adapter_name: Union[str, List[str]]) -> None: 81 | """ 82 | Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. 83 | 84 | If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT 85 | [documentation](https://huggingface.co/docs/peft). 86 | 87 | Args: 88 | adapter_name (Union[str, List[str]])): 89 | The list of adapters to set or the adapter name in the case of a single adapter. 90 | """ 91 | # check_peft_version(min_version=MIN_PEFT_VERSION) 92 | 93 | if not model._hf_peft_config_loaded: 94 | raise ValueError("No adapter loaded. Please load an adapter first.") 95 | 96 | if isinstance(adapter_name, str): 97 | adapter_name = [adapter_name] 98 | 99 | missing = set(adapter_name) - set(model.peft_config) 100 | if len(missing) > 0: 101 | raise ValueError( 102 | f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." 103 | f" current loaded adapters are: {list(model.peft_config.keys())}" 104 | ) 105 | 106 | _adapters_has_been_set = False 107 | 108 | for _, module in model.named_modules(): 109 | if isinstance(module, BaseTunerLayer): 110 | if hasattr(module, "set_adapter"): 111 | module.set_adapter(adapter_name) 112 | # Previous versions of PEFT does not support multi-adapter inference 113 | elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: 114 | raise ValueError( 115 | "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." 116 | " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" 117 | ) 118 | else: 119 | module.active_adapter = adapter_name 120 | _adapters_has_been_set = True 121 | 122 | if not _adapters_has_been_set: 123 | raise ValueError( 124 | "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." 125 | ) 126 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import subprocess 9 | from pathlib import Path 10 | 11 | import h5py 12 | import numpy as np 13 | from PIL import Image 14 | from tqdm import trange 15 | 16 | 17 | # Helper to load 'nsd_expdesign.mat' file Copy-pasted, 18 | # from https://github.com/ozcelikfu/brain-diffuser/blob/main/data/prepare_nsddata.py 19 | # Commit 1c07200 20 | def _loadmat(filename): 21 | """ 22 | this function should be called instead of direct spio.loadmat 23 | as it cures the problem of not properly recovering python dictionaries 24 | from mat files. It calls the function check keys to cure all entries 25 | which are still mat-objects 26 | """ 27 | 28 | def _check_keys(d): 29 | """ 30 | checks if entries in dictionary are mat-objects. If yes 31 | todict is called to change them to nested dictionaries 32 | """ 33 | for key in d: 34 | if isinstance(d[key], spio.matlab.mat_struct): 35 | d[key] = _todict(d[key]) 36 | return d 37 | 38 | def _todict(matobj): 39 | """ 40 | A recursive function which constructs from matobjects nested dictionaries 41 | """ 42 | d = {} 43 | for strg in matobj._fieldnames: 44 | elem = matobj.__dict__[strg] 45 | if isinstance(elem, spio.matlab): 46 | d[strg] = _todict(elem) 47 | elif isinstance(elem, np.ndarray): 48 | d[strg] = _tolist(elem) 49 | else: 50 | d[strg] = elem 51 | return d 52 | 53 | def _tolist(ndarray): 54 | """ 55 | A recursive function which constructs lists from cellarrays 56 | (which are loaded as numpy ndarrays), recursing into the elements 57 | if they contain matobjects. 58 | """ 59 | elem_list = [] 60 | for sub_elem in ndarray: 61 | if isinstance(sub_elem, spio.matlab.mio5_params.mat_struct): 62 | elem_list.append(_todict(sub_elem)) 63 | elif isinstance(sub_elem, np.ndarray): 64 | elem_list.append(_tolist(sub_elem)) 65 | else: 66 | elem_list.append(sub_elem) 67 | return elem_list 68 | 69 | import scipy.io as spio 70 | 71 | data = spio.loadmat(filename, struct_as_record=False, squeeze_me=True) 72 | return _check_keys(data) 73 | 74 | 75 | def download(nsd_bucket: str, path: str, aws_args: str) -> None: 76 | path = Path(path) 77 | download_nsd_timeseries_dataset(nsd_bucket, path, aws_args) 78 | prepare_dataset(path) 79 | 80 | 81 | def download_nsd_timeseries_dataset(nsd_bucket: str, path: Path, aws_args: str) -> None: 82 | path.mkdir(exist_ok=True, parents=True) 83 | 84 | aws_cmds = [] 85 | 86 | nsd_bucket = nsd_bucket.rstrip('/') 87 | for subject in [1, 2, 5, 7]: 88 | # timeseries data 89 | for data_type in ["timeseries", "design"]: 90 | aws_cmd = ( 91 | f"aws s3 {aws_args} sync" 92 | f" {nsd_bucket}/nsddata_timeseries/ppdata" 93 | f"/subj{subject:02}" 94 | f"/func1pt8mm/{data_type}/" 95 | f" {path}/nsddata_timeseries/ppdata/subj{subject:02}" 96 | f"/func1pt8mm/{data_type}/" 97 | " --exclude '*'" 98 | f" --include '{data_type}_session*'" 99 | ) 100 | aws_cmds.append(aws_cmd) 101 | # rois 102 | roi_aws_cmd = ( 103 | f"aws s3 {aws_args} sync" 104 | f" {nsd_bucket}/nsddata/ppdata/subj{subject:02}" 105 | "/func1pt8mm/roi/" 106 | f" {path}/nsddata/ppdata/subj{subject:02}/func1pt8mm/roi/" 107 | ) 108 | aws_cmds.append(roi_aws_cmd) 109 | 110 | # Experimental design matrix 111 | expdesign_mat_aws_cmd = ( 112 | f"aws s3 {aws_args} cp" 113 | f" {nsd_bucket}/nsddata/experiments/nsd/nsd_expdesign.mat" 114 | f" {path}" 115 | ) 116 | aws_cmds.append(expdesign_mat_aws_cmd) 117 | 118 | # Stimulus matrix 119 | stimuli_mat_aws_cmd = ( 120 | f"aws s3 {aws_args} cp" 121 | f" {nsd_bucket}/nsddata_stimuli/stimuli/nsd/nsd_stimuli.hdf5" 122 | f" {path}" 123 | ) 124 | aws_cmds.append(stimuli_mat_aws_cmd) 125 | 126 | for aws_cmd in aws_cmds: 127 | subprocess.run(aws_cmd, shell=True) 128 | 129 | 130 | def prepare_dataset(path: Path) -> None: 131 | extract_stimuli(path) 132 | extract_test_images_ids(path) 133 | 134 | 135 | def extract_stimuli(path: Path) -> None: 136 | f_stim = h5py.File(path / "nsd_stimuli.hdf5", "r") 137 | stim = f_stim["imgBrick"][:] 138 | 139 | nsd_stimuli_folder = path / "nsd_stimuli" 140 | nsd_stimuli_folder.mkdir(exist_ok=True, parents=True) 141 | 142 | for idx in trange(stim.shape[0]): 143 | Image.fromarray(stim[idx]).save(nsd_stimuli_folder / f"{idx}.png") 144 | 145 | 146 | def extract_test_images_ids(path: Path) -> None: 147 | path_to_expdesign_mat = path / "nsd_expdesign.mat" 148 | expdesign_mat = _loadmat(path_to_expdesign_mat) 149 | np.save(path / "test_images_ids.npy", expdesign_mat["sharedix"]) 150 | 151 | 152 | if __name__ == "__main__": 153 | argparser = argparse.ArgumentParser() 154 | argparser.add_argument( 155 | "--nsd_bucket", 156 | type=str, 157 | required=True, 158 | help="NSD S3 bucket URI, in the form 's3://...' (requires agreement to NSD Terms & Conditions)", 159 | ) 160 | 161 | argparser.add_argument( 162 | "--path", 163 | type=str, 164 | required=True, 165 | help="Path where the NSD data will be downloaded and prepared", 166 | ) 167 | 168 | argparser.add_argument( 169 | "--aws_args", 170 | type=str, 171 | default="", 172 | help="Additional AWS args to use for downloading NSD data", 173 | ) 174 | 175 | args = argparser.parse_args() 176 | download(args.nsd_bucket, args.path, args.aws_args) 177 | print("Done downloading and preparing NSD data.") 178 | -------------------------------------------------------------------------------- /metrics/mIOU/eval_miou.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | """Evaluation Metrics for Semantic Segmentation""" 8 | """Code taken from https://github.com/Tramac/mobilenetv3-segmentation/blob/master/core/utils/metric.py """ 9 | 10 | import numpy as np 11 | import torch 12 | 13 | __all__ = [ 14 | "SegmentationMetric", 15 | "batch_pix_accuracy", 16 | "batch_intersection_union", 17 | "intersectionAndUnion", 18 | "hist_info", 19 | "compute_score", 20 | ] 21 | 22 | 23 | class SegmentationMetric(object): 24 | """Computes pixAcc and mIoU metric scores""" 25 | 26 | def __init__(self, nclass): 27 | super(SegmentationMetric, self).__init__() 28 | self.nclass = nclass 29 | self.reset() 30 | 31 | def update(self, preds, labels, masks): 32 | """Updates the internal evaluation result. 33 | Parameters 34 | ---------- 35 | labels : 'NumpyArray' or list of `NumpyArray` 36 | The labels of the data. 37 | preds : 'NumpyArray' or list of `NumpyArray` 38 | Predicted values. 39 | """ 40 | 41 | def evaluate_worker(self, pred, label, mask): 42 | inter, union, cls_visited = batch_intersection_union( 43 | pred, label, self.nclass, mask 44 | ) 45 | 46 | if self.total_inter.device != inter.device: 47 | self.total_inter = self.total_inter.to(inter.device) 48 | self.total_union = self.total_union.to(union.device) 49 | self.total_inter += inter 50 | self.total_union += union 51 | 52 | self.total_visited.update(cls_visited) 53 | 54 | if isinstance(preds, torch.Tensor): 55 | evaluate_worker(self, preds, labels, masks) 56 | elif isinstance(preds, (list, tuple)): 57 | for pred, label, mask in zip(preds, labels, masks): 58 | evaluate_worker(self, pred, label, mask) 59 | 60 | def get(self): 61 | """Gets the current evaluation result. 62 | Returns 63 | ------- 64 | metrics : tuple of float 65 | pixAcc and mIoU 66 | """ 67 | tens_totVisited = torch.tensor(list(self.total_visited)).long() 68 | IoU = 1.0 * self.total_inter / (2.220446049250313e-16 + self.total_union) 69 | mIoU = IoU[tens_totVisited].mean().item() 70 | return mIoU 71 | 72 | def reset(self): 73 | """Resets the internal evaluation result to initial state.""" 74 | self.total_inter = torch.zeros(self.nclass) 75 | self.total_union = torch.zeros(self.nclass) 76 | self.total_correct = 0 77 | self.total_label = 0 78 | self.total_visited = set() 79 | 80 | 81 | # pytorch version 82 | def batch_pix_accuracy(output, target, mask): 83 | """PixAcc""" 84 | predict = output 85 | predict = predict[mask] 86 | target = target[mask] 87 | 88 | pixel_labeled = torch.sum(target > 0).item() 89 | pixel_correct = torch.sum((predict == target) * (target > 0)).item() 90 | assert pixel_correct <= pixel_labeled, "Correct area should be smaller than Labeled" 91 | return pixel_correct, pixel_labeled 92 | 93 | 94 | def batch_intersection_union(output, target, nclass, mask): 95 | """mIoU""" 96 | # inputs are numpy array, output 4D, target 3D 97 | mini = 1 98 | maxi = nclass 99 | nbins = nclass 100 | predict = output 101 | predict = (predict[mask] + 1).float() 102 | target = (target[mask] + 1).float() 103 | cls_visited = torch.unique(target - 1).tolist() 104 | 105 | predict = predict.float() * (target > -0).float() 106 | intersection = predict * (predict == target).float() 107 | # areas of intersection and union 108 | # element 0 in intersection occur the main difference from np.bincount. set boundary to -1 is necessary. 109 | area_inter = torch.histc(intersection.cpu(), bins=nbins, min=mini, max=maxi) 110 | area_pred = torch.histc(predict.cpu(), bins=nbins, min=mini, max=maxi) 111 | area_lab = torch.histc(target.cpu(), bins=nbins, min=mini, max=maxi) 112 | area_union = area_pred + area_lab - area_inter 113 | 114 | assert ( 115 | torch.sum(area_inter > area_union).item() == 0 116 | ), "Intersection area should be smaller than Union area" 117 | return area_inter.float(), area_union.float(), cls_visited 118 | 119 | def intersectionAndUnion(imPred, imLab, numClass): 120 | """ 121 | This function takes the prediction and label of a single image, 122 | returns intersection and union areas for each class 123 | To compute over many images do: 124 | for i in range(Nimages): 125 | (area_intersection[:,i], area_union[:,i]) = intersectionAndUnion(imPred[i], imLab[i]) 126 | IoU = 1.0 * np.sum(area_intersection, axis=1) / np.sum(np.spacing(1)+area_union, axis=1) 127 | """ 128 | # Remove classes from unlabeled pixels in gt image. 129 | # We should not penalize detections in unlabeled portions of the image. 130 | imPred = imPred * (imLab >= 0) 131 | 132 | # Compute area intersection: 133 | intersection = imPred * (imPred == imLab) 134 | (area_intersection, _) = np.histogram( 135 | intersection, bins=numClass, range=(1, numClass) 136 | ) 137 | 138 | # Compute area union: 139 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 140 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 141 | area_union = area_pred + area_lab - area_intersection 142 | return (area_intersection, area_union) 143 | 144 | 145 | def hist_info(pred, label, num_cls): 146 | assert pred.shape == label.shape 147 | k = (label >= 0) & (label < num_cls) 148 | labeled = np.sum(k) 149 | correct = np.sum((pred[k] == label[k])) 150 | 151 | return ( 152 | np.bincount( 153 | num_cls * label[k].astype(int) + pred[k], minlength=num_cls**2 154 | ).reshape(num_cls, num_cls), 155 | labeled, 156 | correct, 157 | ) 158 | 159 | 160 | def compute_score(hist, correct, labeled): 161 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist)) 162 | mean_IU = np.nanmean(iu) 163 | mean_IU_no_back = np.nanmean(iu[1:]) 164 | freq = hist.sum(1) / hist.sum() 165 | freq_IU = (iu[freq > 0] * freq[freq > 0]).sum() 166 | mean_pixel_acc = correct / labeled 167 | 168 | return iu, mean_IU, mean_IU_no_back, mean_pixel_acc 169 | -------------------------------------------------------------------------------- /dynadiff.py: -------------------------------------------------------------------------------- 1 | # Lightning module for model training and evaluation# Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import json 8 | import os 9 | import typing as tp 10 | from pathlib import Path 11 | 12 | import deepspeed 13 | import numpy as np 14 | import pydantic 15 | import torch 16 | import torchvision.transforms as T 17 | from blur import blur_faces 18 | from data import NeuroImagesDataModuleConfig 19 | from exca import MapInfra, TaskInfra 20 | from metrics.image_metrics import compute_image_generation_metrics 21 | from model.models import VersatileDiffusionConfig 22 | from torch import nn 23 | from tqdm import tqdm 24 | 25 | 26 | class DynaDiffEval(pydantic.BaseModel): 27 | data: NeuroImagesDataModuleConfig 28 | seed: int = 33 29 | versatilediffusion_config: VersatileDiffusionConfig 30 | strategy: str = "auto" 31 | device: str = "cuda" 32 | checkpoint_path: str = "./checkpoints" 33 | infra: TaskInfra = TaskInfra(version="1") 34 | image_generation_infra: MapInfra = MapInfra(version="1") 35 | 36 | _exclude_from_cls_uid: tp.ClassVar[list[str]] = [ 37 | "device", 38 | "seed", 39 | ] 40 | 41 | def model_post_init(self, __context: tp.Any) -> None: 42 | if self.infra.folder is None: 43 | msg = "infra.folder needs to be specified to save the results." 44 | raise ValueError(msg) 45 | self.data.workers = self.infra.cpus_per_task 46 | self.infra.folder = Path(self.infra.folder) 47 | self.infra.folder.mkdir(exist_ok=True, parents=True) 48 | self.image_generation_infra.folder = self.infra.folder 49 | 50 | def _get_brain_model(self, data_module) -> nn.Module: 51 | 52 | brain_n_in_channels, brain_temp_dim = data_module.eval_dataset[0]["brain"].size() 53 | 54 | copy_versatilediffusion_config = self.infra.clone_obj().versatilediffusion_config 55 | if copy_versatilediffusion_config.brain_modules_config is not None: 56 | for k in copy_versatilediffusion_config.brain_modules_config: 57 | copy_versatilediffusion_config.brain_modules_config[k].n_subjects = ( 58 | 1 59 | ) 60 | 61 | brain_model = copy_versatilediffusion_config.build( 62 | brain_n_in_channels, brain_temp_dim 63 | ) 64 | 65 | checkpoint_dir = ( 66 | Path(self.checkpoint_path).resolve() 67 | / f"sub{self.data.nsd_dataset_config.subject_id:02d}" 68 | ) 69 | 70 | print(f"Loading checkpoint at: {checkpoint_dir}") 71 | 72 | state_dict = ( 73 | deepspeed.utils.zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint( 74 | checkpoint_dir=checkpoint_dir, 75 | tag="checkpoint", 76 | ) 77 | ) 78 | 79 | state_dict = { 80 | k[len("model.") :]: v 81 | for k, v in state_dict.items() 82 | if k[: len("model.")] == "model." 83 | } 84 | brain_model.load_state_dict(state_dict, strict=False) 85 | brain_model = brain_model.eval() 86 | return brain_model 87 | 88 | @image_generation_infra.apply(item_uid=str, cache_type="MemmapArrayFile") 89 | def generate_images(self, images_idx: list[int]) -> tp.Iterator[np.ndarray]: 90 | data = self.data.build() 91 | brain_model = self._get_brain_model(data).to(self.device) 92 | print("IMG IDX ARE: ", images_idx) 93 | with torch.no_grad(): 94 | for img_idx in images_idx: 95 | ipt_data = data.eval_dataset[img_idx] 96 | ipt_data = { 97 | k: ( 98 | v[None, ...].to(device=self.device) 99 | if isinstance(v, torch.Tensor) 100 | else v 101 | ) 102 | for k, v in ipt_data.items() 103 | } 104 | 105 | if self.data.test_groupbyimg == "unaveraged": 106 | ipt_data = { 107 | k: v[0] if isinstance(v, torch.Tensor) else v 108 | for k, v in ipt_data.items() 109 | } 110 | print("UNAVG", ipt_data["brain"].size()) 111 | else: 112 | batch_out = brain_model(**ipt_data, is_img_gen_mode=True).image 113 | for image in batch_out: 114 | yield image.cpu().numpy() 115 | 116 | def prepare(self): 117 | torch.cuda.manual_seed(self.seed) 118 | data = self.data.build() 119 | return data, self.generate_images(list(range(len(data.eval_dataset)))) 120 | 121 | @infra.apply 122 | def run(self): 123 | config_path = Path(self.infra.folder) / "config.yaml" 124 | if not config_path.exists(): 125 | os.makedirs(self.infra.folder, exist_ok=True) 126 | self.infra.config(uid=False, exclude_defaults=False).to_yaml(config_path) 127 | 128 | data, recons = self.prepare() 129 | 130 | average_id = 'averaged' if self.data.nsd_dataset_config.averaged else 'unaveraged' 131 | subject_id = f"subj{self.data.nsd_dataset_config.subject_id:02d}" 132 | 133 | folder = self.image_generation_infra.folder / f"reconstructions_{subject_id}_{average_id}" 134 | folder.mkdir(exist_ok=True, parents=True) 135 | metrics_recimg = [] 136 | metrics_gtimg = [] 137 | for i, image in tqdm(enumerate(recons), total=len(data.eval_dataset)): 138 | recimg = T.ToPILImage()((image.transpose(1, 2, 0) * 255).astype(np.uint8)) 139 | recimg = blur_faces(recimg) 140 | recimg.save(folder / f"{i}.png") 141 | metrics_recimg.append(recimg) 142 | 143 | gtimg = data.eval_dataset[i]["img"] 144 | gtimg = T.ToPILImage()((gtimg * 255).to(torch.uint8)) 145 | gtimg.save(folder / f"{i}_gt.png") 146 | metrics_gtimg.append(gtimg) 147 | 148 | metrics = compute_image_generation_metrics(metrics_gtimg, metrics_recimg) 149 | 150 | mean_values = {k: float(v) for k, v in metrics.items() if "scores" not in k} 151 | with open("oss_metrics.json", "w") as f: 152 | json.dump(mean_values, f, indent=4) 153 | print(mean_values) 154 | 155 | def compute_miou(self): 156 | from metrics.mIOU.evaluate_img_gen import compute_miou 157 | 158 | data, recons = self.prepare() 159 | recons = [ 160 | T.ToPILImage()((image.transpose(1, 2, 0) * 255).astype(np.uint8)) 161 | for image in recons 162 | ] 163 | gtims = [ 164 | T.ToPILImage()((x["img"] * 255).to(torch.uint8)) for x in data.eval_dataset 165 | ] 166 | miou = compute_miou(recons, gtims, eval_res=512) 167 | print(f"mIoU: {miou}") 168 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import random 8 | import typing as tp 9 | from collections import defaultdict 10 | from pathlib import Path 11 | 12 | import nibabel 13 | import nilearn.signal 14 | import numpy as np 15 | import pandas as pd 16 | import pydantic 17 | import torch 18 | from exca.map import MapInfra 19 | from lightning.pytorch import LightningDataModule 20 | from PIL import Image 21 | from torch.utils.data import DataLoader 22 | 23 | TR_s = 4 / 3 24 | 25 | 26 | class NsdDataset(torch.utils.data.Dataset): 27 | def __init__(self, fmris: np.ndarray, images: np.ndarray): 28 | self.fmris = fmris 29 | self.images = images 30 | 31 | def __len__(self): 32 | return len(self.fmris) 33 | 34 | def __getitem__(self, idx): 35 | fmri = torch.from_numpy(self.fmris[idx]).float() 36 | image = torch.from_numpy(self.images[idx]).long().permute(2, 0, 1) / 255.0 37 | subject_id = torch.tensor( 38 | 0, dtype=torch.long 39 | ) # Single-subject models, for compatibility 40 | return {"brain": fmri, "img": image, "subject_idx": subject_id} 41 | 42 | 43 | class ImageEvent(pydantic.BaseModel): 44 | im_fp: str | Path 45 | nifti_fp: str | Path 46 | roi_fp: str | Path 47 | start_idx: int 48 | end_idx: int 49 | 50 | 51 | class NsdDatasetConfig(pydantic.BaseModel): 52 | nsddata_path: str 53 | subject_id: int 54 | offset: float = 4.6 55 | duration: float = 8.0 56 | seed: int = 42 57 | averaged: bool = False 58 | infra: MapInfra = MapInfra(version="1") 59 | 60 | def create_event_list(self) -> list[ImageEvent]: 61 | nsddata_path = Path(self.nsddata_path).resolve() 62 | test_im_ids = np.load(nsddata_path / "test_images_ids.npy") 63 | subject_to_14run_sessions = { 64 | 1: (21, 38), 65 | 5: (21, 38), 66 | 2: (21, 30), 67 | 7: (21, 30), 68 | } 69 | 70 | events = [] 71 | for session in range(1, 41): 72 | runs = ( 73 | range(2, 14) 74 | if subject_to_14run_sessions[self.subject_id][0] 75 | <= session 76 | <= subject_to_14run_sessions[self.subject_id][1] 77 | else range(1, 13) 78 | ) 79 | for run in runs: 80 | run_id = f"session{session:02d}_run{run:02d}" 81 | path_to_df = ( 82 | nsddata_path 83 | / f"nsddata_timeseries/ppdata/subj{self.subject_id:02d}/func1pt8mm/design/design_{run_id}.tsv" 84 | ) 85 | im_ids = pd.read_csv(path_to_df, header=None).iloc[:, 0].to_list() 86 | for timestep, image_id in enumerate(im_ids): 87 | if image_id != 0 and image_id in test_im_ids: 88 | im_fp = nsddata_path / f"nsd_stimuli/{image_id-1}.png" 89 | nifti_fp = ( 90 | nsddata_path 91 | / f"nsddata_timeseries/ppdata/subj{self.subject_id:02d}/func1pt8mm/timeseries/timeseries_{run_id}.nii.gz" 92 | ) 93 | roi_fp = ( 94 | nsddata_path 95 | / f"nsddata/ppdata/subj{int(self.subject_id):02d}/func1pt8mm/roi/nsdgeneral.nii.gz" 96 | ) 97 | start_idx = timestep + int(round(self.offset / TR_s)) 98 | end_idx = timestep + int( 99 | round(self.offset + self.duration) / TR_s 100 | ) 101 | 102 | events.append( 103 | ImageEvent( 104 | im_fp=im_fp, 105 | nifti_fp=nifti_fp, 106 | roi_fp=roi_fp, 107 | start_idx=start_idx, 108 | end_idx=end_idx, 109 | ) 110 | ) 111 | return events 112 | 113 | @infra.apply(item_uid=str, exclude_from_cache_uid=("averaged", "seed")) 114 | def prepare(self, events: list[ImageEvent]) -> tp.Iterator[tp.Any]: 115 | for event in events: 116 | nifti = nibabel.load(event.nifti_fp, mmap=True) 117 | nifti = nifti.slicer[..., :225] 118 | roi_np = nibabel.load(event.roi_fp, mmap=True).get_fdata() 119 | nifti_data = nifti.get_fdata()[roi_np > 0] 120 | 121 | # z-score across run and detrend 122 | nifti_data = nifti_data.T # set time as first dim 123 | shape = nifti_data.shape 124 | nifti_data = nilearn.signal.clean( 125 | nifti_data.reshape(shape[0], -1), 126 | detrend=True, 127 | high_pass=None, 128 | t_r=TR_s, 129 | standardize="zscore_sample", 130 | ) 131 | nifti_data = nifti_data.reshape(shape).T 132 | 133 | image = np.array( 134 | Image.open(event.im_fp).convert("RGB").resize((512, 512), Image.BILINEAR), 135 | dtype=np.uint8, 136 | ) 137 | 138 | yield { 139 | "brain": nifti_data[..., event.start_idx : event.end_idx], 140 | "img": image, 141 | } 142 | 143 | def build(self): 144 | events = self.create_event_list() 145 | data = self.prepare(events) 146 | data = list(data) 147 | 148 | grouped_events = defaultdict(list) 149 | for idx, event in enumerate(events): 150 | grouped_events[event.im_fp].append(idx) 151 | if self.averaged: 152 | averaged_data_list = [] 153 | for im_fp in grouped_events: 154 | averaged_brain = np.mean( 155 | [data[idx]["brain"] for idx in grouped_events[im_fp]], axis=0 156 | ) 157 | averaged_data = data[grouped_events[im_fp][0]].copy() 158 | averaged_data["brain"] = averaged_brain 159 | averaged_data_list.append(averaged_data) 160 | data = averaged_data_list 161 | else: 162 | random.seed(self.seed) 163 | data = [ 164 | data[random.choice(grouped_events[im_fp])] for im_fp in grouped_events 165 | ] 166 | 167 | fmri = np.stack([item["brain"] for item in data], axis=0) 168 | images = np.stack([item["img"] for item in data], axis=0) 169 | return NsdDataset(fmris=fmri, images=images) 170 | 171 | 172 | class NeuroImagesDataModuleConfig(pydantic.BaseModel): 173 | name: tp.Literal["NeuroImagesDataModuleConfig"] = "NeuroImagesDataModuleConfig" 174 | model_config = pydantic.ConfigDict(extra="forbid") 175 | 176 | nsd_dataset_config: NsdDatasetConfig 177 | 178 | pin_memory: bool = True 179 | workers: int = 0 180 | batch_size: int 181 | 182 | test_groupbyimg: tp.Literal["averaged", "unaveraged"] | None = None 183 | 184 | def build( 185 | self, 186 | ) -> LightningDataModule: 187 | data_module = NeuroImagesDataModule( 188 | config=self, 189 | ) 190 | data_module.setup() 191 | return data_module 192 | 193 | 194 | class NeuroImagesDataModule(LightningDataModule): 195 | def __init__(self, config): 196 | super().__init__() 197 | config = config if config is not None else NeuroImagesDataModuleConfig() 198 | self.nsd_dataset_config = config.nsd_dataset_config 199 | 200 | self.batch_size = config.batch_size 201 | self.workers = config.workers 202 | self.pin_memory = config.pin_memory 203 | 204 | self.test_groupbyimg = config.test_groupbyimg 205 | 206 | def setup(self, stage: tp.Optional[str] = None): 207 | 208 | self.eval_dataset = self.nsd_dataset_config.build() 209 | 210 | print( 211 | "Number of samples in test dataset:", 212 | len(self.eval_dataset), 213 | ) 214 | 215 | def val_dataloader(self): 216 | return DataLoader( 217 | self.eval_dataset, 218 | batch_size=self.batch_size, 219 | num_workers=self.workers, 220 | pin_memory=self.pin_memory, 221 | drop_last=False, 222 | ) 223 | 224 | def test_dataloader(self): 225 | return self.val_dataloader() -------------------------------------------------------------------------------- /metrics/image_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Code modified from https://github.com/ozcelikfu/brain-diffuser evaluation scripts 8 | 9 | from pathlib import Path 10 | 11 | import clip 12 | import numpy as np 13 | import scipy as sp 14 | import torch 15 | import torchvision.models as tvmodels 16 | import torchvision.transforms as transforms 17 | import torchvision.transforms as T 18 | from dreamsim import dreamsim 19 | from PIL import Image 20 | from scipy.stats import binom 21 | from skimage.color import rgb2gray 22 | from skimage.metrics import structural_similarity as ssim 23 | from torch.utils.data import DataLoader, Dataset 24 | from tqdm import tqdm 25 | 26 | 27 | def _compute_image_generation_features(images, emb_batch_size=32, device="cuda:0"): 28 | class batch_generator_external_images(Dataset): 29 | def __init__(self, images: list, net_name="clip"): 30 | self.images = images 31 | self.net_name = net_name 32 | 33 | if self.net_name == "clip": 34 | self.normalize = transforms.Normalize( 35 | mean=[0.48145466, 0.4578275, 0.40821073], 36 | std=[0.26862954, 0.26130258, 0.27577711], 37 | ) 38 | else: 39 | self.normalize = transforms.Normalize( 40 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 41 | ) 42 | 43 | def __getitem__(self, idx): 44 | img = self.images[idx] 45 | img = T.functional.resize(img, (224, 224)) 46 | img = T.functional.to_tensor(img).float() 47 | img = self.normalize(img) 48 | return img 49 | 50 | def __len__(self): 51 | return len(self.images) 52 | 53 | global feat_list 54 | feat_list = [] 55 | 56 | def fn(module, inputs, outputs): 57 | feat_list.append(outputs.cpu().numpy()) 58 | 59 | net_list = [ 60 | ("inceptionv3", "avgpool"), 61 | ("clip", "final"), 62 | ("alexnet", 2), 63 | ("alexnet", 5), 64 | ("efficientnet", "avgpool"), 65 | ("swav", "avgpool"), 66 | ] 67 | 68 | net = None 69 | 70 | result = dict() 71 | 72 | for net_name, layer in net_list: 73 | feat_list = [] 74 | print(net_name, layer) 75 | dataset = batch_generator_external_images(images=images, net_name=net_name) 76 | loader = DataLoader(dataset, emb_batch_size, shuffle=False) 77 | 78 | if net_name == "inceptionv3": 79 | net = tvmodels.inception_v3(pretrained=True) 80 | if layer == "avgpool": 81 | net.avgpool.register_forward_hook(fn) 82 | elif layer == "lastconv": 83 | net.Mixed_7c.register_forward_hook(fn) 84 | 85 | elif net_name == "alexnet": 86 | net = tvmodels.alexnet(pretrained=True) 87 | if layer == 2: 88 | net.features[4].register_forward_hook(fn) 89 | elif layer == 5: 90 | net.features[11].register_forward_hook(fn) 91 | elif layer == 7: 92 | net.classifier[5].register_forward_hook(fn) 93 | 94 | elif net_name == "clip": 95 | model, _ = clip.load("ViT-L/14", device=device) 96 | net = model.visual 97 | net = net.to(torch.float32) 98 | if layer == 7: 99 | net.transformer.resblocks[7].register_forward_hook(fn) 100 | elif layer == 12: 101 | net.transformer.resblocks[12].register_forward_hook(fn) 102 | elif layer == "final": 103 | net.register_forward_hook(fn) 104 | 105 | elif net_name == "efficientnet": 106 | net = tvmodels.efficientnet_b1(weights=True) 107 | net.avgpool.register_forward_hook(fn) 108 | 109 | elif net_name == "swav": 110 | net = torch.hub.load("facebookresearch/swav:main", "resnet50") 111 | net.avgpool.register_forward_hook(fn) 112 | 113 | net = net.to(device) 114 | net.eval() 115 | 116 | with torch.no_grad(): 117 | for i, x in tqdm(enumerate(loader), total=len(loader)): 118 | x = x.to(device) 119 | _ = net(x) 120 | if net_name == "clip": 121 | if layer == 7 or layer == 12: 122 | feat_list = np.concatenate(feat_list, axis=1).transpose((1, 0, 2)) 123 | else: 124 | feat_list = np.concatenate(feat_list) 125 | else: 126 | feat_list = np.concatenate(feat_list) 127 | 128 | result[net_name + "-" + str(layer)] = feat_list 129 | 130 | return result 131 | 132 | 133 | def _pairwise_corr_all(ground_truth, predictions): 134 | r = np.corrcoef(ground_truth, predictions) 135 | r = r[: len(ground_truth), len(ground_truth) :] 136 | congruents = np.diag(r) 137 | 138 | success = r < congruents 139 | success_cnt = np.sum(success, 0) 140 | 141 | perf = np.mean(success_cnt) / (len(ground_truth) - 1) 142 | p = 1 - binom.cdf( 143 | perf * len(ground_truth) * (len(ground_truth) - 1), 144 | len(ground_truth) * (len(ground_truth) - 1), 145 | 0.5, 146 | ) 147 | 148 | return perf, p 149 | 150 | 151 | def compute_dreamsim( 152 | preds: list[str | Path | Image.Image], 153 | trues: list[str | Path | Image.Image], 154 | device: str = "cuda:0", 155 | ): 156 | 157 | 158 | model_dreamsim, preprocess_dreamsim = dreamsim( 159 | pretrained=True, 160 | ) 161 | 162 | dreamsim_list = [] 163 | for pred, true in zip(preds, trues): 164 | dreamsim_list += [ 165 | model_dreamsim( 166 | preprocess_dreamsim(pred).to(device), 167 | preprocess_dreamsim(true).to(device), 168 | ).cpu().numpy().item() 169 | ] 170 | return np.array(dreamsim_list).mean() 171 | 172 | 173 | def compute_image_generation_metrics( 174 | preds, 175 | trues, 176 | imsize_for_pixel_level_metrics=425, 177 | emb_batch_size=32, 178 | device="cuda:0", 179 | ): 180 | result = dict() 181 | 182 | assert len(preds) == len(trues) 183 | n = len(preds) 184 | 185 | all_images = trues + preds 186 | feats = _compute_image_generation_features( 187 | all_images, emb_batch_size=emb_batch_size, device=device 188 | ) 189 | 190 | gt_feats = dict() 191 | eval_feats = dict() 192 | 193 | for metric_name in feats: 194 | gt_feats[metric_name] = feats[metric_name][:n] 195 | eval_feats[metric_name] = feats[metric_name][n:] 196 | 197 | distance_fn = sp.spatial.distance.correlation 198 | for metric_name in gt_feats.keys(): 199 | gt_feat = gt_feats[metric_name] 200 | eval_feat = eval_feats[metric_name] 201 | n = len(gt_feat) 202 | 203 | gt_feat = gt_feat.reshape((len(gt_feat), -1)) 204 | eval_feat = eval_feat.reshape((len(eval_feat), -1)) 205 | 206 | net_name, _ = metric_name.split("-") 207 | 208 | if net_name in ["efficientnet", "swav"]: 209 | distances = np.array( 210 | [distance_fn(gt_feat[i], eval_feat[i]) for i in range(n)] 211 | ) 212 | result[metric_name] = distances.mean() 213 | else: 214 | result[metric_name] = _pairwise_corr_all(gt_feat, eval_feat)[0] 215 | 216 | ssim_list = [] 217 | pixcorr_list = [] 218 | for i in range(n): 219 | gen_image = preds[i].resize( 220 | (imsize_for_pixel_level_metrics, imsize_for_pixel_level_metrics) 221 | ) 222 | gt_image = trues[i].resize( 223 | (imsize_for_pixel_level_metrics, imsize_for_pixel_level_metrics) 224 | ) 225 | 226 | gen_image = np.array(gen_image) / 255.0 227 | gt_image = np.array(gt_image) / 255.0 228 | pixcorr_res = np.corrcoef(gt_image.reshape(1, -1), gen_image.reshape(1, -1))[0, 1] 229 | pixcorr_list.append(pixcorr_res) 230 | gen_image = rgb2gray(gen_image) 231 | gt_image = rgb2gray(gt_image) 232 | ssim_res = ssim( 233 | gen_image, 234 | gt_image, 235 | multichannel=True, 236 | gaussian_weights=True, 237 | sigma=1.5, 238 | use_sample_covariance=False, 239 | data_range=1.0, 240 | ) 241 | ssim_list.append(ssim_res) 242 | 243 | ssim_list = np.array(ssim_list) 244 | pixcorr_list = np.array(pixcorr_list) 245 | result["pixcorr"] = pixcorr_list.mean() 246 | result["ssim"] = ssim_list.mean() 247 | result['dreamsim'] = compute_dreamsim(preds, trues, device=device) 248 | 249 | return result 250 | -------------------------------------------------------------------------------- /model/fmri_mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | import typing as tp 9 | from functools import partial 10 | 11 | import pydantic 12 | import torch 13 | from torch import nn 14 | 15 | try: 16 | from diffusers.models.vae import Decoder 17 | except: 18 | from diffusers.models.autoencoders.vae import Decoder 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | class Mean(nn.Module): 25 | def __init__(self, dim: int, keepdim: bool = False): 26 | super().__init__() 27 | self.dim = dim 28 | self.keepdim = keepdim 29 | 30 | def forward(self, x: torch.Tensor) -> torch.Tensor: 31 | return x.mean(dim=self.dim, keepdim=self.keepdim) 32 | 33 | 34 | class SubjectLayers(nn.Module): 35 | """Per subject linear layer.""" 36 | 37 | def __init__( 38 | self, 39 | in_channels: int, 40 | out_channels: int, 41 | n_subjects: int, 42 | init_id: bool = False, 43 | mode: tp.Literal["gather", "for_loop"] = "gather", 44 | ): 45 | super().__init__() 46 | self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels)) 47 | if init_id: 48 | assert in_channels == out_channels 49 | self.weights.data[:] = torch.eye(in_channels)[None] 50 | self.weights.data *= 1 / in_channels**0.5 51 | self.mode = mode 52 | 53 | def forward(self, x: torch.Tensor, subjects: torch.Tensor) -> torch.Tensor: 54 | N, C, D = self.weights.shape 55 | 56 | if self.mode == "gather": 57 | weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D)) 58 | out = torch.einsum("bct,bcd->bdt", x, weights) 59 | elif self.mode == "for_loop": 60 | B, _, T = x.shape 61 | out = torch.empty((B, D, T), device=x.device, dtype=x.dtype) 62 | for subject in subjects.unique(): 63 | mask = subjects.reshape(-1) == subject 64 | id_weights = subject 65 | out[mask] = torch.einsum("bct,cd->bdt", x[mask], self.weights[id_weights]) 66 | else: 67 | raise NotImplementedError() 68 | 69 | return out 70 | 71 | def __repr__(self): 72 | S, C, D = self.weights.shape 73 | return f"SubjectLayers({C}, {D}, {S})" 74 | 75 | 76 | class DeeperSubjectLayers(nn.Module): 77 | """Per subject linear layer.""" 78 | 79 | def __init__( 80 | self, 81 | in_channels: int, 82 | out_channels: int, 83 | n_subjects: int, 84 | init_id: bool = False, 85 | mode: tp.Literal["gather", "for_loop"] = "gather", 86 | mlp_n_blocks: int = 4, 87 | ): 88 | super().__init__() 89 | self.mlp_n_blocks = mlp_n_blocks 90 | self.weights = nn.Parameter(torch.randn(n_subjects, in_channels, out_channels)) 91 | if init_id: 92 | assert in_channels == out_channels 93 | self.weights.data[:] = torch.eye(in_channels)[None] 94 | self.weights.data *= 1 / in_channels**0.5 95 | self.mode = mode 96 | 97 | norm_func = partial(nn.LayerNorm, normalized_shape=out_channels) 98 | act_fn = nn.GELU 99 | act_and_norm = (norm_func, act_fn) 100 | self.mlp = nn.ModuleList( 101 | nn.ModuleList( 102 | [ 103 | nn.Sequential( 104 | nn.Linear(out_channels, out_channels), 105 | *[item() for item in act_and_norm], 106 | nn.Dropout(0.15), 107 | ) 108 | for _ in range(mlp_n_blocks) 109 | ] 110 | ) 111 | for _ in range(n_subjects) 112 | ) 113 | 114 | def forward(self, x: torch.Tensor, subjects: torch.Tensor) -> torch.Tensor: 115 | N, C, D = self.weights.shape 116 | assert subjects.max() < N, ( 117 | f"Subject index ({subjects.max()}) too high for number of subjects used to initialize" 118 | f" the weights ({N})." 119 | ) 120 | 121 | if self.mode == "gather": 122 | weights = self.weights.gather(0, subjects.view(-1, 1, 1).expand(-1, C, D)) 123 | out = torch.einsum("bct,bcd->bdt", x, weights) 124 | elif self.mode == "for_loop": 125 | B, _, T = x.shape 126 | out = torch.empty((B, D, T), device=x.device, dtype=x.dtype) 127 | for subject in subjects.unique(): 128 | mask = subjects.reshape(-1) == subject 129 | out_int = torch.einsum("bct,cd->bdt", x[mask], self.weights[subject]) 130 | out_int = out_int.permute(0, 2, 1) 131 | mlp = self.mlp[subject] 132 | residual = out_int 133 | for res_block in range(self.mlp_n_blocks): 134 | out_int = mlp[res_block](out_int) 135 | out_int += residual 136 | residual = out_int 137 | out[mask] = out_int.permute(0, 2, 1) 138 | else: 139 | raise NotImplementedError() 140 | 141 | return out 142 | 143 | def __repr__(self): 144 | S, C, D = self.weights.shape 145 | return f"SubjectLayers({C}, {D}, {S})" 146 | 147 | 148 | class FmriMLPConfig(pydantic.BaseModel): 149 | model_config = pydantic.ConfigDict(extra="forbid") 150 | name: tp.Literal["FmriMLP"] = "FmriMLP" # type: ignore 151 | 152 | hidden: int = 4096 153 | n_blocks: int = 4 154 | norm_type: str = "ln" 155 | act_first: bool = False 156 | 157 | n_repetition_times: int = 1 158 | time_agg: tp.Literal["in_mean", "in_linear", "out_mean", "out_linear"] = "out_linear" 159 | 160 | # TR embeddings 161 | use_tr_embeds: bool = False 162 | tr_embed_dim: int = 16 163 | use_tr_layer: bool = False 164 | 165 | # Control output size explicitly 166 | out_dim: int | None = None 167 | 168 | # Subject-specific settings 169 | subject_layers: bool = False 170 | deep_subject_layers: bool = False 171 | n_subjects: int = 20 172 | subject_layers_dim: tp.Literal["input", "hidden"] = "hidden" 173 | subject_layers_id: bool = False 174 | 175 | # blurry recons 176 | blurry_recon: bool = False 177 | 178 | native_fmri_space: bool = False 179 | 180 | def build(self, n_in_channels: int, n_outputs: int | None) -> nn.Module: 181 | if n_outputs is None and self.out_dim is None: 182 | raise ValueError("One of n_outputs or config.out_dim must be set.") 183 | return FmriMLP( 184 | in_dim=n_in_channels, 185 | out_dim=self.out_dim if n_outputs is None else n_outputs, 186 | config=self, 187 | ) 188 | 189 | 190 | class FmriMLP(nn.Module): 191 | """Residual MLP adapted from [1]. 192 | 193 | See https://github.com/MedARC-AI/fMRI-reconstruction-NSD/blob/main/src/models.py#L171 194 | 195 | References 196 | ---------- 197 | [1] Scotti, Paul, et al. "Reconstructing the mind's eye: fMRI-to-image with contrastive 198 | learning and diffusion priors." Advances in Neural Information Processing Systems 36 199 | (2024). 200 | """ 201 | 202 | def __init__( 203 | self, 204 | in_dim: int, 205 | out_dim: int, 206 | config: FmriMLPConfig | None = None, 207 | ): 208 | super().__init__() 209 | 210 | # Temporal aggregation 211 | self.in_time_agg, self.out_time_agg = None, None 212 | self.n_repetition_times = config.n_repetition_times 213 | self.blurry_recon = config.blurry_recon 214 | if config.time_agg == "in_mean": 215 | self.in_time_agg = Mean(dim=2, keepdim=True) 216 | self.n_repetition_times = 1 217 | elif config.time_agg == "in_linear": 218 | self.in_time_agg = nn.Linear(self.n_repetition_times, 1) 219 | 220 | self.n_repetition_times = 1 221 | elif config.time_agg == "out_mean": 222 | self.out_time_agg = Mean(dim=2) 223 | elif config.time_agg == "out_linear": 224 | self.out_time_agg = nn.Linear(self.n_repetition_times, 1) 225 | 226 | norm_func = ( 227 | partial(nn.BatchNorm1d, num_features=config.hidden) 228 | if config.norm_type == "bn" 229 | else partial(nn.LayerNorm, normalized_shape=config.hidden) 230 | ) 231 | act_fn = partial(nn.ReLU, inplace=True) if config.norm_type == "bn" else nn.GELU 232 | act_and_norm = (act_fn, norm_func) if config.act_first else (norm_func, act_fn) 233 | 234 | self.proj2flat = None 235 | if config.native_fmri_space: 236 | self.proj2flat = nn.Sequential( 237 | *[ 238 | nn.Conv3d(1, 8, 3, stride=2, padding=1), 239 | nn.LayerNorm([39, 47, 40]), 240 | nn.GELU(), 241 | nn.Conv3d(8, 8, 3, stride=2, padding=1), 242 | nn.LayerNorm([20, 24, 20]), 243 | nn.GELU(), 244 | nn.Conv3d(8, 4, 2, stride=1, padding=1), 245 | nn.LayerNorm([21, 25, 21]), 246 | nn.GELU(), 247 | nn.Conv3d(4, 2, 2, stride=1, padding=1), 248 | ] 249 | ) 250 | 251 | # Subject-specific linear layer 252 | self.subject_layers = None 253 | if config.subject_layers: 254 | dim = {"hidden": config.hidden, "input": in_dim}[config.subject_layers_dim] 255 | if not config.deep_subject_layers: 256 | self.subject_layers = SubjectLayers( 257 | in_dim, 258 | dim, 259 | config.n_subjects, 260 | config.subject_layers_id, 261 | mode="for_loop", 262 | ) 263 | else: 264 | self.subject_layers = DeeperSubjectLayers( 265 | in_dim, 266 | dim, 267 | config.n_subjects, 268 | config.subject_layers_id, 269 | mode="for_loop", 270 | ) 271 | 272 | in_dim = dim 273 | 274 | # TR embeddings 275 | self.tr_embeddings = None 276 | if config.use_tr_embeds: 277 | self.tr_embeddings = nn.Embedding( 278 | self.n_repetition_times, config.tr_embed_dim 279 | ) 280 | in_dim += config.tr_embed_dim 281 | 282 | if config.use_tr_layer: 283 | # depthwise convolution 284 | # Each TR is passed to a (distinct) linear layer Linear(in_dim, config.hidden) 285 | self.lin0 = nn.Conv1d( 286 | in_channels=self.n_repetition_times, 287 | out_channels=self.n_repetition_times * config.hidden, 288 | kernel_size=in_dim, 289 | groups=self.n_repetition_times, 290 | bias=True, 291 | ) 292 | else: 293 | self.lin0 = nn.Linear(in_dim, config.hidden) 294 | self.post_lin0 = nn.Sequential( 295 | *[item() for item in act_and_norm], nn.Dropout(0.5) 296 | ) 297 | 298 | self.n_blocks = config.n_blocks 299 | self.mlp = nn.ModuleList( 300 | [ 301 | nn.Sequential( 302 | nn.Linear(config.hidden, config.hidden), 303 | *[item() for item in act_and_norm], 304 | nn.Dropout(0.15), 305 | ) 306 | for _ in range(config.n_blocks) 307 | ] 308 | ) 309 | if not self.blurry_recon: 310 | self.lin1 = nn.Linear(config.hidden, out_dim) 311 | else: 312 | self.blin1 = nn.Linear(config.hidden, 4 * 28 * 28) 313 | self.bdropout = nn.Dropout(0.3) 314 | self.bnorm = nn.GroupNorm(1, 64) 315 | self.bupsampler = Decoder( 316 | in_channels=64, 317 | out_channels=4, 318 | up_block_types=[ 319 | "UpDecoderBlock2D", 320 | "UpDecoderBlock2D", 321 | "UpDecoderBlock2D", 322 | ], 323 | block_out_channels=[32, 64, 128], 324 | layers_per_block=1, 325 | ) 326 | 327 | 328 | def forward( 329 | self, 330 | x: torch.Tensor, 331 | subject_ids: torch.Tensor | None = None, 332 | channel_positions: torch.Tensor | None = None, # Unused 333 | ) -> torch.Tensor: 334 | 335 | if self.proj2flat is not None: 336 | bs = x.size(0) 337 | x = x.permute(0, 4, 1, 2, 3) ## to have (B,T,D,H,W) 338 | x = x.reshape( 339 | x.shape[0] * x.shape[1], x.shape[2], x.shape[3], x.shape[4] 340 | ) # (B*T,D,H,W) 341 | x = x[:, None] # (B*T,1, D,H,W) to have C = 1 for 3d conv 342 | x = self.proj2flat(x) # (B*T, 2, 22,26,22) 343 | x = x.reshape(bs, self.n_repetition_times, -1) # (B, T, F) 344 | x = x.permute(0, 2, 1) # (B, F, T) 345 | else: 346 | x = x.reshape(x.shape[0], -1, x.shape[-1]) # (B, F, T) 347 | 348 | if self.in_time_agg is not None: 349 | x = self.in_time_agg(x) # (B, F, 1) 350 | 351 | B, F, T = x.shape 352 | 353 | assert ( 354 | T == self.n_repetition_times 355 | ), f"Mismatch between expected and provided number TRs: {T} != {self.n_repetition_times}" 356 | 357 | if self.subject_layers is not None: 358 | x = self.subject_layers(x, subject_ids) 359 | x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) 360 | 361 | if self.tr_embeddings is not None: 362 | embeds = self.tr_embeddings(torch.arange(T, device=x.device)) 363 | embeds = torch.tile(embeds, dims=(B, 1, 1)) 364 | x = torch.cat([x, embeds], dim=2) 365 | 366 | x = self.lin0(x).reshape(B, T, -1) # (B, T, F) -> (B, T * F, 1) -> (B, T, F) 367 | x = self.post_lin0(x) 368 | 369 | residual = x 370 | for res_block in range(self.n_blocks): 371 | x = self.mlp[res_block](x) 372 | x += residual 373 | residual = x 374 | 375 | x = x.permute(0, 2, 1) # (B, T, F) -> (B, F, T) 376 | if self.out_time_agg is not None: 377 | x = self.out_time_agg(x) # (B, F, 1) 378 | x = x.flatten(1) # Ensure 2D 379 | 380 | if self.blurry_recon: 381 | b = self.blin1(x) 382 | b = self.bdropout(b) 383 | b = b.reshape(b.shape[0], -1, 7, 7).contiguous() 384 | b = self.bnorm(b) 385 | x_final = self.bupsampler(b) 386 | else: 387 | x_final = self.lin1(x) 388 | 389 | 390 | return { 391 | "MSELoss": x_final 392 | } 393 | -------------------------------------------------------------------------------- /metrics/mIOU/segment_vitadapter.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Copyright (c) OpenMMLab. All rights reserved. 8 | import argparse 9 | import sys 10 | from pathlib import Path 11 | 12 | import cv2 13 | import mmcv 14 | import torch 15 | from mmcv.parallel import MMDataParallel 16 | from mmcv.runner import load_checkpoint, wrap_fp16_model 17 | from mmseg.datasets.pipelines import Compose 18 | from mmseg.models import build_segmentor 19 | from PIL import Image 20 | from tqdm import tqdm 21 | 22 | sys.path.append("ViT-Adapter/segmentation/") 23 | 24 | # CAVEAT: the script won't run without this seemingly unused import 25 | # see https://github.com/czczup/ViT-Adapter/issues/29 26 | import mmseg_custom # noqa: F401 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description="mmseg test (and eval) a model") 31 | parser.add_argument("config", help="test config file path") 32 | parser.add_argument("checkpoint", help="checkpoint file") 33 | parser.add_argument( 34 | "--work-dir", 35 | help=( 36 | "if specified, the evaluation metric results will be dumped" 37 | "into the directory as json" 38 | ), 39 | ) 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | def main(args): 45 | cfg = mmcv.Config.fromfile(args.config) 46 | # set cudnn_benchmark 47 | if cfg.get("cudnn_benchmark", False): 48 | torch.backends.cudnn.benchmark = True 49 | cfg.model.pretrained = None 50 | cfg.data.test.test_mode = True 51 | 52 | # build the model and load checkpoint 53 | cfg.model.train_cfg = None 54 | model = build_segmentor(cfg.model, test_cfg=cfg.get("test_cfg")) 55 | fp16_cfg = cfg.get("fp16", None) 56 | if fp16_cfg is not None: 57 | wrap_fp16_model(model) 58 | checkpoint = load_checkpoint(model, args.checkpoint, map_location="cpu") 59 | if "CLASSES" in checkpoint.get("meta", {}): 60 | model.CLASSES = checkpoint["meta"]["CLASSES"] 61 | else: 62 | print('"CLASSES" not found in meta, use dataset.CLASSES instead') 63 | model.CLASSES = ADE_CLASSES 64 | if "PALETTE" in checkpoint.get("meta", {}): 65 | model.PALETTE = checkpoint["meta"]["PALETTE"] 66 | else: 67 | print('"PALETTE" not found in meta, use dataset.PALETTE instead') 68 | model.PALETTE = ADE_PALETTE 69 | print(model.CLASSES) 70 | # clean gpu memory when starting a new evaluation. 71 | torch.cuda.empty_cache() 72 | 73 | model = MMDataParallel(model, device_ids=[0]) 74 | model.eval() 75 | 76 | # clean gpu memory when starting a new evaluation. 77 | torch.cuda.empty_cache() 78 | 79 | #### Read data to be segmented in image folder 80 | root = Path(args.work_dir) 81 | predlbl_path = root.joinpath("pred_label") 82 | predlbl_path.mkdir(exist_ok=True, parents=True) 83 | images_path = root.joinpath("images") 84 | 85 | sample_ids = [f.name for f in images_path.iterdir() if f.name.endswith(".png")] 86 | for filename in tqdm(sample_ids): 87 | file_name = predlbl_path.joinpath(filename) 88 | Image.new("RGB", (1, 1)).save(file_name) 89 | pipeline = Compose(cfg.data.test.pipeline) 90 | results = {} 91 | results["img_info"] = {"filename": filename} 92 | results["seg_fields"] = [] 93 | results["img_prefix"] = images_path 94 | results["seg_prefix"] = None 95 | data = pipeline(results) 96 | data["img"][0] = torch.unsqueeze(data["img"][0], dim=0) 97 | data["img_metas"][0]._data = [[data["img_metas"][0]._data]] 98 | 99 | with torch.no_grad(): 100 | result = model(return_loss=False, rescale=True, **data) 101 | # yield result[0] 102 | cv2.imwrite(str(predlbl_path.joinpath(filename)), result[0]) 103 | 104 | 105 | CLASSES = ( 106 | "person", 107 | "bicycle", 108 | "car", 109 | "motorcycle", 110 | "airplane", 111 | "bus", 112 | "train", 113 | "truck", 114 | "boat", 115 | "traffic light", 116 | "fire hydrant", 117 | "stop sign", 118 | "parking meter", 119 | "bench", 120 | "bird", 121 | "cat", 122 | "dog", 123 | "horse", 124 | "sheep", 125 | "cow", 126 | "elephant", 127 | "bear", 128 | "zebra", 129 | "giraffe", 130 | "backpack", 131 | "umbrella", 132 | "handbag", 133 | "tie", 134 | "suitcase", 135 | "frisbee", 136 | "skis", 137 | "snowboard", 138 | "sports ball", 139 | "kite", 140 | "baseball bat", 141 | "baseball glove", 142 | "skateboard", 143 | "surfboard", 144 | "tennis racket", 145 | "bottle", 146 | "wine glass", 147 | "cup", 148 | "fork", 149 | "knife", 150 | "spoon", 151 | "bowl", 152 | "banana", 153 | "apple", 154 | "sandwich", 155 | "orange", 156 | "broccoli", 157 | "carrot", 158 | "hot dog", 159 | "pizza", 160 | "donut", 161 | "cake", 162 | "chair", 163 | "couch", 164 | "potted plant", 165 | "bed", 166 | "dining table", 167 | "toilet", 168 | "tv", 169 | "laptop", 170 | "mouse", 171 | "remote", 172 | "keyboard", 173 | "cell phone", 174 | "microwave", 175 | "oven", 176 | "toaster", 177 | "sink", 178 | "refrigerator", 179 | "book", 180 | "clock", 181 | "vase", 182 | "scissors", 183 | "teddy bear", 184 | "hair drier", 185 | "toothbrush", 186 | "banner", 187 | "blanket", 188 | "branch", 189 | "bridge", 190 | "building-other", 191 | "bush", 192 | "cabinet", 193 | "cage", 194 | "cardboard", 195 | "carpet", 196 | "ceiling-other", 197 | "ceiling-tile", 198 | "cloth", 199 | "clothes", 200 | "clouds", 201 | "counter", 202 | "cupboard", 203 | "curtain", 204 | "desk-stuff", 205 | "dirt", 206 | "door-stuff", 207 | "fence", 208 | "floor-marble", 209 | "floor-other", 210 | "floor-stone", 211 | "floor-tile", 212 | "floor-wood", 213 | "flower", 214 | "fog", 215 | "food-other", 216 | "fruit", 217 | "furniture-other", 218 | "grass", 219 | "gravel", 220 | "ground-other", 221 | "hill", 222 | "house", 223 | "leaves", 224 | "light", 225 | "mat", 226 | "metal", 227 | "mirror-stuff", 228 | "moss", 229 | "mountain", 230 | "mud", 231 | "napkin", 232 | "net", 233 | "paper", 234 | "pavement", 235 | "pillow", 236 | "plant-other", 237 | "plastic", 238 | "platform", 239 | "playingfield", 240 | "railing", 241 | "railroad", 242 | "river", 243 | "road", 244 | "rock", 245 | "roof", 246 | "rug", 247 | "salad", 248 | "sand", 249 | "sea", 250 | "shelf", 251 | "sky-other", 252 | "skyscraper", 253 | "snow", 254 | "solid-other", 255 | "stairs", 256 | "stone", 257 | "straw", 258 | "structural-other", 259 | "table", 260 | "tent", 261 | "textile-other", 262 | "towel", 263 | "tree", 264 | "vegetable", 265 | "wall-brick", 266 | "wall-concrete", 267 | "wall-other", 268 | "wall-panel", 269 | "wall-stone", 270 | "wall-tile", 271 | "wall-wood", 272 | "water-other", 273 | "waterdrops", 274 | "window-blind", 275 | "window-other", 276 | "wood", 277 | ) 278 | 279 | PALETTE = [ 280 | [0, 192, 64], 281 | [0, 192, 64], 282 | [0, 64, 96], 283 | [128, 192, 192], 284 | [0, 64, 64], 285 | [0, 192, 224], 286 | [0, 192, 192], 287 | [128, 192, 64], 288 | [0, 192, 96], 289 | [128, 192, 64], 290 | [128, 32, 192], 291 | [0, 0, 224], 292 | [0, 0, 64], 293 | [0, 160, 192], 294 | [128, 0, 96], 295 | [128, 0, 192], 296 | [0, 32, 192], 297 | [128, 128, 224], 298 | [0, 0, 192], 299 | [128, 160, 192], 300 | [128, 128, 0], 301 | [128, 0, 32], 302 | [128, 32, 0], 303 | [128, 0, 128], 304 | [64, 128, 32], 305 | [0, 160, 0], 306 | [0, 0, 0], 307 | [192, 128, 160], 308 | [0, 32, 0], 309 | [0, 128, 128], 310 | [64, 128, 160], 311 | [128, 160, 0], 312 | [0, 128, 0], 313 | [192, 128, 32], 314 | [128, 96, 128], 315 | [0, 0, 128], 316 | [64, 0, 32], 317 | [0, 224, 128], 318 | [128, 0, 0], 319 | [192, 0, 160], 320 | [0, 96, 128], 321 | [128, 128, 128], 322 | [64, 0, 160], 323 | [128, 224, 128], 324 | [128, 128, 64], 325 | [192, 0, 32], 326 | [128, 96, 0], 327 | [128, 0, 192], 328 | [0, 128, 32], 329 | [64, 224, 0], 330 | [0, 0, 64], 331 | [128, 128, 160], 332 | [64, 96, 0], 333 | [0, 128, 192], 334 | [0, 128, 160], 335 | [192, 224, 0], 336 | [0, 128, 64], 337 | [128, 128, 32], 338 | [192, 32, 128], 339 | [0, 64, 192], 340 | [0, 0, 32], 341 | [64, 160, 128], 342 | [128, 64, 64], 343 | [128, 0, 160], 344 | [64, 32, 128], 345 | [128, 192, 192], 346 | [0, 0, 160], 347 | [192, 160, 128], 348 | [128, 192, 0], 349 | [128, 0, 96], 350 | [192, 32, 0], 351 | [128, 64, 128], 352 | [64, 128, 96], 353 | [64, 160, 0], 354 | [0, 64, 0], 355 | [192, 128, 224], 356 | [64, 32, 0], 357 | [0, 192, 128], 358 | [64, 128, 224], 359 | [192, 160, 0], 360 | [0, 192, 0], 361 | [192, 128, 96], 362 | [192, 96, 128], 363 | [0, 64, 128], 364 | [64, 0, 96], 365 | [64, 224, 128], 366 | [128, 64, 0], 367 | [192, 0, 224], 368 | [64, 96, 128], 369 | [128, 192, 128], 370 | [64, 0, 224], 371 | [192, 224, 128], 372 | [128, 192, 64], 373 | [192, 0, 96], 374 | [192, 96, 0], 375 | [128, 64, 192], 376 | [0, 128, 96], 377 | [0, 224, 0], 378 | [64, 64, 64], 379 | [128, 128, 224], 380 | [0, 96, 0], 381 | [64, 192, 192], 382 | [0, 128, 224], 383 | [128, 224, 0], 384 | [64, 192, 64], 385 | [128, 128, 96], 386 | [128, 32, 128], 387 | [64, 0, 192], 388 | [0, 64, 96], 389 | [0, 160, 128], 390 | [192, 0, 64], 391 | [128, 64, 224], 392 | [0, 32, 128], 393 | [192, 128, 192], 394 | [0, 64, 224], 395 | [128, 160, 128], 396 | [192, 128, 0], 397 | [128, 64, 32], 398 | [128, 32, 64], 399 | [192, 0, 128], 400 | [64, 192, 32], 401 | [0, 160, 64], 402 | [64, 0, 0], 403 | [192, 192, 160], 404 | [0, 32, 64], 405 | [64, 128, 128], 406 | [64, 192, 160], 407 | [128, 160, 64], 408 | [64, 128, 0], 409 | [192, 192, 32], 410 | [128, 96, 192], 411 | [64, 0, 128], 412 | [64, 64, 32], 413 | [0, 224, 192], 414 | [192, 0, 0], 415 | [192, 64, 160], 416 | [0, 96, 192], 417 | [192, 128, 128], 418 | [64, 64, 160], 419 | [128, 224, 192], 420 | [192, 128, 64], 421 | [192, 64, 32], 422 | [128, 96, 64], 423 | [192, 0, 192], 424 | [0, 192, 32], 425 | [64, 224, 64], 426 | [64, 0, 64], 427 | [128, 192, 160], 428 | [64, 96, 64], 429 | [64, 128, 192], 430 | [0, 192, 160], 431 | [192, 224, 64], 432 | [64, 128, 64], 433 | [128, 192, 32], 434 | [192, 32, 192], 435 | [64, 64, 192], 436 | [0, 64, 32], 437 | [64, 160, 192], 438 | [192, 64, 64], 439 | [128, 64, 160], 440 | [64, 32, 192], 441 | [192, 192, 192], 442 | [0, 64, 160], 443 | [192, 160, 192], 444 | [192, 192, 0], 445 | [128, 64, 96], 446 | [192, 32, 64], 447 | [192, 64, 128], 448 | [64, 192, 96], 449 | [64, 160, 64], 450 | [64, 64, 0], 451 | ] 452 | 453 | 454 | ADE_CLASSES = ( 455 | "wall", 456 | "building", 457 | "sky", 458 | "floor", 459 | "tree", 460 | "ceiling", 461 | "road", 462 | "bed ", 463 | "windowpane", 464 | "grass", 465 | "cabinet", 466 | "sidewalk", 467 | "person", 468 | "earth", 469 | "door", 470 | "table", 471 | "mountain", 472 | "plant", 473 | "curtain", 474 | "chair", 475 | "car", 476 | "water", 477 | "painting", 478 | "sofa", 479 | "shelf", 480 | "house", 481 | "sea", 482 | "mirror", 483 | "rug", 484 | "field", 485 | "armchair", 486 | "seat", 487 | "fence", 488 | "desk", 489 | "rock", 490 | "wardrobe", 491 | "lamp", 492 | "bathtub", 493 | "railing", 494 | "cushion", 495 | "base", 496 | "box", 497 | "column", 498 | "signboard", 499 | "chest of drawers", 500 | "counter", 501 | "sand", 502 | "sink", 503 | "skyscraper", 504 | "fireplace", 505 | "refrigerator", 506 | "grandstand", 507 | "path", 508 | "stairs", 509 | "runway", 510 | "case", 511 | "pool table", 512 | "pillow", 513 | "screen door", 514 | "stairway", 515 | "river", 516 | "bridge", 517 | "bookcase", 518 | "blind", 519 | "coffee table", 520 | "toilet", 521 | "flower", 522 | "book", 523 | "hill", 524 | "bench", 525 | "countertop", 526 | "stove", 527 | "palm", 528 | "kitchen island", 529 | "computer", 530 | "swivel chair", 531 | "boat", 532 | "bar", 533 | "arcade machine", 534 | "hovel", 535 | "bus", 536 | "towel", 537 | "light", 538 | "truck", 539 | "tower", 540 | "chandelier", 541 | "awning", 542 | "streetlight", 543 | "booth", 544 | "television receiver", 545 | "airplane", 546 | "dirt track", 547 | "apparel", 548 | "pole", 549 | "land", 550 | "bannister", 551 | "escalator", 552 | "ottoman", 553 | "bottle", 554 | "buffet", 555 | "poster", 556 | "stage", 557 | "van", 558 | "ship", 559 | "fountain", 560 | "conveyer belt", 561 | "canopy", 562 | "washer", 563 | "plaything", 564 | "swimming pool", 565 | "stool", 566 | "barrel", 567 | "basket", 568 | "waterfall", 569 | "tent", 570 | "bag", 571 | "minibike", 572 | "cradle", 573 | "oven", 574 | "ball", 575 | "food", 576 | "step", 577 | "tank", 578 | "trade name", 579 | "microwave", 580 | "pot", 581 | "animal", 582 | "bicycle", 583 | "lake", 584 | "dishwasher", 585 | "screen", 586 | "blanket", 587 | "sculpture", 588 | "hood", 589 | "sconce", 590 | "vase", 591 | "traffic light", 592 | "tray", 593 | "ashcan", 594 | "fan", 595 | "pier", 596 | "crt screen", 597 | "plate", 598 | "monitor", 599 | "bulletin board", 600 | "shower", 601 | "radiator", 602 | "glass", 603 | "clock", 604 | "flag", 605 | ) 606 | 607 | ADE_PALETTE = [ 608 | [120, 120, 120], 609 | [180, 120, 120], 610 | [6, 230, 230], 611 | [80, 50, 50], 612 | [4, 200, 3], 613 | [120, 120, 80], 614 | [140, 140, 140], 615 | [204, 5, 255], 616 | [230, 230, 230], 617 | [4, 250, 7], 618 | [224, 5, 255], 619 | [235, 255, 7], 620 | [150, 5, 61], 621 | [120, 120, 70], 622 | [8, 255, 51], 623 | [255, 6, 82], 624 | [143, 255, 140], 625 | [204, 255, 4], 626 | [255, 51, 7], 627 | [204, 70, 3], 628 | [0, 102, 200], 629 | [61, 230, 250], 630 | [255, 6, 51], 631 | [11, 102, 255], 632 | [255, 7, 71], 633 | [255, 9, 224], 634 | [9, 7, 230], 635 | [220, 220, 220], 636 | [255, 9, 92], 637 | [112, 9, 255], 638 | [8, 255, 214], 639 | [7, 255, 224], 640 | [255, 184, 6], 641 | [10, 255, 71], 642 | [255, 41, 10], 643 | [7, 255, 255], 644 | [224, 255, 8], 645 | [102, 8, 255], 646 | [255, 61, 6], 647 | [255, 194, 7], 648 | [255, 122, 8], 649 | [0, 255, 20], 650 | [255, 8, 41], 651 | [255, 5, 153], 652 | [6, 51, 255], 653 | [235, 12, 255], 654 | [160, 150, 20], 655 | [0, 163, 255], 656 | [140, 140, 140], 657 | [250, 10, 15], 658 | [20, 255, 0], 659 | [31, 255, 0], 660 | [255, 31, 0], 661 | [255, 224, 0], 662 | [153, 255, 0], 663 | [0, 0, 255], 664 | [255, 71, 0], 665 | [0, 235, 255], 666 | [0, 173, 255], 667 | [31, 0, 255], 668 | [11, 200, 200], 669 | [255, 82, 0], 670 | [0, 255, 245], 671 | [0, 61, 255], 672 | [0, 255, 112], 673 | [0, 255, 133], 674 | [255, 0, 0], 675 | [255, 163, 0], 676 | [255, 102, 0], 677 | [194, 255, 0], 678 | [0, 143, 255], 679 | [51, 255, 0], 680 | [0, 82, 255], 681 | [0, 255, 41], 682 | [0, 255, 173], 683 | [10, 0, 255], 684 | [173, 255, 0], 685 | [0, 255, 153], 686 | [255, 92, 0], 687 | [255, 0, 255], 688 | [255, 0, 245], 689 | [255, 0, 102], 690 | [255, 173, 0], 691 | [255, 0, 20], 692 | [255, 184, 184], 693 | [0, 31, 255], 694 | [0, 255, 61], 695 | [0, 71, 255], 696 | [255, 0, 204], 697 | [0, 255, 194], 698 | [0, 255, 82], 699 | [0, 10, 255], 700 | [0, 112, 255], 701 | [51, 0, 255], 702 | [0, 194, 255], 703 | [0, 122, 255], 704 | [0, 255, 163], 705 | [255, 153, 0], 706 | [0, 255, 10], 707 | [255, 112, 0], 708 | [143, 255, 0], 709 | [82, 0, 255], 710 | [163, 255, 0], 711 | [255, 235, 0], 712 | [8, 184, 170], 713 | [133, 0, 255], 714 | [0, 255, 92], 715 | [184, 0, 255], 716 | [255, 0, 31], 717 | [0, 184, 255], 718 | [0, 214, 255], 719 | [255, 0, 112], 720 | [92, 255, 0], 721 | [0, 224, 255], 722 | [112, 224, 255], 723 | [70, 184, 160], 724 | [163, 0, 255], 725 | [153, 0, 255], 726 | [71, 255, 0], 727 | [255, 0, 163], 728 | [255, 204, 0], 729 | [255, 0, 143], 730 | [0, 255, 235], 731 | [133, 255, 0], 732 | [255, 0, 235], 733 | [245, 0, 255], 734 | [255, 0, 122], 735 | [255, 245, 0], 736 | [10, 190, 212], 737 | [214, 255, 0], 738 | [0, 204, 255], 739 | [20, 0, 255], 740 | [255, 255, 0], 741 | [0, 153, 255], 742 | [0, 41, 255], 743 | [0, 255, 204], 744 | [41, 0, 255], 745 | [41, 255, 0], 746 | [173, 0, 255], 747 | [0, 245, 255], 748 | [71, 0, 255], 749 | [122, 0, 255], 750 | [0, 255, 184], 751 | [0, 92, 255], 752 | [184, 255, 0], 753 | [0, 133, 255], 754 | [255, 214, 0], 755 | [25, 194, 194], 756 | [102, 255, 0], 757 | [92, 0, 255], 758 | ] 759 | 760 | if __name__ == "__main__": 761 | args = parse_args() 762 | main(args) 763 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import typing as tp 8 | from functools import partial 9 | from typing import List 10 | 11 | import pydantic 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import torchvision.transforms as T 16 | from diffusers import DDIMScheduler, DDPMScheduler, VersatileDiffusionDualGuidedPipeline 17 | from diffusers.models import DualTransformer2DModel 18 | from model.fmri_mlp import FmriMLPConfig 19 | from model.peft_utils import add_adapter 20 | from peft import LoraConfig 21 | from torch import Tensor 22 | 23 | 24 | class DiffusionOutput(tp.NamedTuple): 25 | image: Tensor 26 | losses: dict = None 27 | t_diffusion: Tensor = None 28 | brain_embeddings: dict = None 29 | 30 | 31 | class VersatileDiffusionConfig(pydantic.BaseModel): 32 | model_config = pydantic.ConfigDict(extra="forbid") 33 | name: tp.Literal["VersatileDiffusion"] = "VersatileDiffusion" 34 | 35 | vd_cache_dir: str = "/fsx-brainai/marlenec/vd_cache_dir" 36 | in_dim: int = 15724 37 | num_inference_steps: int = 20 38 | 39 | 40 | diffusion_noise_offset: bool = False 41 | prediction_type: str = "epsilon" 42 | noise_cubic_sampling: bool = False 43 | drop_rate_clsfree: float = 0.1 44 | trainable_unet_layers: str = "lora" 45 | 46 | training_strategy: tp.Literal["w/_difloss", "w/o_difloss"] = "w/_difloss" 47 | 48 | 49 | brain_modules_config: dict[str, FmriMLPConfig] | None = None 50 | 51 | 52 | 53 | 54 | def build( 55 | self, brain_n_in_channels: int | None = None, brain_temp_dim: int | None = None 56 | ) -> nn.Module: 57 | return VersatileDiffusion( 58 | config=self, 59 | brain_n_in_channels=brain_n_in_channels, 60 | brain_temp_dim=brain_temp_dim, 61 | ) 62 | 63 | 64 | class VersatileDiffusion(nn.Module): 65 | """End-to-end finetuning on brain signals""" 66 | 67 | def __init__( 68 | self, 69 | config: VersatileDiffusionConfig | None = None, 70 | brain_n_in_channels: int | None = None, 71 | brain_temp_dim: int | None = None, 72 | ): 73 | super().__init__() 74 | config = config if config is not None else VersatileDiffusionConfig() 75 | self.config = config 76 | self.drop_rate_clsfree = config.drop_rate_clsfree 77 | self.guidance_scale = 3.5 78 | self.diffusion_noise_offset = config.diffusion_noise_offset 79 | self.prediction_type = config.prediction_type 80 | self.noise_cubic_sampling = config.noise_cubic_sampling 81 | in_dim = config.in_dim 82 | self.brain_modules_config = config.brain_modules_config 83 | self.training_strategy = config.training_strategy 84 | 85 | 86 | print("VD cache dir is : ", config.vd_cache_dir) 87 | try: 88 | vd_pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( 89 | config.vd_cache_dir 90 | ) 91 | except: 92 | print("Downloading Versatile Diffusion to", config.vd_cache_dir) 93 | vd_pipe = VersatileDiffusionDualGuidedPipeline.from_pretrained( 94 | "shi-labs/versatile-diffusion", cache_dir=config.vd_cache_dir 95 | ) 96 | vd_pipe.vae.eval() 97 | vd_pipe.scheduler = DDPMScheduler.from_pretrained( 98 | "shi-labs/versatile-diffusion", 99 | subfolder="scheduler", 100 | ) 101 | self.num_inference_steps = 20 102 | 103 | text_image_ratio = ( 104 | 0.0 105 | ) 106 | for name, module in vd_pipe.image_unet.named_modules(): 107 | if isinstance(module, DualTransformer2DModel): 108 | module.mix_ratio = text_image_ratio 109 | for i, type in enumerate(("text", "image")): 110 | if type == "text": 111 | module.condition_lengths[i] = 77 112 | module.transformer_index_for_condition[i] = ( 113 | 1 114 | ) 115 | else: 116 | module.condition_lengths[i] = 257 117 | module.transformer_index_for_condition[i] = ( 118 | 0 119 | ) 120 | 121 | self.unet = vd_pipe.image_unet 122 | self.vae = vd_pipe.vae 123 | self.noise_scheduler = vd_pipe.scheduler 124 | self.eval_noise_scheduler = DDIMScheduler.from_pretrained( 125 | "shi-labs/versatile-diffusion", 126 | subfolder="scheduler", 127 | ) 128 | self.unet.enable_xformers_memory_efficient_attention() 129 | 130 | 131 | self.vae.state_dict = partial(state_dict_cust, self=self.vae) 132 | 133 | 134 | if ( 135 | self.brain_modules_config is not None 136 | and "blurry" in self.brain_modules_config.keys() 137 | ): 138 | num_spa_channels = 4 139 | orig_inpt_0_weights = self.unet.conv_in.weight 140 | dst_input_blocks = torch.zeros( 141 | (320, 4 + num_spa_channels, 3, 3), dtype=self.unet.conv_in.weight.dtype 142 | ) 143 | dst_input_blocks[:, :4, :, :] = orig_inpt_0_weights 144 | self.unet.conv_in.weight = torch.nn.Parameter(dst_input_blocks) 145 | 146 | del vd_pipe 147 | 148 | 149 | self.has_channel_merger = False 150 | self.brain_modules = nn.ModuleDict() 151 | if self.brain_modules_config is not None: 152 | 153 | for brain_module_name in self.brain_modules_config.keys(): 154 | brain_module = self.brain_modules_config[brain_module_name] 155 | brain_module = brain_module.build( 156 | 157 | n_in_channels=brain_n_in_channels, 158 | n_outputs=257 * 768, 159 | ) 160 | self.brain_modules[brain_module_name] = brain_module 161 | 162 | 163 | self.vae.requires_grad_(False) 164 | self.unet.requires_grad_(False) 165 | if ( 166 | self.brain_modules_config is not None 167 | and "blurry" in self.brain_modules_config.keys() 168 | ): 169 | set_requires_grad(self.unet.conv_in, True) 170 | if config.trainable_unet_layers == "lora": 171 | 172 | self.unet._hf_peft_config_loaded = False 173 | self.unet.peft_config = {} 174 | unet_lora_config = LoraConfig( 175 | r=4, 176 | lora_alpha=4, 177 | target_modules=[ 178 | "attn2.to_k", 179 | "attn2.to_q", 180 | "attn2.to_v", 181 | "attn2.to_out.0", 182 | ], 183 | ) 184 | add_adapter(self.unet, unet_lora_config) 185 | elif config.trainable_unet_layers == "no": 186 | pass 187 | else: 188 | raise ValueError("config.trainable_unet_layers is unknown") 189 | 190 | def get_condition( 191 | self, brain: Tensor, subject_idx: Tensor, **kwargs 192 | ) -> tp.Dict: 193 | brain_embeddings = { 194 | name: head( 195 | x=brain, subject_ids=subject_idx, 196 | ) 197 | for name, head in self.brain_modules.items() 198 | } 199 | BS = brain_embeddings["clip_image"]["MSELoss"].size(0) 200 | brain_embeddings["clip_image"]["MSELoss"] = brain_embeddings["clip_image"][ 201 | "MSELoss" 202 | ].reshape(BS, -1, 768) 203 | if "Clip" in brain_embeddings["clip_image"]: 204 | brain_embeddings["clip_image"]["Clip"] = brain_embeddings["clip_image"][ 205 | "Clip" 206 | ].reshape(BS, -1, 768) 207 | 208 | if "blurry" in brain_embeddings: 209 | brain_embeddings["blurry"]["MSELoss"] = F.interpolate( 210 | brain_embeddings["blurry"]["MSELoss"], 211 | [64, 64], 212 | mode="bicubic", 213 | antialias=True, 214 | align_corners=False, 215 | ) 216 | return brain_embeddings 217 | 218 | def compute_diffusion_loss(self, img, brain_embeddings: tp.Dict) -> DiffusionOutput: 219 | 220 | brain_clip_embeddings = brain_embeddings["clip_image"]["MSELoss"] 221 | brain_blurry_embeddings = None 222 | if "blurry" in brain_embeddings: 223 | brain_blurry_embeddings = brain_embeddings["blurry"]["MSELoss"] 224 | 225 | 226 | img = 2 * img - 1 227 | img = img.to(dtype=self.vae.dtype, device=self.vae.device) 228 | 229 | latent = ( 230 | self.vae.encode(img).latent_dist.sample() * self.vae.config.scaling_factor 231 | ) 232 | 233 | 234 | losses = {} 235 | timesteps = None 236 | 237 | bsz = latent.size(0) 238 | 239 | if self.noise_cubic_sampling: 240 | timesteps = torch.rand((bsz,), device=latent.device) 241 | timesteps = ( 242 | 1 - timesteps**3 243 | ) * self.noise_scheduler.config.num_train_timesteps 244 | timesteps = timesteps.long().to(self.noise_scheduler.timesteps.dtype) 245 | timesteps = timesteps.clamp( 246 | 0, self.noise_scheduler.config.num_train_timesteps - 1 247 | ) 248 | else: 249 | timesteps = torch.randint( 250 | 0, 251 | self.noise_scheduler.config.num_train_timesteps, 252 | (bsz,), 253 | device=latent.device, 254 | ) 255 | timesteps = timesteps.long() 256 | 257 | 258 | if self.diffusion_noise_offset: 259 | noise = torch.randn_like(latent) + 0.1 * torch.randn( 260 | latent.shape[0], latent.shape[1], 1, 1, device=latent.device 261 | ) 262 | noise = noise.to(self.unet.dtype) 263 | else: 264 | noise = torch.randn_like(latent) 265 | noisy_latents = self.noise_scheduler.add_noise(latent, noise, timesteps) 266 | 267 | 268 | if self.noise_scheduler.config.prediction_type == "epsilon": 269 | target = noise 270 | elif self.noise_scheduler.config.prediction_type == "v_prediction": 271 | 272 | target = self.noise_scheduler.get_velocity(latent, noise, timesteps) 273 | else: 274 | raise ValueError( 275 | f"Unknown prediction type {self.noise_scheduler.config.prediction_type}" 276 | ) 277 | 278 | 279 | 280 | brain_clip_embeddings = torch.cat( 281 | [ 282 | torch.zeros(len(brain_clip_embeddings), 77, 768) 283 | .to(self.unet.dtype) 284 | .to(self.unet.device), 285 | brain_clip_embeddings, 286 | ], 287 | dim=1, 288 | ) 289 | 290 | 291 | mask = ( 292 | torch.rand( 293 | size=(len(brain_clip_embeddings),), 294 | device=brain_clip_embeddings.device, 295 | ) 296 | < self.drop_rate_clsfree 297 | ) 298 | if self.drop_rate_clsfree > 0.0: 299 | brain_clip_embeddings[mask] = 0 300 | 301 | if brain_blurry_embeddings != None: 302 | noisy_latents = torch.cat( 303 | (noisy_latents, brain_blurry_embeddings.to(self.unet.dtype)), dim=1 304 | ).to(self.unet.dtype) 305 | 306 | losses["blurry"] = F.mse_loss( 307 | brain_blurry_embeddings, latent, reduction="mean" 308 | ) 309 | 310 | 311 | noise_pred = self.unet( 312 | noisy_latents, 313 | timesteps, 314 | encoder_hidden_states=brain_clip_embeddings.to(self.unet.dtype), 315 | cross_attention_kwargs=None, 316 | 317 | return_dict=False, 318 | )[0] 319 | dif_losses = F.mse_loss(noise_pred, target, reduction="mean") 320 | losses["diffusion"] = dif_losses 321 | 322 | rec_image = noise_pred 323 | 324 | return DiffusionOutput( 325 | image=rec_image, 326 | losses=losses, 327 | t_diffusion=timesteps, 328 | brain_embeddings=brain_embeddings, 329 | ) 330 | 331 | def forward( 332 | self, 333 | brain: Tensor, 334 | subject_idx: Tensor, 335 | img: Tensor = None, 336 | is_img_gen_mode: bool = False, 337 | seed=0, 338 | return_interm_noisy: bool = False, 339 | uncond: bool = False, 340 | blurry_recons_extern: Tensor = None, 341 | **kwargs, 342 | ) -> DiffusionOutput: 343 | 344 | brain_embeddings = self.get_condition(brain, subject_idx) 345 | if not is_img_gen_mode: 346 | if self.training_strategy == "w/o_difloss": 347 | return DiffusionOutput( 348 | image=img, losses={}, brain_embeddings=brain_embeddings 349 | ) 350 | return self.compute_diffusion_loss(img, brain_embeddings) 351 | else: 352 | 353 | 354 | return self.reconstruction_from_clipbrainimage( 355 | brain_embeddings, 356 | 357 | 358 | img_lowlevel=blurry_recons_extern, 359 | num_inference_steps=self.num_inference_steps, 360 | recons_per_sample=1, 361 | guidance_scale=self.guidance_scale, 362 | img2img_strength=0.85, 363 | seed=seed, 364 | verbose=False, 365 | img_variations=False, 366 | return_interm_noisy=return_interm_noisy, 367 | uncond=uncond, 368 | 369 | ) 370 | 371 | @torch.no_grad() 372 | def decode_latents(self, latents): 373 | latents = 1 / 0.18215 * latents 374 | image = self.vae.decode(latents).sample 375 | image = (image / 2 + 0.5).clamp(0, 1) 376 | return image 377 | 378 | @torch.no_grad() 379 | def reconstruction_from_clipbrainimage( 380 | self, 381 | 382 | brain_embeddings, 383 | 384 | text_token=None, 385 | 386 | img_lowlevel=None, 387 | num_inference_steps=20, 388 | recons_per_sample=1, 389 | guidance_scale=3.5, 390 | img2img_strength=0.85, 391 | 392 | seed=0, 393 | 394 | 395 | verbose=False, 396 | img_variations=False, 397 | return_interm_noisy=False, 398 | uncond=False, 399 | 400 | 401 | 402 | ): 403 | 404 | 405 | 406 | 407 | 408 | 409 | brain_clip_embeddings = brain_embeddings["clip_image"]["MSELoss"] 410 | img_lowlevel_latent = None 411 | if "blurry" in brain_embeddings: 412 | img_lowlevel_latent = brain_embeddings["blurry"]["MSELoss"] 413 | 414 | batchsize = brain_clip_embeddings.size(0) 415 | with torch.no_grad(): 416 | brain_recons = None 417 | if img_lowlevel is not None: 418 | img_lowlevel = img_lowlevel.to(self.unet.dtype).to(self.unet.device) 419 | 420 | if self.unet is not None: 421 | do_classifier_free_guidance = guidance_scale > 1.0 422 | 423 | generator = torch.Generator(device=self.unet.device) 424 | generator.manual_seed(seed) 425 | 426 | if uncond: 427 | input_embedding = torch.zeros_like(brain_clip_embeddings) 428 | else: 429 | input_embedding = ( 430 | brain_clip_embeddings 431 | ) 432 | if verbose: 433 | print("input_embedding", input_embedding.shape) 434 | 435 | if text_token is not None: 436 | prompt_embeds = text_token.repeat(recons_per_sample, 1, 1) 437 | else: 438 | prompt_embeds = ( 439 | torch.zeros(len(input_embedding), 77, 768) 440 | .to(self.unet.dtype) 441 | .to(self.unet.device) 442 | ) 443 | if verbose: 444 | print("prompt!", prompt_embeds.shape) 445 | 446 | if do_classifier_free_guidance: 447 | input_embedding = ( 448 | torch.cat([torch.zeros_like(input_embedding), input_embedding]) 449 | .to(self.unet.dtype) 450 | .to(self.unet.device) 451 | ) 452 | prompt_embeds = ( 453 | torch.cat([torch.zeros_like(prompt_embeds), prompt_embeds]) 454 | .to(self.unet.dtype) 455 | .to(self.unet.device) 456 | ) 457 | 458 | 459 | 460 | 461 | if not img_variations: 462 | 463 | input_embedding = torch.cat([prompt_embeds, input_embedding], dim=1) 464 | 465 | 466 | 467 | 468 | 469 | self.noise_scheduler.set_timesteps( 470 | num_inference_steps=num_inference_steps, device=self.unet.device 471 | ) 472 | 473 | 474 | batch_size = ( 475 | input_embedding.shape[0] // 2 476 | ) 477 | 478 | if ( 479 | img_lowlevel is not None 480 | and not self.sanity_check_blurry 481 | and img2img_strength != 1.0 482 | ): 483 | print("use img_lowlevel for img2img initialization") 484 | init_timestep = min( 485 | int(num_inference_steps * img2img_strength), num_inference_steps 486 | ) 487 | t_start = max(num_inference_steps - init_timestep, 0) 488 | timesteps = self.noise_scheduler.timesteps[t_start:] 489 | latent_timestep = timesteps[:1].repeat(batch_size) 490 | 491 | if verbose: 492 | print("img_lowlevel", img_lowlevel.shape) 493 | 494 | img_lowlevel_embeddings = 2 * img_lowlevel - 1 495 | if verbose: 496 | print("img_lowlevel_embeddings", img_lowlevel_embeddings.shape) 497 | init_latents = self.vae.encode( 498 | img_lowlevel_embeddings.to(self.vae.dtype) 499 | ).latent_dist.sample( 500 | generator 501 | ) 502 | init_latents = self.vae.config.scaling_factor * init_latents 503 | init_latents = init_latents.repeat(recons_per_sample, 1, 1, 1) 504 | print("init with low level") 505 | noise = torch.randn( 506 | [recons_per_sample, 4, 64, 64], 507 | device=self.unet.device, 508 | generator=generator, 509 | dtype=input_embedding.dtype, 510 | ) 511 | init_latents = self.noise_scheduler.add_noise( 512 | init_latents, noise, latent_timestep 513 | ) 514 | latents = init_latents 515 | else: 516 | timesteps = self.noise_scheduler.timesteps 517 | latents = torch.randn( 518 | [recons_per_sample * batchsize, 4, 64, 64], 519 | device=self.unet.device, 520 | generator=generator, 521 | dtype=input_embedding.dtype, 522 | ) 523 | latents = latents * self.noise_scheduler.init_noise_sigma 524 | 525 | 526 | interm_noisy = [] 527 | for i, t in enumerate(timesteps): 528 | 529 | 530 | if img_lowlevel_latent != None: 531 | latents = torch.cat((latents, img_lowlevel_latent), dim=1) 532 | 533 | latent_model_input = ( 534 | torch.cat([latents] * 2) if do_classifier_free_guidance else latents 535 | ) 536 | latent_model_input = self.noise_scheduler.scale_model_input( 537 | latent_model_input, t 538 | ) 539 | 540 | if verbose: 541 | print("latent_model_input", latent_model_input.shape) 542 | if verbose: 543 | print("input_embedding", input_embedding.shape) 544 | noise_pred = self.unet( 545 | latent_model_input, 546 | t, 547 | encoder_hidden_states=input_embedding, 548 | ).sample 549 | 550 | 551 | if do_classifier_free_guidance: 552 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 553 | noise_pred = noise_pred_uncond + guidance_scale * ( 554 | noise_pred_text - noise_pred_uncond 555 | ) 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | if img_lowlevel_latent != None: 565 | latents = latents[:, :4] 566 | res_prev = self.noise_scheduler.step(noise_pred, t, latents) 567 | latents = res_prev.prev_sample 568 | 569 | if return_interm_noisy and i % 4 == 0: 570 | pred_orig = res_prev.pred_original_sample 571 | interm_noisy.append( 572 | T.Resize((64, 64))(self.decode_latents(pred_orig).detach().cpu()) 573 | ) 574 | 575 | recons = self.decode_latents(latents).detach().cpu() 576 | 577 | brain_recons = recons.unsqueeze(0) 578 | 579 | 580 | 581 | 582 | return DiffusionOutput(image=brain_recons[0], brain_embeddings=brain_embeddings) 583 | 584 | def collect_parameters( 585 | self, 586 | ) -> List[nn.Parameter]: 587 | """Return the trainable parameters of the model. 588 | 589 | Returns: 590 | model parameter_dict 591 | """ 592 | 593 | model_parameters = {n: p for n, p in self.named_parameters() if p.requires_grad} 594 | 595 | return [v for _, v in model_parameters.items()] 596 | 597 | 598 | def state_dict_cust(*args, destination=None, prefix="", keep_vars=False, self=None): 599 | return OrderedDict() 600 | 601 | 602 | def set_requires_grad(module, value): 603 | for n, p in module.named_parameters(): 604 | p.requires_grad = value 605 | --------------------------------------------------------------------------------