├── README.md ├── __init__.py ├── language_modelling ├── __init__.py ├── run_generation.py └── utils.py ├── model ├── __init__.py ├── graph.py ├── modelling_cross_attention.py └── modelling_self_attention.py ├── requirements.txt ├── script └── train_generation.sh └── wikiweb2m ├── __init__.py ├── cider ├── __init__.py ├── cider.py └── cider_scorer.py ├── data.py └── preprocess_data.py /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Graph Learning 2 | 3 | Most multimodal learning algorithms focus on modeling simple one-to-one pairs of data from two modalities, such as image-caption pairs, or audiotext pairs. However, in most real-world settings, entities of different modalities 4 | interact with each other in more complex and multifaceted ways, going beyond one-to-one mappings. 5 | 6 | We propose Multimodal Graph Learning (MMGL), a systematic framework for capturing information from multiple multimodal 7 | neighbors with relational structures among them. 8 | In particular, we focus on MMGL for generative tasks, building upon pretrained Language Models (LMs), aiming to 9 | augment their text generation with multimodal neighbor contexts. 10 | 11 | The original paper can be found at [MMGL](https://arxiv.org/pdf/2310.07478.pdf). 12 | The initial version of implementation including graph encodings can be found at [research-MMHG](https://github.com/minjiyoon/research-MMHG/tree/main). 13 | 14 | ## Setup 15 | 16 | Create a new conda environment, install [PyTorch](https://pytorch.org) and the remaining requirements: 17 | ``` 18 | conda create python==3.7 -n mmgl 19 | conda activate mmgl 20 | pip install -r requirements.txt 21 | pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu117 22 | ``` 23 | The code is implemented on PyTorch DistributedDataParallel. 24 | The code supports the [WikiWeb2M](https://github.com/google-research-datasets/wit/blob/main/wikiweb2m.md) dataset. 25 | 26 | ## Data preprocessing 27 | 28 | First, make a folder to download the WikiWeb2M dataset: `mkdir wikiweb2m/raw`. 29 | Then download all Train/Validation/Test files from the WikiWeb2M into `wikiweb2m/raw`. 30 | Next, make a folder to download images: `mkdir wikiweb2m/raw/images`. 31 | Finally, run `preprocess_data.py` to convert the WikiWeb2M dataset into pytorch format. 32 | 33 | ``` 34 | python preprocess_data.py 35 | ``` 36 | 37 | The output training/validation/test set sizes for section summarization is as follows: 38 | 39 | | Number of | Train | Validation | Test | 40 | | ---- | ---- | ---- | ---- | 41 | | Sections | 680K | 170K | 170K | 42 | 43 | ## Training 44 | 45 | #### Script 46 | 47 | In `script/train_generation.sh`, you can specify the base model (`MODEL_NAME`), the task (`TASK`; currently we support only section summarization 'section'), the neighbor context (`CONTEXT`). 48 | For `CONTEXT`, there are four options as follows: 49 | 50 | | CONTEXT | description | 51 | | ---- | ---- | 52 | | section_only | use only text in the target section | 53 | | section_all | use text and images in the target section | 54 | | text_only | use only text in the all page | 55 | | all | use text and images in the all page | 56 | 57 | You can set how to encode text neighbors using `NEIGHBOR_MODE`. There are two options as follows: 58 | 59 | | NEIGHBOR_MODE | description | 60 | | ---- | ---- | 61 | | raw | concatenate text neighbors as raw text into the input text | 62 | | embedding | embed text neighbors using `text_model` and concatenate embeddings into the input text | 63 | 64 | You can set the parameter-efficient fine-tuning (PEFT) option in the script using `PEFT_TYPE`. There are four PEFT options. 65 | 66 | | CONTEXT | description | 67 | | ---- | ---- | 68 | | none | full finetune | 69 | | prompt | prompt tuning | 70 | | prefix | prefix tuning | 71 | | lora | LoRA | 72 | | flamingo | fine-tune only newly added cross-attention; can be used on decode-only models with `neighbor_mode = embedding`| 73 | 74 | In the script, you can change `max_input_length` and `max_output_length` in addition to other optimization hyperparameters (e.g., `epochs`, `learning_rate`, `per_device_train_batch_size`). 75 | You can set which models to encode text and image neighbors using `text_model` and `visual_model`. 76 | All arguments you can set are defined under `Argument` class in `language_modelling/run_generation.py`. 77 | 78 | #### File description 79 | 80 | We provide brief descriptions for each file as follows: 81 | 82 | | Directory/File | description | 83 | | ---- | ---- | 84 | | wikiweb2m/ | codes related to WikiWeb2M dataset | 85 | | wikiweb2m/cider | compute CIDEr scores | 86 | | wikiweb2m/data.py | prepare each training point based on `context` and `neighbor_mode` | 87 | | wikiweb2m/preprocess_data.py | codes to preprocess WikiWeb2M dataset and download images | 88 | | script/ | codes to run MMGL | 89 | | script/train_generation.sh | set hyperparameters | 90 | | language_modelling/ | main directory | 91 | | language_modelling/run_generation.py | prepare models, read datasets, train/validation loops | 92 | | language_modelling/utils.py | utility functions | 93 | | model/ | language models | 94 | | model/modelling_self_attention.py | LMs only with self-attention; including encoder-decoder and decoder-only models | 95 | | model/modelling_cross_attention.py | LMs with cross-attention to encode neighbor information; decoder-only models| 96 | 97 | ## Citation 98 | If you find this work or our code useful, please consider citing: 99 | ``` 100 | @article{yoon2023multimodal, 101 | title={Multimodal Graph Learning for Generative Tasks}, 102 | author={Yoon, Minji and Koh, Jing Yu and Hooi, Bryan and Salakhutdinov, Ruslan}, 103 | journal={arXiv preprint arXiv:2310.07478}, 104 | year={2023} 105 | } 106 | ``` 107 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minjiyoon/MMGL/025cc8f850262bcf39ea958a79cbc6727540ec46/__init__.py -------------------------------------------------------------------------------- /language_modelling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/minjiyoon/MMGL/025cc8f850262bcf39ea958a79cbc6727540ec46/language_modelling/__init__.py -------------------------------------------------------------------------------- /language_modelling/run_generation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | 4 | """ Finetuning summary generation models""" 5 | from collections import OrderedDict 6 | import json 7 | import os 8 | import random 9 | import sys 10 | import tqdm 11 | import wandb 12 | import warnings 13 | warnings.simplefilter(action='ignore', category=FutureWarning) 14 | warnings.simplefilter(action='ignore', category=UserWarning) 15 | 16 | import time 17 | from time import perf_counter 18 | from dataclasses import dataclass, field 19 | from typing import Optional 20 | 21 | import numpy as np 22 | 23 | import torch 24 | import torch.backends.cudnn as cudnn 25 | import torch.distributed as dist 26 | import torch.nn as nn 27 | import torch.multiprocessing as mp 28 | mp.set_sharing_strategy('file_system') 29 | from torch.utils.data import DataLoader 30 | from torch.nn.parallel import DistributedDataParallel as DDP 31 | from torch.optim.lr_scheduler import StepLR 32 | from torchmetrics import BLEUScore 33 | from torchmetrics.text import ROUGEScore 34 | from warmup_scheduler import GradualWarmupScheduler 35 | 36 | from datasets import load_dataset 37 | import evaluate 38 | import transformers 39 | from transformers import ( 40 | AutoConfig, 41 | AutoTokenizer, 42 | AutoModelForSeq2SeqLM, 43 | AutoModelForCausalLM, 44 | HfArgumentParser, 45 | set_seed, 46 | get_scheduler, 47 | ) 48 | from transformers.optimization import Adafactor 49 | from transformers.utils import check_min_version 50 | from transformers.utils.versions import require_version 51 | 52 | from wikiweb2m import load_wikiweb2m, WikiWeb2M 53 | from wikiweb2m.cider import Cider 54 | 55 | from language_modelling import utils 56 | from model import SelfAttentionModel, CrossAttentionModel 57 | 58 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 59 | check_min_version("4.17.0") 60 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/text-classification/requirements.txt") 61 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Only display errors 62 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 63 | 64 | best_acc1 = 0 # Variable to keep track of best model so far. 65 | 66 | @dataclass 67 | class Arguments: 68 | """ 69 | Arguments pertaining to what data we are going to input our model for training and eval. 70 | 71 | Using `HfArgumentParser` we can turn this class 72 | into argparse arguments to be able to specify them on 73 | the command line. 74 | """ 75 | overwrite_cache: Optional[bool] = field( 76 | default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} 77 | ) 78 | dataset: Optional[str] = field( 79 | default='wikiweb2m', metadata={"help": "The name of the dataset to use."} 80 | ) 81 | task: Optional[str] = field( 82 | default='section', metadata={"help": "One of three generation tasks in WikiWeb2M"} 83 | ) 84 | context: Optional[str] = field( 85 | default='section_only', metadata={"help": "Range of neighbor context: section_only, section_all, text_only, all"} 86 | ) 87 | max_input_length: Optional[int] = field( 88 | default=512, metadata={"help": "maximum token length of input text"} 89 | ) 90 | max_output_length: Optional[int] = field( 91 | default=128, metadata={"help": "maximum token length of output text"} 92 | ) 93 | 94 | wandb_project: Optional[str] = field( 95 | default='MMGL', metadata={"help": "wandb project name"} 96 | ) 97 | wandb_run: Optional[str] = field( 98 | default='default', metadata={"help": "wandb run name"} 99 | ) 100 | log_dir: Optional[str] = field( 101 | default='log', metadata={"help": "logging dir"} 102 | ) 103 | save_dir: Optional[str] = field( 104 | default=None, metadata={"help": "save dir"} 105 | ) 106 | resume: Optional[str] = field( 107 | default=None, metadata={"help": "path to latest checkpoint (default: none)"} 108 | ) 109 | 110 | seed: Optional[int] = field( 111 | default=None, metadata={"help": "seed for initializing training."} 112 | ) 113 | fp16: Optional[bool] = field( 114 | default=False, metadata={"help": "What precision to train in."} 115 | ) 116 | bf16: Optional[bool] = field( 117 | default=False, metadata={"help": "What precision to train in."} 118 | ) 119 | 120 | test: Optional[bool] = field( 121 | default=False, metadata={"help": "evaluate model on validation set."} 122 | ) 123 | 124 | per_device_train_batch_size: Optional[int] = field( 125 | default=4, metadata={"help": "Batch size per device during training."} 126 | ) 127 | per_device_val_batch_size: Optional[int] = field( 128 | default=4, metadata={"help": "Batch size per device during validation/test."} 129 | ) 130 | dataloader_num_workers: Optional[int] = field( 131 | default=4, metadata={"help": "Number of threads to read data."} 132 | ) 133 | 134 | start_epoch: Optional[int] = field( 135 | default=0, metadata={"help": "Starting epoch."} 136 | ) 137 | epochs: Optional[int] = field( 138 | default=90, metadata={"help": "Total number of epochs."} 139 | ) 140 | steps_per_epoch: Optional[int] = field( 141 | default=2000, metadata={"help": "Number of training steps per epoch."} 142 | ) 143 | val_steps_per_epoch: Optional[int] = field( 144 | default=1000, metadata={"help": "Number of validation/test steps per epoch."} 145 | ) 146 | print_freq: Optional[int] = field( 147 | default=50, metadata={"help": "print frequency"} 148 | ) 149 | 150 | learning_rate: Optional[float] = field( 151 | default=0.001, metadata={"help": "initial learning rate."} 152 | ) 153 | adam_beta1: Optional[float] = field( 154 | default=0.9, metadata={"help": "beta1 for Adam."} 155 | ) 156 | adam_beta2: Optional[float] = field( 157 | default=0.95, metadata={"help": "beta2 for AdamDecay."} 158 | ) 159 | weight_decay: Optional[float] = field( 160 | default=0.01, metadata={"help": "Weight decay parameter."} 161 | ) 162 | grad_accumulation_steps: Optional[int] = field( 163 | default=4, metadata={"help": "number of gradient accumulation steps."} 164 | ) 165 | grad_clip: Optional[float] = field( 166 | default=1.0, metadata={"help": "gradient clipping amount."} 167 | ) 168 | lr_warmup_steps: Optional[int] = field( 169 | default=2000, metadata={"help": "Number of steps to warm up lr."} 170 | ) 171 | lr_schedule_step_size: Optional[int] = field( 172 | default=5, metadata={"help": "Number of steps before decaying lr."} 173 | ) 174 | lr_schedule_gamma: Optional[float] = field( 175 | default=0.1, metadata={"help": "Decay parameter for learning rate scheduler."} 176 | ) 177 | 178 | model_name_or_path: str = field( 179 | default=None, metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 180 | ) 181 | decoder_only: Optional[bool] = field( 182 | default=False, metadata={"help": "whether LM models are decoder-only: opt or mpt"} 183 | ) 184 | cross_attention: Optional[bool] = field( 185 | default=False, metadata={"help": "whether LM models use cross-attention: mpt"} 186 | ) 187 | text_model: str = field( 188 | default="roberta-base", metadata={"help": "text model to encode neighbor texts"} 189 | ) 190 | visual_model: str = field( 191 | default="openai/clip-vit-base-patch16", metadata={"help": "visual model to encode neighbor images"} 192 | ) 193 | n_text_tokens: int = field( 194 | default=4, metadata={"help": "number of tokens for text embeddings"} 195 | ) 196 | n_visual_tokens: int = field( 197 | default=4, metadata={"help": "number of tokens for visual embeddings"} 198 | ) 199 | freeze_lm: Optional[bool] = field( 200 | default=False, metadata={"help": "whether to freeze LM parameters"} 201 | ) 202 | neighbor_mode: str = field( 203 | default="raw", metadata={"help": "how to encode neighbor information: raw, embedding"} 204 | ) 205 | max_text_neighbors: int = field( 206 | default=11, metadata={"help": "maxinum number of text neighbors"} 207 | ) 208 | max_image_neighbors: int = field( 209 | default=5, metadata={"help": "maximum number of image neighbors"} 210 | ) 211 | position_type: str = field( 212 | default="none", metadata={"help": "position id type for text/image neighbors"} 213 | ) 214 | 215 | num_neighbor_layers: int = field( 216 | default=4, metadata={"help": "number of cross-attention layers to encode neighbor information"} 217 | ) 218 | peft_type: str = field( 219 | default="none", metadata={"help": "peft type: none, prefix, prompt, lora, flamingo"} 220 | ) 221 | lora_r: int = field( 222 | default=64, metadata={"help": "lora row rank"} 223 | ) 224 | lora_alpha: float = field( 225 | default=1, metadata={"help": "lora scaling factor"} 226 | ) 227 | lora_dropout: float = field( 228 | default=0.0, metadata={"help": "lora dropout rate"} 229 | ) 230 | 231 | 232 | def main(): 233 | 234 | parser = HfArgumentParser((Arguments)) 235 | args = parser.parse_args_into_dataclasses()[0] 236 | 237 | # Set a new log directory 238 | i = 0 239 | log_dir = os.path.join(args.log_dir, f'{args.wandb_run}_{i}') 240 | while os.path.exists(log_dir): 241 | i += 1 242 | log_dir = os.path.join(args.log_dir, f'{args.wandb_run}_{i}') 243 | os.makedirs(log_dir) 244 | args.save_dir = os.path.join(log_dir, 'ckpt.pth.tar') 245 | 246 | # Wandb logging 247 | combined_args = {**vars(args)} 248 | run = wandb.init(project=args.wandb_project, name=args.wandb_run) 249 | run.config.update(combined_args) 250 | 251 | print(f'Logging to {log_dir}.') 252 | 253 | # Prepare seed 254 | if args.seed is not None: 255 | random.seed(args.seed) 256 | torch.manual_seed(args.seed) 257 | cudnn.deterministic = True 258 | warnings.warn('You have chosen to seed training. ' 259 | 'This will turn on the CUDNN deterministic setting, ' 260 | 'which can slow down your training considerably! ' 261 | 'You may see unexpected behavior when restarting ' 262 | 'from checkpoints.') 263 | 264 | # Prepare distributed data parallel 265 | ngpus_per_node = torch.cuda.device_count() 266 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args, log_dir, run)) 267 | 268 | 269 | def main_worker(gpu, world_size, args, log_dir, run): 270 | """ 271 | Main worker function to train and evaluate models. 272 | Args: 273 | gpu (int): GPU id to use. 274 | world_size (int): Number of GPUs to use. 275 | args (Arguments): Arguments. 276 | log_dir (str): Logging directory. 277 | run (wandb run): Wandb run. 278 | """ 279 | 280 | # Variable to keep track of best model so far. 281 | global best_acc1 282 | print("Use GPU: {} for training".format(gpu)) 283 | dist.init_process_group(backend='nccl', init_method='tcp://127.0.0.1:1337', world_size=world_size, rank=gpu) 284 | 285 | # Prepare pretrained model 286 | if "t5" in args.model_name_or_path: 287 | # encoder-decoder models 288 | args.decoder_only = False 289 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 290 | model = SelfAttentionModel(args, tokenizer) 291 | elif "opt" in args.model_name_or_path: 292 | # decoder-only models 293 | args.decoder_only = True 294 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 295 | model = SelfAttentionModel(args, tokenizer) 296 | elif "mpt" in args.model_name_or_path: 297 | # OPT models with newly added cross-attention layers 298 | args.decoder_only = True 299 | args.model_name_or_path = args.model_name_or_path.replace("mpt", "opt") 300 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, use_fast=False) 301 | model = CrossAttentionModel(args, tokenizer) 302 | 303 | tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True 304 | if args.fp16: 305 | model = model.float() 306 | elif args.bf16: 307 | model = model.bfloat16() 308 | 309 | # Wandb logging 310 | if gpu % world_size == 0: 311 | _, total_trainable_params, total_nontrainable_params = utils.get_params_count(model) 312 | run.watch(model) 313 | run.config.update({"total_params": total_trainable_params + total_nontrainable_params}) 314 | run.config.update({"trainable_params": total_trainable_params}) 315 | run.config.update({"non_trainable_params": total_nontrainable_params}) 316 | 317 | torch.cuda.set_device(gpu) 318 | model.cuda(gpu) 319 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu], find_unused_parameters=False) 320 | 321 | if "t5" in args.model_name_or_path: 322 | print('Using Adafactor as the optimizer.') 323 | optimizer = Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=args.learning_rate) 324 | scheduler = None 325 | elif "opt" in args.model_name_or_path: 326 | print('Using AdamW as the optimizer.') 327 | optimizer_cls = torch.optim.AdamW 328 | optimizer = optimizer_cls(model.parameters(), args.learning_rate, 329 | betas=(args.adam_beta1, args.adam_beta2), 330 | weight_decay=args.weight_decay, eps=1e-8) 331 | """Sets the learning rate to the initial LR decayed by 10 every 5 epochs""" 332 | scheduler_steplr = StepLR(optimizer, step_size=(args.lr_schedule_step_size * args.steps_per_epoch) // args.grad_accumulation_steps, gamma=args.lr_schedule_gamma) 333 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1.0, total_epoch=args.lr_warmup_steps, after_scheduler=scheduler_steplr) 334 | 335 | # Detecting last checkpoint. 336 | if args.resume: 337 | checkpoint_path = os.path.join(args.log_dir, args.resume, 'ckpt.pth.tar') 338 | if os.path.isfile(checkpoint_path): 339 | print("=> loading checkpoint '{}'".format(checkpoint_path)) 340 | loc = 'cuda:{}'.format(gpu) 341 | checkpoint = torch.load(checkpoint_path, map_location=loc) 342 | args.start_epoch = checkpoint['epoch'] 343 | best_acc1 = checkpoint['best_acc1'] 344 | model.load_state_dict(checkpoint['state_dict'], strict=False) 345 | optimizer.load_state_dict(checkpoint['optimizer']) 346 | if scheduler is not None: 347 | scheduler.load_state_dict(checkpoint['scheduler']) 348 | print("=> loaded checkpoint '{}' (epoch {}, best_acc {})".format(checkpoint_path, checkpoint['epoch'], checkpoint['best_acc1'])) 349 | else: 350 | print("=> no checkpoint found at '{}'".format(checkpoint_path)) 351 | 352 | cudnn.benchmark = True 353 | 354 | # Prepare Dataset 355 | start_time = perf_counter() 356 | train_data, val_data, test_data, id_list = load_wikiweb2m(args.task) 357 | print(f'Loading wikiweb2m done: {perf_counter()-start_time}') 358 | start_time = perf_counter() 359 | train_dataset = WikiWeb2M(args, train_data, id_list["train"], tokenizer, args.visual_model) 360 | val_dataset = WikiWeb2M(args, val_data, id_list["val"], tokenizer, args.visual_model) 361 | test_dataset = WikiWeb2M(args, test_data, id_list["test"], tokenizer, args.visual_model) 362 | print(f'Initialize datasets: {perf_counter()-start_time}') 363 | print(f'Training with {len(train_dataset)} examples, validating with {len(val_dataset)} examples, testing with {len(test_dataset)} examples.') 364 | 365 | # Sampler 366 | train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, drop_last=True) 367 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True) 368 | test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False, drop_last=True) 369 | 370 | # Dataloader 371 | start_time = perf_counter() 372 | train_loader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size, 373 | shuffle=False, num_workers=args.dataloader_num_workers, prefetch_factor=10, pin_memory=False, sampler=train_sampler) 374 | val_loader = DataLoader(val_dataset, batch_size=args.per_device_val_batch_size, 375 | shuffle=False, num_workers=args.dataloader_num_workers, prefetch_factor=10, pin_memory=False, sampler=val_sampler) 376 | test_loader = DataLoader(test_dataset, batch_size=args.per_device_val_batch_size, 377 | shuffle=False, num_workers=args.dataloader_num_workers, prefetch_factor=10, pin_memory=False, sampler=test_sampler) 378 | print(f'Initialize dataloaders: {perf_counter()-start_time}') 379 | 380 | if args.test: 381 | evaluate_loop(test_loader, model, tokenizer, epoch, args, run) 382 | return 383 | total_time = 0 384 | for epoch in range(args.start_epoch, args.epochs): 385 | start_time = time.time() 386 | if epoch == 0: 387 | evaluate_loop(val_loader, model, tokenizer, epoch-1, args, run) 388 | 389 | # train for one epoch 390 | train_sampler.set_epoch(epoch) 391 | train_loop(train_loader, model, tokenizer, optimizer, epoch, scheduler, args, run) 392 | 393 | # evaluate on validation set 394 | acc1 = evaluate_loop(val_loader, model, tokenizer, epoch, args, run) 395 | 396 | # remember best acc@1 and save checkpoint 397 | is_best = acc1 > best_acc1 398 | best_acc1 = max(acc1, best_acc1) 399 | 400 | if gpu % world_size == 0 and (is_best or epoch == 0): 401 | # Only save non-frozen parameters. 402 | stripped_state_dict = { 403 | k: v for k, v in model.state_dict().items() if 404 | ('.text_model' not in k and '.visual_model' not in k) 405 | } 406 | stripped_state_dict = OrderedDict(sorted(stripped_state_dict.items())) 407 | state = { 408 | 'epoch': epoch, 409 | 'best_acc1': acc1, 410 | 'state_dict': stripped_state_dict, 411 | 'optimizer' : optimizer.state_dict(), 412 | } 413 | if scheduler is not None: 414 | state['scheduler'] = scheduler.state_dict() 415 | print('=> save best val model ...', args.save_dir) 416 | torch.save(state, args.save_dir) 417 | epoch_time = time.time() - start_time 418 | total_time += epoch_time 419 | print(f"Epoch {epoch} time: {epoch_time}s") 420 | print(f"Total time: {total_time}s") 421 | # Test 422 | checkpoint_path = args.save_dir 423 | print("=> loading best val checkpoint '{}'".format(checkpoint_path)) 424 | loc = 'cuda:{}'.format(gpu) 425 | checkpoint = torch.load(checkpoint_path, map_location=loc) 426 | model.load_state_dict(checkpoint['state_dict'], strict=False) 427 | print("=> loaded best val checkpoint '{}'".format(checkpoint_path)) 428 | evaluate_loop(test_loader, model, tokenizer, args.epochs, args, run, "test") 429 | 430 | def train_loop(train_loader, model, tokenizer, optimizer, epoch, scheduler, args, run): 431 | """ 432 | Train loop for one epoch. 433 | Args: 434 | train_loader (DataLoader): Training dataloader. 435 | model (nn.Module): Model to train. 436 | tokenizer (PreTrainedTokenizer): Tokenizer. 437 | optimizer (Optimizer): Optimizer to use. 438 | epoch (int): Current epoch. 439 | scheduler (Scheduler): Scheduler to use. 440 | args (Arguments): Arguments. 441 | run (wandb run): Wandb run. 442 | """ 443 | gpu, world_size = dist.get_rank(), dist.get_world_size() 444 | ngpus_per_node = torch.cuda.device_count() 445 | 446 | # Metrics 447 | batch_time = utils.AverageMeter('Time', ':6.3f') 448 | data_time = utils.AverageMeter('Data', ':6.3f') 449 | forward_time = utils.AverageMeter('Forward', ':6.3f') 450 | losses = utils.AverageMeter('Loss', ':.4e') 451 | 452 | # Progress bar 453 | if gpu % world_size == 0: 454 | progress = utils.ProgressMeter(args.steps_per_epoch, [batch_time, losses], prefix="Epoch: [{}]".format(epoch)) 455 | 456 | # Additional loss just to record the summary loss on decoder-only models 457 | if args.decoder_only: 458 | loss_fct = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) 459 | 460 | model.train() 461 | end = time.time() 462 | for i, batch in enumerate(train_loader): 463 | data_time.update(time.time() - end) 464 | batch = {k: v.cuda(gpu, non_blocking=True) for k, v in batch.items()} 465 | forward_start = time.time() 466 | outputs = model(**batch) 467 | forward_time.update(time.time() - forward_start) 468 | 469 | loss = outputs.loss 470 | if args.decoder_only: 471 | logits = outputs.logits 472 | # Only consider loss on reference summary just like encoder-decoder models 473 | shift_logits = logits[..., args.max_input_length:-1, :].contiguous() 474 | shift_labels = batch['labels'][..., (args.max_input_length + 1):].contiguous() 475 | # Ignore loss for some logits 476 | if shift_logits.shape[1] - shift_labels.shape[1] > 0: 477 | diff = shift_logits.shape[1] - shift_labels.shape[1] 478 | shift_logits = shift_logits[..., :-diff, :] 479 | # Summary_loss 480 | summary_loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 481 | losses.update(summary_loss.item(), batch["input_ids"].size(0)) 482 | else: 483 | losses.update(loss.item(), batch["input_ids"].size(0)) 484 | loss = loss / args.grad_accumulation_steps 485 | loss.backward() 486 | 487 | # Update weights every args.grad_accumulation_steps 488 | if ((i + 1) % args.grad_accumulation_steps == 0) or (i == args.steps_per_epoch - 1): 489 | optimizer.step() 490 | if scheduler is not None: 491 | scheduler.step() 492 | if args.grad_clip > 2: 493 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 494 | optimizer.zero_grad() 495 | 496 | # Log metrics every update step 497 | actual_step = (epoch * args.steps_per_epoch + i + 1) // args.grad_accumulation_steps 498 | if actual_step == 1 or actual_step % args.print_freq == 0: 499 | losses.all_reduce() 500 | batch_time.all_reduce() 501 | data_time.all_reduce() 502 | forward_time.all_reduce() 503 | ex_per_sec = (args.per_device_train_batch_size / batch_time.avg) * ngpus_per_node 504 | 505 | # Log only on the first GPU 506 | if gpu % world_size == 0: 507 | progress.display(i + 1) 508 | run.log({"train/loss": losses.avg}, step=actual_step) 509 | run.log({"metrics/total_secs_per_batch": batch_time.avg}, step=actual_step) 510 | run.log({"metrics/data_secs_per_batch": data_time.avg}, step=actual_step) 511 | run.log({"metrics/total_secs_captioning": forward_time.avg}, step=actual_step) 512 | run.log({"metrics/examples_per_sec": ex_per_sec}, step=actual_step) 513 | 514 | losses.reset() 515 | batch_time.reset() 516 | data_time.reset() 517 | forward_time.reset() 518 | 519 | # Measure elapsed time 520 | batch_time.update(time.time() - end) 521 | end = time.time() 522 | 523 | if i == args.steps_per_epoch - 1: 524 | break 525 | 526 | 527 | def evaluate_loop(val_loader, model, tokenizer, epoch, args, run, prefix="val"): 528 | """ 529 | Evaluate loop. 530 | Args: 531 | val_loader (DataLoader): Validation dataloader. 532 | model (nn.Module): Model to evaluate. 533 | tokenizer (PreTrainedTokenizer): Tokenizer. 534 | epoch (int): Current epoch. 535 | args (Arguments): Arguments. 536 | run (wandb run): Wandb run. 537 | prefix (str): Prefix to use for logging. 538 | """ 539 | 540 | gpu, world_size = dist.get_rank(), dist.get_world_size() 541 | ngpus_per_node = torch.cuda.device_count() 542 | 543 | # Three metrics to evaluate summarization: BLEU, ROUGE, CIDEr 544 | bleu_scorers = [BLEUScore(n_gram=i) for i in [1, 2, 3, 4]] 545 | rouge_scorer = ROUGEScore() 546 | cider_scorer = Cider() 547 | actual_step = ((epoch + 1) * args.steps_per_epoch) // args.grad_accumulation_steps 548 | 549 | batch_time = utils.AverageMeter('Time', ':6.3f', utils.Summary.AVERAGE) 550 | losses = utils.AverageMeter('Loss', ':.4e', utils.Summary.AVERAGE) 551 | bleu1 = utils.AverageMeter('BLEU@1', ':6.2f', utils.Summary.AVERAGE) 552 | bleu2 = utils.AverageMeter('BLEU@2', ':6.2f', utils.Summary.AVERAGE) 553 | bleu3 = utils.AverageMeter('BLEU@3', ':6.2f', utils.Summary.AVERAGE) 554 | bleu4 = utils.AverageMeter('BLEU@4', ':6.2f', utils.Summary.AVERAGE) 555 | rouge1 = utils.AverageMeter('ROUGE@1', ':6.2f', utils.Summary.AVERAGE) 556 | rouge2 = utils.AverageMeter('ROUGE@2', ':6.2f', utils.Summary.AVERAGE) 557 | rougeL = utils.AverageMeter('ROUGE@L', ':6.2f', utils.Summary.AVERAGE) 558 | rougeLsum = utils.AverageMeter('ROUGE@Lsum', ':6.2f', utils.Summary.AVERAGE) 559 | cider = utils.AverageMeter('CIDER', ':6.2f', utils.Summary.AVERAGE) 560 | 561 | # Progress bar 562 | if gpu % world_size == 0: 563 | progress = utils.ProgressMeter(args.val_steps_per_epoch, [batch_time, losses], prefix=f'{prefix}: ') 564 | 565 | # Additional loss just to record the summary loss on decoder-only models 566 | if args.decoder_only: 567 | loss_fct = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id) 568 | 569 | # Switch to evaluate mode 570 | model.eval() 571 | with torch.no_grad(): 572 | end = time.time() 573 | all_generated_captions = [] 574 | all_gt_captions = [] 575 | max_to_display = 5 576 | 577 | for i, batch in enumerate(val_loader): 578 | batch = {k: v.cuda(gpu, non_blocking=True) for k, v in batch.items()} 579 | 580 | outputs = model(**batch) 581 | logits = outputs.logits 582 | if args.decoder_only: 583 | # Only consider loss on reference summary just like encoder-decoder models 584 | logits = logits[..., args.max_input_length:-1, :].contiguous() 585 | labels = batch['labels'][..., (args.max_input_length + 1):].contiguous() 586 | 587 | # Ignore loss for some logits 588 | if logits.shape[1] - labels.shape[1] > 0: 589 | diff = logits.shape[1] - labels.shape[1] 590 | logits = logits[..., :-diff, :] 591 | loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) 592 | else: 593 | labels = batch['labels'] 594 | loss = outputs.loss 595 | losses.update(loss.item(), batch["input_ids"].size(0)) 596 | 597 | if prefix == "test": 598 | # Generate tokens sequentially 599 | if args.decoder_only: 600 | generated_ids = model.module.generate(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], max_new_tokens=32) 601 | else: 602 | generated_ids = model.module.generate(input_ids=batch["input_ids"][..., :args.max_input_length, :].contiguous(), \ 603 | attention_mask=batch["attention_mask"][..., :args.max_input_length, :].contiguous(), max_new_tokens=32) 604 | else: 605 | # Generate tokens based on the input text 606 | generated_ids = torch.argmax(logits, dim=-1) 607 | 608 | all_generated_ids = [torch.zeros_like(generated_ids) for _ in range(dist.get_world_size())] 609 | dist.all_gather(all_generated_ids, generated_ids) 610 | all_generated_ids[dist.get_rank()] = generated_ids 611 | generated_ids = torch.cat(all_generated_ids) 612 | 613 | tgt_tokens = labels 614 | all_tgt_tokens = [torch.zeros_like(tgt_tokens) for _ in range(dist.get_world_size())] 615 | dist.all_gather(all_tgt_tokens, tgt_tokens) 616 | all_tgt_tokens[dist.get_rank()] = tgt_tokens 617 | all_tgt_tokens = torch.cat(all_tgt_tokens) 618 | 619 | if not args.decoder_only: 620 | all_tgt_tokens[all_tgt_tokens == -100] = tokenizer.pad_token_id 621 | generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) 622 | gt_captions = tokenizer.batch_decode(all_tgt_tokens, skip_special_tokens=True) 623 | 624 | for cap_i in range(len(generated_captions)): 625 | stop_idx = generated_captions[cap_i].find('.') 626 | if stop_idx > 5: 627 | all_generated_captions.append(generated_captions[cap_i][:stop_idx]) 628 | else: 629 | all_generated_captions.append(generated_captions[cap_i]) 630 | all_gt_captions.append([gt_captions[cap_i]]) 631 | 632 | # Measure elapsed time 633 | batch_time.update(time.time() - end) 634 | end = time.time() 635 | 636 | if i % args.print_freq == 0 and gpu % world_size == 0: 637 | progress.display(i + 1) 638 | 639 | if i == args.val_steps_per_epoch - 1: 640 | break 641 | 642 | if gpu % world_size == 0: 643 | print('=' * 30) 644 | print(f'Computing BLEU with {len(all_generated_captions)} generated captions and {len(all_gt_captions)} groundtruth captions.') 645 | for cap_i, cap in enumerate(all_generated_captions[:max_to_display]): 646 | print(f'{cap_i}) {cap}') 647 | print('=' * 30) 648 | print('Real samples:') 649 | for cap_i, cap in enumerate(all_gt_captions[:max_to_display]): 650 | print(f'{cap_i}) {cap}') 651 | print('=' * 30) 652 | 653 | bleu1_score = bleu_scorers[0](all_generated_captions, all_gt_captions) 654 | bleu1.update(bleu1_score, 1) 655 | bleu2_score = bleu_scorers[1](all_generated_captions, all_gt_captions) 656 | bleu2.update(bleu2_score, 1) 657 | bleu3_score = bleu_scorers[2](all_generated_captions, all_gt_captions) 658 | bleu3.update(bleu3_score, 1) 659 | bleu4_score = bleu_scorers[3](all_generated_captions, all_gt_captions) 660 | bleu4.update(bleu4_score, 1) 661 | 662 | rouge_scores = rouge_scorer(all_generated_captions, all_gt_captions) 663 | rouge1.update(rouge_scores['rouge1_fmeasure'], 1) 664 | rouge2.update(rouge_scores['rouge2_fmeasure'], 1) 665 | rougeL.update(rouge_scores['rougeL_fmeasure'], 1) 666 | rougeLsum.update(rouge_scores['rougeLsum_fmeasure'], 1) 667 | 668 | cands = {idx: [pred] for idx, pred in enumerate(all_generated_captions)} 669 | refs = {idx: [label] for idx, label in enumerate(all_gt_captions)} 670 | cider_scores, _ = cider_scorer.compute_score(refs, cands) 671 | cider.update(cider_scores, 1) 672 | 673 | batch_time.all_reduce() 674 | losses.all_reduce() 675 | bleu1.all_reduce() 676 | bleu2.all_reduce() 677 | bleu3.all_reduce() 678 | bleu4.all_reduce() 679 | rouge1.all_reduce() 680 | rouge2.all_reduce() 681 | rougeL.all_reduce() 682 | rougeLsum.all_reduce() 683 | cider.all_reduce() 684 | 685 | if gpu % world_size == 0: 686 | progress.display_summary() 687 | print("BLEU", bleu1.avg, bleu2.avg, bleu3.avg, bleu4.avg) 688 | print("ROUGE", rouge1.avg, rouge2.avg, rougeL.avg, rougeLsum.avg) 689 | print("CIDER", cider.avg) 690 | 691 | run.log({f"{prefix}/total_secs_per_batch": batch_time.avg}, step=actual_step) 692 | run.log({f"{prefix}/loss": losses.avg}, step=actual_step) 693 | run.log({f"{prefix}/bleu1": bleu1.avg}, step=actual_step) 694 | run.log({f"{prefix}/bleu2": bleu2.avg}, step=actual_step) 695 | run.log({f"{prefix}/bleu3": bleu3.avg}, step=actual_step) 696 | run.log({f"{prefix}/bleu4": bleu4.avg}, step=actual_step) 697 | run.log({f"{prefix}/rouge1": rouge1.avg}, step=actual_step) 698 | run.log({f"{prefix}/rouge2": rouge2.avg}, step=actual_step) 699 | run.log({f"{prefix}/rougeL": rougeL.avg}, step=actual_step) 700 | run.log({f"{prefix}/rougeLsum": rougeLsum.avg}, step=actual_step) 701 | run.log({f"{prefix}/cider": cider.avg}, step=actual_step) 702 | 703 | return bleu4.avg 704 | 705 | 706 | if __name__ == "__main__": 707 | main() 708 | -------------------------------------------------------------------------------- /language_modelling/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from transformers import AutoFeatureExtractor 3 | from PIL import Image 4 | import shutil 5 | import torch 6 | import torch.distributed as dist 7 | 8 | import nltk 9 | try: 10 | nltk.data.find("tokenizers/punkt") 11 | except (LookupError, OSError): 12 | nltk.download("punkt", quiet=True) 13 | 14 | 15 | def get_feature_extractor_for_model(model_name: str): 16 | print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.') 17 | feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) 18 | return feature_extractor 19 | 20 | 21 | def get_pixel_values_for_model(feature_extractor, img: Image.Image): 22 | pixel_values = feature_extractor(img.convert('RGB'), return_tensors="pt").pixel_values[0, ...] # (3, H, W) 23 | return pixel_values 24 | 25 | 26 | def get_params_count(model, max_name_len: int = 60): 27 | params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()] 28 | total_trainable_params = sum([x[1] for x in params if x[-1]]) 29 | total_nontrainable_params = sum([x[1] for x in params if not x[-1]]) 30 | return params, total_trainable_params, total_nontrainable_params 31 | 32 | 33 | def get_params_count_str(model, max_name_len: int = 60): 34 | padding = 70 # Hardcoded depending on desired amount of padding and separators. 35 | params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len) 36 | param_counts_text = '' 37 | param_counts_text += '=' * (max_name_len + padding) + '\n' 38 | param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n' 39 | param_counts_text += '-' * (max_name_len + padding) + '\n' 40 | for name, param_count, shape, trainable in params: 41 | param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n' 42 | param_counts_text += '-' * (max_name_len + padding) + '\n' 43 | param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n' 44 | param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n' 45 | param_counts_text += '=' * (max_name_len + padding) + '\n' 46 | return param_counts_text 47 | 48 | 49 | def save_checkpoint(state, is_best, filename='checkpoint'): 50 | torch.save(state, filename + '.pth.tar') 51 | if is_best: 52 | shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar') 53 | 54 | 55 | def postprocess_text(preds, labels): 56 | preds = [pred.strip() for pred in preds] 57 | labels = [label.strip() for label in labels] 58 | 59 | # rougeLSum expects newline after each sentence 60 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 61 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 62 | 63 | return preds, labels 64 | 65 | 66 | class Summary(Enum): 67 | NONE = 0 68 | AVERAGE = 1 69 | SUM = 2 70 | COUNT = 3 71 | 72 | class ProgressMeter(object): 73 | def __init__(self, num_batches, meters, prefix=""): 74 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 75 | self.meters = meters 76 | self.prefix = prefix 77 | 78 | def display(self, batch): 79 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 80 | entries += [str(meter) for meter in self.meters] 81 | print('\t'.join(entries)) 82 | 83 | def display_summary(self): 84 | entries = [" *"] 85 | entries += [meter.summary() for meter in self.meters] 86 | print(' '.join(entries)) 87 | 88 | def _get_batch_fmtstr(self, num_batches): 89 | num_digits = len(str(num_batches // 1)) 90 | fmt = '{:' + str(num_digits) + 'd}' 91 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 92 | 93 | class AverageMeter(object): 94 | """Computes and stores the average and current value""" 95 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE): 96 | self.name = name 97 | self.fmt = fmt 98 | self.summary_type = summary_type 99 | self.reset() 100 | 101 | def reset(self): 102 | self.val = 0 103 | self.avg = 0 104 | self.sum = 0 105 | self.count = 0 106 | 107 | def update(self, val, n=1): 108 | self.val = val 109 | self.sum += val * n 110 | self.count += n 111 | self.avg = self.sum / self.count 112 | 113 | def all_reduce(self): 114 | device = "cuda" if torch.cuda.is_available() else "cpu" 115 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device) 116 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False) 117 | self.sum, self.count = total.tolist() 118 | self.avg = self.sum / self.count 119 | 120 | def __str__(self): 121 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 122 | return fmtstr.format(**self.__dict__) 123 | 124 | def summary(self): 125 | fmtstr = '' 126 | if self.summary_type is Summary.NONE: 127 | fmtstr = '' 128 | elif self.summary_type is Summary.AVERAGE: 129 | fmtstr = '{name} {avg:.3f}' 130 | elif self.summary_type is Summary.SUM: 131 | fmtstr = '{name} {sum:.3f}' 132 | elif self.summary_type is Summary.COUNT: 133 | fmtstr = '{name} {count:.3f}' 134 | else: 135 | raise ValueError('invalid summary type %r' % self.summary_type) 136 | 137 | return fmtstr.format(**self.__dict__) 138 | 139 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .modelling_self_attention import SelfAttentionModel 2 | from .modelling_cross_attention import CrossAttentionModel 3 | -------------------------------------------------------------------------------- /model/graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class GCN(nn.Module): 7 | 8 | def __init__(self, input_dim, output_dim, hidden_dim): 9 | super(GCN, self).__init__() 10 | self.input_dim = input_dim 11 | self.output_dim = output_dim 12 | self.hidden_dim = hidden_dim 13 | 14 | self.w1 = nn.Linear(2 * input_dim, hidden_dim, bias=False) 15 | self.w2 = nn.Linear(2 * hidden_dim, output_dim, bias=False) 16 | 17 | def forward(self, X, adj): 18 | null_root = torch.zeros((X.shape[0], 1, X.shape[2])).to(X.device) 19 | X = torch.cat((null_root, X), dim=1) 20 | batch_size, node_num, _ = X.shape 21 | 22 | agg = torch.bmm(adj, X) 23 | X = torch.cat((X, agg), dim=-1) 24 | X = self.w1(X.view(-1, 2 * self.input_dim)).view(batch_size, node_num, self.hidden_dim) 25 | X = F.relu(X) 26 | 27 | agg = torch.bmm(adj, X) 28 | X = torch.cat((X, agg), dim=-1) 29 | X = self.w2(X.view(-1, 2 * self.hidden_dim)).view(batch_size, node_num, self.output_dim) 30 | 31 | return X[:, 1:, :] 32 | -------------------------------------------------------------------------------- /model/modelling_cross_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code block is adapted from: 3 | - Repository: OPT model in Huggingface Transformers 4 | - Author: Huggingface 5 | - Link: https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py 6 | """ 7 | # coding=utf-8 8 | # Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved. 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | # 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | """ PyTorch MPT model.""" 22 | import math 23 | from typing import List, Optional, Tuple, Union 24 | import torch 25 | import torch.utils.checkpoint 26 | from torch import nn 27 | from torch.nn import CrossEntropyLoss 28 | 29 | from transformers.activations import ACT2FN 30 | from transformers.modeling_outputs import ( 31 | BaseModelOutputWithPast, 32 | CausalLMOutputWithPast, 33 | ) 34 | from transformers.modeling_utils import PreTrainedModel 35 | from transformers import ( 36 | AutoConfig, 37 | PretrainedConfig, 38 | AutoTokenizer, 39 | AutoModelForSeq2SeqLM, 40 | AutoModelForCausalLM, 41 | CLIPVisionModel, 42 | CLIPTextModel, 43 | RobertaModel, 44 | ) 45 | 46 | from transformers.utils import logging 47 | logger = logging.get_logger(__name__) 48 | 49 | 50 | # Copied from transformers.models.bart.modeling_bart._make_causal_mask 51 | def _make_causal_mask( 52 | input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 53 | ): 54 | """ 55 | Make causal mask used for bi-directional self-attention. 56 | """ 57 | bsz, tgt_len = input_ids_shape 58 | mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) 59 | mask_cond = torch.arange(mask.size(-1), device=device) 60 | mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) 61 | mask = mask.to(dtype) 62 | 63 | if past_key_values_length > 0: 64 | mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) 65 | return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) 66 | 67 | 68 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 69 | """ 70 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 71 | """ 72 | bsz, src_len = mask.size() 73 | tgt_len = tgt_len if tgt_len is not None else src_len 74 | 75 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 76 | 77 | inverted_mask = 1.0 - expanded_mask 78 | 79 | return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) 80 | 81 | 82 | class MPTConfig(PretrainedConfig): 83 | 84 | def __init__(self, args, opt_config, **kwargs): 85 | super().__init__( 86 | pad_token_id=opt_config.pad_token_id, 87 | bos_token_id=opt_config.bos_token_id, 88 | eos_token_id=opt_config.eos_token_id, 89 | **kwargs, 90 | ) 91 | # MPT configuration 92 | self.neighbor_layer_wise = args.neighbor_layer_wise 93 | self.neighbor_mode = args.neighbor_mode 94 | self.peft_type = args.peft_type 95 | self.lora_r = args.lora_r 96 | self.lora_alpha = args.lora_alpha 97 | self.lora_dropout = args.lora_dropout 98 | 99 | # OPT configuration 100 | self.vocab_size = opt_config.vocab_size 101 | self.max_position_embeddings = opt_config.max_position_embeddings 102 | self.num_attention_heads = opt_config.num_attention_heads 103 | self.word_embed_proj_dim = opt_config.word_embed_proj_dim 104 | self.ffn_dim = opt_config.ffn_dim 105 | self.hidden_size = opt_config.hidden_size 106 | self.num_hidden_layers = opt_config.num_hidden_layers 107 | self.dropout = opt_config.dropout 108 | self.attention_dropout = opt_config.attention_dropout 109 | self.activation_function = opt_config.activation_function 110 | self.init_std = opt_config.init_std 111 | self.layerdrop = opt_config.layerdrop 112 | self.use_cache = opt_config.use_cache 113 | self.do_layer_norm_before = opt_config.do_layer_norm_before 114 | # We keep these variables at `True` for backward compatibility. 115 | self.enable_bias = opt_config.enable_bias 116 | self.layer_norm_elementwise_affine = opt_config.layer_norm_elementwise_affine 117 | 118 | # Note that the only purpose of `_remove_final_layer_norm` is to keep backward compatibility 119 | # with checkpoints that have been fine-tuned before transformers v4.20.1 120 | # see https://github.com/facebookresearch/metaseq/pull/164 121 | self._remove_final_layer_norm = opt_config._remove_final_layer_norm 122 | 123 | 124 | class MPTLearnedPositionalEmbedding(nn.Embedding): 125 | """ 126 | This module learns positional embeddings up to a fixed maximum size. 127 | """ 128 | 129 | def __init__(self, num_embeddings: int, embedding_dim: int): 130 | # OPT is set up so that if padding_idx is specified then offset the embedding ids by 2 131 | # and adjust num_embeddings appropriately. Other models don't have this hack 132 | self.offset = 2 133 | super().__init__(num_embeddings + self.offset, embedding_dim) 134 | 135 | def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): 136 | """`input_ids_shape` is expected to be [bsz x seqlen].""" 137 | attention_mask = attention_mask.long() 138 | 139 | # create positions depending on attention_mask 140 | positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 141 | 142 | # cut positions if `past_key_values_length` is > 0 143 | positions = positions[:, past_key_values_length:] 144 | 145 | return super().forward(positions + self.offset) 146 | 147 | 148 | class MPTAttention(nn.Module): 149 | """Multi-headed attention from 'Attention Is All You Need' paper""" 150 | 151 | def __init__(self, config, cross_attention): 152 | super().__init__() 153 | 154 | self.embed_dim = config.hidden_size 155 | self.num_heads = config.num_attention_heads 156 | self.dropout = config.attention_dropout 157 | self.head_dim = self.embed_dim // self.num_heads 158 | bias = config.enable_bias 159 | 160 | if (self.head_dim * self.num_heads) != self.embed_dim: 161 | raise ValueError( 162 | f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" 163 | f" and `num_heads`: {self.num_heads})." 164 | ) 165 | self.scaling = self.head_dim**-0.5 166 | self.is_decoder = False 167 | 168 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) 169 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) 170 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) 171 | 172 | self.cross_attention = cross_attention 173 | self.peft_type = config.peft_type 174 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias) 175 | 176 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 177 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 178 | 179 | def forward( 180 | self, 181 | hidden_states: torch.Tensor, 182 | attention_mask: Optional[torch.Tensor] = None, 183 | neighbor_embeds: Optional[torch.Tensor] = None, 184 | neighbor_attention_mask: Optional[torch.Tensor] = None, 185 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 186 | layer_head_mask: Optional[torch.Tensor] = None, 187 | output_attentions: bool = False, 188 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 189 | """Input shape: Batch x Time x Channel""" 190 | 191 | bsz, tgt_len, _ = hidden_states.size() 192 | 193 | # get query proj 194 | query_states = self.q_proj(hidden_states) * self.scaling 195 | # get key, value proj 196 | if self.cross_attention: 197 | # cross_attentions 198 | key_states = self._shape(self.k_proj(neighbor_embeds), -1, bsz) 199 | value_states = self._shape(self.v_proj(neighbor_embeds), -1, bsz) 200 | attention_mask = neighbor_attention_mask 201 | else: 202 | # self_attention 203 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 204 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 205 | 206 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 207 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 208 | key_states = key_states.view(*proj_shape) 209 | value_states = value_states.view(*proj_shape) 210 | 211 | src_len = key_states.size(1) 212 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 213 | 214 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 215 | raise ValueError( 216 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" 217 | f" {attn_weights.size()}" 218 | ) 219 | 220 | if attention_mask is not None: 221 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 222 | raise ValueError( 223 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 224 | ) 225 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 226 | attn_weights = torch.max( 227 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) 228 | ) 229 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 230 | 231 | # upcast to fp32 if the weights are in fp16. Please see https://github.com/huggingface/transformers/pull/17437 232 | if attn_weights.dtype == torch.float16: 233 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(torch.float16) 234 | else: 235 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 236 | 237 | if layer_head_mask is not None: 238 | if layer_head_mask.size() != (self.num_heads,): 239 | raise ValueError( 240 | f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" 241 | f" {layer_head_mask.size()}" 242 | ) 243 | attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 244 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 245 | 246 | if output_attentions: 247 | # this operation is a bit awkward, but it's required to 248 | # make sure that attn_weights keeps its gradient. 249 | # In order to do so, attn_weights have to be reshaped 250 | # twice and have to be reused in the following 251 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 252 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 253 | else: 254 | attn_weights_reshaped = None 255 | 256 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 257 | 258 | attn_output = torch.bmm(attn_probs, value_states) 259 | 260 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 261 | raise ValueError( 262 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" 263 | f" {attn_output.size()}" 264 | ) 265 | 266 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 267 | attn_output = attn_output.transpose(1, 2) 268 | 269 | # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be 270 | # partitioned aross GPUs when using tensor-parallelism. 271 | attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) 272 | 273 | attn_output = self.out_proj(attn_output) 274 | 275 | return attn_output, attn_weights_reshaped, None 276 | 277 | 278 | class MPTDecoderLayer(nn.Module): 279 | def __init__(self, config, cross_attention=False): 280 | super().__init__() 281 | self.embed_dim = config.hidden_size 282 | self.self_attn = MPTAttention(config, cross_attention) 283 | self.do_layer_norm_before = config.do_layer_norm_before 284 | self.dropout = config.dropout 285 | self.activation_fn = ACT2FN[config.activation_function] 286 | 287 | self.self_attn_layer_norm = nn.LayerNorm( 288 | self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine 289 | ) 290 | 291 | self.fc1 = nn.Linear(self.embed_dim, config.ffn_dim, bias=config.enable_bias) 292 | self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) 293 | 294 | self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) 295 | 296 | self.cross_attention = cross_attention 297 | self.peft_type = config.peft_type 298 | if self.cross_attention and self.peft_type == "flamingo": 299 | self.tanh_layer1 = nn.Tanh() 300 | self.tanh_layer2 = nn.Tanh() 301 | self.gating1 = nn.Parameter(torch.tensor(0.0)) 302 | self.gating2 = nn.Parameter(torch.tensor(0.0)) 303 | 304 | def forward( 305 | self, 306 | hidden_states: torch.Tensor, 307 | attention_mask: Optional[torch.Tensor] = None, 308 | neighbor_embeds: Optional[torch.FloatTensor] = None, 309 | neighbor_attention_mask: Optional[torch.Tensor] = None, 310 | layer_head_mask: Optional[torch.Tensor] = None, 311 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 312 | output_attentions: Optional[bool] = False, 313 | use_cache: Optional[bool] = False, 314 | ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: 315 | 316 | residual = hidden_states 317 | 318 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention 319 | if self.do_layer_norm_before: 320 | hidden_states = self.self_attn_layer_norm(hidden_states) 321 | 322 | # Self Attention 323 | hidden_states, self_attn_weights, present_key_value = self.self_attn( 324 | hidden_states=hidden_states, 325 | attention_mask=attention_mask, 326 | neighbor_embeds=neighbor_embeds, 327 | neighbor_attention_mask=neighbor_attention_mask, 328 | layer_head_mask=layer_head_mask, 329 | past_key_value=past_key_value, 330 | output_attentions=output_attentions, 331 | ) 332 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 333 | 334 | if self.cross_attention and self.peft_type == "flamingo": 335 | hidden_states = residual + self.tanh_layer1(self.gating1) * hidden_states 336 | else: 337 | hidden_states = residual + hidden_states 338 | 339 | # 350m applies layer norm AFTER attention 340 | if not self.do_layer_norm_before: 341 | hidden_states = self.self_attn_layer_norm(hidden_states) 342 | 343 | # Fully Connected 344 | hidden_states_shape = hidden_states.shape 345 | hidden_states = hidden_states.reshape(-1, hidden_states.size(-1)) 346 | residual = hidden_states 347 | 348 | # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention 349 | if self.do_layer_norm_before: 350 | hidden_states = self.final_layer_norm(hidden_states) 351 | 352 | hidden_states = self.fc1(hidden_states) 353 | hidden_states = self.activation_fn(hidden_states) 354 | 355 | hidden_states = self.fc2(hidden_states) 356 | hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) 357 | 358 | if self.cross_attention and self.peft_type == "flamingo": 359 | hidden_states = (residual + self.tanh_layer2(self.gating2) * hidden_states).view(hidden_states_shape) 360 | else: 361 | hidden_states = (residual + hidden_states).view(hidden_states_shape) 362 | 363 | # 350m applies layer norm AFTER attention 364 | if not self.do_layer_norm_before: 365 | hidden_states = self.final_layer_norm(hidden_states) 366 | 367 | outputs = (hidden_states,) 368 | 369 | if output_attentions: 370 | outputs += (self_attn_weights,) 371 | 372 | if use_cache: 373 | outputs += (present_key_value,) 374 | 375 | return outputs 376 | 377 | 378 | class MPTPreTrainedModel(PreTrainedModel): 379 | config_class = MPTConfig 380 | base_model_prefix = "model" 381 | supports_gradient_checkpointing = True 382 | _no_split_modules = ["MPTDecoderLayer"] 383 | 384 | def _init_weights(self, module): 385 | std = self.config.init_std 386 | if isinstance(module, nn.Linear): 387 | module.weight.data.normal_(mean=0.0, std=std) 388 | if module.bias is not None: 389 | module.bias.data.zero_() 390 | elif isinstance(module, nn.Embedding): 391 | module.weight.data.normal_(mean=0.0, std=std) 392 | if module.padding_idx is not None: 393 | module.weight.data[module.padding_idx].zero_() 394 | 395 | def _set_gradient_checkpointing(self, module, value=False): 396 | if isinstance(module, (MPTDecoder)): 397 | module.gradient_checkpointing = value 398 | 399 | 400 | class MPTDecoder(MPTPreTrainedModel): 401 | 402 | def __init__(self, config: MPTConfig): 403 | super().__init__(config) 404 | self.dropout = config.dropout 405 | self.layerdrop = config.layerdrop 406 | self.padding_idx = config.pad_token_id 407 | self.max_target_positions = config.max_position_embeddings 408 | self.vocab_size = config.vocab_size 409 | 410 | self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx) 411 | self.embed_positions = MPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size) 412 | 413 | if config.word_embed_proj_dim != config.hidden_size: 414 | self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False) 415 | else: 416 | self.project_out = None 417 | 418 | if config.word_embed_proj_dim != config.hidden_size: 419 | self.project_in = nn.Linear(config.word_embed_proj_dim, config.hidden_size, bias=False) 420 | else: 421 | self.project_in = None 422 | 423 | # Note that the only purpose of `config._remove_final_layer_norm` is to keep backward compatibility 424 | # with checkpoints that have been fine-tuned before transformers v4.20.1 425 | # see https://github.com/facebookresearch/metaseq/pull/164 426 | if config.do_layer_norm_before and not config._remove_final_layer_norm: 427 | self.final_layer_norm = nn.LayerNorm( 428 | config.hidden_size, elementwise_affine=config.layer_norm_elementwise_affine 429 | ) 430 | else: 431 | self.final_layer_norm = None 432 | 433 | self.cross_attention = (config.neighbor_mode == "cross_attention") 434 | self.neighbor_layer_wise = config.neighbor_layer_wise 435 | self.peft_type = config.peft_type 436 | 437 | self.layers = nn.ModuleList() 438 | self.neighbor_layers = nn.ModuleList() 439 | for l in range(config.num_hidden_layers): 440 | self.layers.append(MPTDecoderLayer(config)) 441 | if self.cross_attention and (l + 1) % self.neighbor_layer_wise == 0: 442 | self.neighbor_layers.append(MPTDecoderLayer(config, cross_attention=True)) 443 | 444 | self.gradient_checkpointing = False 445 | # Initialize weights and apply final processing 446 | self.post_init() 447 | 448 | def get_input_embeddings(self): 449 | return self.embed_tokens 450 | 451 | def set_input_embeddings(self, value): 452 | self.embed_tokens = value 453 | 454 | # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask 455 | def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): 456 | # create causal mask 457 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 458 | combined_attention_mask = None 459 | if input_shape[-1] > 1: 460 | combined_attention_mask = _make_causal_mask( 461 | input_shape, 462 | inputs_embeds.dtype, 463 | device=inputs_embeds.device, 464 | past_key_values_length=past_key_values_length, 465 | ) 466 | 467 | if attention_mask is not None: 468 | # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] 469 | expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( 470 | inputs_embeds.device 471 | ) 472 | combined_attention_mask = ( 473 | expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask 474 | ) 475 | 476 | return combined_attention_mask 477 | 478 | def forward( 479 | self, 480 | input_ids: torch.LongTensor = None, 481 | attention_mask: Optional[torch.Tensor] = None, 482 | head_mask: Optional[torch.Tensor] = None, 483 | past_key_values: Optional[List[torch.FloatTensor]] = None, 484 | inputs_embeds: Optional[torch.FloatTensor] = None, 485 | neighbor_embeds: Optional[torch.FloatTensor] = None, 486 | neighbor_attention_mask: Optional[torch.Tensor] = None, 487 | use_cache: Optional[bool] = None, 488 | output_attentions: Optional[bool] = None, 489 | output_hidden_states: Optional[bool] = None, 490 | return_dict: Optional[bool] = None, 491 | ) -> Union[Tuple, BaseModelOutputWithPast]: 492 | """ 493 | Args: 494 | input_ids : token ids of input text 495 | attention_mask : attention_mask of input text 496 | head_mask : mask selected heads of the attention modules 497 | past_key_values : previous key values of the decoder 498 | inputs_embeds : embeddings of input text 499 | neighbor_embeds : embeddings of neighbor text/images 500 | neighbor_attention_mask : attention mask of neighbor text/images 501 | use_cache : whether to use cache 502 | output_attentions : whether to output attentions 503 | output_hidden_states : whether to output hidden states 504 | return_dict : whether to return a dict 505 | """ 506 | 507 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 508 | output_hidden_states = ( 509 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 510 | ) 511 | use_cache = use_cache if use_cache is not None else self.config.use_cache 512 | 513 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 514 | 515 | # retrieve input_ids and inputs_embeds 516 | if input_ids is not None and inputs_embeds is not None: 517 | raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") 518 | elif input_ids is not None: 519 | input_shape = input_ids.size() 520 | input_ids = input_ids.view(-1, input_shape[-1]) 521 | elif inputs_embeds is not None: 522 | input_shape = inputs_embeds.size()[:-1] 523 | else: 524 | raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") 525 | 526 | if inputs_embeds is None: 527 | inputs_embeds = self.embed_tokens(input_ids) 528 | 529 | batch_size, seq_length = input_shape 530 | past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 531 | # required mask seq length can be calculated via length of past 532 | mask_seq_length = past_key_values_length + seq_length 533 | 534 | # embed positions 535 | if attention_mask is None: 536 | attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) 537 | elif attention_mask.shape[1] != mask_seq_length: 538 | raise ValueError( 539 | f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " 540 | f"{mask_seq_length} (sum of the lengths of current and past inputs)" 541 | ) 542 | causal_attention_mask = self._prepare_decoder_attention_mask( 543 | attention_mask, input_shape, inputs_embeds, past_key_values_length 544 | ) 545 | if neighbor_attention_mask is not None: 546 | neighbor_attention_mask = _expand_mask(neighbor_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to(inputs_embeds.device) 547 | 548 | pos_embeds = self.embed_positions(attention_mask, past_key_values_length) 549 | 550 | if self.project_in is not None: 551 | inputs_embeds = self.project_in(inputs_embeds) 552 | 553 | hidden_states = inputs_embeds + pos_embeds 554 | 555 | if self.gradient_checkpointing and self.training: 556 | if use_cache: 557 | logger.warning_once( 558 | "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." 559 | ) 560 | use_cache = False 561 | 562 | # decoder layers 563 | all_hidden_states = () if output_hidden_states else None 564 | all_self_attns = () if output_attentions else None 565 | next_decoder_cache = () if use_cache else None 566 | 567 | # check if head_mask has a correct number of layers specified if desired 568 | for attn_mask, mask_name in zip([head_mask], ["head_mask"]): 569 | if attn_mask is not None: 570 | if attn_mask.size()[0] != (len(self.layers)): 571 | raise ValueError( 572 | f"The `{mask_name}` should be specified for {len(self.layers)} layers, but it is for" 573 | f" {head_mask.size()[0]}." 574 | ) 575 | 576 | for idx, decoder_layer in enumerate(self.layers): 577 | # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) 578 | if output_hidden_states: 579 | all_hidden_states += (hidden_states,) 580 | 581 | if self.training: 582 | dropout_probability = torch.rand([]) 583 | if dropout_probability < self.layerdrop: 584 | continue 585 | 586 | past_key_value = past_key_values[idx] if past_key_values is not None else None 587 | 588 | if self.gradient_checkpointing and self.training: 589 | 590 | def create_custom_forward(module): 591 | def custom_forward(*inputs): 592 | # None for past_key_value 593 | return module(*inputs, output_attentions, None) 594 | 595 | return custom_forward 596 | 597 | layer_outputs = torch.utils.checkpoint.checkpoint( 598 | create_custom_forward(decoder_layer), 599 | hidden_states, 600 | causal_attention_mask, 601 | head_mask[idx] if head_mask is not None else None, 602 | None, 603 | ) 604 | else: 605 | layer_outputs = decoder_layer( 606 | hidden_states, 607 | attention_mask=causal_attention_mask, 608 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 609 | past_key_value=past_key_value, 610 | output_attentions=output_attentions, 611 | use_cache=use_cache, 612 | ) 613 | if self.cross_attention and (idx + 1) % self.neighbor_layer_wise == 0: 614 | hidden_states = layer_outputs[0] 615 | neighbor_idx = (idx + 1) // self.neighbor_layer_wise - 1 616 | layer_outputs = self.neighbor_layers[neighbor_idx]( 617 | hidden_states, 618 | attention_mask=causal_attention_mask, 619 | neighbor_embeds=neighbor_embeds, 620 | neighbor_attention_mask=neighbor_attention_mask, 621 | layer_head_mask=(head_mask[idx] if head_mask is not None else None), 622 | past_key_value=past_key_value, 623 | output_attentions=output_attentions, 624 | use_cache=use_cache, 625 | ) 626 | 627 | hidden_states = layer_outputs[0] 628 | 629 | if use_cache: 630 | next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) 631 | 632 | if output_attentions: 633 | all_self_attns += (layer_outputs[1],) 634 | 635 | if self.final_layer_norm is not None: 636 | hidden_states = self.final_layer_norm(hidden_states) 637 | 638 | if self.project_out is not None: 639 | hidden_states = self.project_out(hidden_states) 640 | 641 | # add hidden states from the last decoder layer 642 | if output_hidden_states: 643 | all_hidden_states += (hidden_states,) 644 | 645 | next_cache = next_decoder_cache if use_cache else None 646 | if not return_dict: 647 | return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) 648 | return BaseModelOutputWithPast( 649 | last_hidden_state=hidden_states, 650 | past_key_values=next_cache, 651 | hidden_states=all_hidden_states, 652 | attentions=all_self_attns, 653 | ) 654 | 655 | 656 | class MPTModel(MPTPreTrainedModel): 657 | def __init__(self, config: MPTConfig): 658 | super().__init__(config) 659 | self.decoder = MPTDecoder(config) 660 | # Initialize weights and apply final processing 661 | self.post_init() 662 | 663 | def get_input_embeddings(self): 664 | return self.decoder.embed_tokens 665 | 666 | def set_input_embeddings(self, value): 667 | self.decoder.embed_tokens = value 668 | 669 | def get_decoder(self): 670 | return self.decoder 671 | 672 | def forward( 673 | self, 674 | input_ids: torch.LongTensor = None, 675 | attention_mask: Optional[torch.Tensor] = None, 676 | head_mask: Optional[torch.Tensor] = None, 677 | past_key_values: Optional[List[torch.FloatTensor]] = None, 678 | inputs_embeds: Optional[torch.FloatTensor] = None, 679 | neighbor_embeds: Optional[torch.FloatTensor] = None, 680 | neighbor_attention_mask: Optional[torch.Tensor] = None, 681 | use_cache: Optional[bool] = None, 682 | output_attentions: Optional[bool] = None, 683 | output_hidden_states: Optional[bool] = None, 684 | return_dict: Optional[bool] = None, 685 | ) -> Union[Tuple, BaseModelOutputWithPast]: 686 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 687 | output_hidden_states = ( 688 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 689 | ) 690 | use_cache = use_cache if use_cache is not None else self.config.use_cache 691 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 692 | 693 | # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) 694 | decoder_outputs = self.decoder( 695 | input_ids=input_ids, 696 | attention_mask=attention_mask, 697 | head_mask=head_mask, 698 | past_key_values=past_key_values, 699 | inputs_embeds=inputs_embeds, 700 | neighbor_embeds=neighbor_embeds, 701 | neighbor_attention_mask=neighbor_attention_mask, 702 | use_cache=use_cache, 703 | output_attentions=output_attentions, 704 | output_hidden_states=output_hidden_states, 705 | return_dict=return_dict, 706 | ) 707 | 708 | if not return_dict: 709 | return decoder_outputs 710 | 711 | return BaseModelOutputWithPast( 712 | last_hidden_state=decoder_outputs.last_hidden_state, 713 | past_key_values=decoder_outputs.past_key_values, 714 | hidden_states=decoder_outputs.hidden_states, 715 | attentions=decoder_outputs.attentions, 716 | ) 717 | 718 | 719 | def reset_peft_parameters(model): 720 | for n, p in model.named_parameters(): 721 | if "lora_A" in n: 722 | nn.init.kaiming_uniform_(p, a=math.sqrt(5)) 723 | if "lora_B" in n: 724 | nn.init.zeros_(p) 725 | if "adapter" in n: 726 | identity = torch.eye(p.size(0), p.size(1)) 727 | # Add small random noise 728 | noise = torch.randn(p.size(0), p.size(1)) * 0.01 729 | p = identity + noise 730 | 731 | def mark_only_peft_as_trainable(model): 732 | for n, p in model.named_parameters(): 733 | p.requires_grad = False 734 | for m in model.modules(): 735 | if isinstance(m, MPTDecoderLayer) and m.cross_attention == True: 736 | for n, p in m.named_parameters(): 737 | p.requires_grad = True 738 | 739 | class MPTForCausalLM(MPTPreTrainedModel): 740 | _tied_weights_keys = ["lm_head.weight"] 741 | 742 | def __init__(self, config): 743 | super().__init__(config) 744 | self.model = MPTModel(config) 745 | 746 | # the lm_head weight is automatically tied to the embed tokens weight 747 | self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False) 748 | 749 | # Initialize weights and apply final processing 750 | self.post_init() 751 | 752 | if config.peft_type != 'none': 753 | reset_peft_parameters(self.model) 754 | mark_only_peft_as_trainable(self.model) 755 | 756 | def get_input_embeddings(self): 757 | return self.model.decoder.embed_tokens 758 | 759 | def set_input_embeddings(self, value): 760 | self.model.decoder.embed_tokens = value 761 | 762 | def get_output_embeddings(self): 763 | return self.lm_head 764 | 765 | def set_output_embeddings(self, new_embeddings): 766 | self.lm_head = new_embeddings 767 | 768 | def set_decoder(self, decoder): 769 | self.model.decoder = decoder 770 | 771 | def get_decoder(self): 772 | return self.model.decoder 773 | 774 | def forward( 775 | self, 776 | input_ids: torch.LongTensor = None, 777 | attention_mask: Optional[torch.Tensor] = None, 778 | head_mask: Optional[torch.Tensor] = None, 779 | past_key_values: Optional[List[torch.FloatTensor]] = None, 780 | inputs_embeds: Optional[torch.FloatTensor] = None, 781 | labels: Optional[torch.LongTensor] = None, 782 | neighbor_embeds: Optional[torch.FloatTensor] = None, 783 | neighbor_attention_mask: Optional[torch.Tensor] = None, 784 | use_cache: Optional[bool] = None, 785 | output_attentions: Optional[bool] = None, 786 | output_hidden_states: Optional[bool] = None, 787 | return_dict: Optional[bool] = None, 788 | ) -> Union[Tuple, CausalLMOutputWithPast]: 789 | """ 790 | Args: 791 | input_ids : token ids of input text 792 | attention_mask : attention_mask of input text 793 | head_mask : mask selected heads of the attention modules 794 | past_key_values : previous key values of the decoder 795 | inputs_embeds : embeddings of input text 796 | labels : token ids of output text 797 | neighbor_embeds : embeddings of neighbor text/images 798 | neighbor_attention_mask : attention mask of neighbor text/images 799 | use_cache : whether to use cache 800 | output_attentions : whether to output attentions 801 | output_hidden_states : whether to output hidden states 802 | return_dict : whether to return a dict 803 | """ 804 | 805 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 806 | output_hidden_states = ( 807 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 808 | ) 809 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 810 | 811 | # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) 812 | outputs = self.model.decoder( 813 | input_ids=input_ids, 814 | attention_mask=attention_mask, 815 | head_mask=head_mask, 816 | past_key_values=past_key_values, 817 | inputs_embeds=inputs_embeds, 818 | neighbor_embeds=neighbor_embeds, 819 | neighbor_attention_mask=neighbor_attention_mask, 820 | use_cache=use_cache, 821 | output_attentions=output_attentions, 822 | output_hidden_states=output_hidden_states, 823 | return_dict=return_dict, 824 | ) 825 | 826 | logits = self.lm_head(outputs[0]).contiguous() 827 | loss = None 828 | if labels is not None: 829 | # move labels to correct device to enable model parallelism 830 | labels = labels.to(logits.device) 831 | # Shift so that tokens < n predict n 832 | shift_logits = logits[..., :-1, :].contiguous() 833 | shift_labels = labels[..., 1:].contiguous() 834 | # Flatten the tokens 835 | loss_fct = CrossEntropyLoss() 836 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 837 | 838 | if not return_dict: 839 | output = (logits,) + outputs[1:] 840 | return (loss,) + output if loss is not None else output 841 | 842 | return CausalLMOutputWithPast( 843 | loss=loss, 844 | logits=logits, 845 | past_key_values=outputs.past_key_values, 846 | hidden_states=outputs.hidden_states, 847 | attentions=outputs.attentions, 848 | ) 849 | 850 | def prepare_inputs_for_generation( 851 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs 852 | ): 853 | if past_key_values: 854 | input_ids = input_ids[:, -1:] 855 | 856 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 857 | if inputs_embeds is not None and past_key_values is None: 858 | model_inputs = {"inputs_embeds": inputs_embeds} 859 | else: 860 | model_inputs = {"input_ids": input_ids} 861 | 862 | model_inputs.update( 863 | { 864 | "past_key_values": past_key_values, 865 | "use_cache": kwargs.get("use_cache"), 866 | "attention_mask": attention_mask, 867 | } 868 | ) 869 | return model_inputs 870 | 871 | @staticmethod 872 | def _reorder_cache(past_key_values, beam_idx): 873 | reordered_past = () 874 | for layer_past in past_key_values: 875 | reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) 876 | return reordered_past 877 | 878 | 879 | class TextPooler(nn.Module): 880 | """ 881 | Pool the hidden state corresponding to the first token. 882 | """ 883 | def __init__(self, config): 884 | super().__init__() 885 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 886 | self.activation = nn.Tanh() 887 | 888 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 889 | # We "pool" the model by simply taking the hidden state corresponding to the first token. 890 | first_token_tensor = hidden_states[:, 0] 891 | pooled_output = self.dense(first_token_tensor) 892 | pooled_output = self.activation(pooled_output) 893 | return pooled_output 894 | 895 | 896 | class CrossAttentionModel(nn.Module): 897 | """ 898 | CrossAttentionModel is a wrapper around a pretrained language model. 899 | It supports the decoder-only models (e.g., OPT).) 900 | """ 901 | def __init__(self, args, tokenizer): 902 | super().__init__() 903 | 904 | self.args = args 905 | self.context = args.context 906 | self.neighbor_mode = args.neighbor_mode 907 | self.n_text_tokens = args.n_text_tokens 908 | self.n_visual_tokens = args.n_visual_tokens 909 | self.tokenizer = tokenizer 910 | 911 | self.initialize_lm(args) 912 | self.input_embeddings = self.lm.get_input_embeddings() 913 | 914 | self.text_model = None 915 | if self.context != "section_only": 916 | # Text model to encode text neighbors 917 | embedding_dim = self.input_embeddings.embedding_dim * args.n_text_tokens 918 | if "clip" in args.text_model: 919 | self.text_model = CLIPTextModel.from_pretrained(args.text_model) 920 | else: 921 | self.text_model = RobertaModel.from_pretrained(args.text_model) 922 | self.text_pooler = TextPooler(self.text_model.config) 923 | self.text_embeddings = nn.Linear(self.text_model.config.hidden_size, embedding_dim) 924 | self.text_position_embeddings = nn.Embedding(args.max_output_length + 1, embedding_dim) # + 1 for padding neighbors 925 | # Freeze the text model 926 | self.text_model.eval() 927 | for name, param in self.text_model.named_parameters(): 928 | param.requires_grad = False 929 | 930 | self.visual_model = None 931 | if self.context in ("section_all", "all"): 932 | # Vision model to encode image neighbors 933 | embedding_dim = self.input_embeddings.embedding_dim * args.n_visual_tokens 934 | self.visual_model = CLIPVisionModel.from_pretrained(args.visual_model) 935 | self.visual_embeddings = nn.Linear(self.visual_model.config.hidden_size, embedding_dim) 936 | self.visual_position_embeddings = nn.Embedding(args.max_output_length + 1, embedding_dim) # + 1 for padding neighbors 937 | # Freeze the vision model 938 | self.visual_model.eval() 939 | for param in self.visual_model.parameters(): 940 | param.requires_grad = False 941 | 942 | # Freeze the base LM if needed 943 | if self.args.freeze_lm: 944 | print("Freezing the LM.") 945 | self.lm.eval() 946 | for param in self.lm.parameters(): 947 | param.requires_grad = False 948 | else: 949 | self.lm.train() 950 | 951 | def initialize_lm(self, args): 952 | # Initialize the LM using the pretrained model except cross-attention layers 953 | opt_config = AutoConfig.from_pretrained(args.model_name_or_path) 954 | opt_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, config=opt_config) 955 | 956 | mpt_config = MPTConfig(args, opt_config) 957 | mpt_model = MPTForCausalLM(mpt_config) 958 | 959 | # Copy embeddings 960 | mpt_model.model.decoder.embed_tokens.load_state_dict(opt_model.model.decoder.embed_tokens.state_dict()) 961 | mpt_model.model.decoder.embed_positions.load_state_dict(opt_model.model.decoder.embed_positions.state_dict()) 962 | if mpt_config.word_embed_proj_dim != mpt_config.hidden_size: 963 | mpt_model.model.decoder.project_out.load_state_dict(opt_model.model.decoder.project_out.state_dict()) 964 | mpt_model.model.decoder.project_in.load_state_dict(opt_model.model.decoder.project_in.state_dict()) 965 | if mpt_config.do_layer_norm_before and not mpt_config._remove_final_layer_norm: 966 | mpt_model.model.decoder.final_layer_norm.load_state_dict(opt_model.model.decoder.final_layer_norm.state_dict()) 967 | 968 | # Copy self-attention layers 969 | for idx in range(opt_config.num_hidden_layers): 970 | missing_keys, unexpected_keys = mpt_model.model.decoder.layers[idx].load_state_dict(opt_model.model.decoder.layers[idx].state_dict(), strict=False) 971 | print(f'{idx}th layer missing_keys: {missing_keys}, unexpected_keys: {unexpected_keys}') 972 | 973 | # Copy lm_head 974 | mpt_model.lm_head.load_state_dict(opt_model.lm_head.state_dict()) 975 | 976 | self.lm = mpt_model 977 | 978 | def get_text_embs(self, input_ids, attention_mask, pos_ids=None): 979 | """ 980 | Get the text embeddings from the text model. 981 | Args: 982 | input_ids: token ids of text neighbors (batch_size, neighbor_num, seq_len) 983 | attention_mask: attention mask of text neighbors (batch_size, neighbor_num, seq_len) 984 | pos_ids: position ids of text neighbors (batch_size, neighbor_num, seq_len) 985 | Returns: 986 | text_embs: text embeddings of text neighbors (batch_size, neighbor_num, n_text_tokens, hidden_dim) 987 | """ 988 | batch_size, neighbor_num, seq_len = input_ids.shape 989 | input_ids = input_ids.reshape(-1, seq_len) 990 | attention_mask = attention_mask.reshape(-1, seq_len) 991 | 992 | outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) 993 | if "clip" in self.args.text_model: 994 | encoder_outputs = outputs.pooler_output 995 | else: 996 | encoder_outputs = self.text_pooler(outputs.last_hidden_state) 997 | text_embs = self.text_embeddings(encoder_outputs) 998 | 999 | if pos_ids is not None: 1000 | pos_ids = pos_ids.reshape(-1) 1001 | text_embs = text_embs + self.text_position_embeddings(pos_ids) 1002 | 1003 | text_embs = text_embs.reshape(text_embs.shape[0], self.n_text_tokens, -1) 1004 | return text_embs.reshape(batch_size, neighbor_num, self.n_text_tokens, -1) 1005 | 1006 | def get_visual_embs(self, pixel_values, pos_ids=None): 1007 | """ 1008 | Get the visual embeddings from the vision model. 1009 | Args: 1010 | pixel_values: pixel values of image neighbors (batch_size, neighbor_num, pixel, width, height) 1011 | pos_ids: position ids of image neighbors (batch_size, neighbor_num) 1012 | Returns: 1013 | visual_embs: visual embeddings of image neighbors (batch_size, neighbor_num, n_visual_tokens, hidden_dim) 1014 | """ 1015 | batch_size, neighbor_num, pixel, width, height = pixel_values.shape 1016 | pixel_values = pixel_values.reshape(-1, pixel, width, height) 1017 | 1018 | outputs = self.visual_model(pixel_values) 1019 | encoder_outputs = outputs.pooler_output 1020 | visual_embs = self.visual_embeddings(encoder_outputs) 1021 | 1022 | if pos_ids is not None: 1023 | pos_ids = pos_ids.reshape(-1) 1024 | visual_embs = visual_embs + self.visual_position_embeddings(pos_ids) 1025 | 1026 | visual_embs = visual_embs.reshape(visual_embs.shape[0], self.n_visual_tokens, -1) 1027 | return visual_embs.reshape(batch_size, neighbor_num, self.n_visual_tokens, -1) 1028 | 1029 | def train(self, mode=True): 1030 | super(CrossAttentionModel, self).train(mode=mode) 1031 | if self.args.freeze_lm: 1032 | self.lm.eval() 1033 | if self.text_model is not None: 1034 | self.text_model.eval() 1035 | if self.visual_model is not None: 1036 | self.visual_model.eval() 1037 | 1038 | def forward( 1039 | self, 1040 | input_ids, 1041 | attention_mask, 1042 | labels, 1043 | images=None, 1044 | image_positions=None, 1045 | neighbor_input_ids=None, 1046 | neighbor_attention_mask=None, 1047 | neighbor_pos_ids=None, 1048 | text_locations=None, 1049 | neighbor_images=None, 1050 | neighbor_images_pos_ids=None, 1051 | image_locations=None, 1052 | ): 1053 | """ 1054 | Args: 1055 | input_ids: token ids of input text (batch_size, seq_len) 1056 | attention_mask: attention_mask of input text (batch_size, seq_len) 1057 | labels: token ids of labels (batch_size, seq_len) 1058 | images: neighbor image features (batch_size, image_num, pixel, width, height) 1059 | image_positions: neighbor image locations (batch_size, image_num) 1060 | neighbor_input_ids: token ids of neighbor text (batch_size, text_num, seq_len) 1061 | neighbor_attention_mask: attention mask of neighbor text (batch_size, text_num, seq_len) 1062 | neighbor_pos_ids: position ids of neighbor text (batch_size, text_num, seq_len) 1063 | text_locations: locations of text embeddings (batch_size, text_num) 1064 | neighbor_images: neighbor image features (batch_size, image_num, pixel, width, height) 1065 | neighbor_images_pos_ids: position ids of neighbor images (batch_size, image_num) 1066 | image_locations: locations of image embeddings (batch_size, image_num) 1067 | """ 1068 | if self.neighbor_mode == "raw" or self.context == "section_only": 1069 | # For sanity check: run the pure OPT model 1070 | neighbor_embeds = None 1071 | neighbor_attention_mask = None 1072 | elif self.neighbor_mode == "cross_attention" and self.context == "text_only": 1073 | # Text neighbors only; need to compute text embeddings 1074 | batch_size, neighbor_num, seq_len = neighbor_input_ids.shape 1075 | neighbor_embeds = self.get_text_embs(neighbor_input_ids, neighbor_attention_mask, neighbor_pos_ids) 1076 | neighbor_embeds = neighbor_embeds.reshape(batch_size, neighbor_num * self.n_text_tokens, -1) 1077 | neighbor_attention_mask = neighbor_pos_ids > 0 1078 | neighbor_attention_mask = torch.repeat_interleave(neighbor_attention_mask, repeats=self.n_text_tokens, dim=1) 1079 | 1080 | elif self.neighbor_mode == "cross_attention" and self.context in ("section_all", "all"): 1081 | # Text and image neighbors; need to compute text and image embeddings 1082 | text_embeds = self.get_text_embs(neighbor_input_ids, neighbor_attention_mask, neighbor_pos_ids) 1083 | batch_size, text_neighbor_num, n_tokens, hidden_dim = text_embeds.shape 1084 | text_attention_mask = neighbor_pos_ids > 0 1085 | text_attention_mask = text_attention_mask.unsqueeze(-1).expand(-1, -1, self.n_text_tokens) 1086 | 1087 | visual_embeds = self.get_visual_embs(neighbor_images, neighbor_images_pos_ids) 1088 | batch_size, visual_neighbor_num, n_tokens, hidden_dim = visual_embeds.shape 1089 | visual_attention_mask = neighbor_images_pos_ids > 0 1090 | visual_attention_mask = visual_attention_mask.unsqueeze(-1).expand(-1, -1, self.n_visual_tokens) 1091 | 1092 | # Interleave text and image neighbors 1093 | batch_idx = torch.arange(batch_size)[:, None] 1094 | total_neighbor_num = text_neighbor_num + visual_neighbor_num 1095 | neighbor_embeds = torch.zeros((batch_size, total_neighbor_num, n_tokens, hidden_dim)).to(neighbor_input_ids.device) 1096 | neighbor_embeds[batch_idx, text_locations] = text_embeds 1097 | neighbor_embeds[batch_idx, image_locations] = visual_embeds 1098 | neighbor_embeds = neighbor_embeds.reshape(batch_size, -1, hidden_dim) 1099 | 1100 | # Interleave text and image attention masks 1101 | neighbor_attention_mask = torch.zeros((batch_size, total_neighbor_num, n_tokens)).bool().to(neighbor_attention_mask.device) 1102 | neighbor_attention_mask[batch_idx, text_locations] = text_attention_mask 1103 | neighbor_attention_mask[batch_idx, image_locations] = visual_attention_mask 1104 | neighbor_attention_mask = neighbor_attention_mask.reshape(batch_size, -1) 1105 | else: 1106 | raise ValueError(f"Neighbor mode: {self.neighbor_mode} and context: {self.context} are not supported.") 1107 | 1108 | output = self.lm(input_ids=input_ids, 1109 | attention_mask=attention_mask, 1110 | labels=labels, 1111 | neighbor_embeds=neighbor_embeds, 1112 | neighbor_attention_mask=neighbor_attention_mask) 1113 | 1114 | return output 1115 | -------------------------------------------------------------------------------- /model/modelling_self_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | TextPooler code block is adapted from: 3 | - Repository: Bert model in Huggingface Transformers 4 | - Author: Huggingface 5 | - Link: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py 6 | """ 7 | 8 | import torch 9 | import torch.nn as nn 10 | from transformers import ( 11 | AutoConfig, 12 | AutoTokenizer, 13 | AutoModelForSeq2SeqLM, 14 | AutoModelForCausalLM, 15 | RobertaModel, 16 | CLIPVisionModel 17 | ) 18 | 19 | from peft import ( 20 | LoraConfig, 21 | PrefixTuningConfig, 22 | PromptTuningInit, 23 | PromptTuningConfig, 24 | TaskType, 25 | get_peft_model, 26 | ) 27 | 28 | from .graph import GCN 29 | 30 | 31 | class TextPooler(nn.Module): 32 | """ 33 | Pool the hidden state corresponding to the first token. 34 | """ 35 | def __init__(self, config): 36 | super().__init__() 37 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 38 | self.activation = nn.Tanh() 39 | 40 | def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: 41 | # We "pool" the model by simply taking the hidden state corresponding to the first token. 42 | first_token_tensor = hidden_states[:, 0] 43 | pooled_output = self.dense(first_token_tensor) 44 | pooled_output = self.activation(pooled_output) 45 | return pooled_output 46 | 47 | 48 | class SelfAttentionModel(nn.Module): 49 | """ 50 | SelfAttentionModel is a wrapper around a pretrained language model. 51 | It supports the encoder-decoder models (e.g., T5) and decoder-only models (e.g., OPT).) 52 | """ 53 | def __init__(self, args, tokenizer): 54 | super().__init__() 55 | 56 | self.args = args 57 | self.context = args.context 58 | self.decoder_only = args.decoder_only 59 | self.neighbor_mode = args.neighbor_mode 60 | self.position_type = args.position_type 61 | self.n_text_tokens = args.n_text_tokens 62 | self.n_visual_tokens = args.n_visual_tokens 63 | self.tokenizer = tokenizer 64 | 65 | if "t5" in args.model_name_or_path: 66 | peft_task_type = TaskType.SEQ_2_SEQ_LM 67 | config = AutoConfig.from_pretrained(args.model_name_or_path) 68 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name_or_path, config=config) 69 | elif "opt" in args.model_name_or_path: 70 | peft_task_type = TaskType.CAUSAL_LM 71 | config = AutoConfig.from_pretrained(args.model_name_or_path) 72 | model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, config=config) 73 | else: 74 | raise ValueError(f"SelfAttentionModel does not support {args.model_name_or_path}.") 75 | 76 | if args.peft_type == "none": 77 | self.lm = model 78 | else: 79 | if args.peft_type == "lora": 80 | peft_config = LoraConfig( 81 | r=args.lora_r, 82 | lora_alpha=args.lora_alpha, 83 | target_modules=["query", "value"], 84 | lora_dropout=args.lora_dropout, 85 | bias="none", 86 | modules_to_save=["lm_head"], 87 | ) 88 | elif args.peft_type == "prefix": 89 | peft_config = PrefixTuningConfig( 90 | task_type=peft_task_type, 91 | inference_mode=False, 92 | num_virtual_tokens=20 93 | ) 94 | elif args.peft_type == "prompt": 95 | peft_config = PromptTuningConfig( 96 | task_type=peft_task_type, 97 | prompt_tuning_init=PromptTuningInit.RANDOM, 98 | num_virtual_tokens=20, 99 | ) 100 | else: 101 | raise ValueError(f"SelfAttentionModel does not support {args.peft_type}.") 102 | self.lm = get_peft_model(model, peft_config) 103 | 104 | self.input_embeddings = self.lm.get_input_embeddings() 105 | 106 | self.text_model = None 107 | if self.neighbor_mode == "embedding": 108 | # Text model to compute embeddings for text neighbors 109 | config = AutoConfig.from_pretrained(args.text_model) 110 | embedding_dim = self.input_embeddings.embedding_dim * args.n_text_tokens 111 | self.text_model = RobertaModel.from_pretrained(args.text_model, config=config) 112 | self.text_pooler = TextPooler(config) 113 | self.text_embeddings = nn.Linear(config.hidden_size, embedding_dim) 114 | if args.position_type != "none": 115 | self.text_position_embeddings = nn.Embedding(args.max_output_length + 1, embedding_dim) # + 1 for padding neighbors 116 | # Freeze the text model 117 | self.text_model.eval() 118 | for name, param in self.text_model.named_parameters(): 119 | param.requires_grad = False 120 | 121 | self.visual_model = None 122 | if self.context in ("session_all", "all"): 123 | # Vision model to compute embeddings for image neighbors 124 | embedding_dim = self.input_embeddings.embedding_dim * args.n_visual_tokens 125 | self.visual_model = CLIPVisionModel.from_pretrained(args.visual_model) 126 | self.visual_embeddings = nn.Linear(self.visual_model.config.hidden_size, embedding_dim) 127 | if args.position_type != "none": 128 | self.visual_position_embeddings = nn.Embedding(args.max_output_length + 1, embedding_dim) # + 1 for padding neighbors 129 | # Freeze the vision model 130 | self.visual_model.eval() 131 | for param in self.visual_model.parameters(): 132 | param.requires_grad = False 133 | 134 | if self.position_type == "laplacian": 135 | if self.context in ("section_only", "section_all", "text_only") or self.neighbor_mode == "raw": 136 | raise ValueError(f"[Laplacian PE] neighbor mode: {self.neighbor_mode} and context: {self.context} are not supported.") 137 | k = 1 + args.max_text_neighbors + args.max_image_neighbors - 5 138 | embedding_dim = self.input_embeddings.embedding_dim * args.n_text_tokens 139 | self.lpe_embeddings = nn.Linear(k, embedding_dim) 140 | 141 | if self.position_type == "gnn": 142 | embedding_dim = self.input_embeddings.embedding_dim * args.n_text_tokens 143 | self.gnn = GCN(input_dim=embedding_dim, output_dim=embedding_dim, hidden_dim=self.text_model.config.hidden_size) 144 | 145 | # Freeze the base LM if needed 146 | if self.args.freeze_lm: 147 | print("Freezing the LM.") 148 | self.lm.eval() 149 | for param in self.lm.parameters(): 150 | param.requires_grad = False 151 | else: 152 | self.lm.train() 153 | 154 | def get_text_embs(self, input_ids, attention_mask, pos_ids=None): 155 | """ 156 | Compute embeddings for text neighbors. 157 | Args: 158 | input_ids: (batch_size, neighbor_num, seq_len) 159 | attention_mask: (batch_size, neighbor_num, seq_len) 160 | pos_ids: (batch_size, neighbor_num, seq_len) 161 | Returns: 162 | text_embs: (batch_size, neighbor_num, n_text_tokens, hidden_dim) 163 | """ 164 | batch_size, neighbor_num, seq_len = input_ids.shape 165 | input_ids = input_ids.reshape(-1, seq_len) 166 | attention_mask = attention_mask.reshape(-1, seq_len) 167 | 168 | outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) 169 | encoder_outputs = self.text_pooler(outputs.last_hidden_state) 170 | text_embs = self.text_embeddings(encoder_outputs) 171 | 172 | if self.position_type != "none" and pos_ids is not None: 173 | pos_ids = pos_ids.reshape(-1) 174 | text_embs = text_embs + self.text_position_embeddings(pos_ids) 175 | 176 | text_embs = text_embs.reshape(text_embs.shape[0], self.n_text_tokens, -1) 177 | return text_embs.reshape(batch_size, neighbor_num, self.n_text_tokens, -1) 178 | 179 | def get_visual_embs(self, pixel_values, pos_ids=None): 180 | """ 181 | Compute embeddings for image neighbors. 182 | Args: 183 | pixel_values: (batch_size, neighbor_num, pixel, width, height) 184 | pos_ids: (batch_size, neighbor_num) 185 | Returns: 186 | visual_embs: (batch_size, neighbor_num, n_visual_tokens, hidden_dim) 187 | """ 188 | batch_size, neighbor_num, pixel, width, height = pixel_values.shape 189 | pixel_values = pixel_values.reshape(-1, pixel, width, height) 190 | 191 | outputs = self.visual_model(pixel_values) 192 | encoder_outputs = outputs.pooler_output 193 | visual_embs = self.visual_embeddings(encoder_outputs) 194 | 195 | if self.position_type != "none" and pos_ids is not None: 196 | pos_ids = pos_ids.reshape(-1) 197 | visual_embs = visual_embs + self.visual_position_embeddings(pos_ids) 198 | 199 | visual_embs = visual_embs.reshape(visual_embs.shape[0], self.n_visual_tokens, -1) 200 | return visual_embs.reshape(batch_size, neighbor_num, self.n_visual_tokens, -1) 201 | 202 | def train(self, mode=True): 203 | super(SelfAttentionModel, self).train(mode=mode) 204 | if self.args.freeze_lm: 205 | self.lm.eval() 206 | if self.text_model is not None: 207 | self.text_model.eval() 208 | if self.visual_model is not None: 209 | self.visual_model.eval() 210 | 211 | def forward( 212 | self, 213 | input_ids, 214 | attention_mask, 215 | labels, 216 | images=None, 217 | image_positions=None, 218 | neighbor_input_ids=None, 219 | neighbor_attention_mask=None, 220 | neighbor_pos_ids=None, 221 | text_locations=None, 222 | neighbor_images=None, 223 | neighbor_images_pos_ids=None, 224 | image_locations=None, 225 | lpe=None, 226 | graph=None 227 | ): 228 | """ 229 | Args: 230 | input_ids: token ids of input text (batch_size, seq_len) 231 | attention_mask: attention_mask of input text (batch_size, seq_len) 232 | labels: token ids of labels (batch_size, seq_len) 233 | images: neighbor image features (batch_size, image_num, pixel, width, height) 234 | image_positions: neighbor image locations (batch_size, image_num) 235 | neighbor_input_ids: token ids of neighbor text (batch_size, text_num, seq_len) 236 | neighbor_attention_mask: attention mask of neighbor text (batch_size, text_num, seq_len) 237 | neighbor_pos_ids: position ids of neighbor text (batch_size, text_num, seq_len) 238 | text_locations: locations of text embeddings (batch_size, text_num) 239 | neighbor_images: neighbor image features (batch_size, image_num, pixel, width, height) 240 | neighbor_images_pos_ids: position ids of neighbor images (batch_size, image_num) 241 | image_locations: locations of image embeddings (batch_size, image_num) 242 | """ 243 | 244 | if self.neighbor_mode == "raw" and self.context in ("session", "text_only"): 245 | # Only text information is provided as raw text; no need to compute embeddings for neighbors 246 | return self.lm(input_ids=input_ids, attention_mask=attention_mask, labels=labels) 247 | 248 | elif self.neighbor_mode == "raw" and self.context in ("session_all", "all"): 249 | # Both text and image information are provided as raw text and images; no need to compute embeddings for neighbors 250 | input_embs = self.input_embeddings(input_ids) 251 | visual_embs = self.get_visual_embs(images) 252 | 253 | batch_size, seq_len, hidden_dim = input_embs.shape 254 | batch_idx = torch.arange(batch_size)[:, None] 255 | input_embs[batch_idx, image_positions] = visual_embs.reshape(batch_size, -1, hidden_dim) 256 | 257 | if self.decoder_only: 258 | # Labels should not be included in loss computation 259 | labels[batch_idx, image_positions] = -100 260 | 261 | return self.lm(inputs_embeds=input_embs, attention_mask=attention_mask, labels=labels) 262 | 263 | elif self.neighbor_mode == "embedding" and self.context in ("session", "text_only"): 264 | # Only text information is provided as embeddings; Compute embeddings for text neighbors 265 | batch_size, neighbor_num, seq_len = neighbor_input_ids.shape 266 | neighbor_embeds = self.get_text_embs(neighbor_input_ids, neighbor_attention_mask, neighbor_pos_ids) 267 | neighbor_embeds = neighbor_embeds.reshape(batch_size, neighbor_num * self.n_text_tokens, -1) 268 | neighbor_attention_mask = neighbor_pos_ids > 0 269 | neighbor_attention_mask = torch.repeat_interleave(neighbor_attention_mask, repeats=self.n_text_tokens, dim=1) 270 | 271 | input_embs = self.input_embeddings(input_ids) 272 | input_embs = torch.cat((input_embs, neighbor_embeds), dim=1) 273 | attention_mask = torch.cat((attention_mask, neighbor_attention_mask), dim=1) 274 | 275 | if self.decoder_only: 276 | # Labels should not be included in loss computation 277 | neighbor_labels = -100 * torch.ones((batch_size, neighbor_num * self.n_text_tokens), dtype=labels.dtype).to(labels.device) 278 | labels = torch.cat((labels, neighbor_labels), dim=1) 279 | 280 | return self.lm(inputs_embeds=input_embs, attention_mask=attention_mask, labels=labels) 281 | 282 | elif self.neighbor_mode == "embedding" and self.context in ("session_all", "all"): 283 | # Both text and image information are provided as embeddings; Compute embeddings for text and image neighbors 284 | text_embeds = self.get_text_embs(neighbor_input_ids, neighbor_attention_mask, neighbor_pos_ids) 285 | batch_size, text_neighbor_num, n_tokens, hidden_dim = text_embeds.shape 286 | text_attention_mask = neighbor_pos_ids > 0 287 | text_attention_mask = text_attention_mask.unsqueeze(-1).expand(-1, -1, self.n_text_tokens) 288 | 289 | visual_embeds = self.get_visual_embs(neighbor_images, neighbor_images_pos_ids) 290 | batch_size, visual_neighbor_num, n_tokens, hidden_dim = visual_embeds.shape 291 | batch_idx = torch.arange(batch_size)[:, None] 292 | visual_attention_mask = neighbor_images_pos_ids > 0 293 | visual_attention_mask = visual_attention_mask.unsqueeze(-1).expand(-1, -1, self.n_visual_tokens) 294 | 295 | # Interleave text and image neighbors 296 | neighbor_embeds = torch.zeros((batch_size, text_neighbor_num + visual_neighbor_num, n_tokens, hidden_dim), 297 | device=text_embeds.device) 298 | neighbor_embeds[batch_idx, text_locations] = text_embeds 299 | neighbor_embeds[batch_idx, image_locations] = visual_embeds 300 | neighbor_embeds = neighbor_embeds.reshape(batch_size, -1, hidden_dim) 301 | 302 | # Interleave text and image attention masks 303 | total_neighbor_num = text_neighbor_num + visual_neighbor_num 304 | neighbor_attention_mask = torch.zeros((batch_size, total_neighbor_num, n_tokens), 305 | device=text_attention_mask.device) 306 | neighbor_attention_mask[batch_idx, text_locations] = text_attention_mask.float() 307 | neighbor_attention_mask[batch_idx, image_locations] = visual_attention_mask.float() 308 | neighbor_attention_mask = neighbor_attention_mask.reshape(batch_size, -1) 309 | 310 | # Graph position encoding 311 | if self.context == "all": 312 | if self.position_type == "laplacian": 313 | lpe_embeddings = self.lpe_embeddings(lpe) 314 | lpe_embeddings = lpe_embeddings.reshape(batch_size, total_neighbor_num + 1, n_tokens, hidden_dim) 315 | neighbor_embeds = neighbor_embeds + lpe_embeddings[:, 1:].reshape(batch_size, -1, hidden_dim) 316 | elif self.position_type == "gnn": 317 | neighbor_embeds = neighbor_embeds.view(batch_size, total_neighbor_num, n_tokens, hidden_dim).view(batch_size, total_neighbor_num, -1) 318 | gnn_embeds = self.gnn(neighbor_embeds, graph) 319 | neighbor_embeds = neighbor_embeds + gnn_embeds 320 | neighbor_embeds = neighbor_embeds.view(batch_size, total_neighbor_num, n_tokens, hidden_dim).view(batch_size, -1, hidden_dim) 321 | 322 | # Concatenate neighbor embeddings into input token embeddings 323 | input_embs = self.input_embeddings(input_ids) 324 | input_embs = torch.cat((input_embs, neighbor_embeds), dim=1) 325 | attention_mask = torch.cat((attention_mask, neighbor_attention_mask), dim=1) 326 | 327 | if self.decoder_only: 328 | # Labels should not be included in loss computation 329 | neighbor_labels = -100 * torch.ones((batch_size, total_neighbor_num * n_tokens), dtype=labels.dtype).to(labels.device) 330 | labels = torch.cat((labels, neighbor_labels), dim=1) 331 | 332 | return self.lm(inputs_embeds=input_embs, attention_mask=attention_mask, labels=labels) 333 | 334 | else: 335 | raise ValueError(f"Neighbor mode: {self.neighbor_mode} and context: {self.context} are not supported.") 336 | 337 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets 3 | dill 4 | evaluate 5 | nltk 6 | numpy 7 | pandas 8 | peft 9 | pillow 10 | pyarrow 11 | requests 12 | tensorflow 13 | tqdm 14 | transformers 15 | wandb 16 | warmup_scheduler 17 | -------------------------------------------------------------------------------- /script/train_generation.sh: -------------------------------------------------------------------------------- 1 | ulimit -c unlimited 2 | 3 | export WANDB_WATCH=false 4 | export PYTHONPATH=. 5 | 6 | MODEL_NAME='facebook/opt-350m' 7 | TASK='section' 8 | CONTEXT='all' 9 | NEIGHBOR_MODE='raw' 10 | PEFT_TYPE='none' 11 | DESCRIPTION=${MODEL_NAME}-${TASK}-${CONTEXT} 12 | 13 | python language_modelling/run_generation.py \ 14 | --dataset wikiweb2m \ 15 | --model_name_or_path ${MODEL_NAME} \ 16 | --task ${TASK} \ 17 | --context ${CONTEXT} \ 18 | --neighbor_mode ${NEIGHBOR_MODE} \ 19 | --peft_type ${PEFT_TYPE} \ 20 | --max_input_length 512 \ 21 | --max_output_length 128 \ 22 | --epochs 50 \ 23 | --steps_per_epoch 10000 \ 24 | --val_steps_per_epoch 400 \ 25 | --learning_rate 1e-4 \ 26 | --per_device_train_batch_size 2 \ 27 | --per_device_val_batch_size 2 \ 28 | --dataloader_num_workers 8 \ 29 | --grad_accumulation_steps 16 \ 30 | --fp16 \ 31 | --wandb_project MMHG \ 32 | --wandb_run ${DESCRIPTION} 33 | -------------------------------------------------------------------------------- /wikiweb2m/__init__.py: -------------------------------------------------------------------------------- 1 | from .data import load_wikiweb2m, WikiWeb2M 2 | -------------------------------------------------------------------------------- /wikiweb2m/cider/__init__.py: -------------------------------------------------------------------------------- 1 | from .cider import Cider 2 | -------------------------------------------------------------------------------- /wikiweb2m/cider/cider.py: -------------------------------------------------------------------------------- 1 | """ 2 | The following code block is adapted from: 3 | - Repository: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | - Author: Ramakrishna Vedantam and Tsung-Yi Lin 5 | - Link: https://github.com/vrama91/cider 6 | - Creation Date: Sun Feb 8 14:16:54 2015 7 | """ 8 | 9 | from .cider_scorer import CiderScorer 10 | 11 | class Cider: 12 | """ 13 | Main Class to compute the CIDEr metric 14 | 15 | """ 16 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 17 | # set cider to sum over 1 to 4-grams 18 | self._n = n 19 | # set the standard deviation parameter for gaussian penalty 20 | self._sigma = sigma 21 | 22 | def compute_score(self, gts, res): 23 | """ 24 | Main function to compute CIDEr score 25 | :param hypo_for_image (dict) : dictionary with key and value 26 | ref_for_image (dict) : dictionary with key and value 27 | :return: cider (float) : computed CIDEr score for the corpus 28 | """ 29 | 30 | assert(gts.keys() == res.keys()) 31 | imgIds = gts.keys() 32 | 33 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma) 34 | 35 | for id in imgIds: 36 | hypo = res[id] 37 | ref = gts[id] 38 | 39 | # Sanity check. 40 | assert(type(hypo) is list) 41 | assert(len(hypo) == 1) 42 | assert(type(ref) is list) 43 | assert(len(ref) > 0) 44 | 45 | cider_scorer += (hypo[0], ref) 46 | 47 | (score, scores) = cider_scorer.compute_score() 48 | 49 | return score, scores 50 | 51 | def method(self): 52 | return "CIDEr" 53 | -------------------------------------------------------------------------------- /wikiweb2m/cider/cider_scorer.py: -------------------------------------------------------------------------------- 1 | """ 2 | The following code block is adapted from: 3 | - Repository: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric 4 | - Author: Ramakrishna Vedantam and Tsung-Yi Lin 5 | - Link: https://github.com/vrama91/cider 6 | - Creation Date: Sun Feb 8 14:16:54 2015 7 | """ 8 | 9 | #!/usr/bin/env python 10 | 11 | import copy 12 | from collections import defaultdict 13 | import numpy as np 14 | import math 15 | 16 | def precook(s, n=4, out=False): 17 | """ 18 | Takes a string as input and returns an object that can be given to 19 | either cook_refs or cook_test. This is optional: cook_refs and cook_test 20 | can take string arguments as well. 21 | :param s: string : sentence to be converted into ngrams 22 | :param n: int : number of ngrams for which representation is calculated 23 | :return: term frequency vector for occuring ngrams 24 | """ 25 | if isinstance(s, (list, tuple)): 26 | s = s[0] 27 | words = s.split() 28 | counts = defaultdict(int) 29 | for k in range(1,n+1): 30 | for i in range(len(words)-k+1): 31 | ngram = tuple(words[i:i+k]) 32 | counts[ngram] += 1 33 | return counts 34 | 35 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average" 36 | '''Takes a list of reference sentences for a single segment 37 | and returns an object that encapsulates everything that BLEU 38 | needs to know about them. 39 | :param refs: list of string : reference sentences for some image 40 | :param n: int : number of ngrams for which (ngram) representation is calculated 41 | :return: result (list of dict) 42 | ''' 43 | return [precook(ref, n) for ref in refs] 44 | 45 | def cook_test(test, n=4): 46 | '''Takes a test sentence and returns an object that 47 | encapsulates everything that BLEU needs to know about it. 48 | :param test: list of string : hypothesis sentence for some image 49 | :param n: int : number of ngrams for which (ngram) representation is calculated 50 | :return: result (dict) 51 | ''' 52 | return precook(test, n, True) 53 | 54 | class CiderScorer(object): 55 | """CIDEr scorer. 56 | """ 57 | 58 | def copy(self): 59 | ''' copy the refs.''' 60 | new = CiderScorer(n=self.n) 61 | new.ctest = copy.copy(self.ctest) 62 | new.crefs = copy.copy(self.crefs) 63 | return new 64 | 65 | def __init__(self, test=None, refs=None, n=4, sigma=6.0): 66 | ''' singular instance ''' 67 | self.n = n 68 | self.sigma = sigma 69 | self.crefs = [] 70 | self.ctest = [] 71 | self.document_frequency = defaultdict(float) 72 | self.cook_append(test, refs) 73 | self.ref_len = None 74 | 75 | def cook_append(self, test, refs): 76 | '''called by constructor and __iadd__ to avoid creating new instances.''' 77 | 78 | if refs is not None: 79 | self.crefs.append(cook_refs(refs)) 80 | if test is not None: 81 | self.ctest.append(cook_test(test)) ## N.B.: -1 82 | else: 83 | self.ctest.append(None) # lens of crefs and ctest have to match 84 | 85 | def size(self): 86 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest)) 87 | return len(self.crefs) 88 | 89 | def __iadd__(self, other): 90 | '''add an instance (e.g., from another sentence).''' 91 | 92 | if type(other) is tuple: 93 | ## avoid creating new CiderScorer instances 94 | self.cook_append(other[0], other[1]) 95 | else: 96 | self.ctest.extend(other.ctest) 97 | self.crefs.extend(other.crefs) 98 | 99 | return self 100 | def compute_doc_freq(self): 101 | ''' 102 | Compute term frequency for reference data. 103 | This will be used to compute idf (inverse document frequency later) 104 | The term frequency is stored in the object 105 | :return: None 106 | ''' 107 | for refs in self.crefs: 108 | # refs, k ref captions of one image 109 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]): 110 | self.document_frequency[ngram] += 1 111 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count) 112 | 113 | def compute_cider(self): 114 | def counts2vec(cnts): 115 | """ 116 | Function maps counts of ngram to vector of tfidf weights. 117 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights. 118 | The n-th entry of array denotes length of n-grams. 119 | :param cnts: 120 | :return: vec (array of dict), norm (array of float), length (int) 121 | """ 122 | vec = [defaultdict(float) for _ in range(self.n)] 123 | length = 0 124 | norm = [0.0 for _ in range(self.n)] 125 | for (ngram,term_freq) in cnts.items(): 126 | # give word count 1 if it doesn't appear in reference corpus 127 | df = np.log(max(1.0, self.document_frequency[ngram])) 128 | # ngram index 129 | n = len(ngram)-1 130 | # tf (term_freq) * idf (precomputed idf) for n-grams 131 | vec[n][ngram] = float(term_freq)*(self.ref_len - df) 132 | # compute norm for the vector. the norm will be used for computing similarity 133 | norm[n] += pow(vec[n][ngram], 2) 134 | 135 | if n == 1: 136 | length += term_freq 137 | norm = [np.sqrt(n) for n in norm] 138 | return vec, norm, length 139 | 140 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref): 141 | ''' 142 | Compute the cosine similarity of two vectors. 143 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis 144 | :param vec_ref: array of dictionary for vector corresponding to reference 145 | :param norm_hyp: array of float for vector corresponding to hypothesis 146 | :param norm_ref: array of float for vector corresponding to reference 147 | :param length_hyp: int containing length of hypothesis 148 | :param length_ref: int containing length of reference 149 | :return: array of score for each n-grams cosine similarity 150 | ''' 151 | delta = float(length_hyp - length_ref) 152 | # measure consine similarity 153 | val = np.array([0.0 for _ in range(self.n)]) 154 | for n in range(self.n): 155 | # ngram 156 | for (ngram,count) in vec_hyp[n].items(): 157 | # vrama91 : added clipping 158 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram] 159 | 160 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0): 161 | val[n] /= (norm_hyp[n]*norm_ref[n]) 162 | 163 | assert(not math.isnan(val[n])) 164 | # vrama91: added a length based gaussian penalty 165 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2)) 166 | return val 167 | 168 | # compute log reference length 169 | self.ref_len = np.log(float(len(self.crefs))) 170 | 171 | scores = [] 172 | for test, refs in zip(self.ctest, self.crefs): 173 | # compute vector for test captions 174 | vec, norm, length = counts2vec(test) 175 | # compute vector for ref captions 176 | score = np.array([0.0 for _ in range(self.n)]) 177 | for ref in refs: 178 | vec_ref, norm_ref, length_ref = counts2vec(ref) 179 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref) 180 | # change by vrama91 - mean of ngram scores, instead of sum 181 | score_avg = np.mean(score) 182 | # divide by number of references 183 | score_avg /= len(refs) 184 | # multiply score by 10 185 | score_avg *= 10.0 186 | # append score of an image to the score list 187 | scores.append(score_avg) 188 | return scores 189 | 190 | def compute_score(self, option=None, verbose=0): 191 | # compute idf 192 | self.compute_doc_freq() 193 | # assert to check document frequency 194 | assert(len(self.ctest) >= max(self.document_frequency.values())) 195 | # compute cider score 196 | score = self.compute_cider() 197 | # debug 198 | # print score 199 | return np.mean(np.array(score)), np.array(score) 200 | -------------------------------------------------------------------------------- /wikiweb2m/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from transformers import AutoTokenizer 5 | import pickle 6 | import pandas as pd 7 | from PIL import Image 8 | from urllib.request import urlopen 9 | 10 | from language_modelling import utils 11 | from torch_geometric.data import Data 12 | 13 | def load_wikiweb2m(task): 14 | """ 15 | Load WikiWeb2M dataset 16 | Args: 17 | task: 'section' for section summarization task 18 | Returns: 19 | train_df: pandas dataframe for training_set 20 | val_df: pandas dataframe for validation_set 21 | test_df: pandas dataframe for test_set 22 | id_list: list of (page_id, section_id) pairs 23 | """ 24 | train_df = pd.read_parquet(f'./wikiweb2m/raw/wikiweb2m_train_large.parquet') 25 | val_df = pd.read_parquet(f'./wikiweb2m/raw/wikiweb2m_val_large.parquet') 26 | test_df = pd.read_parquet(f'./wikiweb2m/raw/wikiweb2m_test_large.parquet') 27 | 28 | with open(f'./wikiweb2m/raw/{task}_id_split_large.pkl', 'rb') as f: 29 | id_list = pickle.load(f) 30 | 31 | return train_df, val_df, test_df, id_list 32 | 33 | 34 | class WikiWeb2M(torch.utils.data.Dataset): 35 | """ 36 | WikiWeb2M dataset 37 | Args: 38 | args: args 39 | df: pandas dataframe for dataset 40 | id_list: list of (page_id, section_id) pairs 41 | tokenizer: tokenizer 42 | visual_feature_extractor_model: visual feature extractor model 43 | """ 44 | 45 | def __init__(self, args, df, id_list, tokenizer, visual_feature_extractor_model=None): 46 | self.path = './wikiweb2m/raw/' 47 | self.image_path = '/projects/rsalakhugroup/minjiy/images/' 48 | 49 | self.task = args.task 50 | self.context = args.context 51 | self.decoder_only = args.decoder_only 52 | self.neighbor_mode = args.neighbor_mode 53 | 54 | self.max_text_neighbors = args.max_text_neighbors 55 | self.max_image_neighbors = args.max_image_neighbors 56 | self.position_type = args.position_type 57 | 58 | self.df = df 59 | self.id_list = id_list 60 | self.tokenizer = tokenizer 61 | self.max_input_length = args.max_input_length 62 | self.max_output_length = args.max_output_length 63 | 64 | if visual_feature_extractor_model is not None and self.context in ('section_all', 'all'): 65 | self.visual_feature_extractor = utils.get_feature_extractor_for_model(visual_feature_extractor_model) 66 | 67 | self.n_text_tokens = args.n_text_tokens 68 | self.n_visual_tokens = args.n_visual_tokens 69 | 70 | def __len__(self): 71 | """ 72 | Get dataset length 73 | Returns: 74 | dataset length 75 | """ 76 | return len(self.id_list) 77 | 78 | def get_page_info(self, d): 79 | """ 80 | Get page text information 81 | Args: 82 | d: pandas dataframe for dataset 83 | Returns: 84 | page_info: page information in raw text 85 | """ 86 | page_url = d['page_url'].decode() 87 | page_title = d['page_title'].decode() 88 | page_description = d['page_description'].decode() 89 | page_info = ', '.join([page_title, page_description]) 90 | return ' '.join(page_info.replace('\n', ' ').split()) 91 | 92 | def get_section_info(self, section_id, d, remove_summary=True): 93 | """ 94 | Get section text information 95 | Args: 96 | section_id: section id 97 | d: pandas dataframe for dataset 98 | remove_summary: whether to remove summary (ground truth for section summarization) or not 99 | Returns: 100 | section_info: section information in raw text 101 | section_summary: section summary in raw text 102 | """ 103 | section_depth = str(d['section_depth'][section_id]) 104 | section_heading = str(d['section_heading'][section_id]) 105 | section_parent_index = str(d['section_parent_index'][section_id]) 106 | section_title = d['section_title'][section_id].decode() 107 | section_summary = d['section_summary'][section_id].decode() 108 | section_rest_sentence = d['section_rest_sentence'][section_id].decode() 109 | if remove_summary: 110 | section_info = ', '.join([section_rest_sentence]) 111 | section_info, section_summary = ' '.join(section_info.replace('\n', ' ').split()), ' '.join(section_summary.replace('\n', ' ').split()) 112 | return section_info, section_summary 113 | else: 114 | section_info = ', '.join([section_summary, section_rest_sentence]) 115 | section_info = ' '.join(section_info.replace('\n', ' ').split()) 116 | return section_info 117 | 118 | def get_section_images(self, page_id, section_id, d): 119 | """ 120 | Get section image information 121 | Args: 122 | page_id: page id 123 | section_id: section id 124 | d: pandas dataframe for dataset 125 | Returns: 126 | section_image: section image 127 | section_caption: section image caption 128 | """ 129 | section_num = d['section_title'].shape[0] 130 | image_urls = d['image_url'].reshape(section_num, -1) 131 | image_captions = d['image_caption'].reshape(section_num, -1) 132 | for image_id in range(image_urls[section_id].shape[0]): 133 | image_url = image_urls[section_id][image_id].decode() 134 | file_format = os.path.splitext(image_url)[1][1:] 135 | file_name = f'{self.image_path}/{page_id}_{section_id}_{image_id}.{file_format}' 136 | if os.path.exists(file_name): 137 | try: 138 | img = Image.open(f'./wikiweb2m/raw/images/{page_id}_{section_id}_{image_id}.{file_format}') 139 | section_image = utils.get_pixel_values_for_model(self.visual_feature_extractor, img) 140 | section_caption = image_captions[section_id][image_id].decode() 141 | return section_image, ' '.join(section_caption.replace('\n', ' ').split()) 142 | except: 143 | continue 144 | return None, None 145 | 146 | def __getitem__(self, index): 147 | """ 148 | Get item 149 | Args: 150 | index: index 151 | Returns: 152 | dictionary of { 153 | input_ids: tokenized input ids, 154 | attention_mask: attention mask for input ids, 155 | labels: tokenized label ids, 156 | images: image features (only for section_all and all), 157 | image_positions: image locations among texts (only for section_all and all), 158 | neighbor_input_ids: tokenized input ids for neighbor texts, 159 | neighbor_attention_mask: attention mask for neighbor texts, 160 | neighbor_pos_ids: position ids for neighbor texts, 161 | text_locations: neighbor text embedding locations among embeddings, 162 | neighbor_images: image features for neighbor images, 163 | neighbor_images_pos_ids: position ids for neighbor images, 164 | image_locations: neighbor image embedding locations among embeddings 165 | } 166 | """ 167 | if self.neighbor_mode == "embedding": 168 | return self.get_embedding_item(index) 169 | 170 | page_id, section_id = self.id_list[index] 171 | d = self.df[self.df['page_id'] == page_id].iloc[0] 172 | if self.context == 'section_only': 173 | # Get section text information only 174 | section_info, labels = self.get_section_info(section_id, d, remove_summary=True) 175 | inputs = 'summarize: ' + section_info 176 | input_ids = self.tokenizer(inputs, max_length=self.max_input_length, padding="do_not_pad", truncation=True, return_tensors="pt").input_ids[0] 177 | 178 | elif self.context == "section_all": 179 | # Get section text and image information 180 | section_info, labels = self.get_section_info(section_id, d, remove_summary=True) 181 | image, image_caption = self.get_section_images(page_id, section_id, d) 182 | 183 | images = [] 184 | image_positions = [] 185 | if image is None: 186 | # No image in the section 187 | inputs = "summarize: " + section_info 188 | visual_ids = torch.LongTensor(self.n_visual_tokens * [self.tokenizer.pad_token_id]) 189 | images.append(torch.zeros((3, 224, 224))) 190 | else: 191 | # If image exists, add image caption to the input text 192 | inputs = "summarize: " + section_info + ", conext: " + image_caption 193 | visual_ids = torch.LongTensor(self.n_visual_tokens * [-1]) 194 | images.append(image) 195 | max_text_length = self.max_input_length - self.n_visual_tokens 196 | input_ids = self.tokenizer(inputs, max_length=max_text_length, padding="do_not_pad", truncation=True, return_tensors="pt").input_ids[0] 197 | # Image is concatenated at the end of the input text 198 | image_positions.append(input_ids.shape[0] + torch.arange(self.n_visual_tokens)) 199 | input_ids = torch.cat([input_ids, visual_ids], dim=0) 200 | 201 | elif self.context == "text_only": 202 | # Get all text information in the page 203 | page_info = self.get_page_info(d) 204 | section_info, labels = self.get_section_info(section_id, d, remove_summary=True) 205 | # Collect text information from other sections 206 | context_info = [] 207 | for context_id in range(len(d['section_title'])): 208 | if context_id == section_id: 209 | continue 210 | context_info.append(self.get_section_info(context_id, d, remove_summary=False)) 211 | context_info = ', '.join(context_info) 212 | inputs = "summarize: " + section_info + ", context: " + page_info + context_info 213 | input_ids = self.tokenizer(inputs, max_length=self.max_input_length, padding="do_not_pad", truncation=True, return_tensors="pt").input_ids[0] 214 | 215 | elif self.context == "all": 216 | # Get all text and image information in the page 217 | page_info = self.get_page_info(d) 218 | # Get text and image information in the target section 219 | section_info, labels = self.get_section_info(section_id, d, remove_summary=True) 220 | section_image, section_caption = self.get_section_images(page_id, section_id, d) 221 | 222 | images = [] 223 | image_positions = [] 224 | if section_image is None: 225 | # No image in the section 226 | inputs = "summarize: " + section_info 227 | visual_ids = torch.LongTensor(self.n_visual_tokens * [self.tokenizer.pad_token_id]) 228 | images.append(torch.zeros((3, 224, 224))) 229 | else: 230 | # If image exists, add image caption to the input text 231 | inputs = "summarize: " + section_info + ", conext: " + section_caption 232 | visual_ids = torch.LongTensor(self.n_visual_tokens * [-1]) 233 | images.append(section_image) 234 | max_text_length = self.max_input_length - self.n_visual_tokens 235 | input_ids = self.tokenizer(inputs, max_length=max_text_length, padding="do_not_pad", truncation=True, return_tensors="pt").input_ids[0] 236 | # Image is concatenated at the end of the input text from the target section 237 | image_positions.append(input_ids.shape[0] + torch.arange(self.n_visual_tokens)) 238 | input_ids = torch.cat([input_ids, visual_ids], dim=0) 239 | 240 | # Collect text and image information from other sections 241 | for context_id in range(len(d['section_title'])): 242 | if context_id == section_id: 243 | continue 244 | context_info = self.get_section_info(context_id, d, remove_summary=False) 245 | context_image, context_caption = self.get_section_images(page_id, context_id, d) 246 | if context_image is None: 247 | # No image in the section 248 | context = context_info 249 | visual_ids = torch.LongTensor(self.n_visual_tokens * [self.tokenizer.pad_token_id]) 250 | context_image = torch.zeros((3, 224, 224)) 251 | else: 252 | # If image exists, add image caption to the input text 253 | context = context_info + context_caption 254 | visual_ids = torch.LongTensor(self.n_visual_tokens * [-1]) 255 | max_text_length = self.max_input_length - input_ids.shape[0] - self.n_visual_tokens 256 | context_ids = self.tokenizer(context, max_length=max_text_length, padding="do_not_pad", truncation=False, return_tensors="pt").input_ids[0] 257 | if input_ids.shape[0] + context_ids.shape[0] + visual_ids.shape[0] > self.max_input_length: 258 | break 259 | images.append(context_image) 260 | # Image is concatenated at the end of the input text from the corresponding section 261 | image_positions.append(input_ids.shape[0] + context_ids.shape[0] + torch.arange(self.n_visual_tokens)) 262 | input_ids = torch.cat([input_ids, context_ids, visual_ids], dim=0) 263 | 264 | if len(input_ids) > self.max_input_length: 265 | input_ids = input_ids[:self.max_input_length] 266 | 267 | if self.decoder_only: 268 | # For decode-only models, labels are the same as inputs 269 | model_inputs = self.tokenizer.pad({"input_ids": [input_ids]}, max_length=self.max_input_length, padding="max_length", return_tensors="pt") 270 | labels = ", summary: " + labels 271 | label_ids = self.tokenizer(labels, max_length=self.max_output_length, padding="do_not_pad", truncation=True, return_tensors="pt").input_ids[0] 272 | # Remove SOS token and add EOS token 273 | label_ids = torch.cat([label_ids[1:], torch.LongTensor([self.tokenizer.eos_token_id])], dim=0) 274 | model_outputs = self.tokenizer.pad({"input_ids": [label_ids]}, max_length=self.max_output_length, padding="max_length", return_tensors="pt") 275 | 276 | result = {"input_ids": torch.cat((model_inputs.input_ids[0], model_outputs.input_ids[0]), dim=0),\ 277 | "attention_mask": torch.cat((model_inputs.attention_mask[0], model_outputs.attention_mask[0]), dim=0),\ 278 | "labels": torch.cat((model_inputs.input_ids[0], model_outputs.input_ids[0]), dim=0)} 279 | 280 | else: 281 | # For encoder-decoder models, labels are the summary of the target section 282 | model_inputs = self.tokenizer.pad({"input_ids": [input_ids]}, max_length=self.max_input_length, padding="max_length", return_tensors="pt") 283 | labels = self.tokenizer(labels, max_length=self.max_output_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0] 284 | labels_with_ignore_index = torch.LongTensor([label if label != 0 else -100 for label in labels]) 285 | result = {"input_ids": model_inputs.input_ids[0], "attention_mask": model_inputs.attention_mask[0], "labels": labels_with_ignore_index} 286 | 287 | if self.context in ("section_all", "all"): 288 | # When image information is included, add image features and image locations 289 | images = torch.stack(images, dim=0) 290 | image_positions = torch.cat(image_positions, dim=0) 291 | result["images"] = images 292 | result["image_positions"] = image_positions 293 | 294 | return result 295 | 296 | def get_embedding_item(self, index): 297 | """ 298 | Get item for embedding mode 299 | Args: 300 | index: index 301 | Returns: 302 | dictionary of { 303 | input_ids: tokenized input ids, 304 | attention_mask: attention mask for input ids, 305 | labels: tokenized label ids, 306 | neighbor_input_ids: tokenized input ids for neighbor texts, 307 | neighbor_attention_mask: attention mask for neighbor texts, 308 | neighbor_pos_ids: position ids for neighbor texts, 309 | text_locations: neighbor text embedding locations among embeddings, 310 | neighbor_images: image features for neighbor images, 311 | neighbor_images_pos_ids: position ids for neighbor images, 312 | image_locations: neighbor image embedding locations among embeddings 313 | } 314 | """ 315 | page_id, section_id = self.id_list[index] 316 | d = self.df[self.df['page_id'] == page_id].iloc[0] 317 | 318 | # Get section text information 319 | section_info, labels = self.get_section_info(section_id, d, remove_summary=True) 320 | inputs = "summarize: " + section_info 321 | model_inputs = self.tokenizer(inputs, max_length=self.max_input_length, padding="max_length", truncation=True, return_tensors="pt") 322 | 323 | if self.decoder_only: 324 | # For decode-only models, labels are the same as inputs 325 | labels = ", summary: " + labels 326 | label_ids = self.tokenizer(labels, max_length=self.max_output_length, padding="do_not_pad", truncation=True, return_tensors="pt").input_ids[0] 327 | # Remove SOS token and add EOS token 328 | label_ids = torch.cat([label_ids[1:], torch.LongTensor([self.tokenizer.eos_token_id])], dim=0) 329 | model_outputs = self.tokenizer.pad({"input_ids": [label_ids]}, max_length=self.max_output_length, padding="max_length", return_tensors="pt") 330 | 331 | result = {"input_ids": torch.cat((model_inputs.input_ids[0], model_outputs.input_ids[0]), dim=0), \ 332 | "attention_mask": torch.cat((model_inputs.attention_mask[0], model_outputs.attention_mask[0]), dim=0), \ 333 | "labels": torch.cat((model_inputs.input_ids[0], model_outputs.input_ids[0]), dim=0)} 334 | else: 335 | # For encoder-decoder models, labels are the summary of the target section 336 | labels = self.tokenizer(labels, max_length=self.max_output_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0] 337 | labels_with_ignore_index = torch.LongTensor([label if label != 0 else -100 for label in labels]) 338 | result = {"input_ids": model_inputs.input_ids[0], "attention_mask": model_inputs.attention_mask[0], "labels": labels_with_ignore_index} 339 | 340 | # Raw text and image features for neighbor texts and images 341 | neighbor_texts = [] 342 | neighbor_images = [] 343 | # Position ids for position embeddings 344 | position_texts = [] 345 | position_images = [] 346 | # Locations among neighbor embeddings 347 | location_texts = [] 348 | location_images = [] 349 | location = 0 350 | # Graph 351 | graph_index = {section_id: 0} # input text: 0, neighbors: location + 1 352 | edge_list = [] 353 | 354 | #(1) page information 355 | page_info = self.get_page_info(d) 356 | neighbor_texts.append(page_info) 357 | position_texts.append(len(position_texts)) 358 | location_texts.append(location) 359 | location += 1 360 | # Graph: input_text <-> page description 361 | edge_list.append((graph_index[section_id], location)) 362 | 363 | #(2) section image information 364 | section_image, section_caption = self.get_section_images(page_id, section_id, d) 365 | if section_image is not None: 366 | neighbor_images.append(section_image) 367 | position_images.append(len(position_images)) 368 | location_images.append(location) 369 | location += 1 370 | # Graph: input_text <-> image 371 | edge_list.append((graph_index[section_id], location)) 372 | previous_image_id = location 373 | 374 | neighbor_texts.append(section_caption) 375 | position_texts.append(len(position_texts)) 376 | location_texts.append(location) 377 | location += 1 378 | # Graph: input_text <-> caption 379 | edge_list.append((graph_index[section_id], location)) 380 | # Graph: image <-> caption 381 | edge_list.append((previous_image_id, location)) 382 | 383 | #(3) other section information from the same page 384 | previous_section_id = -1 385 | for context_id in range(len(d['section_title'])): 386 | if context_id == section_id: 387 | continue 388 | 389 | if len(neighbor_texts) < self.max_text_neighbors: 390 | context_info = self.get_section_info(context_id, d, remove_summary=False) 391 | neighbor_texts.append(context_info) 392 | position_texts.append(len(position_texts)) 393 | location_texts.append(location) 394 | location += 1 395 | # Graph: previous section - current section (order) 396 | if previous_section_id > -1: 397 | edge_list.append((previous_section_id, location)) 398 | graph_index[context_id] = location 399 | previous_section_id = location 400 | 401 | if len(neighbor_images) < self.max_image_neighbors: 402 | context_image, context_caption = self.get_section_images(page_id, context_id, d) 403 | if context_image is not None: 404 | neighbor_images.append(context_image) 405 | position_images.append(len(position_images)) 406 | location_images.append(location) 407 | location += 1 408 | # Graph: section <-> image 409 | edge_list.append((previous_section_id, location)) 410 | previous_image_id = location 411 | 412 | if len(neighbor_texts) < self.max_text_neighbors: 413 | neighbor_texts.append(context_caption) 414 | position_texts.append(len(position_texts)) 415 | location_texts.append(location) 416 | location += 1 417 | # Graph: section <-> caption 418 | edge_list.append((previous_section_id, location)) 419 | # Graph: image <-> caption 420 | edge_list.append((previous_image_id, location)) 421 | 422 | # Graph: hierachical relations 423 | for context_id in range(len(d['section_parent_index'])): 424 | parent_id = d['section_parent_index'][context_id] 425 | if context_id in graph_index.keys() and parent_id in graph_index.keys(): 426 | edge_list.append((graph_index[context_id], graph_index[parent_id])) 427 | 428 | # PyG graph data 429 | node_num = 1 + self.max_text_neighbors + self.max_image_neighbors 430 | edge_index = torch.LongTensor(edge_list).t().contiguous() 431 | if self.position_type == 'laplacian': 432 | node_value = torch.zeros((node_num)) 433 | graph = Data(x=node_value, edge_index=edge_index) 434 | lpe = utils.compute_LPE(graph) 435 | elif self.position_type == 'gnn': 436 | edge_value = torch.ones((edge_index.shape[1])) 437 | graph = torch.sparse_coo_tensor(edge_index, edge_value, [node_num, node_num]).to_dense() 438 | graph = utils.normalize_graph(graph) 439 | 440 | # Increase position ids by 1 for padding_id 441 | position_texts = [position_id + 1 for position_id in position_texts] 442 | position_images = [position_id + 1 for position_id in position_images] 443 | # Pad texts 444 | while len(neighbor_texts) < self.max_text_neighbors: 445 | neighbor_texts.append('') 446 | position_texts.append(0) 447 | location_texts.append(location) 448 | location += 1 449 | # Pad images 450 | while len(neighbor_images) < self.max_image_neighbors: 451 | neighbor_images.append(torch.zeros((3, 224, 224))) 452 | position_images.append(0) 453 | location_images.append(location) 454 | location += 1 455 | 456 | #Tokenize neighbor texts 457 | neighbor_texts = self.tokenizer(neighbor_texts, max_length=self.max_input_length, padding="max_length", truncation=True, return_tensors="pt") 458 | result["neighbor_input_ids"] = neighbor_texts.input_ids 459 | result["neighbor_attention_mask"] = neighbor_texts.attention_mask 460 | result["neighbor_pos_ids"] = torch.LongTensor(position_texts) 461 | result["text_locations"] = torch.LongTensor(location_texts) 462 | result["neighbor_images"] = torch.stack(neighbor_images, dim=0) 463 | result["neighbor_images_pos_ids"] = torch.LongTensor(position_images) 464 | result["image_locations"] = torch.LongTensor(location_images) 465 | if self.position_type == 'laplacian': 466 | result["lpe"] = lpe 467 | if self.position_type == 'gnn': 468 | result["graph"] = graph 469 | return result 470 | 471 | 472 | -------------------------------------------------------------------------------- /wikiweb2m/preprocess_data.py: -------------------------------------------------------------------------------- 1 | import json 2 | import time 3 | import glob 4 | import pickle 5 | import numpy as np 6 | import pandas as pd 7 | import pyarrow.parquet as pq 8 | import pyarrow as pa 9 | from PIL import Image 10 | 11 | import os 12 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 13 | import tensorflow.compat.v1 as tf 14 | from collections import defaultdict 15 | 16 | import requests 17 | 18 | 19 | def sparse_tf_to_numpy(tf_sparse_tensor): 20 | """ 21 | Converts a tf.SparseTensor to a list of numpy object. 22 | Args: 23 | tf_sparse_tensor: A tf.SparseTensor object. 24 | Returns: 25 | A list of numpy object. 26 | """ 27 | array = tf.sparse.to_dense(tf_sparse_tensor).numpy() 28 | if len(array.shape) == 2: 29 | array = array.reshape(-1) 30 | return array.tolist() 31 | 32 | def convert_to_numpy(page_id, d): 33 | """ 34 | Converts a tf.Tensor to a list of numpy object. 35 | Args: 36 | page_id: A page id. 37 | d: A tf.Tensor object. 38 | Returns: 39 | A list of numpy object. 40 | """ 41 | page_url = d[0]['page_url'].numpy() 42 | page_title = d[0]['page_title'].numpy() 43 | page_description = d[0]['clean_page_description'].numpy() 44 | section_title = sparse_tf_to_numpy(d[1]['section_title']) 45 | section_depth = sparse_tf_to_numpy(d[1]['section_depth']) 46 | section_heading = sparse_tf_to_numpy(d[1]['section_heading_level']) 47 | section_parent_index = sparse_tf_to_numpy(d[1]['section_parent_index']) 48 | section_summary = sparse_tf_to_numpy(d[1]['section_clean_1st_sentence']) 49 | section_rest_sentence = sparse_tf_to_numpy(d[1]['section_rest_sentence']) 50 | image_url = sparse_tf_to_numpy(d[1]['section_image_url']) 51 | image_caption = sparse_tf_to_numpy(d[1]['section_image_captions']) 52 | 53 | return [page_id, page_url, page_title, page_description, section_title, section_depth, section_heading, \ 54 | section_parent_index, section_summary, section_rest_sentence, image_url, image_caption] 55 | 56 | class DataParser(): 57 | """ 58 | Parses the tfrecord files and saves the data as parquet files. 59 | Follow the WikiWeb2M dataset format (https://github.com/google-research-datasets/wit/blob/main/wikiweb2m.md). 60 | """ 61 | def __init__(self): 62 | self.path = './wikiweb2m/raw/' 63 | self.filepath = 'wikiweb2m-*' 64 | self.suffix = '.tfrecord*' 65 | self.parse_data() 66 | 67 | def parse_data(self): 68 | context_feature_description = { 69 | 'split': tf.io.FixedLenFeature([], dtype=tf.string), 70 | 'page_title': tf.io.FixedLenFeature([], dtype=tf.string), 71 | 'page_url': tf.io.FixedLenFeature([], dtype=tf.string), 72 | 'clean_page_description': tf.io.FixedLenFeature([], dtype=tf.string), 73 | 'raw_page_description': tf.io.FixedLenFeature([], dtype=tf.string), 74 | 'is_page_description_sample': tf.io.FixedLenFeature([], dtype=tf.int64), 75 | 'page_contains_images': tf.io.FixedLenFeature([], dtype=tf.int64), 76 | 'page_content_sections_without_table_list': tf.io.FixedLenFeature([] , dtype=tf.int64) 77 | } 78 | 79 | sequence_feature_description = { 80 | 'is_section_summarization_sample': tf.io.VarLenFeature(dtype=tf.int64), 81 | 'section_title': tf.io.VarLenFeature(dtype=tf.string), 82 | 'section_index': tf.io.VarLenFeature(dtype=tf.int64), 83 | 'section_depth': tf.io.VarLenFeature(dtype=tf.int64), 84 | 'section_heading_level': tf.io.VarLenFeature(dtype=tf.int64), 85 | 'section_subsection_index': tf.io.VarLenFeature(dtype=tf.int64), 86 | 'section_parent_index': tf.io.VarLenFeature(dtype=tf.int64), 87 | 'section_text': tf.io.VarLenFeature(dtype=tf.string), 88 | 'section_clean_1st_sentence': tf.io.VarLenFeature(dtype=tf.string), 89 | 'section_raw_1st_sentence': tf.io.VarLenFeature(dtype=tf.string), 90 | 'section_rest_sentence': tf.io.VarLenFeature(dtype=tf.string), 91 | 'is_image_caption_sample': tf.io.VarLenFeature(dtype=tf.int64), 92 | 'section_image_url': tf.io.VarLenFeature(dtype=tf.string), 93 | 'section_image_mime_type': tf.io.VarLenFeature(dtype=tf.string), 94 | 'section_image_width': tf.io.VarLenFeature(dtype=tf.int64), 95 | 'section_image_height': tf.io.VarLenFeature(dtype=tf.int64), 96 | 'section_image_in_wit': tf.io.VarLenFeature(dtype=tf.int64), 97 | 'section_contains_table_or_list': tf.io.VarLenFeature(dtype=tf.int64), 98 | 'section_image_captions': tf.io.VarLenFeature(dtype=tf.string), 99 | 'section_image_alt_text': tf.io.VarLenFeature(dtype=tf.string), 100 | 'section_image_raw_attr_desc': tf.io.VarLenFeature(dtype=tf.string), 101 | 'section_image_clean_attr_desc': tf.io.VarLenFeature(dtype=tf.string), 102 | 'section_image_raw_ref_desc': tf.io.VarLenFeature(dtype=tf.string), 103 | 'section_image_clean_ref_desc': tf.io.VarLenFeature(dtype=tf.string), 104 | 'section_contains_images': tf.io.VarLenFeature(dtype=tf.int64) 105 | } 106 | 107 | def _parse_function(example_proto): 108 | return tf.io.parse_single_sequence_example(example_proto, 109 | context_feature_description, 110 | sequence_feature_description) 111 | 112 | data_path = glob.glob(self.path + self.filepath + self.suffix) 113 | raw_dataset = tf.data.TFRecordDataset(data_path, compression_type='GZIP') 114 | self.dataset = raw_dataset.map(_parse_function) 115 | 116 | def save_df_torch(self): 117 | # save as parquet files 118 | 119 | # columns describing each section 120 | columns = ['page_id', 'page_url', 'page_title', 'page_description', 'section_title', 'section_depth', 'section_heading', \ 121 | 'section_parent_index', 'section_summary', 'section_rest_sentence', 'image_url', 'image_caption'] 122 | 123 | train_df = pd.DataFrame(columns=columns) 124 | val_df = pd.DataFrame(columns=columns) 125 | test_df = pd.DataFrame(columns=columns) 126 | 127 | for page_id, d in enumerate(self.dataset): 128 | if page_id % 100000 == 0: 129 | print(page_id, 'have processed...') 130 | # we sample first 600k pages 131 | if page_id == 600000: 132 | break 133 | split = d[0]['split'].numpy().decode() 134 | # we sample first 400k pages for training, next 100k for validation, and next 100k for testing 135 | if page_id < 400000: 136 | train_df.loc[len(train_df)] = convert_to_numpy(page_id, d) 137 | elif page_id < 500000: 138 | val_df.loc[len(val_df)] = convert_to_numpy(page_id, d) 139 | else: 140 | test_df.loc[len(test_df)] = convert_to_numpy(page_id, d) 141 | 142 | print(f'train_num: ', len(train_df), ', val_num: ', len(val_df), ', test_num: ', len(test_df)) 143 | train_df.to_parquet(f'{self.path}/wikiweb2m_train_large.parquet') 144 | val_df.to_parquet(f'{self.path}/wikiweb2m_val_large.parquet') 145 | test_df.to_parquet(f'{self.path}/wikiweb2m_test_large.parquet') 146 | 147 | def split_ids(self, task): 148 | # split page ids into training/validation/test sets and save as pickle files 149 | id_list = defaultdict(list) 150 | for page_id, d in enumerate(self.dataset): 151 | if page_id % 100000 == 0: 152 | print(page_id, 'have processed...') 153 | # we sample first 600k pages 154 | if page_id == 600000: 155 | break 156 | # we sample first 400k pages for training, next 100k for validation, and next 100k for testing 157 | if page_id < 400000: 158 | split = "train" 159 | elif page_id < 500000: 160 | split = "val" 161 | else: 162 | split = "test" 163 | 164 | # when task is page summarization 165 | if task == 'page': 166 | is_sample = d[0]['is_page_description_sample'].numpy() 167 | if is_sample == 0: 168 | continue 169 | id_list[split].append(page_id) 170 | # when task is section summarization 171 | elif task == 'section': 172 | are_samples = d[1]['is_section_summarization_sample'].values.numpy() 173 | for section_id in range(are_samples.shape[0]): 174 | is_sample = are_samples[section_id] 175 | if is_sample == 0: 176 | continue 177 | id_list[split].append((page_id, section_id)) 178 | 179 | print(f'task: {task}, train_num: ', len(id_list['train']), ', val_num: ', len(id_list['val']), ', test_num: ', len(id_list['test'])) 180 | with open(f'{self.path}/{task}_id_split_large.pkl', 'wb') as file: 181 | pickle.dump(id_list, file) 182 | 183 | def download_images(self): 184 | # download images from image urls 185 | 186 | headers = {"User-Agent": "research (https://www.cs.cmu.edu/; minjiy@cs.cmu.edu)"} 187 | 188 | for page_id, d in enumerate(self.dataset): 189 | # we sample first 600k pages 190 | if page_id == 600000: 191 | break 192 | if page_id % 1000 == 0: 193 | print(page_id, 'have processed...') 194 | image_urls = tf.sparse.to_dense(d[1]['section_image_url']).numpy() 195 | for section_id in range(image_urls.shape[0]): 196 | for image_id in range(image_urls[section_id].shape[0]): 197 | image_url = image_urls[section_id][image_id] 198 | if image_url == b'': 199 | continue 200 | image_url = image_url.decode() 201 | file_format = os.path.splitext(image_url)[1][1:] 202 | file_name = f'{self.path}/images/{page_id}_{section_id}_{image_id}.{file_format}' 203 | if os.path.exists(file_name): 204 | break 205 | 206 | another_image = False 207 | try: 208 | response = requests.get(image_url, headers=headers) 209 | response.raise_for_status() 210 | except requests.exceptions.HTTPError as e: 211 | if "404 Client Error: Not Found for url" in str(e): 212 | # corresponding image does not exist 213 | another_image = True 214 | continue 215 | else: 216 | # Wikimedia server is busy; try again after 1 second 217 | time.sleep(1) 218 | response = requests.get(image_url) 219 | 220 | with open(file_name, 'wb') as file: 221 | for chunk in response.iter_content(8192): 222 | file.write(chunk) 223 | # check if the downloaded file is a right format 224 | try: 225 | img = Image.open(file_name) 226 | except: 227 | if os.path.exists(file_name): 228 | os.remove(file_name) 229 | another_image = True 230 | continue 231 | # if another_image == True, we try to download another image in the same section 232 | if another_image == False: 233 | break 234 | 235 | 236 | if __name__ == "__main__": 237 | parser = DataParser() 238 | # split (page ids, section_ids) into training/validation/test sets and save as pickle files 239 | parser.split_ids('section') 240 | # save WikiWeb2M data as parquet files 241 | parser.save_df_torch() 242 | # download images from image urls 243 | parser.download_images() 244 | --------------------------------------------------------------------------------