├── data ├── prompt_template.txt ├── configure.py └── dataset_augment.py ├── caption └── example.json ├── assets ├── MindGPT.png └── brain2text.png ├── requirements.txt ├── utils.py ├── README.md ├── modules ├── fmriencoder.py ├── vit.py ├── pos_embed.py ├── gpt2.py └── brain2text.py ├── feature_extract.py ├── brain2text_train.py └── brain2text_infer.py /data/prompt_template.txt: -------------------------------------------------------------------------------- 1 | This image shows 2 | -------------------------------------------------------------------------------- /caption/example.json: -------------------------------------------------------------------------------- 1 | {"n01443537_10014.JPEG": "xxx", "n01443537_1002.JPEG": "xxx"} -------------------------------------------------------------------------------- /assets/MindGPT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JxuanC/MindGPT/HEAD/assets/MindGPT.png -------------------------------------------------------------------------------- /assets/brain2text.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JxuanC/MindGPT/HEAD/assets/brain2text.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | bdpy==0.19 2 | einops==0.8.0 3 | faiss_gpu==1.7.2 4 | h5py==3.10.0 5 | numpy==1.25.2 6 | packaging==24.1 7 | pandas==2.2.2 8 | Pillow==10.4.0 9 | timm==0.9.12 10 | torch==2.0.1 11 | torchvision==0.15.2 12 | tqdm==4.66.1 13 | transformers==4.35.2 -------------------------------------------------------------------------------- /data/configure.py: -------------------------------------------------------------------------------- 1 | kamitani_Tr_Aug = '' # setting candidate images path 2 | 3 | DIR_dataset_dir = '' # setting DIR dataset path 4 | 5 | DIR_train_subs = {# setting train sub path 6 | 'sub-1':f'{DIR_dataset_dir}/sub-01_perceptionNaturalImageTraining_VC_v2.h5', 7 | 'sub-2':'{DIR_dataset_dir}/sub-02_perceptionNaturalImageTraining_VC_v2.h5', 8 | 'sub-3':f'{DIR_dataset_dir}/sub-03_perceptionNaturalImageTraining_VC_v2.h5' 9 | } 10 | kamitani_sti_trainID = "" # setting DIR sti_trainID 11 | 12 | kamitani_sti_testID = "" # setting DIR sti_testID 13 | 14 | smallCap_Kamitani_train = "caption/example.json" #setting SMALLCAP Cpation path 15 | 16 | DIR_test_subs = {# setting test sub path 17 | 'sub-1':f'{DIR_dataset_dir}/sub-01_perceptionNaturalImageTest_VC_v2.h5', 18 | 'sub-2':'{DIR_dataset_dir}/sub-02_perceptionNaturalImageTest_VC_v2.h5', 19 | 'sub-3':f'{DIR_dataset_dir}/sub-03_perceptionNaturalImageTest_VC_v2.h5' 20 | } 21 | 22 | CLIPGPTFEATURE = 'features/imagenet.hdf5' # CLIP-B-32 feature for brain2text 23 | 24 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | CAPTION_LENGTH = 25 2 | SIMPLE_PREFIX = "This image shows " 3 | 4 | def prep_strings(text, tokenizer, template=None, retrieved_caps=None, 5 | k=None, is_test=False, max_length=None): 6 | 7 | if is_test: 8 | padding = False 9 | truncation = False 10 | else: 11 | padding = True 12 | truncation = True 13 | 14 | if retrieved_caps is not None: 15 | infix = '\n\n'.join(retrieved_caps[:k]) + '.' 16 | prefix = template.replace('||', infix) 17 | else: 18 | prefix = SIMPLE_PREFIX 19 | 20 | prefix_ids = tokenizer.encode(prefix) 21 | len_prefix = len(prefix_ids) 22 | 23 | text_ids = tokenizer.encode(text, add_special_tokens=False) 24 | if truncation: 25 | text_ids = text_ids[:CAPTION_LENGTH] 26 | input_ids = prefix_ids + text_ids if not is_test else prefix_ids 27 | 28 | # we ignore the prefix (minus one as the first subtoken in the prefix is not predicted) 29 | label_ids = [-100] * (len_prefix - 1) + text_ids + [tokenizer.eos_token_id] 30 | if padding: 31 | input_ids += [tokenizer.pad_token_id] * (max_length - len(input_ids)) 32 | label_ids += [-100] * (max_length - len(label_ids)) 33 | 34 | if is_test: 35 | return input_ids 36 | else: 37 | return input_ids, label_ids 38 | 39 | def postprocess_preds(pred, tokenizer): 40 | pred = pred.split(SIMPLE_PREFIX)[-1] 41 | pred = pred.replace(tokenizer.pad_token, '') 42 | if pred.startswith(tokenizer.bos_token): 43 | pred = pred[len(tokenizer.bos_token):] 44 | if pred.endswith(tokenizer.eos_token): 45 | pred = pred[:-len(tokenizer.eos_token)] 46 | return pred -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MindGPT: Interpreting What You See with Non-invasive Brain Recordings 2 | Official Implementation of MindGPT in PyTorch 3 | 4 | ## News 5 | 6 | * 2024-07-11 7 | 8 | Codes release. 9 | 10 | 11 | * 2023-09-28 12 | 13 | Preprint release. Codes will be released soon! 14 | 15 | 16 | 17 | ## Overview 18 | ![MindGPT](assets/MindGPT.png) 19 | 20 | ## Samples 21 | ![brain2text results](assets/brain2text.png) 22 | 23 | ## Environment setup 24 | 1. `pip install -r requirements.txt` 25 | 26 | 2. Download [DIR dataset](https://figshare.com/articles/dataset/Deep_Image_Reconstruction/7033577) (Kamitani Lab) and [ImageNet dataset](https://image-net.org/). 27 | 28 | 3. Extract CLIP visual representations by running `feature_extract.py` and use [SMALLCAP](https://github.com/RitaRamo/smallcap) to generate pseudo labels (format see `caption/example.json`). 29 | 30 | 4. Change Paths in `data/configure.py` to match your file locations. 31 | 32 | 33 | ## Training 34 | Hyper-parameters can be changed with command line arguments 35 | ``` 36 | python brain2text_train.py --n_epochs 20 --batch_size 128 37 | ``` 38 | 39 | ## Reconstruction with Trained Checkpoints 40 | ``` 41 | python brain2text_infer.py 42 | ``` 43 | 44 | ## Acknowledgement 45 | We thank Kamitani Lab for making their raw and pre-processed data public. Our MindGPT implementation is based on the [SMALLCAP](https://github.com/RitaRamo/smallcap). We thank these authors for making their codes and checkpoints publicly available! 46 | 47 | ## Cite 48 | ``` 49 | @article{chen2023mindgpt, 50 | title={MindGPT: Interpreting What You See with Non-invasive Brain Recordings}, 51 | author={Jiaxuan Chen and Yu Qi and Yueming Wang and Gang Pan}, 52 | year={2023}, 53 | journal={arXiv preprint arXiv:2309.15729}, 54 | } 55 | ``` 56 | -------------------------------------------------------------------------------- /modules/fmriencoder.py: -------------------------------------------------------------------------------- 1 | from modules.vit import ViT 2 | from transformers.modeling_utils import PreTrainedModel 3 | from transformers.configuration_utils import PretrainedConfig 4 | import torch.nn as nn 5 | from timm.models.vision_transformer import PatchEmbed 6 | import torch 7 | 8 | class fMRIViTEncoderConfig(PretrainedConfig): 9 | model_type = "fMRIViTEncoder" 10 | 11 | def __init__( 12 | self, 13 | fmri_dim, rois_len, embed_dim, depth, num_heads, 14 | fmri2img = False, **kwargs, 15 | ): 16 | super().__init__(**kwargs) 17 | self.fmri_dim = fmri_dim 18 | self.rois_len = rois_len 19 | self.embed_dim = embed_dim 20 | self.depth = depth 21 | self.num_heads = num_heads 22 | self.hidden_size = embed_dim 23 | self.fmri2img = fmri2img 24 | 25 | class fMRIViTEncoder(PreTrainedModel): 26 | config_class = fMRIViTEncoderConfig 27 | def __init__(self, config): 28 | super(fMRIViTEncoder, self).__init__(config) 29 | if config.fmri2img: 30 | self.proj = nn.Linear(config.fmri_dim, 112 * 112 * 3) 31 | self.patch_embed = PatchEmbed(112, 16, 3, config.embed_dim) 32 | self.encoder = ViT(config.embed_dim, 49, config.embed_dim, config.depth, config.num_heads) 33 | else: 34 | self.encoder = ViT(config.fmri_dim, config.rois_len, config.embed_dim, config.depth, config.num_heads) 35 | self.config = config 36 | 37 | def forward(self, encoder_inputs, **kwargs): 38 | # encoder_inputs shape (batch, roi_num, roi_dim) 39 | if(self.config.fmri2img): 40 | encoder_inputs = self.proj(encoder_inputs) 41 | encoder_inputs = torch.reshape(encoder_inputs, (-1, 3, 112, 112)) 42 | encoder_inputs = self.patch_embed(encoder_inputs) 43 | else: 44 | encoder_outputs = self.encoder(encoder_inputs) 45 | return encoder_outputs -------------------------------------------------------------------------------- /feature_extract.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import pandas as pd 3 | import faiss 4 | import h5py 5 | from tqdm import tqdm 6 | from PIL import Image 7 | import numpy as np 8 | import torch 9 | import data.configure as config 10 | from torch.utils.data import Dataset, DataLoader 11 | from transformers import CLIPFeatureExtractor, CLIPVisionModel 12 | 13 | DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" 14 | 15 | def extract_visual_feature(images_dir, batch_size, index_type, encoder_name, save_name, faiss_index = False): 16 | last_hidden_state = [] 17 | class_embedding = [] 18 | h5py_file = h5py.File('./features/{}.hdf5'.format(save_name), 'w') 19 | feature_extractor = CLIPFeatureExtractor.from_pretrained(encoder_name) 20 | clip_encoder = CLIPVisionModel.from_pretrained(encoder_name).to(DEVICE) 21 | for idx in tqdm(range(0, len(images_dir), batch_size)): 22 | imgids = images_dir[idx:idx + batch_size] 23 | images = [Image.open(file_name).convert("RGB") for file_name in imgids] 24 | with torch.no_grad(): 25 | pixel_values = feature_extractor(images, return_tensors='pt').pixel_values.to(DEVICE) 26 | encodings = clip_encoder(pixel_values=pixel_values).last_hidden_state.cpu().numpy() 27 | #last_hidden_state.append(encodings) 28 | class_embedding.append(encodings[:, 0, :]) 29 | for imgid, encoding in zip(imgids, encodings): 30 | h5py_file.create_dataset('n' + str(imgid).split('n')[-1], (50, 768), data = encoding) 31 | 32 | if(faiss_index): 33 | class_embedding = np.vstack(class_embedding) 34 | embedding_dimension = class_embedding.shape[1] 35 | embedding_nums = class_embedding.shape[0] 36 | 37 | index_type = 'dot' 38 | if index_type == "L2": 39 | cpu_index = faiss.IndexFlatL2(embedding_dimension) 40 | #gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) 41 | if index_type == "dot": 42 | cpu_index = faiss.IndexFlatIP(embedding_dimension) 43 | #gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) 44 | if index_type == "cosine": 45 | # cosine = normalize & dot 46 | faiss.normalize_L2(class_embedding) 47 | cpu_index = faiss.IndexFlatIP(embedding_dimension) 48 | #gpu_index = faiss.index_cpu_to_all_gpus(cpu_index) 49 | 50 | print(cpu_index.is_trained) 51 | cpu_index.add(class_embedding) 52 | faiss.write_index(cpu_index, f"data/features/CLIP_{save_name}_index") 53 | #faiss.write_index(gpu_index, f"database/GPU_{save_name}") 54 | 55 | encoder_name = 'openai/clip-vit-base-patch32' 56 | 57 | imgs_dir = np.concatenate([np.array(glob.glob(f"{config.kamitani_Tr_Aug}/*/*.JPEG"))]) 58 | 59 | extract_visual_feature(imgs_dir, 128, "dot", encoder_name, 'imagenet') 60 | -------------------------------------------------------------------------------- /modules/vit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops import rearrange 5 | from timm.models.vision_transformer import Block 6 | from modules.pos_embed import get_2d_sincos_pos_embed 7 | 8 | class ViT(nn.Module): 9 | def __init__(self, d, num_patches, 10 | embed_dim=1024, depth=24, num_heads=16, 11 | mlp_ratio=4., norm_layer=nn.LayerNorm): 12 | super().__init__() 13 | 14 | self.proj = nn.Conv1d(d, embed_dim, kernel_size = 1) 15 | # -------------------------------------------------------------------------- 16 | # encoder specifics 17 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 18 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding 19 | 20 | self.blocks = nn.ModuleList([ 21 | Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) 22 | for i in range(depth)]) 23 | self.norm = norm_layer(embed_dim) 24 | # -------------------------------------------------------------------------- 25 | 26 | 27 | def initialize_weights(self): 28 | # initialization 29 | # initialize (and freeze) pos_embed by sin-cos embedding 30 | pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True) 31 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 32 | 33 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 34 | w = self.patch_embed.proj.weight.data 35 | torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 36 | 37 | # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) 38 | torch.nn.init.normal_(self.cls_token, std=.02) 39 | 40 | # initialize nn.Linear and nn.LayerNorm 41 | self.apply(self._init_weights) 42 | 43 | def _init_weights(self, m): 44 | if isinstance(m, nn.Linear): 45 | # we use xavier_uniform following official JAX ViT: 46 | torch.nn.init.xavier_uniform_(m.weight) 47 | if isinstance(m, nn.Linear) and m.bias is not None: 48 | nn.init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.LayerNorm): 50 | nn.init.constant_(m.bias, 0) 51 | nn.init.constant_(m.weight, 1.0) 52 | 53 | def forward_encoder(self, x): 54 | x = rearrange(x, 'b n d -> b d n') 55 | x = self.proj(x) 56 | x = rearrange(x, 'b d n -> b n d') 57 | 58 | # add pos embed w/o cls token 59 | x = x + self.pos_embed[:, 1:, :] 60 | 61 | # append cls token 62 | cls_token = self.cls_token + self.pos_embed[:, :1, :] 63 | cls_tokens = cls_token.expand(x.shape[0], -1, -1) 64 | x = torch.cat((cls_tokens, x), dim=1) 65 | 66 | # apply Transformer blocks 67 | for blk in self.blocks: 68 | x = blk(x) 69 | x = self.norm(x) 70 | 71 | return nn.Tanh()(x) 72 | 73 | def forward(self, x): 74 | encode = self.forward_encoder(x) 75 | return encode#encode[:, :1, :].squeeze() -------------------------------------------------------------------------------- /modules/pos_embed.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | 4 | import torch 5 | 6 | # -------------------------------------------------------- 7 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): 8 | """ 9 | grid_size: int of the grid height and width 10 | return: 11 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 12 | """ 13 | grid_h = np.arange(grid_size, dtype=np.float32) 14 | grid_w = np.arange(grid_size, dtype=np.float32) 15 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 16 | grid = np.stack(grid, axis=0) 17 | 18 | grid = grid.reshape([2, 1, grid_size, grid_size]) 19 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 20 | if cls_token: 21 | pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) 22 | return pos_embed 23 | 24 | 25 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 26 | assert embed_dim % 2 == 0 27 | 28 | # use half of dimensions to encode grid_h 29 | emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) 30 | emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) 31 | 32 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 33 | return emb 34 | 35 | 36 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 37 | """ 38 | embed_dim: output dimension for each position 39 | pos: a list of positions to be encoded: size (M,) 40 | out: (M, D) 41 | """ 42 | assert embed_dim % 2 == 0 43 | omega = np.arange(embed_dim // 2, dtype=np.float) 44 | omega /= embed_dim / 2. 45 | omega = 1. / 10000**omega # (D/2,) 46 | 47 | pos = pos.reshape(-1) # (M,) 48 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 49 | 50 | emb_sin = np.sin(out) # (M, D/2) 51 | emb_cos = np.cos(out) # (M, D/2) 52 | 53 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 54 | return emb 55 | 56 | 57 | # -------------------------------------------------------- 58 | def interpolate_pos_embed(model, checkpoint_model): 59 | if 'pos_embed' in checkpoint_model: 60 | pos_embed_checkpoint = checkpoint_model['pos_embed'] 61 | embedding_size = pos_embed_checkpoint.shape[-1] 62 | num_patches = model.patch_embed.num_patches 63 | num_extra_tokens = model.pos_embed.shape[-2] - num_patches 64 | # height (== width) for the checkpoint position embedding 65 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 66 | # height (== width) for the new position embedding 67 | new_size = int(num_patches ** 0.5) 68 | # class_token and dist_token are kept unchanged 69 | if orig_size != new_size: 70 | print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) 71 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 72 | # only the position tokens are interpolated 73 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 74 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 75 | pos_tokens = torch.nn.functional.interpolate( 76 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 77 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 78 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 79 | checkpoint_model['pos_embed'] = new_pos_embed 80 | -------------------------------------------------------------------------------- /brain2text_train.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import os 4 | import argparse 5 | import h5py 6 | 7 | os.environ["WANDB_DISABLED"] = "true" 8 | from transformers.models.auto.configuration_auto import AutoConfig 9 | from transformers import AutoTokenizer, GPT2Tokenizer, CLIPFeatureExtractor, AutoModel, AutoModelForCausalLM 10 | from transformers import Seq2SeqTrainer, default_data_collator, Seq2SeqTrainingArguments 11 | 12 | from modules.brain2text import Brain2Text, Brain2TextConfig 13 | from modules.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel 14 | from modules.fmriencoder import fMRIViTEncoderConfig, fMRIViTEncoder 15 | from data.dataset_augment import get_visual_text_dir_dataset 16 | from data.configure import CLIPGPTFEATURE 17 | 18 | # for attention with 28M params, we devide the attention dimensions by 1 19 | # for attention with 14M params, we devide the attention dimensions by 2, etc. 20 | PARAMS2REDUCE_FACTOR = {28: 1, 14: 2, 7: 4, 3.5: 8, 1.75: 16} 21 | ENCODERPARAMS = {'ViT-4-4': 4, 'ViT-8-8': 8, 'ViT-16-16': 16} 22 | 23 | PAD_TOKEN = '!' 24 | EOS_TOKEN = '.' 25 | CAPTION_LENGTH = 25 26 | CLIP_FEATURES = h5py.File(f'{CLIPGPTFEATURE}', 'r') 27 | 28 | def get_model_and_auxiliaries(args): 29 | 30 | AutoConfig.register("this_gpt2", ThisGPT2Config) 31 | AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel) 32 | AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel) 33 | 34 | AutoConfig.register("Brain2Text", Brain2TextConfig) 35 | AutoModel.register(Brain2TextConfig, Brain2Text) 36 | 37 | AutoConfig.register("fMRIViTEncoder", fMRIViTEncoderConfig) 38 | AutoModel.register(fMRIViTEncoderConfig, fMRIViTEncoder) 39 | 40 | # create and configure model 41 | cross_attention_reduce_factor = PARAMS2REDUCE_FACTOR[args.attention_size] 42 | 43 | #feature_extractor = CLIPFeatureExtractor.from_pretrained(args.encoder_name) 44 | tokenizer = AutoTokenizer.from_pretrained(args.decoder_name) 45 | #tokenizer = GPT2Tokenizer.from_pretrained(args.decoder_name) 46 | tokenizer.pad_token = PAD_TOKEN 47 | tokenizer.eos_token = EOS_TOKEN 48 | 49 | selected_rois = ['ROI_V1', 'ROI_V2', 'ROI_V3', 'ROI_V4', 'ROI_LOC', 'ROI_FFA', 'ROI_PPA'] 50 | 51 | dataset, dataloader, fmri_dim = get_visual_text_dir_dataset(args.sub, selected_rois, args.batch_size, 52 | tokenizer = tokenizer, mixup = True, candidate = True, 53 | clip_features = CLIP_FEATURES) 54 | 55 | encoder_config = fMRIViTEncoderConfig(fmri_dim, len(selected_rois), 768, ENCODERPARAMS[args.encoder_cog], 56 | ENCODERPARAMS[args.encoder_cog], fmri2img = False) 57 | 58 | model = Brain2Text.from_encoder_decoder_pretrained(fMRIViTEncoder(encoder_config), args.decoder_name, 59 | cross_attention_reduce_factor = cross_attention_reduce_factor) 60 | model.config.vocab_size = model.config.decoder.vocab_size 61 | model.config.decoder_start_token_id = None 62 | model.config.pad_token_id = tokenizer.pad_token_id 63 | model.config.eos_token_id = tokenizer.eos_token_id 64 | model.config.max_length = CAPTION_LENGTH 65 | 66 | model.config.k = 0 67 | 68 | # freeze parameters 69 | for param in model.encoder.parameters(): 70 | param.requires_grad = True 71 | 72 | if not args.train_decoder: 73 | for name, param in model.decoder.named_parameters(): 74 | if 'crossattention' not in name: 75 | param.requires_grad = False 76 | 77 | # count trainable parameters 78 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 79 | num_trainable_params = sum([np.prod(p.size()) for p in model_parameters]) 80 | print('Training a model with {}M trainable parameters.'.format(num_trainable_params/100/10000)) 81 | 82 | return model, tokenizer, dataset 83 | 84 | 85 | def main(args): 86 | model, tokenizer, dataset = get_model_and_auxiliaries(args) 87 | 88 | model_type = 'mindgpt' 89 | 90 | output_dir = '{}_{}M_{}'.format(model_type, args.attention_size, args.decoder_name) 91 | 92 | output_dir = os.path.join(args.experiments_dir, args.encoder_cog, args.dataset, args.sub, args.ROI, output_dir) 93 | 94 | training_args = Seq2SeqTrainingArguments( 95 | num_train_epochs=args.n_epochs, 96 | per_device_train_batch_size=args.batch_size, 97 | gradient_accumulation_steps=args.gradient_steps, 98 | learning_rate=args.lr, 99 | fp16=False, 100 | save_strategy="epoch", 101 | save_total_limit=args.n_epochs, 102 | logging_strategy="steps", 103 | output_dir=output_dir, 104 | overwrite_output_dir=True, 105 | weight_decay=1e-4, 106 | ) 107 | 108 | trainer = Seq2SeqTrainer( 109 | model=model, 110 | args=training_args, 111 | data_collator=default_data_collator, 112 | train_dataset=dataset, 113 | tokenizer=tokenizer, 114 | ) 115 | 116 | trainer.train() 117 | 118 | if __name__ == '__main__': 119 | parser = argparse.ArgumentParser(description='Model Training') 120 | parser.add_argument("--experiments_dir", type=str, default="./log/Brain2Text/", help="Directory where trained models will be saved") 121 | parser.add_argument("--encoder_name", type=str, default="fMRIEncoder", help="Encoder name as found of HuggingFace or stored locally") 122 | parser.add_argument("--encoder_cog", type=str, default="ViT-16-16", help="Encoder parameters") 123 | parser.add_argument("--decoder_name", type=str, default="gpt2", help="Decoder name as found of HuggingFace or stored locally") 124 | parser.add_argument("--attention_size", type=float, default=1.75, help="Number of parameters in the cross attention {28, 14, 7, 3.5, 1.75}") 125 | parser.add_argument("--train_decoder", action="store_true", default=False, help="Whether to train the decoder in addition to the attention") 126 | parser.add_argument("--n_epochs", type=int, help="Number of training epochs") 127 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") 128 | parser.add_argument("--batch_size", type=int, default=32, help="Batch size") 129 | parser.add_argument("--gradient_steps", type=int, default=1, help="Number of gradient accumulation steps") 130 | parser.add_argument("--sub", type=str, default='sub-3', help="subject") 131 | parser.add_argument("--dataset", type=str, default='DIR', help="fMRI dataset name") 132 | parser.add_argument("--ROI", type=str, default='VC', help='brain ROIs') 133 | args = parser.parse_args() 134 | 135 | main(args) 136 | -------------------------------------------------------------------------------- /modules/gpt2.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team. 3 | # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """PyTorch OpenAI GPT-2 model.""" 17 | 18 | import math 19 | import os 20 | from dataclasses import dataclass 21 | from typing import Optional, Tuple, Union 22 | 23 | import torch 24 | import torch.utils.checkpoint 25 | from packaging import version 26 | from torch import nn 27 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 28 | 29 | from transformers.models.gpt2.modeling_gpt2 import load_tf_weights_in_gpt2, GPT2LMHeadModel, GPT2MLP, GPT2Attention, GPT2Block, GPT2Model 30 | 31 | from transformers.activations import ACT2FN 32 | from transformers.modeling_outputs import ( 33 | BaseModelOutputWithPastAndCrossAttentions, 34 | CausalLMOutputWithCrossAttentions, 35 | SequenceClassifierOutputWithPast, 36 | TokenClassifierOutput, 37 | ) 38 | from transformers.modeling_utils import PreTrainedModel, SequenceSummary 39 | from transformers.pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer 40 | from transformers.utils import ( 41 | ModelOutput, 42 | logging, 43 | ) 44 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 45 | from transformers.models.gpt2.configuration_gpt2 import GPT2Config 46 | 47 | 48 | if version.parse(torch.__version__) >= version.parse("1.6"): 49 | is_amp_available = True 50 | from torch.cuda.amp import autocast 51 | else: 52 | is_amp_available = False 53 | 54 | 55 | class ThisGPT2Config(GPT2Config): 56 | model_type = "this_gpt2" 57 | 58 | def __init__( 59 | self, 60 | cross_attention_reduce_factor = 1, 61 | **kwargs, 62 | ): 63 | super().__init__(**kwargs) 64 | self.cross_attention_reduce_factor = cross_attention_reduce_factor 65 | 66 | class ThisGPT2Attention(GPT2Attention): 67 | def __init__(self, config, is_cross_attention=False, layer_idx=None): 68 | super().__init__(config, is_cross_attention, layer_idx) 69 | 70 | #print("this gpt2") 71 | 72 | #print("self.is_cross_attention = is_cross_attention", self.is_cross_attention, is_cross_attention) 73 | 74 | self.cross_attention_reduce_factor = config.cross_attention_reduce_factor 75 | 76 | if self.is_cross_attention: 77 | self.c_attn = Conv1D(int(2 / self.cross_attention_reduce_factor * self.embed_dim), 78 | self.embed_dim) 79 | self.q_attn = Conv1D(int(self.embed_dim / self.cross_attention_reduce_factor), self.embed_dim) 80 | self.c_proj = Conv1D(self.embed_dim, int(self.embed_dim / self.cross_attention_reduce_factor)) 81 | else: 82 | self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) 83 | self.c_proj = Conv1D(self.embed_dim, self.embed_dim) 84 | 85 | def forward( 86 | self, 87 | hidden_states: Optional[Tuple[torch.FloatTensor]], 88 | layer_past: Optional[Tuple[torch.Tensor]] = None, 89 | attention_mask: Optional[torch.FloatTensor] = None, 90 | head_mask: Optional[torch.FloatTensor] = None, 91 | encoder_hidden_states: Optional[torch.Tensor] = None, 92 | encoder_attention_mask: Optional[torch.FloatTensor] = None, 93 | use_cache: Optional[bool] = False, 94 | output_attentions: Optional[bool] = False, 95 | ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: 96 | if encoder_hidden_states is not None: 97 | if not hasattr(self, "q_attn"): 98 | raise ValueError( 99 | "If class is used as cross attention, the weights `q_attn` have to be defined. " 100 | "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`." 101 | ) 102 | split_size = int(self.split_size / self.cross_attention_reduce_factor) 103 | head_dim = int(self.head_dim / self.cross_attention_reduce_factor) 104 | 105 | query = self.q_attn(hidden_states) 106 | key, value = self.c_attn(encoder_hidden_states).split(split_size, dim=2) 107 | attention_mask = encoder_attention_mask 108 | 109 | query = self._split_heads(query, self.num_heads, head_dim) 110 | key = self._split_heads(key, self.num_heads, head_dim) 111 | value = self._split_heads(value, self.num_heads, head_dim) 112 | else: 113 | query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2) 114 | 115 | query = self._split_heads(query, self.num_heads, self.head_dim) 116 | key = self._split_heads(key, self.num_heads, self.head_dim) 117 | value = self._split_heads(value, self.num_heads, self.head_dim) 118 | 119 | if layer_past is not None: 120 | past_key, past_value = layer_past 121 | key = torch.cat((past_key, key), dim=-2) 122 | value = torch.cat((past_value, value), dim=-2) 123 | 124 | if use_cache is True: 125 | present = (key, value) 126 | else: 127 | present = None 128 | 129 | if self.reorder_and_upcast_attn: 130 | attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask) 131 | else: 132 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 133 | 134 | attn_output = self._merge_heads(attn_output, self.num_heads, int(self.head_dim / self.cross_attention_reduce_factor)) 135 | attn_output = self.c_proj(attn_output) 136 | attn_output = self.resid_dropout(attn_output) 137 | 138 | outputs = (attn_output, present) 139 | if output_attentions: 140 | outputs += (attn_weights,) 141 | 142 | return outputs # a, present, (attentions) 143 | 144 | 145 | class ThisGPT2Block(GPT2Block): 146 | def __init__(self, config, layer_idx=None): 147 | super().__init__(config, layer_idx) 148 | hidden_size = config.hidden_size 149 | 150 | if config.add_cross_attention: 151 | self.crossattention = ThisGPT2Attention(config, is_cross_attention=True, layer_idx=layer_idx) 152 | self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) 153 | 154 | class ThisGPT2Model(GPT2Model): 155 | 156 | def __init__(self, config): 157 | super().__init__(config) 158 | self.h = nn.ModuleList([ThisGPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)]) 159 | 160 | 161 | class ThisGPT2LMHeadModel(GPT2LMHeadModel): 162 | config_class = ThisGPT2Config 163 | 164 | def __init__(self, config): 165 | super().__init__(config) 166 | self.transformer = ThisGPT2Model(config) 167 | 168 | -------------------------------------------------------------------------------- /brain2text_infer.py: -------------------------------------------------------------------------------- 1 | 2 | import argparse 3 | import os 4 | import numpy as np 5 | from tqdm import tqdm 6 | import json 7 | from PIL import ImageFile 8 | import torch 9 | from transformers import AutoTokenizer, GPT2Tokenizer, CLIPFeatureExtractor, AutoModel 10 | from transformers.models.auto.configuration_auto import AutoConfig 11 | from transformers.modeling_outputs import BaseModelOutput 12 | from modules.vit import ViT 13 | from utils import prep_strings, postprocess_preds 14 | ImageFile.LOAD_TRUNCATED_IMAGES = True 15 | from data.dataset_augment import DIR_sub_without_images 16 | import faiss 17 | 18 | PAD_TOKEN = '!' 19 | EOS_TOKEN = '.' 20 | CAPTION_LENGTH = 25 21 | 22 | def evaluate_norag_model(args, tokenizer, model, eval_df): 23 | out = [] 24 | bs = args.batch_size 25 | for idx in tqdm(range(0, len(eval_df[0]), bs)): 26 | fMRI = torch.tensor(eval_df[0][idx:idx + bs], dtype = torch.float32).to(args.device) 27 | image_ids = eval_df[1][idx:idx + bs] 28 | decoder_input_ids = [prep_strings('', tokenizer, is_test = True) for _ in range(len(image_ids))] 29 | 30 | with torch.no_grad(): 31 | encoder_last_hidden_state = torch.FloatTensor(model.encoder(fMRI).cpu()).to(args.device) 32 | encoder_outputs = BaseModelOutput(last_hidden_state = encoder_last_hidden_state) 33 | preds = model.generate(encoder_outputs = encoder_outputs, decoder_input_ids = torch.tensor(decoder_input_ids).to(args.device), 34 | **args.generation_kwargs) 35 | preds = tokenizer.batch_decode(preds) 36 | 37 | for image_id, pred in zip(image_ids, preds): 38 | pred = postprocess_preds(pred, tokenizer) 39 | out.append({"image_id": image_id, "caption": pred}) 40 | 41 | return out 42 | 43 | 44 | def load_model(args, checkpoint_path): 45 | config = AutoConfig.from_pretrained(checkpoint_path + '/config.json') 46 | model = AutoModel.from_pretrained(checkpoint_path) 47 | model.config = config 48 | model.eval() 49 | model.to(args.device) 50 | return model 51 | 52 | def infer_one_checkpoint(args, tokenizer, checkpoint_path, eval_df, infer_fn): 53 | model = load_model(args, checkpoint_path) 54 | preds = infer_fn(args, tokenizer, model, eval_df) 55 | with open(os.path.join(checkpoint_path, args.outfile_name), 'w') as outfile: 56 | json.dump(preds, outfile) 57 | 58 | def register_model_and_config(): 59 | from transformers import AutoModelForCausalLM 60 | from modules.brain2text import Brain2Text, Brain2TextConfig 61 | from modules.fmriencoder import fMRIViTEncoder, fMRIViTEncoderConfig 62 | from modules.gpt2 import ThisGPT2Config, ThisGPT2LMHeadModel 63 | 64 | AutoConfig.register("this_gpt2", ThisGPT2Config) 65 | AutoModel.register(ThisGPT2Config, ThisGPT2LMHeadModel) 66 | AutoModelForCausalLM.register(ThisGPT2Config, ThisGPT2LMHeadModel) 67 | 68 | AutoConfig.register("fMRIViTEncoder", fMRIViTEncoderConfig) 69 | AutoModel.register(fMRIViTEncoderConfig, fMRIViTEncoder) 70 | 71 | AutoConfig.register("Brain2Text", Brain2TextConfig) 72 | AutoModel.register(Brain2TextConfig, Brain2Text) 73 | 74 | @torch.no_grad() 75 | def rag_captions(data, sortedimageIDs, caps, args): 76 | fmri_dim = data[0].shape[-1] 77 | roi_num = data[0].shape[-2] 78 | retrieval_model = ViT(fmri_dim, roi_num, 512) 79 | retrieval_model.load_state_dict(torch.load(args.retrieval_model_path, map_location = 'cpu')) 80 | retrieval_model.to(args.device).eval() 81 | 82 | retrieval_index = faiss.read_index(args.retrieval_index_path) 83 | res = faiss.StandardGpuResources() 84 | retrieval_index = faiss.index_cpu_to_gpu(res, 0, retrieval_index) 85 | 86 | fmri_embedding = retrieval_model(torch.tensor(data[0], dtype = torch.float32).to(args.device))[:, 0, :] 87 | fmri_embedding = fmri_embedding / fmri_embedding.norm(dim=-1, keepdim=True) 88 | dis, nns = retrieval_index.search(fmri_embedding.cpu().numpy().astype(np.float32), args.k) 89 | 90 | return [data[0], data[1], [[caps[str(sortedimageIDs[nns[n][k]]).split('/')[-1]] for k in range(args.k)] for n in range(nns.shape[0])]] 91 | 92 | 93 | def main(args): 94 | 95 | register_model_and_config() 96 | 97 | args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 98 | 99 | 100 | if args.ROI == 'VC': 101 | selected_rois = ['ROI_V1', 'ROI_V2', 'ROI_V3', 'ROI_V4', 'ROI_LOC', 'ROI_FFA', 'ROI_PPA'] 102 | elif args.ROI == 'LVC': 103 | selected_rois = ['ROI_V1', 'ROI_V2', 'ROI_V3'] 104 | elif args.ROI == 'HVC': 105 | selected_rois = ['ROI_LOC', 'ROI_FFA', 'ROI_PPA'] 106 | 107 | _, _, _, test_cat_rois, test_rois, testStiIDs = DIR_sub_without_images(args.sub, rois = selected_rois) 108 | data = [test_rois, testStiIDs] 109 | 110 | output_dir = 'mindgpt_{}M_{}'.format(args.attention_size, args.decoder_name) 111 | args.k = 0 112 | infer_fn = evaluate_norag_model 113 | 114 | args.model_path = os.path.join(args.model_path, args.encoder_cog, args.dataset, args.sub, args.ROI, output_dir) 115 | 116 | tokenizer = GPT2Tokenizer.from_pretrained(args.decoder_name) 117 | tokenizer.pad_token = PAD_TOKEN 118 | tokenizer.eos_token = EOS_TOKEN 119 | 120 | # configure generation 121 | args.generation_kwargs = {'max_new_tokens': CAPTION_LENGTH, 'no_repeat_ngram_size': 0, 'length_penalty': 0., 122 | 'num_beams': 3, 'early_stopping': True, 'eos_token_id': tokenizer.eos_token_id, 'bos_token_id': tokenizer.bos_token_id} 123 | 124 | # run inference once if checkpoint specified else run for all checkpoints 125 | if args.checkpoint_path is not None: 126 | checkpoint_path = os.path.join(args.model_path, args.checkpoint_path) 127 | infer_one_checkpoint(args, tokenizer, checkpoint_path, data, infer_fn) 128 | else: 129 | for checkpoint_path in os.listdir(args.model_path): 130 | if 'runs' in checkpoint_path: 131 | continue 132 | checkpoint_path = os.path.join(args.model_path, checkpoint_path) 133 | if os.path.exists(os.path.join(checkpoint_path, args.outfile_name)): 134 | print('Found existing file for', checkpoint_path) 135 | else: 136 | infer_one_checkpoint(args, tokenizer, checkpoint_path, data, infer_fn) 137 | 138 | 139 | if __name__ == '__main__': 140 | parser = argparse.ArgumentParser(description='Model Inferring') 141 | parser.add_argument("--model_path", type=str, default='./log/Brain2Text/', help="Path to model to use for inference") 142 | parser.add_argument("--outfile_name", type=str, default='preds.json', help="output file name") 143 | parser.add_argument("--checkpoint_path", type=str, help="Path to checkpoint to use for inference; If not specified, will infer with all checkpoints") 144 | parser.add_argument("--infer_data", type=str, default='test', help="Using test data for inference") 145 | parser.add_argument("--encoder_name", type=str, default="fMRIEncoder", help="Encoder name as found of HuggingFace or stored locally") 146 | parser.add_argument("--encoder_cog", type=str, default="ViT-8-8", help="Encoder parameters") 147 | parser.add_argument("--decoder_name", type=str, default="gpt2", help="Decoder name as found of HuggingFace or stored locally") 148 | parser.add_argument("--template_path", type=str, default="data/template.txt", help="TXT file with template") 149 | parser.add_argument("--attention_size", type=float, default=3.5, help="Number of parameters in the cross attention {28, 14, 7, 3.5, 1.75}") 150 | parser.add_argument("--batch_size", type=int, default=10, help="Batch size; only matter if evaluating a norag model") 151 | parser.add_argument("--sub", type=str, default='sub-3', help="subject") 152 | parser.add_argument("--dataset", type=str, default='DIR', help="fMRI dataset name") 153 | parser.add_argument("--ROI", type=str, default='VC', help='brain ROIs') 154 | args = parser.parse_args() 155 | 156 | main(args) 157 | 158 | -------------------------------------------------------------------------------- /data/dataset_augment.py: -------------------------------------------------------------------------------- 1 | import os 2 | import faiss 3 | import bdpy 4 | import json 5 | import glob 6 | import random 7 | import numpy as np 8 | import pandas as pd 9 | from PIL import Image 10 | import torch.utils.data 11 | import data.configure as config 12 | from torchvision import datasets, transforms 13 | from torch.utils.data import Dataset, DataLoader 14 | import torchvision.io.image as imageio 15 | 16 | def DIR_sub_without_images(sub = 'sub-3', rois = ['ROI_VC']): 17 | train_rois, test_rois = [], [] 18 | DIR_sub_train = bdpy.BData(os.path.join(config.DIR_dataset_dir, config.DIR_train_subs[sub])) 19 | DIR_sub_test = bdpy.BData(os.path.join(config.DIR_dataset_dir, config.DIR_test_subs[sub])) 20 | 21 | train_image_index = DIR_sub_train.select('image_index').squeeze().astype(int) - 1 22 | test_image_index = DIR_sub_test.select('image_index').squeeze().astype(int) - 1 23 | 24 | trainStiIDs = np.array(pd.read_csv(config.kamitani_sti_trainID, header = None)[1])[train_image_index] 25 | testStiIDs = np.array(pd.read_csv(config.kamitani_sti_testID, header = None)[1]) 26 | 27 | MAX_DIM = 0 28 | for roi in rois: 29 | train_roi_fMRI = DIR_sub_train.select(roi) 30 | test_roi_fMRI = DIR_sub_test.select(roi) 31 | 32 | test_roi_fMRI_avg = np.zeros([50, test_roi_fMRI.shape[1]]) 33 | for i in range(50): 34 | test_roi_fMRI_avg[i] = np.mean(test_roi_fMRI[test_image_index == i], axis = 0) 35 | 36 | train_rois.append(train_roi_fMRI) 37 | test_rois.append(test_roi_fMRI_avg) 38 | MAX_DIM = train_roi_fMRI.shape[-1] if train_roi_fMRI.shape[-1] > MAX_DIM else MAX_DIM 39 | 40 | train_rois = np.concatenate(([np.pad(fmri, ((0, 0), (0, MAX_DIM - fmri.shape[-1])))[:,None,:] for fmri in train_rois]), 1).squeeze() 41 | test_rois = np.concatenate(([np.pad(fmri, ((0, 0), (0, MAX_DIM - fmri.shape[-1])))[:,None,:] for fmri in test_rois]), 1).squeeze() 42 | 43 | train_cat_rois = {} 44 | trainCatIDs = [id.split('_')[0] for id in trainStiIDs] 45 | trainCatSet = set(trainCatIDs) 46 | for cat in trainCatSet: 47 | train_cat_rois[cat] = train_rois[np.array(trainCatIDs) == cat] 48 | 49 | test_cat_rois = {} 50 | testCatIDs = [id.split('_')[0] for id in testStiIDs] 51 | for cat in testCatIDs: 52 | test_cat_rois[cat] = test_rois[np.array(testCatIDs) == cat] 53 | 54 | return train_cat_rois, train_rois, trainStiIDs, test_cat_rois, test_rois, testStiIDs 55 | 56 | class Visual_Text_fMRI_Dataset(Dataset): 57 | def __init__(self, imageIDs, caps, categories, fMRI, tokenizer, mixup = False, train = True, 58 | transform = None, max_caption_length = 25, clip_features = None): 59 | self.tokenizer = tokenizer 60 | self.imageIDs = imageIDs 61 | self.fMRI = fMRI 62 | self.mixup = mixup 63 | self.train = train 64 | self.transform = transform 65 | self.caps = caps 66 | self.categories = categories 67 | self.clip_features = clip_features 68 | self.SIMPLE_PREFIX = "This image shows " 69 | self.retrieved_caps = None 70 | self.CAPTION_LENGTH = max_caption_length 71 | 72 | self.template = self.SIMPLE_PREFIX 73 | self.max_target_length = (max_caption_length 74 | + len(tokenizer.encode(self.template))) 75 | 76 | def __len__(self): 77 | return len(self.imageIDs) 78 | 79 | def prep_strings(self, text, tokenizer, retrieved_caps = None): 80 | if not self.train: 81 | padding = False 82 | truncation = False 83 | else: 84 | padding = True 85 | truncation = True 86 | 87 | if retrieved_caps is not None: 88 | infix = '\n\n'.join(retrieved_caps) + '.' 89 | prefix = self.template.replace('||', infix) 90 | else: 91 | prefix = self.SIMPLE_PREFIX 92 | 93 | prefix_ids = tokenizer.encode(prefix) 94 | len_prefix = len(prefix_ids) 95 | 96 | text_ids = tokenizer.encode(text, add_special_tokens = False) 97 | if truncation: 98 | text_ids = text_ids[:self.CAPTION_LENGTH] 99 | input_ids = prefix_ids + text_ids if self.train else prefix_ids 100 | 101 | # we ignore the prefix (minus one as the first subtoken in the prefix is not predicted) 102 | label_ids = [-100] * (len_prefix - 1) + text_ids + [tokenizer.eos_token_id] 103 | if padding: 104 | input_ids += [tokenizer.pad_token_id] * (self.max_target_length - len(input_ids)) 105 | label_ids += [-100] * (self.max_target_length - len(label_ids)) 106 | 107 | if not self.train: 108 | return input_ids 109 | else: 110 | return input_ids, label_ids 111 | 112 | 113 | def __getitem__(self, idx): 114 | file_name = self.imageIDs[idx].split('/')[-1] 115 | category_id = file_name.split('_')[0] 116 | image_id = file_name.split('_')[1].split('.')[0] 117 | #image = Image.open(self.imageIDs[idx]) 118 | cap = self.caps[file_name] 119 | #category = self.categories[category_id] 120 | #visual_features = self.clip_features[idx] 121 | visual_features = self.clip_features[file_name][()] 122 | if(self.train): 123 | fmri_num = self.fMRI[category_id].shape[0] 124 | selected_no = np.random.permutation(range(fmri_num))[:random.randint(1, fmri_num - 1)] 125 | if(self.mixup and selected_no.shape[0] != 1): 126 | coefficient = torch.tensor(np.random.uniform(-1, 1, size = selected_no.shape[0])).softmax(0) 127 | selected_fMRI = torch.tensor(self.fMRI[category_id][selected_no]) 128 | coefficient = coefficient[:, None, None] if len(selected_fMRI.shape) == 3 else coefficient[:, None] 129 | mixup_fMRI = torch.sum(selected_fMRI * coefficient, 0) 130 | fMRI = mixup_fMRI.numpy() 131 | else: 132 | fMRI = self.fMRI[category_id][selected_no[0]].squeeze() 133 | else: 134 | selected = random.randint(0, self.fMRI[category_id].shape[0] - 1) 135 | fMRI = self.fMRI[category_id][selected] 136 | 137 | #image = self.transform(image) if self.transform else image 138 | k_caption = None 139 | decoder_input_ids, labels = self.prep_strings(cap, self.tokenizer, k_caption) 140 | data = {'encoder_inputs': fMRI.astype(np.float32), 'encoder_labels': visual_features, 141 | 'decoder_input_ids': np.array(decoder_input_ids), 'decoder_labels': np.array(labels)} 142 | return data 143 | 144 | def get_visual_text_dir_dataset(sub, rois, batch_size, mixup = True, candidate = True, 145 | tokenizer = None, clip_features = None): 146 | 147 | train_cat_rois, _, trainStiIDs,\ 148 | test_cat_rois, _, testStiIDs = DIR_sub_without_images(sub, rois) 149 | 150 | fmri_dim = train_cat_rois[trainStiIDs[0].split('_')[0]].shape[-1] 151 | Train_category = set([id.split('_')[0] for id in trainStiIDs]) 152 | if(candidate): 153 | train_images = np.concatenate([np.array(glob.glob(f"{config.kamitani_Tr_Aug}/{category}/*.JPEG")) for category in Train_category]) 154 | else: 155 | train_images = np.concatenate([np.array(glob.glob(f"{config.kamitani_Tr_Aug}/*/{image}")) for image in trainStiIDs]) 156 | 157 | train_caps = json.load(open(config.smallCap_Kamitani_train)) 158 | #classIDs = np.array(pd.read_csv(config.kamitani_sti_text, header = None)[0]) 159 | #classTexts = np.array(pd.read_csv(config.kamitani_sti_text, header = None)[1]) 160 | #id_text = dict(zip(classIDs, classTexts)) 161 | 162 | train_dataset = Visual_Text_fMRI_Dataset(train_images, train_caps, None, train_cat_rois, tokenizer, mixup, 163 | True, clip_features = clip_features) 164 | train_dataloader = DataLoader(dataset = train_dataset, batch_size = batch_size, shuffle = True) 165 | return train_dataset, train_dataloader, fmri_dim -------------------------------------------------------------------------------- /modules/brain2text.py: -------------------------------------------------------------------------------- 1 | import timeit 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import CrossEntropyLoss 8 | from transformers.configuration_utils import PretrainedConfig 9 | from transformers.modeling_outputs import BaseModelOutput, Seq2SeqLMOutput 10 | from transformers.modeling_utils import PreTrainedModel 11 | #from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings 12 | from transformers.utils import logging 13 | from transformers.models.auto.configuration_auto import AutoConfig 14 | from transformers.models.auto.modeling_auto import AutoModel, AutoModelForCausalLM 15 | from transformers.models.vision_encoder_decoder.configuration_vision_encoder_decoder import VisionEncoderDecoderConfig 16 | import inspect 17 | 18 | from modules.gpt2 import ThisGPT2LMHeadModel 19 | from modules.gpt2 import ThisGPT2Config 20 | 21 | 22 | 23 | # Copied from transformers.models.encoder_decoder.modeling_encoder_decoder.shift_tokens_right 24 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 25 | """ 26 | Shift input ids one token to the right. 27 | """ 28 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 29 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 30 | if decoder_start_token_id is None: 31 | raise ValueError("Make sure to set the decoder_start_token_id attribute of the model's configuration.") 32 | shifted_input_ids[:, 0] = decoder_start_token_id 33 | 34 | if pad_token_id is None: 35 | raise ValueError("Make sure to set the pad_token_id attribute of the model's configuration.") 36 | # replace possible -100 values in labels by `pad_token_id` 37 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 38 | 39 | return shifted_input_ids 40 | 41 | 42 | logger = logging.get_logger(__name__) 43 | 44 | _CONFIG_FOR_DOC = "Brain2TextConfig" 45 | 46 | VISION_ENCODER_DECODER_START_DOCSTRING = r""" 47 | This class can be used to initialize an image-to-text-sequence model with any pretrained vision autoencoding model 48 | as the encoder and any pretrained text autoregressive model as the decoder. The encoder is loaded via 49 | [`~AutoModel.from_pretrained`] function and the decoder is loaded via [`~AutoModelForCausalLM.from_pretrained`] 50 | function. Cross-attention layers are automatically added to the decoder and should be fine-tuned on a downstream 51 | generative task, like image captioning. 52 | 53 | The effectiveness of initializing sequence-to-sequence models with pretrained checkpoints for sequence generation 54 | tasks was shown in [Leveraging Pre-trained Checkpoints for Sequence Generation 55 | Tasks](https://arxiv.org/abs/1907.12461) by Sascha Rothe, Shashi Narayan, Aliaksei Severyn. Michael Matena, Yanqi 56 | Zhou, Wei Li, Peter J. Liu. 57 | 58 | Additionally, in [TrOCR: Transformer-based Optical Character Recognition with Pre-trained 59 | Models](https://arxiv.org/abs/2109.10282) it is shown how leveraging large pretrained vision models for optical 60 | character recognition (OCR) yields a significant performance improvement. 61 | 62 | After such a Vision-Encoder-Text-Decoder model has been trained/fine-tuned, it can be saved/loaded just like any 63 | other models (see the examples for more information). 64 | 65 | This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the 66 | library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads 67 | etc.) 68 | 69 | This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. 70 | Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage 71 | and behavior. 72 | 73 | Parameters: 74 | config ([`VisionEncoderDecoderConfig`]): Model configuration class with all the parameters of the model. 75 | Initializing with a config file does not load the weights associated with the model, only the 76 | configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. 77 | """ 78 | 79 | VISION_ENCODER_DECODER_INPUTS_DOCSTRING = r""" 80 | Args: 81 | pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 82 | Pixel values. Pixel values can be obtained using a feature extractor (e.g. if you use ViT as the encoder, 83 | you should use [`ViTFeatureExtractor`]). See [`ViTFeatureExtractor.__call__`] for details. 84 | decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 85 | Indices of decoder input sequence tokens in the vocabulary. 86 | 87 | Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and 88 | [`PreTrainedTokenizer.__call__`] for details. 89 | 90 | [What are input IDs?](../glossary#input-ids) 91 | 92 | If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see 93 | `past_key_values`). 94 | 95 | For training, `decoder_input_ids` are automatically created by the model by shifting the `labels` to the 96 | right, replacing -100 by the `pad_token_id` and prepending them with the `decoder_start_token_id`. 97 | decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*): 98 | Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also 99 | be used by default. 100 | encoder_outputs (`tuple(torch.FloatTensor)`, *optional*): 101 | This tuple must consist of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`) 102 | `last_hidden_state` (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) is a tensor 103 | of hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the 104 | decoder. 105 | past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): 106 | Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. 107 | 108 | If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that 109 | don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all 110 | `decoder_input_ids` of shape `(batch_size, sequence_length)`. 111 | decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`, *optional*): 112 | Optionally, instead of passing `decoder_input_ids` you can choose to directly pass an embedded 113 | representation. This is useful if you want more control over how to convert `decoder_input_ids` indices 114 | into associated vectors than the model's internal embedding lookup matrix. 115 | labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): 116 | Labels for computing the masked language modeling loss for the decoder. Indices should be in `[-100, 0, 117 | ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored 118 | (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` 119 | use_cache (`bool`, *optional*): 120 | If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see 121 | `past_key_values`). 122 | output_attentions (`bool`, *optional*): 123 | Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned 124 | tensors for more detail. 125 | output_hidden_states (`bool`, *optional*): 126 | Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for 127 | more detail. 128 | return_dict (`bool`, *optional*): 129 | If set to `True`, the model will return a [`~utils.Seq2SeqLMOutput`] instead of a plain tuple. 130 | kwargs: (*optional*) Remaining dictionary of keyword arguments. Keyword arguments come in two flavors: 131 | 132 | - Without a prefix which will be input as `**encoder_kwargs` for the encoder forward function. 133 | - With a *decoder_* prefix which will be input as `**decoder_kwargs` for the decoder forward function. 134 | """ 135 | 136 | class Brain2TextConfig(VisionEncoderDecoderConfig): 137 | model_type = "Brain2Text" 138 | 139 | def __init__( 140 | self, 141 | **kwargs, 142 | ): 143 | super().__init__(**kwargs) 144 | 145 | 146 | class Brain2Text(PreTrainedModel): 147 | config_class = Brain2TextConfig 148 | base_model_prefix = "Brain2Text" 149 | main_input_name = "fMRI" 150 | 151 | def __init__(self, 152 | config: Optional[PretrainedConfig] = None, 153 | encoder: Optional[PreTrainedModel] = None, 154 | decoder: Optional[PreTrainedModel] = None, 155 | ): 156 | if config is None and (encoder is None or decoder is None): 157 | raise ValueError("Either a configuration or an encoder and a decoder has to be provided.") 158 | if config is None: 159 | config = Brain2TextConfig.from_encoder_decoder_configs(encoder.config, decoder.config) 160 | else: 161 | if not isinstance(config, self.config_class): 162 | raise ValueError(f"Config: {config} has to be of type {self.config_class}") 163 | 164 | if config.decoder.cross_attention_hidden_size is not None: 165 | if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size: 166 | raise ValueError( 167 | "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal#" 168 | f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for" 169 | f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for" 170 | " `config.encoder.hidden_size`." 171 | ) 172 | 173 | # initialize with config 174 | # make sure input & output embeddings is not tied 175 | config.tie_word_embeddings = False 176 | super().__init__(config) 177 | 178 | if encoder is None: 179 | encoder = AutoModel.from_config(config.encoder) 180 | 181 | if decoder is None: 182 | decoder = AutoModelForCausalLM.from_config(config.decoder) 183 | 184 | self.encoder = encoder 185 | self.encoder.main_input_name = 'fMRI' 186 | self.decoder = decoder 187 | # self.brainsem = None 188 | # make sure that the individual model's config refers to the shared config 189 | # so that the updates to the config will be synced 190 | self.encoder.config = self.config.encoder 191 | self.decoder.config = self.config.decoder 192 | 193 | def get_encoder(self): 194 | return self.encoder 195 | 196 | def get_decoder(self): 197 | return self.decoder 198 | 199 | 200 | def get_output_embeddings(self): 201 | return self.decoder.get_output_embeddings() 202 | 203 | def set_output_embeddings(self, new_embeddings): 204 | return self.decoder.set_output_embeddings(new_embeddings) 205 | 206 | @classmethod 207 | def from_pretrained(cls, *args, **kwargs): 208 | # At the moment fast initialization is not supported for composite models 209 | if kwargs.get("_fast_init", False): 210 | logger.warning( 211 | "Fast initialization is currently not supported for VisionEncoderDecoderModel. " 212 | "Falling back to slow initialization..." 213 | ) 214 | kwargs["_fast_init"] = False 215 | return super().from_pretrained(*args, **kwargs) 216 | 217 | @classmethod 218 | def from_encoder_decoder_pretrained( 219 | cls, 220 | encoder, 221 | decoder_name: str = None, 222 | cross_attention_reduce_factor: int = None, 223 | **kwargs 224 | ) -> PreTrainedModel: 225 | 226 | decoder_config = ThisGPT2Config.from_pretrained(decoder_name) 227 | decoder_config.is_decoder = True 228 | decoder_config.add_cross_attention = True 229 | decoder_config.encoder_hidden_size = encoder.config.hidden_size 230 | decoder_config.cross_attention_reduce_factor = cross_attention_reduce_factor 231 | decoder = ThisGPT2LMHeadModel.from_pretrained(decoder_name, config = decoder_config) 232 | 233 | # instantiate config with corresponding kwargs 234 | config = Brain2TextConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) 235 | 236 | # make sure input & output embeddings is not tied 237 | config.tie_word_embeddings = False 238 | return cls(encoder=encoder, decoder=decoder, config=config) 239 | 240 | def forward( 241 | self, 242 | encoder_inputs=None, 243 | encoder_labels=None, 244 | encoder_outputs=None, 245 | decoder_input_ids=None, 246 | decoder_attention_mask=None, 247 | past_key_values=None, 248 | decoder_inputs_embeds=None, 249 | decoder_labels=None, 250 | use_cache=None, 251 | output_attentions=None, 252 | output_hidden_states=None, 253 | return_dict=None, 254 | **kwargs, 255 | ): 256 | 257 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 258 | 259 | kwargs_decoder = {argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")} 260 | 261 | #kwargs_encoder = {argument[len("encoder_") :]: value for argument, value in kwargs.items() if argument.startswith("encoder_")} 262 | 263 | #encoder_outputs = kwargs_encoder['outputs'].last_hidden_state 264 | #if(encoder_outputs is None): 265 | if(encoder_outputs is None): 266 | encoder_outputs = self.encoder(encoder_inputs) 267 | encoder_hidden_states = encoder_outputs 268 | encoder_attention_mask = None 269 | 270 | visual_loss = None 271 | if encoder_labels is not None: 272 | #visual_loss = nn.MSELoss()(encoder_outputs[:, :, :], encoder_labels[:, :, :]) 273 | visual_loss = nn.MSELoss()(encoder_outputs[:, 0, :], encoder_labels[:, 0, :]) 274 | 275 | # Decode 276 | decoder_outputs = self.decoder( 277 | input_ids=decoder_input_ids, 278 | attention_mask=decoder_attention_mask, 279 | encoder_hidden_states=encoder_hidden_states, 280 | encoder_attention_mask=encoder_attention_mask, 281 | inputs_embeds=decoder_inputs_embeds, 282 | output_attentions=output_attentions, 283 | output_hidden_states=output_hidden_states, 284 | use_cache=use_cache, 285 | past_key_values=past_key_values, 286 | return_dict=return_dict, 287 | **kwargs_decoder, 288 | ) 289 | 290 | # Compute loss independent from decoder (as some shift the logits inside them) 291 | loss = None 292 | if decoder_labels is not None: 293 | logits = decoder_outputs.logits if return_dict else decoder_outputs[0] 294 | loss_fct = CrossEntropyLoss() 295 | loss = loss_fct(logits.reshape(-1, self.decoder.config.vocab_size), decoder_labels.view(-1)) 296 | 297 | final_loss = loss 298 | if visual_loss is not None: 299 | final_loss = loss + 10 * visual_loss 300 | 301 | if not return_dict: 302 | if loss is not None: 303 | return (loss,) + decoder_outputs + encoder_outputs 304 | else: 305 | return decoder_outputs + encoder_outputs 306 | 307 | #print(f'text_loss: {loss}, visual_loss: {visual_loss}') 308 | 309 | return Seq2SeqLMOutput( 310 | loss=final_loss, 311 | logits=decoder_outputs.logits, 312 | past_key_values=decoder_outputs.past_key_values, 313 | decoder_hidden_states=decoder_outputs.hidden_states, 314 | decoder_attentions=decoder_outputs.attentions, 315 | cross_attentions=decoder_outputs.cross_attentions, 316 | encoder_last_hidden_state=None, 317 | encoder_hidden_states=None, 318 | encoder_attentions=None 319 | ) 320 | 321 | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): 322 | return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) 323 | 324 | def prepare_inputs_for_generation( 325 | self, input_ids, encoder_outputs, past=None, attention_mask=None, use_cache=None, **kwargs 326 | ): 327 | decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) 328 | decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None 329 | input_dict = { 330 | "attention_mask": attention_mask, 331 | "decoder_attention_mask": decoder_attention_mask, 332 | "decoder_input_ids": decoder_inputs["input_ids"], 333 | "encoder_outputs": encoder_outputs['last_hidden_state'], 334 | #"encoder_labels": encoder_labels, 335 | "past_key_values": decoder_inputs["past_key_values"], 336 | "use_cache": use_cache, 337 | } 338 | return input_dict 339 | 340 | def resize_token_embeddings(self, *args, **kwargs): 341 | raise NotImplementedError( 342 | "Resizing the embedding layers via the VisionEncoderDecoderModel directly is not supported.Please use the" 343 | " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" 344 | ) 345 | 346 | def _reorder_cache(self, past, beam_idx): 347 | # apply decoder cache reordering here 348 | return self.decoder._reorder_cache(past, beam_idx) 349 | --------------------------------------------------------------------------------