├── README.md
├── args.py
├── asset
├── cvqa.png
└── ovqa.png
├── dataset.py
├── model
├── __init__.py
├── adapter.py
├── deberta.py
└── gnn.py
├── setup.sh
├── train.py
└── util
├── __init__.py
├── dist.py
├── metrics.py
└── misc.py
/README.md:
--------------------------------------------------------------------------------
1 | # Open-Vocabulary Video Question Answering: A New Benchmark for Evaluating the Generalizability of Video Question Answering Models
2 |
3 | This is the official implementation of OVQA (ICCV 2023). ([arxiv](https://arxiv.org/abs/2308.09363))
4 |
5 | > Dohwan Ko, Ji Soo Lee, Miso Choi, Jaewon Chu, Jihwan Park, Hyunwoo J. Kim.
6 | >
7 | > Department of Computer Science and Engineering, Korea University
8 |
9 |
10 |

11 |

12 |
13 |
14 | (a) Closed-vocabulary Video Question Answering (b) Open-vocabulary Video Question Answering (Ours)
15 |
16 |
17 |
18 | ## Setup
19 | To install requirements, run:
20 | ```
21 | conda create -n ovqa python=3.8
22 | conda activate ovqa
23 | sh setup.sh
24 | ```
25 | ## Data Preparation
26 | ### Download preprocessed data, visual features
27 | Pretrained checkpoint, preprocessed data, and data annotations are provided [here](https://drive.google.com/drive/folders/1HcrMsINkNcRUfnZMd15l9AH0R9X4qvQG). You can download pretrained DeBERTa-v2-xlarge model [here](https://huggingface.co/microsoft/deberta-v2-xlarge).
28 |
29 | Then, place the files as follows:
30 |
31 | ```
32 | ./pretrained
33 | |─ pretrained.pth
34 | └─ deberta-v2-xlarge
35 | ./meta_data
36 | |─ activitynet
37 | │ |─ train.csv
38 | │ |─ test.csv
39 | │ |─ train_vocab.json
40 | │ |─ test_vocab.json
41 | │ |─ clipvitl14.pth
42 | │ |─ subtitles.pkl
43 | │ |─ ans2cat.json
44 | │ └─ answer_graph
45 | │ |─ train_edge_index.pth
46 | │ |─ train_x.pth
47 | │ |─ test_edge_index.pth
48 | │ └─ test_x.pth
49 | │
50 | |─ msvd
51 | │ |─ train.csv
52 | │ |─ test.csv
53 | │ |─ train_vocab.json
54 | │ |─ test_vocab.json
55 | │ |─ clipvitl14.pth
56 | │ |─ subtitles.pkl
57 | │ |─ ans2cat.json
58 | │ └─ answer_graph
59 | │ |─ train_edge_index.pth
60 | │ |─ train_x.pth
61 | │ |─ test_edge_index.pth
62 | │ └─ test_x.pth
63 | │ :
64 | ```
65 |
66 |
67 | ## Train FrozenBiLM+ on OVQA
68 |
69 | To train on ActivityNet-QA, MSVD-QA, TGIF-QA, and MSRVTT-QA, run below command. You can modify `--dataset activitynet ` to change dataset.
70 |
71 | ```
72 | python -m torch.distributed.launch --nproc_per_node 4 --use_env train.py --dist-url tcp://127.0.0.1:12345 \
73 | --dataset activitynet --lr 5e-5 --batch_size 8 --batch_size_test 32 --save_dir ./path/to/save/files --epochs 20 --eps 0.7
74 | ```
75 |
76 |
77 | ## Acknowledgements
78 | This repo is built upon [FrozenBiLM](https://github.com/antoyang/FrozenBiLM).
79 |
80 |
81 |
82 | ## Citation
83 | ```
84 | @inproceedings{ko2023open,
85 | title={Open-vocabulary Video Question Answering: A New Benchmark for Evaluating the Generalizability of Video Question Answering Models},
86 | author={Ko, Dohwan and Lee, Ji Soo and Choi, Miso and Chu, Jaewon and Park, Jihwan and Kim, Hyunwoo J},
87 | booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
88 | year={2023}
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/args.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | PRESAVE_DIR = ""
5 | MODEL_DIR = ""
6 | DATA_DIR = ""
7 | SSD_DIR = ""
8 | name2folder = {
9 | "msrvtt": "./meta_data/msrvtt",
10 | "msvd": "./meta_data/msvd",
11 | "activitynet": "./meta_data/activitynet",
12 | "tgif": "./meta_data/tgif",
13 | }
14 |
15 |
16 | def get_args_parser():
17 | parser = argparse.ArgumentParser("Set OVQA", add_help=False)
18 |
19 | # Dataset specific
20 | parser.add_argument("--dataset", help="dataset", type=str)
21 |
22 | # Training hyper-parameters
23 | parser.add_argument("--mlm_prob", type=float, default=0.15, help="masking probability for the MLM objective")
24 | parser.add_argument("--lr", default=3e-4, type=float, help="learning rate")
25 | parser.add_argument("--beta1", default=0.9, type=float, help="Adam optimizer parameter")
26 | parser.add_argument("--beta2", default=0.95, type=float, help="Adam optimizer parameter")
27 | parser.add_argument("--batch_size", default=32, type=int, help="batch size used for training")
28 | parser.add_argument("--batch_size_test", default=32, type=int, help="batch size used for test")
29 | parser.add_argument("--weight_decay", default=0, type=float)
30 | parser.add_argument("--epochs", default=10, type=int, help="number of training epochs")
31 | parser.add_argument("--lr_drop", default=10, type=int, help="number of epochs after which the learning rate is reduced when not using linear decay")
32 | parser.add_argument("--optimizer", default="adam", type=str)
33 | parser.add_argument("--clip_max_norm", default=0.1, type=float, help="gradient clipping max norm")
34 | parser.add_argument("--schedule", default="linear_with_warmup", choices=["", "linear_with_warmup"], help="learning rate decay schedule, default is constant")
35 | parser.add_argument("--fraction_warmup_steps", default=0.1, type=float, help="fraction of number of steps used for warmup when using linear schedule")
36 | parser.add_argument("--eval_skip", default=1, type=int, help='do evaluation every "eval_skip" epochs')
37 | parser.add_argument("--print_freq", type=int, default=100, help="print log every print_freq iterations")
38 |
39 | # Model parameters
40 | parser.add_argument("--ft_lm", dest="freeze_lm", action="store_false", help="whether to finetune the weights of the language model")
41 | parser.add_argument("--model_name", default="deberta-v2-xlarge", choices=("deberta-v2-xlarge"))
42 | parser.add_argument("--ds_factor_attn", type=int, default=8, help="downsampling factor for adapter attn")
43 | parser.add_argument("--ds_factor_ff", type=int, default=8, help="downsampling factor for adapter ff")
44 | parser.add_argument("--freeze_ln", dest="ft_ln", action="store_false", help="whether or not to freeze layer norm parameters")
45 | parser.add_argument("--ft_mlm", dest="freeze_mlm", action="store_false", help="whether or not to finetune the mlm head parameters")
46 | parser.add_argument("--dropout", default=0.1, type=float, help="dropout to use in the adapter")
47 | parser.add_argument("--scratch", action="store_true", help="whether to train the LM with or without language init")
48 | parser.add_argument("--n_ans", type=int, default=0, help="number of answers in the answer embedding module, it is automatically set")
49 | parser.add_argument("--ft_last", dest="freeze_last", action="store_false", help="whether to finetune answer embedding module or not")
50 | parser.add_argument("--eps", default=1.0, type=float, help="strength")
51 |
52 | # Run specific
53 | parser.add_argument("--test", action="store_true", help="whether to run evaluation on val or test set")
54 | parser.add_argument("--save_dir", default="", help="path where to save, empty for no saving")
55 | parser.add_argument("--presave_dir", default=PRESAVE_DIR, help="the actual save_dir is an union of presave_dir and save_dir")
56 | parser.add_argument("--device", default="cuda", help="device to use")
57 | parser.add_argument("--seed", default=42, type=int, help="random seed")
58 | parser.add_argument("--load", default="./pretrained/pretrained.pth", help="path to load checkpoint")
59 | parser.add_argument("--resume", action="store_true", help="continue training if loading checkpoint")
60 | parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="start epoch")
61 | parser.add_argument("--eval", action="store_true", help="only run evaluation")
62 | parser.add_argument("--num_workers", default=3, type=int, help="number of workers for dataloader")
63 |
64 | # Distributed training parameters
65 | parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
66 | parser.add_argument("--dist-url", default="env://", help="url used to set up distributed training")
67 |
68 | # Video and Text parameters
69 | parser.add_argument("--max_feats", type=int, default=10, help="maximum number of video features considered, one per frame")
70 | parser.add_argument("--features_dim", type=int, default=768, help="dimension of the visual embedding space")
71 | parser.add_argument("--no_video", dest="use_video", action="store_false", help="disables usage of video")
72 | parser.add_argument("--no_context", dest="use_context", action="store_false", help="disables usage of speech")
73 | parser.add_argument("--max_tokens", type=int, default=256, help="maximum number of tokens in the input text prompt")
74 | parser.add_argument("--max_atokens", type=int, default=5, help="maximum number of tokens in the answer")
75 | parser.add_argument("--prefix", default="", type=str, help="task induction before question for videoqa")
76 | parser.add_argument("--suffix", default=".", type=str, help="suffix after the answer mask for videoqa")
77 |
78 | return parser
79 |
--------------------------------------------------------------------------------
/asset/cvqa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/OVQA/8f9fb810c11f348eb3ef6bdd4cc0d2c313299d5c/asset/cvqa.png
--------------------------------------------------------------------------------
/asset/ovqa.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/OVQA/8f9fb810c11f348eb3ef6bdd4cc0d2c313299d5c/asset/ovqa.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | from torch.utils.data.dataloader import default_collate
4 | import pandas as pd
5 | import collections
6 | import json
7 | import pickle
8 |
9 |
10 | class VideoQA_Dataset(Dataset):
11 | def __init__(self, args, tokenizer, split):
12 |
13 | path = f'./meta_data/{args.dataset}/'
14 | self.data = pd.read_csv(path + f'{split}.csv')
15 | self.features = torch.load(path + f'clipvitl14.pth')
16 | self.ans2id = json.load(open(path + f'{split}_vocab.json'))
17 | self.ans2cat = json.load(open(path + 'ans2cat.json'))
18 | self.max_feats = args.max_feats
19 | self.features_dim = args.features_dim
20 | self.split = split
21 | self.prefix = args.prefix
22 | self.suffix = args.suffix
23 | self.mask = tokenizer.mask_token
24 | self.pad = tokenizer.pad_token
25 | self.tokenizer = tokenizer
26 | self.use_context = (args.use_context and args.dataset != "tgif")
27 | self.subs = pickle.load(open(path + f'subtitles.pkl', "rb")) if self.use_context else None
28 | self.args = args
29 | self.load_answer_graph()
30 |
31 | def load_answer_graph(self):
32 | self.edge_index = torch.load(f'./meta_data/{self.args.dataset}/answer_graph/{self.split}_edge_index.pth')
33 | self.vocab_embeddings = torch.load(f'./meta_data/{self.args.dataset}/answer_graph/{self.split}_x.pth')
34 | cat2coef = {'base': 1.0, 'common': self.args.eps, 'rare': self.args.eps, 'unseen': self.args.eps}
35 | # cat2coef = {'base': self.args.eps, 'common': self.args.eps, 'rare': self.args.eps, 'unseen': self.args.eps}
36 | self.eps = torch.tensor([cat2coef[self.ans2cat[k]] for k, v in self.ans2id.items()])
37 |
38 | def __len__(self):
39 | return len(self.data)
40 |
41 | def _get_text(self, question, mask, sub):
42 | text = f"{self.prefix} Question: {question} Answer: {mask}{self.suffix}"
43 | if sub:
44 | text += f" Subtitles: {sub}"
45 | text = text.strip()
46 | return text
47 |
48 | def _get_video(self, video_id):
49 | if video_id not in self.features:
50 | print(video_id)
51 | video = torch.zeros(1, self.features_dim)
52 | else:
53 | video = self.features[video_id].float()
54 | if len(video) > self.max_feats:
55 | sampled = []
56 | for j in range(self.max_feats):
57 | sampled.append(video[(j * len(video)) // self.max_feats])
58 | video = torch.stack(sampled)
59 | video_len = self.max_feats
60 | elif len(video) < self.max_feats:
61 | video_len = len(video)
62 | video = torch.cat([video, torch.zeros(self.max_feats - video_len, self.features_dim)], 0)
63 | else:
64 | video_len = self.max_feats
65 |
66 | return video, video_len
67 |
68 | def __getitem__(self, idx):
69 | # get question
70 | question = self.data["question"].values[idx].capitalize().strip()
71 | if question[-1] != "?":
72 | question = str(question) + "?"
73 | type = 0
74 | if "type" in self.data:
75 | type = self.data["type"].values[idx]
76 |
77 |
78 | original_answer = self.data["answer"].values[idx]
79 | answer_id = self.ans2id[original_answer]
80 | video_id = self.data["video_id"].values[idx]
81 |
82 | # get subtitles
83 | sub = ""
84 | if self.subs is not None and video_id in self.subs:
85 | sub = self.subs[video_id]
86 | sub_bool = bool(sub)
87 | if not self.use_context:
88 | sub = ""
89 |
90 | # get pattern
91 | text = self._get_text(question, self.mask, sub)
92 |
93 | # get video
94 | video, video_len = self._get_video(video_id)
95 |
96 | return {"video": video, "video_len": video_len, "text": text, "qid": idx, "answer_id": answer_id,
97 | "type": type, "sub": sub_bool, "original_answer": original_answer}
98 |
99 |
100 | def videoqa_collate_fn(batch):
101 | bs = len(batch)
102 | video = torch.stack([batch[i]["video"] for i in range(bs)])
103 | video_len = torch.tensor([batch[i]["video_len"] for i in range(bs)], dtype=torch.long)
104 | text = [batch[i]["text"] for i in range(bs)] if isinstance(batch[0]["text"], str) else [[batch[i]["text"][j] for i in range(bs)] for j in range(len(batch[0]["text"]))]
105 | qid = [batch[i]["qid"] for i in range(bs)]
106 | answer_id = default_collate([batch[i]["answer_id"] for i in range(bs)])
107 | type = [batch[i]["type"] for i in range(bs)]
108 | original_answer = [batch[i]["original_answer"] for i in range(bs)]
109 | out = {"video": video, "video_len": video_len, "text": text, "qid": qid, "answer_id": answer_id, "type": type, "original_answer": original_answer}
110 | if "sub" in batch[0]:
111 | sub = [batch[i]["sub"] for i in range(bs)]
112 | out["sub"] = sub
113 | return out
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from .deberta import DebertaV2ForMaskedLM
--------------------------------------------------------------------------------
/model/adapter.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 |
4 |
5 | class Adapter(nn.Module):
6 | def __init__(
7 | self, ds_factor, hidden_dim, ln_after=False, ln_before=False, dropout=0.1
8 | ):
9 | super().__init__()
10 | assert not hidden_dim % ds_factor
11 | self.down = nn.Linear(hidden_dim, hidden_dim // ds_factor)
12 | self.act = nn.ReLU()
13 | self.up = nn.Linear(hidden_dim // ds_factor, hidden_dim)
14 | self.apply(self.init_weights)
15 | self.ln_after = ln_after
16 | self.ln_before = ln_before
17 | self.dropout = dropout
18 | if ln_after or ln_before:
19 | self.ln = nn.LayerNorm(hidden_dim)
20 | if dropout:
21 | self.dropout = nn.Dropout(dropout)
22 |
23 | def init_weights(self, m: nn.Module, std=1e-3):
24 | if isinstance(m, nn.Linear):
25 | torch.nn.init.normal_(m.weight, std=std)
26 | torch.nn.init.normal_(m.bias, std=std)
27 | m.weight.data = torch.clamp(m.weight.data, min=-2 * std, max=2 * std)
28 | m.bias.data = torch.clamp(m.bias.data, min=-2 * std, max=2 * std)
29 | elif isinstance(m, nn.LayerNorm):
30 | m.bias.data.zero_()
31 | m.weight.data.fill_(1.0)
32 |
33 | def forward(self, hidden_states):
34 | if self.ln_before:
35 | residual = self.ln(hidden_states)
36 | residual = self.down(residual)
37 | else:
38 | residual = self.down(hidden_states)
39 | residual = self.act(residual)
40 | if self.dropout:
41 | residual = self.dropout(residual)
42 | residual = self.up(residual)
43 | if self.ln_after:
44 | residual = self.ln(hidden_states)
45 | return hidden_states + residual
--------------------------------------------------------------------------------
/model/deberta.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020 Microsoft and the Hugging Face Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """ PyTorch DeBERTa-v2 model. """
16 |
17 | import math
18 | import json
19 | from collections.abc import Sequence
20 | from typing import Tuple, Optional
21 |
22 | import numpy as np
23 | import torch
24 | from torch import _softmax_backward_data, nn, optim
25 | from torch.nn import CrossEntropyLoss, LayerNorm
26 | import torch.nn.functional as F
27 |
28 | from model.adapter import Adapter
29 | from model.gnn import GNNSoftVerbalizer
30 | from transformers.activations import ACT2FN
31 |
32 | from transformers.modeling_outputs import (
33 | # BaseModelOutput,
34 | ModelOutput,
35 | MaskedLMOutput,
36 | QuestionAnsweringModelOutput,
37 | SequenceClassifierOutput,
38 | TokenClassifierOutput,
39 | )
40 |
41 | from transformers.modeling_utils import PreTrainedModel
42 | from transformers import DebertaV2Config
43 |
44 |
45 |
46 | _CONFIG_FOR_DOC = "DebertaV2Config"
47 | _TOKENIZER_FOR_DOC = "DebertaV2Tokenizer"
48 | _CHECKPOINT_FOR_DOC = "microsoft/deberta-v2-xlarge"
49 |
50 | DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST = [
51 | "microsoft/deberta-v2-xlarge",
52 | "microsoft/deberta-v2-xxlarge",
53 | "microsoft/deberta-v2-xlarge-mnli",
54 | "microsoft/deberta-v2-xxlarge-mnli",
55 | ]
56 |
57 |
58 | class BaseModelOutput(ModelOutput):
59 | """
60 | Base class for model's outputs, with potential hidden states and attentions.
61 | Args:
62 | last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
63 | Sequence of hidden-states at the output of the last layer of the model.
64 | hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
65 | Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
66 | shape `(batch_size, sequence_length, hidden_size)`.
67 | Hidden-states of the model at the output of each layer plus the initial embedding outputs.
68 | attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
69 | Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
70 | sequence_length)`.
71 | Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
72 | heads.
73 | """
74 |
75 | last_hidden_state: torch.FloatTensor = None
76 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None
77 | attentions: Optional[Tuple[torch.FloatTensor]] = None
78 | position_embeddings: torch.FloatTensor = None
79 | attention_mask: torch.BoolTensor = None
80 |
81 |
82 | # Copied from transformers.models.deberta.modeling_deberta.ContextPooler
83 | class ContextPooler(nn.Module):
84 | def __init__(self, config):
85 | super().__init__()
86 | self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
87 | self.dropout = StableDropout(config.pooler_dropout)
88 | self.config = config
89 |
90 | def forward(self, hidden_states):
91 | # We "pool" the model by simply taking the hidden state corresponding
92 | # to the first token.
93 |
94 | context_token = hidden_states[:, 0]
95 | context_token = self.dropout(context_token)
96 | pooled_output = self.dense(context_token)
97 | pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
98 | return pooled_output
99 |
100 | @property
101 | def output_dim(self):
102 | return self.config.hidden_size
103 |
104 |
105 | # Copied from transformers.models.deberta.modeling_deberta.XSoftmax with deberta->deberta_v2
106 | class XSoftmax(torch.autograd.Function):
107 | """
108 | Masked Softmax which is optimized for saving memory
109 |
110 | Args:
111 | input (:obj:`torch.tensor`): The input tensor that will apply softmax.
112 | mask (:obj:`torch.IntTensor`): The mask matrix where 0 indicate that element will be ignored in the softmax calculation.
113 | dim (int): The dimension that will apply softmax
114 |
115 | Example::
116 |
117 | import torch
118 | from transformers.models.deberta_v2.modeling_deberta_v2 import XSoftmax
119 |
120 | # Make a tensor
121 | x = torch.randn([4,20,100])
122 |
123 | # Create a mask
124 | mask = (x>0).int()
125 |
126 | y = XSoftmax.apply(x, mask, dim=-1)
127 | """
128 |
129 | @staticmethod
130 | def forward(self, input, mask, dim):
131 | self.dim = dim
132 | rmask = ~(mask.bool())
133 |
134 | output = input.masked_fill(rmask, float("-inf"))
135 | output = torch.softmax(output, self.dim)
136 | output.masked_fill_(rmask, 0)
137 | self.save_for_backward(output)
138 | return output
139 |
140 | @staticmethod
141 | def backward(self, grad_output):
142 | (output,) = self.saved_tensors
143 | inputGrad = _softmax_backward_data(grad_output, output, self.dim, output)
144 | return inputGrad, None, None
145 |
146 |
147 | # Copied from transformers.models.deberta.modeling_deberta.DropoutContext
148 | class DropoutContext(object):
149 | def __init__(self):
150 | self.dropout = 0
151 | self.mask = None
152 | self.scale = 1
153 | self.reuse_mask = True
154 |
155 |
156 | # Copied from transformers.models.deberta.modeling_deberta.get_mask
157 | def get_mask(input, local_context):
158 | if not isinstance(local_context, DropoutContext):
159 | dropout = local_context
160 | mask = None
161 | else:
162 | dropout = local_context.dropout
163 | dropout *= local_context.scale
164 | mask = local_context.mask if local_context.reuse_mask else None
165 |
166 | if dropout > 0 and mask is None:
167 | mask = (1 - torch.empty_like(input).bernoulli_(1 - dropout)).bool()
168 |
169 | if isinstance(local_context, DropoutContext):
170 | if local_context.mask is None:
171 | local_context.mask = mask
172 |
173 | return mask, dropout
174 |
175 |
176 | # Copied from transformers.models.deberta.modeling_deberta.XDropout
177 | class XDropout(torch.autograd.Function):
178 | """Optimized dropout function to save computation and memory by using mask operation instead of multiplication."""
179 |
180 | @staticmethod
181 | def forward(ctx, input, local_ctx):
182 | mask, dropout = get_mask(input, local_ctx)
183 | ctx.scale = 1.0 / (1 - dropout)
184 | if dropout > 0:
185 | ctx.save_for_backward(mask)
186 | return input.masked_fill(mask, 0) * ctx.scale
187 | else:
188 | return input
189 |
190 | @staticmethod
191 | def backward(ctx, grad_output):
192 | if ctx.scale > 1:
193 | (mask,) = ctx.saved_tensors
194 | return grad_output.masked_fill(mask, 0) * ctx.scale, None
195 | else:
196 | return grad_output, None
197 |
198 |
199 | # Copied from transformers.models.deberta.modeling_deberta.StableDropout
200 | class StableDropout(nn.Module):
201 | """
202 | Optimized dropout module for stabilizing the training
203 |
204 | Args:
205 | drop_prob (float): the dropout probabilities
206 | """
207 |
208 | def __init__(self, drop_prob):
209 | super().__init__()
210 | self.drop_prob = drop_prob
211 | self.count = 0
212 | self.context_stack = None
213 |
214 | def forward(self, x):
215 | """
216 | Call the module
217 |
218 | Args:
219 | x (:obj:`torch.tensor`): The input tensor to apply dropout
220 | """
221 | if self.training and self.drop_prob > 0:
222 | return XDropout.apply(x, self.get_context())
223 | return x
224 |
225 | def clear_context(self):
226 | self.count = 0
227 | self.context_stack = None
228 |
229 | def init_context(self, reuse_mask=True, scale=1):
230 | if self.context_stack is None:
231 | self.context_stack = []
232 | self.count = 0
233 | for c in self.context_stack:
234 | c.reuse_mask = reuse_mask
235 | c.scale = scale
236 |
237 | def get_context(self):
238 | if self.context_stack is not None:
239 | if self.count >= len(self.context_stack):
240 | self.context_stack.append(DropoutContext())
241 | ctx = self.context_stack[self.count]
242 | ctx.dropout = self.drop_prob
243 | self.count += 1
244 | return ctx
245 | else:
246 | return self.drop_prob
247 |
248 |
249 | # Copied from transformers.models.deberta.modeling_deberta.DebertaSelfOutput with DebertaLayerNorm->LayerNorm
250 | class DebertaV2SelfOutput(nn.Module):
251 | def __init__(self, config, ds_factor, dropout):
252 | super().__init__()
253 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
254 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
255 | self.dropout = StableDropout(config.hidden_dropout_prob)
256 | self.ds_factor = ds_factor
257 | if self.ds_factor:
258 | self.adapter = Adapter(ds_factor, config.hidden_size, dropout=dropout)
259 |
260 | def forward(self, hidden_states, input_tensor):
261 | hidden_states = self.dense(hidden_states)
262 | if self.ds_factor:
263 | hidden_states = self.adapter(hidden_states)
264 | hidden_states = self.dropout(hidden_states)
265 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
266 | return hidden_states
267 |
268 |
269 | # Copied from transformers.models.deberta.modeling_deberta.DebertaAttention with Deberta->DebertaV2
270 | class DebertaV2Attention(nn.Module):
271 | def __init__(self, config, ds_factor, dropout):
272 | super().__init__()
273 | self.self = DisentangledSelfAttention(config)
274 | self.output = DebertaV2SelfOutput(config, ds_factor, dropout)
275 | self.config = config
276 |
277 | def forward(
278 | self,
279 | hidden_states,
280 | attention_mask,
281 | return_att=False,
282 | query_states=None,
283 | relative_pos=None,
284 | rel_embeddings=None,
285 | ):
286 | self_output = self.self(
287 | hidden_states,
288 | attention_mask,
289 | return_att,
290 | query_states=query_states,
291 | relative_pos=relative_pos,
292 | rel_embeddings=rel_embeddings,
293 | )
294 | if return_att:
295 | self_output, att_matrix = self_output
296 | if query_states is None:
297 | query_states = hidden_states
298 | attention_output = self.output(self_output, query_states)
299 |
300 | if return_att:
301 | return (attention_output, att_matrix)
302 | else:
303 | return attention_output
304 |
305 |
306 | # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->DebertaV2
307 | class DebertaV2Intermediate(nn.Module):
308 | def __init__(self, config):
309 | super().__init__()
310 | self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
311 | if isinstance(config.hidden_act, str):
312 | self.intermediate_act_fn = ACT2FN[config.hidden_act]
313 | else:
314 | self.intermediate_act_fn = config.hidden_act
315 |
316 | def forward(self, hidden_states):
317 | hidden_states = self.dense(hidden_states)
318 | hidden_states = self.intermediate_act_fn(hidden_states)
319 | return hidden_states
320 |
321 |
322 | # Copied from transformers.models.deberta.modeling_deberta.DebertaOutput with DebertaLayerNorm->LayerNorm
323 | class DebertaV2Output(nn.Module):
324 | def __init__(self, config, ds_factor, dropout):
325 | super().__init__()
326 | self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
327 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
328 | self.dropout = StableDropout(config.hidden_dropout_prob)
329 | self.config = config
330 | self.ds_factor = ds_factor
331 | if self.ds_factor:
332 | self.adapter = Adapter(ds_factor, config.hidden_size, dropout=dropout)
333 |
334 | def forward(self, hidden_states, input_tensor):
335 | hidden_states = self.dense(hidden_states)
336 | if self.ds_factor:
337 | hidden_states = self.adapter(hidden_states)
338 | hidden_states = self.dropout(hidden_states)
339 | hidden_states = self.LayerNorm(hidden_states + input_tensor)
340 | return hidden_states
341 |
342 |
343 | # Copied from transformers.models.deberta.modeling_deberta.DebertaLayer with Deberta->DebertaV2
344 | class DebertaV2Layer(nn.Module):
345 | def __init__(
346 | self,
347 | config,
348 | ds_factor_attn,
349 | ds_factor_ff,
350 | dropout,
351 | ):
352 | super().__init__()
353 | self.attention = DebertaV2Attention(config, ds_factor_attn, dropout)
354 | self.intermediate = DebertaV2Intermediate(config)
355 | self.output = DebertaV2Output(config, ds_factor_ff, dropout)
356 |
357 | def forward(
358 | self,
359 | hidden_states,
360 | attention_mask,
361 | return_att=False,
362 | query_states=None,
363 | relative_pos=None,
364 | rel_embeddings=None,
365 | ):
366 | attention_output = self.attention(
367 | hidden_states,
368 | attention_mask,
369 | return_att=return_att,
370 | query_states=query_states,
371 | relative_pos=relative_pos,
372 | rel_embeddings=rel_embeddings,
373 | )
374 | if return_att:
375 | attention_output, att_matrix = attention_output
376 | intermediate_output = self.intermediate(attention_output)
377 | layer_output = self.output(intermediate_output, attention_output)
378 | if return_att:
379 | return (layer_output, att_matrix)
380 | else:
381 | return layer_output
382 |
383 |
384 | class ConvLayer(nn.Module):
385 | def __init__(self, config):
386 | super().__init__()
387 | kernel_size = getattr(config, "conv_kernel_size", 3)
388 | groups = getattr(config, "conv_groups", 1)
389 | self.conv_act = getattr(config, "conv_act", "tanh")
390 | self.conv = nn.Conv1d(
391 | config.hidden_size,
392 | config.hidden_size,
393 | kernel_size,
394 | padding=(kernel_size - 1) // 2,
395 | groups=groups,
396 | )
397 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
398 | self.dropout = StableDropout(config.hidden_dropout_prob)
399 | self.config = config
400 |
401 | def forward(self, hidden_states, residual_states, input_mask):
402 | out = (
403 | self.conv(hidden_states.permute(0, 2, 1).contiguous())
404 | .permute(0, 2, 1)
405 | .contiguous()
406 | )
407 | rmask = (1 - input_mask).bool()
408 | out.masked_fill_(rmask.unsqueeze(-1).expand(out.size()), 0)
409 | out = ACT2FN[self.conv_act](self.dropout(out))
410 |
411 | layer_norm_input = residual_states + out
412 | output = self.LayerNorm(layer_norm_input).to(layer_norm_input)
413 |
414 | if input_mask is None:
415 | output_states = output
416 | else:
417 | if input_mask.dim() != layer_norm_input.dim():
418 | if input_mask.dim() == 4:
419 | input_mask = input_mask.squeeze(1).squeeze(1)
420 | input_mask = input_mask.unsqueeze(2)
421 |
422 | input_mask = input_mask.to(output.dtype)
423 | output_states = output * input_mask
424 |
425 | return output_states
426 |
427 |
428 | class DebertaV2Encoder(nn.Module):
429 | """Modified BertEncoder with relative position bias support"""
430 |
431 | def __init__(self, config, ds_factor_attn, ds_factor_ff, dropout):
432 | super().__init__()
433 |
434 | self.layer = nn.ModuleList([DebertaV2Layer(config, ds_factor_attn, ds_factor_ff, dropout) for _ in range(config.num_hidden_layers)])
435 | self.relative_attention = getattr(config, "relative_attention", False)
436 |
437 | if self.relative_attention:
438 | self.max_relative_positions = getattr(config, "max_relative_positions", -1)
439 | if self.max_relative_positions < 1:
440 | self.max_relative_positions = config.max_position_embeddings
441 |
442 | self.position_buckets = getattr(config, "position_buckets", -1)
443 | pos_ebd_size = self.max_relative_positions * 2
444 |
445 | if self.position_buckets > 0:
446 | pos_ebd_size = self.position_buckets * 2
447 |
448 | self.rel_embeddings = nn.Embedding(pos_ebd_size, config.hidden_size)
449 |
450 | self.norm_rel_ebd = [x.strip() for x in getattr(config, "norm_rel_ebd", "none").lower().split("|")]
451 |
452 | if "layer_norm" in self.norm_rel_ebd:
453 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
454 |
455 | self.conv = ConvLayer(config) if getattr(config, "conv_kernel_size", 0) > 0 else None
456 |
457 | def get_rel_embedding(self):
458 | rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
459 | if rel_embeddings is not None and ("layer_norm" in self.norm_rel_ebd):
460 | rel_embeddings = self.LayerNorm(rel_embeddings)
461 | return rel_embeddings
462 |
463 | def get_attention_mask(self, attention_mask):
464 | if attention_mask.dim() <= 2:
465 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
466 | attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
467 | attention_mask = attention_mask.byte()
468 | elif attention_mask.dim() == 3:
469 | attention_mask = attention_mask.unsqueeze(1)
470 |
471 | return attention_mask
472 |
473 | def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
474 | if self.relative_attention and relative_pos is None:
475 | q = query_states.size(-2) if query_states is not None else hidden_states.size(-2)
476 | relative_pos = build_relative_position(q, hidden_states.size(-2), bucket_size=self.position_buckets, max_position=self.max_relative_positions)
477 | return relative_pos
478 |
479 | def forward(self, hidden_states, attention_mask, output_hidden_states=True, output_attentions=False, query_states=None, relative_pos=None, return_dict=True):
480 | if attention_mask.dim() <= 2:
481 | input_mask = attention_mask
482 | else:
483 | input_mask = (attention_mask.sum(-2) > 0).byte()
484 | attention_mask = self.get_attention_mask(attention_mask)
485 | relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
486 |
487 | all_hidden_states = () if output_hidden_states else None
488 | all_attentions = () if output_attentions else None
489 |
490 | if isinstance(hidden_states, Sequence):
491 | next_kv = hidden_states[0]
492 | else:
493 | next_kv = hidden_states
494 | rel_embeddings = self.get_rel_embedding()
495 | output_states = next_kv
496 | for i, layer_module in enumerate(self.layer):
497 |
498 | if output_hidden_states:
499 | all_hidden_states = all_hidden_states + (output_states,)
500 |
501 | output_states = layer_module(next_kv, attention_mask, output_attentions, query_states=query_states, relative_pos=relative_pos, rel_embeddings=rel_embeddings)
502 | if output_attentions:
503 | output_states, att_m = output_states
504 |
505 | if i == 0 and self.conv is not None:
506 | output_states = self.conv(hidden_states, output_states, input_mask)
507 |
508 | if query_states is not None:
509 | query_states = output_states
510 | if isinstance(hidden_states, Sequence):
511 | next_kv = hidden_states[i + 1] if i + 1 < len(self.layer) else None
512 | else:
513 | next_kv = output_states
514 |
515 | if output_attentions:
516 | all_attentions = all_attentions + (att_m,)
517 |
518 | if output_hidden_states:
519 | all_hidden_states = all_hidden_states + (output_states,)
520 |
521 | if not return_dict:
522 | return tuple(v for v in [output_states, all_hidden_states, all_attentions] if v is not None)
523 |
524 | return BaseModelOutput(last_hidden_state=output_states, hidden_states=all_hidden_states, attentions=all_attentions)
525 |
526 |
527 | def make_log_bucket_position(relative_pos, bucket_size, max_position):
528 | sign = np.sign(relative_pos)
529 | mid = bucket_size // 2
530 | abs_pos = np.where(
531 | (relative_pos < mid) & (relative_pos > -mid), mid - 1, np.abs(relative_pos)
532 | )
533 | log_pos = (
534 | np.ceil(np.log(abs_pos / mid) / np.log((max_position - 1) / mid) * (mid - 1))
535 | + mid
536 | )
537 | bucket_pos = np.where(abs_pos <= mid, relative_pos, log_pos * sign).astype(np.int)
538 | return bucket_pos
539 |
540 |
541 | def build_relative_position(query_size, key_size, bucket_size=-1, max_position=-1):
542 | """
543 | Build relative position according to the query and key
544 |
545 | We assume the absolute position of query :math:`P_q` is range from (0, query_size) and the absolute position of key
546 | :math:`P_k` is range from (0, key_size), The relative positions from query to key is :math:`R_{q \\rightarrow k} =
547 | P_q - P_k`
548 |
549 | Args:
550 | query_size (int): the length of query
551 | key_size (int): the length of key
552 | bucket_size (int): the size of position bucket
553 | max_position (int): the maximum allowed absolute position
554 |
555 | Return:
556 | :obj:`torch.LongTensor`: A tensor with shape [1, query_size, key_size]
557 |
558 | """
559 | q_ids = np.arange(0, query_size)
560 | k_ids = np.arange(0, key_size)
561 | rel_pos_ids = q_ids[:, None] - np.tile(k_ids, (q_ids.shape[0], 1))
562 | if bucket_size > 0 and max_position > 0:
563 | rel_pos_ids = make_log_bucket_position(rel_pos_ids, bucket_size, max_position)
564 | rel_pos_ids = torch.tensor(rel_pos_ids, dtype=torch.long)
565 | rel_pos_ids = rel_pos_ids[:query_size, :]
566 | rel_pos_ids = rel_pos_ids.unsqueeze(0)
567 | return rel_pos_ids
568 |
569 |
570 | @torch.jit.script
571 | # Copied from transformers.models.deberta.modeling_deberta.c2p_dynamic_expand
572 | def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
573 | return c2p_pos.expand(
574 | [
575 | query_layer.size(0),
576 | query_layer.size(1),
577 | query_layer.size(2),
578 | relative_pos.size(-1),
579 | ]
580 | )
581 |
582 |
583 | @torch.jit.script
584 | # Copied from transformers.models.deberta.modeling_deberta.p2c_dynamic_expand
585 | def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
586 | return c2p_pos.expand(
587 | [
588 | query_layer.size(0),
589 | query_layer.size(1),
590 | key_layer.size(-2),
591 | key_layer.size(-2),
592 | ]
593 | )
594 |
595 |
596 | @torch.jit.script
597 | # Copied from transformers.models.deberta.modeling_deberta.pos_dynamic_expand
598 | def pos_dynamic_expand(pos_index, p2c_att, key_layer):
599 | return pos_index.expand(
600 | p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))
601 | )
602 |
603 |
604 | class DisentangledSelfAttention(nn.Module):
605 | """
606 | Disentangled self-attention module
607 |
608 | Parameters:
609 | config (:obj:`DebertaV2Config`):
610 | A model config class instance with the configuration to build a new model. The schema is similar to
611 | `BertConfig`, for more details, please refer :class:`~transformers.DebertaV2Config`
612 |
613 | """
614 |
615 | def __init__(self, config):
616 | super().__init__()
617 | if config.hidden_size % config.num_attention_heads != 0:
618 | raise ValueError(
619 | f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
620 | f"heads ({config.num_attention_heads})"
621 | )
622 | self.num_attention_heads = config.num_attention_heads
623 | _attention_head_size = config.hidden_size // config.num_attention_heads
624 | self.attention_head_size = getattr(
625 | config, "attention_head_size", _attention_head_size
626 | )
627 | self.all_head_size = self.num_attention_heads * self.attention_head_size
628 | self.query_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
629 | self.key_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
630 | self.value_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=True)
631 |
632 | self.share_att_key = getattr(config, "share_att_key", False)
633 | self.pos_att_type = (
634 | config.pos_att_type if config.pos_att_type is not None else []
635 | )
636 | self.relative_attention = getattr(config, "relative_attention", False)
637 |
638 | if self.relative_attention:
639 | self.position_buckets = getattr(config, "position_buckets", -1)
640 | self.max_relative_positions = getattr(config, "max_relative_positions", -1)
641 | if self.max_relative_positions < 1:
642 | self.max_relative_positions = config.max_position_embeddings
643 | self.pos_ebd_size = self.max_relative_positions
644 | if self.position_buckets > 0:
645 | self.pos_ebd_size = self.position_buckets
646 |
647 | self.pos_dropout = StableDropout(config.hidden_dropout_prob)
648 |
649 | if not self.share_att_key:
650 | if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
651 | self.pos_key_proj = nn.Linear(
652 | config.hidden_size, self.all_head_size, bias=True
653 | )
654 | if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
655 | self.pos_query_proj = nn.Linear(
656 | config.hidden_size, self.all_head_size
657 | )
658 |
659 | self.dropout = StableDropout(config.attention_probs_dropout_prob)
660 |
661 | def transpose_for_scores(self, x, attention_heads):
662 | new_x_shape = x.size()[:-1] + (attention_heads, -1)
663 | x = x.view(*new_x_shape)
664 | return x.permute(0, 2, 1, 3).contiguous().view(-1, x.size(1), x.size(-1))
665 |
666 | def forward(
667 | self,
668 | hidden_states,
669 | attention_mask,
670 | return_att=False,
671 | query_states=None,
672 | relative_pos=None,
673 | rel_embeddings=None,
674 | ):
675 | """
676 | Call the module
677 |
678 | Args:
679 | hidden_states (:obj:`torch.FloatTensor`):
680 | Input states to the module usually the output from previous layer, it will be the Q,K and V in
681 | `Attention(Q,K,V)`
682 |
683 | attention_mask (:obj:`torch.ByteTensor`):
684 | An attention mask matrix of shape [`B`, `N`, `N`] where `B` is the batch size, `N` is the maximum
685 | sequence length in which element [i,j] = `1` means the `i` th token in the input can attend to the `j`
686 | th token.
687 |
688 | return_att (:obj:`bool`, optional):
689 | Whether return the attention matrix.
690 |
691 | query_states (:obj:`torch.FloatTensor`, optional):
692 | The `Q` state in `Attention(Q,K,V)`.
693 |
694 | relative_pos (:obj:`torch.LongTensor`):
695 | The relative position encoding between the tokens in the sequence. It's of shape [`B`, `N`, `N`] with
696 | values ranging in [`-max_relative_positions`, `max_relative_positions`].
697 |
698 | rel_embeddings (:obj:`torch.FloatTensor`):
699 | The embedding of relative distances. It's a tensor of shape [:math:`2 \\times
700 | \\text{max_relative_positions}`, `hidden_size`].
701 |
702 |
703 | """
704 | if query_states is None:
705 | query_states = hidden_states
706 | query_layer = self.transpose_for_scores(
707 | self.query_proj(query_states), self.num_attention_heads
708 | )
709 | key_layer = self.transpose_for_scores(
710 | self.key_proj(hidden_states), self.num_attention_heads
711 | )
712 | value_layer = self.transpose_for_scores(
713 | self.value_proj(hidden_states), self.num_attention_heads
714 | )
715 |
716 | rel_att = None
717 | # Take the dot product between "query" and "key" to get the raw attention scores.
718 | scale_factor = 1
719 | if "c2p" in self.pos_att_type:
720 | scale_factor += 1
721 | if "p2c" in self.pos_att_type:
722 | scale_factor += 1
723 | if "p2p" in self.pos_att_type:
724 | scale_factor += 1
725 | scale = math.sqrt(query_layer.size(-1) * scale_factor)
726 | attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale
727 | if self.relative_attention:
728 | rel_embeddings = self.pos_dropout(rel_embeddings)
729 | rel_att = self.disentangled_attention_bias(
730 | query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
731 | )
732 |
733 | if rel_att is not None:
734 | attention_scores = attention_scores + rel_att
735 | attention_scores = attention_scores
736 | attention_scores = attention_scores.view(
737 | -1,
738 | self.num_attention_heads,
739 | attention_scores.size(-2),
740 | attention_scores.size(-1),
741 | )
742 |
743 | # bsz x height x length x dimension
744 | attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
745 | attention_probs = self.dropout(attention_probs)
746 | context_layer = torch.bmm(
747 | attention_probs.view(
748 | -1, attention_probs.size(-2), attention_probs.size(-1)
749 | ),
750 | value_layer,
751 | )
752 | context_layer = (
753 | context_layer.view(
754 | -1,
755 | self.num_attention_heads,
756 | context_layer.size(-2),
757 | context_layer.size(-1),
758 | )
759 | .permute(0, 2, 1, 3)
760 | .contiguous()
761 | )
762 | new_context_layer_shape = context_layer.size()[:-2] + (-1,)
763 | context_layer = context_layer.view(*new_context_layer_shape)
764 | if return_att:
765 | return (context_layer, attention_probs)
766 | else:
767 | return context_layer
768 |
769 | def disentangled_attention_bias(
770 | self, query_layer, key_layer, relative_pos, rel_embeddings, scale_factor
771 | ):
772 | if relative_pos is None:
773 | q = query_layer.size(-2)
774 | relative_pos = build_relative_position(
775 | q,
776 | key_layer.size(-2),
777 | bucket_size=self.position_buckets,
778 | max_position=self.max_relative_positions,
779 | )
780 | if relative_pos.dim() == 2:
781 | relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
782 | elif relative_pos.dim() == 3:
783 | relative_pos = relative_pos.unsqueeze(1)
784 | # bsz x height x query x key
785 | elif relative_pos.dim() != 4:
786 | raise ValueError(
787 | f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}"
788 | )
789 |
790 | att_span = self.pos_ebd_size
791 | relative_pos = relative_pos.long().to(query_layer.device)
792 |
793 | rel_embeddings = rel_embeddings[
794 | self.pos_ebd_size - att_span : self.pos_ebd_size + att_span, :
795 | ].unsqueeze(0)
796 | if self.share_att_key:
797 | pos_query_layer = self.transpose_for_scores(
798 | self.query_proj(rel_embeddings), self.num_attention_heads
799 | ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
800 | pos_key_layer = self.transpose_for_scores(
801 | self.key_proj(rel_embeddings), self.num_attention_heads
802 | ).repeat(query_layer.size(0) // self.num_attention_heads, 1, 1)
803 | else:
804 | if "c2p" in self.pos_att_type or "p2p" in self.pos_att_type:
805 | pos_key_layer = self.transpose_for_scores(
806 | self.pos_key_proj(rel_embeddings), self.num_attention_heads
807 | ).repeat(
808 | query_layer.size(0) // self.num_attention_heads, 1, 1
809 | ) # .split(self.all_head_size, dim=-1)
810 | if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
811 | pos_query_layer = self.transpose_for_scores(
812 | self.pos_query_proj(rel_embeddings), self.num_attention_heads
813 | ).repeat(
814 | query_layer.size(0) // self.num_attention_heads, 1, 1
815 | ) # .split(self.all_head_size, dim=-1)
816 |
817 | score = 0
818 | # content->position
819 | if "c2p" in self.pos_att_type:
820 | scale = math.sqrt(pos_key_layer.size(-1) * scale_factor)
821 | c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
822 | c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
823 | c2p_att = torch.gather(
824 | c2p_att,
825 | dim=-1,
826 | index=c2p_pos.squeeze(0).expand(
827 | [query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]
828 | ),
829 | )
830 | score += c2p_att / scale
831 |
832 | # position->content
833 | if "p2c" in self.pos_att_type or "p2p" in self.pos_att_type:
834 | scale = math.sqrt(pos_query_layer.size(-1) * scale_factor)
835 | if key_layer.size(-2) != query_layer.size(-2):
836 | r_pos = build_relative_position(
837 | key_layer.size(-2),
838 | key_layer.size(-2),
839 | bucket_size=self.position_buckets,
840 | max_position=self.max_relative_positions,
841 | ).to(query_layer.device)
842 | r_pos = r_pos.unsqueeze(0)
843 | else:
844 | r_pos = relative_pos
845 |
846 | p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
847 | if query_layer.size(-2) != key_layer.size(-2):
848 | pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
849 |
850 | if "p2c" in self.pos_att_type:
851 | p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
852 | p2c_att = torch.gather(
853 | p2c_att,
854 | dim=-1,
855 | index=p2c_pos.squeeze(0).expand(
856 | [query_layer.size(0), key_layer.size(-2), key_layer.size(-2)]
857 | ),
858 | ).transpose(-1, -2)
859 | if query_layer.size(-2) != key_layer.size(-2):
860 | p2c_att = torch.gather(
861 | p2c_att,
862 | dim=-2,
863 | index=pos_index.expand(
864 | p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))
865 | ),
866 | )
867 | score += p2c_att / scale
868 |
869 | # position->position
870 | if "p2p" in self.pos_att_type:
871 | pos_query = pos_query_layer[:, :, att_span:, :]
872 | p2p_att = torch.matmul(pos_query, pos_key_layer.transpose(-1, -2))
873 | p2p_att = p2p_att.expand(query_layer.size()[:2] + p2p_att.size()[2:])
874 | if query_layer.size(-2) != key_layer.size(-2):
875 | p2p_att = torch.gather(
876 | p2p_att,
877 | dim=-2,
878 | index=pos_index.expand(
879 | query_layer.size()[:2] + (pos_index.size(-2), p2p_att.size(-1))
880 | ),
881 | )
882 | p2p_att = torch.gather(
883 | p2p_att,
884 | dim=-1,
885 | index=c2p_pos.expand(
886 | [
887 | query_layer.size(0),
888 | query_layer.size(1),
889 | query_layer.size(2),
890 | relative_pos.size(-1),
891 | ]
892 | ),
893 | )
894 | score += p2p_att
895 |
896 | return score
897 |
898 |
899 | # Copied from transformers.models.deberta.modeling_deberta.DebertaEmbeddings with DebertaLayerNorm->LayerNorm
900 | class DebertaV2Embeddings(nn.Module):
901 | """Construct the embeddings from word, position and token_type embeddings."""
902 |
903 | def __init__(self, config, features_dim):
904 | super().__init__()
905 | pad_token_id = getattr(config, "pad_token_id", 0)
906 | self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
907 | self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
908 |
909 | self.position_biased_input = getattr(config, "position_biased_input", True)
910 | self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size) # it is used for the decoder anyway
911 |
912 | if config.type_vocab_size > 0:
913 | self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
914 |
915 | if self.embedding_size != config.hidden_size:
916 | self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
917 | self.LayerNorm = LayerNorm(config.hidden_size, config.layer_norm_eps)
918 | self.dropout = StableDropout(config.hidden_dropout_prob)
919 | self.config = config
920 |
921 | # position_ids (1, len position emb) is contiguous in memory and exported when serialized
922 | self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
923 |
924 | self.features_dim = features_dim
925 | if self.features_dim:
926 | self.linear_video = nn.Linear(features_dim, config.hidden_size)
927 |
928 | # self.prompt_embedding = nn.Embedding(32, config.hidden_size)
929 | self.prompt_embedding = None
930 |
931 | def get_video_embedding(self, video):
932 | video = self.linear_video(video)
933 | return video
934 |
935 | def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None, video=None):
936 | if input_ids is not None:
937 | input_shape = input_ids.size()
938 | else:
939 | input_shape = inputs_embeds.size()[:-1]
940 |
941 | if inputs_embeds is None:
942 | inputs_embeds = self.word_embeddings(input_ids)
943 | if self.features_dim and video is not None:
944 | video = self.get_video_embedding(video)
945 | if self.prompt_embedding is not None:
946 | inputs_embeds = torch.cat([video, inputs_embeds, self.prompt_embedding.weight[None, :, :].repeat(video.shape[0], 1, 1)], 1)
947 | input_shape = inputs_embeds[:, :, 0].shape
948 | else:
949 | inputs_embeds = torch.cat([video, inputs_embeds], 1)
950 | input_shape = inputs_embeds[:, :, 0].shape
951 |
952 | seq_length = input_shape[1]
953 |
954 | if position_ids is None:
955 | position_ids = self.position_ids[:, :seq_length]
956 |
957 | if token_type_ids is None:
958 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
959 |
960 | if self.position_embeddings is not None:
961 | position_embeddings = self.position_embeddings(position_ids.long())
962 | else:
963 | position_embeddings = torch.zeros_like(inputs_embeds)
964 | embeddings = inputs_embeds
965 |
966 | if self.position_biased_input:
967 | embeddings += position_embeddings
968 | # embeddings[:, :position_embeddings.shape[1], :] += position_embeddings
969 | if self.config.type_vocab_size > 0:
970 | token_type_embeddings = self.token_type_embeddings(token_type_ids)
971 | embeddings += token_type_embeddings
972 |
973 | if self.embedding_size != self.config.hidden_size:
974 | embeddings = self.embed_proj(embeddings)
975 |
976 |
977 | embeddings = self.LayerNorm(embeddings)
978 |
979 | if mask is not None:
980 | if mask.dim() != embeddings.dim():
981 | if mask.dim() == 4:
982 | mask = mask.squeeze(1).squeeze(1)
983 | mask = mask.unsqueeze(2)
984 | mask = mask.to(embeddings.dtype)
985 |
986 | embeddings = embeddings * mask
987 |
988 | embeddings = self.dropout(embeddings)
989 | return {"embeddings": embeddings, "position_embeddings": position_embeddings}
990 |
991 |
992 | # Copied from transformers.models.deberta.modeling_deberta.DebertaPreTrainedModel with Deberta->DebertaV2
993 |
994 |
995 | class DebertaV2PreTrainedModel(PreTrainedModel):
996 | """
997 | An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
998 | models.
999 | """
1000 |
1001 | config_class = DebertaV2Config
1002 | base_model_prefix = "deberta"
1003 | _keys_to_ignore_on_load_missing = ["position_ids"]
1004 | _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
1005 |
1006 | def __init__(self, config):
1007 | super().__init__(config)
1008 | self._register_load_state_dict_pre_hook(self._pre_load_hook)
1009 |
1010 | def _init_weights(self, module):
1011 | """Initialize the weights."""
1012 | if isinstance(module, nn.Linear):
1013 | # Slightly different from the TF version which uses truncated_normal for initialization
1014 | # cf https://github.com/pytorch/pytorch/pull/5617
1015 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1016 | if module.bias is not None:
1017 | module.bias.data.zero_()
1018 | elif isinstance(module, nn.Embedding):
1019 | module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1020 | if module.padding_idx is not None:
1021 | module.weight.data[module.padding_idx].zero_()
1022 |
1023 | def _pre_load_hook(
1024 | self,
1025 | state_dict,
1026 | prefix,
1027 | local_metadata,
1028 | strict,
1029 | missing_keys,
1030 | unexpected_keys,
1031 | error_msgs,
1032 | ):
1033 | """
1034 | Removes the classifier if it doesn't have the correct number of labels.
1035 | """
1036 | self_state = self.state_dict()
1037 | if (
1038 | ("classifier.weight" in self_state)
1039 | and ("classifier.weight" in state_dict)
1040 | and self_state["classifier.weight"].size()
1041 | != state_dict["classifier.weight"].size()
1042 | ):
1043 | print(
1044 | f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
1045 | f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
1046 | f"weights. You should train your model on new data."
1047 | )
1048 | del state_dict["classifier.weight"]
1049 | if "classifier.bias" in state_dict:
1050 | del state_dict["classifier.bias"]
1051 |
1052 |
1053 | # Copied from transformers.models.deberta.modeling_deberta.DebertaModel with Deberta->DebertaV2
1054 | class DebertaV2Model(DebertaV2PreTrainedModel):
1055 | def __init__(self, config, max_feats=10, features_dim=768, freeze_lm=False, ds_factor_attn=8, ds_factor_ff=8, ft_ln=False, dropout=0.1):
1056 | super().__init__(config)
1057 |
1058 | self.embeddings = DebertaV2Embeddings(config, features_dim)
1059 | self.encoder = DebertaV2Encoder(config, ds_factor_attn, ds_factor_ff, dropout)
1060 | self.z_steps = 0
1061 | self.config = config
1062 |
1063 | self.features_dim = features_dim
1064 | self.max_feats = max_feats
1065 | if freeze_lm:
1066 | for n, p in self.named_parameters():
1067 | if (not "linear_video" in n) and (not "adapter" in n):
1068 | if ft_ln and "LayerNorm" in n:
1069 | continue
1070 | else:
1071 | p.requires_grad_(False)
1072 |
1073 | self.init_weights()
1074 |
1075 | def get_input_embeddings(self):
1076 | return self.embeddings.word_embeddings
1077 |
1078 | def set_input_embeddings(self, new_embeddings):
1079 | self.embeddings.word_embeddings = new_embeddings
1080 |
1081 | def _prune_heads(self, heads_to_prune):
1082 | """
1083 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1084 | class PreTrainedModel
1085 | """
1086 | raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
1087 |
1088 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, output_attentions=None, output_hidden_states=None,
1089 | return_dict=None, video=None, video_mask=None):
1090 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1091 | output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1092 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1093 |
1094 | if input_ids is not None and inputs_embeds is not None:
1095 | raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1096 | elif input_ids is not None:
1097 | input_shape = input_ids.size()
1098 | elif inputs_embeds is not None:
1099 | input_shape = inputs_embeds.size()[:-1]
1100 | else:
1101 | raise ValueError("You have to specify either input_ids or inputs_embeds")
1102 |
1103 | device = input_ids.device if input_ids is not None else inputs_embeds.device
1104 |
1105 | if attention_mask is None:
1106 | attention_mask = torch.ones(input_shape, device=device)
1107 |
1108 | if self.features_dim and video is not None:
1109 | if video_mask is None:
1110 | video_shape = video[:, :, 0].size()
1111 | video_mask = torch.ones(video_shape, device=device)
1112 | attention_mask = torch.cat([video_mask, attention_mask], 1)
1113 | # attention_mask = torch.cat([video_mask, attention_mask, torch.ones(video_mask.shape[0], 32).to(video_mask.device)], 1)
1114 | input_shape = attention_mask.size()
1115 |
1116 | if token_type_ids is None:
1117 | token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1118 |
1119 | embedding_output = self.embeddings(input_ids=input_ids, token_type_ids=token_type_ids, position_ids=position_ids, mask=attention_mask, inputs_embeds=inputs_embeds, video=video)
1120 | embedding_output, position_embeddings = embedding_output["embeddings"], embedding_output["position_embeddings"]
1121 |
1122 | encoder_outputs = self.encoder(embedding_output, attention_mask, output_hidden_states=True, output_attentions=output_attentions, return_dict=return_dict)
1123 | encoded_layers = encoder_outputs[1]
1124 |
1125 | if self.z_steps > 1:
1126 | hidden_states = encoded_layers[-2]
1127 | layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
1128 | query_states = encoded_layers[-1]
1129 | rel_embeddings = self.encoder.get_rel_embedding()
1130 | attention_mask = self.encoder.get_attention_mask(attention_mask)
1131 | rel_pos = self.encoder.get_rel_pos(embedding_output)
1132 | for layer in layers[1:]:
1133 | query_states = layer(hidden_states, attention_mask, return_att=False, query_states=query_states, relative_pos=rel_pos, rel_embeddings=rel_embeddings)
1134 | encoded_layers.append(query_states)
1135 |
1136 | sequence_output = encoded_layers[-1]
1137 |
1138 | if not return_dict:
1139 | return sequence_output + encoder_outputs[1 if output_hidden_states else 2:]
1140 |
1141 | return BaseModelOutput(last_hidden_state=sequence_output, hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
1142 | attentions=encoder_outputs.attentions, position_embeddings=position_embeddings, attention_mask=attention_mask)
1143 |
1144 |
1145 | # Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
1146 | class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
1147 | _keys_to_ignore_on_load_unexpected = [r"pooler"]
1148 | _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
1149 |
1150 | def __init__(self, config, max_feats=10, features_dim=768, freeze_lm=True, freeze_mlm=True, ds_factor_attn=8, ds_factor_ff=8, ft_ln=True,
1151 | dropout=0.1, n_ans=0, freeze_last=True, args=None):
1152 | """
1153 | :param config: BiLM configuration
1154 | :param max_feats: maximum number of frames used by the model
1155 | :param features_dim: embedding dimension of the visual features, set = 0 for text-only mode
1156 | :param freeze_lm: whether to freeze or not the language model (Transformer encoder + token embedder)
1157 | :param freeze_mlm: whether to freeze or not the MLM head
1158 | :param ds_factor_attn: downsampling factor for the adapter after self-attention, no adapter if set to 0
1159 | :param ds_factor_ff: downsampling factor for the adapter after feed-forward, no adapter if set to 0
1160 | :param ft_ln: whether to finetune or not the normalization layers
1161 | :param dropout: dropout probability in the adapter
1162 | :param n_ans: number of answers in the downstream vocabulary, set = 0 during cross-modal training
1163 | :param freeze_last: whether to freeze or not the answer embedding module
1164 | """
1165 | super().__init__(config)
1166 |
1167 | self.deberta = DebertaV2Model(config, max_feats, features_dim, freeze_lm, ds_factor_attn, ds_factor_ff, ft_ln, dropout)
1168 | self.lm_predictions = DebertaV2OnlyMLMHead(config, args)
1169 | self.features_dim = features_dim
1170 | if freeze_mlm:
1171 | for n, p in self.lm_predictions.named_parameters():
1172 | if ft_ln and "LayerNorm" in n:
1173 | continue
1174 | else:
1175 | p.requires_grad_(False)
1176 |
1177 | self.init_weights()
1178 | self.n_ans = n_ans
1179 | if n_ans:
1180 | self.answer_embeddings = nn.Embedding(n_ans, self.deberta.embeddings.embedding_size)
1181 | self.answer_bias = nn.Parameter(torch.zeros(n_ans))
1182 | if freeze_last:
1183 | self.answer_embeddings.requires_grad_(False)
1184 | self.answer_bias.requires_grad_(False)
1185 |
1186 | self.lm_predictions.lm_head.gnn1.lin_src.weight.data = torch.eye(config.hidden_size)
1187 | self.lm_predictions.lm_head.gnn2.lin_src.weight.data = torch.eye(config.hidden_size)
1188 | self.lm_predictions.lm_head.gnn1.lin_dst.weight.data = torch.eye(config.hidden_size)
1189 | self.lm_predictions.lm_head.gnn2.lin_dst.weight.data = torch.eye(config.hidden_size)
1190 | self.lm_predictions.lm_head.gnn1.lin_src.weight.requires_grad_(True)
1191 | self.lm_predictions.lm_head.gnn2.lin_src.weight.requires_grad_(True)
1192 | self.lm_predictions.lm_head.gnn1.lin_dst.weight.requires_grad_(True)
1193 | self.lm_predictions.lm_head.gnn2.lin_dst.weight.requires_grad_(True)
1194 |
1195 | def get_output_embeddings(self):
1196 | return self.lm_predictions.lm_head.decoder
1197 |
1198 | def set_output_embeddings(self, new_embeddings):
1199 | self.lm_predictions.lm_head.decoder = new_embeddings
1200 |
1201 | def set_answer_embeddings(self, a2tok, freeze_last=True):
1202 | a2v = self.deberta.embeddings.word_embeddings(a2tok)
1203 | pad_token_id = getattr(self.config, "pad_token_id", 0)
1204 | sum_tokens = (a2tok != pad_token_id).sum(1, keepdims=True) # n_ans
1205 | if len(a2v) != self.n_ans: # reinitialize the answer embeddings
1206 | # assert not self.training
1207 | self.n_ans = len(a2v)
1208 | self.answer_embeddings = nn.Embedding(self.n_ans, self.deberta.embeddings.embedding_size).to(self.device)
1209 | self.answer_bias.requires_grad = False
1210 | self.answer_bias.resize_(self.n_ans)
1211 | self.answer_embeddings.weight.data = torch.div((a2v * (a2tok != pad_token_id).float()[:, :, None]).sum(1), sum_tokens.clamp(min=1)) # n_ans
1212 | a2b = self.lm_predictions.lm_head.bias[a2tok]
1213 | self.answer_bias.weight = torch.div((a2b * (a2tok != pad_token_id).float()).sum(1), sum_tokens.clamp(min=1))
1214 | if freeze_last:
1215 | self.answer_embeddings.requires_grad_(False)
1216 | self.answer_bias.requires_grad_(False)
1217 |
1218 | def get_verbalized_embeddings(self, verbalized_id):
1219 | verbalized_embeddings = self.deberta.embeddings.word_embeddings(verbalized_id)
1220 | pad_token_id = getattr(self.config, "pad_token_id", 0)
1221 | sum_tokens = (verbalized_id != pad_token_id).sum(1, keepdims=True)
1222 | return verbalized_embeddings, sum_tokens
1223 |
1224 | def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder):
1225 | if attention_mask.dim() <= 2:
1226 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
1227 | att_mask = extended_attention_mask.byte()
1228 | attention_mask = att_mask * att_mask.squeeze(-2).unsqueeze(-1)
1229 | elif attention_mask.dim() == 3:
1230 | attention_mask = attention_mask.unsqueeze(1)
1231 | hidden_states = encoder_layers[-2]
1232 |
1233 | if not self.config.position_biased_input:
1234 | layers = [encoder.layer[-1] for _ in range(2)]
1235 | z_states += hidden_states
1236 | query_states = z_states
1237 | query_mask = attention_mask
1238 | outputs = []
1239 | rel_embeddings = encoder.get_rel_embedding()
1240 | for layer in layers:
1241 | output = layer(hidden_states, query_mask, return_att=False, query_states=query_states, relative_pos=None, rel_embeddings=rel_embeddings)
1242 | query_states = output
1243 | outputs.append(query_states)
1244 | else:
1245 | outputs = [encoder_layers[-1]]
1246 |
1247 | return outputs
1248 |
1249 | def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, labels=None, output_attentions=None,
1250 | return_dict=None, video=None, video_mask=None, mlm=False, eps=None, edge_index=None, vocab_embeddings=None):
1251 | r"""
1252 | labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1253 | Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1254 | config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1255 | (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1256 | """
1257 |
1258 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1259 |
1260 | outputs = self.deberta(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds,
1261 | output_attentions=output_attentions, output_hidden_states=True, return_dict=return_dict, video=video, video_mask=video_mask)
1262 |
1263 |
1264 | if labels is not None:
1265 | if self.features_dim and video is not None: # ignore the label predictions for visual tokens
1266 | video_shape = video[:, :, 0].size()
1267 | video_labels = torch.tensor([[-100] * video_shape[1]] * video_shape[0], dtype=torch.long, device=labels.device)
1268 | labels = torch.cat([video_labels, labels], 1)
1269 |
1270 | # sequence_output = outputs[0]
1271 | modified = self.emd_context_layer(encoder_layers=outputs["hidden_states"],
1272 | z_states=outputs["position_embeddings"].repeat(input_ids.shape[0] // len(outputs["position_embeddings"]), 1, 1),
1273 | attention_mask=outputs["attention_mask"], encoder=self.deberta.encoder)
1274 |
1275 | bias = None
1276 | if self.n_ans and (not mlm): # downstream mode
1277 | embeddings = self.answer_embeddings.weight
1278 | bias = self.answer_bias
1279 | else:
1280 | embeddings = self.deberta.embeddings.word_embeddings.weight
1281 |
1282 | prediction_scores = self.lm_predictions(modified[-1], embeddings, edge_index, vocab_embeddings, eps, bias)
1283 | masked_lm_loss = None
1284 | if labels is not None:
1285 | loss_fct = CrossEntropyLoss() # -100 index = padding token
1286 | masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) # labels[labels > 0].view(-1)
1287 |
1288 | if not return_dict:
1289 | output = (prediction_scores,) + outputs[1:]
1290 | return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1291 |
1292 | return MaskedLMOutput(loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
1293 |
1294 |
1295 | # copied from transformers.models.bert.BertPredictionHeadTransform with bert -> deberta
1296 | class DebertaV2PredictionHeadTransform(nn.Module):
1297 | def __init__(self, config):
1298 | super().__init__()
1299 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1300 | if isinstance(config.hidden_act, str):
1301 | self.transform_act_fn = ACT2FN[config.hidden_act]
1302 | else:
1303 | self.transform_act_fn = config.hidden_act
1304 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1305 |
1306 | def forward(self, hidden_states):
1307 | hidden_states = self.dense(hidden_states)
1308 | hidden_states = self.transform_act_fn(hidden_states)
1309 | hidden_states = self.LayerNorm(hidden_states)
1310 | return hidden_states
1311 |
1312 |
1313 | # copied from transformers.models.bert.BertLMPredictionHead with bert -> deberta
1314 | class DebertaV2LMPredictionHead(nn.Module):
1315 | def __init__(self, config, args):
1316 | super().__init__()
1317 | self.dense = nn.Linear(config.hidden_size, config.hidden_size)
1318 | if isinstance(config.hidden_act, str):
1319 | self.transform_act_fn = ACT2FN[config.hidden_act]
1320 | else:
1321 | self.transform_act_fn = config.hidden_act
1322 | self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1323 |
1324 | # The output weights are the same as the input embeddings, but there is
1325 | # an output-only bias for each token.
1326 | self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # only for compatiblity
1327 |
1328 | self.bias = nn.Parameter(torch.zeros(config.vocab_size))
1329 |
1330 | # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
1331 | self.decoder.bias = self.bias # only for compatiblity
1332 |
1333 | self.gnn1 = GNNSoftVerbalizer(config.hidden_size, config.hidden_size)
1334 | self.gnn2 = GNNSoftVerbalizer(config.hidden_size, config.hidden_size)
1335 | self.args = args
1336 |
1337 | def forward(self, hidden_states, embedding_weight, edge_index, vocab_embeddings, eps, bias=None):
1338 | hidden_states = self.dense(hidden_states)
1339 | hidden_states = self.transform_act_fn(hidden_states)
1340 | hidden_states = self.LayerNorm(hidden_states)
1341 |
1342 | if bias is not None:
1343 | _embedding_weight = self.gnn1(vocab_embeddings, edge_index)
1344 | _embedding_weight = self.gnn2(_embedding_weight, edge_index)[:embedding_weight.shape[0]]
1345 | _embedding_weight = eps * embedding_weight + (1 - eps) * _embedding_weight
1346 |
1347 | logits = torch.matmul(hidden_states, _embedding_weight.t().to(hidden_states)) + bias
1348 |
1349 |
1350 | else:
1351 | logits = torch.matmul(hidden_states, embedding_weight.t().to(hidden_states)) + self.bias
1352 |
1353 | return logits
1354 |
1355 | # copied from transformers.models.bert.BertOnlyMLMHead with bert -> deberta
1356 | class DebertaV2OnlyMLMHead(nn.Module):
1357 | def __init__(self, config, args):
1358 | super().__init__()
1359 | self.lm_head = DebertaV2LMPredictionHead(config, args)
1360 |
1361 | def forward(self, sequence_output, embedding_weight, edge_index, vocab_embeddings, eps, bias=None):
1362 | prediction_scores = self.lm_head(sequence_output, embedding_weight, edge_index, vocab_embeddings, eps, bias=bias)
1363 | return prediction_scores
1364 |
--------------------------------------------------------------------------------
/model/gnn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn import Parameter
6 |
7 | from torch_geometric.nn.conv import MessagePassing
8 | from torch_geometric.nn.dense.linear import Linear
9 | from torch_geometric.nn.inits import zeros
10 | import torch_sparse
11 | from torch_geometric.utils import (
12 | add_remaining_self_loops,
13 | is_torch_sparse_tensor,
14 | scatter,
15 | spmm,
16 | to_torch_coo_tensor,
17 | )
18 | from torch_geometric.utils.num_nodes import maybe_num_nodes
19 |
20 |
21 | from typing import Optional, Tuple, Union
22 |
23 |
24 | import torch.nn.functional as F
25 | import torch_geometric.typing
26 | from torch_geometric.typing import NoneType # noqa
27 | from torch_geometric.typing import (
28 | Adj,
29 | OptPairTensor,
30 | OptTensor,
31 | Size,
32 | SparseTensor
33 | )
34 | from torch_geometric.utils import add_self_loops, remove_self_loops, scatter, softmax
35 | from torch_geometric.nn.inits import glorot, zeros
36 |
37 |
38 | class GNNSoftVerbalizer(MessagePassing):
39 | r"""The graph attentional operator from the `"Graph Attention Networks"
40 | `_ paper
41 |
42 | .. math::
43 | \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
44 | \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
45 |
46 | where the attention coefficients :math:`\alpha_{i,j}` are computed as
47 |
48 | .. math::
49 | \alpha_{i,j} =
50 | \frac{
51 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
52 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
53 | \right)\right)}
54 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
55 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
56 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
57 | \right)\right)}.
58 |
59 | If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`,
60 | the attention coefficients :math:`\alpha_{i,j}` are computed as
61 |
62 | .. math::
63 | \alpha_{i,j} =
64 | \frac{
65 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
66 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j
67 | \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,j}]\right)\right)}
68 | {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
69 | \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
70 | [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k
71 | \, \Vert \, \mathbf{\Theta}_{e} \mathbf{e}_{i,k}]\right)\right)}.
72 |
73 | Args:
74 | in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
75 | derive the size from the first input(s) to the forward method.
76 | A tuple corresponds to the sizes of source and target
77 | dimensionalities.
78 | out_channels (int): Size of each output sample.
79 | heads (int, optional): Number of multi-head-attentions.
80 | (default: :obj:`1`)
81 | concat (bool, optional): If set to :obj:`False`, the multi-head
82 | attentions are averaged instead of concatenated.
83 | (default: :obj:`True`)
84 | negative_slope (float, optional): LeakyReLU angle of the negative
85 | slope. (default: :obj:`0.2`)
86 | dropout (float, optional): Dropout probability of the normalized
87 | attention coefficients which exposes each node to a stochastically
88 | sampled neighborhood during training. (default: :obj:`0`)
89 | add_self_loops (bool, optional): If set to :obj:`False`, will not add
90 | self-loops to the input graph. (default: :obj:`True`)
91 | edge_dim (int, optional): Edge feature dimensionality (in case
92 | there are any). (default: :obj:`None`)
93 | fill_value (float or torch.Tensor or str, optional): The way to
94 | generate edge features of self-loops (in case
95 | :obj:`edge_dim != None`).
96 | If given as :obj:`float` or :class:`torch.Tensor`, edge features of
97 | self-loops will be directly given by :obj:`fill_value`.
98 | If given as :obj:`str`, edge features of self-loops are computed by
99 | aggregating all features of edges that point to the specific node,
100 | according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`,
101 | :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`)
102 | bias (bool, optional): If set to :obj:`False`, the layer will not learn
103 | an additive bias. (default: :obj:`True`)
104 | **kwargs (optional): Additional arguments of
105 | :class:`torch_geometric.nn.conv.MessagePassing`.
106 |
107 | Shapes:
108 | - **input:**
109 | node features :math:`(|\mathcal{V}|, F_{in})` or
110 | :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
111 | if bipartite,
112 | edge indices :math:`(2, |\mathcal{E}|)`,
113 | edge features :math:`(|\mathcal{E}|, D)` *(optional)*
114 | - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or
115 | :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite.
116 | If :obj:`return_attention_weights=True`, then
117 | :math:`((|\mathcal{V}|, H * F_{out}),
118 | ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))`
119 | or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|),
120 | (|\mathcal{E}|, H)))` if bipartite
121 | """
122 | def __init__(
123 | self,
124 | in_channels: Union[int, Tuple[int, int]],
125 | out_channels: int,
126 | heads: int = 1,
127 | concat: bool = True,
128 | negative_slope: float = 0.2,
129 | dropout: float = 0.0,
130 | add_self_loops: bool = True,
131 | edge_dim: Optional[int] = None,
132 | fill_value: Union[float, Tensor, str] = 'mean',
133 | bias: bool = True,
134 | **kwargs,
135 | ):
136 | kwargs.setdefault('aggr', 'add')
137 | super().__init__(node_dim=0, **kwargs)
138 |
139 | self.in_channels = in_channels
140 | self.out_channels = out_channels
141 | self.heads = heads
142 | self.concat = concat
143 | self.negative_slope = negative_slope
144 | self.dropout = dropout
145 | self.add_self_loops = add_self_loops
146 | self.edge_dim = edge_dim
147 | self.fill_value = fill_value
148 |
149 | # In case we are operating in bipartite graphs, we apply separate
150 | # transformations 'lin_src' and 'lin_dst' to source and target nodes:
151 | if isinstance(in_channels, int):
152 | self.lin_src = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot')
153 | self.lin_dst = Linear(in_channels, heads * out_channels, bias=False, weight_initializer='glorot')
154 | # self.lin_dst = self.lin_src
155 | else:
156 | self.lin_src = Linear(in_channels[0], heads * out_channels, False, weight_initializer='glorot')
157 | self.lin_dst = Linear(in_channels[1], heads * out_channels, False, weight_initializer='glorot')
158 |
159 |
160 | if edge_dim is not None:
161 | self.lin_edge = Linear(edge_dim, heads * out_channels, bias=False, weight_initializer='glorot')
162 | self.att_edge = Parameter(torch.Tensor(1, heads, out_channels))
163 | else:
164 | self.lin_edge = None
165 | self.register_parameter('att_edge', None)
166 |
167 | self.reset_parameters()
168 |
169 | def reset_parameters(self):
170 | # super().reset_parameters()
171 | self.lin_src.reset_parameters()
172 | self.lin_dst.reset_parameters()
173 | if self.lin_edge is not None:
174 | self.lin_edge.reset_parameters()
175 | glorot(self.att_edge)
176 |
177 |
178 | def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
179 | edge_attr: OptTensor = None, size: Size = None,
180 | return_attention_weights=None):
181 | # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, NoneType) -> Tensor # noqa
182 | # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, NoneType) -> Tensor # noqa
183 | # type: (Union[Tensor, OptPairTensor], Tensor, OptTensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
184 | # type: (Union[Tensor, OptPairTensor], SparseTensor, OptTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
185 | r"""Runs the forward pass of the module.
186 |
187 | Args:
188 | return_attention_weights (bool, optional): If set to :obj:`True`,
189 | will additionally return the tuple
190 | :obj:`(edge_index, attention_weights)`, holding the computed
191 | attention weights for each edge. (default: :obj:`None`)
192 | """
193 | # NOTE: attention weights will be returned whenever
194 | # `return_attention_weights` is set to a value, regardless of its
195 | # actual value (might be `True` or `False`). This is a current somewhat
196 | # hacky workaround to allow for TorchScript support via the
197 | # `torch.jit._overload` decorator, as we can only change the output
198 | # arguments conditioned on type (`None` or `bool`), not based on its
199 | # actual value.
200 |
201 | H, C = self.heads, self.out_channels
202 |
203 | # We first transform the input node features. If a tuple is passed, we
204 | # transform source and target node features via separate weights:
205 | if isinstance(x, Tensor):
206 | assert x.dim() == 2, "Static graphs not supported in 'GATConv'"
207 | # x_src = x_dst = self.lin_src(x).view(-1, H, C)
208 | x_src = x_dst = x.view(-1, H, C)
209 | else: # Tuple of source and target node features:
210 | x_src, x_dst = x
211 | assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'"
212 | x_src = self.lin_src(x_src).view(-1, H, C)
213 | if x_dst is not None:
214 | x_dst = self.lin_dst(x_dst).view(-1, H, C)
215 |
216 | x = (x_src, x_dst)
217 |
218 | # Next, we compute node-level attention coefficients, both for source
219 | # and target nodes (if present):
220 |
221 | if self.add_self_loops:
222 | if isinstance(edge_index, Tensor):
223 | # We only want to add self-loops for nodes that appear both as
224 | # source and target nodes:
225 | num_nodes = x_src.size(0)
226 | if x_dst is not None:
227 | num_nodes = min(num_nodes, x_dst.size(0))
228 | num_nodes = min(size) if size is not None else num_nodes
229 | edge_index, edge_attr = remove_self_loops(edge_index, edge_attr)
230 | edge_index, edge_attr = add_self_loops(edge_index, edge_attr, fill_value=self.fill_value, num_nodes=num_nodes)
231 | elif isinstance(edge_index, SparseTensor):
232 | if self.edge_dim is None:
233 | edge_index = torch_sparse.set_diag(edge_index)
234 | else:
235 | raise NotImplementedError(
236 | "The usage of 'edge_attr' and 'add_self_loops' "
237 | "simultaneously is currently not yet supported for "
238 | "'edge_index' in a 'SparseTensor' form")
239 |
240 | # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor)
241 | alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)
242 |
243 | # propagate_type: (x: OptPairTensor, alpha: Tensor)
244 | out = self.propagate(edge_index, x=x, alpha=alpha, size=size)
245 |
246 | if self.concat:
247 | out = out.view(-1, self.heads * self.out_channels)
248 | else:
249 | out = out.mean(dim=1)
250 |
251 | if isinstance(return_attention_weights, bool):
252 | if isinstance(edge_index, Tensor):
253 | return out, (edge_index, alpha)
254 | elif isinstance(edge_index, SparseTensor):
255 | return out, edge_index.set_value(alpha, layout='coo')
256 | else:
257 | return out
258 |
259 |
260 | def edge_update(self, x_j: Tensor, x_i: OptTensor, edge_attr: OptTensor, index: Tensor, ptr: OptTensor, size_i: Optional[int]) -> Tensor:
261 | # Given edge-level attention coefficients for source and target nodes,
262 | # we simply need to sum them up to "emulate" concatenation:
263 |
264 | _x_i, _x_j = self.lin_dst(x_i), self.lin_src(x_j)
265 | alpha = torch.bmm(_x_i, _x_j.transpose(1, 2)).squeeze(-1)
266 |
267 | if edge_attr is not None and self.lin_edge is not None:
268 | if edge_attr.dim() == 1:
269 | edge_attr = edge_attr.view(-1, 1)
270 | edge_attr = self.lin_edge(edge_attr)
271 | edge_attr = edge_attr.view(-1, self.heads, self.out_channels)
272 | alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)
273 | alpha = alpha + alpha_edge
274 |
275 | alpha = F.leaky_relu(alpha, self.negative_slope)
276 |
277 |
278 | alpha = softmax(alpha, index, ptr, size_i)
279 | alpha = F.dropout(alpha, p=self.dropout, training=self.training)
280 | return alpha
281 |
282 | def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
283 | return alpha.unsqueeze(-1) * x_j
284 |
285 | def __repr__(self) -> str:
286 | return (f'{self.__class__.__name__}({self.in_channels}, '
287 | f'{self.out_channels}, heads={self.heads})')
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
2 | pip install clip
3 | pip install jsonlines==3.0.0
4 | pip install numpy==1.22.0
5 | pip install pandas==1.4.1
6 | pip install Pillow==9.3.0
7 | pip install scikit-learn==1.0.2
8 | pip install scipy==1.8.0
9 | pip install sentencepiece==0.1.96
10 | pip install tokenizers==0.11.6
11 | pip install tqdm==4.63.1
12 | pip install transformers==4.17.0
13 | pip install hostlist==1.4.8
14 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_cluster-1.5.9-cp38-cp38-linux_x86_64.whl
15 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_scatter-2.0.8-cp38-cp38-linux_x86_64.whl
16 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_sparse-0.6.11-cp38-cp38-linux_x86_64.whl
17 | pip install https://data.pyg.org/whl/torch-1.8.0%2Bcu111/torch_spline_conv-1.2.1-cp38-cp38-linux_x86_64.whl
18 | pip install torch-geometric==2.2.0
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import numpy as np
6 | import random
7 | import json
8 | import math
9 | import sys
10 | from typing import Iterable
11 | import argparse
12 | import time
13 | import datetime
14 | from util import dist
15 | from torch.utils.data import DataLoader, DistributedSampler
16 | from collections import namedtuple
17 | from functools import reduce
18 | import pickle
19 |
20 | from dataset import VideoQA_Dataset, videoqa_collate_fn
21 | from args import get_args_parser
22 | from util.misc import get_mask, adjust_learning_rate
23 | from util.metrics import MetricLogger
24 |
25 | from transformers import DebertaV2Tokenizer
26 | from model import DebertaV2ForMaskedLM
27 |
28 | def train_one_epoch(model: torch.nn.Module, tokenizer, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int,
29 | dataset_name, args, max_norm: float = 0):
30 | model.train()
31 | edge_index = data_loader.dataset.edge_index.to(device)
32 | vocab_embeddings = data_loader.dataset.vocab_embeddings.to(device)
33 | eps = data_loader.dataset.eps[:, None].to(device)
34 |
35 | metric_logger = MetricLogger(delimiter=" ")
36 | header = "Epoch: [{}]".format(epoch)
37 | num_training_steps = int(len(data_loader) * args.epochs)
38 | args.print_freq = int(len(data_loader) / 4)
39 | for i_batch, batch_dict in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
40 | video = batch_dict["video"].to(device)
41 | video_len = batch_dict["video_len"]
42 | video_mask = get_mask(video_len, video.size(1)).to(device)
43 | text = batch_dict["text"]
44 | encoded = tokenizer(text, add_special_tokens=True, max_length=args.max_tokens, padding="longest", truncation=True, return_tensors="pt")
45 |
46 | inputs = encoded["input_ids"].to(device)
47 | attention_mask = encoded["attention_mask"].to(device)
48 |
49 | # forward
50 | answer_id = batch_dict["answer_id"].to(device)
51 | output = model(video=video, video_mask=video_mask, input_ids=inputs, attention_mask=attention_mask, edge_index=edge_index, vocab_embeddings=vocab_embeddings, eps=eps)
52 | delay = args.max_feats if args.use_video else 0
53 | logits = output['logits']
54 | logits = logits[:, delay:encoded["input_ids"].size(1) + delay][encoded["input_ids"] == tokenizer.mask_token_id]
55 |
56 | if dataset_name == "ivqa":
57 | a = (answer_id / 2).clamp(max=1)
58 | nll = -F.log_softmax(logits, 1, _stacklevel=5)
59 | loss = (nll * a / a.sum(1, keepdim=True).clamp(min=1)).sum(dim=1).mean()
60 | elif dataset_name == "vqa":
61 | a = (answer_id / 3).clamp(max=1)
62 | nll = -F.log_softmax(logits, 1, _stacklevel=5)
63 | loss = (nll * a / a.sum(1, keepdim=True).clamp(min=1)).sum(dim=1).mean()
64 | else:
65 | loss = F.cross_entropy(logits, answer_id)
66 |
67 | loss_dict = {"cls_loss": loss}
68 |
69 | # reduce losses over all GPUs for logging purposes
70 | loss_dict_reduced = dist.reduce_dict(loss_dict)
71 | loss_reduced = sum(loss_dict_reduced.values())
72 | loss_value = loss_reduced.item()
73 |
74 | if not math.isfinite(loss_value):
75 | print("Loss is {}, stopping training".format(loss_value))
76 | print(loss_dict_reduced)
77 | sys.exit(1)
78 |
79 | optimizer.zero_grad()
80 | loss.backward()
81 | if max_norm > 0:
82 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
83 | optimizer.step()
84 |
85 | adjust_learning_rate(optimizer, curr_step=epoch * len(data_loader) + i_batch, num_training_steps=num_training_steps, args=args)
86 |
87 | metric_logger.update(loss=loss_value)
88 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
89 | # gather the stats from all processes
90 | metric_logger.synchronize_between_processes()
91 | print("Averaged stats:", metric_logger)
92 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
93 |
94 |
95 | @torch.no_grad()
96 | def evaluate(model: torch.nn.Module, tokenizer, data_loader, device: torch.device, dataset_name, args, thresholds=[1, 10], split="test", epoch=-1):
97 | model.eval()
98 | ans2cat = data_loader.dataset.ans2cat
99 | class_tensor = torch.zeros((len(data_loader.dataset.ans2id), 2), dtype=torch.float64, device="cuda")
100 | edge_index = data_loader.dataset.edge_index.to(device)
101 | vocab_embeddings = data_loader.dataset.vocab_embeddings.to(device)
102 | eps = data_loader.dataset.eps[:, None].to(device)
103 |
104 | metric_logger = MetricLogger(delimiter=" ")
105 | metric_logger.update(n=0, base=0)
106 | metric_logger.update(n=0, common=0)
107 | metric_logger.update(n=0, rare=0)
108 | metric_logger.update(n=0, unseen=0)
109 | metric_logger.update(n=0, total=0)
110 | header = f"{split}:"
111 |
112 | args.print_freq = int(len(data_loader) / 4)
113 | for i_batch, batch_dict in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
114 | video = batch_dict["video"].to(device)
115 | video_len = batch_dict["video_len"]
116 | video_mask = get_mask(video_len, video.size(1)).to(device)
117 | text = batch_dict["text"]
118 | encoded = tokenizer(text, add_special_tokens=True, max_length=args.max_tokens, padding="longest", truncation=True, return_tensors="pt")
119 | input_ids = encoded["input_ids"].to(device)
120 | attention_mask = encoded["attention_mask"].to(device)
121 | if not args.suffix and not args.use_context: # remove sep token if not using the suffix
122 | attention_mask[input_ids == tokenizer.sep_token_id] = 0
123 | input_ids[input_ids == tokenizer.sep_token_id] = tokenizer.pad_token_id
124 |
125 |
126 | answer_id, qids = batch_dict["answer_id"].to(device), batch_dict["qid"]
127 | output = model(video=video, video_mask=video_mask, input_ids=input_ids, attention_mask=attention_mask, edge_index=edge_index, vocab_embeddings=vocab_embeddings, eps=eps)
128 | logits = output["logits"]
129 | delay = args.max_feats if args.use_video else 0
130 | logits = logits[:, delay:encoded["input_ids"].size(1) + delay][encoded["input_ids"] == tokenizer.mask_token_id] # get the prediction on the mask token
131 | logits = logits.softmax(-1)
132 |
133 | topk_logits, topk_aids = torch.topk(logits, max(thresholds), -1)
134 |
135 | types = batch_dict["type"]
136 | original_answers = batch_dict['original_answer']
137 |
138 |
139 | for i, (p, ans) in enumerate(zip(answer_id == logits.max(1).indices, original_answers)):
140 | category = ans2cat[ans]
141 | class_tensor[answer_id[i]][0] += p.float().item()
142 | class_tensor[answer_id[i]][1] += 1.
143 | if category == 'base':
144 | metric_logger.update(n=1, base=p.float().item())
145 | elif category == 'common':
146 | metric_logger.update(n=1, common=p.float().item())
147 | elif category == 'rare':
148 | metric_logger.update(n=1, rare=p.float().item())
149 | elif category == 'unseen':
150 | metric_logger.update(n=1, unseen=p.float().item())
151 | metric_logger.update(n=1, total=p.float().item())
152 |
153 |
154 | torch.distributed.barrier()
155 | torch.distributed.all_reduce(class_tensor)
156 | macc = (class_tensor[:, 0] / class_tensor[:, 1]).mean().item()
157 | metric_logger.synchronize_between_processes()
158 | metric_logger.update(n=1, macc=macc)
159 | print("Averaged stats:", metric_logger)
160 |
161 | results = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
162 | return results
163 |
164 |
165 | def main(args):
166 | # Init distributed mode
167 | dist.init_distributed_mode(args)
168 | if dist.is_main_process():
169 | if args.save_dir and not (os.path.isdir(args.save_dir)):
170 | os.makedirs(os.path.join(args.save_dir), exist_ok=True)
171 | print(args)
172 |
173 | device = torch.device(args.device)
174 |
175 | # Fix seeds
176 | seed = args.seed + dist.get_rank()
177 | torch.manual_seed(seed)
178 | np.random.seed(seed)
179 | random.seed(seed)
180 |
181 | # Build model
182 | tokenizer = DebertaV2Tokenizer.from_pretrained(args.model_name, local_files_only=True)
183 |
184 | dataset_test = VideoQA_Dataset(args, tokenizer, "test")
185 | sampler_test = DistributedSampler(dataset_test, shuffle=False) if args.distributed else torch.utils.data.SequentialSampler(dataset_test)
186 | dataloader_test = DataLoader(dataset_test, batch_size=args.batch_size_test, sampler=sampler_test, collate_fn=videoqa_collate_fn, num_workers=args.num_workers)
187 |
188 | if not args.eval:
189 | dataset_train = VideoQA_Dataset(args, tokenizer, 'train')
190 | sampler_train = DistributedSampler(dataset_train) if args.distributed else torch.utils.data.RandomSampler(dataset_train)
191 | dataloader_train = DataLoader(dataset_train, batch_size=args.batch_size, sampler=sampler_train, collate_fn=videoqa_collate_fn, num_workers=args.num_workers)
192 |
193 | args.n_ans = len(dataloader_test.dataset.ans2id)
194 |
195 | model = DebertaV2ForMaskedLM.from_pretrained(features_dim=args.features_dim if args.use_video else 0, max_feats=args.max_feats, freeze_lm=args.freeze_lm,
196 | freeze_mlm=args.freeze_mlm, ft_ln=args.ft_ln, ds_factor_attn=args.ds_factor_attn, ds_factor_ff=args.ds_factor_ff,
197 | dropout=args.dropout, n_ans=args.n_ans, freeze_last=args.freeze_last, pretrained_model_name_or_path=args.model_name,
198 | local_files_only=True, args=args)
199 | model.to(device)
200 |
201 | total_parameters = sum(p.numel() for p in model.parameters())
202 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
203 | print(f'Total params: {total_parameters:,}')
204 | print(f'Trained params: {n_parameters:,}')
205 |
206 | # Set up optimizer
207 | params_for_optimization = list(p for n, p in model.named_parameters() if (p.requires_grad and 'gat' not in n))
208 | answer_params_for_optimization = list(p for n, p in model.named_parameters() if (p.requires_grad and 'gat' in n))
209 | optimizer = torch.optim.Adam([{"params": params_for_optimization, "lr": args.lr}, {"params": answer_params_for_optimization, "lr": args.lr}],
210 | lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay)
211 |
212 | # Load pretrained checkpoint
213 | if args.load:
214 | print("loading from", args.load)
215 | checkpoint = torch.load(args.load, map_location="cpu")
216 | if 'model' in checkpoint.keys():
217 | model.load_state_dict(checkpoint['model'], strict=False)
218 | else:
219 | model.load_state_dict(checkpoint, strict=False)
220 |
221 | if args.resume and not args.eval:
222 | optimizer.load_state_dict(checkpoint["optimizer"])
223 | args.start_epoch = checkpoint["epoch"] + 1
224 |
225 | if not args.eval:
226 | train_aid2tokid = torch.zeros(len(dataloader_train.dataset.ans2id), args.max_atokens).long()
227 | for a, aid in dataloader_train.dataset.ans2id.items():
228 | tok = torch.tensor(tokenizer(a, add_special_tokens=False, max_length=args.max_atokens, truncation=True, padding="max_length")["input_ids"], dtype=torch.long)
229 | train_aid2tokid[aid] = tok
230 |
231 | print(f'Training Vocab : {len(train_aid2tokid)}')
232 | print(f'Training Samples : {len(dataloader_train.dataset)}')
233 | test_aid2tokid = torch.zeros(len(dataloader_test.dataset.ans2id), args.max_atokens).long()
234 | for a, aid in dataloader_test.dataset.ans2id.items():
235 | tok = torch.tensor(tokenizer(a, add_special_tokens=False, max_length=args.max_atokens, truncation=True, padding="max_length")["input_ids"], dtype=torch.long)
236 | test_aid2tokid[aid] = tok
237 | print(f'Test Vocab : {len(test_aid2tokid)}')
238 | print(f'Test Samples : {len(dataloader_test.dataset)}')
239 |
240 | if not args.eval:
241 | print("Start training")
242 | start_time = time.time()
243 | best_epoch = args.start_epoch
244 | best_acc = 0
245 | for epoch in range(args.start_epoch, args.epochs):
246 | print(f"Starting epoch {epoch}")
247 | if args.distributed:
248 | sampler_train.set_epoch(epoch)
249 |
250 | model.set_answer_embeddings(train_aid2tokid.to(model.device), freeze_last=args.freeze_last)
251 | train_stats = train_one_epoch(model=model, tokenizer=tokenizer, data_loader=dataloader_train, optimizer=optimizer, device=device, epoch=epoch,
252 | dataset_name=args.dataset, args=args, max_norm=args.clip_max_norm)
253 |
254 | if (epoch + 1) % args.eval_skip == 0:
255 | print(f"Validating {args.dataset}")
256 | val_stats = {}
257 | model.set_answer_embeddings(test_aid2tokid.to(model.device), freeze_last=args.freeze_last)
258 | results = evaluate(model=model, tokenizer=tokenizer, data_loader=dataloader_test, device=device, dataset_name=args.dataset,
259 | args=args, split="val", epoch=epoch)
260 | val_stats.update({args.dataset + "_" + k: v for k, v in results.items()})
261 |
262 | if results["total"] > best_acc:
263 | best_epoch = epoch
264 | best_acc = results["total"]
265 | if dist.is_main_process() and args.save_dir:
266 | checkpoint_path = os.path.join(args.save_dir, f"best_model.pth")
267 | dist.save_on_master({"model": model, "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args}, checkpoint_path)
268 | json.dump({"acc": best_acc, "ep": epoch}, open(os.path.join(args.save_dir, args.dataset + "acc_val.json"), "w"))
269 | else:
270 | val_stats = {}
271 |
272 | log_stats = {**{f"train_{k}": v for k, v in train_stats.items()}, **{f"val_{k}": v for k, v in val_stats.items()}, "epoch": epoch, "n_parameters": n_parameters}
273 |
274 | if args.save_dir and dist.is_main_process():
275 | with open(os.path.join(args.save_dir, "log.txt"), "a") as f:
276 | f.write(json.dumps(log_stats) + "\n")
277 | checkpoint_path = os.path.join(args.save_dir, f"ckpt.pth")
278 | dist.save_on_master({"model": model, "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args}, checkpoint_path)
279 |
280 | total_time = time.time() - start_time
281 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
282 | print("Training time {}".format(total_time_str))
283 | # load best ckpt
284 | if dist.is_main_process() and args.save_dir:
285 | print(f"loading best checkpoint from epoch {best_epoch}")
286 | if args.save_dir:
287 | torch.distributed.barrier() # wait all processes
288 | checkpoint = torch.load(os.path.join(args.save_dir, f"best_model.pth"), map_location="cpu")
289 | model.load_state_dict(checkpoint["model"], strict=False)
290 |
291 | model.set_answer_embeddings(test_aid2tokid.to(model.device), freeze_last=args.freeze_last)
292 | results = evaluate(model=model, tokenizer=tokenizer, data_loader=dataloader_test, device=device, dataset_name=args.dataset,
293 | args=args, split="val" if (args.eval and not args.test) else "test")
294 |
295 | if args.save_dir and dist.is_main_process():
296 | json.dump(results, open(os.path.join(args.save_dir, args.dataset + ".json"), "w"))
297 |
298 |
299 |
300 | if __name__ == "__main__":
301 | parser = argparse.ArgumentParser(parents=[get_args_parser()])
302 | args = parser.parse_args()
303 |
304 | if args.save_dir:
305 | args.save_dir = os.path.join(args.presave_dir, args.save_dir)
306 | args.model_name = os.path.join('./pretrained', args.model_name)
307 | main(args)
308 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mlvlab/OVQA/8f9fb810c11f348eb3ef6bdd4cc0d2c313299d5c/util/__init__.py
--------------------------------------------------------------------------------
/util/dist.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import io
3 | import os
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import hostlist
8 |
9 | _LOCAL_PROCESS_GROUP = None
10 |
11 |
12 | @functools.lru_cache()
13 | def _get_global_gloo_group():
14 | """
15 | Return a process group based on gloo backend, containing all the ranks
16 | The result is cached.
17 | """
18 |
19 | if dist.get_backend() == "nccl":
20 | return dist.new_group(backend="gloo")
21 |
22 | return dist.group.WORLD
23 |
24 |
25 | def all_gather(data):
26 | """
27 | Run all_gather on arbitrary picklable data (not necessarily tensors)
28 | Args:
29 | data: any picklable object
30 | Returns:
31 | list[data]: list of data gathered from each rank
32 | """
33 |
34 | world_size = get_world_size()
35 | if world_size == 1:
36 | return [data]
37 |
38 | cpu_group = None
39 | if os.getenv("MDETR_CPU_REDUCE") == "1":
40 | cpu_group = _get_global_gloo_group()
41 |
42 | buffer = io.BytesIO()
43 | torch.save(data, buffer)
44 | data_view = buffer.getbuffer()
45 | device = "cuda" if cpu_group is None else "cpu"
46 | tensor = torch.ByteTensor(data_view).to(device)
47 |
48 | # obtain Tensor size of each rank
49 | local_size = torch.tensor([tensor.numel()], device=device, dtype=torch.long)
50 | size_list = [
51 | torch.tensor([0], device=device, dtype=torch.long) for _ in range(world_size)
52 | ]
53 | if cpu_group is None:
54 | dist.all_gather(size_list, local_size)
55 | else:
56 | print("gathering on cpu")
57 | dist.all_gather(size_list, local_size, group=cpu_group)
58 | size_list = [int(size.item()) for size in size_list]
59 | max_size = max(size_list)
60 | assert isinstance(local_size.item(), int)
61 | local_size = int(local_size.item())
62 |
63 | # receiving Tensor from all ranks
64 | # we pad the tensor because torch all_gather does not support
65 | # gathering tensors of different shapes
66 | tensor_list = []
67 | for _ in size_list:
68 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device=device))
69 | if local_size != max_size:
70 | padding = torch.empty(
71 | size=(max_size - local_size,), dtype=torch.uint8, device=device
72 | )
73 | tensor = torch.cat((tensor, padding), dim=0)
74 | if cpu_group is None:
75 | dist.all_gather(tensor_list, tensor)
76 | else:
77 | dist.all_gather(tensor_list, tensor, group=cpu_group)
78 |
79 | data_list = []
80 | for size, tensor in zip(size_list, tensor_list):
81 | tensor = torch.split(tensor, [size, max_size - size], dim=0)[0]
82 | buffer = io.BytesIO(tensor.cpu().numpy())
83 | obj = torch.load(buffer)
84 | data_list.append(obj)
85 |
86 | return data_list
87 |
88 |
89 | def reduce_dict(input_dict, average=True):
90 | """
91 | Args:
92 | input_dict (dict): all the values will be reduced
93 | average (bool): whether to do average or sum
94 | Reduce the values in the dictionary from all processes so that all processes
95 | have the averaged results. Returns a dict with the same fields as
96 | input_dict, after reduction.
97 | """
98 | world_size = get_world_size()
99 | if world_size < 2:
100 | return input_dict
101 | with torch.no_grad():
102 | names = []
103 | values = []
104 | # sort the keys so that they are consistent across processes
105 | for k in sorted(input_dict.keys()):
106 | names.append(k)
107 | values.append(input_dict[k])
108 | values = torch.stack(values, dim=0)
109 | dist.all_reduce(values)
110 | if average:
111 | values /= world_size
112 | reduced_dict = {k: v for k, v in zip(names, values)}
113 | return reduced_dict
114 |
115 |
116 | def setup_for_distributed(is_master):
117 | """
118 | This function disables printing when not in master process
119 | """
120 | import builtins as __builtin__
121 |
122 | builtin_print = __builtin__.print
123 |
124 | def print(*args, **kwargs):
125 | force = kwargs.pop("force", False)
126 | if is_master or force:
127 | builtin_print(*args, **kwargs)
128 |
129 | __builtin__.print = print
130 |
131 |
132 | def is_dist_avail_and_initialized():
133 | """
134 | Returns:
135 | True if distributed training is enabled
136 | """
137 | if not dist.is_available():
138 | return False
139 | if not dist.is_initialized():
140 | return False
141 | return True
142 |
143 |
144 | def get_world_size():
145 | """
146 | Returns:
147 | The number of processes in the process group
148 | """
149 | if not is_dist_avail_and_initialized():
150 | return 1
151 | return dist.get_world_size()
152 |
153 |
154 | def get_rank():
155 | """
156 | Returns:
157 | The rank of the current process within the global process group.
158 | """
159 | if not is_dist_avail_and_initialized():
160 | return 0
161 | return dist.get_rank()
162 |
163 |
164 | def get_local_rank() -> int:
165 | """
166 | Returns:
167 | The rank of the current process within the local (per-machine) process group.
168 | """
169 | if not dist.is_available():
170 | return 0
171 | if not dist.is_initialized():
172 | return 0
173 | assert _LOCAL_PROCESS_GROUP is not None
174 | return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
175 |
176 |
177 | def get_local_size() -> int:
178 | """
179 | Returns:
180 | The size of the per-machine process group,
181 | i.e. the number of processes per machine.
182 | """
183 | if not dist.is_available():
184 | return 1
185 | if not dist.is_initialized():
186 | return 1
187 | return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
188 |
189 |
190 | def is_main_process():
191 | """Return true if the current process is the main one"""
192 | return get_rank() == 0
193 |
194 |
195 | def save_on_master(state_dict, checkpoint_path):
196 | """Utility function to save only from the main process"""
197 | model_without_deberta = {}
198 | for n, p in state_dict['model'].named_parameters():
199 | if p.requires_grad:
200 | model_without_deberta[n] = p
201 | state_dict['model'] = model_without_deberta
202 | if is_main_process():
203 | torch.save(state_dict, checkpoint_path)
204 |
205 |
206 | def init_distributed_mode(args):
207 | """Initialize distributed training, if appropriate"""
208 | if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
209 | args.rank = int(os.environ["RANK"])
210 | args.world_size = int(os.environ["WORLD_SIZE"])
211 | args.gpu = int(os.environ["LOCAL_RANK"])
212 | elif "SLURM_PROCID" in os.environ:
213 | args.rank = int(os.environ["SLURM_PROCID"])
214 | args.gpu = args.rank % torch.cuda.device_count()
215 | # CLUSTER SPECIFIC
216 | args.world_size = int(os.environ["SLURM_NTASKS"])
217 | hostnames = hostlist.expand_hostlist(os.environ["SLURM_JOB_NODELIST"])
218 | gpu_ids = os.environ["SLURM_STEP_GPUS"].split(",")
219 | os.environ["MASTER_ADDR"] = hostnames[0]
220 | os.environ["MASTER_PORT"] = str(12345 + int(min(gpu_ids))) # to avoid port conflict on the same node
221 | else:
222 | print("Not using distributed mode")
223 | args.distributed = False
224 | return
225 |
226 | args.distributed = True
227 |
228 | torch.cuda.set_device(args.gpu)
229 | args.dist_backend = "nccl"
230 | print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True)
231 |
232 | dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
233 | dist.barrier()
234 | setup_for_distributed(args.rank == 0)
235 |
--------------------------------------------------------------------------------
/util/metrics.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import time
3 | from collections import defaultdict, deque
4 |
5 | import torch
6 | import torch.distributed as dist
7 |
8 | from util.dist import is_dist_avail_and_initialized
9 |
10 |
11 | class SmoothedValue:
12 | """Track a series of values and provide access to smoothed values over a
13 | window or the global series average.
14 | """
15 |
16 | def __init__(self, window_size=20, fmt=None):
17 | if fmt is None:
18 | fmt = "{median:.4f} ({global_avg:.4f})"
19 | self.deque = deque(maxlen=window_size)
20 | self.total = 0.0
21 | self.count = 0
22 | self.fmt = fmt
23 |
24 | def update(self, value, num=1):
25 | self.deque.append(value)
26 | self.count += num
27 | self.total += value * num
28 |
29 | def synchronize_between_processes(self):
30 | """
31 | Distributed synchronization of the metric
32 | Warning: does not synchronize the deque!
33 | """
34 | if not is_dist_avail_and_initialized():
35 | return
36 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
37 | dist.barrier()
38 | dist.all_reduce(t)
39 | t = t.tolist()
40 | self.count = int(t[0])
41 | self.total = t[1]
42 |
43 | @property
44 | def median(self):
45 | d = torch.tensor(list(self.deque))
46 | return d.median().item()
47 |
48 | @property
49 | def avg(self):
50 | d = torch.tensor(list(self.deque), dtype=torch.float32)
51 | return d.mean().item()
52 |
53 | @property
54 | def global_avg(self):
55 | if self.count == 0:
56 | return 0.
57 | else:
58 | return self.total / self.count
59 |
60 | @property
61 | def max(self):
62 | return max(self.deque)
63 |
64 | @property
65 | def value(self):
66 | return self.deque[-1]
67 |
68 | def __str__(self):
69 | return self.fmt.format(
70 | median=self.median,
71 | avg=self.avg,
72 | global_avg=self.global_avg,
73 | max=self.max,
74 | value=self.value,
75 | )
76 |
77 |
78 | class MetricLogger(object):
79 | def __init__(self, delimiter="\t"):
80 | self.meters = defaultdict(SmoothedValue)
81 | self.delimiter = delimiter
82 |
83 | def update(self, n=1, **kwargs):
84 | for k, v in kwargs.items():
85 | if isinstance(v, torch.Tensor):
86 | v = v.item()
87 | assert isinstance(v, (float, int))
88 | self.meters[k].update(v, num=n)
89 |
90 | def __getattr__(self, attr):
91 | if attr in self.meters:
92 | return self.meters[attr]
93 | if attr in self.__dict__:
94 | return self.__dict__[attr]
95 | raise AttributeError(
96 | "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
97 | )
98 |
99 | def __str__(self):
100 | loss_str = []
101 | for name, meter in self.meters.items():
102 | loss_str.append("{}: {}".format(name, str(meter)))
103 | return self.delimiter.join(loss_str)
104 |
105 | def synchronize_between_processes(self):
106 | for meter in self.meters.values():
107 | meter.synchronize_between_processes()
108 |
109 | def add_meter(self, name, meter):
110 | self.meters[name] = meter
111 |
112 | def log_every(self, iterable, print_freq, header=None):
113 | i = 0
114 | if not header:
115 | header = ""
116 | start_time = time.time()
117 | end = time.time()
118 | iter_time = SmoothedValue(fmt="{avg:.4f}")
119 | data_time = SmoothedValue(fmt="{avg:.4f}")
120 | space_fmt = ":" + str(len(str(len(iterable)))) + "d"
121 | if torch.cuda.is_available():
122 | log_msg = self.delimiter.join(
123 | [
124 | header,
125 | "[{0" + space_fmt + "}/{1}]",
126 | "eta: {eta}",
127 | "{meters}",
128 | "time: {time}",
129 | "data: {data}",
130 | "max mem: {memory:.0f}",
131 | ]
132 | )
133 | else:
134 | log_msg = self.delimiter.join(
135 | [
136 | header,
137 | "[{0" + space_fmt + "}/{1}]",
138 | "eta: {eta}",
139 | "{meters}",
140 | "time: {time}",
141 | "data: {data}",
142 | ]
143 | )
144 | MB = 1024.0 * 1024.0
145 | for obj in iterable:
146 | data_time.update(time.time() - end)
147 | yield obj
148 | iter_time.update(time.time() - end)
149 | if i % print_freq == 0 or i == len(iterable) - 1:
150 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
151 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
152 | if torch.cuda.is_available():
153 | print(
154 | log_msg.format(
155 | i,
156 | len(iterable),
157 | eta=eta_string,
158 | meters=str(self),
159 | time=str(iter_time),
160 | data=str(data_time),
161 | memory=torch.cuda.max_memory_allocated() / MB,
162 | )
163 | )
164 | else:
165 | print(
166 | log_msg.format(
167 | i,
168 | len(iterable),
169 | eta=eta_string,
170 | meters=str(self),
171 | time=str(iter_time),
172 | data=str(data_time),
173 | )
174 | )
175 | i += 1
176 | end = time.time()
177 | total_time = time.time() - start_time
178 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
179 | print(
180 | "{} Total time: {} ({:.4f} s / it)".format(
181 | header, total_time_str, total_time / len(iterable)
182 | )
183 | )
184 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Tuple, List
3 | import random
4 |
5 |
6 | def get_mask(lengths, max_length):
7 | """Computes a batch of padding masks given batched lengths"""
8 | mask = 1 * (torch.arange(max_length).unsqueeze(1).to(lengths.device) < lengths).transpose(0, 1)
9 | return mask
10 |
11 |
12 | def mask_tokens(inputs, tokenizer, mlm_probability):
13 | """
14 | Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
15 | """
16 |
17 | if tokenizer.mask_token is None:
18 | raise ValueError("This tokenizer does not have a mask token which is necessary for masked language modeling. Remove the --mlm flag if you want to use this tokenizer.")
19 |
20 | labels = inputs.clone()
21 | # We sample a few tokens in each sequence for masked-LM training (with probability args.mlm_probability defaults to 0.15 in Bert/RoBERTa)
22 | probability_matrix = torch.full(labels.shape, mlm_probability)
23 | special_tokens_mask = [tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()]
24 | probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
25 | if tokenizer._pad_token is not None:
26 | padding_mask = labels.eq(tokenizer.pad_token_id)
27 | probability_matrix.masked_fill_(padding_mask, value=0.0)
28 | masked_indices = torch.bernoulli(probability_matrix).bool()
29 | labels[~masked_indices] = -100 # We only compute loss on masked tokens
30 |
31 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
32 | indices_replaced = (torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices)
33 | inputs[indices_replaced] = tokenizer.convert_tokens_to_ids(tokenizer.mask_token)
34 |
35 | # 10% of the time, we replace masked input tokens with random word
36 | indices_random = (torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced)
37 | random_words = torch.randint(len(tokenizer), labels.shape, dtype=torch.long)
38 | inputs[indices_random] = random_words[indices_random]
39 |
40 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
41 | return inputs, labels
42 |
43 |
44 | def adjust_learning_rate(optimizer, curr_step: int, num_training_steps: int, args):
45 | num_warmup_steps: int = round(args.fraction_warmup_steps * num_training_steps)
46 | if args.schedule == "linear_with_warmup":
47 | if curr_step < num_warmup_steps:
48 | gamma = float(curr_step) / float(max(1, num_warmup_steps))
49 | else:
50 | gamma = max(0.0, float(num_training_steps - curr_step) / float(max(1, num_training_steps - num_warmup_steps)))
51 | else: # constant LR
52 | gamma = 1
53 |
54 | optimizer.param_groups[0]["lr"] = args.lr * gamma
55 |
--------------------------------------------------------------------------------