├── README.md ├── args.py ├── asset ├── cvqa.png └── ovqa.png ├── dataset.py ├── model ├── __init__.py ├── adapter.py ├── deberta.py └── gnn.py ├── setup.sh ├── train.py └── util ├── __init__.py ├── dist.py ├── metrics.py └── misc.py /README.md: -------------------------------------------------------------------------------- 1 | # Open-Vocabulary Video Question Answering: A New Benchmark for Evaluating the Generalizability of Video Question Answering Models 2 | 3 | This is the official implementation of OVQA (ICCV 2023). ([arxiv](https://arxiv.org/abs/2308.09363)) 4 | 5 | > Dohwan Ko, Ji Soo Lee, Miso Choi, Jaewon Chu, Jihwan Park, Hyunwoo J. Kim. 6 | > 7 | > Department of Computer Science and Engineering, Korea University 8 | 9 |
10 | 11 | 12 |
13 | 14 |  (a) Closed-vocabulary Video Question Answering       (b) Open-vocabulary Video Question Answering (Ours) 15 | 16 | 17 | 18 | ## Setup 19 | To install requirements, run: 20 | ``` 21 | conda create -n ovqa python=3.8 22 | conda activate ovqa 23 | sh setup.sh 24 | ``` 25 | ## Data Preparation 26 | ### Download preprocessed data, visual features 27 | Pretrained checkpoint, preprocessed data, and data annotations are provided [here](https://drive.google.com/drive/folders/1HcrMsINkNcRUfnZMd15l9AH0R9X4qvQG). You can download pretrained DeBERTa-v2-xlarge model [here](https://huggingface.co/microsoft/deberta-v2-xlarge). 28 | 29 | Then, place the files as follows: 30 | 31 | ``` 32 | ./pretrained 33 | |─ pretrained.pth 34 | └─ deberta-v2-xlarge 35 | ./meta_data 36 | |─ activitynet 37 | │ |─ train.csv 38 | │ |─ test.csv 39 | │ |─ train_vocab.json 40 | │ |─ test_vocab.json 41 | │ |─ clipvitl14.pth 42 | │ |─ subtitles.pkl 43 | │ |─ ans2cat.json 44 | │ └─ answer_graph 45 | │ |─ train_edge_index.pth 46 | │ |─ train_x.pth 47 | │ |─ test_edge_index.pth 48 | │ └─ test_x.pth 49 | │ 50 | |─ msvd 51 | │ |─ train.csv 52 | │ |─ test.csv 53 | │ |─ train_vocab.json 54 | │ |─ test_vocab.json 55 | │ |─ clipvitl14.pth 56 | │ |─ subtitles.pkl 57 | │ |─ ans2cat.json 58 | │ └─ answer_graph 59 | │ |─ train_edge_index.pth 60 | │ |─ train_x.pth 61 | │ |─ test_edge_index.pth 62 | │ └─ test_x.pth 63 | │ : 64 | ``` 65 | 66 | 67 | ## Train FrozenBiLM+ on OVQA 68 | 69 | To train on ActivityNet-QA, MSVD-QA, TGIF-QA, and MSRVTT-QA, run below command. You can modify `--dataset activitynet ` to change dataset. 70 | 71 | ``` 72 | python -m torch.distributed.launch --nproc_per_node 4 --use_env train.py --dist-url tcp://127.0.0.1:12345 \ 73 | --dataset activitynet --lr 5e-5 --batch_size 8 --batch_size_test 32 --save_dir ./path/to/save/files --epochs 20 --eps 0.7 74 | ``` 75 | 76 | 77 | ## Acknowledgements 78 | This repo is built upon [FrozenBiLM](https://github.com/antoyang/FrozenBiLM). 79 | 80 | 81 | 82 | ## Citation 83 | ``` 84 | @inproceedings{ko2023open, 85 | title={Open-vocabulary Video Question Answering: A New Benchmark for Evaluating the Generalizability of Video Question Answering Models}, 86 | author={Ko, Dohwan and Lee, Ji Soo and Choi, Miso and Chu, Jaewon and Park, Jihwan and Kim, Hyunwoo J}, 87 | booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, 88 | year={2023} 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | PRESAVE_DIR = "" 5 | MODEL_DIR = "" 6 | DATA_DIR = "" 7 | SSD_DIR = "" 8 | name2folder = { 9 | "msrvtt": "./meta_data/msrvtt", 10 | "msvd": "./meta_data/msvd", 11 | "activitynet": "./meta_data/activitynet", 12 | "tgif": "./meta_data/tgif", 13 | } 14 | 15 | 16 | def get_args_parser(): 17 | parser = argparse.ArgumentParser("Set OVQA", add_help=False) 18 | 19 | # Dataset specific 20 | parser.add_argument("--dataset", help="dataset", type=str) 21 | 22 | # Training hyper-parameters 23 | parser.add_argument("--mlm_prob", type=float, default=0.15, help="masking probability for the MLM objective") 24 | parser.add_argument("--lr", default=3e-4, type=float, help="learning rate") 25 | parser.add_argument("--beta1", default=0.9, type=float, help="Adam optimizer parameter") 26 | parser.add_argument("--beta2", default=0.95, type=float, help="Adam optimizer parameter") 27 | parser.add_argument("--batch_size", default=32, type=int, help="batch size used for training") 28 | parser.add_argument("--batch_size_test", default=32, type=int, help="batch size used for test") 29 | parser.add_argument("--weight_decay", default=0, type=float) 30 | parser.add_argument("--epochs", default=10, type=int, help="number of training epochs") 31 | parser.add_argument("--lr_drop", default=10, type=int, help="number of epochs after which the learning rate is reduced when not using linear decay") 32 | parser.add_argument("--optimizer", default="adam", type=str) 33 | parser.add_argument("--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm") 34 | parser.add_argument("--schedule", default="linear_with_warmup", choices=["", "linear_with_warmup"], help="learning rate decay schedule, default is constant") 35 | parser.add_argument("--fraction_warmup_steps", default=0.1, type=float, help="fraction of number of steps used for warmup when using linear schedule") 36 | parser.add_argument("--eval_skip", default=1, type=int, help='do evaluation every "eval_skip" epochs') 37 | parser.add_argument("--print_freq", type=int, default=100, help="print log every print_freq iterations") 38 | 39 | # Model parameters 40 | parser.add_argument("--ft_lm", dest="freeze_lm", action="store_false", help="whether to finetune the weights of the language model") 41 | parser.add_argument("--model_name", default="deberta-v2-xlarge", choices=("deberta-v2-xlarge")) 42 | parser.add_argument("--ds_factor_attn", type=int, default=8, help="downsampling factor for adapter attn") 43 | parser.add_argument("--ds_factor_ff", type=int, default=8, help="downsampling factor for adapter ff") 44 | parser.add_argument("--freeze_ln", dest="ft_ln", action="store_false", help="whether or not to freeze layer norm parameters") 45 | parser.add_argument("--ft_mlm", dest="freeze_mlm", action="store_false", help="whether or not to finetune the mlm head parameters") 46 | parser.add_argument("--dropout", default=0.1, type=float, help="dropout to use in the adapter") 47 | parser.add_argument("--scratch", action="store_true", help="whether to train the LM with or without language init") 48 | parser.add_argument("--n_ans", type=int, default=0, help="number of answers in the answer embedding module, it is automatically set") 49 | parser.add_argument("--ft_last", dest="freeze_last", action="store_false", help="whether to finetune answer embedding module or not") 50 | parser.add_argument("--eps", default=1.0, type=float, help="strength") 51 | 52 | # Run specific 53 | parser.add_argument("--test", action="store_true", help="whether to run evaluation on val or test set") 54 | parser.add_argument("--save_dir", default="", help="path where to save, empty for no saving") 55 | parser.add_argument("--presave_dir", default=PRESAVE_DIR, help="the actual save_dir is an union of presave_dir and save_dir") 56 | parser.add_argument("--device", default="cuda", help="device to use") 57 | parser.add_argument("--seed", default=42, type=int, help="random seed") 58 | parser.add_argument("--load", default="./pretrained/pretrained.pth", help="path to load checkpoint") 59 | parser.add_argument("--resume", action="store_true", help="continue training if loading checkpoint") 60 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch") 61 | parser.add_argument("--eval", action="store_true", help="only run evaluation") 62 | parser.add_argument("--num_workers", default=3, type=int, help="number of workers for dataloader") 63 | 64 | # Distributed training parameters 65 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") 66 | parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training") 67 | 68 | # Video and Text parameters 69 | parser.add_argument("--max_feats", type=int, default=10, help="maximum number of video features considered, one per frame") 70 | parser.add_argument("--features_dim", type=int, default=768, help="dimension of the visual embedding space") 71 | parser.add_argument("--no_video", dest="use_video", action="store_false", help="disables usage of video") 72 | parser.add_argument("--no_context", dest="use_context", action="store_false", help="disables usage of speech") 73 | parser.add_argument("--max_tokens", type=int, default=256, help="maximum number of tokens in the input text prompt") 74 | parser.add_argument("--max_atokens", type=int, default=5, help="maximum number of tokens in the answer") 75 | parser.add_argument("--prefix", default="", type=str, help="task induction before question for videoqa") 76 | parser.add_argument("--suffix", default=".", type=str, help="suffix after the answer mask for videoqa") 77 | 78 | return parser 79 | -------------------------------------------------------------------------------- /asset/cvqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/OVQA/8f9fb810c11f348eb3ef6bdd4cc0d2c313299d5c/asset/cvqa.png -------------------------------------------------------------------------------- /asset/ovqa.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/OVQA/8f9fb810c11f348eb3ef6bdd4cc0d2c313299d5c/asset/ovqa.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | from torch.utils.data.dataloader import default_collate 4 | import pandas as pd 5 | import collections 6 | import json 7 | import pickle 8 | 9 | 10 | class VideoQA_Dataset(Dataset): 11 | def __init__(self, args, tokenizer, split): 12 | 13 | path = f'./meta_data/{args.dataset}/' 14 | self.data = pd.read_csv(path + f'{split}.csv') 15 | self.features = torch.load(path + f'clipvitl14.pth') 16 | self.ans2id = json.load(open(path + f'{split}_vocab.json')) 17 | self.ans2cat = json.load(open(path + 'ans2cat.json')) 18 | self.max_feats = args.max_feats 19 | self.features_dim = args.features_dim 20 | self.split = split 21 | self.prefix = args.prefix 22 | self.suffix = args.suffix 23 | self.mask = tokenizer.mask_token 24 | self.pad = tokenizer.pad_token 25 | self.tokenizer = tokenizer 26 | self.use_context = (args.use_context and args.dataset != "tgif") 27 | self.subs = pickle.load(open(path + f'subtitles.pkl', "rb")) if self.use_context else None 28 | self.args = args 29 | self.load_answer_graph() 30 | 31 | def load_answer_graph(self): 32 | self.edge_index = torch.load(f'./meta_data/{self.args.dataset}/answer_graph/{self.split}_edge_index.pth') 33 | self.vocab_embeddings = torch.load(f'./meta_data/{self.args.dataset}/answer_graph/{self.split}_x.pth') 34 | cat2coef = {'base': 1.0, 'common': self.args.eps, 'rare': self.args.eps, 'unseen': self.args.eps} 35 | # cat2coef = {'base': self.args.eps, 'common': self.args.eps, 'rare': self.args.eps, 'unseen': self.args.eps} 36 | self.eps = torch.tensor([cat2coef[self.ans2cat[k]] for k, v in self.ans2id.items()]) 37 | 38 | def __len__(self): 39 | return len(self.data) 40 | 41 | def _get_text(self, question, mask, sub): 42 | text = f"{self.prefix} Question: {question} Answer: {mask}{self.suffix}" 43 | if sub: 44 | text += f" Subtitles: {sub}" 45 | text = text.strip() 46 | return text 47 | 48 | def _get_video(self, video_id): 49 | if video_id not in self.features: 50 | print(video_id) 51 | video = torch.zeros(1, self.features_dim) 52 | else: 53 | video = self.features[video_id].float() 54 | if len(video) > self.max_feats: 55 | sampled = [] 56 | for j in range(self.max_feats): 57 | sampled.append(video[(j * len(video)) // self.max_feats]) 58 | video = torch.stack(sampled) 59 | video_len = self.max_feats 60 | elif len(video) < self.max_feats: 61 | video_len = len(video) 62 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0) 63 | else: 64 | video_len = self.max_feats 65 | 66 | return video, video_len 67 | 68 | def __getitem__(self, idx): 69 | # get question 70 | question = self.data["question"].values[idx].capitalize().strip() 71 | if question[-1] != "?": 72 | question = str(question) + "?" 73 | type = 0 74 | if "type" in self.data: 75 | type = self.data["type"].values[idx] 76 | 77 | 78 | original_answer = self.data["answer"].values[idx] 79 | answer_id = self.ans2id[original_answer] 80 | video_id = self.data["video_id"].values[idx] 81 | 82 | # get subtitles 83 | sub = "" 84 | if self.subs is not None and video_id in self.subs: 85 | sub = self.subs[video_id] 86 | sub_bool = bool(sub) 87 | if not self.use_context: 88 | sub = "" 89 | 90 | # get pattern 91 | text = self._get_text(question, self.mask, sub) 92 | 93 | # get video 94 | video, video_len = self._get_video(video_id) 95 | 96 | return {"video": video, "video_len": video_len, "text": text, "qid": idx, "answer_id": answer_id, 97 | "type": type, "sub": sub_bool, "original_answer": original_answer} 98 | 99 | 100 | def videoqa_collate_fn(batch): 101 | bs = len(batch) 102 | video = torch.stack([batch[i]["video"] for i in range(bs)]) 103 | video_len = torch.tensor([batch[i]["video_len"] for i in range(bs)], dtype=torch.long) 104 | text = [batch[i]["text"] for i in range(bs)] if isinstance(batch[0]["text"], str) else [[batch[i]["text"][j] for i in range(bs)] for j in range(len(batch[0]["text"]))] 105 | qid = [batch[i]["qid"] for i in range(bs)] 106 | answer_id = default_collate([batch[i]["answer_id"] for i in range(bs)]) 107 | type = [batch[i]["type"] for i in range(bs)] 108 | original_answer = [batch[i]["original_answer"] for i in range(bs)] 109 | out = {"video": video, "video_len": video_len, "text": text, "qid": qid, "answer_id": answer_id, "type": type, "original_answer": original_answer} 110 | if "sub" in batch[0]: 111 | sub = [batch[i]["sub"] for i in range(bs)] 112 | out["sub"] = sub 113 | return out -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from .deberta import DebertaV2ForMaskedLM -------------------------------------------------------------------------------- /model/adapter.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | 4 | 5 | class Adapter(nn.Module): 6 | def __init__( 7 | self, ds_factor, hidden_dim, ln_after=False, ln_before=False, dropout=0.1 8 | ): 9 | super().__init__() 10 | assert not hidden_dim % ds_factor 11 | self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor) 12 | self.act = nn.ReLU() 13 | self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim) 14 | self.apply(self.init_weights) 15 | self.ln_after = ln_after 16 | self.ln_before = ln_before 17 | self.dropout = dropout 18 | if ln_after or ln_before: 19 | self.ln = nn.LayerNorm(hidden_dim) 20 | if dropout: 21 | self.dropout = nn.Dropout(dropout) 22 | 23 | def init_weights(self, m: nn.Module, std=1e-3): 24 | if isinstance(m, nn.Linear): 25 | torch.nn.init.normal_(m.weight, std=std) 26 | torch.nn.init.normal_(m.bias, std=std) 27 | m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std) 28 | m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std) 29 | elif isinstance(m, nn.LayerNorm): 30 | m.bias.data.zero_() 31 | m.weight.data.fill_(1.0) 32 | 33 | def forward(self, hidden_states): 34 | if self.ln_before: 35 | residual = self.ln(hidden_states) 36 | residual = self.down(residual) 37 | else: 38 | residual = self.down(hidden_states) 39 | residual = self.act(residual) 40 | if self.dropout: 41 | residual = self.dropout(residual) 42 | residual = self.up(residual) 43 | if self.ln_after: 44 | residual = self.ln(hidden_states) 45 | return hidden_states + residual -------------------------------------------------------------------------------- /model/deberta.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020 Microsoft and the Hugging Face Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ PyTorch DeBERTa-v2 model. """ 16 | 17 | import math 18 | import json 19 | from collections.abc import Sequence 20 | from typing import Tuple, Optional 21 | 22 | import numpy as np 23 | import torch 24 | from torch import _softmax_backward_data, nn, optim 25 | from torch.nn import CrossEntropyLoss, LayerNorm 26 | import torch.nn.functional as F 27 | 28 | from model.adapter import Adapter 29 | from model.gnn import GNNSoftVerbalizer 30 | from transformers.activations import ACT2FN 31 | 32 | from transformers.modeling_outputs import ( 33 | # BaseModelOutput, 34 | ModelOutput, 35 | MaskedLMOutput, 36 | QuestionAnsweringModelOutput, 37 | SequenceClassifierOutput, 38 | TokenClassifierOutput, 39 | ) 40 | 41 | from transformers.modeling_utils import PreTrainedModel 42 | from transformers import DebertaV2Config 43 | 44 | 45 | 46 | _CONFIG_FOR_DOC = "DebertaV2Config" 47 | _TOKENIZER_FOR_DOC = "DebertaV2Tokenizer" 48 | _CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge" 49 | 50 | DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [ 51 | "microsoft/deberta-v2-xlarge", 52 | "microsoft/deberta-v2-xxlarge", 53 | "microsoft/deberta-v2-xlarge-mnli", 54 | "microsoft/deberta-v2-xxlarge-mnli", 55 | ] 56 | 57 | 58 | class BaseModelOutput(ModelOutput): 59 | """ 60 | Base class for model's outputs, with potential hidden states and attentions. 61 | Args: 62 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): 63 | Sequence of hidden-states at the output of the last layer of the model. 64 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): 65 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of 66 | shape `(batch_size, sequence_length, hidden_size)`. 67 | Hidden-states of the model at the output of each layer plus the initial embedding outputs. 68 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): 69 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, 70 | sequence_length)`. 71 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention 72 | heads. 73 | """ 74 | 75 | last_hidden_state: torch.FloatTensor = None 76 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 77 | attentions: Optional[Tuple[torch.FloatTensor]] = None 78 | position_embeddings: torch.FloatTensor = None 79 | attention_mask: torch.BoolTensor = None 80 | 81 | 82 | # Copied from transformers.models.deberta.modeling_deberta.ContextPooler 83 | class ContextPooler(nn.Module): 84 | def __init__(self, config): 85 | super().__init__() 86 | self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size) 87 | self.dropout = StableDropout(config.pooler_dropout) 88 | self.config = config 89 | 90 | def forward(self, hidden_states): 91 | # We "pool" the model by simply taking the hidden state corresponding 92 | # to the first token. 93 | 94 | context_token = hidden_states[:, 0] 95 | context_token = self.dropout(context_token) 96 | pooled_output = self.dense(context_token) 97 | pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output) 98 | return pooled_output 99 | 100 | @property 101 | def output_dim(self): 102 | return self.config.hidden_size 103 | 104 | 105 | # Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2 106 | class XSoftmax(torch.autograd.Function): 107 | """ 108 | Masked Softmax which is optimized for saving memory 109 | 110 | Args: 111 | input (:obj:`torch.tensor`): The input tensor that will apply softmax. 112 | mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation. 113 | dim (int): The dimension that will apply softmax 114 | 115 | Example:: 116 | 117 | import torch 118 | from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax 119 | 120 | # Make a tensor 121 | x = torch.randn([4,20,100]) 122 | 123 | # Create a mask 124 | mask = (x>0).int() 125 | 126 | y = XSoftmax.apply(x, mask, dim=-1) 127 | """ 128 | 129 | @staticmethod 130 | def forward(self, input, mask, dim): 131 | self.dim = dim 132 | rmask = ~(mask.bool()) 133 | 134 | output = input.masked_fill(rmask, float("-inf")) 135 | output = torch.softmax(output, self.dim) 136 | output.masked_fill_(rmask, 0) 137 | self.save_for_backward(output) 138 | return output 139 | 140 | @staticmethod 141 | def backward(self, grad_output): 142 | (output,) = self.saved_tensors 143 | inputGrad = _softmax_backward_data(grad_output, output, self.dim, output) 144 | return inputGrad, None, None 145 | 146 | 147 | # Copied from transformers.models.deberta.modeling_deberta.DropoutContext 148 | class DropoutContext(object): 149 | def __init__(self): 150 | self.dropout = 0 151 | self.mask = None 152 | self.scale = 1 153 | self.reuse_mask = True 154 | 155 | 156 | # Copied from transformers.models.deberta.modeling_deberta.get_mask 157 | def get_mask(input, local_context): 158 | if not isinstance(local_context, DropoutContext): 159 | dropout = local_context 160 | mask = None 161 | else: 162 | dropout = local_context.dropout 163 | dropout *= local_context.scale 164 | mask = local_context.mask if local_context.reuse_mask else None 165 | 166 | if dropout > 0 and mask is None: 167 | mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool() 168 | 169 | if isinstance(local_context, DropoutContext): 170 | if local_context.mask is None: 171 | local_context.mask = mask 172 | 173 | return mask, dropout 174 | 175 | 176 | # Copied from transformers.models.deberta.modeling_deberta.XDropout 177 | class XDropout(torch.autograd.Function): 178 | """Optimized dropout function to save computation and memory by using mask operation instead of multiplication.""" 179 | 180 | @staticmethod 181 | def forward(ctx, input, local_ctx): 182 | mask, dropout = get_mask(input, local_ctx) 183 | ctx.scale = 1.0 / (1 - dropout) 184 | if dropout > 0: 185 | ctx.save_for_backward(mask) 186 | return input.masked_fill(mask, 0) * ctx.scale 187 | else: 188 | return input 189 | 190 | @staticmethod 191 | def backward(ctx, grad_output): 192 | if ctx.scale > 1: 193 | (mask,) = ctx.saved_tensors 194 | return grad_output.masked_fill(mask, 0) * ctx.scale, None 195 | else: 196 | return grad_output, None 197 | 198 | 199 | # Copied from transformers.models.deberta.modeling_deberta.StableDropout 200 | class StableDropout(nn.Module): 201 | """ 202 | Optimized dropout module for stabilizing the training 203 | 204 | Args: 205 | drop_prob (float): the dropout probabilities 206 | """ 207 | 208 | def __init__(self, drop_prob): 209 | super().__init__() 210 | self.drop_prob = drop_prob 211 | self.count = 0 212 | self.context_stack = None 213 | 214 | def forward(self, x): 215 | """ 216 | Call the module 217 | 218 | Args: 219 | x (:obj:`torch.tensor`): The input tensor to apply dropout 220 | """ 221 | if self.training and self.drop_prob > 0: 222 | return XDropout.apply(x, self.get_context()) 223 | return x 224 | 225 | def clear_context(self): 226 | self.count = 0 227 | self.context_stack = None 228 | 229 | def init_context(self, reuse_mask=True, scale=1): 230 | if self.context_stack is None: 231 | self.context_stack = [] 232 | self.count = 0 233 | for c in self.context_stack: 234 | c.reuse_mask = reuse_mask 235 | c.scale = scale 236 | 237 | def get_context(self): 238 | if self.context_stack is not None: 239 | if self.count >= len(self.context_stack): 240 | self.context_stack.append(DropoutContext()) 241 | ctx = self.context_stack[self.count] 242 | ctx.dropout = self.drop_prob 243 | self.count += 1 244 | return ctx 245 | else: 246 | return self.drop_prob 247 | 248 | 249 | # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm 250 | class DebertaV2SelfOutput(nn.Module): 251 | def __init__(self, config, ds_factor, dropout): 252 | super().__init__() 253 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 254 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) 255 | self.dropout = StableDropout(config.hidden_dropout_prob) 256 | self.ds_factor = ds_factor 257 | if self.ds_factor: 258 | self.adapter = Adapter(ds_factor, config.hidden_size, dropout=dropout) 259 | 260 | def forward(self, hidden_states, input_tensor): 261 | hidden_states = self.dense(hidden_states) 262 | if self.ds_factor: 263 | hidden_states = self.adapter(hidden_states) 264 | hidden_states = self.dropout(hidden_states) 265 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 266 | return hidden_states 267 | 268 | 269 | # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2 270 | class DebertaV2Attention(nn.Module): 271 | def __init__(self, config, ds_factor, dropout): 272 | super().__init__() 273 | self.self = DisentangledSelfAttention(config) 274 | self.output = DebertaV2SelfOutput(config, ds_factor, dropout) 275 | self.config = config 276 | 277 | def forward( 278 | self, 279 | hidden_states, 280 | attention_mask, 281 | return_att=False, 282 | query_states=None, 283 | relative_pos=None, 284 | rel_embeddings=None, 285 | ): 286 | self_output = self.self( 287 | hidden_states, 288 | attention_mask, 289 | return_att, 290 | query_states=query_states, 291 | relative_pos=relative_pos, 292 | rel_embeddings=rel_embeddings, 293 | ) 294 | if return_att: 295 | self_output, att_matrix = self_output 296 | if query_states is None: 297 | query_states = hidden_states 298 | attention_output = self.output(self_output, query_states) 299 | 300 | if return_att: 301 | return (attention_output, att_matrix) 302 | else: 303 | return attention_output 304 | 305 | 306 | # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2 307 | class DebertaV2Intermediate(nn.Module): 308 | def __init__(self, config): 309 | super().__init__() 310 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size) 311 | if isinstance(config.hidden_act, str): 312 | self.intermediate_act_fn = ACT2FN[config.hidden_act] 313 | else: 314 | self.intermediate_act_fn = config.hidden_act 315 | 316 | def forward(self, hidden_states): 317 | hidden_states = self.dense(hidden_states) 318 | hidden_states = self.intermediate_act_fn(hidden_states) 319 | return hidden_states 320 | 321 | 322 | # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm 323 | class DebertaV2Output(nn.Module): 324 | def __init__(self, config, ds_factor, dropout): 325 | super().__init__() 326 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size) 327 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) 328 | self.dropout = StableDropout(config.hidden_dropout_prob) 329 | self.config = config 330 | self.ds_factor = ds_factor 331 | if self.ds_factor: 332 | self.adapter = Adapter(ds_factor, config.hidden_size, dropout=dropout) 333 | 334 | def forward(self, hidden_states, input_tensor): 335 | hidden_states = self.dense(hidden_states) 336 | if self.ds_factor: 337 | hidden_states = self.adapter(hidden_states) 338 | hidden_states = self.dropout(hidden_states) 339 | hidden_states = self.LayerNorm(hidden_states + input_tensor) 340 | return hidden_states 341 | 342 | 343 | # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2 344 | class DebertaV2Layer(nn.Module): 345 | def __init__( 346 | self, 347 | config, 348 | ds_factor_attn, 349 | ds_factor_ff, 350 | dropout, 351 | ): 352 | super().__init__() 353 | self.attention = DebertaV2Attention(config, ds_factor_attn, dropout) 354 | self.intermediate = DebertaV2Intermediate(config) 355 | self.output = DebertaV2Output(config, ds_factor_ff, dropout) 356 | 357 | def forward( 358 | self, 359 | hidden_states, 360 | attention_mask, 361 | return_att=False, 362 | query_states=None, 363 | relative_pos=None, 364 | rel_embeddings=None, 365 | ): 366 | attention_output = self.attention( 367 | hidden_states, 368 | attention_mask, 369 | return_att=return_att, 370 | query_states=query_states, 371 | relative_pos=relative_pos, 372 | rel_embeddings=rel_embeddings, 373 | ) 374 | if return_att: 375 | attention_output, att_matrix = attention_output 376 | intermediate_output = self.intermediate(attention_output) 377 | layer_output = self.output(intermediate_output, attention_output) 378 | if return_att: 379 | return (layer_output, att_matrix) 380 | else: 381 | return layer_output 382 | 383 | 384 | class ConvLayer(nn.Module): 385 | def __init__(self, config): 386 | super().__init__() 387 | kernel_size = getattr(config, "conv_kernel_size", 3) 388 | groups = getattr(config, "conv_groups", 1) 389 | self.conv_act = getattr(config, "conv_act", "tanh") 390 | self.conv = nn.Conv1d( 391 | config.hidden_size, 392 | config.hidden_size, 393 | kernel_size, 394 | padding=(kernel_size - 1) // 2, 395 | groups=groups, 396 | ) 397 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) 398 | self.dropout = StableDropout(config.hidden_dropout_prob) 399 | self.config = config 400 | 401 | def forward(self, hidden_states, residual_states, input_mask): 402 | out = ( 403 | self.conv(hidden_states.permute(0, 2, 1).contiguous()) 404 | .permute(0, 2, 1) 405 | .contiguous() 406 | ) 407 | rmask = (1 - input_mask).bool() 408 | out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0) 409 | out = ACT2FN[self.conv_act](self.dropout(out)) 410 | 411 | layer_norm_input = residual_states + out 412 | output = self.LayerNorm(layer_norm_input).to(layer_norm_input) 413 | 414 | if input_mask is None: 415 | output_states = output 416 | else: 417 | if input_mask.dim() != layer_norm_input.dim(): 418 | if input_mask.dim() == 4: 419 | input_mask = input_mask.squeeze(1).squeeze(1) 420 | input_mask = input_mask.unsqueeze(2) 421 | 422 | input_mask = input_mask.to(output.dtype) 423 | output_states = output * input_mask 424 | 425 | return output_states 426 | 427 | 428 | class DebertaV2Encoder(nn.Module): 429 | """Modified BertEncoder with relative position bias support""" 430 | 431 | def __init__(self, config, ds_factor_attn, ds_factor_ff, dropout): 432 | super().__init__() 433 | 434 | self.layer = nn.ModuleList([DebertaV2Layer(config, ds_factor_attn, ds_factor_ff, dropout) for _ in range(config.num_hidden_layers)]) 435 | self.relative_attention = getattr(config, "relative_attention", False) 436 | 437 | if self.relative_attention: 438 | self.max_relative_positions = getattr(config, "max_relative_positions", -1) 439 | if self.max_relative_positions < 1: 440 | self.max_relative_positions = config.max_position_embeddings 441 | 442 | self.position_buckets = getattr(config, "position_buckets", -1) 443 | pos_ebd_size = self.max_relative_positions * 2 444 | 445 | if self.position_buckets > 0: 446 | pos_ebd_size = self.position_buckets * 2 447 | 448 | self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size) 449 | 450 | self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")] 451 | 452 | if "layer_norm" in self.norm_rel_ebd: 453 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True) 454 | 455 | self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None 456 | 457 | def get_rel_embedding(self): 458 | rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None 459 | if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd): 460 | rel_embeddings = self.LayerNorm(rel_embeddings) 461 | return rel_embeddings 462 | 463 | def get_attention_mask(self, attention_mask): 464 | if attention_mask.dim() <= 2: 465 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 466 | attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) 467 | attention_mask = attention_mask.byte() 468 | elif attention_mask.dim() == 3: 469 | attention_mask = attention_mask.unsqueeze(1) 470 | 471 | return attention_mask 472 | 473 | def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None): 474 | if self.relative_attention and relative_pos is None: 475 | q = query_states.size(-2) if query_states is not None else hidden_states.size(-2) 476 | relative_pos = build_relative_position(q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions) 477 | return relative_pos 478 | 479 | def forward(self, hidden_states, attention_mask, output_hidden_states=True, output_attentions=False, query_states=None, relative_pos=None, return_dict=True): 480 | if attention_mask.dim() <= 2: 481 | input_mask = attention_mask 482 | else: 483 | input_mask = (attention_mask.sum(-2) > 0).byte() 484 | attention_mask = self.get_attention_mask(attention_mask) 485 | relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) 486 | 487 | all_hidden_states = () if output_hidden_states else None 488 | all_attentions = () if output_attentions else None 489 | 490 | if isinstance(hidden_states, Sequence): 491 | next_kv = hidden_states[0] 492 | else: 493 | next_kv = hidden_states 494 | rel_embeddings = self.get_rel_embedding() 495 | output_states = next_kv 496 | for i, layer_module in enumerate(self.layer): 497 | 498 | if output_hidden_states: 499 | all_hidden_states = all_hidden_states + (output_states,) 500 | 501 | output_states = layer_module(next_kv, attention_mask, output_attentions, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings) 502 | if output_attentions: 503 | output_states, att_m = output_states 504 | 505 | if i == 0 and self.conv is not None: 506 | output_states = self.conv(hidden_states, output_states, input_mask) 507 | 508 | if query_states is not None: 509 | query_states = output_states 510 | if isinstance(hidden_states, Sequence): 511 | next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None 512 | else: 513 | next_kv = output_states 514 | 515 | if output_attentions: 516 | all_attentions = all_attentions + (att_m,) 517 | 518 | if output_hidden_states: 519 | all_hidden_states = all_hidden_states + (output_states,) 520 | 521 | if not return_dict: 522 | return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None) 523 | 524 | return BaseModelOutput(last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions) 525 | 526 | 527 | def make_log_bucket_position(relative_pos, bucket_size, max_position): 528 | sign = np.sign(relative_pos) 529 | mid = bucket_size // 2 530 | abs_pos = np.where( 531 | (relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos) 532 | ) 533 | log_pos = ( 534 | np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1)) 535 | + mid 536 | ) 537 | bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int) 538 | return bucket_pos 539 | 540 | 541 | def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1): 542 | """ 543 | Build relative position according to the query and key 544 | 545 | We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key 546 | :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} = 547 | P_q - P_k` 548 | 549 | Args: 550 | query_size (int): the length of query 551 | key_size (int): the length of key 552 | bucket_size (int): the size of position bucket 553 | max_position (int): the maximum allowed absolute position 554 | 555 | Return: 556 | :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size] 557 | 558 | """ 559 | q_ids = np.arange(0, query_size) 560 | k_ids = np.arange(0, key_size) 561 | rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1)) 562 | if bucket_size > 0 and max_position > 0: 563 | rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position) 564 | rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long) 565 | rel_pos_ids = rel_pos_ids[:query_size, :] 566 | rel_pos_ids = rel_pos_ids.unsqueeze(0) 567 | return rel_pos_ids 568 | 569 | 570 | @torch.jit.script 571 | # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand 572 | def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos): 573 | return c2p_pos.expand( 574 | [ 575 | query_layer.size(0), 576 | query_layer.size(1), 577 | query_layer.size(2), 578 | relative_pos.size(-1), 579 | ] 580 | ) 581 | 582 | 583 | @torch.jit.script 584 | # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand 585 | def p2c_dynamic_expand(c2p_pos, query_layer, key_layer): 586 | return c2p_pos.expand( 587 | [ 588 | query_layer.size(0), 589 | query_layer.size(1), 590 | key_layer.size(-2), 591 | key_layer.size(-2), 592 | ] 593 | ) 594 | 595 | 596 | @torch.jit.script 597 | # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand 598 | def pos_dynamic_expand(pos_index, p2c_att, key_layer): 599 | return pos_index.expand( 600 | p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)) 601 | ) 602 | 603 | 604 | class DisentangledSelfAttention(nn.Module): 605 | """ 606 | Disentangled self-attention module 607 | 608 | Parameters: 609 | config (:obj:`DebertaV2Config`): 610 | A model config class instance with the configuration to build a new model. The schema is similar to 611 | `BertConfig`, for more details, please refer :class:`~transformers.DebertaV2Config` 612 | 613 | """ 614 | 615 | def __init__(self, config): 616 | super().__init__() 617 | if config.hidden_size % config.num_attention_heads != 0: 618 | raise ValueError( 619 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " 620 | f"heads ({config.num_attention_heads})" 621 | ) 622 | self.num_attention_heads = config.num_attention_heads 623 | _attention_head_size = config.hidden_size // config.num_attention_heads 624 | self.attention_head_size = getattr( 625 | config, "attention_head_size", _attention_head_size 626 | ) 627 | self.all_head_size = self.num_attention_heads * self.attention_head_size 628 | self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) 629 | self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) 630 | self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True) 631 | 632 | self.share_att_key = getattr(config, "share_att_key", False) 633 | self.pos_att_type = ( 634 | config.pos_att_type if config.pos_att_type is not None else [] 635 | ) 636 | self.relative_attention = getattr(config, "relative_attention", False) 637 | 638 | if self.relative_attention: 639 | self.position_buckets = getattr(config, "position_buckets", -1) 640 | self.max_relative_positions = getattr(config, "max_relative_positions", -1) 641 | if self.max_relative_positions < 1: 642 | self.max_relative_positions = config.max_position_embeddings 643 | self.pos_ebd_size = self.max_relative_positions 644 | if self.position_buckets > 0: 645 | self.pos_ebd_size = self.position_buckets 646 | 647 | self.pos_dropout = StableDropout(config.hidden_dropout_prob) 648 | 649 | if not self.share_att_key: 650 | if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: 651 | self.pos_key_proj = nn.Linear( 652 | config.hidden_size, self.all_head_size, bias=True 653 | ) 654 | if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: 655 | self.pos_query_proj = nn.Linear( 656 | config.hidden_size, self.all_head_size 657 | ) 658 | 659 | self.dropout = StableDropout(config.attention_probs_dropout_prob) 660 | 661 | def transpose_for_scores(self, x, attention_heads): 662 | new_x_shape = x.size()[:-1] + (attention_heads, -1) 663 | x = x.view(*new_x_shape) 664 | return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1)) 665 | 666 | def forward( 667 | self, 668 | hidden_states, 669 | attention_mask, 670 | return_att=False, 671 | query_states=None, 672 | relative_pos=None, 673 | rel_embeddings=None, 674 | ): 675 | """ 676 | Call the module 677 | 678 | Args: 679 | hidden_states (:obj:`torch.FloatTensor`): 680 | Input states to the module usually the output from previous layer, it will be the Q,K and V in 681 | `Attention(Q,K,V)` 682 | 683 | attention_mask (:obj:`torch.ByteTensor`): 684 | An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum 685 | sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j` 686 | th token. 687 | 688 | return_att (:obj:`bool`, optional): 689 | Whether return the attention matrix. 690 | 691 | query_states (:obj:`torch.FloatTensor`, optional): 692 | The `Q` state in `Attention(Q,K,V)`. 693 | 694 | relative_pos (:obj:`torch.LongTensor`): 695 | The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with 696 | values ranging in [`-max_relative_positions`, `max_relative_positions`]. 697 | 698 | rel_embeddings (:obj:`torch.FloatTensor`): 699 | The embedding of relative distances. It's a tensor of shape [:math:`2 \\times 700 | \\text{max_relative_positions}`, `hidden_size`]. 701 | 702 | 703 | """ 704 | if query_states is None: 705 | query_states = hidden_states 706 | query_layer = self.transpose_for_scores( 707 | self.query_proj(query_states), self.num_attention_heads 708 | ) 709 | key_layer = self.transpose_for_scores( 710 | self.key_proj(hidden_states), self.num_attention_heads 711 | ) 712 | value_layer = self.transpose_for_scores( 713 | self.value_proj(hidden_states), self.num_attention_heads 714 | ) 715 | 716 | rel_att = None 717 | # Take the dot product between "query" and "key" to get the raw attention scores. 718 | scale_factor = 1 719 | if "c2p" in self.pos_att_type: 720 | scale_factor += 1 721 | if "p2c" in self.pos_att_type: 722 | scale_factor += 1 723 | if "p2p" in self.pos_att_type: 724 | scale_factor += 1 725 | scale = math.sqrt(query_layer.size(-1) * scale_factor) 726 | attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale 727 | if self.relative_attention: 728 | rel_embeddings = self.pos_dropout(rel_embeddings) 729 | rel_att = self.disentangled_attention_bias( 730 | query_layer, key_layer, relative_pos, rel_embeddings, scale_factor 731 | ) 732 | 733 | if rel_att is not None: 734 | attention_scores = attention_scores + rel_att 735 | attention_scores = attention_scores 736 | attention_scores = attention_scores.view( 737 | -1, 738 | self.num_attention_heads, 739 | attention_scores.size(-2), 740 | attention_scores.size(-1), 741 | ) 742 | 743 | # bsz x height x length x dimension 744 | attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1) 745 | attention_probs = self.dropout(attention_probs) 746 | context_layer = torch.bmm( 747 | attention_probs.view( 748 | -1, attention_probs.size(-2), attention_probs.size(-1) 749 | ), 750 | value_layer, 751 | ) 752 | context_layer = ( 753 | context_layer.view( 754 | -1, 755 | self.num_attention_heads, 756 | context_layer.size(-2), 757 | context_layer.size(-1), 758 | ) 759 | .permute(0, 2, 1, 3) 760 | .contiguous() 761 | ) 762 | new_context_layer_shape = context_layer.size()[:-2] + (-1,) 763 | context_layer = context_layer.view(*new_context_layer_shape) 764 | if return_att: 765 | return (context_layer, attention_probs) 766 | else: 767 | return context_layer 768 | 769 | def disentangled_attention_bias( 770 | self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor 771 | ): 772 | if relative_pos is None: 773 | q = query_layer.size(-2) 774 | relative_pos = build_relative_position( 775 | q, 776 | key_layer.size(-2), 777 | bucket_size=self.position_buckets, 778 | max_position=self.max_relative_positions, 779 | ) 780 | if relative_pos.dim() == 2: 781 | relative_pos = relative_pos.unsqueeze(0).unsqueeze(0) 782 | elif relative_pos.dim() == 3: 783 | relative_pos = relative_pos.unsqueeze(1) 784 | # bsz x height x query x key 785 | elif relative_pos.dim() != 4: 786 | raise ValueError( 787 | f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}" 788 | ) 789 | 790 | att_span = self.pos_ebd_size 791 | relative_pos = relative_pos.long().to(query_layer.device) 792 | 793 | rel_embeddings = rel_embeddings[ 794 | self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, : 795 | ].unsqueeze(0) 796 | if self.share_att_key: 797 | pos_query_layer = self.transpose_for_scores( 798 | self.query_proj(rel_embeddings), self.num_attention_heads 799 | ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) 800 | pos_key_layer = self.transpose_for_scores( 801 | self.key_proj(rel_embeddings), self.num_attention_heads 802 | ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1) 803 | else: 804 | if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type: 805 | pos_key_layer = self.transpose_for_scores( 806 | self.pos_key_proj(rel_embeddings), self.num_attention_heads 807 | ).repeat( 808 | query_layer.size(0) // self.num_attention_heads, 1, 1 809 | ) # .split(self.all_head_size, dim=-1) 810 | if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: 811 | pos_query_layer = self.transpose_for_scores( 812 | self.pos_query_proj(rel_embeddings), self.num_attention_heads 813 | ).repeat( 814 | query_layer.size(0) // self.num_attention_heads, 1, 1 815 | ) # .split(self.all_head_size, dim=-1) 816 | 817 | score = 0 818 | # content->position 819 | if "c2p" in self.pos_att_type: 820 | scale = math.sqrt(pos_key_layer.size(-1) * scale_factor) 821 | c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2)) 822 | c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1) 823 | c2p_att = torch.gather( 824 | c2p_att, 825 | dim=-1, 826 | index=c2p_pos.squeeze(0).expand( 827 | [query_layer.size(0), query_layer.size(1), relative_pos.size(-1)] 828 | ), 829 | ) 830 | score += c2p_att / scale 831 | 832 | # position->content 833 | if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type: 834 | scale = math.sqrt(pos_query_layer.size(-1) * scale_factor) 835 | if key_layer.size(-2) != query_layer.size(-2): 836 | r_pos = build_relative_position( 837 | key_layer.size(-2), 838 | key_layer.size(-2), 839 | bucket_size=self.position_buckets, 840 | max_position=self.max_relative_positions, 841 | ).to(query_layer.device) 842 | r_pos = r_pos.unsqueeze(0) 843 | else: 844 | r_pos = relative_pos 845 | 846 | p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1) 847 | if query_layer.size(-2) != key_layer.size(-2): 848 | pos_index = relative_pos[:, :, :, 0].unsqueeze(-1) 849 | 850 | if "p2c" in self.pos_att_type: 851 | p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2)) 852 | p2c_att = torch.gather( 853 | p2c_att, 854 | dim=-1, 855 | index=p2c_pos.squeeze(0).expand( 856 | [query_layer.size(0), key_layer.size(-2), key_layer.size(-2)] 857 | ), 858 | ).transpose(-1, -2) 859 | if query_layer.size(-2) != key_layer.size(-2): 860 | p2c_att = torch.gather( 861 | p2c_att, 862 | dim=-2, 863 | index=pos_index.expand( 864 | p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)) 865 | ), 866 | ) 867 | score += p2c_att / scale 868 | 869 | # position->position 870 | if "p2p" in self.pos_att_type: 871 | pos_query = pos_query_layer[:, :, att_span:, :] 872 | p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2)) 873 | p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:]) 874 | if query_layer.size(-2) != key_layer.size(-2): 875 | p2p_att = torch.gather( 876 | p2p_att, 877 | dim=-2, 878 | index=pos_index.expand( 879 | query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1)) 880 | ), 881 | ) 882 | p2p_att = torch.gather( 883 | p2p_att, 884 | dim=-1, 885 | index=c2p_pos.expand( 886 | [ 887 | query_layer.size(0), 888 | query_layer.size(1), 889 | query_layer.size(2), 890 | relative_pos.size(-1), 891 | ] 892 | ), 893 | ) 894 | score += p2p_att 895 | 896 | return score 897 | 898 | 899 | # Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm 900 | class DebertaV2Embeddings(nn.Module): 901 | """Construct the embeddings from word, position and token_type embeddings.""" 902 | 903 | def __init__(self, config, features_dim): 904 | super().__init__() 905 | pad_token_id = getattr(config, "pad_token_id", 0) 906 | self.embedding_size = getattr(config, "embedding_size", config.hidden_size) 907 | self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id) 908 | 909 | self.position_biased_input = getattr(config, "position_biased_input", True) 910 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) # it is used for the decoder anyway 911 | 912 | if config.type_vocab_size > 0: 913 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size) 914 | 915 | if self.embedding_size != config.hidden_size: 916 | self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False) 917 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps) 918 | self.dropout = StableDropout(config.hidden_dropout_prob) 919 | self.config = config 920 | 921 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized 922 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) 923 | 924 | self.features_dim = features_dim 925 | if self.features_dim: 926 | self.linear_video = nn.Linear(features_dim, config.hidden_size) 927 | 928 | # self.prompt_embedding = nn.Embedding(32, config.hidden_size) 929 | self.prompt_embedding = None 930 | 931 | def get_video_embedding(self, video): 932 | video = self.linear_video(video) 933 | return video 934 | 935 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, video=None): 936 | if input_ids is not None: 937 | input_shape = input_ids.size() 938 | else: 939 | input_shape = inputs_embeds.size()[:-1] 940 | 941 | if inputs_embeds is None: 942 | inputs_embeds = self.word_embeddings(input_ids) 943 | if self.features_dim and video is not None: 944 | video = self.get_video_embedding(video) 945 | if self.prompt_embedding is not None: 946 | inputs_embeds = torch.cat([video, inputs_embeds, self.prompt_embedding.weight[None, :, :].repeat(video.shape[0], 1, 1)], 1) 947 | input_shape = inputs_embeds[:, :, 0].shape 948 | else: 949 | inputs_embeds = torch.cat([video, inputs_embeds], 1) 950 | input_shape = inputs_embeds[:, :, 0].shape 951 | 952 | seq_length = input_shape[1] 953 | 954 | if position_ids is None: 955 | position_ids = self.position_ids[:, :seq_length] 956 | 957 | if token_type_ids is None: 958 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) 959 | 960 | if self.position_embeddings is not None: 961 | position_embeddings = self.position_embeddings(position_ids.long()) 962 | else: 963 | position_embeddings = torch.zeros_like(inputs_embeds) 964 | embeddings = inputs_embeds 965 | 966 | if self.position_biased_input: 967 | embeddings += position_embeddings 968 | # embeddings[:, :position_embeddings.shape[1], :] += position_embeddings 969 | if self.config.type_vocab_size > 0: 970 | token_type_embeddings = self.token_type_embeddings(token_type_ids) 971 | embeddings += token_type_embeddings 972 | 973 | if self.embedding_size != self.config.hidden_size: 974 | embeddings = self.embed_proj(embeddings) 975 | 976 | 977 | embeddings = self.LayerNorm(embeddings) 978 | 979 | if mask is not None: 980 | if mask.dim() != embeddings.dim(): 981 | if mask.dim() == 4: 982 | mask = mask.squeeze(1).squeeze(1) 983 | mask = mask.unsqueeze(2) 984 | mask = mask.to(embeddings.dtype) 985 | 986 | embeddings = embeddings * mask 987 | 988 | embeddings = self.dropout(embeddings) 989 | return {"embeddings": embeddings, "position_embeddings": position_embeddings} 990 | 991 | 992 | # Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2 993 | 994 | 995 | class DebertaV2PreTrainedModel(PreTrainedModel): 996 | """ 997 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained 998 | models. 999 | """ 1000 | 1001 | config_class = DebertaV2Config 1002 | base_model_prefix = "deberta" 1003 | _keys_to_ignore_on_load_missing = ["position_ids"] 1004 | _keys_to_ignore_on_load_unexpected = ["position_embeddings"] 1005 | 1006 | def __init__(self, config): 1007 | super().__init__(config) 1008 | self._register_load_state_dict_pre_hook(self._pre_load_hook) 1009 | 1010 | def _init_weights(self, module): 1011 | """Initialize the weights.""" 1012 | if isinstance(module, nn.Linear): 1013 | # Slightly different from the TF version which uses truncated_normal for initialization 1014 | # cf https://github.com/pytorch/pytorch/pull/5617 1015 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 1016 | if module.bias is not None: 1017 | module.bias.data.zero_() 1018 | elif isinstance(module, nn.Embedding): 1019 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) 1020 | if module.padding_idx is not None: 1021 | module.weight.data[module.padding_idx].zero_() 1022 | 1023 | def _pre_load_hook( 1024 | self, 1025 | state_dict, 1026 | prefix, 1027 | local_metadata, 1028 | strict, 1029 | missing_keys, 1030 | unexpected_keys, 1031 | error_msgs, 1032 | ): 1033 | """ 1034 | Removes the classifier if it doesn't have the correct number of labels. 1035 | """ 1036 | self_state = self.state_dict() 1037 | if ( 1038 | ("classifier.weight" in self_state) 1039 | and ("classifier.weight" in state_dict) 1040 | and self_state["classifier.weight"].size() 1041 | != state_dict["classifier.weight"].size() 1042 | ): 1043 | print( 1044 | f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model " 1045 | f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint " 1046 | f"weights. You should train your model on new data." 1047 | ) 1048 | del state_dict["classifier.weight"] 1049 | if "classifier.bias" in state_dict: 1050 | del state_dict["classifier.bias"] 1051 | 1052 | 1053 | # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2 1054 | class DebertaV2Model(DebertaV2PreTrainedModel): 1055 | def __init__(self, config, max_feats=10, features_dim=768, freeze_lm=False, ds_factor_attn=8, ds_factor_ff=8, ft_ln=False, dropout=0.1): 1056 | super().__init__(config) 1057 | 1058 | self.embeddings = DebertaV2Embeddings(config, features_dim) 1059 | self.encoder = DebertaV2Encoder(config, ds_factor_attn, ds_factor_ff, dropout) 1060 | self.z_steps = 0 1061 | self.config = config 1062 | 1063 | self.features_dim = features_dim 1064 | self.max_feats = max_feats 1065 | if freeze_lm: 1066 | for n, p in self.named_parameters(): 1067 | if (not "linear_video" in n) and (not "adapter" in n): 1068 | if ft_ln and "LayerNorm" in n: 1069 | continue 1070 | else: 1071 | p.requires_grad_(False) 1072 | 1073 | self.init_weights() 1074 | 1075 | def get_input_embeddings(self): 1076 | return self.embeddings.word_embeddings 1077 | 1078 | def set_input_embeddings(self, new_embeddings): 1079 | self.embeddings.word_embeddings = new_embeddings 1080 | 1081 | def _prune_heads(self, heads_to_prune): 1082 | """ 1083 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 1084 | class PreTrainedModel 1085 | """ 1086 | raise NotImplementedError("The prune function is not implemented in DeBERTa model.") 1087 | 1088 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None, 1089 | return_dict=None, video=None, video_mask=None): 1090 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 1091 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 1092 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1093 | 1094 | if input_ids is not None and inputs_embeds is not None: 1095 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") 1096 | elif input_ids is not None: 1097 | input_shape = input_ids.size() 1098 | elif inputs_embeds is not None: 1099 | input_shape = inputs_embeds.size()[:-1] 1100 | else: 1101 | raise ValueError("You have to specify either input_ids or inputs_embeds") 1102 | 1103 | device = input_ids.device if input_ids is not None else inputs_embeds.device 1104 | 1105 | if attention_mask is None: 1106 | attention_mask = torch.ones(input_shape, device=device) 1107 | 1108 | if self.features_dim and video is not None: 1109 | if video_mask is None: 1110 | video_shape = video[:, :, 0].size() 1111 | video_mask = torch.ones(video_shape, device=device) 1112 | attention_mask = torch.cat([video_mask, attention_mask], 1) 1113 | # attention_mask = torch.cat([video_mask, attention_mask, torch.ones(video_mask.shape[0], 32).to(video_mask.device)], 1) 1114 | input_shape = attention_mask.size() 1115 | 1116 | if token_type_ids is None: 1117 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) 1118 | 1119 | embedding_output = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, mask=attention_mask, inputs_embeds=inputs_embeds, video=video) 1120 | embedding_output, position_embeddings = embedding_output["embeddings"], embedding_output["position_embeddings"] 1121 | 1122 | encoder_outputs = self.encoder(embedding_output, attention_mask, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict) 1123 | encoded_layers = encoder_outputs[1] 1124 | 1125 | if self.z_steps > 1: 1126 | hidden_states = encoded_layers[-2] 1127 | layers = [self.encoder.layer[-1] for _ in range(self.z_steps)] 1128 | query_states = encoded_layers[-1] 1129 | rel_embeddings = self.encoder.get_rel_embedding() 1130 | attention_mask = self.encoder.get_attention_mask(attention_mask) 1131 | rel_pos = self.encoder.get_rel_pos(embedding_output) 1132 | for layer in layers[1:]: 1133 | query_states = layer(hidden_states, attention_mask, return_att=False, query_states=query_states, relative_pos=rel_pos, rel_embeddings=rel_embeddings) 1134 | encoded_layers.append(query_states) 1135 | 1136 | sequence_output = encoded_layers[-1] 1137 | 1138 | if not return_dict: 1139 | return sequence_output + encoder_outputs[1 if output_hidden_states else 2:] 1140 | 1141 | return BaseModelOutput(last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, 1142 | attentions=encoder_outputs.attentions, position_embeddings=position_embeddings, attention_mask=attention_mask) 1143 | 1144 | 1145 | # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2 1146 | class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel): 1147 | _keys_to_ignore_on_load_unexpected = [r"pooler"] 1148 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] 1149 | 1150 | def __init__(self, config, max_feats=10, features_dim=768, freeze_lm=True, freeze_mlm=True, ds_factor_attn=8, ds_factor_ff=8, ft_ln=True, 1151 | dropout=0.1, n_ans=0, freeze_last=True, args=None): 1152 | """ 1153 | :param config: BiLM configuration 1154 | :param max_feats: maximum number of frames used by the model 1155 | :param features_dim: embedding dimension of the visual features, set = 0 for text-only mode 1156 | :param freeze_lm: whether to freeze or not the language model (Transformer encoder + token embedder) 1157 | :param freeze_mlm: whether to freeze or not the MLM head 1158 | :param ds_factor_attn: downsampling factor for the adapter after self-attention, no adapter if set to 0 1159 | :param ds_factor_ff: downsampling factor for the adapter after feed-forward, no adapter if set to 0 1160 | :param ft_ln: whether to finetune or not the normalization layers 1161 | :param dropout: dropout probability in the adapter 1162 | :param n_ans: number of answers in the downstream vocabulary, set = 0 during cross-modal training 1163 | :param freeze_last: whether to freeze or not the answer embedding module 1164 | """ 1165 | super().__init__(config) 1166 | 1167 | self.deberta = DebertaV2Model(config, max_feats, features_dim, freeze_lm, ds_factor_attn, ds_factor_ff, ft_ln, dropout) 1168 | self.lm_predictions = DebertaV2OnlyMLMHead(config, args) 1169 | self.features_dim = features_dim 1170 | if freeze_mlm: 1171 | for n, p in self.lm_predictions.named_parameters(): 1172 | if ft_ln and "LayerNorm" in n: 1173 | continue 1174 | else: 1175 | p.requires_grad_(False) 1176 | 1177 | self.init_weights() 1178 | self.n_ans = n_ans 1179 | if n_ans: 1180 | self.answer_embeddings = nn.Embedding(n_ans, self.deberta.embeddings.embedding_size) 1181 | self.answer_bias = nn.Parameter(torch.zeros(n_ans)) 1182 | if freeze_last: 1183 | self.answer_embeddings.requires_grad_(False) 1184 | self.answer_bias.requires_grad_(False) 1185 | 1186 | self.lm_predictions.lm_head.gnn1.lin_src.weight.data = torch.eye(config.hidden_size) 1187 | self.lm_predictions.lm_head.gnn2.lin_src.weight.data = torch.eye(config.hidden_size) 1188 | self.lm_predictions.lm_head.gnn1.lin_dst.weight.data = torch.eye(config.hidden_size) 1189 | self.lm_predictions.lm_head.gnn2.lin_dst.weight.data = torch.eye(config.hidden_size) 1190 | self.lm_predictions.lm_head.gnn1.lin_src.weight.requires_grad_(True) 1191 | self.lm_predictions.lm_head.gnn2.lin_src.weight.requires_grad_(True) 1192 | self.lm_predictions.lm_head.gnn1.lin_dst.weight.requires_grad_(True) 1193 | self.lm_predictions.lm_head.gnn2.lin_dst.weight.requires_grad_(True) 1194 | 1195 | def get_output_embeddings(self): 1196 | return self.lm_predictions.lm_head.decoder 1197 | 1198 | def set_output_embeddings(self, new_embeddings): 1199 | self.lm_predictions.lm_head.decoder = new_embeddings 1200 | 1201 | def set_answer_embeddings(self, a2tok, freeze_last=True): 1202 | a2v = self.deberta.embeddings.word_embeddings(a2tok) 1203 | pad_token_id = getattr(self.config, "pad_token_id", 0) 1204 | sum_tokens = (a2tok != pad_token_id).sum(1, keepdims=True) # n_ans 1205 | if len(a2v) != self.n_ans: # reinitialize the answer embeddings 1206 | # assert not self.training 1207 | self.n_ans = len(a2v) 1208 | self.answer_embeddings = nn.Embedding(self.n_ans, self.deberta.embeddings.embedding_size).to(self.device) 1209 | self.answer_bias.requires_grad = False 1210 | self.answer_bias.resize_(self.n_ans) 1211 | self.answer_embeddings.weight.data = torch.div((a2v * (a2tok != pad_token_id).float()[:, :, None]).sum(1), sum_tokens.clamp(min=1)) # n_ans 1212 | a2b = self.lm_predictions.lm_head.bias[a2tok] 1213 | self.answer_bias.weight = torch.div((a2b * (a2tok != pad_token_id).float()).sum(1), sum_tokens.clamp(min=1)) 1214 | if freeze_last: 1215 | self.answer_embeddings.requires_grad_(False) 1216 | self.answer_bias.requires_grad_(False) 1217 | 1218 | def get_verbalized_embeddings(self, verbalized_id): 1219 | verbalized_embeddings = self.deberta.embeddings.word_embeddings(verbalized_id) 1220 | pad_token_id = getattr(self.config, "pad_token_id", 0) 1221 | sum_tokens = (verbalized_id != pad_token_id).sum(1, keepdims=True) 1222 | return verbalized_embeddings, sum_tokens 1223 | 1224 | def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder): 1225 | if attention_mask.dim() <= 2: 1226 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 1227 | att_mask = extended_attention_mask.byte() 1228 | attention_mask = att_mask * att_mask.squeeze(-2).unsqueeze(-1) 1229 | elif attention_mask.dim() == 3: 1230 | attention_mask = attention_mask.unsqueeze(1) 1231 | hidden_states = encoder_layers[-2] 1232 | 1233 | if not self.config.position_biased_input: 1234 | layers = [encoder.layer[-1] for _ in range(2)] 1235 | z_states += hidden_states 1236 | query_states = z_states 1237 | query_mask = attention_mask 1238 | outputs = [] 1239 | rel_embeddings = encoder.get_rel_embedding() 1240 | for layer in layers: 1241 | output = layer(hidden_states, query_mask, return_att=False, query_states=query_states, relative_pos=None, rel_embeddings=rel_embeddings) 1242 | query_states = output 1243 | outputs.append(query_states) 1244 | else: 1245 | outputs = [encoder_layers[-1]] 1246 | 1247 | return outputs 1248 | 1249 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, output_attentions=None, 1250 | return_dict=None, video=None, video_mask=None, mlm=False, eps=None, edge_index=None, vocab_embeddings=None): 1251 | r""" 1252 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 1253 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., 1254 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored 1255 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` 1256 | """ 1257 | 1258 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1259 | 1260 | outputs = self.deberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, 1261 | output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, video=video, video_mask=video_mask) 1262 | 1263 | 1264 | if labels is not None: 1265 | if self.features_dim and video is not None: # ignore the label predictions for visual tokens 1266 | video_shape = video[:, :, 0].size() 1267 | video_labels = torch.tensor([[-100] * video_shape[1]] * video_shape[0], dtype=torch.long, device=labels.device) 1268 | labels = torch.cat([video_labels, labels], 1) 1269 | 1270 | # sequence_output = outputs[0] 1271 | modified = self.emd_context_layer(encoder_layers=outputs["hidden_states"], 1272 | z_states=outputs["position_embeddings"].repeat(input_ids.shape[0] // len(outputs["position_embeddings"]), 1, 1), 1273 | attention_mask=outputs["attention_mask"], encoder=self.deberta.encoder) 1274 | 1275 | bias = None 1276 | if self.n_ans and (not mlm): # downstream mode 1277 | embeddings = self.answer_embeddings.weight 1278 | bias = self.answer_bias 1279 | else: 1280 | embeddings = self.deberta.embeddings.word_embeddings.weight 1281 | 1282 | prediction_scores = self.lm_predictions(modified[-1], embeddings, edge_index, vocab_embeddings, eps, bias) 1283 | masked_lm_loss = None 1284 | if labels is not None: 1285 | loss_fct = CrossEntropyLoss() # -100 index = padding token 1286 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) # labels[labels > 0].view(-1) 1287 | 1288 | if not return_dict: 1289 | output = (prediction_scores,) + outputs[1:] 1290 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output 1291 | 1292 | return MaskedLMOutput(loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions) 1293 | 1294 | 1295 | # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta 1296 | class DebertaV2PredictionHeadTransform(nn.Module): 1297 | def __init__(self, config): 1298 | super().__init__() 1299 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1300 | if isinstance(config.hidden_act, str): 1301 | self.transform_act_fn = ACT2FN[config.hidden_act] 1302 | else: 1303 | self.transform_act_fn = config.hidden_act 1304 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 1305 | 1306 | def forward(self, hidden_states): 1307 | hidden_states = self.dense(hidden_states) 1308 | hidden_states = self.transform_act_fn(hidden_states) 1309 | hidden_states = self.LayerNorm(hidden_states) 1310 | return hidden_states 1311 | 1312 | 1313 | # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta 1314 | class DebertaV2LMPredictionHead(nn.Module): 1315 | def __init__(self, config, args): 1316 | super().__init__() 1317 | self.dense = nn.Linear(config.hidden_size, config.hidden_size) 1318 | if isinstance(config.hidden_act, str): 1319 | self.transform_act_fn = ACT2FN[config.hidden_act] 1320 | else: 1321 | self.transform_act_fn = config.hidden_act 1322 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) 1323 | 1324 | # The output weights are the same as the input embeddings, but there is 1325 | # an output-only bias for each token. 1326 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # only for compatiblity 1327 | 1328 | self.bias = nn.Parameter(torch.zeros(config.vocab_size)) 1329 | 1330 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` 1331 | self.decoder.bias = self.bias # only for compatiblity 1332 | 1333 | self.gnn1 = GNNSoftVerbalizer(config.hidden_size, config.hidden_size) 1334 | self.gnn2 = GNNSoftVerbalizer(config.hidden_size, config.hidden_size) 1335 | self.args = args 1336 | 1337 | def forward(self, hidden_states, embedding_weight, edge_index, vocab_embeddings, eps, bias=None): 1338 | hidden_states = self.dense(hidden_states) 1339 | hidden_states = self.transform_act_fn(hidden_states) 1340 | hidden_states = self.LayerNorm(hidden_states) 1341 | 1342 | if bias is not None: 1343 | _embedding_weight = self.gnn1(vocab_embeddings, edge_index) 1344 | _embedding_weight = self.gnn2(_embedding_weight, edge_index)[:embedding_weight.shape[0]] 1345 | _embedding_weight = eps * embedding_weight + (1 - eps) * _embedding_weight 1346 | 1347 | logits = torch.matmul(hidden_states, _embedding_weight.t().to(hidden_states)) + bias 1348 | 1349 | 1350 | else: 1351 | logits = torch.matmul(hidden_states, embedding_weight.t().to(hidden_states)) + self.bias 1352 | 1353 | return logits 1354 | 1355 | # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta 1356 | class DebertaV2OnlyMLMHead(nn.Module): 1357 | def __init__(self, config, args): 1358 | super().__init__() 1359 | self.lm_head = DebertaV2LMPredictionHead(config, args) 1360 | 1361 | def forward(self, sequence_output, embedding_weight, edge_index, vocab_embeddings, eps, bias=None): 1362 | prediction_scores = self.lm_head(sequence_output, embedding_weight, edge_index, vocab_embeddings, eps, bias=bias) 1363 | return prediction_scores 1364 | -------------------------------------------------------------------------------- /model/gnn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import Parameter 6 | 7 | from torch_geometric.nn.conv import MessagePassing 8 | from torch_geometric.nn.dense.linear import Linear 9 | from torch_geometric.nn.inits import zeros 10 | import torch_sparse 11 | from torch_geometric.utils import ( 12 | add_remaining_self_loops, 13 | is_torch_sparse_tensor, 14 | scatter, 15 | spmm, 16 | to_torch_coo_tensor, 17 | ) 18 | from torch_geometric.utils.num_nodes import maybe_num_nodes 19 | 20 | 21 | from typing import Optional, Tuple, Union 22 | 23 | 24 | import torch.nn.functional as F 25 | import torch_geometric.typing 26 | from torch_geometric.typing import NoneType # noqa 27 | from torch_geometric.typing import ( 28 | Adj, 29 | OptPairTensor, 30 | OptTensor, 31 | Size, 32 | SparseTensor 33 | ) 34 | from torch_geometric.utils import add_self_loops, remove_self_loops, scatter, softmax 35 | from torch_geometric.nn.inits import glorot, zeros 36 | 37 | 38 | class GNNSoftVerbalizer(MessagePassing): 39 | r"""The graph attentional operator from the `"Graph Attention Networks" 40 | `_ paper 41 | 42 | .. math:: 43 | \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + 44 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, 45 | 46 | where the attention coefficients :math:`\alpha_{i,j}` are computed as 47 | 48 | .. math:: 49 | \alpha_{i,j} = 50 | \frac{ 51 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 52 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] 53 | \right)\right)} 54 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} 55 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 56 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] 57 | \right)\right)}. 58 | 59 | If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, 60 | the attention coefficients :math:`\alpha_{i,j}` are computed as 61 | 62 | .. math:: 63 | \alpha_{i,j} = 64 | \frac{ 65 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 66 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j 67 | \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,j}]\right)\right)} 68 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} 69 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} 70 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k 71 | \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,k}]\right)\right)}. 72 | 73 | Args: 74 | in_channels (int or tuple): Size of each input sample, or :obj:`-1` to 75 | derive the size from the first input(s) to the forward method. 76 | A tuple corresponds to the sizes of source and target 77 | dimensionalities. 78 | out_channels (int): Size of each output sample. 79 | heads (int, optional): Number of multi-head-attentions. 80 | (default: :obj:`1`) 81 | concat (bool, optional): If set to :obj:`False`, the multi-head 82 | attentions are averaged instead of concatenated. 83 | (default: :obj:`True`) 84 | negative_slope (float, optional): LeakyReLU angle of the negative 85 | slope. (default: :obj:`0.2`) 86 | dropout (float, optional): Dropout probability of the normalized 87 | attention coefficients which exposes each node to a stochastically 88 | sampled neighborhood during training. (default: :obj:`0`) 89 | add_self_loops (bool, optional): If set to :obj:`False`, will not add 90 | self-loops to the input graph. (default: :obj:`True`) 91 | edge_dim (int, optional): Edge feature dimensionality (in case 92 | there are any). (default: :obj:`None`) 93 | fill_value (float or torch.Tensor or str, optional): The way to 94 | generate edge features of self-loops (in case 95 | :obj:`edge_dim != None`). 96 | If given as :obj:`float` or :class:`torch.Tensor`, edge features of 97 | self-loops will be directly given by :obj:`fill_value`. 98 | If given as :obj:`str`, edge features of self-loops are computed by 99 | aggregating all features of edges that point to the specific node, 100 | according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, 101 | :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`) 102 | bias (bool, optional): If set to :obj:`False`, the layer will not learn 103 | an additive bias. (default: :obj:`True`) 104 | **kwargs (optional): Additional arguments of 105 | :class:`torch_geometric.nn.conv.MessagePassing`. 106 | 107 | Shapes: 108 | - **input:** 109 | node features :math:`(|\mathcal{V}|, F_{in})` or 110 | :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` 111 | if bipartite, 112 | edge indices :math:`(2, |\mathcal{E}|)`, 113 | edge features :math:`(|\mathcal{E}|, D)` *(optional)* 114 | - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or 115 | :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. 116 | If :obj:`return_attention_weights=True`, then 117 | :math:`((|\mathcal{V}|, H * F_{out}), 118 | ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` 119 | or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), 120 | (|\mathcal{E}|, H)))` if bipartite 121 | """ 122 | def __init__( 123 | self, 124 | in_channels: Union[int, Tuple[int, int]], 125 | out_channels: int, 126 | heads: int = 1, 127 | concat: bool = True, 128 | negative_slope: float = 0.2, 129 | dropout: float = 0.0, 130 | add_self_loops: bool = True, 131 | edge_dim: Optional[int] = None, 132 | fill_value: Union[float, Tensor, str] = 'mean', 133 | bias: bool = True, 134 | **kwargs, 135 | ): 136 | kwargs.setdefault('aggr', 'add') 137 | super().__init__(node_dim=0, **kwargs) 138 | 139 | self.in_channels = in_channels 140 | self.out_channels = out_channels 141 | self.heads = heads 142 | self.concat = concat 143 | self.negative_slope = negative_slope 144 | self.dropout = dropout 145 | self.add_self_loops = add_self_loops 146 | self.edge_dim = edge_dim 147 | self.fill_value = fill_value 148 | 149 | # In case we are operating in bipartite graphs, we apply separate 150 | # transformations 'lin_src' and 'lin_dst' to source and target nodes: 151 | if isinstance(in_channels, int): 152 | self.lin_src = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') 153 | self.lin_dst = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot') 154 | # self.lin_dst = self.lin_src 155 | else: 156 | self.lin_src = Linear(in_channels[0], heads * out_channels, False, weight_initializer='glorot') 157 | self.lin_dst = Linear(in_channels[1], heads * out_channels, False, weight_initializer='glorot') 158 | 159 | 160 | if edge_dim is not None: 161 | self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, weight_initializer='glorot') 162 | self.att_edge = Parameter(torch.Tensor(1, heads, out_channels)) 163 | else: 164 | self.lin_edge = None 165 | self.register_parameter('att_edge', None) 166 | 167 | self.reset_parameters() 168 | 169 | def reset_parameters(self): 170 | # super().reset_parameters() 171 | self.lin_src.reset_parameters() 172 | self.lin_dst.reset_parameters() 173 | if self.lin_edge is not None: 174 | self.lin_edge.reset_parameters() 175 | glorot(self.att_edge) 176 | 177 | 178 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, 179 | edge_attr: OptTensor = None, size: Size = None, 180 | return_attention_weights=None): 181 | # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa 182 | # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa 183 | # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa 184 | # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa 185 | r"""Runs the forward pass of the module. 186 | 187 | Args: 188 | return_attention_weights (bool, optional): If set to :obj:`True`, 189 | will additionally return the tuple 190 | :obj:`(edge_index, attention_weights)`, holding the computed 191 | attention weights for each edge. (default: :obj:`None`) 192 | """ 193 | # NOTE: attention weights will be returned whenever 194 | # `return_attention_weights` is set to a value, regardless of its 195 | # actual value (might be `True` or `False`). This is a current somewhat 196 | # hacky workaround to allow for TorchScript support via the 197 | # `torch.jit._overload` decorator, as we can only change the output 198 | # arguments conditioned on type (`None` or `bool`), not based on its 199 | # actual value. 200 | 201 | H, C = self.heads, self.out_channels 202 | 203 | # We first transform the input node features. If a tuple is passed, we 204 | # transform source and target node features via separate weights: 205 | if isinstance(x, Tensor): 206 | assert x.dim() == 2, "Static graphs not supported in 'GATConv'" 207 | # x_src = x_dst = self.lin_src(x).view(-1, H, C) 208 | x_src = x_dst = x.view(-1, H, C) 209 | else: # Tuple of source and target node features: 210 | x_src, x_dst = x 211 | assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" 212 | x_src = self.lin_src(x_src).view(-1, H, C) 213 | if x_dst is not None: 214 | x_dst = self.lin_dst(x_dst).view(-1, H, C) 215 | 216 | x = (x_src, x_dst) 217 | 218 | # Next, we compute node-level attention coefficients, both for source 219 | # and target nodes (if present): 220 | 221 | if self.add_self_loops: 222 | if isinstance(edge_index, Tensor): 223 | # We only want to add self-loops for nodes that appear both as 224 | # source and target nodes: 225 | num_nodes = x_src.size(0) 226 | if x_dst is not None: 227 | num_nodes = min(num_nodes, x_dst.size(0)) 228 | num_nodes = min(size) if size is not None else num_nodes 229 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) 230 | edge_index, edge_attr = add_self_loops(edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes) 231 | elif isinstance(edge_index, SparseTensor): 232 | if self.edge_dim is None: 233 | edge_index = torch_sparse.set_diag(edge_index) 234 | else: 235 | raise NotImplementedError( 236 | "The usage of 'edge_attr' and 'add_self_loops' " 237 | "simultaneously is currently not yet supported for " 238 | "'edge_index' in a 'SparseTensor' form") 239 | 240 | # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor) 241 | alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr) 242 | 243 | # propagate_type: (x: OptPairTensor, alpha: Tensor) 244 | out = self.propagate(edge_index, x=x, alpha=alpha, size=size) 245 | 246 | if self.concat: 247 | out = out.view(-1, self.heads * self.out_channels) 248 | else: 249 | out = out.mean(dim=1) 250 | 251 | if isinstance(return_attention_weights, bool): 252 | if isinstance(edge_index, Tensor): 253 | return out, (edge_index, alpha) 254 | elif isinstance(edge_index, SparseTensor): 255 | return out, edge_index.set_value(alpha, layout='coo') 256 | else: 257 | return out 258 | 259 | 260 | def edge_update(self, x_j: Tensor, x_i: OptTensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor: 261 | # Given edge-level attention coefficients for source and target nodes, 262 | # we simply need to sum them up to "emulate" concatenation: 263 | 264 | _x_i, _x_j = self.lin_dst(x_i), self.lin_src(x_j) 265 | alpha = torch.bmm(_x_i, _x_j.transpose(1, 2)).squeeze(-1) 266 | 267 | if edge_attr is not None and self.lin_edge is not None: 268 | if edge_attr.dim() == 1: 269 | edge_attr = edge_attr.view(-1, 1) 270 | edge_attr = self.lin_edge(edge_attr) 271 | edge_attr = edge_attr.view(-1, self.heads, self.out_channels) 272 | alpha_edge = (edge_attr * self.att_edge).sum(dim=-1) 273 | alpha = alpha + alpha_edge 274 | 275 | alpha = F.leaky_relu(alpha, self.negative_slope) 276 | 277 | 278 | alpha = softmax(alpha, index, ptr, size_i) 279 | alpha = F.dropout(alpha, p=self.dropout, training=self.training) 280 | return alpha 281 | 282 | def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: 283 | return alpha.unsqueeze(-1) * x_j 284 | 285 | def __repr__(self) -> str: 286 | return (f'{self.__class__.__name__}({self.in_channels}, ' 287 | f'{self.out_channels}, heads={self.heads})') -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html 2 | pip install clip 3 | pip install jsonlines==3.0.0 4 | pip install numpy==1.22.0 5 | pip install pandas==1.4.1 6 | pip install Pillow==9.3.0 7 | pip install scikit-learn==1.0.2 8 | pip install scipy==1.8.0 9 | pip install sentencepiece==0.1.96 10 | pip install tokenizers==0.11.6 11 | pip install tqdm==4.63.1 12 | pip install transformers==4.17.0 13 | pip install hostlist==1.4.8 14 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_cluster-1.5.9-cp38-cp38-linux_x86_64.whl 15 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_scatter-2.0.8-cp38-cp38-linux_x86_64.whl 16 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_sparse-0.6.11-cp38-cp38-linux_x86_64.whl 17 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_spline_conv-1.2.1-cp38-cp38-linux_x86_64.whl 18 | pip install torch-geometric==2.2.0 -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | import random 7 | import json 8 | import math 9 | import sys 10 | from typing import Iterable 11 | import argparse 12 | import time 13 | import datetime 14 | from util import dist 15 | from torch.utils.data import DataLoader, DistributedSampler 16 | from collections import namedtuple 17 | from functools import reduce 18 | import pickle 19 | 20 | from dataset import VideoQA_Dataset, videoqa_collate_fn 21 | from args import get_args_parser 22 | from util.misc import get_mask, adjust_learning_rate 23 | from util.metrics import MetricLogger 24 | 25 | from transformers import DebertaV2Tokenizer 26 | from model import DebertaV2ForMaskedLM 27 | 28 | def train_one_epoch(model: torch.nn.Module, tokenizer, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, 29 | dataset_name, args, max_norm: float = 0): 30 | model.train() 31 | edge_index = data_loader.dataset.edge_index.to(device) 32 | vocab_embeddings = data_loader.dataset.vocab_embeddings.to(device) 33 | eps = data_loader.dataset.eps[:, None].to(device) 34 | 35 | metric_logger = MetricLogger(delimiter=" ") 36 | header = "Epoch: [{}]".format(epoch) 37 | num_training_steps = int(len(data_loader) * args.epochs) 38 | args.print_freq = int(len(data_loader) / 4) 39 | for i_batch, batch_dict in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 40 | video = batch_dict["video"].to(device) 41 | video_len = batch_dict["video_len"] 42 | video_mask = get_mask(video_len, video.size(1)).to(device) 43 | text = batch_dict["text"] 44 | encoded = tokenizer(text, add_special_tokens=True, max_length=args.max_tokens, padding="longest", truncation=True, return_tensors="pt") 45 | 46 | inputs = encoded["input_ids"].to(device) 47 | attention_mask = encoded["attention_mask"].to(device) 48 | 49 | # forward 50 | answer_id = batch_dict["answer_id"].to(device) 51 | output = model(video=video, video_mask=video_mask, input_ids=inputs, attention_mask=attention_mask, edge_index=edge_index, vocab_embeddings=vocab_embeddings, eps=eps) 52 | delay = args.max_feats if args.use_video else 0 53 | logits = output['logits'] 54 | logits = logits[:, delay:encoded["input_ids"].size(1) + delay][encoded["input_ids"] == tokenizer.mask_token_id] 55 | 56 | if dataset_name == "ivqa": 57 | a = (answer_id / 2).clamp(max=1) 58 | nll = -F.log_softmax(logits, 1, _stacklevel=5) 59 | loss = (nll * a / a.sum(1, keepdim=True).clamp(min=1)).sum(dim=1).mean() 60 | elif dataset_name == "vqa": 61 | a = (answer_id / 3).clamp(max=1) 62 | nll = -F.log_softmax(logits, 1, _stacklevel=5) 63 | loss = (nll * a / a.sum(1, keepdim=True).clamp(min=1)).sum(dim=1).mean() 64 | else: 65 | loss = F.cross_entropy(logits, answer_id) 66 | 67 | loss_dict = {"cls_loss": loss} 68 | 69 | # reduce losses over all GPUs for logging purposes 70 | loss_dict_reduced = dist.reduce_dict(loss_dict) 71 | loss_reduced = sum(loss_dict_reduced.values()) 72 | loss_value = loss_reduced.item() 73 | 74 | if not math.isfinite(loss_value): 75 | print("Loss is {}, stopping training".format(loss_value)) 76 | print(loss_dict_reduced) 77 | sys.exit(1) 78 | 79 | optimizer.zero_grad() 80 | loss.backward() 81 | if max_norm > 0: 82 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 83 | optimizer.step() 84 | 85 | adjust_learning_rate(optimizer, curr_step=epoch * len(data_loader) + i_batch, num_training_steps=num_training_steps, args=args) 86 | 87 | metric_logger.update(loss=loss_value) 88 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 89 | # gather the stats from all processes 90 | metric_logger.synchronize_between_processes() 91 | print("Averaged stats:", metric_logger) 92 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 93 | 94 | 95 | @torch.no_grad() 96 | def evaluate(model: torch.nn.Module, tokenizer, data_loader, device: torch.device, dataset_name, args, thresholds=[1, 10], split="test", epoch=-1): 97 | model.eval() 98 | ans2cat = data_loader.dataset.ans2cat 99 | class_tensor = torch.zeros((len(data_loader.dataset.ans2id), 2), dtype=torch.float64, device="cuda") 100 | edge_index = data_loader.dataset.edge_index.to(device) 101 | vocab_embeddings = data_loader.dataset.vocab_embeddings.to(device) 102 | eps = data_loader.dataset.eps[:, None].to(device) 103 | 104 | metric_logger = MetricLogger(delimiter=" ") 105 | metric_logger.update(n=0, base=0) 106 | metric_logger.update(n=0, common=0) 107 | metric_logger.update(n=0, rare=0) 108 | metric_logger.update(n=0, unseen=0) 109 | metric_logger.update(n=0, total=0) 110 | header = f"{split}:" 111 | 112 | args.print_freq = int(len(data_loader) / 4) 113 | for i_batch, batch_dict in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): 114 | video = batch_dict["video"].to(device) 115 | video_len = batch_dict["video_len"] 116 | video_mask = get_mask(video_len, video.size(1)).to(device) 117 | text = batch_dict["text"] 118 | encoded = tokenizer(text, add_special_tokens=True, max_length=args.max_tokens, padding="longest", truncation=True, return_tensors="pt") 119 | input_ids = encoded["input_ids"].to(device) 120 | attention_mask = encoded["attention_mask"].to(device) 121 | if not args.suffix and not args.use_context: # remove sep token if not using the suffix 122 | attention_mask[input_ids == tokenizer.sep_token_id] = 0 123 | input_ids[input_ids == tokenizer.sep_token_id] = tokenizer.pad_token_id 124 | 125 | 126 | answer_id, qids = batch_dict["answer_id"].to(device), batch_dict["qid"] 127 | output = model(video=video, video_mask=video_mask, input_ids=input_ids, attention_mask=attention_mask, edge_index=edge_index, vocab_embeddings=vocab_embeddings, eps=eps) 128 | logits = output["logits"] 129 | delay = args.max_feats if args.use_video else 0 130 | logits = logits[:, delay:encoded["input_ids"].size(1) + delay][encoded["input_ids"] == tokenizer.mask_token_id] # get the prediction on the mask token 131 | logits = logits.softmax(-1) 132 | 133 | topk_logits, topk_aids = torch.topk(logits, max(thresholds), -1) 134 | 135 | types = batch_dict["type"] 136 | original_answers = batch_dict['original_answer'] 137 | 138 | 139 | for i, (p, ans) in enumerate(zip(answer_id == logits.max(1).indices, original_answers)): 140 | category = ans2cat[ans] 141 | class_tensor[answer_id[i]][0] += p.float().item() 142 | class_tensor[answer_id[i]][1] += 1. 143 | if category == 'base': 144 | metric_logger.update(n=1, base=p.float().item()) 145 | elif category == 'common': 146 | metric_logger.update(n=1, common=p.float().item()) 147 | elif category == 'rare': 148 | metric_logger.update(n=1, rare=p.float().item()) 149 | elif category == 'unseen': 150 | metric_logger.update(n=1, unseen=p.float().item()) 151 | metric_logger.update(n=1, total=p.float().item()) 152 | 153 | 154 | torch.distributed.barrier() 155 | torch.distributed.all_reduce(class_tensor) 156 | macc = (class_tensor[:, 0] / class_tensor[:, 1]).mean().item() 157 | metric_logger.synchronize_between_processes() 158 | metric_logger.update(n=1, macc=macc) 159 | print("Averaged stats:", metric_logger) 160 | 161 | results = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 162 | return results 163 | 164 | 165 | def main(args): 166 | # Init distributed mode 167 | dist.init_distributed_mode(args) 168 | if dist.is_main_process(): 169 | if args.save_dir and not (os.path.isdir(args.save_dir)): 170 | os.makedirs(os.path.join(args.save_dir), exist_ok=True) 171 | print(args) 172 | 173 | device = torch.device(args.device) 174 | 175 | # Fix seeds 176 | seed = args.seed + dist.get_rank() 177 | torch.manual_seed(seed) 178 | np.random.seed(seed) 179 | random.seed(seed) 180 | 181 | # Build model 182 | tokenizer = DebertaV2Tokenizer.from_pretrained(args.model_name, local_files_only=True) 183 | 184 | dataset_test = VideoQA_Dataset(args, tokenizer, "test") 185 | sampler_test = DistributedSampler(dataset_test, shuffle=False) if args.distributed else torch.utils.data.SequentialSampler(dataset_test) 186 | dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size_test, sampler=sampler_test, collate_fn=videoqa_collate_fn, num_workers=args.num_workers) 187 | 188 | if not args.eval: 189 | dataset_train = VideoQA_Dataset(args, tokenizer, 'train') 190 | sampler_train = DistributedSampler(dataset_train) if args.distributed else torch.utils.data.RandomSampler(dataset_train) 191 | dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, sampler=sampler_train, collate_fn=videoqa_collate_fn, num_workers=args.num_workers) 192 | 193 | args.n_ans = len(dataloader_test.dataset.ans2id) 194 | 195 | model = DebertaV2ForMaskedLM.from_pretrained(features_dim=args.features_dim if args.use_video else 0, max_feats=args.max_feats, freeze_lm=args.freeze_lm, 196 | freeze_mlm=args.freeze_mlm, ft_ln=args.ft_ln, ds_factor_attn=args.ds_factor_attn, ds_factor_ff=args.ds_factor_ff, 197 | dropout=args.dropout, n_ans=args.n_ans, freeze_last=args.freeze_last, pretrained_model_name_or_path=args.model_name, 198 | local_files_only=True, args=args) 199 | model.to(device) 200 | 201 | total_parameters = sum(p.numel() for p in model.parameters()) 202 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 203 | print(f'Total params: {total_parameters:,}') 204 | print(f'Trained params: {n_parameters:,}') 205 | 206 | # Set up optimizer 207 | params_for_optimization = list(p for n, p in model.named_parameters() if (p.requires_grad and 'gat' not in n)) 208 | answer_params_for_optimization = list(p for n, p in model.named_parameters() if (p.requires_grad and 'gat' in n)) 209 | optimizer = torch.optim.Adam([{"params": params_for_optimization, "lr": args.lr}, {"params": answer_params_for_optimization, "lr": args.lr}], 210 | lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay) 211 | 212 | # Load pretrained checkpoint 213 | if args.load: 214 | print("loading from", args.load) 215 | checkpoint = torch.load(args.load, map_location="cpu") 216 | if 'model' in checkpoint.keys(): 217 | model.load_state_dict(checkpoint['model'], strict=False) 218 | else: 219 | model.load_state_dict(checkpoint, strict=False) 220 | 221 | if args.resume and not args.eval: 222 | optimizer.load_state_dict(checkpoint["optimizer"]) 223 | args.start_epoch = checkpoint["epoch"] + 1 224 | 225 | if not args.eval: 226 | train_aid2tokid = torch.zeros(len(dataloader_train.dataset.ans2id), args.max_atokens).long() 227 | for a, aid in dataloader_train.dataset.ans2id.items(): 228 | tok = torch.tensor(tokenizer(a, add_special_tokens=False, max_length=args.max_atokens, truncation=True, padding="max_length")["input_ids"], dtype=torch.long) 229 | train_aid2tokid[aid] = tok 230 | 231 | print(f'Training Vocab : {len(train_aid2tokid)}') 232 | print(f'Training Samples : {len(dataloader_train.dataset)}') 233 | test_aid2tokid = torch.zeros(len(dataloader_test.dataset.ans2id), args.max_atokens).long() 234 | for a, aid in dataloader_test.dataset.ans2id.items(): 235 | tok = torch.tensor(tokenizer(a, add_special_tokens=False, max_length=args.max_atokens, truncation=True, padding="max_length")["input_ids"], dtype=torch.long) 236 | test_aid2tokid[aid] = tok 237 | print(f'Test Vocab : {len(test_aid2tokid)}') 238 | print(f'Test Samples : {len(dataloader_test.dataset)}') 239 | 240 | if not args.eval: 241 | print("Start training") 242 | start_time = time.time() 243 | best_epoch = args.start_epoch 244 | best_acc = 0 245 | for epoch in range(args.start_epoch, args.epochs): 246 | print(f"Starting epoch {epoch}") 247 | if args.distributed: 248 | sampler_train.set_epoch(epoch) 249 | 250 | model.set_answer_embeddings(train_aid2tokid.to(model.device), freeze_last=args.freeze_last) 251 | train_stats = train_one_epoch(model=model, tokenizer=tokenizer, data_loader=dataloader_train, optimizer=optimizer, device=device, epoch=epoch, 252 | dataset_name=args.dataset, args=args, max_norm=args.clip_max_norm) 253 | 254 | if (epoch + 1) % args.eval_skip == 0: 255 | print(f"Validating {args.dataset}") 256 | val_stats = {} 257 | model.set_answer_embeddings(test_aid2tokid.to(model.device), freeze_last=args.freeze_last) 258 | results = evaluate(model=model, tokenizer=tokenizer, data_loader=dataloader_test, device=device, dataset_name=args.dataset, 259 | args=args, split="val", epoch=epoch) 260 | val_stats.update({args.dataset + "_" + k: v for k, v in results.items()}) 261 | 262 | if results["total"] > best_acc: 263 | best_epoch = epoch 264 | best_acc = results["total"] 265 | if dist.is_main_process() and args.save_dir: 266 | checkpoint_path = os.path.join(args.save_dir, f"best_model.pth") 267 | dist.save_on_master({"model": model, "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args}, checkpoint_path) 268 | json.dump({"acc": best_acc, "ep": epoch}, open(os.path.join(args.save_dir, args.dataset + "acc_val.json"), "w")) 269 | else: 270 | val_stats = {} 271 | 272 | log_stats = {**{f"train_{k}": v for k, v in train_stats.items()}, **{f"val_{k}": v for k, v in val_stats.items()}, "epoch": epoch, "n_parameters": n_parameters} 273 | 274 | if args.save_dir and dist.is_main_process(): 275 | with open(os.path.join(args.save_dir, "log.txt"), "a") as f: 276 | f.write(json.dumps(log_stats) + "\n") 277 | checkpoint_path = os.path.join(args.save_dir, f"ckpt.pth") 278 | dist.save_on_master({"model": model, "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args}, checkpoint_path) 279 | 280 | total_time = time.time() - start_time 281 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 282 | print("Training time {}".format(total_time_str)) 283 | # load best ckpt 284 | if dist.is_main_process() and args.save_dir: 285 | print(f"loading best checkpoint from epoch {best_epoch}") 286 | if args.save_dir: 287 | torch.distributed.barrier() # wait all processes 288 | checkpoint = torch.load(os.path.join(args.save_dir, f"best_model.pth"), map_location="cpu") 289 | model.load_state_dict(checkpoint["model"], strict=False) 290 | 291 | model.set_answer_embeddings(test_aid2tokid.to(model.device), freeze_last=args.freeze_last) 292 | results = evaluate(model=model, tokenizer=tokenizer, data_loader=dataloader_test, device=device, dataset_name=args.dataset, 293 | args=args, split="val" if (args.eval and not args.test) else "test") 294 | 295 | if args.save_dir and dist.is_main_process(): 296 | json.dump(results, open(os.path.join(args.save_dir, args.dataset + ".json"), "w")) 297 | 298 | 299 | 300 | if __name__ == "__main__": 301 | parser = argparse.ArgumentParser(parents=[get_args_parser()]) 302 | args = parser.parse_args() 303 | 304 | if args.save_dir: 305 | args.save_dir = os.path.join(args.presave_dir, args.save_dir) 306 | args.model_name = os.path.join('./pretrained', args.model_name) 307 | main(args) 308 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mlvlab/OVQA/8f9fb810c11f348eb3ef6bdd4cc0d2c313299d5c/util/__init__.py -------------------------------------------------------------------------------- /util/dist.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import io 3 | import os 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import hostlist 8 | 9 | _LOCAL_PROCESS_GROUP = None 10 | 11 | 12 | @functools.lru_cache() 13 | def _get_global_gloo_group(): 14 | """ 15 | Return a process group based on gloo backend, containing all the ranks 16 | The result is cached. 17 | """ 18 | 19 | if dist.get_backend() == "nccl": 20 | return dist.new_group(backend="gloo") 21 | 22 | return dist.group.WORLD 23 | 24 | 25 | def all_gather(data): 26 | """ 27 | Run all_gather on arbitrary picklable data (not necessarily tensors) 28 | Args: 29 | data: any picklable object 30 | Returns: 31 | list[data]: list of data gathered from each rank 32 | """ 33 | 34 | world_size = get_world_size() 35 | if world_size == 1: 36 | return [data] 37 | 38 | cpu_group = None 39 | if os.getenv("MDETR_CPU_REDUCE") == "1": 40 | cpu_group = _get_global_gloo_group() 41 | 42 | buffer = io.BytesIO() 43 | torch.save(data, buffer) 44 | data_view = buffer.getbuffer() 45 | device = "cuda" if cpu_group is None else "cpu" 46 | tensor = torch.ByteTensor(data_view).to(device) 47 | 48 | # obtain Tensor size of each rank 49 | local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long) 50 | size_list = [ 51 | torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size) 52 | ] 53 | if cpu_group is None: 54 | dist.all_gather(size_list, local_size) 55 | else: 56 | print("gathering on cpu") 57 | dist.all_gather(size_list, local_size, group=cpu_group) 58 | size_list = [int(size.item()) for size in size_list] 59 | max_size = max(size_list) 60 | assert isinstance(local_size.item(), int) 61 | local_size = int(local_size.item()) 62 | 63 | # receiving Tensor from all ranks 64 | # we pad the tensor because torch all_gather does not support 65 | # gathering tensors of different shapes 66 | tensor_list = [] 67 | for _ in size_list: 68 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device)) 69 | if local_size != max_size: 70 | padding = torch.empty( 71 | size=(max_size - local_size,), dtype=torch.uint8, device=device 72 | ) 73 | tensor = torch.cat((tensor, padding), dim=0) 74 | if cpu_group is None: 75 | dist.all_gather(tensor_list, tensor) 76 | else: 77 | dist.all_gather(tensor_list, tensor, group=cpu_group) 78 | 79 | data_list = [] 80 | for size, tensor in zip(size_list, tensor_list): 81 | tensor = torch.split(tensor, [size, max_size - size], dim=0)[0] 82 | buffer = io.BytesIO(tensor.cpu().numpy()) 83 | obj = torch.load(buffer) 84 | data_list.append(obj) 85 | 86 | return data_list 87 | 88 | 89 | def reduce_dict(input_dict, average=True): 90 | """ 91 | Args: 92 | input_dict (dict): all the values will be reduced 93 | average (bool): whether to do average or sum 94 | Reduce the values in the dictionary from all processes so that all processes 95 | have the averaged results. Returns a dict with the same fields as 96 | input_dict, after reduction. 97 | """ 98 | world_size = get_world_size() 99 | if world_size < 2: 100 | return input_dict 101 | with torch.no_grad(): 102 | names = [] 103 | values = [] 104 | # sort the keys so that they are consistent across processes 105 | for k in sorted(input_dict.keys()): 106 | names.append(k) 107 | values.append(input_dict[k]) 108 | values = torch.stack(values, dim=0) 109 | dist.all_reduce(values) 110 | if average: 111 | values /= world_size 112 | reduced_dict = {k: v for k, v in zip(names, values)} 113 | return reduced_dict 114 | 115 | 116 | def setup_for_distributed(is_master): 117 | """ 118 | This function disables printing when not in master process 119 | """ 120 | import builtins as __builtin__ 121 | 122 | builtin_print = __builtin__.print 123 | 124 | def print(*args, **kwargs): 125 | force = kwargs.pop("force", False) 126 | if is_master or force: 127 | builtin_print(*args, **kwargs) 128 | 129 | __builtin__.print = print 130 | 131 | 132 | def is_dist_avail_and_initialized(): 133 | """ 134 | Returns: 135 | True if distributed training is enabled 136 | """ 137 | if not dist.is_available(): 138 | return False 139 | if not dist.is_initialized(): 140 | return False 141 | return True 142 | 143 | 144 | def get_world_size(): 145 | """ 146 | Returns: 147 | The number of processes in the process group 148 | """ 149 | if not is_dist_avail_and_initialized(): 150 | return 1 151 | return dist.get_world_size() 152 | 153 | 154 | def get_rank(): 155 | """ 156 | Returns: 157 | The rank of the current process within the global process group. 158 | """ 159 | if not is_dist_avail_and_initialized(): 160 | return 0 161 | return dist.get_rank() 162 | 163 | 164 | def get_local_rank() -> int: 165 | """ 166 | Returns: 167 | The rank of the current process within the local (per-machine) process group. 168 | """ 169 | if not dist.is_available(): 170 | return 0 171 | if not dist.is_initialized(): 172 | return 0 173 | assert _LOCAL_PROCESS_GROUP is not None 174 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP) 175 | 176 | 177 | def get_local_size() -> int: 178 | """ 179 | Returns: 180 | The size of the per-machine process group, 181 | i.e. the number of processes per machine. 182 | """ 183 | if not dist.is_available(): 184 | return 1 185 | if not dist.is_initialized(): 186 | return 1 187 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) 188 | 189 | 190 | def is_main_process(): 191 | """Return true if the current process is the main one""" 192 | return get_rank() == 0 193 | 194 | 195 | def save_on_master(state_dict, checkpoint_path): 196 | """Utility function to save only from the main process""" 197 | model_without_deberta = {} 198 | for n, p in state_dict['model'].named_parameters(): 199 | if p.requires_grad: 200 | model_without_deberta[n] = p 201 | state_dict['model'] = model_without_deberta 202 | if is_main_process(): 203 | torch.save(state_dict, checkpoint_path) 204 | 205 | 206 | def init_distributed_mode(args): 207 | """Initialize distributed training, if appropriate""" 208 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ: 209 | args.rank = int(os.environ["RANK"]) 210 | args.world_size = int(os.environ["WORLD_SIZE"]) 211 | args.gpu = int(os.environ["LOCAL_RANK"]) 212 | elif "SLURM_PROCID" in os.environ: 213 | args.rank = int(os.environ["SLURM_PROCID"]) 214 | args.gpu = args.rank % torch.cuda.device_count() 215 | # CLUSTER SPECIFIC 216 | args.world_size = int(os.environ["SLURM_NTASKS"]) 217 | hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"]) 218 | gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",") 219 | os.environ["MASTER_ADDR"] = hostnames[0] 220 | os.environ["MASTER_PORT"] = str(12345 + int(min(gpu_ids))) # to avoid port conflict on the same node 221 | else: 222 | print("Not using distributed mode") 223 | args.distributed = False 224 | return 225 | 226 | args.distributed = True 227 | 228 | torch.cuda.set_device(args.gpu) 229 | args.dist_backend = "nccl" 230 | print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) 231 | 232 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank) 233 | dist.barrier() 234 | setup_for_distributed(args.rank == 0) 235 | -------------------------------------------------------------------------------- /util/metrics.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import time 3 | from collections import defaultdict, deque 4 | 5 | import torch 6 | import torch.distributed as dist 7 | 8 | from util.dist import is_dist_avail_and_initialized 9 | 10 | 11 | class SmoothedValue: 12 | """Track a series of values and provide access to smoothed values over a 13 | window or the global series average. 14 | """ 15 | 16 | def __init__(self, window_size=20, fmt=None): 17 | if fmt is None: 18 | fmt = "{median:.4f} ({global_avg:.4f})" 19 | self.deque = deque(maxlen=window_size) 20 | self.total = 0.0 21 | self.count = 0 22 | self.fmt = fmt 23 | 24 | def update(self, value, num=1): 25 | self.deque.append(value) 26 | self.count += num 27 | self.total += value * num 28 | 29 | def synchronize_between_processes(self): 30 | """ 31 | Distributed synchronization of the metric 32 | Warning: does not synchronize the deque! 33 | """ 34 | if not is_dist_avail_and_initialized(): 35 | return 36 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 37 | dist.barrier() 38 | dist.all_reduce(t) 39 | t = t.tolist() 40 | self.count = int(t[0]) 41 | self.total = t[1] 42 | 43 | @property 44 | def median(self): 45 | d = torch.tensor(list(self.deque)) 46 | return d.median().item() 47 | 48 | @property 49 | def avg(self): 50 | d = torch.tensor(list(self.deque), dtype=torch.float32) 51 | return d.mean().item() 52 | 53 | @property 54 | def global_avg(self): 55 | if self.count == 0: 56 | return 0. 57 | else: 58 | return self.total / self.count 59 | 60 | @property 61 | def max(self): 62 | return max(self.deque) 63 | 64 | @property 65 | def value(self): 66 | return self.deque[-1] 67 | 68 | def __str__(self): 69 | return self.fmt.format( 70 | median=self.median, 71 | avg=self.avg, 72 | global_avg=self.global_avg, 73 | max=self.max, 74 | value=self.value, 75 | ) 76 | 77 | 78 | class MetricLogger(object): 79 | def __init__(self, delimiter="\t"): 80 | self.meters = defaultdict(SmoothedValue) 81 | self.delimiter = delimiter 82 | 83 | def update(self, n=1, **kwargs): 84 | for k, v in kwargs.items(): 85 | if isinstance(v, torch.Tensor): 86 | v = v.item() 87 | assert isinstance(v, (float, int)) 88 | self.meters[k].update(v, num=n) 89 | 90 | def __getattr__(self, attr): 91 | if attr in self.meters: 92 | return self.meters[attr] 93 | if attr in self.__dict__: 94 | return self.__dict__[attr] 95 | raise AttributeError( 96 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr) 97 | ) 98 | 99 | def __str__(self): 100 | loss_str = [] 101 | for name, meter in self.meters.items(): 102 | loss_str.append("{}: {}".format(name, str(meter))) 103 | return self.delimiter.join(loss_str) 104 | 105 | def synchronize_between_processes(self): 106 | for meter in self.meters.values(): 107 | meter.synchronize_between_processes() 108 | 109 | def add_meter(self, name, meter): 110 | self.meters[name] = meter 111 | 112 | def log_every(self, iterable, print_freq, header=None): 113 | i = 0 114 | if not header: 115 | header = "" 116 | start_time = time.time() 117 | end = time.time() 118 | iter_time = SmoothedValue(fmt="{avg:.4f}") 119 | data_time = SmoothedValue(fmt="{avg:.4f}") 120 | space_fmt = ":" + str(len(str(len(iterable)))) + "d" 121 | if torch.cuda.is_available(): 122 | log_msg = self.delimiter.join( 123 | [ 124 | header, 125 | "[{0" + space_fmt + "}/{1}]", 126 | "eta: {eta}", 127 | "{meters}", 128 | "time: {time}", 129 | "data: {data}", 130 | "max mem: {memory:.0f}", 131 | ] 132 | ) 133 | else: 134 | log_msg = self.delimiter.join( 135 | [ 136 | header, 137 | "[{0" + space_fmt + "}/{1}]", 138 | "eta: {eta}", 139 | "{meters}", 140 | "time: {time}", 141 | "data: {data}", 142 | ] 143 | ) 144 | MB = 1024.0 * 1024.0 145 | for obj in iterable: 146 | data_time.update(time.time() - end) 147 | yield obj 148 | iter_time.update(time.time() - end) 149 | if i % print_freq == 0 or i == len(iterable) - 1: 150 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 152 | if torch.cuda.is_available(): 153 | print( 154 | log_msg.format( 155 | i, 156 | len(iterable), 157 | eta=eta_string, 158 | meters=str(self), 159 | time=str(iter_time), 160 | data=str(data_time), 161 | memory=torch.cuda.max_memory_allocated() / MB, 162 | ) 163 | ) 164 | else: 165 | print( 166 | log_msg.format( 167 | i, 168 | len(iterable), 169 | eta=eta_string, 170 | meters=str(self), 171 | time=str(iter_time), 172 | data=str(data_time), 173 | ) 174 | ) 175 | i += 1 176 | end = time.time() 177 | total_time = time.time() - start_time 178 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 179 | print( 180 | "{} Total time: {} ({:.4f} s / it)".format( 181 | header, total_time_str, total_time / len(iterable) 182 | ) 183 | ) 184 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Tuple, List 3 | import random 4 | 5 | 6 | def get_mask(lengths, max_length): 7 | """Computes a batch of padding masks given batched lengths""" 8 | mask = 1 * (torch.arange(max_length).unsqueeze(1).to(lengths.device) < lengths).transpose(0, 1) 9 | return mask 10 | 11 | 12 | def mask_tokens(inputs, tokenizer, mlm_probability): 13 | """ 14 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original. 15 | """ 16 | 17 | if tokenizer.mask_token is None: 18 | raise ValueError("This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer.") 19 | 20 | labels = inputs.clone() 21 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa) 22 | probability_matrix = torch.full(labels.shape, mlm_probability) 23 | special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()] 24 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0) 25 | if tokenizer._pad_token is not None: 26 | padding_mask = labels.eq(tokenizer.pad_token_id) 27 | probability_matrix.masked_fill_(padding_mask, value=0.0) 28 | masked_indices = torch.bernoulli(probability_matrix).bool() 29 | labels[~masked_indices] = -100 # We only compute loss on masked tokens 30 | 31 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 32 | indices_replaced = (torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices) 33 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token) 34 | 35 | # 10% of the time, we replace masked input tokens with random word 36 | indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced) 37 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long) 38 | inputs[indices_random] = random_words[indices_random] 39 | 40 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 41 | return inputs, labels 42 | 43 | 44 | def adjust_learning_rate(optimizer, curr_step: int, num_training_steps: int, args): 45 | num_warmup_steps: int = round(args.fraction_warmup_steps * num_training_steps) 46 | if args.schedule == "linear_with_warmup": 47 | if curr_step < num_warmup_steps: 48 | gamma = float(curr_step) / float(max(1, num_warmup_steps)) 49 | else: 50 | gamma = max(0.0, float(num_training_steps - curr_step) / float(max(1, num_training_steps - num_warmup_steps))) 51 | else: # constant LR 52 | gamma = 1 53 | 54 | optimizer.param_groups[0]["lr"] = args.lr * gamma 55 | --------------------------------------------------------------------------------