├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── VideoGPT2.py ├── dataset.py ├── generate.py ├── images ├── Figure1.png └── Figure2.png ├── requirements.txt └── train.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ictnlp/DSTC8-AVSD/e9578fe1dc0d982928b4be8b5e133036664ad05c/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 ICTNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DSTC8-AVSD 2 | We rank the 1st in DSTC8 Audio-Visual Scene-Aware Dialog competition. This is the source code for our AAAI2020-DSTC8-AVSD paper [Bridging Text and Video: A Universal Multimodal Transformer for Video-Audio Scene-Aware Dialog.]() Zekang Li, Zongjia Li, Jinchao Zhang, Yang Feng, Cheng Niu, Jie Zhou. AAAI2020. 3 | 4 | ## News 5 | Our paper is accpeted by IEEE/ACM Transactions on Audio, Speech, and Language Processing (TASLP). [url]() 6 | 7 | ## Abstract 8 | 9 | Audio-Visual Scene-Aware Dialog (AVSD) is a task to generate responses when chatting about a given video, which is organized as a track of the 8th Dialog System Technology Challenge (DSTC8). To solve the task, we propose a universal multimodal transformer and introduce the multi-task learning method to learn joint representations among different modalities as well as generate informative and fluent responses. Our method extends the natural language generation pre-trained model to multimodal dialogue generation 10 | task. Our system achieves the best performance in both objective and subjective evaluations in the challenge. 11 | 12 | ![A dialogue sampled from the DSTC8-AVSD dataset. For each dialogue, there are video, audio, video caption, dialogue summary and 10 turns of conversations about the video.](./images/Figure1.png) 13 | 14 | ## Model Architecture 15 | 16 | ![](./images/Figure2.png) 17 | 18 | 19 | 20 | ## How to Run 21 | 22 | ### Requirements 23 | 24 | Python. 3.6 25 | 26 | torch==1.0.1 27 | pytorch-ignite==0.2.1 28 | transformers==2.1.1 29 | tqdm==4.36.1 30 | 31 | ```shell 32 | pip install -r requirements.txt 33 | ``` 34 | 35 | ### Data 36 | 37 | Download [dataset](https://drive.google.com/drive/folders/1SlZTySJAk_2tiMG5F8ivxCfOl_OWwd_Q) of the DSTC8, including the training, validation, and test dialogues and the features of Charades videos extracted using VGGish and I3D models. 38 | 39 | All the data should be saved into folder `data/` in the repo root folder. 40 | 41 | ### Train 42 | 43 | ```shell 44 | python train.py --log_path log/ 45 | ``` 46 | 47 | ### Generate 48 | 49 | ```shell 50 | python generate.py --model_checkpoint log/ --output result.json --beam_search 51 | ``` 52 | 53 | 54 | 55 | ## Citation 56 | 57 | If you use this code in your research, you can cite our AAAI2020 DSTC8 workshop paper: 58 | 59 | ``` 60 | @article{li2020bridging, 61 | title={Bridging Text and Video: A Universal Multimodal Transformer for Video-Audio Scene-Aware Dialog}, 62 | author={Zekang Li and Zongjia Li and Jinchao Zhang and Yang Feng and Cheng Niu and Jie Zhou}, 63 | year={2020}, 64 | eprint={2002.00163}, 65 | archivePrefix={arXiv}, 66 | journal={CoRR}, 67 | primaryClass={cs.CL} 68 | } 69 | ``` 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /VideoGPT2.py: -------------------------------------------------------------------------------- 1 | from transformers import * 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import CrossEntropyLoss, MSELoss 6 | 7 | 8 | def gelu(x): 9 | return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) 10 | 11 | class Attention(nn.Module): 12 | def __init__(self, nx, n_ctx, config, scale=False): 13 | super(Attention, self).__init__() 14 | self.output_attentions = config.output_attentions 15 | 16 | n_state = nx # in Attention: n_state=768 (nx=n_embd) 17 | # [switch nx => n_state from Block to Attention to keep identical to TF implem] 18 | assert n_state % config.n_head == 0 19 | self.register_buffer("bias", torch.tril(torch.ones(n_ctx, n_ctx)).view(1, 1, n_ctx, n_ctx)) 20 | self.n_head = config.n_head 21 | self.split_size = n_state 22 | self.scale = scale 23 | 24 | self.c_attn = Conv1D(n_state * 3, nx) 25 | self.c_proj = Conv1D(n_state, nx) 26 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 27 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 28 | self.pruned_heads = set() 29 | 30 | def prune_heads(self, heads): 31 | if len(heads) == 0: 32 | return 33 | mask = torch.ones(self.n_head, self.split_size // self.n_head) 34 | heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads 35 | for head in heads: 36 | # Compute how many pruned heads are before the head and move the index accordingly 37 | head = head - sum(1 if h < head else 0 for h in self.pruned_heads) 38 | mask[head] = 0 39 | mask = mask.view(-1).contiguous().eq(1) 40 | index = torch.arange(len(mask))[mask].long() 41 | index_attn = torch.cat([index, index + self.split_size, index + (2*self.split_size)]) 42 | 43 | # Prune conv1d layers 44 | self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1) 45 | self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0) 46 | 47 | # Update hyper params 48 | self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads)) 49 | self.n_head = self.n_head - len(heads) 50 | self.pruned_heads = self.pruned_heads.union(heads) 51 | 52 | def _attn(self, q, k, v, attention_mask=None, head_mask=None): 53 | w = torch.matmul(q, k) 54 | if self.scale: 55 | w = w / math.sqrt(v.size(-1)) 56 | nd, ns = w.size(-2), w.size(-1) 57 | b = self.bias[:, :, ns-nd:ns, :ns] 58 | #w = w * b - 1e18 * (1 - b) 59 | 60 | if attention_mask is not None: 61 | # Apply the attention mask 62 | b = torch.gt(b + attention_mask[0], 0).float() 63 | w = w * b - 1e18 * (1 - b) 64 | w = w - 1e18 * (1 - attention_mask[1]) 65 | else: 66 | w = w * b - 1e18 * (1 - b) 67 | 68 | w = nn.Softmax(dim=-1)(w) 69 | w = self.attn_dropout(w) 70 | 71 | # Mask heads if we want to 72 | if head_mask is not None: 73 | w = w * head_mask 74 | 75 | outputs = [torch.matmul(w, v)] 76 | if self.output_attentions: 77 | outputs.append(w) 78 | return outputs 79 | 80 | def merge_heads(self, x): 81 | x = x.permute(0, 2, 1, 3).contiguous() 82 | new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) 83 | return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states 84 | 85 | def split_heads(self, x, k=False): 86 | new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) 87 | x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states 88 | if k: 89 | return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length) 90 | else: 91 | return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) 92 | 93 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): 94 | x = self.c_attn(x) 95 | query, key, value = x.split(self.split_size, dim=2) 96 | query = self.split_heads(query) 97 | key = self.split_heads(key, k=True) 98 | value = self.split_heads(value) 99 | if layer_past is not None: 100 | past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below 101 | key = torch.cat((past_key, key), dim=-1) 102 | value = torch.cat((past_value, value), dim=-2) 103 | present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking 104 | 105 | attn_outputs = self._attn(query, key, value, attention_mask, head_mask) 106 | a = attn_outputs[0] 107 | 108 | a = self.merge_heads(a) 109 | a = self.c_proj(a) 110 | a = self.resid_dropout(a) 111 | 112 | outputs = [a, present] + attn_outputs[1:] 113 | return outputs # a, present, (attentions) 114 | 115 | 116 | class MLP(nn.Module): 117 | def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd) 118 | super(MLP, self).__init__() 119 | nx = config.n_embd 120 | self.c_fc = Conv1D(n_state, nx) 121 | self.c_proj = Conv1D(nx, n_state) 122 | self.act = gelu 123 | self.dropout = nn.Dropout(config.resid_pdrop) 124 | 125 | def forward(self, x): 126 | h = self.act(self.c_fc(x)) 127 | h2 = self.c_proj(h) 128 | return self.dropout(h2) 129 | 130 | 131 | class Block(nn.Module): 132 | def __init__(self, n_ctx, config, scale=False): 133 | super(Block, self).__init__() 134 | nx = config.n_embd 135 | self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 136 | self.attn = Attention(nx, n_ctx, config, scale) 137 | self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) 138 | self.mlp = MLP(4 * nx, config) 139 | 140 | def forward(self, x, layer_past=None, attention_mask=None, head_mask=None): 141 | output_attn = self.attn(self.ln_1(x), 142 | layer_past=layer_past, 143 | attention_mask=attention_mask, 144 | head_mask=head_mask) 145 | a = output_attn[0] # output_attn: a, present, (attentions) 146 | 147 | x = x + a 148 | m = self.mlp(self.ln_2(x)) 149 | x = x + m 150 | 151 | outputs = [x] + output_attn[1:] 152 | return outputs # x, present, (attentions) 153 | 154 | 155 | class VideoGPT2Model(GPT2Model): 156 | 157 | def __init__(self, config): 158 | super(VideoGPT2Model, self).__init__(config) 159 | self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)]) 160 | 161 | def forward(self, input_embs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): 162 | if past is None: 163 | past_length = 0 164 | past = [None] * len(self.h) 165 | else: 166 | past_length = past[0][0].size(-2) 167 | if position_ids is None: 168 | position_ids = torch.arange(past_length, input_embs.size(-2) + past_length, dtype=torch.long, device=input_embs.device) 169 | position_ids = position_ids.unsqueeze(0).expand_as(input_embs[:, :, 0]) 170 | 171 | # Attention mask. 172 | if attention_mask is not None: 173 | # We create a 3D attention mask from a 2D tensor mask. 174 | # Sizes are [batch_size, 1, 1, to_seq_length] 175 | # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] 176 | # this attention mask is more simple than the triangular masking of causal attention 177 | # used in OpenAI GPT, we just need to prepare the broadcast dimension here. 178 | attention_mask[0] = attention_mask[0].unsqueeze(1).unsqueeze(2) 179 | attention_mask[1] = attention_mask[1].unsqueeze(1).unsqueeze(2) 180 | 181 | # Since attention_mask is 1.0 for positions we want to attend and 0.0 for 182 | # masked positions, this operation will create a tensor which is 0.0 for 183 | # positions we want to attend and -10000.0 for masked positions. 184 | # Since we are adding it to the raw scores before the softmax, this is 185 | # effectively the same as removing these entirely. 186 | attention_mask[0] = attention_mask[0].to(dtype=next(self.parameters()).dtype) # fp16 compatibility 187 | attention_mask[1] = attention_mask[1].to(dtype=next(self.parameters()).dtype) # fp16 compatibility 188 | #attention_mask = (1.0 - attention_mask) * -1e18 189 | 190 | # Prepare head mask if needed 191 | # 1.0 in head_mask indicate we keep the head 192 | # attention_probs has shape bsz x n_heads x N x N 193 | # head_mask has shape n_layer x batch x n_heads x N x N 194 | if head_mask is not None: 195 | if head_mask.dim() == 1: 196 | head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 197 | head_mask = head_mask.expand(self.config.n_layer, -1, -1, -1, -1) 198 | elif head_mask.dim() == 2: 199 | head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer 200 | head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility 201 | else: 202 | head_mask = [None] * self.config.n_layer 203 | 204 | input_shape = input_embs.size()[:2] 205 | # input_ids = input_ids.view(-1, input_ids.size(-1)) 206 | position_ids = position_ids.view(-1, position_ids.size(-1)) 207 | 208 | # inputs_embeds = self.wte(input_ids) 209 | inputs_embeds = input_embs 210 | position_embeds = self.wpe(position_ids) 211 | if token_type_ids is not None: 212 | token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) 213 | token_type_embeds = self.wte(token_type_ids) 214 | else: 215 | token_type_embeds = 0 216 | hidden_states = inputs_embeds + position_embeds + token_type_embeds 217 | hidden_states = self.drop(hidden_states) 218 | 219 | output_shape = input_shape + (hidden_states.size(-1),) 220 | 221 | presents = () 222 | all_attentions = [] 223 | all_hidden_states = () 224 | for i, (block, layer_past) in enumerate(zip(self.h, past)): 225 | if self.output_hidden_states: 226 | all_hidden_states = all_hidden_states + (hidden_states.view(*output_shape),) 227 | 228 | outputs = block(hidden_states, 229 | layer_past=layer_past, 230 | attention_mask=attention_mask, 231 | head_mask=head_mask[i]) 232 | 233 | hidden_states, present = outputs[:2] 234 | presents = presents + (present,) 235 | 236 | if self.output_attentions: 237 | all_attentions.append(outputs[2]) 238 | 239 | hidden_states = self.ln_f(hidden_states) 240 | 241 | hidden_states = hidden_states.view(*output_shape) 242 | # Add last hidden state 243 | if self.output_hidden_states: 244 | all_hidden_states = all_hidden_states + (hidden_states,) 245 | 246 | outputs = (hidden_states, presents) 247 | if self.output_hidden_states: 248 | outputs = outputs + (all_hidden_states,) 249 | if self.output_attentions: 250 | # let the number of heads free (-1) so we can extract attention even after head pruning 251 | attention_output_shape = input_shape[:-1] + (-1,) + all_attentions[0].shape[-2:] 252 | all_attentions = tuple(t.view(*attention_output_shape) for t in all_attentions) 253 | outputs = outputs + (all_attentions,) 254 | return outputs # last hidden state, presents, (all hidden_states), (attentions) 255 | 256 | 257 | class VideoGPT2LMHeadModel(GPT2PreTrainedModel): 258 | def __init__(self, config): 259 | super(VideoGPT2LMHeadModel, self).__init__(config) 260 | self.transformer = VideoGPT2Model(config) 261 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 262 | self.video_ff = nn.Linear(4224, config.n_embd) 263 | self.video_inverse_ff = nn.Linear(config.n_embd, 4224) 264 | 265 | self.init_weights() 266 | self.tie_weights() 267 | 268 | def tie_weights(self): 269 | """ Make sure we are sharing the input and output embeddings. 270 | Export to TorchScript can't handle parameter sharing so we are cloning them instead. 271 | """ 272 | self._tie_or_clone_weights(self.lm_head, 273 | self.transformer.wte) 274 | 275 | 276 | def forward(self, input_embs, past=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, 277 | labels=None, mode="reply"): 278 | transformer_outputs = self.transformer(input_embs, 279 | past=past, 280 | attention_mask=attention_mask, 281 | token_type_ids=token_type_ids, 282 | position_ids=position_ids, 283 | head_mask=head_mask) 284 | hidden_states = transformer_outputs[0] 285 | 286 | lm_logits = self.lm_head(hidden_states) 287 | 288 | outputs = (lm_logits,) + transformer_outputs[1:] 289 | if labels is not None: 290 | # Shift so that tokens < n predict n 291 | if mode == "reply": 292 | shift_logits = lm_logits[..., :-1, :].contiguous() 293 | shift_labels = labels[0][..., 1:].contiguous() 294 | # Flatten the tokens 295 | loss_text_fct = CrossEntropyLoss(ignore_index=-1) 296 | loss_text = loss_text_fct(shift_logits.view(-1, shift_logits.size(-1)), 297 | shift_labels.view(-1)) 298 | loss = loss_text 299 | else: 300 | lm_video_regs = self.video_inverse_ff(hidden_states[:, :labels[1].size(1), :]) 301 | shift_video_regs = lm_video_regs[..., :-1, :].contiguous() 302 | shift_video_labels = labels[1][..., :-1, :].contiguous() 303 | loss_video_fct = MSELoss(reduce=True, size_average=True) 304 | loss_video = loss_video_fct(shift_video_regs, shift_video_labels) 305 | loss = loss_video 306 | outputs = (loss,) + outputs 307 | 308 | return outputs # (loss), lm_logits, presents, (all hidden_states), (attentions) 309 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # coding: utf-8 2 | # author: noctli 3 | import json 4 | import pickle 5 | import logging 6 | import numpy as np 7 | import torch 8 | import torch.utils.data 9 | from torch.utils.data import Dataset 10 | from itertools import chain 11 | # from train import SPECIAL_TOKENS, MODEL_INPUTS, PADDED_INPUTS 12 | SPECIAL_TOKENS = ["", "", "", "","", "