├── README.md ├── environment.yml ├── howto.ipynb ├── loop4all.sh ├── map_fsaverage_to_hcp.sh ├── mesh_assets └── meshes │ └── 100-500-2k-8k-32k │ ├── README.md │ ├── ico100_neighbor_patches.npy │ ├── ico100_to_ico500_vertices.npy │ ├── ico100_to_ico500_vertices_closests_dists.npy │ ├── ico2k_neighbor_patches.npy │ ├── ico2k_to_ico500_vertices.npy │ ├── ico2k_to_ico500_vertices_closests_dists.npy │ ├── ico2k_to_ico8k_vertices.npy │ ├── ico2k_to_ico8k_vertices_closests_dists.npy │ ├── ico32k_neighbor_patches.npy │ ├── ico32k_to_ico8k_vertices.npy │ ├── ico32k_to_ico8k_vertices_closests_dists.npy │ ├── ico500_neighbor_patches.npy │ ├── ico500_to_ico100_vertices.npy │ ├── ico500_to_ico100_vertices_closests_dists.npy │ ├── ico500_to_ico2k_vertices.npy │ ├── ico500_to_ico2k_vertices_closests_dists.npy │ ├── ico8k_neighbor_patches.npy │ ├── ico8k_to_ico2k_vertices.npy │ ├── ico8k_to_ico2k_vertices_closests_dists.npy │ ├── ico8k_to_ico32k_vertices.npy │ ├── ico8k_to_ico32k_vertices_closests_dists.npy │ ├── icosphere_100.pkl │ ├── icosphere_2k.pkl │ ├── icosphere_32k.pkl │ ├── icosphere_500.pkl │ └── icosphere_8k.pkl ├── model.png ├── test.py ├── train_combined_decoding.py ├── train_feature_decoding.py └── utils4image ├── __init__.py ├── config.py ├── dataloaders.py ├── eva_utils.py ├── models.py ├── ugscnn_utils.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # Decoding natural image stimuli from fMRI data with a surface-based convolutional network 2 | 3 | Code for MIDL 2023 paper [Decoding natural image stimuli from fMRI data with a surface-based convolutional network](https://arxiv.org/abs/2212.02409) (oral). 4 | 5 | ![](/model.png) 6 | 7 | ## !Update! 8 | Preprocessed brain data on surface can be found [here](https://cornell.box.com/s/7fjry2vhy11mpl908b6rrmpmbxmbhpdp). 9 | 10 | ## Requirements 11 | 1. Clone the IC-GAN repo: `git clone https://github.com/facebookresearch/ic_gan.git` 12 | 2. Install required packages: 13 | ```shell 14 | conda env create -f environment.yml 15 | conda activate meshconvdec 16 | ``` 17 | 3. If you use NSD, it is publicly available at https://naturalscenesdataset.org/. Follow the 2nd step in the Instructions to prepare the data. 18 | 19 | ## Instructions 20 | 1. Install required packages. 21 | 2. (optional) If the data is in fsaverage space instead of fs_LR_32K surface, e.g., the data in NSD repo here https://natural-scenes-dataset.s3.amazonaws.com/index.html#nsddata_betas/ppdata/subj01/fsaverage/betas_fithrf_GLMdenoise_RR/. Run `loop4all.sh` which will run`map_fsaverage_to_hcp.sh` for every session and every subject in NSD. 22 | 2. Run `python train_feature_decoding.py` to train the `Cortex2Semantic` model. 23 | 3. Run `python train_combined_decoding.py` to train the `Cortex2Detail` model. 24 | 4. Run `python test.py` to generate the decoded images. 25 | 26 | Please note that the file paths and the hyparameters may need to be changed according to your own settings. 27 | 28 | ## Availability 29 | We welcome researchers to use our models and to compare their new approaches with ours. 30 | Pretrained models and reconstructed images for 1000 shared images in NSD can be downloaded [here](https://cornell.box.com/s/epev6y4y6foqjey4pxtg4txsyfmcvwmj). 31 | 32 | ## Citation 33 | If you find this work helpful for your research, please cite our paper: 34 | ``` 35 | @article{gu2022decoding, 36 | title={Decoding natural image stimuli from fMRI data with a surface-based convolutional network}, 37 | author={Gu, Zijin and Jamison, Keith and Kuceyeski, Amy and Sabuncu, Mert}, 38 | journal={arXiv preprint arXiv:2212.02409}, 39 | year={2022} 40 | } 41 | ``` 42 | 43 | 44 | 45 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: meshconvdec 2 | channels: 3 | - pytorch 4 | - anaconda 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - argparse 9 | - numpy 10 | - os 11 | - torch 12 | - torchvision 13 | - h5py 14 | - math 15 | - pickle 16 | - scipy 17 | - sys 18 | - collections 19 | -------------------------------------------------------------------------------- /loop4all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | i=$1 4 | 5 | mkdir -p ./nsd/responses/subj0$i/fs_LR_32k 6 | for j in {01..40} 7 | do 8 | bash map_fsaverage_to_hcp.sh ./nsd/responses/subj0$i/fsaverage/betas_fithrf_GLMdenoise_RR/lh.betas_session$j.mgh ./nsd/responses/subj0$i/fs_LR_32k/lh_betas_session$j.func.gii 9 | bash map_fsaverage_to_hcp.sh ./nsd/responses/subj0$i/fsaverage/betas_fithrf_GLMdenoise_RR/rh.betas_session$j.mgh ./nsd/responses/subj0$i/fs_LR_32k/rh_betas_session$j.func.gii 10 | done 11 | -------------------------------------------------------------------------------- /map_fsaverage_to_hcp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e 4 | 5 | infile=$1 6 | outgii=$2 #must be "*.func.gii" 7 | 8 | if [[ ${outgii} != *.func.gii ]]; then 9 | echo "Output file must end in func.gii!" 10 | exit 1 11 | fi 12 | 13 | #downloaded from: https://github.com/Washington-University/HCPpipelines/tree/master/global/templates/standard_mesh_atlases 14 | templatedir=$HOME/HCPpipelines/global/templates/standard_mesh_atlases/resample_fsaverage 15 | 16 | #lh.* -> L, rh.* -> R 17 | #H=$(basename ${infile} | sed -E 's/^lh\..+$/L/g' | sed -E 's/^rh\..+$/R/') 18 | H=$(basename ${infile} | sed "s/^[^.]*\.\([^.]*\)\..*/\1/g") 19 | 20 | tmpfile=${infile}_tmp.func.gii 21 | 22 | #use freesurfer mri_convert to go from .mgh to .gii 23 | mri_convert ${infile} ${tmpfile} 24 | 25 | #use HCP connectome workbench wb_command to map to fs_LR32k 26 | wb_command -metric-resample ${tmpfile} $templatedir/fsaverage_std_sphere.$H.164k_fsavg_$H.surf.gii $templatedir/fs_LR-deformed_to-fsaverage.$H.sphere.32k_fs_LR.surf.gii ADAP_BARY_AREA ${outgii} -area-metrics $templatedir/fsaverage.$H.midthickness_va_avg.164k_fsavg_$H.shape.gii $templatedir/fs_LR.$H.midthickness_va_avg.32k_fs_LR.shape.gii 27 | # wb_command -metric-resample ${tmpfile} $templatedir/fsaverage6_std_sphere.$H.41k_fsavg_$H.surf.gii $templatedir/fs_LR-deformed_to-fsaverage.$H.sphere.32k_fs_LR.surf.gii ADAP_BARY_AREA ${outgii} -area-metrics $templatedir/fsaverage6.$H.midthickness_va_avg.41k_fsavg_$H.shape.gii $templatedir/fs_LR.$H.midthickness_va_avg.32k_fs_LR.shape.gii 28 | 29 | rm -f ${tmpfile} 30 | -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/README.md: -------------------------------------------------------------------------------- 1 | The mesh files needed in this code. 2 | -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico100_neighbor_patches.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico100_neighbor_patches.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico100_to_ico500_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico100_to_ico500_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico100_to_ico500_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico100_to_ico500_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico2k_neighbor_patches.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico2k_neighbor_patches.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico500_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico500_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico500_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico500_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico8k_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico8k_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico8k_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico2k_to_ico8k_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico32k_neighbor_patches.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico32k_neighbor_patches.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico32k_to_ico8k_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico32k_to_ico8k_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico32k_to_ico8k_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico32k_to_ico8k_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico500_neighbor_patches.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico500_neighbor_patches.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico100_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico100_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico100_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico100_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico2k_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico2k_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico2k_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico500_to_ico2k_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico8k_neighbor_patches.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico8k_neighbor_patches.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico2k_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico2k_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico2k_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico2k_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico32k_vertices.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico32k_vertices.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico32k_vertices_closests_dists.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/ico8k_to_ico32k_vertices_closests_dists.npy -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/icosphere_100.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/icosphere_100.pkl -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/icosphere_2k.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/icosphere_2k.pkl -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/icosphere_32k.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/icosphere_32k.pkl -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/icosphere_500.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/icosphere_500.pkl -------------------------------------------------------------------------------- /mesh_assets/meshes/100-500-2k-8k-32k/icosphere_8k.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/mesh_assets/meshes/100-500-2k-8k-32k/icosphere_8k.pkl -------------------------------------------------------------------------------- /model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zijin-gu/meshconv-decoding/306bfab400470d28b1149c95833a992838bcad56/model.png -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from utils4image.dataloaders import NSDImageDataset 5 | from torch.utils.data import DataLoader 6 | from utils4image.utils import load_generator 7 | from utils4image.eva_utils import load_meshmodel 8 | import argparse 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--subject", type=int, default=1) 12 | parser.add_argument("--batch_size", type=int, default=8) 13 | args = parser.parse_args() 14 | 15 | subject = args.subject 16 | batch_size = args.batch_size 17 | 18 | scale = 1 19 | 20 | device = torch.device('cuda') 21 | mapping = 'meshpool' 22 | optim = 'adamw' 23 | lr = 1e-4 24 | decay = 0.1 25 | dropout_rate = 0.5 26 | fdim = 32 27 | indim = 32 28 | n_hidden_layer = 3 29 | recon_w = 1e-6 30 | kld_w = 1e-8 31 | annealing_epochs = 10 32 | kld_start_epoch = 0 33 | in_feature_w = 0 34 | out_feature_w = 1 35 | b2f_fix = True 36 | variation = True 37 | combined_type = 'variation' if variation else 'direct' 38 | factor = 1 39 | 40 | ckpt_dir = f'./decoding_ckpt/S{subject}/image_decoding/{combined_type}/' 41 | model_base = '%s_%s_lr%s_dc%s_dp%s_fd%d_ind%d_layer%d_rw%s_ifw%s_ofw%s_kldw%s_ae%d_kldse%d_vsf%s'% \ 42 | (mapping, optim, "{:.0e}".format(lr),"{:.0e}".format(decay), "{:.0e}".format(dropout_rate), 43 | fdim, indim, n_hidden_layer, "{:.0e}".format(recon_w), "{:.0e}".format(in_feature_w), 44 | "{:.0e}".format(out_feature_w), "{:.0e}".format(kld_w), annealing_epochs, kld_start_epoch, "{:.0e}".format(factor)) 45 | 46 | model_base = (model_base + '_fixb2f') if b2f_fix else (model_base + '_ftb2f') 47 | model = load_meshmodel(subject, 'image', fdim, indim, n_hidden_layer, dropout_rate, model_base).to(device) 48 | 49 | test_dataset = NSDImageDataset(mode='test', test_subject=subject) 50 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) 51 | 52 | pred_feat = [] 53 | pred_mu = [] 54 | pred_logvar = [] 55 | for i, (data, target, target_f) in enumerate(test_loader): 56 | print(i) 57 | pred, feat, mu, logvar = model(data.to(device)) 58 | pred_feat.append(feat.detach().cpu()) 59 | pred_mu.append(mu.detach().cpu()) 60 | pred_logvar.append(logvar.detach().cpu()) 61 | 62 | pred_mu = torch.vstack(pred_mu) 63 | pred_logvar = torch.vstack(pred_logvar)*scale 64 | pred_feat = torch.vstack(pred_feat) 65 | 66 | pred_muvar = {'mu':pred_mu.numpy(), 'logvar':pred_logvar.numpy()} 67 | n_pf_pz = [] # (10, 1000, 3, 256, 256) 68 | for n in range(10): 69 | print(n) 70 | #torch.manual_seed(seed=n) 71 | pf_pz = [] 72 | for i in range(len(test_loader)): 73 | print(i) 74 | pf = pred_feat[i*batch_size:(i+1)*batch_size].to(device) 75 | pz = model.brain2noise.reparametrize(pred_mu[i*batch_size:(i+1)*batch_size].to(device), pred_logvar[i*batch_size:(i+1)*batch_size].to(device)) 76 | img = model.generator(pz, None, pf).detach().cpu().numpy() 77 | pf_pz.append(img) 78 | pf_pz = np.vstack(pf_pz) 79 | pf_pz = (pf_pz+1)/2 80 | n_pf_pz.append(pf_pz) 81 | 82 | n_pf_pz = np.moveaxis(np.array(n_pf_pz), 2, -1) 83 | 84 | 85 | # pf_pz = [] #(1000, 3, 256, 256) 86 | # for i in range(len(test_loader)): 87 | # print(i) 88 | # pf = pred_feat[i*batch_size:(i+1)*batch_size].to(device) 89 | # pz = pred_mu[i*batch_size:(i+1)*batch_size].to(device) 90 | # img = model.generator(pz, None, pf).detach().cpu().numpy() 91 | # pf_pz.append(img) 92 | # pf_pz = np.vstack(pf_pz) 93 | # pf_pz = (pf_pz+1)/2 94 | 95 | # pf_pz = np.moveaxis(pf_pz, 1, -1) 96 | 97 | res_dir = f'./decoding_result/S{subject}/image_decoding/' 98 | if not os.path.exists(res_dir): 99 | os.makedirs(res_dir) 100 | 101 | #np.save(res_dir + model_base + f'_pred_imgs_{scale}var.npy', n_pf_pz) 102 | np.save(res_dir + model_base + f'_pred_imgs.npy', n_pf_pz) 103 | -------------------------------------------------------------------------------- /train_combined_decoding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | from torchvision import utils 6 | import numpy as np 7 | from utils4image.dataloaders import NSDImageDataset 8 | from utils4image.models import Brain2FeatureMeshPool, Brain2NoiseVarMeshPool, Brain2NoiseMeshPool, Brain2Image 9 | from utils4image.utils import load_generator, save_checkpoint 10 | from utils4image.eva_utils import load_meshmodel 11 | import argparse 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--subject", type=int, default=1) 17 | parser.add_argument("--batchsize", type=int, default=8) 18 | parser.add_argument("--lr", type=float, default=1e-4) 19 | parser.add_argument("--epoch", type=int, default=100) 20 | parser.add_argument("--mapping", type=str, default='meshpool') 21 | parser.add_argument("--optim", type=str, default='adamw') 22 | parser.add_argument("--dataset", type=str, default='imagenet') 23 | parser.add_argument("--indim", type=int, default=32, help='in dim for the first mesh conv') 24 | parser.add_argument("--fdim", type=int, default=32, help='hidden dim for mesh conv') 25 | parser.add_argument("--n_hidden_layer", type=int, default=3, help='number of hidden layers') 26 | parser.add_argument("--recon_w", type=float, default=1, help='weight for the reconstruction loss') 27 | parser.add_argument("--in_feature_w", type=float, default=1, help='weight for the in feature space loss') 28 | parser.add_argument("--out_feature_w", type=float, default=1, help='weight for the out feature space loss') 29 | parser.add_argument("--kld_w", type=float, default=1, help='weight for the KL divergence loss') 30 | parser.add_argument("--decay", type=float, default=0.1, help='weight decay') 31 | parser.add_argument("--dropout", type=float, default=0.5, help='dropout rate') 32 | parser.add_argument("--annealing_epochs", type=float, default=20, help='annealing epochs') 33 | parser.add_argument("--kld_start_epoch", type=float, default=20, help='when to start kld loss') 34 | parser.add_argument("--b2f_fix", default=True, action='store_false', help='fix the pretrained feature decoder or not') 35 | parser.add_argument("--gen_fix", default=True, action='store_false', help='fix the pretrained generator or not') 36 | parser.add_argument("--variation", default=False, action='store_true', help='variational or not') 37 | parser.add_argument("--factor", default=1, type=float, help='scaling factor for the variance if using variational approach') 38 | args = parser.parse_args() 39 | 40 | device = "cuda" 41 | subject = args.subject 42 | lr = args.lr 43 | n_epochs = args.epoch 44 | optim = args.optim 45 | mapping = args.mapping 46 | dataset = args.dataset 47 | decay = args.decay 48 | dropout_rate = args.dropout 49 | fdim = args.fdim 50 | indim = args.indim 51 | n_hidden_layer = args.n_hidden_layer 52 | recon_w = args.recon_w 53 | in_feature_w = args.in_feature_w 54 | out_feature_w = args.out_feature_w 55 | kld_w = args.kld_w 56 | batch_size = args.batchsize 57 | annealing_epochs = args.annealing_epochs 58 | kld_start_epoch = args.kld_start_epoch 59 | variation = args.variation 60 | combined_type = 'variation' if variation else 'direct' 61 | factor = args.factor 62 | b2f_fix = args.b2f_fix 63 | gen_fix = args.gen_fix 64 | ckpt_dir = f'./decoding_ckpt/S{subject}/image_decoding/{combined_type}/' 65 | samp_dir = f'./decoding_sample/S{subject}/image_decoding/{combined_type}/' 66 | if not os.path.exists(ckpt_dir): 67 | os.makedirs(ckpt_dir) 68 | if not os.path.exists(samp_dir): 69 | os.makedirs(samp_dir) 70 | 71 | model_base = '%s_%s_lr%s_dc%s_dp%s_fd%d_ind%d_layer%d_rw%s_ifw%s_ofw%s_kldw%s_ae%d_kldse%d_vsf%s'%(mapping, optim, "{:.0e}".format(lr), 72 | "{:.0e}".format(decay), "{:.0e}".format(dropout_rate), 73 | fdim, indim, n_hidden_layer, 74 | "{:.0e}".format(recon_w), "{:.0e}".format(in_feature_w), "{:.0e}".format(out_feature_w), 75 | "{:.0e}".format(kld_w), annealing_epochs, kld_start_epoch, "{:.0e}".format(factor)) 76 | model_base = (model_base + '_fixb2f') if b2f_fix else (model_base + '_ftb2f') 77 | train_dataset = NSDImageDataset(mode='train', test_subject=subject) 78 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 79 | val_dataset = NSDImageDataset(mode='val', test_subject=subject) 80 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 81 | 82 | feat_fdim, feat_indim, feat_n_hidden_layer = 32, 32, 3 83 | restore_file = '%s_%s_lr%s_dc%s_dp%s_fd%d_ind%d_layer%d'%(mapping, optim, "{:.0e}".format(1e-3), 84 | "{:.0e}".format(decay), "{:.0e}".format(dropout_rate), 85 | feat_fdim, feat_indim, feat_n_hidden_layer) 86 | b2f = load_meshmodel(subject, 'feature', feat_fdim, feat_indim, feat_n_hidden_layer, dropout_rate, restore_file) 87 | 88 | if variation: 89 | b2z = Brain2NoiseVarMeshPool(indim=indim, fdim=fdim, n_hidden_layer=n_hidden_layer, factor=factor) 90 | else: 91 | b2z = Brain2NoiseMeshPool(indim=indim, fdim=fdim, n_hidden_layer=n_hidden_layer) 92 | model = Brain2Image(b2f, b2z, b2f_fix=b2f_fix, generator_fix=gen_fix, variation=variation).to(device) 93 | print('Variation approach is ', variation) 94 | print('Feature decoder is fixed: ', b2f_fix, ' Generator is fixed: ', gen_fix) 95 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=decay) 96 | 97 | def train_step(data, target, target_f, annealed_kld_w): 98 | model.train() 99 | data, target, target_f = data.to(device), target.to(device), target_f.to(device) 100 | prediction, feature, mu, logvar = model(data) 101 | loss, (recon_loss, in_feat_loss, out_feature_loss, kld_loss) = model.compute_loss(target, prediction, target_f, feature, mu, logvar, recon_w, in_feature_w, out_feature_w, annealed_kld_w) 102 | optimizer.zero_grad() 103 | loss.backward() 104 | optimizer.step() 105 | return loss.item(), recon_loss.item(), in_feat_loss.item(), out_feature_loss.item(), kld_loss.item() 106 | 107 | def val_step(annealed_kld_w): 108 | model.eval() 109 | pred, true, val_loss, val_recon_loss, val_infeat_loss, val_outfeat_loss, val_kld_loss = [], [], 0, 0, 0, 0, 0 110 | for idx, (data, target, target_f) in enumerate(val_loader): 111 | data, target, target_f = data.to(device), target.to(device), target_f.to(device) 112 | prediction, feature, mu, logvar = model(data) 113 | loss, (recon_loss, infeat_loss, outfeat_loss, kld_loss) = model.compute_loss(target, prediction, target_f, feature, mu, logvar, recon_w, in_feature_w, out_feature_w, annealed_kld_w) 114 | 115 | val_loss += loss.item() 116 | val_recon_loss += recon_loss.item() 117 | val_infeat_loss += infeat_loss.item() 118 | val_outfeat_loss += outfeat_loss.item() 119 | val_kld_loss += kld_loss.item() 120 | pred.append(prediction.detach().cpu()) 121 | true.append(target.detach().cpu()) 122 | 123 | return torch.vstack(pred), torch.vstack(true), val_loss/(idx+1), val_recon_loss/(idx+1), val_infeat_loss/(idx+1), val_outfeat_loss/(idx+1), val_kld_loss/(idx+1) 124 | 125 | results = {'train_loss':[], 'train_recon_loss':[], 'train_infeat_loss':[], 'train_outfeat_loss':[], 'train_kld_loss':[], 126 | 'val_loss':[], 'val_recon_loss':[], 'val_infeat_loss':[], 'val_outfeat_loss':[], 'val_kld_loss':[]} 127 | best_loss = 1e5 128 | N_mini_batches = len(train_loader) 129 | for epoch in range(1, n_epochs+1): 130 | # training 131 | total_loss, total_recon_loss, total_infeat_loss, total_outfeat_loss, total_kld_loss = 0, 0, 0, 0, 0 132 | 133 | for batch_idx, (data, target, target_f) in enumerate(train_loader): 134 | if (kld_w == 0) or (epoch <= kld_start_epoch): 135 | annealing_factor = 0 136 | else: 137 | annealing_factor = kld_w * (float(batch_idx + (epoch - kld_start_epoch - 1) * N_mini_batches + 1) / 138 | float(annealing_epochs * N_mini_batches)) 139 | loss, recon_loss, infeat_loss, outfeat_loss, kld_loss = train_step(data, target, target_f, annealing_factor) 140 | total_loss += loss 141 | total_recon_loss += recon_loss 142 | total_infeat_loss += infeat_loss 143 | total_outfeat_loss += outfeat_loss 144 | total_kld_loss += kld_loss 145 | 146 | print("Training [{}:{}/{}] LOSS={:.2} RECON={:.2} INFEAT={:.2} OUTFEAT={:.2} KLD={:.2} ={:.2} ".format( 147 | epoch, batch_idx, len(train_loader), loss, recon_loss, infeat_loss, outfeat_loss, kld_loss, total_loss / (batch_idx + 1))) 148 | 149 | if batch_idx % 100 == 0: 150 | results['train_loss'].append(total_loss / (batch_idx + 1)) 151 | results['train_recon_loss'].append(total_recon_loss / (batch_idx + 1)) 152 | results['train_infeat_loss'].append(total_infeat_loss/ (batch_idx + 1)) 153 | results['train_outfeat_loss'].append(total_outfeat_loss/ (batch_idx + 1)) 154 | results['train_kld_loss'].append(total_kld_loss / (batch_idx + 1)) 155 | 156 | # val 157 | val_pred, val_true, val_loss, val_recon_loss, val_infeat_loss, val_outfeat_loss, val_kld_loss = val_step(annealing_factor) 158 | print("Validation [Epoch {} Test] ={:.2} ={:.2} ={:.2} ={:.2} ={:.2}".format(epoch, val_loss, val_recon_loss, val_infeat_loss, val_outfeat_loss, val_kld_loss)) 159 | results['val_loss'].append(val_loss) 160 | results['val_recon_loss'].append(val_recon_loss) 161 | results['val_infeat_loss'].append(val_infeat_loss) 162 | results['val_outfeat_loss'].append(val_outfeat_loss) 163 | results['val_kld_loss'].append(val_kld_loss) 164 | np.save(ckpt_dir + model_base + '_results.npy', results) 165 | 166 | if val_loss <= best_loss: 167 | best_loss = val_loss 168 | save_checkpoint(model, optimizer, epoch, model_base+'_best.pt', ckpt_dir) 169 | utils.save_image(torch.cat([val_true[:10], val_pred[:10]], 0), 170 | samp_dir+model_base+f"_val_best.png",nrow=10, 171 | normalize=True, value_range=(-1, 1)) 172 | 173 | save_checkpoint(model, optimizer, epoch, model_base+'_last.pt', ckpt_dir) 174 | utils.save_image(torch.cat([val_true[:10], val_pred[:10]], 0), 175 | samp_dir+model_base+f"_val_last.png",nrow=10, 176 | normalize=True, value_range=(-1, 1)) 177 | -------------------------------------------------------------------------------- /train_feature_decoding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from torch.utils.data import DataLoader 5 | from torchvision import utils 6 | import numpy as np 7 | from utils4image.dataloaders import NSDFeatureDataset 8 | from utils4image.models import Brain2FeatureMeshPool 9 | from utils4image.utils import load_generator, save_checkpoint 10 | import argparse 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument("--subject", type=int, default=1) 16 | parser.add_argument("--batchsize", type=int, default=32) 17 | parser.add_argument("--lr", type=float, default=1e-3) 18 | parser.add_argument("--epoch", type=int, default=100) 19 | parser.add_argument("--mapping", type=str, default='mesh') 20 | parser.add_argument("--optim", type=str, default='adamw') 21 | parser.add_argument("--dataset", type=str, default='imagenet') 22 | parser.add_argument("--indim", type=int, default=16, help='in dim for the first mesh conv') 23 | parser.add_argument("--fdim", type=int, default=16, help='hidden dim for mesh conv') 24 | parser.add_argument("--n_hidden_layer", type=int, default=3, help='number of hidden layers') 25 | parser.add_argument("--decay", type=float, default=0.01, help='weight decay') 26 | parser.add_argument("--dropout", type=float, default=0.5, help='dropout rate') 27 | args = parser.parse_args() 28 | 29 | device = "cuda" 30 | subject = args.subject 31 | lr = args.lr 32 | n_epochs = args.epoch 33 | optim = args.optim 34 | mapping = args.mapping 35 | dataset = args.dataset 36 | decay = args.decay 37 | dropout_rate = args.dropout 38 | fdim = args.fdim 39 | indim = args.indim 40 | n_hidden_layer = args.n_hidden_layer 41 | batch_size = args.batchsize 42 | ckpt_dir = f'./decoding_ckpt/S{subject}/feature_decoding/' 43 | samp_dir = f'./decoding_sample/S{subject}/feature_decoding/' 44 | if not os.path.exists(ckpt_dir): 45 | os.makedirs(ckpt_dir) 46 | if not os.path.exists(samp_dir): 47 | os.makedirs(samp_dir) 48 | 49 | model_base = '%s_%s_lr%s_dc%s_dp%s_fd%d_ind%d_layer%d'%(mapping, optim, "{:.0e}".format(lr),"{:.0e}".format(decay), "{:.0e}".format(dropout_rate), fdim, indim, n_hidden_layer) 50 | 51 | train_dataset = NSDFeatureDataset(mode='train', test_subject=subject) 52 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) 53 | val_dataset = NSDFeatureDataset(mode='val', test_subject=subject) 54 | val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False) 55 | 56 | model = Brain2FeatureMeshPool(fdim=fdim, indim=indim, n_hidden_layer=n_hidden_layer, dropout_rate=dropout_rate).to(device) 57 | proj_dir = '/home/zg243/image_generation/ic_gan/' 58 | generator = load_generator(f'icgan_biggan_{dataset}_res256', proj_dir+'pretrained_models', 'biggan', device) 59 | 60 | criterion = nn.MSELoss() 61 | optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=decay) 62 | 63 | def train_step(data, target): 64 | model.train() 65 | data, target = data.to(device), target.to(device) 66 | prediction = model(data) 67 | loss = criterion(prediction, target) 68 | optimizer.zero_grad() 69 | loss.backward() 70 | optimizer.step() 71 | return loss.item() 72 | 73 | def val_step(): 74 | model.eval() 75 | pred, true, val_loss = [], [], 0 76 | for idx, (data, target) in enumerate(val_loader): 77 | data, target = data.to(device), target.to(device) 78 | prediction = model(data) 79 | loss = criterion(prediction, target) 80 | 81 | val_loss += loss.item() 82 | pred.append(prediction.detach().cpu()) 83 | true.append(target.detach().cpu()) 84 | 85 | return torch.vstack(pred), torch.vstack(true), val_loss/(idx+1) 86 | 87 | results = {'train_loss':[], 'val_loss':[]} 88 | best_loss = 1e5 89 | for epoch in range(n_epochs): 90 | # training 91 | total_loss = 0 92 | 93 | for batch_idx, (data, target) in enumerate(train_loader): 94 | loss = train_step(data, target) 95 | total_loss += loss 96 | 97 | print("Training [{}:{}/{}] LOSS={:.2} ={:.2}".format( 98 | epoch, batch_idx, len(train_loader), loss, total_loss / (batch_idx + 1))) 99 | 100 | if batch_idx % 100 == 0: 101 | results['train_loss'].append(total_loss / (batch_idx + 1)) 102 | # val 103 | val_pred, val_true, val_loss = val_step() 104 | print("Validation [Epoch {} Test] ={:.2}".format(epoch, val_loss)) 105 | results['val_loss'].append(val_loss) 106 | np.save(ckpt_dir + model_base + '_results.npy', results) 107 | 108 | # generate samples 109 | zs = torch.empty(10, generator.dim_z,).normal_(mean=0, std=1.0).to(device) 110 | true_img = generator(zs, None, val_true[:10].to(device)).detach().cpu() 111 | pred_img = generator(zs, None, val_pred[:10].to(device)).detach().cpu() 112 | 113 | if val_loss <= best_loss: 114 | best_loss = val_loss 115 | save_checkpoint(model, optimizer, epoch+1, model_base+'_best.pt', ckpt_dir) 116 | utils.save_image(torch.cat([true_img[:10], pred_img[:10]], 0), 117 | samp_dir+model_base+f"_val_best.png",nrow=10, 118 | normalize=True, value_range=(-1, 1)) 119 | 120 | save_checkpoint(model, optimizer, epoch, model_base+'_last.pt', ckpt_dir) 121 | utils.save_image(torch.cat([true_img[:10], pred_img[:10]], 0), 122 | samp_dir+model_base+f"_val_last.png",nrow=10, 123 | normalize=True, value_range=(-1, 1)) 124 | -------------------------------------------------------------------------------- /utils4image/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /utils4image/config.py: -------------------------------------------------------------------------------- 1 | mapping = 'meshpool' 2 | optim = 'adamw' 3 | lr = 1e-4 4 | decay = 0.1 5 | dropout_rate = 0.5 6 | fdim = 32 7 | indim = 32 8 | n_hidden_layer = 3 9 | recon_w = 1e-7 10 | kld_w = 1e-8 11 | annealing_epochs = 10 12 | kld_start_epoch = 0 13 | in_feature_w = 0 14 | out_feature_w = 1 15 | b2f_fix = True 16 | variation = True 17 | combined_type = 'variation' if variation else 'direct' 18 | factor = 1 19 | 20 | feature_model_base = '%s_%s_lr%s_dc%s_dp%s_fd%d_ind%d_layer%d'%(mapping, optim, 21 | "{:.0e}".format(1e-3),"{:.0e}".format(decay), 22 | "{:.0e}".format(dropout_rate),fdim, indim, n_hidden_layer) 23 | 24 | combined_model_base = '%s_%s_lr%s_dc%s_dp%s_fd%d_ind%d_layer%d_rw%s_ifw%s_ofw%s_kldw%s_ae%d_kldse%d_vsf%s_fixb2f'% \ 25 | (mapping, optim, "{:.0e}".format(lr),"{:.0e}".format(decay), "{:.0e}".format(dropout_rate), 26 | fdim, indim, n_hidden_layer, "{:.0e}".format(recon_w), "{:.0e}".format(in_feature_w), 27 | "{:.0e}".format(out_feature_w), "{:.0e}".format(kld_w), annealing_epochs, kld_start_epoch, "{:.0e}".format(factor)) 28 | -------------------------------------------------------------------------------- /utils4image/dataloaders.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset 3 | from torchvision import transforms 4 | import numpy as np 5 | import h5py 6 | from .config import combined_model_base 7 | # import torch.nn.functional as F 8 | 9 | # import shutil 10 | # import os 11 | # from collections import OrderedDict 12 | # import torch.nn as nn 13 | 14 | class NSDFeatureDataset(Dataset): 15 | def __init__(self, level='individual', mode='train', test_subject=1, part='whole'): 16 | if level == 'individual': 17 | image_feats = np.load(f'./cocofeatures/S{test_subject}/image_features.npy') 18 | if part == 'whole': 19 | response_data = np.load(f'/home/zg243/SharedRep/data/S{test_subject}_surface32k_mean.npy', mmap_mode='r') 20 | elif part == 'visual': 21 | response_data = np.load(f'/home/zg243/SharedRep/data/S{test_subject}_visual16k_mean.npy', mmap_mode='r') 22 | if mode == 'test': 23 | responses = response_data[:1000] 24 | feats = image_feats[:1000] 25 | elif mode == 'train': 26 | responses = response_data[1000:9500] 27 | feats = image_feats[1000:9500] 28 | elif mode == 'val': 29 | responses = response_data[9500:] 30 | feats = image_feats[9500:] 31 | self.responses = responses 32 | self.feats = feats 33 | self.n_neurons = self.responses.shape[-1] 34 | 35 | def __len__(self): 36 | 'Denotes the total number of samples' 37 | return len(self.responses) 38 | 39 | def __getitem__(self, index): 40 | 'Generates one sample of data' 41 | X = np.asarray(self.responses[index]).astype(np.float32) 42 | y = np.asarray(self.feats[index]).astype(np.float32) 43 | return X, y 44 | 45 | 46 | preprocess = transforms.Compose([ 47 | transforms.ToTensor(), 48 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 49 | ]) 50 | class NSDImageDataset(Dataset): 51 | def __init__(self, level='individual', mode='train', test_subject=1, part='whole'): 52 | if level == 'individual': 53 | image_data_set = h5py.File(f"/home/zg243/nsd/stimuli/S{test_subject}_stimuli_256.h5py", 'r') 54 | image_data = np.copy(image_data_set['stimuli']).astype(np.float32) / 255. 55 | image_data_set.close() 56 | image_feats = np.load(f'./cocofeatures/S{test_subject}/image_features.npy') 57 | if part == 'whole': 58 | response_data = np.load(f'/home/zg243/SharedRep/data/S{test_subject}_surface32k_mean.npy', mmap_mode='r') 59 | elif part == 'visual': 60 | response_data = np.load(f'/home/zg243/SharedRep/data/S{test_subject}_visual16k_mean.npy', mmap_mode='r') 61 | if mode == 'test': 62 | responses = response_data[:1000] 63 | images = image_data[:1000] 64 | feats = image_feats[:1000] 65 | elif mode == 'train': 66 | responses = response_data[1000:9500] 67 | images = image_data[1000:9500] 68 | feats = image_feats[1000:9500] 69 | elif mode == 'val': 70 | responses = response_data[9500:] 71 | images = image_data[9500:] 72 | feats = image_feats[9500:] 73 | self.responses = responses 74 | self.feats = feats 75 | images = np.moveaxis(images, 1, -1) # (n, 3, 256, 256) -> (n, 256, 256, 3) 76 | self.images = images 77 | self.n_neurons = self.responses.shape[-1] 78 | 79 | def __len__(self): 80 | 'Denotes the total number of samples' 81 | return len(self.responses) 82 | 83 | def __getitem__(self, index): 84 | 'Generates one sample of data' 85 | X = np.asarray(self.responses[index]).astype(np.float32) 86 | y_img = np.asarray(self.images[index]).astype(np.float32) 87 | y_feat = np.asarray(self.feats[index]).astype(np.float32) 88 | y_img = preprocess(y_img) 89 | return X, y_img, y_feat 90 | -------------------------------------------------------------------------------- /utils4image/eva_utils.py: -------------------------------------------------------------------------------- 1 | from .models import Brain2FeatureMeshPool, Brain2Image, Brain2NoiseVarMeshPool 2 | import os 3 | import torch 4 | from torch.nn.parameter import Parameter 5 | 6 | def load_meshmodel(subject, element, fdim, indim, n_hidden_layer, dropout_rate, restore_file, variation=True, mtype='best', device=torch.device('cpu')): 7 | prefix = f'./decoding_ckpt/S{subject}/{element}_decoding/' 8 | if element == 'feature': 9 | model = Brain2FeatureMeshPool(fdim=fdim, indim=indim, n_hidden_layer=n_hidden_layer, dropout_rate=dropout_rate) 10 | elif element == 'image': 11 | b2f = Brain2FeatureMeshPool(fdim=fdim, indim=indim, n_hidden_layer=n_hidden_layer, dropout_rate=dropout_rate) 12 | if variation: 13 | b2z = Brain2NoiseVarMeshPool(indim=indim, fdim=fdim, n_hidden_layer=n_hidden_layer, factor=1) 14 | model = Brain2Image(b2f, b2z, b2f_fix=True, generator_fix=True, variation=variation) 15 | prefix = prefix + 'variation/' if variation else prefix 16 | restore_path = os.path.join(prefix, restore_file + f'_{mtype}.pt') 17 | resume_dict = torch.load(restore_path, map_location=torch.device('cpu')) 18 | state_dict = resume_dict['state_dict'] 19 | print(resume_dict['epoch']) 20 | model_state = model.state_dict() 21 | for name, param in state_dict.items(): 22 | if name not in model_state: 23 | continue 24 | if 'none' in name: 25 | continue 26 | if isinstance(param, Parameter): 27 | # backwards compatibility for serialized parameters 28 | param = param.data 29 | model_state[name].copy_(param) 30 | model.eval() 31 | return model.to(device) 32 | -------------------------------------------------------------------------------- /utils4image/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import os 4 | #from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | from .utils import load_generator, featurespace_loss, load_swav 7 | from .ugscnn_utils import * 8 | #import math 9 | #import pickle 10 | #from scipy import sparse 11 | #import numpy as np 12 | from torch.autograd import Variable 13 | 14 | class Brain2FeatureMeshPool(nn.Module): 15 | def __init__(self, mesh_dir='/home/zg243/SharedRep/mesh_assets/meshes/100-500-2k-8k-32k/', 16 | in_ch=2, indim=1, fdim=16, max_level=4, embed_dim=2048, n_hidden_layer=3, 17 | vertices=['100','500','2k','8k', '32k'], dropout_rate=0.5): 18 | super().__init__() 19 | self.mesh_dir = mesh_dir 20 | self.vertices = vertices 21 | self.fdim = fdim 22 | self.embed_dim = embed_dim 23 | self.n_hidden_layer = n_hidden_layer 24 | self.levels = max_level 25 | self.in_conv = MeshConv(in_ch, indim, self.__meshfile(max_level), stride=1) 26 | self.in_bn = nn.BatchNorm1d(indim) 27 | self.relu = nn.ReLU(inplace=True) 28 | self.in_block = nn.Sequential(self.in_conv, self.in_bn, self.relu) 29 | self.block1 = Down(indim, fdim*4, max_level-1, mesh_dir, vertices) 30 | self.block2 = Down(fdim*4, fdim*16, max_level-2, mesh_dir, vertices) 31 | self.block3 = Down(fdim*16, fdim*64, max_level-3, mesh_dir, vertices) 32 | self.block4 = Down(fdim*64, fdim*128, max_level-4, mesh_dir, vertices) 33 | if n_hidden_layer == 3: 34 | self.blocks = [self.block1, self.block2, self.block3] 35 | else: 36 | self.blocks = [self.block1, self.block2, self.block3, self.block4] 37 | self.avg = nn.AvgPool1d(kernel_size=self.blocks[-1].conv.nv_prev) # output shape batch x channels x 1 38 | self.out_layer = nn.Linear((2**(n_hidden_layer+3))*fdim, embed_dim) 39 | self.out_bn = nn.BatchNorm1d(embed_dim) 40 | self.out_block = nn.Sequential(self.out_layer, self.out_bn, self.relu) 41 | self.dropout_rate = dropout_rate 42 | 43 | def forward(self, x): 44 | x = self.in_block(x) 45 | for block in self.blocks: 46 | x = block(x) 47 | x = torch.squeeze(self.avg(x), dim=-1) 48 | x = F.dropout(x, p=self.dropout_rate, training=self.training) 49 | x = self.out_block(x) 50 | x = F.normalize(x, dim=1, p=2) 51 | return x 52 | 53 | def __meshfile(self, i): 54 | return os.path.join(self.mesh_dir, "icosphere_{}.pkl".format(self.vertices[i])) 55 | 56 | 57 | class Brain2NoiseMeshPool(nn.Module): 58 | def __init__(self, mesh_dir='/home/zg243/SharedRep/mesh_assets/meshes/100-500-2k-8k-32k/', 59 | in_ch=2, indim=2, fdim=16, n_hidden_layer=3, max_level=4, 60 | vertices=['100','500','2k','8k', '32k'], dropout_rate=0.5, latent_dim=119): 61 | super().__init__() 62 | self.mesh_dir = mesh_dir 63 | self.vertices = vertices 64 | self.fdim = fdim 65 | self.indim = indim 66 | self.n_hidden_layer = n_hidden_layer 67 | self.levels = max_level 68 | self.in_conv = MeshConv(in_ch, indim, self.__meshfile(max_level), stride=1) 69 | self.in_bn = nn.BatchNorm1d(indim) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.in_block = nn.Sequential(self.in_conv, self.in_bn, self.relu) 72 | self.block1 = Down(indim, fdim * 4, max_level - 1, mesh_dir, vertices) 73 | self.block2 = Down(fdim * 4, fdim * 16, max_level - 2, mesh_dir, vertices) 74 | self.block3 = Down(fdim * 16, fdim * 64, max_level - 3, mesh_dir, vertices) 75 | self.block4 = Down(fdim * 64, fdim * 128, max_level - 4, mesh_dir, vertices) 76 | if n_hidden_layer == 3: 77 | self.blocks = [self.block1, self.block2, self.block3] 78 | else: 79 | self.blocks = [self.block1, self.block2, self.block3, self.block4] 80 | self.avg = nn.AvgPool1d(kernel_size=self.block3.conv.nv_prev) # output shape batch x channels x 1 81 | self.dropout_rate = dropout_rate 82 | 83 | self.out_layer = nn.Linear((2**(n_hidden_layer+3)) * fdim, latent_dim) 84 | self.z_mag = np.sqrt(2)*1.386831e80/1.801679e79 # \sqrt(2)*Gamma(60)/Gamma(59.5) 85 | 86 | def forward(self, x): 87 | x = self.in_block(x) 88 | for block in self.blocks: 89 | x = block(x) 90 | x = torch.squeeze(self.avg(x), dim=-1) 91 | x = F.dropout(x, p=self.dropout_rate, training=self.training) 92 | z = self.out_layer(x) 93 | z = self.z_mag*F.normalize(z, p=2) 94 | return z 95 | 96 | def __meshfile(self, i): 97 | return os.path.join(self.mesh_dir, "icosphere_{}.pkl".format(self.vertices[i])) 98 | 99 | class Brain2NoiseVarMeshPool(nn.Module): 100 | def __init__(self, mesh_dir='/home/zg243/SharedRep/mesh_assets/meshes/100-500-2k-8k-32k/', 101 | in_ch=2, indim=2, fdim=16, n_hidden_layer=3, max_level=4, 102 | vertices=['100','500','2k','8k', '32k'], dropout_rate=0.5, latent_dim=119, factor=1): 103 | super().__init__() 104 | self.mesh_dir = mesh_dir 105 | self.vertices = vertices 106 | self.fdim = fdim 107 | self.indim = indim 108 | self.n_hidden_layer = n_hidden_layer 109 | self.levels = max_level 110 | self.in_conv = MeshConv(in_ch, indim, self.__meshfile(max_level), stride=1) 111 | self.in_bn = nn.BatchNorm1d(indim) 112 | self.relu = nn.ReLU(inplace=True) 113 | self.in_block = nn.Sequential(self.in_conv, self.in_bn, self.relu) 114 | self.block1 = Down(indim, fdim * 4, max_level - 1, mesh_dir, vertices) 115 | self.block2 = Down(fdim * 4, fdim * 16, max_level - 2, mesh_dir, vertices) 116 | self.block3 = Down(fdim * 16, fdim * 64, max_level - 3, mesh_dir, vertices) 117 | self.block4 = Down(fdim * 64, fdim * 128, max_level - 4, mesh_dir, vertices) 118 | if n_hidden_layer == 3: 119 | self.blocks = [self.block1, self.block2, self.block3] 120 | else: 121 | self.blocks = [self.block1, self.block2, self.block3, self.block4] 122 | self.avg = nn.AvgPool1d(kernel_size=self.block3.conv.nv_prev) # output shape batch x channels x 1 123 | self.dropout_rate = dropout_rate 124 | 125 | self.fc1 = nn.Linear((2**(n_hidden_layer+3)) * fdim, latent_dim) 126 | self.fc2 = nn.Linear((2**(n_hidden_layer+3)) * fdim, latent_dim) 127 | self.chisq_mean = 10.886 # \sqrt(2)*Gamma(60)/Gamma(59.5) 128 | self.radius = 12.9 # R: sqrt( qchisq( pchisq( 9,df=1 ), df = 119)) 129 | self.factor = factor 130 | 131 | def encode(self, x): 132 | x = self.in_block(x) 133 | for block in self.blocks: 134 | x = block(x) 135 | x = torch.squeeze(self.avg(x), dim=-1) 136 | x = F.dropout(x, p=self.dropout_rate, training=self.training) 137 | mu = self.fc1(x) 138 | #mu = self.chisq_mean*F.normalize(mu, p=2) # scale mean to be on the hypersphere 139 | logvar = self.fc2(x) 140 | logvar -= torch.log(torch.tensor(self.factor)**2) 141 | return mu, logvar 142 | 143 | def reparametrize(self, mu, logvar): 144 | std = logvar.mul(0.5).exp_() 145 | if torch.cuda.is_available(): 146 | eps = torch.cuda.FloatTensor(std.size()).normal_() 147 | else: 148 | eps = torch.FloatTensor(std.size()).normal_() 149 | eps = Variable(eps) 150 | return eps.mul(std).add_(mu) 151 | 152 | def forward(self, x): 153 | mu, logvar = self.encode(x) 154 | z = self.reparametrize(mu, logvar) 155 | return z, mu, logvar 156 | 157 | def __meshfile(self, i): 158 | return os.path.join(self.mesh_dir, "icosphere_{}.pkl".format(self.vertices[i])) 159 | 160 | 161 | class Brain2Image(nn.Module): 162 | def __init__(self, b2f, b2z, dataset='imagenet', b2f_fix=True, generator_fix=True, variation=False): 163 | super().__init__() 164 | self.brain2feature = b2f 165 | self.brain2noise = b2z 166 | self.generator = load_generator(f'icgan_biggan_{dataset}_res256', '/home/zg243/image_generation/ic_gan/pretrained_models', 'biggan') 167 | self.feature_extractor = load_swav() 168 | self.variation = variation 169 | if b2f_fix: 170 | self.fix_weights(self.brain2feature) 171 | if generator_fix: 172 | self.fix_weights(self.generator) 173 | 174 | def forward(self, x): 175 | feature = self.brain2feature(x) 176 | 177 | if self.variation: 178 | z, mu, logvar = self.brain2noise(x) 179 | y_hat = self.generator(z, None, feature) 180 | return y_hat, feature, mu, logvar 181 | else: 182 | z = self.brain2noise(x) 183 | y_hat = self.generator(z, None, feature) 184 | return y_hat, feature, 0, 0 185 | 186 | def fix_weights(self, block): 187 | for param in block.parameters(): 188 | param.requires_grad = False 189 | 190 | def compute_loss(self, y, y_hat, in_feature, in_feature_hat, mu=None, logvar=None, recon_w=1, in_feature_w=1, out_feature_w=1, kld_w=0): 191 | recon_loss = F.mse_loss(y, y_hat) 192 | in_feature_loss = F.mse_loss(in_feature, in_feature_hat) 193 | out_feature_loss = featurespace_loss(in_feature, y_hat, self.feature_extractor) 194 | if self.variation: 195 | kld_loss = torch.mean(-0.5 * torch.sum(1 + logvar - mu ** 2 - logvar.exp(), dim = 1), dim = 0) 196 | else: 197 | kld_loss = torch.zeros(recon_loss.size()) 198 | loss = recon_w*recon_loss + in_feature_w*in_feature_loss + out_feature_w*out_feature_loss + kld_w*kld_loss 199 | return loss, (recon_loss, in_feature_loss, out_feature_loss, kld_loss) 200 | 201 | -------------------------------------------------------------------------------- /utils4image/ugscnn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import os 4 | from torch.nn.parameter import Parameter 5 | import torch.nn.functional as F 6 | import math 7 | import pickle 8 | from scipy import sparse 9 | import numpy as np 10 | 11 | 12 | class Down(nn.Module): 13 | def __init__(self, in_ch, out_ch, level, mesh_dir, vertices, bias=True): 14 | super().__init__() 15 | """use mesh_file to perform convolution to the next coarser resolution mesh""" 16 | self.conv = ResPoolBlock(in_ch, in_ch, out_ch, level+1, True, mesh_dir, vertices) 17 | 18 | def forward(self, x): 19 | x = self.conv(x) 20 | return x 21 | 22 | class Down_noReLU(nn.Module): 23 | def __init__(self, in_ch, out_ch, level, mesh_dir, vertices, bias=True): 24 | super().__init__() 25 | """use mesh_file to perform convolution to the next coarser resolution mesh""" 26 | self.conv = ResPoolBlock_noReLU(in_ch, in_ch, out_ch, level+1, True, mesh_dir, vertices) 27 | 28 | def forward(self, x): 29 | x = self.conv(x) 30 | return x 31 | 32 | class _MeshConv(nn.Module): 33 | def __init__(self, in_channels, out_channels, mesh_file, stride=1, bias=True): 34 | assert stride in [1, 2] 35 | super(_MeshConv, self).__init__() 36 | self.in_channels = in_channels 37 | self.out_channels = out_channels 38 | if bias: 39 | self.bias = Parameter(torch.Tensor(out_channels)) 40 | else: 41 | self.register_parameter('bias', None) 42 | self.ncoeff = 4 43 | self.coeffs = Parameter(torch.Tensor(out_channels, in_channels, self.ncoeff)) 44 | self.set_coeffs() 45 | 46 | pkl = pickle.load(open(mesh_file, "rb")) 47 | self.pkl = pkl 48 | self.nv = self.pkl['V'].shape[0] 49 | G = sparse2tensor(pkl['G']) # gradient matrix V->F, 3#F x #V 50 | NS = torch.tensor(pkl['NS'], dtype=torch.float32) # north-south vector field, #F x 3 51 | EW = torch.tensor(pkl['EW'], dtype=torch.float32) # east-west vector field, #F x 3 52 | self.register_buffer("G", G) 53 | self.register_buffer("NS", NS) 54 | self.register_buffer("EW", EW) 55 | 56 | def set_coeffs(self): 57 | n = self.in_channels * self.ncoeff 58 | stdv = 1. / math.sqrt(n) 59 | self.coeffs.data.uniform_(-stdv, stdv) 60 | if self.bias is not None: 61 | self.bias.data.uniform_(-stdv, stdv) 62 | 63 | class MeshConv(_MeshConv): 64 | def __init__(self, in_channels, out_channels, mesh_file, stride=1, bias=True): 65 | super(MeshConv, self).__init__(in_channels, out_channels, mesh_file, stride, bias) 66 | pkl = self.pkl 67 | if stride == 2: 68 | self.nv_prev = pkl['nv_prev'] 69 | L = sparse2tensor(pkl['L'].tocsr()[:self.nv_prev].tocoo()) # laplacian matrix V->V 70 | F2V = sparse2tensor(pkl['F2V'].tocsr()[:self.nv_prev].tocoo()) # F->V, #V x #F 71 | else: # stride == 1 72 | self.nv_prev = pkl['V'].shape[0] 73 | L = sparse2tensor(pkl['L'].tocoo()) 74 | F2V = sparse2tensor(pkl['F2V'].tocoo()) 75 | self.register_buffer("L", L) 76 | self.register_buffer("F2V", F2V) 77 | 78 | def forward(self, input): 79 | # compute gradient 80 | grad_face = spmatmul(input, self.G) 81 | grad_face = grad_face.view(*(input.size()[:2]), 3, -1).permute(0, 1, 3, 2) # gradient, 3 component per face 82 | laplacian = spmatmul(input, self.L) 83 | identity = input[..., :self.nv_prev] 84 | grad_face_ew = torch.sum(torch.mul(grad_face, self.EW), keepdim=False, dim=-1) 85 | grad_face_ns = torch.sum(torch.mul(grad_face, self.NS), keepdim=False, dim=-1) 86 | grad_vert_ew = spmatmul(grad_face_ew, self.F2V) 87 | grad_vert_ns = spmatmul(grad_face_ns, self.F2V) 88 | 89 | feat = [identity, laplacian, grad_vert_ew, grad_vert_ns] 90 | 91 | out = torch.stack(feat, dim=-1) 92 | out = torch.sum(torch.sum(torch.mul(out.unsqueeze(1), self.coeffs.unsqueeze(2)), dim=2), dim=-1) 93 | out += self.bias.unsqueeze(-1) 94 | return out 95 | 96 | class ResPoolBlock_noReLU(nn.Module): 97 | def __init__(self, in_chan, neck_chan, out_chan, level, coarsen, mesh_dir, vertices): 98 | super().__init__() 99 | l = level-1 if coarsen else level 100 | self.coarsen = coarsen 101 | mesh_file = os.path.join(mesh_dir, "icosphere_{}.pkl".format(vertices[l])) 102 | self.conv1 = nn.Conv1d(in_chan, neck_chan, kernel_size=1, stride=1) 103 | self.conv2 = MeshConv(neck_chan, neck_chan, mesh_file=mesh_file, stride=1) 104 | self.conv3 = nn.Conv1d(neck_chan, out_chan, kernel_size=1, stride=1) 105 | self.relu = nn.ReLU(inplace=True) 106 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 107 | self.bn1 = nn.BatchNorm1d(neck_chan) 108 | self.bn2 = nn.BatchNorm1d(neck_chan) 109 | self.bn3 = nn.BatchNorm1d(out_chan) 110 | self.nv_prev = self.conv2.nv_prev 111 | self.pool = MaxPool(mesh_dir, level, vertices) 112 | self.diff_chan = (in_chan != out_chan) 113 | 114 | if coarsen: 115 | self.seq1 = nn.Sequential(self.conv1, self.pool, self.bn1, 116 | self.conv2, self.bn2, 117 | self.conv3, self.bn3) 118 | else: 119 | self.seq1 = nn.Sequential(self.conv1, self.bn1, 120 | self.conv2, self.bn2, 121 | self.conv3, self.bn3) 122 | 123 | if self.diff_chan or coarsen: 124 | self.conv_ = nn.Conv1d(in_chan, out_chan, kernel_size=1, stride=1) 125 | self.bn_ = nn.BatchNorm1d(out_chan) 126 | if coarsen: 127 | self.seq2 = nn.Sequential(self.conv_, self.pool, self.bn_) 128 | else: 129 | self.seq2 = nn.Sequential(self.conv_, self.bn_) 130 | 131 | def forward(self, x): 132 | if self.diff_chan or self.coarsen: 133 | x2 = self.seq2(x) 134 | else: 135 | x2 = x 136 | x1 = self.seq1(x) 137 | out = x1 + x2 138 | #out = self.leaky_relu(out) 139 | return out 140 | 141 | 142 | class ResPoolBlock(nn.Module): 143 | def __init__(self, in_chan, neck_chan, out_chan, level, coarsen, mesh_dir, vertices): 144 | super().__init__() 145 | l = level-1 if coarsen else level 146 | self.coarsen = coarsen 147 | mesh_file = os.path.join(mesh_dir, "icosphere_{}.pkl".format(vertices[l])) 148 | self.conv1 = nn.Conv1d(in_chan, neck_chan, kernel_size=1, stride=1) 149 | self.conv2 = MeshConv(neck_chan, neck_chan, mesh_file=mesh_file, stride=1) 150 | self.conv3 = nn.Conv1d(neck_chan, out_chan, kernel_size=1, stride=1) 151 | self.relu = nn.ReLU(inplace=True) 152 | self.leaky_relu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 153 | self.bn1 = nn.BatchNorm1d(neck_chan) 154 | self.bn2 = nn.BatchNorm1d(neck_chan) 155 | self.bn3 = nn.BatchNorm1d(out_chan) 156 | self.nv_prev = self.conv2.nv_prev 157 | self.pool = MaxPool(mesh_dir, level, vertices) 158 | self.diff_chan = (in_chan != out_chan) 159 | 160 | if coarsen: 161 | self.seq1 = nn.Sequential(self.conv1, self.pool, self.bn1, self.relu, 162 | self.conv2, self.bn2, self.relu, 163 | self.conv3, self.bn3) 164 | else: 165 | self.seq1 = nn.Sequential(self.conv1, self.bn1, self.relu, 166 | self.conv2, self.bn2, self.relu, 167 | self.conv3, self.bn3) 168 | 169 | if self.diff_chan or coarsen: 170 | self.conv_ = nn.Conv1d(in_chan, out_chan, kernel_size=1, stride=1) 171 | self.bn_ = nn.BatchNorm1d(out_chan) 172 | if coarsen: 173 | self.seq2 = nn.Sequential(self.conv_, self.pool, self.bn_) 174 | else: 175 | self.seq2 = nn.Sequential(self.conv_, self.bn_) 176 | 177 | def forward(self, x): 178 | if self.diff_chan or self.coarsen: 179 | x2 = self.seq2(x) 180 | else: 181 | x2 = x 182 | x1 = self.seq1(x) 183 | out = x1 + x2 184 | out = self.relu(out) 185 | return out 186 | 187 | 188 | class MaxPool(nn.Module): 189 | def __init__(self, mesh_dir, level, vertices): 190 | super().__init__() 191 | self.level = level 192 | 193 | if self.level > 0: 194 | vertices_to_prev_lvl_file = os.path.join(mesh_dir, "ico%s_to_ico%s_vertices.npy" % (vertices[level-1], vertices[level])) 195 | self.vertices_to_prev_lvl = np.load(vertices_to_prev_lvl_file) 196 | 197 | neihboring_patches_file = os.path.join(mesh_dir, "ico%s_neighbor_patches.npy" % (vertices[level-1])) 198 | self.neihboring_patches = np.load(neihboring_patches_file) 199 | 200 | def forward(self, x): 201 | tmp = x[..., self.vertices_to_prev_lvl] 202 | out, indices = torch.max(tmp[:, :, self.neihboring_patches], -1) 203 | return out 204 | 205 | """ 206 | from https://github.com/maxjiang93/ugscnn 207 | """ 208 | 209 | def sparse2tensor(m): 210 | """ 211 | Convert sparse matrix (scipy.sparse) to tensor (torch.sparse) 212 | """ 213 | assert(isinstance(m, sparse.coo.coo_matrix)) 214 | i = torch.LongTensor([m.row, m.col]) 215 | v = torch.FloatTensor(m.data) 216 | return torch.sparse.FloatTensor(i, v, torch.Size(m.shape)) 217 | 218 | def spmatmul(den, sp): 219 | """ 220 | den: Dense tensor of shape batch_size x in_chan x #V 221 | sp : Sparse tensor of shape newlen x #V 222 | """ 223 | batch_size, in_chan, nv = list(den.size()) 224 | new_len = sp.size()[0] 225 | den = den.permute(2, 1, 0).contiguous().view(nv, -1) 226 | res = torch.spmm(sp, den).view(new_len, in_chan, batch_size).contiguous().permute(2, 1, 0) 227 | return res 228 | -------------------------------------------------------------------------------- /utils4image/utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append("/home/zg243/image_generation/ic_gan/") 3 | import os 4 | import BigGAN_PyTorch.utils as biggan_utils 5 | import inference.utils as inference_utils 6 | from collections import OrderedDict 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | 11 | def load_generator(exp_name, root_path, backbone, device="cpu"): 12 | parser = biggan_utils.prepare_parser() 13 | parser = biggan_utils.add_sample_parser(parser) 14 | parser = inference_utils.add_backbone_parser(parser) 15 | 16 | args = ["--experiment_name", exp_name] 17 | args += ["--base_root", root_path] 18 | args += ["--model_backbone", backbone] 19 | 20 | config = vars(parser.parse_args(args=args)) 21 | 22 | # Load model and overwrite configuration parameters if stored in the model 23 | config = biggan_utils.update_config_roots(config, change_weight_folder=False) 24 | generator, config = inference_utils.load_model_inference(config, device=device) 25 | biggan_utils.count_parameters(generator) 26 | generator.eval() 27 | 28 | return generator 29 | 30 | 31 | def save_checkpoint(model, optimizer, epoch, fname, output_dir): 32 | state_dict_no_sparse = [it for it in model.state_dict().items() if it[1].type() != "torch.cuda.sparse.FloatTensor"] 33 | state_dict_no_sparse = OrderedDict(state_dict_no_sparse) 34 | 35 | checkpoint = { 36 | 'epoch': epoch, 37 | 'state_dict': state_dict_no_sparse, 38 | #'scheduler': scheduler.state_dict(), 39 | 'optimizer': optimizer.state_dict(), 40 | } 41 | torch.save(checkpoint, os.path.join(output_dir, fname)) 42 | 43 | 44 | def featurespace_loss(true_feat, pred, extractor): 45 | """ 46 | Args: 47 | true_feat: the true image feature (no need to extract from image again) 48 | pred: the decoded image (need to extract feature first) 49 | extractor: swav layer extractor 50 | Returns: feature space loss 51 | """ 52 | pred_feat = extractor(pred).view(pred.shape[0], -1) # unnormalized 53 | pred_feat = F.normalize(pred_feat, dim=1, p=2) # normalized true feats 54 | return F.mse_loss(true_feat, pred_feat) 55 | 56 | 57 | def load_swav(): 58 | swav = torch.hub.load('facebookresearch/swav:main', 'resnet50') 59 | return nn.Sequential(*list(swav.children())[:-1]).eval() 60 | --------------------------------------------------------------------------------