├── .gitignore ├── CMVDM_structure ├── LICENSE ├── env.yml ├── scripts │ ├── infer.sh │ ├── train_dec_rgb.sh │ ├── train_dec_rgb_bold.sh │ ├── train_enc_rgb.sh │ └── train_enc_rgb_bold.sh ├── setup.py └── src │ ├── __init__.py │ ├── config.py │ ├── config_dec.py │ ├── config_enc.py │ ├── inference.py │ ├── lpips │ ├── __init__.py │ ├── lpips.py │ ├── pretrained_networks.py │ └── trainer.py │ ├── train_decoder.py │ ├── train_decoder_bold.py │ ├── train_encoder.py │ ├── train_encoder_bold.py │ └── utils │ ├── __init__.py │ ├── criterion.py │ ├── datasets.py │ ├── misc.py │ ├── models.py │ └── preprocess_BOLD5000.py ├── README.md ├── assets └── framework.jpg ├── code ├── config.py ├── dataset.py ├── dc_ldm │ ├── ldm_for_fmri.py │ ├── ldm_for_fmri_clip.py │ ├── ldm_for_fmri_control.py │ ├── models │ │ ├── autoencoder.py │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── classifier.py │ │ │ ├── ddim.py │ │ │ ├── ddpm.py │ │ │ ├── ddpm_clip.py │ │ │ ├── ddpm_control.py │ │ │ ├── ddpm_control_res.py │ │ │ └── plms.py │ │ └── fmri_base_decoder.py │ ├── modules │ │ ├── attention.py │ │ ├── diffusionmodules │ │ │ ├── __init__.py │ │ │ ├── model.py │ │ │ ├── openaimodel.py │ │ │ ├── openaimodel_control.py │ │ │ └── util.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── ema.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ └── x_transformer.py │ └── util.py ├── eval_metrics.py ├── parallel_datasets.py ├── sc_mbm │ ├── mae_for_fmri.py │ ├── trainer.py │ └── utils.py ├── setup.py ├── stageA1_mbm_pretrain.py ├── stageA2_mbm_finetune.py ├── stageB_ldm_finetune.py ├── stageB_ldm_finetune_base.py └── stageB_ldm_finetune_clip.py ├── env.yaml └── pretrains └── ldm └── label2img ├── config.yaml ├── controlnet_config.yaml └── controlnet_config_BOLD5000.yaml /.gitignore: -------------------------------------------------------------------------------- 1 | results/ 2 | data/ 3 | __pycache__/ 4 | wandb/ 5 | # models/ 6 | out_ckpts/ 7 | results/ 8 | lightning_logs/ 9 | image_generation/ 10 | mind_vis.egg-info/ 11 | *.out 12 | *.npz 13 | *.ckpt 14 | *.DS_Store 15 | *.zip 16 | *.tar 17 | *.nii 18 | *.gz 19 | *.pth -------------------------------------------------------------------------------- /CMVDM_structure/LICENSE: -------------------------------------------------------------------------------- 1 | The Weizmann Institute of Science 2 | Academic Non Commercial Software Code License 3 | © 2020 The Weizmann Institute of Science ("WIS") and Yeda Research and Development Company Ltd. ("Yeda") All Rights Reserved 4 | 5 | 1. YEDA, the commercial arm of WIS, hereby grants you, an individual or a legal entity exercising rights under, and complying with all of the provisions, of this License (“You”) a royalty-free, non-exclusive, sublicensable, worldwide license to: use, copy, modify, create derivative works (including without limiting to: adapt, alter, transform), integrate with other works, distribute, enable access (including without limiting to: communicate copies), publicly display and perform the Work in binary form or in source code, for academic and noncommercial use only and subject to all provisions of this License: 6 | 2. YEDA hereby grants You a royalty-free, non-exclusive, sublicensable, worldwide license under patents claimed or owned by YEDA that are embodied in the Work, to make, have made and use the Work under the License, for avoidance of doubt for academic and noncommercial use only. 7 | 3. Distribution or provision of access to the Work and to derivative works of the Work ("Derivative Works") may be made only under this License, accompanied with a copy of the source code or a reference to an online repository where such source code can be accessed. 8 | 4. Neither the names of WIS or Yeda, nor any of their trademarks or service marks, may be used to endorse or promote Derivative Works or for any other purpose except as expressly permitted hereunder. 9 | 5. Except as expressly stated in this License, nothing in this License grants any license to trademarks, copyrights, patents, trade secrets or any other intellectual property of WIS or Yeda. No license is granted to the trademarks of WIS or Yeda's even if such marks are included in the Work. 10 | 6. Nothing in this License shall be interpreted to prohibit WIS or Yeda from licensing the Work under terms different from this License. For commercial use please e-mail Yeda at: info.yeda@weizmann.ac.il 11 | 7. You must retain, in the Source Code of any Derivative Works that You create, all copyright, patent, or trademark notices from the Source Code of the Work, as well as a notice to inform recipients that You have modified the Work with a description of such modifications. 12 | 8. THE WORK IS PROVIDED "AS IS" AND WITHOUT ANY WARRANTIES WHATSOEVER, EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION ANY WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. 13 | 9. IN NO EVENT WILL WIS, YEDA OR ANY OF THEIR RELATED ENTITES, SCIENTISTS, EMPLOYEES, MANAGERS OR ANY OTHE PERSON ACTING ON THEIR BEHALF, BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY OR CAUSE OF ACTION, WHETHER IN CONTRACT, TORT, STRICT LIABILITY, UNJUST ENRICHMENT OR ANY OTHER, ARISING IN ANY WAY OUT OF THE USE OF THE WORK OR THIS LICENSE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 14 | 10. This License will terminate automatically if any of its conditions is not met, or in case You commence an action, including a cross-claim or counterclaim, against WIS or YEDA or any licensee alleging that the Work (except due to combination with other software or hardware) infringes a patent. 15 | 11. This License shall be exclusively governed by the laws of the State of Israel, without giving effect to conflict of laws principles, and the competent courts in Tel Aviv will have exclusive jurisdiction and venue over any matter between You and WIS or YEDA or any of their related entities relating to this License or the Work. 16 | 12. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. -------------------------------------------------------------------------------- /CMVDM_structure/env.yml: -------------------------------------------------------------------------------- 1 | channels: 2 | - pytorch 3 | - omnia 4 | - conda-forge 5 | - anaconda 6 | - defaults 7 | dependencies: 8 | - python=3.6.7 9 | - pip 10 | - pytorch=1.4.0=py3.6_cuda10.0.130_cudnn7.6.3_0 11 | - torchvision 12 | - tensorboardx 13 | - absl-py 14 | - ipython 15 | - jupyter 16 | - numpy 17 | - pandas 18 | - pillow 19 | - scipy 20 | - seaborn 21 | - termcolor 22 | - tk=8.6.8 23 | - tqdm 24 | - dotmap 25 | - imutils 26 | - py-opencv 27 | - pretrainedmodels 28 | - natsort 29 | - pytest 30 | - pytest-profiling 31 | - pytest-benchmark -------------------------------------------------------------------------------- /CMVDM_structure/scripts/infer.sh: -------------------------------------------------------------------------------- 1 | gpu=2 # GPU ID 2 | sbj_num=3 3 | tensorboard_log_dir=./log 4 | 5 | echo src/inference.py --exp_prefix sub${sbj_num}_rgbd_noDE \ 6 | --enc_cpt_name sub${sbj_num}_rgbd_best --separable 1 --test_avg_qntl 1 --learning_rate 5e-3 --loss_weights 1,1,0 \ 7 | --fc_gl 1 --gl_l1 40 --gl_gl 400 --fc_mom2 0 --l1_convs 1e-4 --tv_reg 3e-1 --n_epochs 100 --batch_size_list 24,16,48,50 \ 8 | --scheduler 12345 --mslr 100,140 --sched_gamma 0.2 --percept_w 10,10,10,10,2 --rgb_mae 1 --is_rgbd 1 --norm_within_img 1 \ 9 | --sbj_num $sbj_num --tensorboard_log_dir $tensorboard_log_dir --gpu $gpu -------------------------------------------------------------------------------- /CMVDM_structure/scripts/train_dec_rgb.sh: -------------------------------------------------------------------------------- 1 | gpu=0 # GPU ID 2 | sbj_num=3 3 | tensorboard_log_dir=./log 4 | 5 | echo src/train_decoder.py --exp_prefix sub${sbj_num}_rgb_only_noDE \ 6 | --enc_cpt_name sub${sbj_num}_rgb_only_best --separable 1 --test_avg_qntl 1 --learning_rate 5e-3 --loss_weights 1,1,0 \ 7 | --fc_gl 1 --gl_l1 40 --gl_gl 400 --fc_mom2 0 --l1_convs 1e-4 --tv_reg 3e-1 --n_epochs 150 --batch_size_list 24,16,48,50 \ 8 | --scheduler 12345 --mslr 100,140 --sched_gamma 0.2 --percept_w 10,10,10,10,2 --rgb_mae 1 --is_rgbd 0 --norm_within_img 1 \ 9 | --sbj_num $sbj_num --depth_from_rgb 0 --tensorboard_log_dir $tensorboard_log_dir --gpu $gpu --may_save 1 10 | -------------------------------------------------------------------------------- /CMVDM_structure/scripts/train_dec_rgb_bold.sh: -------------------------------------------------------------------------------- 1 | gpu=2 # GPU ID 2 | tensorboard_log_dir=./log 3 | 4 | echo src/train_decoder_bold.py --exp_prefix bold_rgb_only_0506_noDE_strd2_largek \ 5 | --enc_cpt_name bold_rgb_only_best --separable 1 --test_avg_qntl 1 --learning_rate 6e-3 --loss_weights 1,1,0 \ 6 | --fc_gl 1 --gl_l1 40 --gl_gl 400 --fc_mom2 40 --l1_convs 1e-4 --tv_reg 3e-1 --n_epochs 60 --batch_size_list 64,16,48,48 \ 7 | --scheduler 1 --percept_w 10,10,10,10,2 --rgb_mae 1 --is_rgbd 0 --norm_within_img 1 \ 8 | --depth_from_rgb 0 --tensorboard_log_dir $tensorboard_log_dir --gpu $gpu --may_save 1 9 | -------------------------------------------------------------------------------- /CMVDM_structure/scripts/train_enc_rgb.sh: -------------------------------------------------------------------------------- 1 | gpu=0 # GPU ID 2 | sbj_num=3 3 | tensorboard_log_dir=SiluetteExtraction/enc 4 | 5 | echo src/train_encoder.py \ 6 | --exp_prefix sub3_rgb_only \ 7 | --separable 1 --n_epochs 50 --learning_rate 1e-3 --cos_loss 0.3 --random_crop_pad_percent 3 --scheduler 10 --gamma 0.2 \ 8 | --fc_gl 1 --fc_mom2 10 --l1_convs 1e-4 --is_rgbd 0 --allow_bbn_detach 1 --train_bbn 0 --norm_within_img 1 --may_save 1 \ 9 | --sbj_num $sbj_num --tensorboard_log_dir $tensorboard_log_dir --gpu $gpu -------------------------------------------------------------------------------- /CMVDM_structure/scripts/train_enc_rgb_bold.sh: -------------------------------------------------------------------------------- 1 | gpu=0 # GPU ID 2 | tensorboard_log_dir=./log 3 | 4 | echo src/train_encoder_bold.py \ 5 | --exp_prefix bold_rgb_only_0429 \ 6 | --separable 1 --n_epochs 150 --learning_rate 2e-3 --cos_loss 0.3 --random_crop_pad_percent 3 --scheduler 10 --gamma 0.5 --scheduler 1 \ 7 | --fc_gl 1 --fc_mom2 10 --l1_convs 1e-4 --is_rgbd 0 --allow_bbn_detach 1 --train_bbn 0 --norm_within_img 1 --may_save 1 \ 8 | --tensorboard_log_dir $tensorboard_log_dir --gpu $gpu -------------------------------------------------------------------------------- /CMVDM_structure/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | from setuptools import setup, find_packages 5 | 6 | setup( 7 | name='CMVDM-siluette-extraction', 8 | version='1.0.0', 9 | author="Bohan Zeng, Shanglin Li, et al.", 10 | author_email='shanglin@buaa.edu.cn', 11 | description="Controllable Mind Visual Diffusion Model (Official PyTorch implementation for siluette extraction)", 12 | long_description=open("README.md", "r", encoding='utf-8').read(), 13 | long_description_content_type="text/markdown", 14 | keywords="fMRI Image Reconstruction; Computer Vision; Generative Models", 15 | url='https://github.com/zengbohan0217/CMVDM', 16 | packages=find_packages(), 17 | include_package_data=True, 18 | tests_require=['pytest'], 19 | license="Yeda", 20 | classifiers=[ 21 | 'Intended Audience :: Science/Research', 22 | 'Programming Language :: Python :: 3.6', 23 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 24 | ], 25 | test_suite='tests', 26 | ) 27 | -------------------------------------------------------------------------------- /CMVDM_structure/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengbohan0217/CMVDM/fa61e9d1ab27c644168fec4d6b5edd135bf88a80/CMVDM_structure/src/__init__.py -------------------------------------------------------------------------------- /CMVDM_structure/src/config.py: -------------------------------------------------------------------------------- 1 | 2 | import os, numpy as np 3 | from absl import flags 4 | import src 5 | from pathlib import Path 6 | FLAGS = flags.FLAGS 7 | 8 | PROJECT_ROOT = str(Path(__file__).parents[1]) 9 | 10 | placeholder_str = '' 11 | 12 | flags.DEFINE_list('gpu', ['0'], 'GPUs') 13 | flags.DEFINE_enum("im_res", '112', ['112','56'], 'Image pixel resolution') 14 | flags.DEFINE_string("checkpoint_out", 'checkpoints/{}.pth.tar'.format(placeholder_str), "Checkpoint path") 15 | flags.DEFINE_integer('may_save', 1, '') 16 | flags.DEFINE_integer("savebestonly", 1, 'Save checkpoint only if it is the best, otherwise do not save.') 17 | flags.DEFINE_integer("sbj_num", 3, '') 18 | flags.DEFINE_integer('separable', 0, 'Separable (Space-Feature) Encoder.') 19 | flags.DEFINE_integer("is_rgbd", 0, '1: RGBD | 2: Depth only') 20 | flags.DEFINE_integer("norm_within_img", 0, 'normalize within each depth map') 21 | 22 | flags.DEFINE_string("select_voxels", '', 'File with voxels selected for analysis') 23 | 24 | im_res = lambda : int(FLAGS.im_res) 25 | 26 | def num_workers(): 27 | return FLAGS.num_workers_gpu * len(FLAGS.gpu) 28 | 29 | def get_checkpoint_out(): 30 | if placeholder_str in FLAGS.checkpoint_out: 31 | return FLAGS.checkpoint_out.replace(placeholder_str, FLAGS.exp_prefix) 32 | else: 33 | return FLAGS.checkpoint_out 34 | 35 | if __name__ == '__main__': 36 | pass 37 | -------------------------------------------------------------------------------- /CMVDM_structure/src/config_dec.py: -------------------------------------------------------------------------------- 1 | from src.config import * 2 | Train_Decoder = True 3 | 4 | if Train_Decoder: 5 | flags.DEFINE_string("tensorboard_log_dir", '', "Log dir.") 6 | flags.DEFINE_string("exp_prefix", 'dec_tmp', "Experiment prefix.") 7 | flags.DEFINE_integer("num_workers_gpu", 2, "Number of workers per GPU.") 8 | flags.DEFINE_enum("roi", 'V2', ['V1', 'V2', 'V3', 'V4', 'FFA', 'PPA', 'LOC', 'LVC', 'HVC', 'VC', 'V1V2V3'], '') 9 | flags.DEFINE_integer("n_epochs", 150, "Number of epochs.") 10 | flags.DEFINE_float("learning_rate", 5e-4, "The initial value for the learning rate.") 11 | flags.DEFINE_integer("scheduler", 0, "Reduce learning rate by scheduler.") 12 | flags.DEFINE_float("decay", 0., "Weight decay.") 13 | flags.DEFINE_string('enc_cpt_name', '', 'Encoder checkpoint name to load') 14 | flags.DEFINE_string('enc_bbn_arch_name', 'alexnet', '') 15 | flags.DEFINE_list("batch_size_list", [24, 16, 16, 64], "Supervised training | unlabeled fMRI | unlabeled images | test") 16 | batch_size_list = lambda : [int(x) for x in FLAGS.batch_size_list] 17 | flags.DEFINE_list("loss_weights", [1, 1, 1], "Loss weights [loss_D, loss_ED, loss_DE]") 18 | flags.DEFINE_list("percept_w", [1, 1, 1, 10, 1], "Perceptual loss weights along blocks") 19 | flags.DEFINE_float("test_avg_qntl", 1., "Quantile of test repeats to be averaged at random") 20 | flags.DEFINE_float("mix_q", 0, "Quotient of mixed samples") 21 | flags.DEFINE_list("mslr", [20, 35, 45, 50], "Scheduler milestones.") 22 | flags.DEFINE_float("sched_gamma", .2, "Scheduler lr reduction gamma") 23 | flags.DEFINE_integer("n_conv_layers", 3, "Number of convolutional layers in Decoder model.") 24 | flags.DEFINE_integer("sum_level", 6, 'Summary level. 4: -BatchOutDist | 5: Report all. | 6: +ImageLoss components') 25 | flags.DEFINE_list("train_montage_indices", list(np.linspace(0, 1199, 50, dtype='int')), '') 26 | flags.DEFINE_enum('interp_mode', 'bicubic', ['nearest', 'bicubic', 'bilinear'], '') 27 | flags.DEFINE_integer("pred_interp", 0, "Size to interpolate the reconstructed image before applying loss. '0' means no interp") 28 | flags.DEFINE_float("tv_reg", 0.5, "Total variation regularization coefficient.") 29 | flags.DEFINE_float("ml_percep", 1., "Multi-layer perceptual loss.") 30 | flags.DEFINE_integer("random_crop_pad_percent", 3, "") 31 | flags.DEFINE_integer("config_train", 10, "Training configuration. 1: Decoder supervised training only | 10: Full method") 32 | flags.DEFINE_float("l1_fcreg", 0., "") 33 | flags.DEFINE_float("l1_convs", 0., "") 34 | flags.DEFINE_float("l2_convs", 0., "") 35 | flags.DEFINE_float("l2_fcreg", 0., "") 36 | flags.DEFINE_float("fc_mom2", 0., "") 37 | flags.DEFINE_float("fc_gl", 0, "") 38 | flags.DEFINE_float("gl_l1", 20, "L1 reg component of GL") 39 | flags.DEFINE_float("gl_gl", 400, "GL component of GL") 40 | flags.DEFINE_float("vox_nc_reg", 0., "") 41 | flags.DEFINE_float("rgb_mae", 0.2, "") 42 | flags.DEFINE_integer('verbose', 1, '0: -within-epoch log | 1: Report all.') 43 | 44 | 45 | exp_folder = None 46 | exp_folder = lambda : os.path.join('results/', FLAGS.exp_prefix) 47 | enc_cpt_path = lambda : f'{PROJECT_ROOT}/checkpoints/{FLAGS.enc_cpt_name}.pth.tar' 48 | 49 | if __name__ == '__main__': 50 | pass 51 | -------------------------------------------------------------------------------- /CMVDM_structure/src/config_enc.py: -------------------------------------------------------------------------------- 1 | 2 | from src.config import * 3 | 4 | flags.DEFINE_string("tensorboard_log_dir", '', "Log dir.") 5 | flags.DEFINE_enum("roi", 'V2', ['V1', 'V2', 'V3', 'V4', 'FFA', 'PPA', 'LOC', 'LVC', 'HVC', 'VC', 'V1V2V3'], '') 6 | flags.DEFINE_string("exp_prefix", 'enc_tmp', "Experiment prefix.") 7 | 8 | flags.DEFINE_integer("num_workers_gpu", 5, "Number of workers per GPU.") 9 | flags.DEFINE_integer("n_epochs", 80, "Number of epochs.") 10 | 11 | flags.DEFINE_float("learning_rate", 1e-1, "The initial value for the learning rate.") 12 | flags.DEFINE_float("mse_loss", 1., "") 13 | flags.DEFINE_float("cos_loss", 0.1, "") 14 | 15 | flags.DEFINE_float("decay", .002, "Weight decay.") 16 | 17 | flags.DEFINE_float("l1_fcreg", 0., "") 18 | flags.DEFINE_float("l1_convs", 1e-5, "") 19 | flags.DEFINE_float("l2_fcreg", 0., "") 20 | flags.DEFINE_float("fc_mom2", 0., "") 21 | 22 | flags.DEFINE_float("fc_gl", 1., "") 23 | flags.DEFINE_float("l1_chan_mix", 5e-6, "") 24 | flags.DEFINE_float("l1_branch_mix", 5e-2, "") 25 | 26 | flags.DEFINE_integer("batch_size", 64, "Batch size.") 27 | 28 | flags.DEFINE_integer("scheduler", 12345, "Reduce learning rate by scheduler.") 29 | 30 | flags.DEFINE_float("gamma", 0.7, "Scheduler gamma") 31 | flags.DEFINE_enum("loss", 'mse', ['mse', 'l1'], '') 32 | 33 | flags.DEFINE_float("mix_q", 0, "Quotient of mixed samples") 34 | flags.DEFINE_integer("pw_corr_win", 0, "Window wize for piecewise correlation.") 35 | 36 | 37 | flags.DEFINE_integer("sum_level", 4, 'Summary level. 4: -BatchOutDist | 5: Report all.') 38 | flags.DEFINE_enum('interp_mode', 'bicubic', ['bicubic', 'bilinear', 'trilinear'], '') 39 | 40 | flags.DEFINE_integer("random_crop_pad_percent", 10, "") 41 | flags.DEFINE_integer("keras_pretrained", 0, 'Load weights from keras case') 42 | 43 | flags.DEFINE_string('bbn_arch_name', 'alexnet', '') 44 | flags.DEFINE_integer('verbose', 1, '0: -within-epoch log | 1: Report all.') 45 | 46 | flags.DEFINE_string('init_cpt_name', '', 'Encoder checkpoint name to load') 47 | 48 | flags.DEFINE_integer("train_bbn", 0, 'Train the backbone network') 49 | flags.DEFINE_integer("allow_bbn_detach", 1, 'Allow bbn detach') 50 | 51 | init_cpt_path = lambda : 'checkpoints/{}.pth.tar'.format(FLAGS.init_cpt_name) 52 | 53 | if __name__ == '__main__': 54 | pass 55 | -------------------------------------------------------------------------------- /CMVDM_structure/src/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | from torchvision.utils import save_image 6 | from src.utils import * 7 | from src.config_dec import * 8 | from src.utils.misc import set_gpu 9 | 10 | 11 | def pure_test(data_loader_labeled, dec, save_folder='./vis_digit_0818'): 12 | os.makedirs(save_folder, exist_ok=True) 13 | dec.eval().cuda() 14 | 15 | main_loader = data_loader_labeled 16 | for batch_idx, (images_gt, fmri_gt) in enumerate(main_loader): 17 | images_gt, fmri_gt = map(lambda x: x.cuda(), [images_gt, fmri_gt]) 18 | with torch.no_grad(): 19 | images_D = dec(fmri_gt) 20 | pred_image_tensors = images_D.cpu() 21 | gt_image_tensors = images_gt.cpu() 22 | 23 | for index, image_tensor in enumerate(pred_image_tensors): 24 | save_image(torch.cat([gt_image_tensors[index], image_tensor], dim=1), pjoin(save_folder, str(index) + '_gt_gen.png')) 25 | 26 | def run_god_test(dec_path): 27 | 28 | # dataset / dataloader 29 | img_xfm_basic = transforms.Compose([transforms.Resize(size=112, interpolation=Image.BILINEAR), transforms.CenterCrop(112), transforms.ToTensor()]) 30 | val_labeled_avg = KamitaniDataset_discussion(fmri_xfm=np.float32, subset_case=KamitaniDataset_discussion.TEST) 31 | val_labeled_avg = CustomDataset(val_labeled_avg, input_xfm=img_xfm_basic) 32 | data_loaders_labeled = { 33 | 'test': data.DataLoader(val_labeled_avg, batch_size=min([24,16,48,50][-1], len(val_labeled_avg)), shuffle=False, num_workers=7, pin_memory=True), 34 | } 35 | 36 | # init fmri-decoder 37 | dec = make_model('BaseDecoder', 9919, 112, start_CHW=(64, 14, 14), n_conv_layers_ramp=3, n_chan=64, n_chan_output=3, depth_extractor=None) 38 | # Load pretrained encoder 39 | dec = nn.DataParallel(dec) 40 | assert os.path.isfile(dec_path) 41 | print('\t==> Loading checkpoint {}'.format(os.path.basename(dec_path))) 42 | dec.load_state_dict(torch.load(dec_path)['state_dict']) 43 | 44 | pure_test(data_loaders_labeled['test'], dec) 45 | 46 | def depth_infer(fmri, 47 | dec_path, 48 | save_folder='./vis_depth'): 49 | 50 | os.makedirs(save_folder, exist_ok=True) 51 | 52 | # init fmri-decoder 53 | dec = make_model('BaseDecoder', 4643, 112, start_CHW=(64, 14, 14), n_conv_layers_ramp=3, n_chan=64, n_chan_output=4, depth_extractor=None) 54 | # Load pretrained encoder 55 | dec = nn.DataParallel(dec) 56 | assert os.path.isfile(dec_path) 57 | print('\t==> Loading checkpoint {}'.format(os.path.basename(dec_path))) 58 | dec.load_state_dict(torch.load(dec_path)['state_dict']) 59 | 60 | dec.eval().cuda() 61 | with torch.no_grad(): 62 | images_D = dec(fmri) 63 | 64 | 65 | 66 | 67 | 68 | 69 | if __name__ == '__main__': 70 | run_god_test() -------------------------------------------------------------------------------- /CMVDM_structure/src/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | import torch 8 | # from torch.autograd import Variable 9 | 10 | from lpips.trainer import * 11 | from lpips.lpips import * 12 | 13 | # class PerceptualLoss(torch.nn.Module): 14 | # def __init__(self, model='lpips', net='alex', spatial=False, use_gpu=False, gpu_ids=[0], version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 15 | # # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 16 | # super(PerceptualLoss, self).__init__() 17 | # print('Setting up Perceptual loss...') 18 | # self.use_gpu = use_gpu 19 | # self.spatial = spatial 20 | # self.gpu_ids = gpu_ids 21 | # self.model = dist_model.DistModel() 22 | # self.model.initialize(model=model, net=net, use_gpu=use_gpu, spatial=self.spatial, gpu_ids=gpu_ids, version=version) 23 | # print('...[%s] initialized'%self.model.name()) 24 | # print('...Done') 25 | 26 | # def forward(self, pred, target, normalize=False): 27 | # """ 28 | # Pred and target are Variables. 29 | # If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 30 | # If normalize is False, assumes the images are already between [-1,+1] 31 | 32 | # Inputs pred and target are Nx3xHxW 33 | # Output pytorch Variable N long 34 | # """ 35 | 36 | # if normalize: 37 | # target = 2 * target - 1 38 | # pred = 2 * pred - 1 39 | 40 | # return self.model.forward(target, pred) 41 | 42 | def normalize_tensor(in_feat,eps=1e-10): 43 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 44 | return in_feat/(norm_factor+eps) 45 | 46 | def l2(p0, p1, range=255.): 47 | return .5*np.mean((p0 / range - p1 / range)**2) 48 | 49 | def psnr(p0, p1, peak=255.): 50 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 51 | 52 | def dssim(p0, p1, range=255.): 53 | from skimage.measure import compare_ssim 54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def load_image(path): 104 | if(path[-3:] == 'dng'): 105 | import rawpy 106 | with rawpy.imread(path) as raw: 107 | img = raw.postprocess() 108 | elif(path[-3:]=='bmp' or path[-3:]=='jpg' or path[-3:]=='png' or path[-4:]=='jpeg'): 109 | import cv2 110 | return cv2.imread(path)[:,:,::-1] 111 | else: 112 | img = (255*plt.imread(path)[:,:,:3]).astype('uint8') 113 | 114 | return img 115 | 116 | def rgb2lab(input): 117 | from skimage import color 118 | return color.rgb2lab(input / 255.) 119 | 120 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 121 | image_numpy = image_tensor[0].cpu().float().numpy() 122 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 123 | return image_numpy.astype(imtype) 124 | 125 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 126 | return torch.Tensor((image / factor - cent) 127 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 128 | 129 | def tensor2vec(vector_tensor): 130 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 131 | 132 | 133 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 134 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 135 | image_numpy = image_tensor[0].cpu().float().numpy() 136 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 137 | return image_numpy.astype(imtype) 138 | 139 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 140 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 141 | return torch.Tensor((image / factor - cent) 142 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 143 | 144 | 145 | 146 | def voc_ap(rec, prec, use_07_metric=False): 147 | """ ap = voc_ap(rec, prec, [use_07_metric]) 148 | Compute VOC AP given precision and recall. 149 | If use_07_metric is true, uses the 150 | VOC 07 11 point method (default:False). 151 | """ 152 | if use_07_metric: 153 | # 11 point metric 154 | ap = 0. 155 | for t in np.arange(0., 1.1, 0.1): 156 | if np.sum(rec >= t) == 0: 157 | p = 0 158 | else: 159 | p = np.max(prec[rec >= t]) 160 | ap = ap + p / 11. 161 | else: 162 | # correct AP calculation 163 | # first append sentinel values at the end 164 | mrec = np.concatenate(([0.], rec, [1.])) 165 | mpre = np.concatenate(([0.], prec, [0.])) 166 | 167 | # compute the precision envelope 168 | for i in range(mpre.size - 1, 0, -1): 169 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 170 | 171 | # to calculate area under PR curve, look for points 172 | # where X axis (recall) changes value 173 | i = np.where(mrec[1:] != mrec[:-1])[0] 174 | 175 | # and sum (\Delta recall) * prec 176 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 177 | return ap 178 | 179 | -------------------------------------------------------------------------------- /CMVDM_structure/src/lpips/lpips.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.init as init 7 | from torch.autograd import Variable 8 | import numpy as np 9 | from lpips import pretrained_networks as pn 10 | import torch.nn 11 | 12 | import lpips 13 | 14 | def spatial_average(in_tens, keepdim=True): 15 | return in_tens.mean([2,3],keepdim=keepdim) 16 | 17 | def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W 18 | in_H, in_W = in_tens.shape[2], in_tens.shape[3] 19 | return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens) 20 | 21 | # Learned perceptual metric 22 | class LPIPS(nn.Module): 23 | def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, 24 | pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=True): 25 | # lpips - [True] means with linear calibration on top of base network 26 | # pretrained - [True] means load linear weights 27 | 28 | super(LPIPS, self).__init__() 29 | if(verbose): 30 | print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'% 31 | ('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off')) 32 | 33 | self.pnet_type = net 34 | self.pnet_tune = pnet_tune 35 | self.pnet_rand = pnet_rand 36 | self.spatial = spatial 37 | self.lpips = lpips # false means baseline of just averaging all layers 38 | self.version = version 39 | self.scaling_layer = ScalingLayer() 40 | 41 | if(self.pnet_type in ['vgg','vgg16']): 42 | net_type = pn.vgg16 43 | self.chns = [64,128,256,512,512] 44 | elif(self.pnet_type=='alex'): 45 | net_type = pn.alexnet 46 | self.chns = [64,192,384,256,256] 47 | elif(self.pnet_type=='squeeze'): 48 | net_type = pn.squeezenet 49 | self.chns = [64,128,256,384,384,512,512] 50 | self.L = len(self.chns) 51 | 52 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 53 | 54 | if(lpips): 55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 60 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 61 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 62 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 63 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 64 | self.lins+=[self.lin5,self.lin6] 65 | self.lins = nn.ModuleList(self.lins) 66 | 67 | if(pretrained): 68 | if(model_path is None): 69 | import inspect 70 | import os 71 | model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net))) 72 | 73 | if(verbose): 74 | print('Loading model from: %s'%model_path) 75 | self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False) 76 | 77 | if(eval_mode): 78 | self.eval() 79 | 80 | def forward(self, in0, in1, retPerLayer=False, normalize=False): 81 | if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1] 82 | in0 = 2 * in0 - 1 83 | in1 = 2 * in1 - 1 84 | 85 | # v0.0 - original release had a bug, where input was not scaled 86 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 87 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 88 | feats0, feats1, diffs = {}, {}, {} 89 | 90 | for kk in range(self.L): 91 | feats0[kk], feats1[kk] = lpips.normalize_tensor(outs0[kk]), lpips.normalize_tensor(outs1[kk]) 92 | diffs[kk] = (feats0[kk]-feats1[kk])**2 93 | 94 | if(self.lpips): 95 | if(self.spatial): 96 | res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)] 97 | else: 98 | res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)] 99 | else: 100 | if(self.spatial): 101 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)] 102 | else: 103 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 104 | 105 | val = res[0] 106 | for l in range(1,self.L): 107 | val += res[l] 108 | 109 | # a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 110 | # b = torch.max(self.lins[kk](feats0[kk]**2)) 111 | # for kk in range(self.L): 112 | # a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True) 113 | # b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2))) 114 | # a = a/self.L 115 | # from IPython import embed 116 | # embed() 117 | # return 10*torch.log10(b/a) 118 | 119 | if(retPerLayer): 120 | return (val, res) 121 | else: 122 | return val 123 | 124 | 125 | class ScalingLayer(nn.Module): 126 | def __init__(self): 127 | super(ScalingLayer, self).__init__() 128 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 129 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 130 | 131 | def forward(self, inp): 132 | return (inp - self.shift) / self.scale 133 | 134 | 135 | class NetLinLayer(nn.Module): 136 | ''' A single linear layer which does a 1x1 conv ''' 137 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 138 | super(NetLinLayer, self).__init__() 139 | 140 | layers = [nn.Dropout(),] if(use_dropout) else [] 141 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 142 | self.model = nn.Sequential(*layers) 143 | 144 | def forward(self, x): 145 | return self.model(x) 146 | 147 | class Dist2LogitLayer(nn.Module): 148 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 149 | def __init__(self, chn_mid=32, use_sigmoid=True): 150 | super(Dist2LogitLayer, self).__init__() 151 | 152 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 153 | layers += [nn.LeakyReLU(0.2,True),] 154 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 155 | layers += [nn.LeakyReLU(0.2,True),] 156 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 157 | if(use_sigmoid): 158 | layers += [nn.Sigmoid(),] 159 | self.model = nn.Sequential(*layers) 160 | 161 | def forward(self,d0,d1,eps=0.1): 162 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 163 | 164 | class BCERankingLoss(nn.Module): 165 | def __init__(self, chn_mid=32): 166 | super(BCERankingLoss, self).__init__() 167 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 168 | # self.parameters = list(self.net.parameters()) 169 | self.loss = torch.nn.BCELoss() 170 | 171 | def forward(self, d0, d1, judge): 172 | per = (judge+1.)/2. 173 | self.logit = self.net.forward(d0,d1) 174 | return self.loss(self.logit, per) 175 | 176 | # L2, DSSIM metrics 177 | class FakeNet(nn.Module): 178 | def __init__(self, use_gpu=True, colorspace='Lab'): 179 | super(FakeNet, self).__init__() 180 | self.use_gpu = use_gpu 181 | self.colorspace = colorspace 182 | 183 | class L2(FakeNet): 184 | def forward(self, in0, in1, retPerLayer=None): 185 | assert(in0.size()[0]==1) # currently only supports batchSize 1 186 | 187 | if(self.colorspace=='RGB'): 188 | (N,C,X,Y) = in0.size() 189 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 190 | return value 191 | elif(self.colorspace=='Lab'): 192 | value = lpips.l2(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), 193 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 194 | ret_var = Variable( torch.Tensor((value,) ) ) 195 | if(self.use_gpu): 196 | ret_var = ret_var.cuda() 197 | return ret_var 198 | 199 | class DSSIM(FakeNet): 200 | 201 | def forward(self, in0, in1, retPerLayer=None): 202 | assert(in0.size()[0]==1) # currently only supports batchSize 1 203 | 204 | if(self.colorspace=='RGB'): 205 | value = lpips.dssim(1.*lpips.tensor2im(in0.data), 1.*lpips.tensor2im(in1.data), range=255.).astype('float') 206 | elif(self.colorspace=='Lab'): 207 | value = lpips.dssim(lpips.tensor2np(lpips.tensor2tensorlab(in0.data,to_norm=False)), 208 | lpips.tensor2np(lpips.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 209 | ret_var = Variable( torch.Tensor((value,) ) ) 210 | if(self.use_gpu): 211 | ret_var = ret_var.cuda() 212 | return ret_var 213 | 214 | def print_network(net): 215 | num_params = 0 216 | for param in net.parameters(): 217 | num_params += param.numel() 218 | print('Network',net) 219 | print('Total number of parameters: %d' % num_params) 220 | -------------------------------------------------------------------------------- /CMVDM_structure/src/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | class squeezenet(torch.nn.Module): 6 | def __init__(self, requires_grad=False, pretrained=True): 7 | super(squeezenet, self).__init__() 8 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 9 | self.slice1 = torch.nn.Sequential() 10 | self.slice2 = torch.nn.Sequential() 11 | self.slice3 = torch.nn.Sequential() 12 | self.slice4 = torch.nn.Sequential() 13 | self.slice5 = torch.nn.Sequential() 14 | self.slice6 = torch.nn.Sequential() 15 | self.slice7 = torch.nn.Sequential() 16 | self.N_slices = 7 17 | for x in range(2): 18 | self.slice1.add_module(str(x), pretrained_features[x]) 19 | for x in range(2,5): 20 | self.slice2.add_module(str(x), pretrained_features[x]) 21 | for x in range(5, 8): 22 | self.slice3.add_module(str(x), pretrained_features[x]) 23 | for x in range(8, 10): 24 | self.slice4.add_module(str(x), pretrained_features[x]) 25 | for x in range(10, 11): 26 | self.slice5.add_module(str(x), pretrained_features[x]) 27 | for x in range(11, 12): 28 | self.slice6.add_module(str(x), pretrained_features[x]) 29 | for x in range(12, 13): 30 | self.slice7.add_module(str(x), pretrained_features[x]) 31 | if not requires_grad: 32 | for param in self.parameters(): 33 | param.requires_grad = False 34 | 35 | def forward(self, X): 36 | h = self.slice1(X) 37 | h_relu1 = h 38 | h = self.slice2(h) 39 | h_relu2 = h 40 | h = self.slice3(h) 41 | h_relu3 = h 42 | h = self.slice4(h) 43 | h_relu4 = h 44 | h = self.slice5(h) 45 | h_relu5 = h 46 | h = self.slice6(h) 47 | h_relu6 = h 48 | h = self.slice7(h) 49 | h_relu7 = h 50 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 51 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 52 | 53 | return out 54 | 55 | 56 | class alexnet(torch.nn.Module): 57 | def __init__(self, requires_grad=False, pretrained=True): 58 | super(alexnet, self).__init__() 59 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 60 | self.slice1 = torch.nn.Sequential() 61 | self.slice2 = torch.nn.Sequential() 62 | self.slice3 = torch.nn.Sequential() 63 | self.slice4 = torch.nn.Sequential() 64 | self.slice5 = torch.nn.Sequential() 65 | self.N_slices = 5 66 | for x in range(2): 67 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 68 | for x in range(2, 5): 69 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 70 | for x in range(5, 8): 71 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 72 | for x in range(8, 10): 73 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 74 | for x in range(10, 12): 75 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 76 | if not requires_grad: 77 | for param in self.parameters(): 78 | param.requires_grad = False 79 | 80 | def forward(self, X): 81 | h = self.slice1(X) 82 | h_relu1 = h 83 | h = self.slice2(h) 84 | h_relu2 = h 85 | h = self.slice3(h) 86 | h_relu3 = h 87 | h = self.slice4(h) 88 | h_relu4 = h 89 | h = self.slice5(h) 90 | h_relu5 = h 91 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 92 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 93 | 94 | return out 95 | 96 | class vgg16(torch.nn.Module): 97 | def __init__(self, requires_grad=False, pretrained=True): 98 | super(vgg16, self).__init__() 99 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 100 | self.slice1 = torch.nn.Sequential() 101 | self.slice2 = torch.nn.Sequential() 102 | self.slice3 = torch.nn.Sequential() 103 | self.slice4 = torch.nn.Sequential() 104 | self.slice5 = torch.nn.Sequential() 105 | self.N_slices = 5 106 | for x in range(4): 107 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 108 | for x in range(4, 9): 109 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(9, 16): 111 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(16, 23): 113 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(23, 30): 115 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 116 | if not requires_grad: 117 | for param in self.parameters(): 118 | param.requires_grad = False 119 | 120 | def forward(self, X): 121 | h = self.slice1(X) 122 | h_relu1_2 = h 123 | h = self.slice2(h) 124 | h_relu2_2 = h 125 | h = self.slice3(h) 126 | h_relu3_3 = h 127 | h = self.slice4(h) 128 | h_relu4_3 = h 129 | h = self.slice5(h) 130 | h_relu5_3 = h 131 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 132 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 133 | 134 | return out 135 | 136 | 137 | 138 | class resnet(torch.nn.Module): 139 | def __init__(self, requires_grad=False, pretrained=True, num=18): 140 | super(resnet, self).__init__() 141 | if(num==18): 142 | self.net = tv.resnet18(pretrained=pretrained) 143 | elif(num==34): 144 | self.net = tv.resnet34(pretrained=pretrained) 145 | elif(num==50): 146 | self.net = tv.resnet50(pretrained=pretrained) 147 | elif(num==101): 148 | self.net = tv.resnet101(pretrained=pretrained) 149 | elif(num==152): 150 | self.net = tv.resnet152(pretrained=pretrained) 151 | self.N_slices = 5 152 | 153 | self.conv1 = self.net.conv1 154 | self.bn1 = self.net.bn1 155 | self.relu = self.net.relu 156 | self.maxpool = self.net.maxpool 157 | self.layer1 = self.net.layer1 158 | self.layer2 = self.net.layer2 159 | self.layer3 = self.net.layer3 160 | self.layer4 = self.net.layer4 161 | 162 | def forward(self, X): 163 | h = self.conv1(X) 164 | h = self.bn1(h) 165 | h = self.relu(h) 166 | h_relu1 = h 167 | h = self.maxpool(h) 168 | h = self.layer1(h) 169 | h_conv2 = h 170 | h = self.layer2(h) 171 | h_conv3 = h 172 | h = self.layer3(h) 173 | h_conv4 = h 174 | h = self.layer4(h) 175 | h_conv5 = h 176 | 177 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 178 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 179 | 180 | return out 181 | -------------------------------------------------------------------------------- /CMVDM_structure/src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os, time 2 | 3 | from .misc import * 4 | from .datasets import * 5 | from .criterion import * 6 | from .models import * 7 | 8 | sns.set(color_codes=True) 9 | 10 | def starred(s, n_stars=10): 11 | return '*' * n_stars + '\n' + s + '\n' + '*' * n_stars 12 | 13 | class NormalizeBatch(object): 14 | """Normalize a tensor image with mean and standard deviation. 15 | Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform 16 | will normalize each channel of the input ``torch.*Tensor`` i.e. 17 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]`` 18 | 19 | Args: 20 | mean (sequence): Sequence of means for each channel. 21 | std (sequence): Sequence of standard deviations for each channel. 22 | """ 23 | 24 | def __init__(self, mean, std): 25 | self.mean = mean 26 | self.std = std 27 | 28 | def __call__(self, tensor): 29 | tensor = tensor.clone() 30 | dtype = tensor.dtype 31 | mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device) 32 | std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device) 33 | if tensor.ndim == 4: 34 | tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) 35 | elif tensor.ndim == 3: 36 | tensor.sub_(mean[:, None, None]).div_(std[:, None, None]) 37 | else: 38 | raise TypeError('tensor is not a torch image nor an image batch.') 39 | return tensor 40 | 41 | def __repr__(self): 42 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) 43 | 44 | class NormalizeImageNet(transforms.Normalize): 45 | def __init__(self): 46 | super(NormalizeImageNet, self).__init__(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 47 | 48 | normalize_imagenet = NormalizeImageNet() 49 | 50 | norm_depth_img = lambda img_depth: (img_depth - img_depth.mean()) / max(img_depth.std(), .1) 51 | 52 | norm_batch_imagenet = NormalizeBatch(NormalizeImageNet().mean, NormalizeImageNet().std) 53 | 54 | norm_imagenet_norm_depth_img = lambda img: torch.cat([normalize_imagenet(img[:3]), norm_depth_img(img[3:])]) 55 | 56 | def norm_batch_within_img(tensor_NCHW): 57 | tensor_NCHW_mean = tensor_NCHW.view(*tensor_NCHW.shape[:-2], -1).mean(-1)[:, :, None, None] 58 | tensor_NCHW_std = tensor_NCHW.view(*tensor_NCHW.shape[:-2], -1).std(-1)[:, :, None, None] 59 | return (tensor_NCHW - tensor_NCHW_mean) / tensor_NCHW_std.clamp(min=.1) 60 | 61 | def norm_imagenet_norm_depth_img_batch(tensor): 62 | if tensor.shape[1] in [3, 4]: 63 | tensor_rgb_norm = norm_batch_imagenet(tensor[:, :3]) 64 | if tensor.shape[1] == 4: 65 | tensor_depth_norm = norm_batch_within_img(tensor[:, 3:]) 66 | return torch.cat([tensor_rgb_norm, tensor_depth_norm] , 1) 67 | else: 68 | return tensor_rgb_norm 69 | elif tensor.shape[1] == 1: 70 | return norm_batch_within_img(tensor[:, :1]) 71 | else: 72 | raise NotImplementedError 73 | -------------------------------------------------------------------------------- /CMVDM_structure/src/utils/criterion.py: -------------------------------------------------------------------------------- 1 | 2 | from src.utils.misc import batch_flat, np 3 | import torch 4 | from torch import nn 5 | from src.utils.models import * 6 | import torchvision.models as torchvis_models 7 | from src.config import PROJECT_ROOT 8 | from pytorch_msssim import ssim 9 | 10 | 11 | def pearson_corr(feats_A, feats_B): 12 | ''' 13 | Assume structure NxK and returns K correlations for each of the features. 14 | ''' 15 | feats_A = feats_A - torch.mean(feats_A, 0, keepdim=True) 16 | feats_B = feats_B - torch.mean(feats_B, 0, keepdim=True) 17 | overflow_factor = (batch_flat(feats_A**2).sum(1) + batch_flat(feats_B**2).sum(1)).mean()/2 18 | r = torch.diag((feats_A.transpose(1, 0) / overflow_factor) @ (feats_B / overflow_factor)) * overflow_factor**2 / \ 19 | (torch.sqrt(torch.sum(feats_A ** 2, 0)) * torch.sqrt(torch.sum(feats_B ** 2, 0))) 20 | return r 21 | 22 | def pearson_corr_piecewise(feats_A, feats_B, win_size=None): 23 | if win_size is None: 24 | win_size = len(feats_A) 25 | def corr(a, b): 26 | a = a - torch.mean(a, 0, keepdim=True) 27 | b = b - torch.mean(b, 0, keepdim=True) 28 | r = torch.diag(a.transpose(1, 0) @ b) 29 | return r 30 | def std(x): 31 | return ((x - torch.mean(x, 0, keepdim=True)) ** 2).sum(0).sqrt() 32 | numer = sum([corr(feats_A[i: i+win_size], feats_B[i: i+win_size]) for i in range(len(feats_A)-win_size)]) 33 | denom = sum([std(feats_A[i: i+win_size]) * std(feats_B[i: i+win_size]) for i in range(len(feats_A)-win_size)]) 34 | return numer / denom 35 | 36 | def total_variation(x, p=2): 37 | return ((x[..., 1:, :-1] - x[...,1:, 1:])**2 + 38 | (x[..., :-1, 1:] - x[..., 1:, 1:])**2).pow(p/2).mean() 39 | 40 | def cosine_loss(pred, actual, return_mean=True): 41 | cos_sim = F.cosine_similarity(pred.view(len(pred), -1), actual.view(len(pred), -1), dim=1) 42 | cos_loss = (1-cos_sim) / 2 43 | if return_mean: 44 | return cos_loss.mean() 45 | else: 46 | return cos_loss 47 | 48 | def normalize_channel(in_feat,eps=1e-10): 49 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True) + eps) 50 | return in_feat/(norm_factor + eps) 51 | 52 | def perceptual_loss_layer(pred, actual, mask=1): 53 | pred, actual = map(normalize_channel, [pred, actual]) 54 | return ((1 - (pred * actual).sum(1)) * mask).mean() 55 | 56 | def gradient(tensor_NCHW): 57 | return tensor_NCHW[:, :, :, 1:] - tensor_NCHW[:, :, :, :-1], tensor_NCHW[:, :, 1:, :] - tensor_NCHW[:, :, :-1, :] 58 | 59 | def euclidian_loss(org_matrix, target_matrix): 60 | """ 61 | Euclidian loss is the main loss function in the paper 62 | ||fi(x) - fi(x_0)||_2^2& / ||fi(x_0)||_2^2 63 | """ 64 | distance_matrix = target_matrix - org_matrix 65 | euclidian_distance = alpha_norm(distance_matrix, 2) 66 | normalized_euclidian_distance = euclidian_distance / alpha_norm(org_matrix, 2) 67 | return normalized_euclidian_distance 68 | 69 | def norm_tensor_batchwise(tensor): 70 | return tensor / tensor.view(len(tensor), -1).abs().sum(1).mean() 71 | 72 | class ImageLoss(nn.Module): 73 | def __init__(self, feats_extractor=None, img_xfm_norm=identity, data_norm_factors_images=None): 74 | super(ImageLoss, self).__init__() 75 | if feats_extractor is None: 76 | bbn = torchvis_models.__dict__['vgg16'](pretrained=True) 77 | 78 | 79 | branch_dict = { # VGG16 Blocks # selectedLayers = [3, 6, 10, 14, 18] 80 | # After maxpools 81 | 'conv1': ['features.{}'.format(i) for i in range(5)], 82 | 'conv2': ['features.{}'.format(i) for i in range(10)], 83 | 'conv3': ['features.{}'.format(i) for i in range(17)], 84 | 'conv4': ['features.{}'.format(i) for i in range(24)], 85 | 'conv5': ['features.{}'.format(i) for i in range(31)], 86 | } 87 | 88 | spatial_out_dims = None 89 | main_branch = branch_dict['conv5'] 90 | branch_dict = {layer: branch_module_list[-1] for layer, branch_module_list in branch_dict.items()} 91 | self.feats_extractor = MultiBranch(bbn, branch_dict, main_branch, spatial_out_dims=spatial_out_dims) 92 | else: 93 | self.feats_extractor = feats_extractor 94 | self.feats_extractor.eval() 95 | self.img_xfm_norm = img_xfm_norm 96 | 97 | self.feats_extractor_wrap = lambda x: self.feats_extractor(img_xfm_norm(x)) 98 | 99 | self.branch_weights_cos = [] 100 | 101 | if data_norm_factors_images is not None: 102 | cprint1(' >> Computing normalization factors.') 103 | self.feats_extractor.eval() 104 | with torch.no_grad(): 105 | images_tensor = torch.stack([x for x in data_norm_factors_images]) 106 | norm_factors_channel = [nn.Parameter(images_tensor.abs().mean(-1).mean(-1).mean(0)[None, :, None, None].clamp_(min=1e-6), requires_grad=False)] 107 | norm_factors_channel += [nn.Parameter(x.detach().abs().mean(-1).mean(-1).mean(0)[None, :, None, None].clamp_(min=1e-6), requires_grad=False) \ 108 | for x in self.feats_extractor_wrap(images_tensor)[:-1]] # The [:-1] is to ensure that class layer is not included 109 | else: 110 | norm_factors_channel = [nn.Parameter(torch.ones(1, 3, 1, 1,dtype=torch.float), requires_grad=False) for _ in range(len(branch_dict))] 111 | 112 | self.norm_factors_channel = nn.ParameterDict(dict(zip(['image'] + list(branch_dict.keys())[:-1], norm_factors_channel))) 113 | 114 | 115 | 116 | def forward(self, pred, actual, sum_writer=None): 117 | if FLAGS.pred_interp > 0: 118 | pred = interpolate(pred, size=FLAGS.pred_interp, mode=FLAGS.interp_mode) 119 | actual = interpolate(actual, size=pred.size(-1), mode=FLAGS.interp_mode) 120 | 121 | 122 | if FLAGS.is_rgbd in [0, 1]: 123 | loss_rgb_mae = F.l1_loss(*[(x / self.norm_factors_channel['image'])[:, :3] for x in [pred, actual]]) 124 | else: 125 | loss_rgb_mae = 0 126 | 127 | 128 | 129 | loss_feats_dict = {} 130 | with torch.no_grad(): 131 | actual_feats_list = [x.detach() for x in self.feats_extractor_wrap(actual)] 132 | 133 | pred_feats_list = [x for x in self.feats_extractor_wrap(pred)] 134 | for layer, (pred_feats, actual_feats) in zip(self.feats_extractor.branch_dict.keys(), zip(pred_feats_list, actual_feats_list)): 135 | loss_feats_dict[layer] = {'perceptual': perceptual_loss_layer(pred_feats, actual_feats)} 136 | 137 | for w_cos, branch_name in zip(self.branch_weights_cos, loss_feats_dict.keys()): 138 | if 'cosine' in loss_feats_dict[branch_name]: 139 | loss_feats_dict[branch_name]['cosine'] *= w_cos 140 | 141 | tv_loss = total_variation(pred, p=2*1.25) 142 | 143 | ssim_loss = 1-ssim(*[(x / self.norm_factors_channel['image'])[:, :3] for x in [pred, actual]]) 144 | 145 | loss_list = [('rgb_mae', FLAGS.rgb_mae * loss_rgb_mae)] + \ 146 | [ \ 147 | (f'mlpercep_layerconv{conv_i}', w * loss_feats_dict[f'conv{conv_i}']['perceptual']) for conv_i, w in zip([1,2,3,4,5], [int(x) for x in FLAGS.percept_w]) \ 148 | ] + \ 149 | [('tv', FLAGS.tv_reg * tv_loss)] + \ 150 | + [('ssim_loss', 0.1*ssim_loss)] 151 | 152 | return loss_list 153 | 154 | 155 | 156 | 157 | import torch.nn.functional as F 158 | from math import * 159 | 160 | def ssim_loss(img1, img2, window_size=11, size_average=True, sigma=1.5): 161 | # Set sigma value 162 | k1 = 0.01 163 | k2 = 0.03 164 | L = 1 # image depth 165 | device = img1.device 166 | 167 | # Create 1D Gaussian kernel 168 | gauss = torch.Tensor( 169 | [exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 170 | gauss = gauss / gauss.sum() 171 | 172 | # Create 2D Gaussian kernel 173 | kernel = torch.mm(gauss.unsqueeze(1), gauss.unsqueeze(0)).unsqueeze(0).unsqueeze(0) 174 | kernel = kernel.repeat(img1.size(0), img1.size(1), 1, 1).to(device) 175 | 176 | # Compute mean of images 177 | mu1 = F.conv2d(img1, kernel, padding=window_size // 2).to(device) 178 | mu2 = F.conv2d(img2, kernel, padding=window_size // 2).to(device) 179 | 180 | # Compute variance of images 181 | mu1_sq = mu1.pow(2) 182 | mu2_sq = mu2.pow(2) 183 | mu1_mu2 = mu1 * mu2 184 | 185 | sigma1_sq = F.conv2d(img1 * img1, kernel, padding=window_size // 2) - mu1_sq 186 | sigma2_sq = F.conv2d(img2 * img2, kernel, padding=window_size // 2) - mu2_sq 187 | sigma12 = F.conv2d(img1 * img2, kernel, padding=window_size // 2) - mu1_mu2 188 | 189 | # Compute SSIM 190 | C1 = (k1 * L) ** 2 191 | C2 = (k2 * L) ** 2 192 | 193 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 194 | 195 | if size_average: 196 | return ssim_map.mean() 197 | else: 198 | return ssim_map.mean(1).mean(1).mean(1) 199 | 200 | 201 | class FmriLoss(nn.Module): 202 | def __init__(self): 203 | super(FmriLoss, self).__init__() 204 | self.losses_dict = { 205 | 1.: nn.L1Loss(), 206 | 0.1: cosine_loss 207 | } 208 | 209 | def forward(self, pred, actual): 210 | return sum([w * loss_func(pred, actual) for w, loss_func in self.losses_dict.items()]) 211 | 212 | if __name__ == '__main__': 213 | pass 214 | # import torchvision.transforms as transforms 215 | # from PIL import Image 216 | # img = 217 | # img_xfm_train = transforms.Compose([ 218 | # transforms.Resize(size=112, interpolation=Image.BILINEAR), 219 | # transforms.GaussianBlur(5, sigma=(0.1, 2.0)), 220 | # transforms.CenterCrop(112), 221 | # transforms.ToTensor(), 222 | # ]) 223 | # img = img_xfm_train(img) 224 | # img = img.unsqueeze(0).permute(0,2,3,1) 225 | # print(img.shape) 226 | # rand_img = torch.randn(img.shape) 227 | # loss_ssim = 1- ssim_loss(img, rand_img) 228 | # print(loss_ssim) 229 | -------------------------------------------------------------------------------- /CMVDM_structure/src/utils/preprocess_BOLD5000.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from PIL import Image 4 | import cv2 5 | from tqdm import tqdm 6 | 7 | def canny_contour(images): 8 | ret = [] 9 | for image in tqdm(images): 10 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 11 | contour_img = image.copy() 12 | contour_img.fill(255) 13 | edges = cv2.Canny(gray, 100, 200) 14 | 15 | contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 16 | cv2.drawContours(contour_img, contours, -1, (0,255,0),5) 17 | pil_image = Image.fromarray(contour_img) 18 | pil_image = pil_image.convert("RGB") 19 | 20 | ret.append(pil_image) 21 | 22 | return ret 23 | 24 | def canny_edge(images): 25 | ret = [] 26 | for image in tqdm(images): 27 | gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) 28 | edges = cv2.Canny(gray, 100, 200) 29 | pil_image = Image.fromarray(edges) 30 | pil_image = pil_image.convert("RGB") 31 | 32 | ret.append(pil_image) 33 | 34 | return ret 35 | 36 | 37 | def list_get_all_index(list, value): 38 | return [i for i, v in enumerate(list) if v == value] 39 | 40 | def identity(x): 41 | return x 42 | 43 | def normalize(x, mean=None, std=None): 44 | mean = np.mean(x) if mean is None else mean 45 | std = np.std(x) if std is None else std 46 | return (x - mean) / (std * 1.0) 47 | 48 | def get_stimuli_list(root, sub): 49 | sti_name = [] 50 | path = os.path.join(root, 'Stimuli_Presentation_Lists', sub) 51 | folders = os.listdir(path) 52 | folders.sort() 53 | for folder in folders: 54 | if not os.path.isdir(os.path.join(path, folder)): 55 | continue 56 | files = os.listdir(os.path.join(path, folder)) 57 | files.sort() 58 | for file in files: 59 | if file.endswith('.txt'): 60 | sti_name += list(np.loadtxt(os.path.join(path, folder, file), dtype=str)) 61 | 62 | sti_name_to_return = [] 63 | for name in sti_name: 64 | if name.startswith('rep_'): 65 | name = name.replace('rep_', '', 1) 66 | sti_name_to_return.append(name) 67 | return sti_name_to_return 68 | 69 | 70 | def create_BOLD5000_dataset(path='./data/BOLD5000', fmri_transform=identity, 71 | image_transform=identity, subjects = ['CSI1', 'CSI2', 'CSI3', 'CSI4'], include_nonavg_test=False): 72 | roi_list = ['EarlyVis', 'LOC', 'OPA', 'PPA', 'RSC'] 73 | fmri_path = os.path.join(path, 'BOLD5000_GLMsingle_ROI_betas/py') 74 | img_path = os.path.join(path, 'BOLD5000_Stimuli') 75 | 76 | imgs_dict = np.load(os.path.join(img_path, 'Scene_Stimuli/Presented_Stimuli/img_dict.npy'),allow_pickle=True).item() 77 | repeated_imgs_list = np.loadtxt(os.path.join(img_path, 'Scene_Stimuli', 'repeated_stimuli_113_list.txt'), dtype=str) 78 | 79 | fmri_files = [f for f in os.listdir(fmri_path) if f.endswith('.npy')] 80 | fmri_files.sort() 81 | 82 | fmri_train_major = [] 83 | fmri_test_major = [] 84 | img_train_major = [] 85 | img_test_major = [] 86 | for sub in subjects: 87 | # load fmri 88 | fmri_data_sub = [] 89 | for roi in roi_list: 90 | for npy in fmri_files: 91 | if npy.endswith('.npy') and sub in npy and roi in npy: 92 | fmri_data_sub.append(np.load(os.path.join(fmri_path, npy))) 93 | fmri_data_sub = np.concatenate(fmri_data_sub, axis=-1) # concatenate all rois 94 | 95 | 96 | # load image 97 | img_files = get_stimuli_list(img_path, sub) 98 | img_data_sub = [imgs_dict[name] for name in img_files] 99 | 100 | # split train test 101 | test_idx = [list_get_all_index(img_files, img) for img in repeated_imgs_list] 102 | test_idx = [i for i in test_idx if len(i) > 0] # remove empy list for CSI4 103 | test_fmri = np.stack([fmri_data_sub[idx].mean(axis=0) for idx in test_idx]) 104 | test_img = np.stack([img_data_sub[idx[0]] for idx in test_idx]) 105 | 106 | test_idx_flatten = [] 107 | for idx in test_idx: 108 | test_idx_flatten += idx # flatten 109 | if include_nonavg_test: 110 | test_fmri = np.concatenate([test_fmri, fmri_data_sub[test_idx_flatten]], axis=0) 111 | test_img = np.concatenate([test_img, np.stack([img_data_sub[idx] for idx in test_idx_flatten])], axis=0) 112 | 113 | train_idx = [i for i in range(len(img_files)) if i not in test_idx_flatten] 114 | train_img = np.stack([img_data_sub[idx] for idx in train_idx]) 115 | train_fmri = fmri_data_sub[train_idx] 116 | 117 | fmri_train_major.append(train_fmri) 118 | fmri_test_major.append(test_fmri) 119 | img_train_major.append(train_img) 120 | img_test_major.append(test_img) 121 | break 122 | 123 | num_voxels = fmri_train_major[0].shape[-1] 124 | img_train = img_train_major[0] 125 | img_test = img_test_major[0] 126 | 127 | img_train = [] 128 | for img in img_train_major[0]: 129 | img_train.append(Image.fromarray(img.astype(np.uint8))) 130 | 131 | img_test = [] 132 | for img in img_test_major[0]: 133 | img_test.append(Image.fromarray(img.astype(np.uint8))) 134 | 135 | data_dict = {'num_voxels': num_voxels, 'fmri_train': fmri_train_major[0], 'fmri_test': fmri_test_major[0], 136 | 'img_train': img_train, 'img_test': img_test} 137 | np.save(os.path.join(path, 'preprocessed_bold5000_data_rgb.npy'), data_dict) 138 | 139 | if __name__ == "__main__": 140 | create_BOLD5000_dataset(path='/path/to/Research/SiluetteExtraction/src/data') 141 | print("Done!") -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Controllable Mind Visual Diffusion Model 2 | 3 | 4 | This paper has been accepted to AAAI 2024. [arxiv](https://arxiv.org/pdf/2305.10135.pdf) 5 | 6 | 7 | ## Abstract 8 | Brain signal visualization has emerged as an active research area, serving as a critical interface between the human visual system and computer vision models. Although diffusion models have shown promise in analyzing functional magnetic resonance imaging (fMRI) data, including reconstructing high-quality images consistent with original visual stimuli, their accuracy in extracting semantic and silhouette information from brain signals remains limited. In this regard, we propose a novel approach, referred to as Controllable Mind Visual Diffusion Model (CMVDM). CMVDM extracts semantic and silhouette information from fMRI data using attribute alignment and assistant networks. Additionally, a residual block is incorporated to capture information beyond semantic and silhouette features. We then leverage a control model to fully exploit the extracted information for image synthesis, resulting in generated images that closely resemble the visual stimuli in terms of semantics and silhouette. Through extensive experimentation, we demonstrate that CMVDM outperforms existing state-of-the-art methods both qualitatively and quantitatively. 9 | 10 | 11 |
12 | 13 | 14 | ## Overview 15 | Our framework consists of two main stages: 16 | - Stage A: Semantic training 17 | - Stage B: Control Model Training 18 | 19 | Following Mind-Vis, The **data** folder and **pretrains** folder are not included in this repository. Please download them from [FigShare](https://figshare.com/s/94cd778e6afafb00946e). 20 | 21 | File path | Description 22 | ``` 23 | 24 | /data 25 | ┣ HCP 26 | ┃ ┣ npz 27 | ┃ ┃ ┣ dummy_sub_01 28 | ┃ ┃ ┃ ┗ HCP_visual_voxel.npz 29 | ┃ ┃ ┣ dummy_sub_02 30 | ┃ ┃ ┃ ┗ ... 31 | 32 | ┣ Kamitani 33 | ┃ ┣ npz 34 | ┃ ┃ ┗ sbj_1.npz 35 | ┃ ┃ ┗ sbj_2.npz 36 | ┃ ┃ ┗ sbj_3.npz 37 | ┃ ┃ ┗ sbj_4.npz 38 | ┃ ┃ ┗ sbj_5.npz 39 | ┃ ┃ ┗ images_256.npz 40 | ┃ ┃ ┗ imagenet_class_index.json 41 | ┃ ┃ ┗ imagenet_training_label.csv 42 | ┃ ┃ ┗ imagenet_testing_label.csv 43 | 44 | ┣ BOLD5000 45 | ┃ ┣ BOLD5000_GLMsingle_ROI_betas 46 | ┃ ┃ ┣ py 47 | ┃ ┃ ┃ ┗ CSI1_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_LHEarlyVis.npy 48 | ┃ ┃ ┃ ┗ ... 49 | ┃ ┃ ┃ ┗ CSIx_GLMbetas-TYPED-FITHRF-GLMDENOISE-RR_allses_xx.npy 50 | ┃ ┣ BOLD5000_Stimuli 51 | ┃ ┃ ┣ Image_Labels 52 | ┃ ┃ ┣ Scene_Stimuli 53 | ┃ ┃ ┣ Stimuli_Presentation_Lists 54 | 55 | ``` 56 | 57 | 58 | ## Environment setup 59 | Create and activate conda environment named ```cmvdm``` from our ```env.yaml``` 60 | ```sh 61 | conda env create -f env.yaml 62 | conda activate cmvdm 63 | ``` 64 | 65 | ## Semantic Training (Stage A) 66 | In this stage, the cross-attention heads and pre-trained fMRI encoder will be jointly optimized with fMRI-image pairs. 67 | 68 | Run with custom-pre-trained fMRI encoder and parameters: 69 | ```sh 70 | python code/stageB_ldm_finetune_clip.py --batch_size 24 --kam_path data/Kamitani/npz --bold5000_path data/BOLD5000 --dataset GOD --pretrain_mbm_path frmi_pretrains/GOD/fmri_encoder.pth --pretrain_gm_path frmi_pretrains/ldm/label2img --pretrain_finetune_path frmi_pretrains/GOD/finetuned.pth --config_root pretrains/ldm/label2img/controlnet_config.yaml 71 | ``` 72 | 73 | 74 | ## Control Model Training (Stage B) 75 | In this stage, the control model and the residual blocks will be training. 76 | 77 | Run with custom-pre-trained fMRI encoder and parameters: 78 | ```sh 79 | python code/stageB_ldm_finetune.py --batch_size 32 --kam_path data/Kamitani/npz --bold5000_path data/BOLD5000 --dataset GOD --pretrain_mbm_path frmi_pretrains/GOD/fmri_encoder.pth --pretrain_gm_path frmi_pretrains/ldm/label2img --pretrain_finetune_path frmi_pretrains/GOD/finetuned.pth --config_root pretrains/ldm/label2img/controlnet_config.yaml --checkpoint_path results/GOD/clip_generation/TIME/checkpoint_best.pth 80 | ``` 81 | 82 | The ```frmi_pretrains``` can be found in link [FigShare](https://figshare.com/s/94cd778e6afafb00946e). 83 | -------------------------------------------------------------------------------- /assets/framework.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengbohan0217/CMVDM/fa61e9d1ab27c644168fec4d6b5edd135bf88a80/assets/framework.jpg -------------------------------------------------------------------------------- /code/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | class Config_MAE_fMRI: # back compatibility 5 | pass 6 | class Config_MBM_finetune: # back compatibility 7 | pass 8 | 9 | class Config_MBM_fMRI(Config_MAE_fMRI): 10 | # configs for fmri_pretrain.py 11 | def __init__(self): 12 | # -------------------------------------------- 13 | # MAE for fMRI 14 | # Training Parameters 15 | self.lr = 2.5e-4 16 | self.min_lr = 0. 17 | self.weight_decay = 0.05 18 | self.num_epoch = 500 19 | self.warmup_epochs = 40 20 | self.batch_size = 100 21 | self.clip_grad = 0.8 22 | 23 | # Model Parameters 24 | self.mask_ratio = 0.75 25 | self.patch_size = 16 26 | self.embed_dim = 1024 # has to be a multiple of num_heads 27 | self.decoder_embed_dim = 512 28 | self.depth = 24 29 | self.num_heads = 16 30 | self.decoder_num_heads = 16 31 | self.mlp_ratio = 1.0 32 | 33 | # Project setting 34 | self.root_path = '.' 35 | self.output_path = self.root_path 36 | self.seed = 2022 37 | self.roi = 'VC' 38 | self.aug_times = 1 39 | self.num_sub_limit = None 40 | self.include_hcp = True 41 | self.include_kam = True 42 | self.accum_iter = 1 43 | 44 | self.use_nature_img_loss = False 45 | self.img_recon_weight = 0.5 46 | self.focus_range = None # [0, 1500] # None to disable it 47 | self.focus_rate = 0.6 48 | 49 | # distributed training 50 | self.local_rank = 0 51 | 52 | class Config_MBM_finetune(Config_MBM_finetune): 53 | def __init__(self): 54 | 55 | # Project setting 56 | self.root_path = '.' 57 | self.output_path = self.root_path 58 | self.kam_path = os.path.join(self.root_path, 'data/Kamitani/npz') 59 | self.bold5000_path = os.path.join(self.root_path, 'data/BOLD5000') 60 | self.dataset = 'GOD' # GOD or BOLD5000 61 | self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth') 62 | 63 | self.include_nonavg_test = True 64 | self.kam_subs = ['sbj_3'] 65 | self.bold5000_subs = ['CSI1'] 66 | 67 | # Training Parameters 68 | self.lr = 5.3e-5 69 | self.weight_decay = 0.05 70 | self.num_epoch = 15 71 | self.batch_size = 16 if self.dataset == 'GOD' else 4 72 | self.mask_ratio = 0.75 73 | self.accum_iter = 1 74 | self.clip_grad = 0.8 75 | self.warmup_epochs = 2 76 | self.min_lr = 0. 77 | 78 | # distributed training 79 | self.local_rank = 0 80 | 81 | class Config_Generative_Model: 82 | def __init__(self): 83 | # project parameters 84 | self.seed = 2022 85 | self.root_path = '.' 86 | self.kam_path = os.path.join(self.root_path, 'data/Kamitani/npz') 87 | self.bold5000_path = os.path.join(self.root_path, 'data/BOLD5000') 88 | self.roi = 'VC' 89 | self.patch_size = 16 90 | 91 | # self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/semantic') 92 | self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/label2img') 93 | # self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/text2img-large') 94 | # self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains/ldm/layout2img') 95 | 96 | self.dataset = 'GOD' # GOD or BOLD5000 97 | self.kam_subs = ['sbj_3'] 98 | self.bold5000_subs = ['CSI1'] 99 | self.pretrain_mbm_path = os.path.join(self.root_path, f'pretrains/{self.dataset}/fmri_encoder.pth') 100 | 101 | self.img_size = 256 102 | 103 | np.random.seed(self.seed) 104 | # finetune parameters 105 | self.batch_size = 5 if self.dataset == 'GOD' else 25 106 | self.lr = 5.3e-5 107 | self.num_epoch = 500 108 | 109 | self.precision = 32 110 | self.accumulate_grad = 1 111 | self.crop_ratio = 0.2 112 | self.global_pool = False 113 | self.use_time_cond = True 114 | self.eval_avg = True 115 | 116 | # diffusion sampling parameters 117 | self.num_samples = 5 118 | self.ddim_steps = 250 119 | self.HW = None 120 | # resume check util 121 | self.model_meta = None 122 | self.checkpoint_path = None 123 | 124 | self.port = '12345' 125 | self.config_root=None 126 | self.pretrain_finetune_path=None 127 | -------------------------------------------------------------------------------- /code/dc_ldm/ldm_for_fmri.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dc_ldm.util import instantiate_from_config 4 | from omegaconf import OmegaConf 5 | import torch.nn as nn 6 | import os 7 | from dc_ldm.models.diffusion.plms import PLMSSampler 8 | from einops import rearrange, repeat 9 | from torchvision.utils import make_grid 10 | from torch.utils.data import DataLoader 11 | from sc_mbm.mae_for_fmri import fmri_encoder 12 | 13 | def create_model_from_config(config, num_voxels, global_pool): 14 | model = fmri_encoder(num_voxels=num_voxels, patch_size=config.patch_size, embed_dim=config.embed_dim, 15 | depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, global_pool=global_pool) 16 | return model 17 | 18 | class cond_stage_model(nn.Module): 19 | def __init__(self, metafile, num_voxels, cond_dim=1280, global_pool=True): 20 | super().__init__() 21 | # prepare pretrained fmri mae 22 | model = create_model_from_config(metafile['config'], num_voxels, global_pool) 23 | model.load_checkpoint(metafile['model']) 24 | self.mae = model 25 | self.fmri_seq_len = model.num_patches 26 | self.fmri_latent_dim = model.embed_dim 27 | if global_pool == False: 28 | self.channel_mapper = nn.Sequential( 29 | nn.Conv1d(self.fmri_seq_len, self.fmri_seq_len // 2, 1, bias=True), 30 | nn.Conv1d(self.fmri_seq_len // 2, 77, 1, bias=True) 31 | ) 32 | self.dim_mapper = nn.Linear(self.fmri_latent_dim, cond_dim, bias=True) 33 | self.global_pool = global_pool 34 | 35 | def forward(self, x): 36 | # n, c, w = x.shape 37 | latent_crossattn = self.mae(x) 38 | if self.global_pool == False: 39 | latent_crossattn = self.channel_mapper(latent_crossattn) 40 | latent_crossattn = self.dim_mapper(latent_crossattn) 41 | out = latent_crossattn 42 | return out 43 | 44 | class fLDM: 45 | 46 | def __init__(self, metafile, num_voxels, device=torch.device('cpu'), 47 | pretrain_root='../pretrains/ldm/label2img', 48 | logger=None, ddim_steps=250, global_pool=True, use_time_cond=True, 49 | config_full=None): 50 | self.ckp_path = os.path.join(pretrain_root, 'model.ckpt') 51 | self.config_path = os.path.join(pretrain_root, 'config.yaml') 52 | config = OmegaConf.load(self.config_path) 53 | config.model.params.unet_config.params.use_time_cond = use_time_cond 54 | config.model.params.unet_config.params.global_pool = global_pool 55 | 56 | self.cond_dim = config.model.params.unet_config.params.context_dim 57 | 58 | model = instantiate_from_config(config.model) 59 | pl_sd = torch.load(self.ckp_path, map_location="cpu")['state_dict'] 60 | 61 | m, u = model.load_state_dict(pl_sd, strict=False) 62 | model.cond_stage_trainable = True 63 | model.cond_stage_model = cond_stage_model(metafile, num_voxels, self.cond_dim, global_pool=global_pool) 64 | 65 | model.ddim_steps = ddim_steps 66 | model.re_init_ema() 67 | if logger is not None: 68 | logger.watch(model, log="all", log_graph=False) 69 | 70 | model.p_channels = config.model.params.channels 71 | model.p_image_size = config.model.params.image_size 72 | model.ch_mult = config.model.params.first_stage_config.params.ddconfig.ch_mult 73 | 74 | self.device = device 75 | self.model = model 76 | self.ldm_config = config 77 | self.pretrain_root = pretrain_root 78 | self.fmri_latent_dim = model.cond_stage_model.fmri_latent_dim 79 | self.metafile = metafile 80 | 81 | def finetune(self, trainers, dataset, test_dataset, bs1, lr1, 82 | output_path, config=None): 83 | config.trainer = None 84 | config.logger = None 85 | self.model.main_config = config 86 | self.model.output_path = output_path 87 | # self.model.train_dataset = dataset 88 | self.model.run_full_validation_threshold = 0.15 89 | # stage one: train the cond encoder with the pretrained one 90 | 91 | # # stage one: only optimize conditional encoders 92 | print('\n##### Stage One: only optimize conditional encoders #####') 93 | dataloader = DataLoader(dataset, batch_size=bs1, shuffle=True) 94 | test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False) 95 | self.model.unfreeze_whole_model() 96 | self.model.freeze_first_stage() 97 | 98 | self.model.learning_rate = lr1 99 | self.model.train_cond_stage_only = True 100 | self.model.eval_avg = config.eval_avg 101 | trainers.fit(self.model, dataloader, val_dataloaders=test_loader) 102 | 103 | self.model.unfreeze_whole_model() 104 | 105 | torch.save( 106 | { 107 | 'model_state_dict': self.model.state_dict(), 108 | 'config': config, 109 | 'state': torch.random.get_rng_state() 110 | 111 | }, 112 | os.path.join(output_path, 'checkpoint.pth') 113 | ) 114 | 115 | 116 | @torch.no_grad() 117 | def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, state=None): 118 | # fmri_embedding: n, seq_len, embed_dim 119 | all_samples = [] 120 | if HW is None: 121 | shape = (self.ldm_config.model.params.channels, 122 | self.ldm_config.model.params.image_size, self.ldm_config.model.params.image_size) 123 | else: 124 | num_resolutions = len(self.ldm_config.model.params.first_stage_config.params.ddconfig.ch_mult) 125 | shape = (self.ldm_config.model.params.channels, 126 | HW[0] // 2**(num_resolutions-1), HW[1] // 2**(num_resolutions-1)) 127 | 128 | model = self.model.to(self.device) 129 | sampler = PLMSSampler(model) 130 | if state is not None: 131 | torch.cuda.set_rng_state(state) 132 | 133 | with model.ema_scope(): 134 | model.eval() 135 | for count, item in enumerate(fmri_embedding): 136 | if limit is not None: 137 | if count >= limit: 138 | break 139 | latent = item['fmri'] 140 | gt_image = rearrange(item['image'], 'h w c -> 1 c h w') # h w c 141 | print(f"rendering {num_samples} examples in {ddim_steps} steps.") 142 | # assert latent.shape[-1] == self.fmri_latent_dim, 'dim error' 143 | 144 | c = model.get_learned_conditioning(repeat(latent, 'h w -> c h w', c=num_samples).to(self.device)) 145 | samples_ddim, _ = sampler.sample(S=ddim_steps, 146 | conditioning=c, 147 | batch_size=num_samples, 148 | shape=shape, 149 | verbose=False) 150 | 151 | x_samples_ddim = model.decode_first_stage(samples_ddim) 152 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) 153 | gt_image = torch.clamp((gt_image+1.0)/2.0, min=0.0, max=1.0) 154 | 155 | all_samples.append(torch.cat([gt_image, x_samples_ddim.detach().cpu()], dim=0)) # put groundtruth at first 156 | 157 | 158 | # display as grid 159 | grid = torch.stack(all_samples, 0) 160 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 161 | grid = make_grid(grid, nrow=num_samples+1) 162 | 163 | # to image 164 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 165 | model = model.to('cpu') 166 | 167 | return grid, (255. * torch.stack(all_samples, 0).cpu().numpy()).astype(np.uint8) 168 | 169 | 170 | -------------------------------------------------------------------------------- /code/dc_ldm/ldm_for_fmri_clip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import wandb 3 | import torch 4 | from dc_ldm.util import instantiate_from_config 5 | from omegaconf import OmegaConf 6 | import torch.nn as nn 7 | import os 8 | from dc_ldm.models.diffusion.plms import PLMSSampler 9 | from einops import rearrange, repeat 10 | from torchvision.utils import make_grid 11 | from torch.utils.data import DataLoader 12 | import torch.nn.functional as F 13 | from sc_mbm.mae_for_fmri import fmri_encoder 14 | 15 | def create_model_from_config(config, num_voxels, global_pool): 16 | model = fmri_encoder(num_voxels=num_voxels, patch_size=config.patch_size, embed_dim=config.embed_dim, 17 | depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, global_pool=global_pool) 18 | return model 19 | 20 | class cond_stage_model(nn.Module): 21 | def __init__(self, metafile, num_voxels, cond_dim=1280, clip_dim=512, global_pool=True): 22 | super().__init__() 23 | # prepare pretrained fmri mae 24 | model = create_model_from_config(metafile['config'], num_voxels, global_pool) 25 | model.load_checkpoint(metafile['model']) 26 | self.mae = model 27 | self.fmri_seq_len = model.num_patches 28 | self.fmri_latent_dim = model.embed_dim 29 | if global_pool == False: 30 | self.channel_mapper = nn.Sequential( 31 | nn.Conv1d(self.fmri_seq_len, self.fmri_seq_len // 2, 1, bias=True), 32 | nn.Conv1d(self.fmri_seq_len // 2, 77, 1, bias=True) 33 | ) 34 | self.dim_mapper = nn.Linear(self.fmri_latent_dim, cond_dim, bias=True) 35 | self.global_pool = global_pool 36 | 37 | # define clip matcher 38 | inner_mlp_dim = 1024 39 | self.clip_pred_conv = nn.Sequential( 40 | nn.Conv1d(77, 64, 3, padding=1, bias=True), 41 | nn.Conv1d(64, 4, 3, padding=1, bias=True)) 42 | self.clip_matcher_ = nn.Sequential(nn.Linear(cond_dim*4, inner_mlp_dim), 43 | nn.SiLU(), 44 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 45 | nn.SiLU(), 46 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 47 | nn.SiLU(), 48 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 49 | nn.SiLU(), 50 | nn.Linear(inner_mlp_dim, clip_dim), 51 | nn.SiLU()) 52 | 53 | 54 | def forward(self, x): 55 | # n, c, w = x.shape 56 | latent_crossattn = self.mae(x) 57 | if self.global_pool == False: 58 | latent_crossattn = self.channel_mapper(latent_crossattn) 59 | latent_crossattn = self.dim_mapper(latent_crossattn) 60 | out = latent_crossattn 61 | 62 | return out 63 | 64 | def get_clip_feature(self, c): 65 | # n, c, w = x.shape 66 | c = self.clip_pred_conv(c).view(c.size(0), -1) 67 | clip_feat = self.clip_matcher_(c) 68 | return clip_feat 69 | 70 | 71 | class fLDM: 72 | 73 | def __init__(self, metafile, num_voxels, device=torch.device('cpu'), 74 | pretrain_root='../pretrains/ldm/label2img', 75 | logger=None, ddim_steps=250, global_pool=True, use_time_cond=True, 76 | config_full=None): 77 | self.ckp_path = os.path.join(pretrain_root, 'model.ckpt') 78 | config = OmegaConf.load(config_full.config_root) 79 | config.model.params.unet_config.params.use_time_cond = use_time_cond 80 | config.model.params.unet_config.params.global_pool = global_pool 81 | 82 | self.cond_dim = config.model.params.unet_config.params.context_dim 83 | 84 | model = instantiate_from_config(config.model) 85 | 86 | if config_full.pretrain_finetune_path == None: 87 | pl_sd = torch.load(self.ckp_path, map_location="cpu")['state_dict'] 88 | else: 89 | pl_sd = torch.load(config_full.pretrain_finetune_path, map_location='cpu')['model_state_dict'] 90 | 91 | # m, u = model.load_state_dict(pl_sd, strict=False) 92 | model.cond_stage_trainable = True 93 | model.cond_stage_model = cond_stage_model(metafile, num_voxels, self.cond_dim, global_pool=global_pool) 94 | m, u = model.load_state_dict(pl_sd, strict=False) 95 | 96 | model.ddim_steps = ddim_steps 97 | model.re_init_ema() 98 | if logger is not None: 99 | logger.watch(model, log="all", log_graph=False) 100 | 101 | model.p_channels = config.model.params.channels 102 | model.p_image_size = config.model.params.image_size 103 | model.ch_mult = config.model.params.first_stage_config.params.ddconfig.ch_mult 104 | 105 | self.device = device 106 | self.model = model 107 | self.ldm_config = config 108 | self.pretrain_root = pretrain_root 109 | self.fmri_latent_dim = model.cond_stage_model.fmri_latent_dim 110 | self.metafile = metafile 111 | 112 | def finetune(self, trainers, dataset, test_dataset, bs1, lr1, 113 | output_path, config=None): 114 | config.trainer = None 115 | config.logger = None 116 | self.model.main_config = config 117 | self.model.output_path = output_path 118 | # self.model.train_dataset = dataset 119 | self.model.run_full_validation_threshold = 0.15 120 | # stage one: train the cond encoder with the pretrained one 121 | 122 | # # stage one: only optimize conditional encoders 123 | print('\n##### Stage One: only optimize conditional encoders #####') 124 | dataloader = DataLoader(dataset, batch_size=bs1, shuffle=True) 125 | test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False) 126 | self.model.unfreeze_whole_model() 127 | self.model.freeze_first_stage() 128 | 129 | self.model.learning_rate = lr1 130 | self.model.train_cond_stage_only = True 131 | self.model.eval_avg = config.eval_avg 132 | trainers.fit(self.model, dataloader, val_dataloaders=test_loader) 133 | 134 | self.model.unfreeze_whole_model() 135 | 136 | torch.save( 137 | { 138 | 'model_state_dict': self.model.state_dict(), 139 | 'config': config, 140 | 'state': torch.random.get_rng_state() 141 | 142 | }, 143 | os.path.join(output_path, 'checkpoint.pth') 144 | ) 145 | 146 | 147 | @torch.no_grad() 148 | def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, state=None): 149 | # fmri_embedding: n, seq_len, embed_dim 150 | all_samples = [] 151 | if HW is None: 152 | shape = (self.ldm_config.model.params.channels, 153 | self.ldm_config.model.params.image_size, self.ldm_config.model.params.image_size) 154 | else: 155 | num_resolutions = len(self.ldm_config.model.params.first_stage_config.params.ddconfig.ch_mult) 156 | shape = (self.ldm_config.model.params.channels, 157 | HW[0] // 2**(num_resolutions-1), HW[1] // 2**(num_resolutions-1)) 158 | 159 | model = self.model.to(self.device) 160 | sampler = PLMSSampler(model) 161 | # sampler = DDIMSampler(model) 162 | if state is not None: 163 | torch.cuda.set_rng_state(state) 164 | 165 | with model.ema_scope(): 166 | model.eval() 167 | for count, item in enumerate(fmri_embedding): 168 | if limit is not None: 169 | if count >= limit: 170 | break 171 | latent = item['fmri'] 172 | gt_image = rearrange(item['image'], 'h w c -> 1 c h w') # h w c 173 | print(f"rendering {num_samples} examples in {ddim_steps} steps.") 174 | # assert latent.shape[-1] == self.fmri_latent_dim, 'dim error' 175 | 176 | c = model.get_learned_conditioning(repeat(latent, 'h w -> c h w', c=num_samples).to(self.device)) 177 | samples_ddim, _ = sampler.sample(S=ddim_steps, 178 | conditioning=c, 179 | batch_size=num_samples, 180 | shape=shape, 181 | verbose=False) 182 | 183 | x_samples_ddim = model.decode_first_stage(samples_ddim) 184 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) 185 | gt_image = torch.clamp((gt_image+1.0)/2.0, min=0.0, max=1.0) 186 | 187 | all_samples.append(torch.cat([gt_image, x_samples_ddim.detach().cpu()], dim=0)) # put groundtruth at first 188 | 189 | 190 | # display as grid 191 | grid = torch.stack(all_samples, 0) 192 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 193 | grid = make_grid(grid, nrow=num_samples+1) 194 | 195 | # to image 196 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 197 | model = model.to('cpu') 198 | 199 | return grid, (255. * torch.stack(all_samples, 0).cpu().numpy()).astype(np.uint8) 200 | 201 | 202 | -------------------------------------------------------------------------------- /code/dc_ldm/ldm_for_fmri_control.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from dc_ldm.util import instantiate_from_config 4 | from omegaconf import OmegaConf 5 | import torch.nn as nn 6 | import os 7 | from dc_ldm.models.diffusion.plms import PLMSSampler 8 | from einops import rearrange, repeat 9 | from torchvision.utils import make_grid 10 | from torch.utils.data import DataLoader 11 | from sc_mbm.mae_for_fmri import fmri_encoder 12 | 13 | 14 | def create_model_from_config(config, num_voxels, global_pool): 15 | model = fmri_encoder(num_voxels=num_voxels, patch_size=config.patch_size, embed_dim=config.embed_dim, 16 | depth=config.depth, num_heads=config.num_heads, mlp_ratio=config.mlp_ratio, global_pool=global_pool) 17 | return model 18 | 19 | 20 | class cond_stage_model(nn.Module): 21 | def __init__(self, metafile, num_voxels, cond_dim=1280, clip_dim=512, global_pool=True): 22 | super().__init__() 23 | # prepare pretrained fmri mae 24 | model = create_model_from_config(metafile['config'], num_voxels, global_pool) 25 | model.load_checkpoint(metafile['model']) 26 | self.mae = model 27 | self.fmri_seq_len = model.num_patches 28 | self.fmri_latent_dim = model.embed_dim 29 | if global_pool == False: 30 | self.channel_mapper = nn.Sequential( 31 | nn.Conv1d(self.fmri_seq_len, self.fmri_seq_len // 2, 1, bias=True), 32 | nn.Conv1d(self.fmri_seq_len // 2, 77, 1, bias=True) 33 | ) 34 | self.dim_mapper = nn.Linear(self.fmri_latent_dim, cond_dim, bias=True) 35 | self.global_pool = global_pool 36 | 37 | def forward(self, x): 38 | # n, c, w = x.shape 39 | latent_crossattn = self.mae(x) 40 | if self.global_pool == False: 41 | latent_crossattn = self.channel_mapper(latent_crossattn) 42 | latent_crossattn = self.dim_mapper(latent_crossattn) 43 | out = latent_crossattn 44 | 45 | return out 46 | 47 | 48 | class fLDM: 49 | 50 | def __init__(self, metafile, num_voxels, device=torch.device('cpu'), 51 | pretrain_root='../pretrains/ldm/label2img', 52 | logger=None, ddim_steps=250, global_pool=True, use_time_cond=True, 53 | config_full=None): 54 | self.ckp_path = os.path.join(pretrain_root, 'model.ckpt') 55 | config = OmegaConf.load(config_full.config_root) 56 | config.model.params.unet_config.params.use_time_cond = use_time_cond 57 | config.model.params.unet_config.params.global_pool = global_pool 58 | 59 | self.cond_dim = config.model.params.unet_config.params.context_dim 60 | 61 | model = instantiate_from_config(config.model) 62 | 63 | if config_full.pretrain_finetune_path == None: 64 | pl_sd = torch.load(self.ckp_path, map_location="cpu")['state_dict'] 65 | else: 66 | pl_sd = torch.load(config_full.pretrain_finetune_path, map_location='cpu')['model_state_dict'] 67 | 68 | # m, u = model.load_state_dict(pl_sd, strict=False) 69 | model.cond_stage_trainable = True 70 | model.cond_stage_model = cond_stage_model(metafile, num_voxels, self.cond_dim, global_pool=global_pool) 71 | m, u = model.load_state_dict(pl_sd, strict=False) 72 | 73 | # here to load the stat_dict for controlnet 74 | model.model.diffusion_model.init_control_net() 75 | # model.model.diffusion_model.init_context_ext(model.context_ext.extend_deal) 76 | 77 | model.ddim_steps = ddim_steps 78 | model.re_init_ema() 79 | if logger is not None: 80 | logger.watch(model, log="all", log_graph=False) 81 | 82 | model.p_channels = config.model.params.channels 83 | model.p_image_size = config.model.params.image_size 84 | model.ch_mult = config.model.params.first_stage_config.params.ddconfig.ch_mult 85 | 86 | self.device = device 87 | self.model = model 88 | self.ldm_config = config 89 | self.pretrain_root = pretrain_root 90 | self.fmri_latent_dim = model.cond_stage_model.fmri_latent_dim 91 | self.metafile = metafile 92 | 93 | def finetune(self, trainers, dataset, test_dataset, bs1, lr1, 94 | output_path, config=None): 95 | config.trainer = None 96 | config.logger = None 97 | self.model.main_config = config 98 | self.model.output_path = output_path 99 | # self.model.train_dataset = dataset 100 | self.model.run_full_validation_threshold = 0.15 101 | # stage one: train the cond encoder with the pretrained one 102 | 103 | # # stage one: only optimize conditional encoders 104 | print('\n##### Stage One: only optimize conditional encoders #####') 105 | dataloader = DataLoader(dataset, batch_size=bs1, shuffle=True) 106 | test_loader = DataLoader(test_dataset, batch_size=len(test_dataset), shuffle=False) 107 | self.model.unfreeze_whole_model() 108 | self.model.freeze_first_stage() 109 | 110 | self.model.learning_rate = lr1 111 | self.model.train_cond_stage_only = True 112 | self.model.eval_avg = config.eval_avg 113 | trainers.fit(self.model, dataloader, val_dataloaders=test_loader) 114 | 115 | self.model.unfreeze_whole_model() 116 | 117 | torch.save( 118 | { 119 | 'model_state_dict': self.model.state_dict(), 120 | 'config': config, 121 | 'state': torch.random.get_rng_state() 122 | }, 123 | os.path.join(output_path, 'checkpoint.pth') 124 | ) 125 | 126 | 127 | @torch.no_grad() 128 | def generate(self, fmri_embedding, num_samples, ddim_steps, HW=None, limit=None, state=None): 129 | # fmri_embedding: n, seq_len, embed_dim 130 | all_samples = [] 131 | if HW is None: 132 | shape = (self.ldm_config.model.params.channels, 133 | self.ldm_config.model.params.image_size, self.ldm_config.model.params.image_size) 134 | else: 135 | num_resolutions = len(self.ldm_config.model.params.first_stage_config.params.ddconfig.ch_mult) 136 | shape = (self.ldm_config.model.params.channels, 137 | HW[0] // 2**(num_resolutions-1), HW[1] // 2**(num_resolutions-1)) 138 | 139 | model = self.model.to(self.device) 140 | sampler = PLMSSampler(model) 141 | # sampler = DDIMSampler(model) 142 | if state is not None: 143 | torch.cuda.set_rng_state(state) 144 | 145 | with model.ema_scope(): 146 | model.eval() 147 | for count, item in enumerate(fmri_embedding): 148 | if limit is not None: 149 | if count >= limit: 150 | break 151 | latent = item['fmri'] 152 | gt_image = rearrange(item['image'], 'h w c -> 1 c h w') # h w c 153 | print(f"rendering {num_samples} examples in {ddim_steps} steps.") 154 | # assert latent.shape[-1] == self.fmri_latent_dim, 'dim error' 155 | 156 | c = model.get_learned_conditioning(repeat(latent, 'h w -> c h w', c=num_samples).to(self.device)) 157 | fmri_depth = model.mimic_model(latent.to(self.device)) 158 | if not isinstance(c, list): 159 | c = [c] 160 | c.append(fmri_depth) 161 | samples_ddim, _ = sampler.sample(S=ddim_steps, 162 | conditioning=c, 163 | batch_size=num_samples, 164 | shape=shape, 165 | verbose=False) 166 | 167 | x_samples_ddim = model.decode_first_stage(samples_ddim) 168 | x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0) 169 | gt_image = torch.clamp((gt_image+1.0)/2.0, min=0.0, max=1.0) 170 | 171 | all_samples.append(torch.cat([gt_image, x_samples_ddim.detach().cpu()], dim=0)) # put groundtruth at first 172 | 173 | 174 | # display as grid 175 | grid = torch.stack(all_samples, 0) 176 | grid = rearrange(grid, 'n b c h w -> (n b) c h w') 177 | grid = make_grid(grid, nrow=num_samples+1) 178 | 179 | # to image 180 | grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy() 181 | model = model.to('cpu') 182 | 183 | return grid, (255. * torch.stack(all_samples, 0).cpu().numpy()).astype(np.uint8) 184 | 185 | 186 | -------------------------------------------------------------------------------- /code/dc_ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengbohan0217/CMVDM/fa61e9d1ab27c644168fec4d6b5edd135bf88a80/code/dc_ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /code/dc_ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from dc_ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from dc_ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of diffusion model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in diffusion model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print(f'load from ckpt "{ckpt_path}"') 105 | self.init_from_ckpt(ckpt_path) 106 | 107 | @torch.no_grad() 108 | def get_x_noisy(self, x, t, noise=None): 109 | noise = default(noise, lambda: torch.randn_like(x)) 110 | continuous_sqrt_alpha_cumprod = None 111 | if self.diffusion_model.use_continuous_noise: 112 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 113 | # todo: make sure t+1 is correct here 114 | 115 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 116 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 117 | 118 | def forward(self, x_noisy, t, *args, **kwargs): 119 | return self.model(x_noisy, t) 120 | 121 | @torch.no_grad() 122 | def get_input(self, batch, k): 123 | x = batch[k] 124 | if len(x.shape) == 3: 125 | x = x[..., None] 126 | x = rearrange(x, 'b h w c -> b c h w') 127 | x = x.to(memory_format=torch.contiguous_format).float() 128 | return x 129 | 130 | @torch.no_grad() 131 | def get_conditioning(self, batch, k=None): 132 | if k is None: 133 | k = self.label_key 134 | assert k is not None, 'Needs to provide label key' 135 | 136 | targets = batch[k].to(self.device) 137 | 138 | if self.label_key == 'segmentation': 139 | targets = rearrange(targets, 'b h w c -> b c h w') 140 | for down in range(self.numd): 141 | h, w = targets.shape[-2:] 142 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 143 | 144 | return targets 145 | 146 | def compute_top_k(self, logits, labels, k, reduction="mean"): 147 | _, top_ks = torch.topk(logits, k, dim=1) 148 | if reduction == "mean": 149 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 150 | elif reduction == "none": 151 | return (top_ks == labels[:, None]).float().sum(dim=-1) 152 | 153 | def on_train_epoch_start(self): 154 | # save some memory 155 | self.diffusion_model.model.to('cpu') 156 | 157 | @torch.no_grad() 158 | def write_logs(self, loss, logits, targets): 159 | log_prefix = 'train' if self.training else 'val' 160 | log = {} 161 | log[f"{log_prefix}/loss"] = loss.mean() 162 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 163 | logits, targets, k=1, reduction="mean" 164 | ) 165 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 166 | logits, targets, k=5, reduction="mean" 167 | ) 168 | 169 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 170 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 171 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 172 | lr = self.optimizers().param_groups[0]['lr'] 173 | self.log('lr_abs', lr, on_step=False, logger=True, on_epoch=False, prog_bar=True) 174 | 175 | def shared_step(self, batch, t=None): 176 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 177 | targets = self.get_conditioning(batch) 178 | if targets.dim() == 4: 179 | targets = targets.argmax(dim=1) 180 | if t is None: 181 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 182 | else: 183 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 184 | x_noisy = self.get_x_noisy(x, t) 185 | logits = self(x_noisy, t) 186 | 187 | loss = F.cross_entropy(logits, targets, reduction='none') 188 | 189 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 190 | 191 | loss = loss.mean() 192 | return loss, logits, x_noisy, targets 193 | 194 | def training_step(self, batch, batch_idx): 195 | loss, *_ = self.shared_step(batch) 196 | return loss 197 | 198 | def reset_noise_accs(self): 199 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 200 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 201 | 202 | def on_validation_start(self): 203 | self.reset_noise_accs() 204 | 205 | @torch.no_grad() 206 | def validation_step(self, batch, batch_idx): 207 | loss, *_ = self.shared_step(batch) 208 | 209 | for t in self.noisy_acc: 210 | _, logits, _, targets = self.shared_step(batch, t) 211 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 212 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 213 | 214 | return loss 215 | 216 | def configure_optimizers(self): 217 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 218 | 219 | if self.use_scheduler: 220 | scheduler = instantiate_from_config(self.scheduler_config) 221 | 222 | print("Setting up LambdaLR scheduler...") 223 | scheduler = [ 224 | { 225 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 226 | 'interval': 'step', 227 | 'frequency': 1 228 | }] 229 | return [optimizer], scheduler 230 | 231 | return optimizer 232 | 233 | @torch.no_grad() 234 | def log_images(self, batch, N=8, *args, **kwargs): 235 | log = dict() 236 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 237 | log['inputs'] = x 238 | 239 | y = self.get_conditioning(batch) 240 | 241 | if self.label_key == 'class_label': 242 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 243 | log['labels'] = y 244 | 245 | if ismap(y): 246 | log['labels'] = self.diffusion_model.to_rgb(y) 247 | 248 | for step in range(self.log_steps): 249 | current_time = step * self.log_time_interval 250 | 251 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 252 | 253 | log[f'inputs@t{current_time}'] = x_noisy 254 | 255 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 256 | pred = rearrange(pred, 'b h w c -> b c h w') 257 | 258 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 259 | 260 | for key in log: 261 | log[key] = log[key][:N] 262 | 263 | return log 264 | -------------------------------------------------------------------------------- /code/dc_ldm/models/fmri_base_decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from dc_ldm.modules.diffusionmodules.util import zero_module 6 | 7 | 8 | class SwishImplementation(torch.autograd.Function): 9 | @staticmethod 10 | def forward(ctx, i): 11 | result = i * torch.sigmoid(i) 12 | ctx.save_for_backward(i) 13 | return result 14 | 15 | @staticmethod 16 | def backward(ctx, grad_output): 17 | i = ctx.saved_variables[0] 18 | sigmoid_i = torch.sigmoid(i) 19 | return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) 20 | 21 | 22 | class MemoryEfficientSwish(nn.Module): 23 | def forward(self, x): 24 | return SwishImplementation.apply(x) 25 | 26 | 27 | def norm_depth_01(depth): 28 | # Normalize to 0-1. 29 | # depth is NxHxW 30 | depth_min = depth.view(*depth.shape[:-2], -1).min(-1).values[:, None, None] 31 | depth_max = depth.view(*depth.shape[:-2], -1).max(-1).values[:, None, None] 32 | depth = (depth - depth_min) / (depth_max - depth_min) 33 | return depth 34 | 35 | 36 | class BaseDecoder(nn.Module): 37 | def __init__(self, in_dim, out_img_res, start_CHW=(64, 14, 14), n_conv_layers_ramp=3, n_chan=64, n_chan_output=3, depth_extractor=None): 38 | super(BaseDecoder, self).__init__() 39 | 40 | self.start_CHW = start_CHW 41 | upsample_scale_factor = (out_img_res / start_CHW[-1]) ** (1/n_conv_layers_ramp) 42 | self.input_fc = nn.Linear(in_dim, np.prod(self.start_CHW)) 43 | 44 | kernel_size = 5 45 | 46 | pad_size = int(kernel_size // 2) 47 | self.blocks = nn.ModuleList([nn.Sequential( 48 | nn.Upsample(scale_factor=upsample_scale_factor, mode='bicubic'), 49 | nn.ReflectionPad2d(pad_size), 50 | nn.Conv2d(start_CHW[0], n_chan, kernel_size), 51 | nn.GroupNorm(32, n_chan), 52 | MemoryEfficientSwish(), 53 | ) for block_index in range(n_conv_layers_ramp)] + \ 54 | [nn.Sequential( 55 | nn.Conv2d(start_CHW[0], n_chan, kernel_size, padding=pad_size), 56 | nn.ReLU(inplace=True), 57 | nn.BatchNorm2d(n_chan) 58 | ) for _ in range(0)]) 59 | 60 | self.top = nn.Sequential( 61 | nn.ReflectionPad2d(pad_size), 62 | nn.Conv2d(n_chan, n_chan_output, kernel_size), 63 | nn.Sigmoid() 64 | ) 65 | 66 | self.depth_extractor = depth_extractor 67 | self.trainable = [self.input_fc, self.blocks, self.top] 68 | self.fmri_dim=in_dim 69 | 70 | def forward(self, x): 71 | x = x.view(x.size(0), -1) 72 | x = x[:, :self.fmri_dim] 73 | x = self.input_fc(x) 74 | x = x.view(-1, *self.start_CHW) 75 | 76 | for block_index, block in enumerate(self.blocks): 77 | x = block(x) 78 | 79 | x = self.top(x) 80 | 81 | if self.depth_extractor: 82 | x_depth = self.depth_extractor(x) 83 | x_depth = norm_depth_01(x_depth).unsqueeze(1) 84 | x = torch.cat([x, x_depth], 1) 85 | x = F.interpolate(x, size=(512, 512)) 86 | 87 | return x 88 | 89 | 90 | class text_clip_encoder(nn.Module): 91 | def __init__(self, cond_dim=512, clip_dim=512): 92 | super(text_clip_encoder, self).__init__() 93 | inner_mlp_dim = 1024 94 | self.extend_deal = nn.Sequential(nn.Linear(cond_dim, inner_mlp_dim), 95 | nn.SiLU(), 96 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 97 | nn.SiLU(), 98 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 99 | nn.SiLU(), 100 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 101 | nn.SiLU(), 102 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 103 | nn.SiLU(), 104 | nn.Linear(inner_mlp_dim, cond_dim), 105 | nn.SiLU()) 106 | self.zero_linear = zero_module(nn.Linear(cond_dim, cond_dim)) 107 | 108 | # define clip matcher 109 | self.clip_pred_conv = nn.Sequential( 110 | nn.Conv1d(77, 64, 3, padding=1, bias=True), 111 | nn.Conv1d(64, 4, 3, padding=1, bias=True)) 112 | self.clip_matcher_ = nn.Sequential(nn.Linear(cond_dim*4, inner_mlp_dim), 113 | nn.SiLU(), 114 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 115 | nn.SiLU(), 116 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 117 | nn.SiLU(), 118 | nn.Linear(inner_mlp_dim, inner_mlp_dim), 119 | nn.SiLU(), 120 | nn.Linear(inner_mlp_dim, clip_dim), 121 | nn.SiLU()) 122 | 123 | def forward(self, encode_c): 124 | return self.zero_linear(self.extend_deal(encode_c)) 125 | 126 | def get_clip(self, encode_c): 127 | out = self.extend_deal(encode_c) 128 | out = self.clip_pred_conv(out).view(out.size(0), -1) 129 | return self.clip_matcher_(out) 130 | -------------------------------------------------------------------------------- /code/dc_ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from dc_ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): # Optimize this module as well 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., cond_scale=1.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | self.cond_scale = cond_scale 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, cond_scale=1.): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,cond_scale=cond_scale) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout,cond_scale=cond_scale) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None, cond_scale=1.): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,cond_scale=cond_scale) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /code/dc_ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengbohan0217/CMVDM/fa61e9d1ab27c644168fec4d6b5edd135bf88a80/code/dc_ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /code/dc_ldm/modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from einops import repeat 7 | 8 | from dc_ldm.util import instantiate_from_config 9 | 10 | 11 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 12 | if schedule == "linear": 13 | betas = ( 14 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 15 | ) 16 | 17 | elif schedule == "cosine": 18 | timesteps = ( 19 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 20 | ) 21 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 22 | alphas = torch.cos(alphas).pow(2) 23 | alphas = alphas / alphas[0] 24 | betas = 1 - alphas[1:] / alphas[:-1] 25 | betas = np.clip(betas, a_min=0, a_max=0.999) 26 | 27 | elif schedule == "sqrt_linear": 28 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 29 | elif schedule == "sqrt": 30 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 31 | else: 32 | raise ValueError(f"schedule '{schedule}' unknown.") 33 | return betas.numpy() 34 | 35 | 36 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 37 | if ddim_discr_method == 'uniform': 38 | c = num_ddpm_timesteps // num_ddim_timesteps 39 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 40 | elif ddim_discr_method == 'quad': 41 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 42 | else: 43 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 44 | 45 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 46 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 47 | steps_out = ddim_timesteps + 1 48 | if verbose: 49 | print(f'Selected timesteps for ddim sampler: {steps_out}') 50 | return steps_out 51 | 52 | 53 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 54 | # select alphas for computing the variance schedule 55 | alphas = alphacums[ddim_timesteps] 56 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 57 | 58 | # according the the formula provided in https://arxiv.org/abs/2010.02502 59 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 60 | if verbose: 61 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 62 | print(f'For the chosen value of eta, which is {eta}, ' 63 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 64 | return sigmas, alphas, alphas_prev 65 | 66 | 67 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 68 | """ 69 | Create a beta schedule that discretizes the given alpha_t_bar function, 70 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 71 | :param num_diffusion_timesteps: the number of betas to produce. 72 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 73 | produces the cumulative product of (1-beta) up to that 74 | part of the diffusion process. 75 | :param max_beta: the maximum beta to use; use values lower than 1 to 76 | prevent singularities. 77 | """ 78 | betas = [] 79 | for i in range(num_diffusion_timesteps): 80 | t1 = i / num_diffusion_timesteps 81 | t2 = (i + 1) / num_diffusion_timesteps 82 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 83 | return np.array(betas) 84 | 85 | 86 | def extract_into_tensor(a, t, x_shape): 87 | b, *_ = t.shape 88 | out = a.gather(-1, t) 89 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 90 | 91 | 92 | def checkpoint(func, inputs, params, flag): 93 | """ 94 | Evaluate a function without caching intermediate activations, allowing for 95 | reduced memory at the expense of extra compute in the backward pass. 96 | :param func: the function to evaluate. 97 | :param inputs: the argument sequence to pass to `func`. 98 | :param params: a sequence of parameters `func` depends on but does not 99 | explicitly take as arguments. 100 | :param flag: if False, disable gradient checkpointing. 101 | """ 102 | if flag: 103 | args = tuple(inputs) + tuple(params) 104 | return CheckpointFunction.apply(func, len(inputs), *args) 105 | else: 106 | return func(*inputs) 107 | 108 | 109 | class CheckpointFunction(torch.autograd.Function): 110 | @staticmethod 111 | def forward(ctx, run_function, length, *args): 112 | ctx.run_function = run_function 113 | ctx.input_tensors = list(args[:length]) 114 | ctx.input_params = list(args[length:]) 115 | 116 | with torch.no_grad(): 117 | output_tensors = ctx.run_function(*ctx.input_tensors) 118 | return output_tensors 119 | 120 | @staticmethod 121 | def backward(ctx, *output_grads): 122 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 123 | with torch.enable_grad(): 124 | # Fixes a bug where the first op in run_function modifies the 125 | # Tensor storage in place, which is not allowed for detach()'d 126 | # Tensors. 127 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 128 | output_tensors = ctx.run_function(*shallow_copies) 129 | input_grads = torch.autograd.grad( 130 | output_tensors, 131 | ctx.input_tensors + ctx.input_params, 132 | output_grads, 133 | allow_unused=True, 134 | ) 135 | del ctx.input_tensors 136 | del ctx.input_params 137 | del output_tensors 138 | return (None, None) + input_grads 139 | 140 | 141 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 142 | """ 143 | Create sinusoidal timestep embeddings. 144 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 145 | These may be fractional. 146 | :param dim: the dimension of the output. 147 | :param max_period: controls the minimum frequency of the embeddings. 148 | :return: an [N x dim] Tensor of positional embeddings. 149 | """ 150 | if not repeat_only: 151 | half = dim // 2 152 | freqs = torch.exp( 153 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 154 | ).to(device=timesteps.device) 155 | args = timesteps[:, None].float() * freqs[None] 156 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 157 | if dim % 2: 158 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 159 | else: 160 | embedding = repeat(timesteps, 'b -> b d', d=dim) 161 | return embedding 162 | 163 | 164 | def zero_module(module): 165 | """ 166 | Zero out the parameters of a module and return it. 167 | """ 168 | for p in module.parameters(): 169 | p.detach().zero_() 170 | return module 171 | 172 | 173 | def scale_module(module, scale): 174 | """ 175 | Scale the parameters of a module and return it. 176 | """ 177 | for p in module.parameters(): 178 | p.detach().mul_(scale) 179 | return module 180 | 181 | 182 | def mean_flat(tensor): 183 | """ 184 | Take the mean over all non-batch dimensions. 185 | """ 186 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 187 | 188 | 189 | def normalization(channels): 190 | """ 191 | Make a standard normalization layer. 192 | :param channels: number of input channels. 193 | :return: an nn.Module for normalization. 194 | """ 195 | return GroupNorm32(32, channels) 196 | 197 | 198 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 199 | class SiLU(nn.Module): 200 | def forward(self, x): 201 | return x * torch.sigmoid(x) 202 | 203 | 204 | class GroupNorm32(nn.GroupNorm): 205 | def forward(self, x): 206 | return super().forward(x.float()).type(x.dtype) 207 | 208 | def conv_nd(dims, *args, **kwargs): 209 | """ 210 | Create a 1D, 2D, or 3D convolution module. 211 | """ 212 | if dims == 1: 213 | return nn.Conv1d(*args, **kwargs) 214 | elif dims == 2: 215 | return nn.Conv2d(*args, **kwargs) 216 | elif dims == 3: 217 | return nn.Conv3d(*args, **kwargs) 218 | raise ValueError(f"unsupported dimensions: {dims}") 219 | 220 | 221 | def linear(*args, **kwargs): 222 | """ 223 | Create a linear module. 224 | """ 225 | return nn.Linear(*args, **kwargs) 226 | 227 | 228 | def avg_pool_nd(dims, *args, **kwargs): 229 | """ 230 | Create a 1D, 2D, or 3D average pooling module. 231 | """ 232 | if dims == 1: 233 | return nn.AvgPool1d(*args, **kwargs) 234 | elif dims == 2: 235 | return nn.AvgPool2d(*args, **kwargs) 236 | elif dims == 3: 237 | return nn.AvgPool3d(*args, **kwargs) 238 | raise ValueError(f"unsupported dimensions: {dims}") 239 | 240 | 241 | class HybridConditioner(nn.Module): 242 | 243 | def __init__(self, c_concat_config, c_crossattn_config): 244 | super().__init__() 245 | self.concat_conditioner = instantiate_from_config(c_concat_config) 246 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 247 | 248 | def forward(self, c_concat, c_crossattn): 249 | c_concat = self.concat_conditioner(c_concat) 250 | c_crossattn = self.crossattn_conditioner(c_crossattn) 251 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 252 | 253 | 254 | def noise_like(shape, device, repeat=False): 255 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 256 | noise = lambda: torch.randn(shape, device=device) 257 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /code/dc_ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengbohan0217/CMVDM/fa61e9d1ab27c644168fec4d6b5edd135bf88a80/code/dc_ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /code/dc_ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /code/dc_ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /code/dc_ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zengbohan0217/CMVDM/fa61e9d1ab27c644168fec4d6b5edd135bf88a80/code/dc_ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /code/dc_ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | 5 | from dc_ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 6 | 7 | 8 | class AbstractEncoder(nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | 12 | def encode(self, *args, **kwargs): 13 | raise NotImplementedError 14 | 15 | 16 | 17 | class ClassEmbedder(nn.Module): 18 | def __init__(self, embed_dim, n_classes=1000, key='class'): 19 | super().__init__() 20 | self.key = key 21 | self.embedding = nn.Embedding(n_classes, embed_dim) 22 | 23 | def forward(self, batch, key=None): 24 | if key is None: 25 | key = self.key 26 | # this is for use in crossattn 27 | c = batch[key][:, None] 28 | c = self.embedding(c) 29 | return c 30 | 31 | 32 | class TransformerEmbedder(AbstractEncoder): 33 | """Some transformer encoder layers""" 34 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 35 | super().__init__() 36 | self.device = device 37 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 38 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 39 | 40 | def forward(self, tokens): 41 | tokens = tokens.to(self.device) # meh 42 | z = self.transformer(tokens, return_embeddings=True) 43 | return z 44 | 45 | def encode(self, x): 46 | return self(x) 47 | 48 | 49 | class BERTTokenizer(AbstractEncoder): 50 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 51 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 52 | super().__init__() 53 | from transformers import BertTokenizerFast # TODO: add to reuquirements 54 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 55 | self.device = device 56 | self.vq_interface = vq_interface 57 | self.max_length = max_length 58 | 59 | def forward(self, text): 60 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 61 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 62 | tokens = batch_encoding["input_ids"].to(self.device) 63 | return tokens 64 | 65 | @torch.no_grad() 66 | def encode(self, text): 67 | tokens = self(text) 68 | if not self.vq_interface: 69 | return tokens 70 | return None, None, [None, None, tokens] 71 | 72 | def decode(self, text): 73 | return text 74 | 75 | 76 | class BERTEmbedder(AbstractEncoder): 77 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 78 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 79 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 80 | super().__init__() 81 | self.use_tknz_fn = use_tokenizer 82 | if self.use_tknz_fn: 83 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 84 | self.device = device 85 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 86 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 87 | emb_dropout=embedding_dropout) 88 | 89 | def forward(self, text): 90 | if self.use_tknz_fn: 91 | tokens = self.tknz_fn(text)#.to(self.device) 92 | else: 93 | tokens = text 94 | z = self.transformer(tokens, return_embeddings=True) 95 | return z 96 | 97 | def encode(self, text): 98 | # output of length 77 99 | return self(text) 100 | 101 | 102 | class SpatialRescaler(nn.Module): 103 | def __init__(self, 104 | n_stages=1, 105 | method='bilinear', 106 | multiplier=0.5, 107 | in_channels=3, 108 | out_channels=None, 109 | bias=False): 110 | super().__init__() 111 | self.n_stages = n_stages 112 | assert self.n_stages >= 0 113 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 114 | self.multiplier = multiplier 115 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 116 | self.remap_output = out_channels is not None 117 | if self.remap_output: 118 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 119 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 120 | 121 | def forward(self,x): 122 | for stage in range(self.n_stages): 123 | x = self.interpolator(x, scale_factor=self.multiplier) 124 | 125 | 126 | if self.remap_output: 127 | x = self.channel_mapper(x) 128 | return x 129 | 130 | def encode(self, x): 131 | return self(x) 132 | -------------------------------------------------------------------------------- /code/dc_ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from dc_ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /code/dc_ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | # from vqperceptual import * # replace taming dependency to local vqperceptual.py 6 | 7 | 8 | class LPIPSWithDiscriminator(nn.Module): 9 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 10 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 11 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 12 | disc_loss="hinge"): 13 | 14 | super().__init__() 15 | assert disc_loss in ["hinge", "vanilla"] 16 | self.kl_weight = kl_weight 17 | self.pixel_weight = pixelloss_weight 18 | self.perceptual_loss = LPIPS().eval() 19 | self.perceptual_weight = perceptual_weight 20 | # output log variance 21 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 22 | 23 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 24 | n_layers=disc_num_layers, 25 | use_actnorm=use_actnorm 26 | ).apply(weights_init) 27 | self.discriminator_iter_start = disc_start 28 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 29 | self.disc_factor = disc_factor 30 | self.discriminator_weight = disc_weight 31 | self.disc_conditional = disc_conditional 32 | 33 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 34 | if last_layer is not None: 35 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 36 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 37 | else: 38 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 39 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 40 | 41 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 42 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 43 | d_weight = d_weight * self.discriminator_weight 44 | return d_weight 45 | 46 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 47 | global_step, last_layer=None, cond=None, split="train", 48 | weights=None): 49 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 50 | if self.perceptual_weight > 0: 51 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 52 | rec_loss = rec_loss + self.perceptual_weight * p_loss 53 | 54 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 55 | weighted_nll_loss = nll_loss 56 | if weights is not None: 57 | weighted_nll_loss = weights*nll_loss 58 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 59 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 60 | kl_loss = posteriors.kl() 61 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 62 | 63 | # now the GAN part 64 | if optimizer_idx == 0: 65 | # generator update 66 | if cond is None: 67 | assert not self.disc_conditional 68 | logits_fake = self.discriminator(reconstructions.contiguous()) 69 | else: 70 | assert self.disc_conditional 71 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 72 | g_loss = -torch.mean(logits_fake) 73 | 74 | if self.disc_factor > 0.0: 75 | try: 76 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 77 | except RuntimeError: 78 | assert not self.training 79 | d_weight = torch.tensor(0.0) 80 | else: 81 | d_weight = torch.tensor(0.0) 82 | 83 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 84 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 85 | 86 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 87 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 88 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 89 | "{}/d_weight".format(split): d_weight.detach(), 90 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 91 | "{}/g_loss".format(split): g_loss.detach().mean(), 92 | } 93 | return loss, log 94 | 95 | if optimizer_idx == 1: 96 | # second pass for discriminator update 97 | if cond is None: 98 | logits_real = self.discriminator(inputs.contiguous().detach()) 99 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 100 | else: 101 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 102 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 103 | 104 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 105 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 106 | 107 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 108 | "{}/logits_real".format(split): logits_real.detach().mean(), 109 | "{}/logits_fake".format(split): logits_fake.detach().mean() 110 | } 111 | return d_loss, log 112 | 113 | -------------------------------------------------------------------------------- /code/dc_ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | from dc_ldm.util import exists 11 | 12 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 13 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 14 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 15 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 16 | loss_real = (weights * loss_real).sum() / weights.sum() 17 | loss_fake = (weights * loss_fake).sum() / weights.sum() 18 | d_loss = 0.5 * (loss_real + loss_fake) 19 | return d_loss 20 | 21 | def adopt_weight(weight, global_step, threshold=0, value=0.): 22 | if global_step < threshold: 23 | weight = value 24 | return weight 25 | 26 | 27 | def measure_perplexity(predicted_indices, n_embed): 28 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 29 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 30 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 31 | avg_probs = encodings.mean(0) 32 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 33 | cluster_use = torch.sum(avg_probs > 0) 34 | return perplexity, cluster_use 35 | 36 | def l1(x, y): 37 | return torch.abs(x-y) 38 | 39 | 40 | def l2(x, y): 41 | return torch.pow((x-y), 2) 42 | 43 | 44 | class VQLPIPSWithDiscriminator(nn.Module): 45 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 46 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 47 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 48 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 49 | pixel_loss="l1"): 50 | super().__init__() 51 | assert disc_loss in ["hinge", "vanilla"] 52 | assert perceptual_loss in ["lpips", "clips", "dists"] 53 | assert pixel_loss in ["l1", "l2"] 54 | self.codebook_weight = codebook_weight 55 | self.pixel_weight = pixelloss_weight 56 | if perceptual_loss == "lpips": 57 | print(f"{self.__class__.__name__}: Running with LPIPS.") 58 | self.perceptual_loss = LPIPS().eval() 59 | else: 60 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 61 | self.perceptual_weight = perceptual_weight 62 | 63 | if pixel_loss == "l1": 64 | self.pixel_loss = l1 65 | else: 66 | self.pixel_loss = l2 67 | 68 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 69 | n_layers=disc_num_layers, 70 | use_actnorm=use_actnorm, 71 | ndf=disc_ndf 72 | ).apply(weights_init) 73 | self.discriminator_iter_start = disc_start 74 | if disc_loss == "hinge": 75 | self.disc_loss = hinge_d_loss 76 | elif disc_loss == "vanilla": 77 | self.disc_loss = vanilla_d_loss 78 | else: 79 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 80 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 81 | self.disc_factor = disc_factor 82 | self.discriminator_weight = disc_weight 83 | self.disc_conditional = disc_conditional 84 | self.n_classes = n_classes 85 | 86 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 87 | if last_layer is not None: 88 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 89 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 90 | else: 91 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 92 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 93 | 94 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 95 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 96 | d_weight = d_weight * self.discriminator_weight 97 | return d_weight 98 | 99 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 100 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 101 | if not exists(codebook_loss): 102 | codebook_loss = torch.tensor([0.]).to(inputs.device) 103 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 104 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 105 | pixel_loss = rec_loss 106 | if self.perceptual_weight > 0: 107 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 108 | rec_loss = rec_loss + self.perceptual_weight * p_loss 109 | else: 110 | p_loss = torch.tensor([0.0]) 111 | 112 | nll_loss = rec_loss 113 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 114 | nll_loss = torch.mean(nll_loss) 115 | 116 | if optimizer_idx == 2: 117 | log = { 118 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 119 | "{}/rec_loss".format(split): pixel_loss.detach().mean(), 120 | "{}/p_loss".format(split): p_loss.detach().mean() 121 | } 122 | return nll_loss, log 123 | # now the GAN part 124 | if optimizer_idx == 0: 125 | # generator update 126 | if cond is None: 127 | assert not self.disc_conditional 128 | logits_fake = self.discriminator(reconstructions.contiguous()) 129 | else: 130 | assert self.disc_conditional 131 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 132 | g_loss = -torch.mean(logits_fake) 133 | 134 | try: 135 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 136 | except RuntimeError: 137 | assert not self.training 138 | d_weight = torch.tensor(0.0) 139 | 140 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 141 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 142 | 143 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 144 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 145 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 146 | "{}/rec_loss".format(split): pixel_loss.detach().mean(), 147 | "{}/p_loss".format(split): p_loss.detach().mean(), 148 | "{}/d_weight".format(split): d_weight.detach(), 149 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 150 | "{}/g_loss".format(split): g_loss.detach().mean(), 151 | } 152 | if predicted_indices is not None: 153 | assert self.n_classes is not None 154 | with torch.no_grad(): 155 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 156 | log[f"{split}/perplexity"] = perplexity 157 | log[f"{split}/cluster_usage"] = cluster_usage 158 | return loss, log 159 | 160 | if optimizer_idx == 1: 161 | # second pass for discriminator update 162 | if cond is None: 163 | logits_real = self.discriminator(inputs.contiguous().detach()) 164 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 165 | else: 166 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 167 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 168 | 169 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 170 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 171 | 172 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 173 | "{}/logits_real".format(split): logits_real.detach().mean(), 174 | "{}/logits_fake".format(split): logits_fake.detach().mean() 175 | } 176 | return d_loss, log 177 | -------------------------------------------------------------------------------- /code/dc_ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | 6 | from inspect import isfunction 7 | from PIL import Image, ImageDraw, ImageFont 8 | 9 | 10 | def log_txt_as_img(wh, xc, size=10): 11 | # wh a tuple of (width, height) 12 | # xc a list of captions to plot 13 | b = len(xc) 14 | txts = list() 15 | for bi in range(b): 16 | txt = Image.new("RGB", wh, color="white") 17 | draw = ImageDraw.Draw(txt) 18 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 19 | nc = int(40 * (wh[0] / 256)) 20 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 21 | 22 | try: 23 | draw.text((0, 0), lines, fill="black", font=font) 24 | except UnicodeEncodeError: 25 | print("Cant encode string for logging. Skipping.") 26 | 27 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 28 | txts.append(txt) 29 | txts = np.stack(txts) 30 | txts = torch.tensor(txts) 31 | return txts 32 | 33 | 34 | def ismap(x): 35 | if not isinstance(x, torch.Tensor): 36 | return False 37 | return (len(x.shape) == 4) and (x.shape[1] > 3) 38 | 39 | 40 | def isimage(x): 41 | if not isinstance(x,torch.Tensor): 42 | return False 43 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 44 | 45 | 46 | def exists(x): 47 | return x is not None 48 | 49 | 50 | def default(val, d): 51 | if exists(val): 52 | return val 53 | return d() if isfunction(d) else d 54 | 55 | 56 | def mean_flat(tensor): 57 | """ 58 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 59 | Take the mean over all non-batch dimensions. 60 | """ 61 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 62 | 63 | 64 | def count_params(model, verbose=False): 65 | total_params = sum(p.numel() for p in model.parameters()) 66 | if verbose: 67 | print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") 68 | return total_params 69 | 70 | 71 | def instantiate_from_config(config): 72 | if not "target" in config: 73 | if config == '__is_first_stage__': 74 | return None 75 | elif config == "__is_unconditional__": 76 | return None 77 | raise KeyError("Expected key `target` to instantiate.") 78 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 79 | 80 | 81 | def get_obj_from_str(string, reload=False): 82 | module, cls = string.rsplit(".", 1) 83 | if reload: 84 | module_imp = importlib.import_module(module) 85 | importlib.reload(module_imp) 86 | return getattr(importlib.import_module(module, package=None), cls) -------------------------------------------------------------------------------- /code/eval_metrics.py: -------------------------------------------------------------------------------- 1 | from os import get_inheritable 2 | import numpy as np 3 | from skimage.metrics import structural_similarity as ssim 4 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 5 | from torchmetrics.image.fid import FrechetInceptionDistance 6 | from torchvision.models import ViT_H_14_Weights, vit_h_14 7 | import torch 8 | from einops import rearrange 9 | from torchmetrics.functional import accuracy 10 | from PIL import Image 11 | import torch.nn.functional as F 12 | 13 | def larger_the_better(gt, comp): 14 | return gt > comp 15 | 16 | def smaller_the_better(gt, comp): 17 | return gt < comp 18 | 19 | def mse_metric(img1, img2): 20 | return (np.square(img1 - img2)).mean() 21 | 22 | def pcc_metric(img1, img2): 23 | return np.corrcoef(img1.reshape(-1), img2.reshape(-1))[0, 1] 24 | 25 | def ssim_metric(img1, img2): 26 | return ssim(img1, img2, data_range=255, channel_axis=-1) 27 | 28 | def identity(x): 29 | return x 30 | 31 | class psm_wrapper: 32 | def __init__(self): 33 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 34 | self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='alex').to(self.device) 35 | 36 | @torch.no_grad() 37 | def __call__(self, img1, img2): 38 | if img1.shape[-1] == 3: 39 | img1 = rearrange(img1, 'w h c -> c w h') 40 | img2 = rearrange(img2, 'w h c -> c w h') 41 | img1 = img1 / 127.5 - 1.0 42 | img2 = img2 / 127.5 - 1.0 43 | img1 = np.expand_dims(img1, axis=0) 44 | img2 = np.expand_dims(img2, axis=0) 45 | return self.lpips(torch.FloatTensor(img1).to(self.device), torch.FloatTensor(img2).to(self.device)).item() 46 | 47 | class fid_wrapper: 48 | def __init__(self): 49 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 50 | self.fid = FrechetInceptionDistance(feature=64) 51 | 52 | @torch.no_grad() 53 | def __call__(self, pred_imgs, gt_imgs): 54 | self.fid.reset() 55 | self.fid.update(torch.tensor(rearrange(gt_imgs, 'n w h c -> n c w h')), real=True) 56 | self.fid.update(torch.tensor(rearrange(pred_imgs, 'n w h c -> n c w h')), real=False) 57 | return self.fid.compute().item() 58 | 59 | def pair_wise_score(pred_imgs, gt_imgs, metric, is_sucess): 60 | # pred_imgs: n, w, h, 3 61 | # gt_imgs: n, w, h, 3 62 | # all in pixel values: 0 ~ 255 63 | # return: list of scores 0 ~ 1. 64 | assert len(pred_imgs) == len(gt_imgs) 65 | assert np.min(pred_imgs) >= 0 and np.min(gt_imgs) >= 0 66 | # assert isinstance(metric, fid_wrapper) == False, 'FID not supported' 67 | corrects = [] 68 | for idx, pred in enumerate(pred_imgs): 69 | gt = gt_imgs[idx] 70 | gt_score = metric(pred, gt) 71 | rest = [img for i, img in enumerate(gt_imgs) if i != idx] 72 | count = 0 73 | for comp in rest: 74 | comp_score = metric(pred, comp) 75 | if is_sucess(gt_score, comp_score): 76 | count += 1 77 | corrects.append(count / len(rest)) 78 | return corrects 79 | 80 | def n_way_scores(pred_imgs, gt_imgs, metric, is_sucess, n=2, n_trials=100): 81 | # pred_imgs: n, w, h, 3 82 | # gt_imgs: n, w, h, 3 83 | # all in pixel values: 0 ~ 255 84 | # return: list of scores 0 ~ 1. 85 | assert len(pred_imgs) == len(gt_imgs) 86 | assert n <= len(pred_imgs) and n >= 2 87 | assert np.min(pred_imgs) >= 0 and np.min(gt_imgs) >= 0 88 | assert isinstance(metric, fid_wrapper) == False, 'FID not supported' 89 | corrects = [] 90 | for idx, pred in enumerate(pred_imgs): 91 | gt = gt_imgs[idx] 92 | gt_score = metric(pred, gt) 93 | rest = np.stack([img for i, img in enumerate(gt_imgs) if i != idx]) 94 | correct_count = 0 95 | for _ in range(n_trials): 96 | n_imgs_idx = np.random.choice(len(rest), n-1, replace=False) 97 | n_imgs = rest[n_imgs_idx] 98 | count = 0 99 | for comp in n_imgs: 100 | comp_score = metric(pred, comp) 101 | if is_sucess(gt_score, comp_score): 102 | count += 1 103 | if count == len(n_imgs): 104 | correct_count += 1 105 | corrects.append(correct_count / n_trials) 106 | return corrects 107 | 108 | def metrics_only(pred_imgs, gt_imgs, metric, *args, **kwargs): 109 | assert np.min(pred_imgs) >= 0 and np.min(gt_imgs) >= 0 110 | 111 | return metric(pred_imgs, gt_imgs) 112 | 113 | @torch.no_grad() 114 | def n_way_top_k_acc(pred, class_id, n_way, num_trials=40, top_k=1): 115 | pick_range =[i for i in np.arange(len(pred)) if i != class_id] 116 | acc_list = [] 117 | for t in range(num_trials): 118 | idxs_picked = np.random.choice(pick_range, n_way-1, replace=False) 119 | pred_picked = torch.cat([pred[class_id].unsqueeze(0), pred[idxs_picked]]) 120 | acc = accuracy(pred_picked.unsqueeze(0), torch.tensor([0], device=pred.device), top_k=top_k, task='multiclass', 121 | num_classes=len(pred_picked)) 122 | acc_list.append(acc.item()) 123 | return np.mean(acc_list), np.std(acc_list) 124 | 125 | @torch.no_grad() 126 | def get_n_way_top_k_acc(pred_imgs, ground_truth, n_way, num_trials, top_k, device, return_std=False): 127 | weights = ViT_H_14_Weights.DEFAULT 128 | model = vit_h_14(weights=weights) 129 | preprocess = weights.transforms() 130 | model = model.to(device) 131 | model = model.eval() 132 | 133 | acc_list = [] 134 | std_list = [] 135 | for pred, gt in zip(pred_imgs, ground_truth): 136 | pred = preprocess(Image.fromarray(pred.astype(np.uint8))).unsqueeze(0).to(device) 137 | gt = preprocess(Image.fromarray(gt.astype(np.uint8))).unsqueeze(0).to(device) 138 | gt_class_id = model(gt).squeeze(0).softmax(0).argmax().item() 139 | # gt_out = model(gt).squeeze(0).softmax(0).detach() 140 | pred_out = model(pred).squeeze(0).softmax(0).detach() 141 | # acc = F.cosine_similarity(pred_out, gt_out, dim=0).mean().item() 142 | # std = 0.01 143 | acc, std = n_way_top_k_acc(pred_out, gt_class_id, n_way, num_trials, top_k) 144 | acc_list.append(acc) 145 | std_list.append(std) 146 | 147 | if return_std: 148 | return acc_list, std_list 149 | return acc_list 150 | 151 | def get_similarity_metric(img1, img2, method='pair-wise', metric_name='mse', **kwargs): 152 | # img1: n, w, h, 3 153 | # img2: n, w, h, 3 154 | # all in pixel values: 0 ~ 255 155 | # return: list of scores 0 ~ 1. 156 | if img1.shape[-1] != 3: 157 | img1 = rearrange(img1, 'n c w h -> n w h c') 158 | if img2.shape[-1] != 3: 159 | img2 = rearrange(img2, 'n c w h -> n w h c') 160 | 161 | if method == 'pair-wise': 162 | eval_procedure_func = pair_wise_score 163 | elif method == 'n-way': 164 | eval_procedure_func = n_way_scores 165 | elif method == 'metrics-only': 166 | eval_procedure_func = metrics_only 167 | elif method == 'class': 168 | return get_n_way_top_k_acc(img1, img2, **kwargs) 169 | else: 170 | raise NotImplementedError 171 | 172 | if metric_name == 'mse': 173 | metric_func = mse_metric 174 | decision_func = smaller_the_better 175 | elif metric_name == 'pcc': 176 | metric_func = pcc_metric 177 | decision_func = larger_the_better 178 | elif metric_name == 'ssim': 179 | metric_func = ssim_metric 180 | decision_func = larger_the_better 181 | elif metric_name == 'psm': 182 | metric_func = psm_wrapper() 183 | decision_func = smaller_the_better 184 | elif metric_name == 'fid': 185 | metric_func = fid_wrapper() 186 | decision_func = smaller_the_better 187 | else: 188 | raise NotImplementedError 189 | 190 | return eval_procedure_func(img1, img2, metric_func, decision_func, **kwargs) 191 | -------------------------------------------------------------------------------- /code/parallel_datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_dataset_distributed_(_dataset, world_size, rank, batch_size, **kwargs): 5 | 6 | sampler = torch.utils.data.distributed.DistributedSampler( 7 | _dataset, 8 | num_replicas=world_size, 9 | rank=rank, 10 | ) 11 | dataloader = torch.utils.data.DataLoader( 12 | _dataset, 13 | sampler=sampler, 14 | batch_size=batch_size, 15 | shuffle=False, 16 | drop_last=True, 17 | pin_memory=False, 18 | num_workers=16, 19 | persistent_workers=True, 20 | ) 21 | 22 | return dataloader, 3 23 | -------------------------------------------------------------------------------- /code/sc_mbm/trainer.py: -------------------------------------------------------------------------------- 1 | import math, sys 2 | import torch 3 | import sc_mbm.utils as ut 4 | from torch._six import inf 5 | import numpy as np 6 | import time 7 | 8 | class NativeScalerWithGradNormCount: 9 | state_dict_key = "amp_scaler" 10 | 11 | def __init__(self): 12 | self._scaler = torch.cuda.amp.GradScaler() 13 | 14 | def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True): 15 | self._scaler.scale(loss).backward(create_graph=create_graph) 16 | if update_grad: 17 | if clip_grad is not None: 18 | assert parameters is not None 19 | self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place 20 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 21 | else: 22 | self._scaler.unscale_(optimizer) 23 | norm = get_grad_norm_(parameters) 24 | self._scaler.step(optimizer) 25 | self._scaler.update() 26 | else: 27 | norm = None 28 | return norm 29 | 30 | def state_dict(self): 31 | return self._scaler.state_dict() 32 | 33 | def load_state_dict(self, state_dict): 34 | self._scaler.load_state_dict(state_dict) 35 | 36 | 37 | def get_grad_norm_(parameters, norm_type: float = 2.0): 38 | if isinstance(parameters, torch.Tensor): 39 | parameters = [parameters] 40 | parameters = [p for p in parameters if p.grad is not None] 41 | norm_type = float(norm_type) 42 | if len(parameters) == 0: 43 | return torch.tensor(0.) 44 | device = parameters[0].grad.device 45 | if norm_type == inf: 46 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 47 | else: 48 | total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type) 49 | return total_norm 50 | 51 | 52 | def train_one_epoch(model, data_loader, optimizer, device, epoch, 53 | loss_scaler,log_writer=None, config=None, start_time=None, model_without_ddp=None, 54 | img_feature_extractor=None, preprocess=None): 55 | model.train(True) 56 | optimizer.zero_grad() 57 | total_loss = [] 58 | total_cor = [] 59 | accum_iter = config.accum_iter 60 | for data_iter_step, (data_dcit) in enumerate(data_loader): 61 | 62 | # we use a per iteration (instead of per epoch) lr scheduler 63 | if data_iter_step % accum_iter == 0: 64 | ut.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, config) 65 | samples = data_dcit['fmri'] 66 | 67 | img_features = None 68 | valid_idx = None 69 | if img_feature_extractor is not None: 70 | images = data_dcit['image'] 71 | valid_idx = torch.nonzero(images.sum(dim=(1,2,3)) != 0).squeeze(1) 72 | img_feature_extractor.eval() 73 | with torch.no_grad(): 74 | img_features = img_feature_extractor(preprocess(images[valid_idx]).to(device))['layer2'] 75 | samples = samples.to(device) 76 | # img_features = img_features.to(device) 77 | 78 | optimizer.zero_grad() 79 | with torch.cuda.amp.autocast(enabled=True): 80 | loss, pred, _ = model(samples, img_features, valid_idx=valid_idx, mask_ratio=config.mask_ratio) 81 | # loss.backward() 82 | # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.clip_grad) 83 | # optimizer.step() 84 | 85 | loss_value = loss.item() 86 | 87 | if not math.isfinite(loss_value): 88 | print(f"Loss is {loss_value}, stopping training at step {data_iter_step} epoch {epoch}") 89 | sys.exit(1) 90 | 91 | # loss /= accum_iter 92 | loss_scaler(loss, optimizer, parameters=model.parameters(), clip_grad=config.clip_grad) 93 | 94 | # if (data_iter_step + 1) % accum_iter == 0: 95 | # cal the cor 96 | pred = pred.to('cpu').detach() 97 | samples = samples.to('cpu').detach() 98 | pred = model_without_ddp.unpatchify(pred) 99 | cor = torch.mean(torch.tensor([torch.corrcoef(torch.cat([p, s],axis=0))[0,1] for p, s in zip(pred, samples)])).item() 100 | optimizer.zero_grad() 101 | 102 | total_loss.append(loss_value) 103 | total_cor.append(cor) 104 | 105 | if log_writer is not None: 106 | lr = optimizer.param_groups[0]["lr"] 107 | log_writer.log('train_loss_step', np.mean(total_loss), step=epoch) 108 | log_writer.log('lr', lr, step=epoch) 109 | log_writer.log('cor', np.mean(total_cor), step=epoch) 110 | if start_time is not None: 111 | log_writer.log('time (min)', (time.time() - start_time)/60.0, step=epoch) 112 | if config.local_rank == 0: 113 | print(f'[Epoch {epoch}] loss: {np.mean(total_loss)}') 114 | 115 | return np.mean(total_cor) -------------------------------------------------------------------------------- /code/sc_mbm/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import math 3 | import torch 4 | import os 5 | 6 | def get_1d_sincos_pos_embed(embed_dim, length, cls_token=False): 7 | """ 8 | grid_size: int of the grid height and width 9 | return: 10 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 11 | """ 12 | grid_l = np.arange(length, dtype=np.float32) 13 | 14 | grid_l = grid_l.reshape([1, length]) 15 | pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid_l) 16 | if cls_token: 17 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 18 | return pos_embed 19 | 20 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 21 | """ 22 | embed_dim: output dimension for each position 23 | pos: a list of positions to be encoded: size (M,) 24 | out: (M, D) 25 | """ 26 | assert embed_dim % 2 == 0 27 | omega = np.arange(embed_dim // 2, dtype=float) 28 | omega /= embed_dim / 2. 29 | omega = 1. / 10000**omega # (D/2,) 30 | 31 | pos = pos.reshape(-1) # (M,) 32 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 33 | 34 | emb_sin = np.sin(out) # (M, D/2) 35 | emb_cos = np.cos(out) # (M, D/2) 36 | 37 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 38 | return emb 39 | 40 | 41 | # -------------------------------------------------------- 42 | # Interpolate position embeddings for high-resolution 43 | # References: 44 | # DeiT: https://github.com/facebookresearch/deit 45 | # -------------------------------------------------------- 46 | def interpolate_pos_embed(model, checkpoint_model): 47 | if 'pos_embed' in checkpoint_model: 48 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 49 | embedding_size = pos_embed_checkpoint.shape[-1] 50 | num_patches = model.patch_embed.num_patches 51 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches # cls token 52 | # height (== width) for the checkpoint position embedding 53 | orig_size = int(pos_embed_checkpoint.shape[-2] - num_extra_tokens) 54 | # height (== width) for the new position embedding 55 | new_size = int(num_patches) 56 | # class_token and dist_token are kept unchanged 57 | if orig_size != new_size: 58 | print("Position interpolate from %d to %d" % (orig_size, new_size)) 59 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 60 | # only the position tokens are interpolated 61 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 62 | pos_tokens = pos_tokens.reshape(-1, orig_size, embedding_size).permute(0, 2, 1) 63 | pos_tokens = torch.nn.functional.interpolate( 64 | pos_tokens, size=(new_size)) 65 | pos_tokens = pos_tokens.permute(0, 2, 1) 66 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 67 | checkpoint_model['pos_embed'] = new_pos_embed 68 | 69 | 70 | 71 | def adjust_learning_rate(optimizer, epoch, config): 72 | """Decay the learning rate with half-cycle cosine after warmup""" 73 | if epoch < config.warmup_epochs: 74 | lr = config.lr * epoch / config.warmup_epochs 75 | else: 76 | lr = config.min_lr + (config.lr - config.min_lr) * 0.5 * \ 77 | (1. + math.cos(math.pi * (epoch - config.warmup_epochs) / (config.num_epoch - config.warmup_epochs))) 78 | for param_group in optimizer.param_groups: 79 | if "lr_scale" in param_group: 80 | param_group["lr"] = lr * param_group["lr_scale"] 81 | else: 82 | param_group["lr"] = lr 83 | return lr 84 | 85 | 86 | def save_model(config, epoch, model, optimizer, loss_scaler, checkpoint_paths): 87 | os.makedirs(checkpoint_paths, exist_ok=True) 88 | to_save = { 89 | 'model': model.state_dict(), 90 | 'optimizer': optimizer.state_dict(), 91 | 'epoch': epoch, 92 | 'scaler': loss_scaler.state_dict(), 93 | 'config': config, 94 | } 95 | torch.save(to_save, os.path.join(checkpoint_paths, 'checkpoint.pth')) 96 | 97 | 98 | def load_model(config, model, checkpoint_path ): 99 | checkpoint = torch.load(checkpoint_path, map_location='cpu') 100 | model.load_state_dict(checkpoint['model']) 101 | print(f'Model loaded with {checkpoint_path}') 102 | 103 | def patchify(imgs, patch_size): 104 | """ 105 | imgs: (N, 1, num_voxels) 106 | x: (N, L, patch_size) 107 | """ 108 | p = patch_size 109 | assert imgs.ndim == 3 and imgs.shape[2] % p == 0 110 | 111 | h = imgs.shape[2] // p 112 | x = imgs.reshape(shape=(imgs.shape[0], h, p)) 113 | return x 114 | 115 | def unpatchify(x, patch_size): 116 | """ 117 | x: (N, L, patch_size) 118 | imgs: (N, 1, num_voxels) 119 | """ 120 | p = patch_size 121 | h = x.shape[1] 122 | 123 | imgs = x.reshape(shape=(x.shape[0], 1, h * p)) 124 | return imgs -------------------------------------------------------------------------------- /code/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='CMVDM', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | 'timm' 13 | ], 14 | ) -------------------------------------------------------------------------------- /code/stageA1_mbm_pretrain.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.nn.parallel import DistributedDataParallel 6 | import argparse 7 | import time 8 | import timm.optim.optim_factory as optim_factory 9 | import datetime 10 | import matplotlib.pyplot as plt 11 | import wandb 12 | import copy 13 | 14 | from config import Config_MBM_fMRI 15 | from dataset import hcp_dataset 16 | from sc_mbm.mae_for_fmri import MAEforFMRI 17 | from sc_mbm.trainer import train_one_epoch 18 | from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler 19 | from sc_mbm.utils import save_model 20 | 21 | os.environ["WANDB_START_METHOD"] = "thread" 22 | os.environ['WANDB_DIR'] = "." 23 | 24 | class wandb_logger: 25 | def __init__(self, config): 26 | wandb.init( 27 | project="CMVDM", 28 | anonymous="allow", 29 | group='stageA_sc-mbm', 30 | config=config, 31 | reinit=True) 32 | 33 | self.config = config 34 | self.step = None 35 | 36 | def log(self, name, data, step=None): 37 | if step is None: 38 | wandb.log({name: data}) 39 | else: 40 | wandb.log({name: data}, step=step) 41 | self.step = step 42 | 43 | def watch_model(self, *args, **kwargs): 44 | wandb.watch(*args, **kwargs) 45 | 46 | def log_image(self, name, fig): 47 | if self.step is None: 48 | wandb.log({name: wandb.Image(fig)}) 49 | else: 50 | wandb.log({name: wandb.Image(fig)}, step=self.step) 51 | 52 | def finish(self): 53 | wandb.finish(quiet=True) 54 | 55 | def get_args_parser(): 56 | parser = argparse.ArgumentParser('MBM pre-training for fMRI', add_help=False) 57 | 58 | # Training Parameters 59 | parser.add_argument('--lr', type=float) 60 | parser.add_argument('--weight_decay', type=float) 61 | parser.add_argument('--num_epoch', type=int) 62 | parser.add_argument('--batch_size', type=int) 63 | 64 | # Model Parameters 65 | parser.add_argument('--mask_ratio', type=float) 66 | parser.add_argument('--patch_size', type=int) 67 | parser.add_argument('--embed_dim', type=int) 68 | parser.add_argument('--decoder_embed_dim', type=int) 69 | parser.add_argument('--depth', type=int) 70 | parser.add_argument('--num_heads', type=int) 71 | parser.add_argument('--decoder_num_heads', type=int) 72 | parser.add_argument('--mlp_ratio', type=float) 73 | 74 | # Project setting 75 | parser.add_argument('--root_path', type=str) 76 | parser.add_argument('--seed', type=str) 77 | parser.add_argument('--roi', type=str) 78 | parser.add_argument('--aug_times', type=int) 79 | parser.add_argument('--num_sub_limit', type=int) 80 | 81 | parser.add_argument('--include_hcp', type=bool) 82 | parser.add_argument('--include_kam', type=bool) 83 | 84 | parser.add_argument('--use_nature_img_loss', type=bool) 85 | parser.add_argument('--img_recon_weight', type=float) 86 | 87 | # distributed training parameters 88 | parser.add_argument('--local_rank', type=int) 89 | 90 | return parser 91 | 92 | def create_readme(config, path): 93 | print(config.__dict__) 94 | with open(os.path.join(path, 'README.md'), 'w+') as f: 95 | print(config.__dict__, file=f) 96 | 97 | def fmri_transform(x, sparse_rate=0.2): 98 | # x: 1, num_voxels 99 | x_aug = copy.deepcopy(x) 100 | idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False) 101 | x_aug[idx] = 0 102 | return torch.FloatTensor(x_aug) 103 | 104 | def main(config): 105 | if torch.cuda.device_count() > 1: 106 | torch.cuda.set_device(config.local_rank) 107 | torch.distributed.init_process_group(backend='nccl') 108 | output_path = os.path.join(config.root_path, 'results', 'fmri_pretrain', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) 109 | # output_path = os.path.join(config.root_path, 'results', 'fmri_pretrain') 110 | config.output_path = output_path 111 | logger = wandb_logger(config) if config.local_rank == 0 else None 112 | 113 | if config.local_rank == 0: 114 | os.makedirs(output_path, exist_ok=True) 115 | create_readme(config, output_path) 116 | 117 | device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu') 118 | torch.manual_seed(config.seed) 119 | np.random.seed(config.seed) 120 | 121 | # create dataset and dataloader 122 | dataset_pretrain = hcp_dataset(path=os.path.join(config.root_path, 'data/HCP/npz'), roi=config.roi, patch_size=config.patch_size, 123 | transform=fmri_transform, aug_times=config.aug_times, num_sub_limit=config.num_sub_limit, 124 | include_kam=config.include_kam, include_hcp=config.include_hcp) 125 | 126 | print(f'Dataset size: {len(dataset_pretrain)}\nNumber of voxels: {dataset_pretrain.num_voxels}') 127 | sampler = torch.utils.data.DistributedSampler(dataset_pretrain, rank=config.local_rank) if torch.cuda.device_count() > 1 else None 128 | 129 | dataloader_hcp = DataLoader(dataset_pretrain, batch_size=config.batch_size, sampler=sampler, 130 | shuffle=(sampler is None), pin_memory=True) 131 | 132 | # create model 133 | config.num_voxels = dataset_pretrain.num_voxels 134 | model = MAEforFMRI(num_voxels=dataset_pretrain.num_voxels, patch_size=config.patch_size, embed_dim=config.embed_dim, 135 | decoder_embed_dim=config.decoder_embed_dim, depth=config.depth, 136 | num_heads=config.num_heads, decoder_num_heads=config.decoder_num_heads, mlp_ratio=config.mlp_ratio, 137 | focus_range=config.focus_range, focus_rate=config.focus_rate, 138 | img_recon_weight=config.img_recon_weight, use_nature_img_loss=config.use_nature_img_loss) 139 | model.to(device) 140 | model_without_ddp = model 141 | if torch.cuda.device_count() > 1: 142 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 143 | model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank, find_unused_parameters=config.use_nature_img_loss) 144 | 145 | param_groups = optim_factory.add_weight_decay(model, config.weight_decay) 146 | optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(0.9, 0.95)) 147 | print(optimizer) 148 | loss_scaler = NativeScaler() 149 | 150 | if logger is not None: 151 | logger.watch_model(model,log='all', log_freq=1000) 152 | 153 | cor_list = [] 154 | start_time = time.time() 155 | print('Start Training the fmri MAE ... ...') 156 | img_feature_extractor = None 157 | preprocess = None 158 | if config.use_nature_img_loss: 159 | from torchvision.models import resnet50, ResNet50_Weights 160 | from torchvision.models.feature_extraction import create_feature_extractor 161 | weights = ResNet50_Weights.DEFAULT 162 | preprocess = weights.transforms() 163 | m = resnet50(weights=weights) 164 | img_feature_extractor = create_feature_extractor(m, return_nodes={f'layer2': 'layer2'}).to(device).eval() 165 | for param in img_feature_extractor.parameters(): 166 | param.requires_grad = False 167 | 168 | for ep in range(config.num_epoch): 169 | if torch.cuda.device_count() > 1: 170 | sampler.set_epoch(ep) # to shuffle the data at every epoch 171 | cor = train_one_epoch(model, dataloader_hcp, optimizer, device, ep, loss_scaler, logger, config, start_time, model_without_ddp, 172 | img_feature_extractor, preprocess) 173 | cor_list.append(cor) 174 | if (ep % 20 == 0 or ep + 1 == config.num_epoch) and ep != 0 and config.local_rank == 0: 175 | # save models 176 | save_model(config, ep, model_without_ddp, optimizer, loss_scaler, os.path.join(output_path,'checkpoints')) 177 | # plot figures 178 | plot_recon_figures(model, device, dataset_pretrain, output_path, 5, config, logger, model_without_ddp) 179 | 180 | total_time = time.time() - start_time 181 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 182 | print('Training time {}'.format(total_time_str)) 183 | if logger is not None: 184 | logger.log('max cor', np.max(cor_list), step=config.num_epoch-1) 185 | logger.finish() 186 | return 187 | 188 | @torch.no_grad() 189 | def plot_recon_figures(model, device, dataset, output_path, num_figures = 5, config=None, logger=None, model_without_ddp=None): 190 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True) 191 | model.eval() 192 | fig, axs = plt.subplots(num_figures, 3, figsize=(30,15)) 193 | fig.tight_layout() 194 | axs[0,0].set_title('Ground-truth') 195 | axs[0,1].set_title('Masked Ground-truth') 196 | axs[0,2].set_title('Reconstruction') 197 | 198 | for ax in axs: 199 | sample = next(iter(dataloader))['fmri'] 200 | sample = sample.to(device) 201 | _, pred, mask = model(sample, mask_ratio=config.mask_ratio) 202 | sample_with_mask = model_without_ddp.patchify(sample).to('cpu').numpy().reshape(-1, model_without_ddp.patch_size) 203 | pred = model_without_ddp.unpatchify(pred).to('cpu').numpy().reshape(-1) 204 | sample = sample.to('cpu').numpy().reshape(-1) 205 | mask = mask.to('cpu').numpy().reshape(-1) 206 | # cal the cor 207 | cor = np.corrcoef([pred, sample])[0,1] 208 | 209 | x_axis = np.arange(0, sample.shape[-1]) 210 | # groundtruth 211 | ax[0].plot(x_axis, sample) 212 | # groundtruth with mask 213 | s = 0 214 | for x, m in zip(sample_with_mask,mask): 215 | if m == 0: 216 | ax[1].plot(x_axis[s:s+len(x)], x, color='#1f77b4') 217 | s += len(x) 218 | # pred 219 | ax[2].plot(x_axis, pred) 220 | ax[2].set_ylabel('cor: %.4f'%cor, weight = 'bold') 221 | ax[2].yaxis.set_label_position("right") 222 | 223 | fig_name = 'reconst-%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) 224 | fig.savefig(os.path.join(output_path, f'{fig_name}.png')) 225 | if logger is not None: 226 | logger.log_image('reconst', fig) 227 | plt.close(fig) 228 | 229 | def update_config(args, config): 230 | for attr in config.__dict__: 231 | if hasattr(args, attr): 232 | if getattr(args, attr) != None: 233 | setattr(config, attr, getattr(args, attr)) 234 | return config 235 | 236 | 237 | if __name__ == '__main__': 238 | args = get_args_parser() 239 | args = args.parse_args() 240 | config = Config_MBM_fMRI() 241 | config = update_config(args, config) 242 | main(config) 243 | -------------------------------------------------------------------------------- /code/stageA2_mbm_finetune.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torch.nn.parallel import DistributedDataParallel 6 | import argparse 7 | import time 8 | import timm.optim.optim_factory as optim_factory 9 | import datetime 10 | import matplotlib.pyplot as plt 11 | import wandb 12 | import copy 13 | 14 | # own code 15 | from config import Config_MBM_finetune 16 | from dataset import create_Kamitani_dataset, create_BOLD5000_dataset 17 | from sc_mbm.mae_for_fmri import MAEforFMRI 18 | from sc_mbm.trainer import train_one_epoch 19 | from sc_mbm.trainer import NativeScalerWithGradNormCount as NativeScaler 20 | from sc_mbm.utils import save_model 21 | 22 | 23 | os.environ["WANDB_START_METHOD"] = "thread" 24 | os.environ['WANDB_DIR'] = "." 25 | 26 | class wandb_logger: 27 | def __init__(self, config): 28 | wandb.init( project='CMVDM', 29 | group="stepA_sc-mbm_tune", 30 | anonymous="allow", 31 | config=config, 32 | reinit=True) 33 | 34 | self.config = config 35 | self.step = None 36 | 37 | def log(self, name, data, step=None): 38 | if step is None: 39 | wandb.log({name: data}) 40 | else: 41 | wandb.log({name: data}, step=step) 42 | self.step = step 43 | 44 | def watch_model(self, *args, **kwargs): 45 | wandb.watch(*args, **kwargs) 46 | 47 | def log_image(self, name, fig): 48 | if self.step is None: 49 | wandb.log({name: wandb.Image(fig)}) 50 | else: 51 | wandb.log({name: wandb.Image(fig)}, step=self.step) 52 | 53 | def finish(self): 54 | wandb.finish(quiet=True) 55 | 56 | def get_args_parser(): 57 | parser = argparse.ArgumentParser('MAE finetuning on Test fMRI', add_help=False) 58 | 59 | # Training Parameters 60 | parser.add_argument('--lr', type=float) 61 | parser.add_argument('--weight_decay', type=float) 62 | parser.add_argument('--num_epoch', type=int) 63 | parser.add_argument('--batch_size', type=int) 64 | parser.add_argument('--mask_ratio', type=float) 65 | 66 | # Project setting 67 | parser.add_argument('--root_path', type=str) 68 | parser.add_argument('--pretrain_mbm_path', type=str) 69 | parser.add_argument('--dataset', type=str) 70 | parser.add_argument('--include_nonavg_test', type=bool) 71 | 72 | # distributed training parameters 73 | parser.add_argument('--local_rank', type=int) 74 | 75 | return parser 76 | 77 | def create_readme(config, path): 78 | print(config.__dict__) 79 | with open(os.path.join(path, 'README.md'), 'w+') as f: 80 | print(config.__dict__, file=f) 81 | 82 | def fmri_transform(x, sparse_rate=0.2): 83 | # x: 1, num_voxels 84 | x_aug = copy.deepcopy(x) 85 | idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False) 86 | x_aug[idx] = 0 87 | return torch.FloatTensor(x_aug) 88 | 89 | def main(config): 90 | if torch.cuda.device_count() > 1: 91 | torch.cuda.set_device(config.local_rank) 92 | torch.distributed.init_process_group(backend='nccl') 93 | sd = torch.load(config.pretrain_mbm_path, map_location='cpu') 94 | config_pretrain = sd['config'] 95 | 96 | output_path = os.path.join(config.root_path, 'results', 'fmri_finetune', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) 97 | # output_path = os.path.join(config.root_path, 'results', 'fmri_finetune') 98 | config.output_path = output_path 99 | logger = wandb_logger(config) if config.local_rank == 0 else None 100 | 101 | if config.local_rank == 0: 102 | os.makedirs(output_path, exist_ok=True) 103 | create_readme(config, output_path) 104 | 105 | device = torch.device(f'cuda:{config.local_rank}') if torch.cuda.is_available() else torch.device('cpu') 106 | torch.manual_seed(config_pretrain.seed) 107 | np.random.seed(config_pretrain.seed) 108 | 109 | # create model 110 | num_voxels = (sd['model']['pos_embed'].shape[1] - 1)* config_pretrain.patch_size 111 | model = MAEforFMRI(num_voxels=num_voxels, patch_size=config_pretrain.patch_size, embed_dim=config_pretrain.embed_dim, 112 | decoder_embed_dim=config_pretrain.decoder_embed_dim, depth=config_pretrain.depth, 113 | num_heads=config_pretrain.num_heads, decoder_num_heads=config_pretrain.decoder_num_heads, 114 | mlp_ratio=config_pretrain.mlp_ratio, focus_range=None, use_nature_img_loss=False) 115 | model.load_state_dict(sd['model'], strict=False) 116 | 117 | model.to(device) 118 | model_without_ddp = model 119 | 120 | # create dataset and dataloader 121 | if config.dataset == 'GOD': 122 | _, test_set = create_Kamitani_dataset(path=config.kam_path, patch_size=config_pretrain.patch_size, 123 | subjects=config.kam_subs, fmri_transform=torch.FloatTensor, include_nonavg_test=config.include_nonavg_test) 124 | elif config.dataset == 'BOLD5000': 125 | _, test_set = create_BOLD5000_dataset(path=config.bold5000_path, patch_size=config_pretrain.patch_size, 126 | fmri_transform=torch.FloatTensor, subjects=config.bold5000_subs, include_nonavg_test=config.include_nonavg_test) 127 | else: 128 | raise NotImplementedError 129 | 130 | print(test_set.fmri.shape) 131 | if test_set.fmri.shape[-1] < num_voxels: 132 | test_set.fmri = np.pad(test_set.fmri, ((0,0), (0, num_voxels - test_set.fmri.shape[-1])), 'wrap') 133 | else: 134 | test_set.fmri = test_set.fmri[:, :num_voxels] 135 | print(f'Dataset size: {len(test_set)}') 136 | sampler = torch.utils.data.DistributedSampler(test_set) if torch.cuda.device_count() > 1 else torch.utils.data.RandomSampler(test_set) 137 | dataloader_hcp = DataLoader(test_set, batch_size=config.batch_size, sampler=sampler) 138 | 139 | if torch.cuda.device_count() > 1: 140 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 141 | model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank, find_unused_parameters=config.use_nature_img_loss) 142 | 143 | param_groups = optim_factory.add_weight_decay(model, config.weight_decay) 144 | optimizer = torch.optim.AdamW(param_groups, lr=config.lr, betas=(0.9, 0.95)) 145 | print(optimizer) 146 | loss_scaler = NativeScaler() 147 | 148 | if logger is not None: 149 | logger.watch_model(model,log='all', log_freq=1000) 150 | 151 | cor_list = [] 152 | start_time = time.time() 153 | print('Finetuning MAE on test fMRI ... ...') 154 | for ep in range(config.num_epoch): 155 | if torch.cuda.device_count() > 1: 156 | sampler.set_epoch(ep) # to shuffle the data at every epoch 157 | cor = train_one_epoch(model, dataloader_hcp, optimizer, device, ep, loss_scaler, logger, config, start_time, model_without_ddp) 158 | cor_list.append(cor) 159 | if (ep % 2 == 0 or ep + 1 == config.num_epoch) and ep != 0 and config.local_rank == 0: 160 | # save models 161 | save_model(config_pretrain, ep, model_without_ddp, optimizer, loss_scaler, os.path.join(output_path,'checkpoints')) 162 | # plot figures 163 | plot_recon_figures(model, device, test_set, output_path, 5, config, logger, model_without_ddp) 164 | 165 | total_time = time.time() - start_time 166 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 167 | print('Training time {}'.format(total_time_str)) 168 | if logger is not None: 169 | logger.log('max cor', np.max(cor_list), step=config.num_epoch-1) 170 | logger.finish() 171 | return 172 | 173 | @torch.no_grad() 174 | def plot_recon_figures(model, device, dataset, output_path, num_figures = 5, config=None, logger=None, model_without_ddp=None): 175 | dataloader = DataLoader(dataset, batch_size=1, shuffle=True) 176 | model.eval() 177 | fig, axs = plt.subplots(num_figures, 3, figsize=(30,15)) 178 | fig.tight_layout() 179 | axs[0,0].set_title('Ground-truth') 180 | axs[0,1].set_title('Masked Ground-truth') 181 | axs[0,2].set_title('Reconstruction') 182 | 183 | for ax in axs: 184 | sample = next(iter(dataloader))['fmri'] 185 | sample = sample.to(device) 186 | _, pred, mask = model(sample, mask_ratio=config.mask_ratio) 187 | sample_with_mask = model_without_ddp.patchify(sample).to('cpu').numpy().reshape(-1, model_without_ddp.patch_size) 188 | pred = model_without_ddp.unpatchify(pred).to('cpu').numpy().reshape(-1) 189 | sample = sample.to('cpu').numpy().reshape(-1) 190 | mask = mask.to('cpu').numpy().reshape(-1) 191 | # cal the cor 192 | cor = np.corrcoef([pred, sample])[0,1] 193 | 194 | x_axis = np.arange(0, sample.shape[-1]) 195 | # groundtruth 196 | ax[0].plot(x_axis, sample) 197 | # groundtruth with mask 198 | s = 0 199 | for x, m in zip(sample_with_mask,mask): 200 | if m == 0: 201 | ax[1].plot(x_axis[s:s+len(x)], x, color='#1f77b4') 202 | s += len(x) 203 | # pred 204 | ax[2].plot(x_axis, pred) 205 | ax[2].set_ylabel('cor: %.4f'%cor, weight = 'bold') 206 | ax[2].yaxis.set_label_position("right") 207 | 208 | fig_name = 'reconst-%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S")) 209 | fig.savefig(os.path.join(output_path, f'{fig_name}.png')) 210 | if logger is not None: 211 | logger.log_image('reconst', fig) 212 | plt.close(fig) 213 | 214 | def update_config(args, config): 215 | for attr in config.__dict__: 216 | if hasattr(args, attr): 217 | if getattr(args, attr) != None: 218 | setattr(config, attr, getattr(args, attr)) 219 | return config 220 | 221 | if __name__ == '__main__': 222 | args = get_args_parser() 223 | args = args.parse_args() 224 | config = Config_MBM_finetune() 225 | config = update_config(args, config) 226 | main(config) 227 | 228 | -------------------------------------------------------------------------------- /code/stageB_ldm_finetune_base.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import torch 4 | import argparse 5 | import datetime 6 | import wandb 7 | import torchvision.transforms as transforms 8 | from einops import rearrange 9 | from PIL import Image 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.loggers import WandbLogger 12 | import copy 13 | 14 | # own code 15 | from config import Config_Generative_Model 16 | from dataset import create_Kamitani_dataset, create_BOLD5000_dataset 17 | from dc_ldm.ldm_for_fmri import fLDM 18 | from eval_metrics import get_similarity_metric 19 | 20 | 21 | def wandb_init(config, output_path): 22 | wandb.init( project='CMVDM', 23 | group="stageB_dc-ldm", 24 | anonymous="allow", 25 | config=config, 26 | reinit=True) 27 | create_readme(config, output_path) 28 | 29 | def wandb_finish(): 30 | wandb.finish() 31 | 32 | def to_image(img): 33 | if img.shape[-1] != 3: 34 | img = rearrange(img, 'c h w -> h w c') 35 | img = 255. * img 36 | return Image.fromarray(img.astype(np.uint8)) 37 | 38 | def channel_last(img): 39 | if img.shape[-1] == 3: 40 | return img 41 | return rearrange(img, 'c h w -> h w c') 42 | 43 | def get_eval_metric(samples, avg=True): 44 | metric_list = ['mse', 'pcc', 'ssim', 'psm'] 45 | res_list = [] 46 | 47 | gt_images = [img[0] for img in samples] 48 | gt_images = rearrange(np.stack(gt_images), 'n c h w -> n h w c') 49 | samples_to_run = np.arange(1, len(samples[0])) if avg else [1] 50 | for m in metric_list: 51 | res_part = [] 52 | for s in samples_to_run: 53 | pred_images = [img[s] for img in samples] 54 | pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') 55 | res = get_similarity_metric(pred_images, gt_images, method='pair-wise', metric_name=m) 56 | res_part.append(np.mean(res)) 57 | res_list.append(np.mean(res_part)) 58 | res_part = [] 59 | for s in samples_to_run: 60 | pred_images = [img[s] for img in samples] 61 | pred_images = rearrange(np.stack(pred_images), 'n c h w -> n h w c') 62 | res = get_similarity_metric(pred_images, gt_images, 'class', None, 63 | n_way=50, num_trials=50, top_k=1, device='cuda') 64 | res_part.append(np.mean(res)) 65 | res_list.append(np.mean(res_part)) 66 | res_list.append(np.max(res_part)) 67 | metric_list.append('top-1-class') 68 | metric_list.append('top-1-class (max)') 69 | return res_list, metric_list 70 | 71 | def generate_images(generative_model, fmri_latents_dataset_train, fmri_latents_dataset_test, config): 72 | grid, _ = generative_model.generate(fmri_latents_dataset_train, config.num_samples, 73 | config.ddim_steps, config.HW, 10) # generate 10 instances 74 | grid_imgs = Image.fromarray(grid.astype(np.uint8)) 75 | grid_imgs.save(os.path.join(config.output_path, 'samples_train.png')) 76 | wandb.log({'summary/samples_train': wandb.Image(grid_imgs)}) 77 | 78 | grid, samples = generative_model.generate(fmri_latents_dataset_test, config.num_samples, 79 | config.ddim_steps, config.HW) 80 | grid_imgs = Image.fromarray(grid.astype(np.uint8)) 81 | grid_imgs.save(os.path.join(config.output_path,f'./samples_test.png')) 82 | for sp_idx, imgs in enumerate(samples): 83 | for copy_idx, img in enumerate(imgs[1:]): 84 | img = rearrange(img, 'c h w -> h w c') 85 | Image.fromarray(img).save(os.path.join(config.output_path, 86 | f'./test{sp_idx}-{copy_idx}.png')) 87 | 88 | wandb.log({f'summary/samples_test': wandb.Image(grid_imgs)}) 89 | 90 | # metric, metric_list = get_eval_metric(samples, avg=config.eval_avg) 91 | # metric_dict = {f'summary/pair-wise_{k}':v for k, v in zip(metric_list[:-2], metric[:-2])} 92 | # metric_dict[f'summary/{metric_list[-2]}'] = metric[-2] 93 | # metric_dict[f'summary/{metric_list[-1]}'] = metric[-1] 94 | # wandb.log(metric_dict) 95 | 96 | def normalize(img): 97 | if img.shape[-1] == 3: 98 | img = rearrange(img, 'h w c -> c h w') 99 | img = torch.tensor(img) 100 | img = img * 2.0 - 1.0 # to -1 ~ 1 101 | return img 102 | 103 | class random_crop: 104 | def __init__(self, size, p): 105 | self.size = size 106 | self.p = p 107 | def __call__(self, img): 108 | if torch.rand(1) < self.p: 109 | return transforms.RandomCrop(size=(self.size, self.size))(img) 110 | return img 111 | 112 | def fmri_transform(x, sparse_rate=0.2): 113 | # x: 1, num_voxels 114 | x_aug = copy.deepcopy(x) 115 | idx = np.random.choice(x.shape[0], int(x.shape[0]*sparse_rate), replace=False) 116 | x_aug[idx] = 0 117 | return torch.FloatTensor(x_aug) 118 | 119 | def main(config): 120 | # project setup 121 | device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') 122 | torch.manual_seed(config.seed) 123 | np.random.seed(config.seed) 124 | 125 | crop_pix = int(config.crop_ratio*config.img_size) 126 | img_transform_train = transforms.Compose([ 127 | normalize, 128 | random_crop(config.img_size-crop_pix, p=0.5), 129 | transforms.Resize((256, 256)), 130 | channel_last 131 | ]) 132 | img_transform_test = transforms.Compose([ 133 | normalize, transforms.Resize((256, 256)), 134 | channel_last 135 | ]) 136 | if config.dataset == 'GOD': 137 | fmri_latents_dataset_train, fmri_latents_dataset_test = create_Kamitani_dataset(config.kam_path, config.roi, config.patch_size, 138 | fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], 139 | subjects=config.kam_subs) 140 | num_voxels = fmri_latents_dataset_train.num_voxels 141 | elif config.dataset == 'BOLD5000': 142 | fmri_latents_dataset_train, fmri_latents_dataset_test = create_BOLD5000_dataset(config.bold5000_path, config.patch_size, 143 | fmri_transform=fmri_transform, image_transform=[img_transform_train, img_transform_test], 144 | subjects=config.bold5000_subs) 145 | num_voxels = fmri_latents_dataset_train.num_voxels 146 | else: 147 | raise NotImplementedError 148 | 149 | # prepare pretrained mbm 150 | pretrain_mbm_metafile = torch.load(config.pretrain_mbm_path, map_location='cpu') 151 | # create generateive model 152 | generative_model = fLDM(pretrain_mbm_metafile, num_voxels, 153 | device=device, pretrain_root=config.pretrain_gm_path, logger=config.logger, 154 | ddim_steps=config.ddim_steps, global_pool=config.global_pool, use_time_cond=config.use_time_cond, config_full=config) 155 | 156 | # resume training if applicable 157 | if config.checkpoint_path is not None: 158 | model_meta = torch.load(config.checkpoint_path, map_location='cpu') 159 | generative_model.model.load_state_dict(model_meta['model_state_dict'], strict=False) 160 | print('model resumed') 161 | # finetune the model 162 | trainer = create_trainer(config.num_epoch, config.precision, config.accumulate_grad, logger, check_val_every_n_epoch=5) 163 | generative_model.finetune(trainer, fmri_latents_dataset_train, fmri_latents_dataset_test, 164 | config.batch_size, config.lr, config.output_path, config=config) 165 | 166 | # generate images 167 | # generate limited train images and generate images for subjects seperately 168 | generate_images(generative_model, fmri_latents_dataset_train, fmri_latents_dataset_test, config) 169 | 170 | return 171 | 172 | def get_args_parser(): 173 | parser = argparse.ArgumentParser('Double Conditioning LDM Finetuning', add_help=False) 174 | # project parameters 175 | parser.add_argument('--seed', type=int) 176 | parser.add_argument('--root_path', type=str) 177 | parser.add_argument('--kam_path', type=str) 178 | parser.add_argument('--bold5000_path', type=str) 179 | parser.add_argument('--pretrain_mbm_path', type=str) 180 | parser.add_argument('--crop_ratio', type=float) 181 | parser.add_argument('--dataset', type=str) 182 | 183 | # finetune parameters 184 | parser.add_argument('--batch_size', type=int) 185 | parser.add_argument('--lr', type=float) 186 | parser.add_argument('--num_epoch', type=int) 187 | parser.add_argument('--precision', type=int) 188 | parser.add_argument('--accumulate_grad', type=int) 189 | parser.add_argument('--global_pool', type=bool) 190 | parser.add_argument('--checkpoint_path', type=str) 191 | 192 | # diffusion sampling parameters 193 | parser.add_argument('--pretrain_gm_path', type=str) 194 | parser.add_argument('--num_samples', type=int) 195 | parser.add_argument('--ddim_steps', type=int) 196 | parser.add_argument('--use_time_cond', type=bool) 197 | parser.add_argument('--eval_avg', type=bool) 198 | parser.add_argument('--pretrain_finetune_path', type=str) 199 | parser.add_argument('--config_root', type=str) 200 | 201 | # # distributed training parameters 202 | # parser.add_argument('--local_rank', type=int) 203 | 204 | return parser 205 | 206 | def update_config(args, config): 207 | for attr in config.__dict__: 208 | if hasattr(args, attr): 209 | if getattr(args, attr) != None: 210 | setattr(config, attr, getattr(args, attr)) 211 | return config 212 | 213 | def create_readme(config, path): 214 | print(config.__dict__) 215 | with open(os.path.join(path, 'README.md'), 'w+') as f: 216 | print(config.__dict__, file=f) 217 | 218 | 219 | def create_trainer(num_epoch, precision=32, accumulate_grad_batches=2,logger=None,check_val_every_n_epoch=0): 220 | acc = 'gpu' if torch.cuda.is_available() else 'cpu' 221 | return pl.Trainer(accelerator=acc, max_epochs=num_epoch, logger=logger, 222 | precision=precision, accumulate_grad_batches=accumulate_grad_batches, 223 | enable_checkpointing=False, enable_model_summary=False, gradient_clip_val=0.5, 224 | check_val_every_n_epoch=check_val_every_n_epoch) 225 | 226 | if __name__ == '__main__': 227 | args = get_args_parser() 228 | args = args.parse_args() 229 | config = Config_Generative_Model() 230 | config = update_config(args, config) 231 | 232 | if config.checkpoint_path is not None: 233 | model_meta = torch.load(config.checkpoint_path, map_location='cpu') 234 | ckp = config.checkpoint_path 235 | config = model_meta['config'] 236 | config.checkpoint_path = ckp 237 | print('Resuming from checkpoint: {}'.format(config.checkpoint_path)) 238 | 239 | output_path = os.path.join(config.root_path, 'results', f'{config.dataset}', 'attn_generation', '%s'%(datetime.datetime.now().strftime("%d-%m-%Y-%H-%M-%S"))) 240 | config.output_path = output_path 241 | os.makedirs(output_path, exist_ok=True) 242 | 243 | wandb_init(config, output_path) 244 | 245 | logger = WandbLogger() 246 | config.logger = logger 247 | main(config) 248 | wandb_finish() 249 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: CMVDM 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - python=3.8.5 7 | - pip>=20.3 8 | - pip: 9 | - numpy 10 | - matplotlib 11 | - omegaconf==2.1.1 12 | - einops==0.3.0 13 | - torch-fidelity==0.3.0 14 | - --extra-index-url https://download.pytorch.org/whl/cu116 15 | - torch==1.12.1 16 | - torchvision==0.13.1 17 | - Pillow==9.0.1 18 | - timm==0.5.4 19 | - tqdm==4.64.0 20 | - wandb==0.12.21 21 | - torchmetrics==0.9.2 22 | - scikit-image 23 | - pytorch-lightning==1.6.5 24 | - lpips==0.1.4 25 | - -e ./code -------------------------------------------------------------------------------- /pretrains/ldm/label2img/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: dc_ldm.models.diffusion.ddpm_clip.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: fmri 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | 19 | unet_config: 20 | target: dc_ldm.modules.diffusionmodules.openaimodel.UNetModel 21 | params: 22 | image_size: 64 23 | in_channels: 3 24 | out_channels: 3 25 | model_channels: 192 26 | attention_resolutions: 27 | - 8 28 | - 4 29 | - 2 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 3 35 | - 5 36 | num_heads: 1 37 | use_spatial_transformer: true 38 | transformer_depth: 1 39 | context_dim: 512 40 | cond_scale: 1.0 41 | use_time_cond: true 42 | global_pool: false 43 | 44 | first_stage_config: 45 | target: dc_ldm.models.autoencoder.VQModelInterface 46 | params: 47 | embed_dim: 3 48 | n_embed: 8192 49 | ddconfig: 50 | double_z: false 51 | z_channels: 3 52 | resolution: 256 53 | in_channels: 3 54 | out_ch: 3 55 | ch: 128 56 | ch_mult: 57 | - 1 58 | - 2 59 | - 4 60 | num_res_blocks: 2 61 | attn_resolutions: [] 62 | dropout: 0.0 63 | lossconfig: 64 | target: torch.nn.Identity 65 | 66 | cond_stage_config: 67 | target: dc_ldm.modules.encoders.modules.ClassEmbedder 68 | params: 69 | n_classes: 1001 70 | embed_dim: 512 71 | key: class_label -------------------------------------------------------------------------------- /pretrains/ldm/label2img/controlnet_config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: dc_ldm.models.diffusion.ddpm_control_res.LatentDiffusion # dc_ldm.models.diffusion.ddpm_control.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: fmri 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | shape_decoder_path: GOD silhouette decoder path (to be released) 19 | 20 | unet_config: 21 | target: dc_ldm.modules.diffusionmodules.openaimodel_control.UNetModel 22 | params: 23 | image_size: 64 24 | in_channels: 3 25 | out_channels: 3 26 | model_channels: 192 27 | attention_resolutions: 28 | - 8 29 | - 4 30 | - 2 31 | num_res_blocks: 2 32 | channel_mult: 33 | - 1 34 | - 2 35 | - 3 36 | - 5 37 | num_heads: 1 38 | use_spatial_transformer: true 39 | transformer_depth: 1 40 | context_dim: 512 41 | cond_scale: 1.0 42 | use_time_cond: true 43 | global_pool: false 44 | 45 | first_stage_config: 46 | target: dc_ldm.models.autoencoder.VQModelInterface 47 | params: 48 | embed_dim: 3 49 | n_embed: 8192 50 | ddconfig: 51 | double_z: false 52 | z_channels: 3 53 | resolution: 256 54 | in_channels: 3 55 | out_ch: 3 56 | ch: 128 57 | ch_mult: 58 | - 1 59 | - 2 60 | - 4 61 | num_res_blocks: 2 62 | attn_resolutions: [] 63 | dropout: 0.0 64 | lossconfig: 65 | target: torch.nn.Identity 66 | 67 | cond_stage_config: 68 | target: dc_ldm.modules.encoders.modules.ClassEmbedder 69 | params: 70 | n_classes: 1001 71 | embed_dim: 512 72 | key: class_label -------------------------------------------------------------------------------- /pretrains/ldm/label2img/controlnet_config_BOLD5000.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 0.0001 3 | target: dc_ldm.models.diffusion.ddpm_control_res.LatentDiffusion # dc_ldm.models.diffusion.ddpm_control.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: fmri 12 | image_size: 64 13 | channels: 3 14 | cond_stage_trainable: true 15 | conditioning_key: crossattn 16 | monitor: val/loss 17 | use_ema: False 18 | shape_decoder_path: BOLD5000 silhouette decoder path (to be released) 19 | data_input_dim: 1685 20 | 21 | unet_config: 22 | target: dc_ldm.modules.diffusionmodules.openaimodel_control.UNetModel 23 | params: 24 | image_size: 64 25 | in_channels: 3 26 | out_channels: 3 27 | model_channels: 192 28 | attention_resolutions: 29 | - 8 30 | - 4 31 | - 2 32 | num_res_blocks: 2 33 | channel_mult: 34 | - 1 35 | - 2 36 | - 3 37 | - 5 38 | num_heads: 1 39 | use_spatial_transformer: true 40 | transformer_depth: 1 41 | context_dim: 512 42 | cond_scale: 1.0 43 | use_time_cond: true 44 | global_pool: false 45 | 46 | first_stage_config: 47 | target: dc_ldm.models.autoencoder.VQModelInterface 48 | params: 49 | embed_dim: 3 50 | n_embed: 8192 51 | ddconfig: 52 | double_z: false 53 | z_channels: 3 54 | resolution: 256 55 | in_channels: 3 56 | out_ch: 3 57 | ch: 128 58 | ch_mult: 59 | - 1 60 | - 2 61 | - 4 62 | num_res_blocks: 2 63 | attn_resolutions: [] 64 | dropout: 0.0 65 | lossconfig: 66 | target: torch.nn.Identity 67 | 68 | cond_stage_config: 69 | target: dc_ldm.modules.encoders.modules.ClassEmbedder 70 | params: 71 | n_classes: 1001 72 | embed_dim: 512 73 | key: class_label --------------------------------------------------------------------------------