├── .gitignore ├── LICENSE ├── README.md ├── bash ├── eval.sh ├── train_bert.sh └── train_glove.sh ├── dataset ├── __init__.py ├── dataset.py ├── load.py ├── release.py └── util.py ├── eval_mc.py ├── fig └── example.png ├── main_qa.py ├── networks ├── Attention.py ├── CRN.py ├── Embed_loss.py ├── EncoderRNN.py ├── GCN.py ├── Transformer.py ├── VQAModel │ ├── B2A.py │ ├── CoMem.py │ ├── EVQA.py │ ├── HCRN.py │ ├── HGA.py │ └── HME.py ├── memory_module.py ├── memory_rand.py └── torchnlp_nn.py ├── requirement.txt ├── utils.py └── videoqa.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | __ignore__ 3 | data -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 BCMI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Causal-VidQA 2 | 3 | ## News 4 | 5 | * [2024.07.11] We release the answer for the [test set](https://cloud.bcmi.sjtu.edu.cn/sharing/aU4Skr9EJ). You can download them and put them into the ```['data/QA']``` to use them. 6 | 7 | 8 | ## Introduction 9 | 10 | The Causal-VidQA dataset contains 107,600 QA pairs from the [Causal-VidQA dataset](https://arxiv.org/pdf/2205.14895.pdf). The dataset aims to facilitate deeper video understanding towards video reasoning. In detail, we present the task of Causal-VidQA, which includes four types of questions ranging from scene description (description) to evidence reasoning (explanation) and commonsense reasoning (prediction and counterfactual). For commonsense reasoning, we set up a two-step solution by answering the question and providing a proper reason. 11 | 12 | Here is an example from our dataset and the comparison between our dataset and other VisualQA datasets. 13 | 14 |
15 |
Example from our Causal-VidQA Dataset
16 | 17 | | Dataset | Visual Type | Visual Source | Annotation | Description | Explanation | Prediction | Counterfactual | \#Video/Image | \#QA | Video Length (s) | 18 | |:--------------:|:-----------:|:-------------:|:----------:|:-----------:|:-----------:|:----------:|:--------------:|:-------------:|:-------:|:----------------:| 19 | | Motivation | Image | MS COCO | Man | ✔ | ✔ | ✔ | $\times$ | 10,191 | - | - | 20 | | VCR | Image | Movie Clip | Man | ✔ | ✔ | ✔ | $\times$ | 110,000 | 290,000 | - | 21 | | MovieQA | Video | Movie Stories | Auto | ✔ | ✔ | $\times$ | $\times$ | 548 | 21,406 | 200 | 22 | | TVQA | Video | TV Show | Man | ✔ | ✔ | $\times$ | $\times$ | 21,793 | 152,545 | 76 | 23 | | TGIF-QA | Video | TGIF | Auto | ✔ | $\times$ | $\times$ | $\times$ | 71,741 | 165,165 | 3 | 24 | | ActivityNet-QA | Video | ActivityNet | Man | ✔ | ✔ | $\times$ | $\times$ | 5,800 | 58,000 | 180 | 25 | | Social-IQ | Video | YouTube | Man | ✔ | ✔ | $\times$ | $\times$ | 1,250 | 7,500 | 60 | 26 | | CLEVRER | Video | Game Engine | Man | ✔ | ✔ | ✔ | ✔ | 20,000 | 305,280 | 5 | 27 | | V2C | Video | MSR-VTT | Man | ✔ | ✔ | $\times$ | $\times$ | 10,000 | 115,312 | 30 | 28 | | NExT-QA | Video | YFCC-100M | Man | ✔ | ✔ | $\times$ | $\times$ | 5,440 | 52,044 | 44 | 29 | | Causal-VidQA | Video | Kinetics-700 | Man | ✔ | ✔ | ✔ | ✔ | 26,900 | 107,600 | 9 | 30 | 31 |
Comparison between our dataset and other VisualQA datasets
32 | 33 | In this page, you can find the code of some SOTA VideoQA methods and the dataset for our **CVPR** conference paper. 34 | 35 | * Jiangtong Li, Li Niu and Liqing Zhang. *From Representation to Reasoning: Towards both Evidence and Commonsense Reasoning for Video Question-Answering*. *CVPR*, 2022. [[paper link]](https://arxiv.org/pdf/2205.14895.pdf) 36 | 37 | ## Download 38 | 1. [Visual Feature](https://cloud.bcmi.sjtu.edu.cn/sharing/ZI1F0Hfd0) 39 | 2. [Text Feature](https://cloud.bcmi.sjtu.edu.cn/sharing/NeiJfafJq) 40 | 3. [Dataset Split](https://cloud.bcmi.sjtu.edu.cn/sharing/6kEtHMarE) 41 | 4. [Text annotation](https://cloud.bcmi.sjtu.edu.cn/sharing/aszEJs8VX) 42 | 5. [Original Data](https://cloud.bcmi.sjtu.edu.cn/sharing/FYDmyDwff) 43 | 44 | ## Install 45 | Please create an env for this project using miniconda (should install [miniconda](https://docs.conda.io/en/latest/miniconda.html) first) 46 | ``` 47 | >conda create -n causal-vidqa python==3.6.12 48 | >conda activate causal-vidqa 49 | >git clone https://github.com/bcmi/Causal-VidQA 50 | >pip install -r requirement.txt 51 | ``` 52 | 53 | ## Data Preparation 54 | Please download the pre-computed features and QA annotations from [Download 1-4](##Download). 55 | And place them in ```['data/visual_feature']```, ```['data/text_feature']```, ```['data/split']``` and ```['data/QA']```. Note that the ```Text annotation``` is package as QA.tar, you need to unpack it first before place it to ```['data/QA']```. 56 | 57 | If you want to extract different video features and text features from our Causal-VidQA dataset, you can download the original data from [Download 5](##Download) and do whatever your want to extract features. 58 | 59 | ## Usage 60 | Once the data is ready, you can easily run the code. First, to run these models with GloVe feature, you can directly train the B2A by: 61 | ``` 62 | >sh bash/train_glove.sh 63 | ``` 64 | Note that if you want to train the model with BERT feature, we suggest your to first load the BERT feature to sharedarray by: 65 | ``` 66 | >python dataset/load.py 67 | ``` 68 | and then train the B2A with BERT feature by: 69 | ``` 70 | >sh bash/train_bert.sh. 71 | ``` 72 | After the train shell file is conducted, you can find the the prediction file under ```['results/model_name/model_prefix.json']``` and you can evaluate the prediction results by: 73 | ``` 74 | >python eval_mc.py 75 | ``` 76 | You can also obtain the prediction by running: 77 | ``` 78 | >sh bash/eval.sh 79 | ``` 80 | The command above will load the model from ```['experiment/model_name/model_prefix/model/best.pkl']``` and generate the prediction file. 81 | 82 | Hint: we have release a trained [model](https://cloud.bcmi.sjtu.edu.cn/sharing/c5IKQVMrM) for ```B2A``` method, please place this the trained weight in ```['experiment/B2A/B2A/model/best.pkl']``` and then make prediction by running: 83 | ``` 84 | >sh bash/eval.sh 85 | ``` 86 | 87 | (*The results may be slightly different depending on the environments and random seeds.*) 88 | 89 | (*For comparison, please refer to the results in our paper.*) 90 | 91 | ## Citation 92 | ``` 93 | @InProceedings{li2022from, 94 | author = {Li, Jiangtong and Niu, Li and Zhang, Liqing}, 95 | title = {From Representation to Reasoning: Towards both Evidence and Commonsense Reasoning for Video Question-Answering}, 96 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 97 | month = {June}, 98 | year = {2022} 99 | } 100 | ``` 101 | ## Acknowledgement 102 | Our reproduction of the methods is mainly based on the [Next-QA](https://github.com/doc-doc/NExT-QA) and other respective official repositories, we thank the authors to release their code. If you use the related part, please cite the corresponding paper commented in the code. 103 | -------------------------------------------------------------------------------- /bash/eval.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_qa.py \ 2 | --mode test \ 3 | --feature_path ./data/visual_feature/ \ 4 | --text_feature_path ./data/text_feature/ \ 5 | --data_path ./data/QA/ \ 6 | --split_path ./data/split/ \ 7 | --checkpoint_path ./experiment \ 8 | --model_type B2A \ 9 | --model_prefix B2A \ 10 | --result_file ./result/{}/{}_{}.json \ 11 | --vid_dim 4096 \ 12 | --hidden_dim 256 \ 13 | --word_dim 300 \ 14 | --max_vid_len 16 \ 15 | --max_qa_len 40 \ 16 | --epoch_num 30 \ 17 | --lr_rate 2e-4 \ 18 | --batch_size 32 -------------------------------------------------------------------------------- /bash/train_bert.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_qa.py \ 2 | --mode train \ 3 | --feature_path ./data/visual_feature/ \ 4 | --text_feature_path ./data/text_feature/ \ 5 | --data_path ./data/QA/ \ 6 | --split_path ./data/split/ \ 7 | --checkpoint_path ./experiment \ 8 | --model_type B2A \ 9 | --model_prefix B2A_bert \ 10 | --result_file ./result/{}/{}_{}.json \ 11 | --vid_dim 4096 \ 12 | --hidden_dim 128 \ 13 | --word_dim 300 \ 14 | --max_vid_len 16 \ 15 | --max_qa_len 40 \ 16 | --epoch_num 30 \ 17 | --lr_rate 2e-4 \ 18 | --batch_size 128 \ 19 | --use_bert -------------------------------------------------------------------------------- /bash/train_glove.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python main_qa.py \ 2 | --mode train \ 3 | --feature_path ./data/visual_feature/ \ 4 | --text_feature_path ./data/text_feature/ \ 5 | --data_path ./data/QA/ \ 6 | --split_path ./data/split/ \ 7 | --checkpoint_path ./experiment \ 8 | --model_type B2A \ 9 | --model_prefix B2A_glove \ 10 | --result_file ./result/{}/{}_{}.json \ 11 | --vid_dim 4096 \ 12 | --hidden_dim 128 \ 13 | --word_dim 300 \ 14 | --max_vid_len 16 \ 15 | --max_qa_len 40 \ 16 | --epoch_num 30 \ 17 | --lr_rate 2e-4 \ 18 | --batch_size 128 -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import VidQADataset, Vocabulary -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from numpy import random 3 | import torch 4 | from torch.utils.data import Dataset 5 | import os.path as osp 6 | import numpy as np 7 | import nltk 8 | import h5py 9 | import os 10 | import json 11 | import pickle as pkl 12 | import pandas as pd 13 | from tqdm import tqdm 14 | import re 15 | import stanfordnlp 16 | import SharedArray as sa 17 | 18 | class Vocabulary(object): 19 | """Simple vocabulary wrapper.""" 20 | def __init__(self, word2idx, idx2word): 21 | self.word2idx = word2idx 22 | self.idx2word = idx2word 23 | self.idx = len(idx2word) 24 | 25 | def add_word(self, word): 26 | if not word in self.word2idx: 27 | self.word2idx[word] = self.idx 28 | self.idx2word[self.idx] = word 29 | self.idx += 1 30 | 31 | def __call__(self, word): 32 | if not word in self.word2idx: 33 | return self.word2idx[''] 34 | return self.word2idx[word] 35 | 36 | def __len__(self): 37 | return len(self.word2idx) 38 | 39 | class VidQADataset(Dataset): 40 | """load the dataset in dataloader""" 41 | 42 | def __init__(self, feature_path, text_feature_path, split_path, data_path, use_bert, vocab, qtype=-1, max_length=40): 43 | self.feature_path = feature_path 44 | self.text_feature_path = text_feature_path 45 | self.split_path = split_path 46 | self.data_path = data_path 47 | self.qtype = qtype 48 | 49 | self.vocab = vocab 50 | self.vids = pkload(self.split_path) 51 | 52 | self.max_length = max_length 53 | self.use_bert = use_bert 54 | if self.use_bert: 55 | self.bert_file = osp.join(text_feature_path, 'text_seq.h5') 56 | self.bert_length = osp.join(text_feature_path, 'text_seq_length.pkl') 57 | self.bert_token = osp.join(text_feature_path, 'token_org.pkl') 58 | with open(self.bert_token, 'rb') as fbt: 59 | self.token_dict = pkl.load(fbt) 60 | with open(self.bert_length, 'rb') as fl: 61 | self.length_dict = pkl.load(fl) 62 | 63 | if self.use_bert: 64 | self.adj_path = osp.join(text_feature_path, 'bert_adj_dict.pkl') 65 | else: 66 | self.adj_path = osp.join(text_feature_path, 'glove_adj_dict.pkl') 67 | with open(self.adj_path, 'rb') as fbt: 68 | self.token_adj = pkl.load(fbt) 69 | 70 | vf_info = pkload(osp.join(feature_path, 'idx2vid.pkl')) 71 | self.vf_info = dict() 72 | for idx, vid in enumerate(vf_info): 73 | if vid in self.vids: 74 | self.vf_info[vid] = idx 75 | app_file = osp.join(feature_path, 'appearance_feat.h5') 76 | mot_file = osp.join(feature_path, 'motion_feat.h5') 77 | print('Load {}...'.format(app_file)) 78 | self.app_feats = dict() 79 | with h5py.File(app_file, 'r') as fp: 80 | feats = fp['resnet_features'] 81 | for vid, idx in self.vf_info.items(): 82 | self.app_feats[vid] = feats[idx][...] 83 | print('Load {}...'.format(mot_file)) 84 | self.mot_feats = dict() 85 | with h5py.File(mot_file, 'r') as fp: 86 | feats = fp['resnet_features'] 87 | for vid, idx in self.vf_info.items(): 88 | self.mot_feats[vid] = feats[idx][...] 89 | 90 | self.txt_obj = dict() 91 | print('Load {}...'.format(osp.join(self.feature_path, 'ROI_text.h5'))) 92 | with h5py.File(osp.join(self.feature_path, 'ROI_text.h5'), 'r') as f: 93 | keys = [item for item in self.vids if item in f.keys()] 94 | for key in keys: 95 | tmp = dict() 96 | labels = f[key].keys() 97 | for label in labels: 98 | new_label = '[' + label + ']' 99 | tmp[new_label] = f[key][label][...] 100 | self.txt_obj[key] = tmp 101 | 102 | def __len__(self): 103 | if self.qtype == -1: 104 | return len(self.vids)*6 105 | elif self.qtype == 0 or self.qtype == 1: 106 | return len(self.vids) 107 | elif self.qtype == 2 or self.qtype == 3: 108 | return len(self.vids)*2 109 | 110 | def get_video_feature(self, video_name): 111 | """ 112 | :param video_name: 113 | :return: 114 | """ 115 | app_feat = self.app_feats[video_name] 116 | mot_feat = self.mot_feats[video_name] 117 | 118 | return torch.from_numpy(app_feat).type(torch.float32), torch.from_numpy(mot_feat).type(torch.float32) 119 | 120 | def get_word_idx(self, text): 121 | """ 122 | """ 123 | tokens = nltk.tokenize.word_tokenize(str(text).lower()) 124 | token_ids = [self.vocab(token) for i, token in enumerate(tokens) if i < (self.max_length - 2)] 125 | 126 | return token_ids 127 | 128 | def get_token_seq(self, text): 129 | """ 130 | """ 131 | tokens = nltk.tokenize.word_tokenize(str(text).lower()) 132 | return tokens 133 | 134 | def get_adj(self, vidx, qtype): 135 | adj_vidx = self.token_adj[vidx] 136 | qas_adj = adj_vidx[6+qtype*5:11+qtype*5] 137 | ques_adj = adj_vidx[qtype] 138 | qas_adj_new = np.zeros((len(qas_adj), self.max_length, self.max_length)) 139 | ques_adj_new = np.zeros((self.max_length, self.max_length)) 140 | for idx, item in enumerate(qas_adj): 141 | if item.shape[0] > self.max_length: 142 | qas_adj_new[idx] = item[:self.max_length, :self.max_length] 143 | else: 144 | qas_adj_new[idx, :item.shape[0], :item.shape[1]] = item 145 | if ques_adj.shape[0] > self.max_length: 146 | ques_adj_new = ques_adj[:self.max_length, :self.max_length] 147 | else: 148 | ques_adj_new[:ques_adj.shape[0], :ques_adj.shape[1]] = ques_adj 149 | return qas_adj_new, ques_adj_new 150 | 151 | def get_trans_matrix(self, candidates): 152 | 153 | qa_lengths = [len(qa) for qa in candidates] 154 | candidates_matrix = torch.zeros([5, self.max_length]).long() 155 | for k in range(5): 156 | sentence = candidates[k] 157 | length = qa_lengths[k] 158 | if length > self.max_length: 159 | length = self.max_length 160 | candidates_matrix[k] = torch.Tensor(sentence[:length]) 161 | else: 162 | candidates_matrix[k, :length] = torch.Tensor(sentence) 163 | 164 | return candidates_matrix, qa_lengths 165 | 166 | def get_ques_matrix(self, ques): 167 | 168 | q_lengths = len(ques) 169 | ques_matrix = torch.zeros([self.max_length]).long() 170 | ques_matrix[:q_lengths] = torch.Tensor(ques) 171 | 172 | return ques_matrix, q_lengths 173 | 174 | def get_tagname(self, line): 175 | tag = set() 176 | tmp_tag = re.findall(r"\[(.+?)\]", line) 177 | for item in tmp_tag: 178 | tag.add('['+item+']') 179 | return list(tag) 180 | 181 | def match_tok_tag(self, labels, tags, tok): 182 | tok_tag = [None for _ in range(len(tok))] 183 | if labels == list(): 184 | return tok_tag 185 | for tag in tags: 186 | for idx in range(len(tok)): 187 | if tag.startswith(tok[idx]): 188 | new_idx = idx 189 | while not tag.endswith(tok[new_idx]): 190 | new_idx += 1 191 | new_tag = ''.join(tok[idx:new_idx+1]) 192 | if new_tag == tag: 193 | for i in range(idx, new_idx+1): 194 | tok_tag[i] = tag 195 | if tag not in labels: 196 | label = random.choice(labels) 197 | for index, item in enumerate(tok_tag): 198 | if item == tag: 199 | tok_tag[index] = label 200 | else: 201 | pass 202 | return tok_tag 203 | 204 | def load_txt_obj(self, vid, tok, org): 205 | if vid in self.txt_obj: 206 | labels = list(self.txt_obj[vid].keys()) 207 | else: 208 | labels = list() 209 | fea = list() 210 | for idx in range(len(tok)): 211 | tags = self.get_tagname(org[idx]) 212 | tok_tag = self.match_tok_tag(labels, tags, tok[idx]) 213 | fea_each = list() 214 | for item in tok_tag: 215 | if item is None: 216 | fea_each.append(np.zeros((2048,))) 217 | else: 218 | fea_each.append(self.txt_obj[vid][item]) 219 | fea_each = np.stack(fea_each, axis=0) 220 | new_fea_each = np.zeros((self.max_length, 2048)) 221 | if fea_each.shape[0] > self.max_length: 222 | new_fea_each = fea_each[:self.max_length] 223 | else: 224 | new_fea_each[:fea_each.shape[0]] = fea_each 225 | 226 | fea.append(new_fea_each) 227 | return fea 228 | 229 | def load_text(self, vid, qtype): 230 | text_file = os.path.join(self.data_path, vid, 'text.json') 231 | answer_file = os.path.join(self.data_path, vid, 'answer.json') 232 | with open(text_file, 'r') as fin: 233 | text = json.load(fin) 234 | with open(answer_file, 'r') as fin: 235 | answer = json.load(fin) 236 | if qtype == 0: 237 | qns = text['descriptive']['question'] 238 | cand_ans = text['descriptive']['answer'] 239 | ans_id = answer['descriptive']['answer'] 240 | if qtype == 1: 241 | qns = text['explanatory']['question'] 242 | cand_ans = text['explanatory']['answer'] 243 | ans_id = answer['explanatory']['answer'] 244 | if qtype == 2: 245 | qns = text['predictive']['question'] 246 | cand_ans = text['predictive']['answer'] 247 | ans_id = answer['predictive']['answer'] 248 | if qtype == 3: 249 | qns = text['predictive']['question'] 250 | cand_ans = text['predictive']['reason'] 251 | ans_id = answer['predictive']['reason'] 252 | if qtype == 4: 253 | qns = text['counterfactual']['question'] 254 | cand_ans = text['counterfactual']['answer'] 255 | ans_id = answer['counterfactual']['answer'] 256 | if qtype == 5: 257 | qns = text['counterfactual']['question'] 258 | cand_ans = text['counterfactual']['reason'] 259 | ans_id = answer['counterfactual']['reason'] 260 | return qns, cand_ans, ans_id 261 | 262 | def load_text_bert(self, vid, qtype): 263 | with h5py.File(self.bert_file, 'r') as fp: 264 | feature = sa.attach("shm://{}".format(vid)) 265 | token_org = self.token_dict[vid] 266 | length = self.length_dict[vid] 267 | cand = feature[6+qtype*5:11+qtype*5] 268 | tok = token_org[0][6+qtype*5:11+qtype*5] 269 | org = token_org[1][6+qtype*5:11+qtype*5] 270 | cand_l = length[6+qtype*5:11+qtype*5] 271 | question = feature[qtype] 272 | tok_q = [token_org[0][qtype], ] 273 | org_q = [token_org[1][qtype], ] 274 | qns_len = length[qtype] 275 | dim = cand.shape[2] 276 | new_candidate = np.zeros((5, self.max_length, dim)) 277 | new_question = np.zeros((self.max_length, dim)) 278 | for idx, qa_l in enumerate(cand_l): 279 | if qa_l > self.max_length: 280 | new_candidate[idx] = cand[idx, :self.max_length] 281 | else: 282 | new_candidate[idx, :qa_l] = cand[idx, :qa_l] 283 | if qns_len > self.max_length: 284 | new_question = question[:self.max_length] 285 | else: 286 | new_question[:qns_len] = question[:qns_len] 287 | return torch.from_numpy(new_candidate).type(torch.float32), tok, org, cand_l, torch.from_numpy(new_question).type(torch.float32), tok_q, org_q, qns_len 288 | 289 | def __getitem__(self, idx): 290 | """ 291 | """ 292 | if self.qtype == -1: 293 | qtype = idx % 6 294 | idx = idx // 6 295 | elif self.qtype == 0 or self.qtype == 1: 296 | qtype = self.qtype 297 | elif self.qtype == 2: 298 | qtype = 2 + (idx % 2) 299 | idx = idx // 2 300 | elif self.qtype == 3: 301 | qtype = 4 + (idx % 2) 302 | idx = idx // 2 303 | vidx = self.vids[idx] 304 | # load text 305 | qns, cand_ans, ans_id = self.load_text(vidx, qtype) 306 | if self.use_bert: 307 | candidate, tok, org, can_lengths, question, tok_q, org_q, qns_len = self.load_text_bert(vidx, qtype) 308 | else: 309 | tok_q = [['',] + self.get_token_seq(qns) + ['', ], ] 310 | org_q = [qns, ] 311 | question, qns_len = self.get_ques_matrix([self.vocab(''), ] + self.get_word_idx(qns) + [self.vocab(''), ]) 312 | 313 | tok = [] 314 | org = [] 315 | candidate = [] 316 | qnstok = ['',] + self.get_token_seq(qns) + ['', ] 317 | qnsids = [self.vocab(''), ] + self.get_word_idx(qns) + [self.vocab(''), ] 318 | for ans in cand_ans: 319 | anstok = ['', ] + self.get_token_seq(ans) + ['', ] 320 | ansids = [self.vocab(''), ] + self.get_word_idx(ans) + [self.vocab(''), ] 321 | tok.append(qnstok+anstok) 322 | org.append(qns+ans) 323 | candidate.append(qnsids + ansids) 324 | candidate, can_lengths = self.get_trans_matrix(candidate) 325 | can_lengths = torch.tensor(can_lengths).clamp(max=self.max_length) 326 | qns_len = torch.tensor(qns_len).clamp(max=self.max_length) 327 | # load object feature 328 | obj_feature = torch.from_numpy(np.stack(self.load_txt_obj(vidx, tok, org), axis=0)).type(torch.float32) 329 | obj_feature_q = torch.from_numpy(self.load_txt_obj(vidx, tok_q, org_q)[0]).type(torch.float32) 330 | # load dependency relation 331 | adj_qas, adj_ques = self.get_adj(vidx, qtype) 332 | adj_ques = torch.from_numpy(adj_ques).type(torch.float32) 333 | adj_qas = torch.from_numpy(np.stack(adj_qas, axis=0)).type(torch.float32) 334 | # load video feature 335 | app_feature, mot_feature = self.get_video_feature(vidx) 336 | qns_key = vidx + '_' + str(qtype) 337 | 338 | return [app_feature, mot_feature], [candidate, can_lengths, obj_feature, adj_qas], [question, qns_len, obj_feature_q, adj_ques], torch.tensor(ans_id), qns_key 339 | 340 | def nozero_row(A): 341 | i = 0 342 | for row in A: 343 | if row.sum()==0: 344 | break 345 | i += 1 346 | 347 | return i 348 | 349 | def load_file(file_name): 350 | annos = None 351 | if osp.splitext(file_name)[-1] == '.csv': 352 | return pd.read_csv(file_name) 353 | with open(file_name, 'r') as fp: 354 | if osp.splitext(file_name)[1]== '.txt': 355 | annos = fp.readlines() 356 | annos = [line.rstrip() for line in annos] 357 | if osp.splitext(file_name)[1] == '.json': 358 | annos = json.load(fp) 359 | return annos 360 | 361 | def save_file(obj, filename): 362 | """ 363 | save obj to filename 364 | :param obj: 365 | :param filename: 366 | :return: 367 | """ 368 | filepath = osp.dirname(filename) 369 | if filepath != '' and not osp.exists(filepath): 370 | os.makedirs(filepath) 371 | else: 372 | with open(filename, 'w') as fp: 373 | json.dump(obj, fp, indent=4) 374 | 375 | def pkload(file): 376 | data = None 377 | if osp.exists(file) and osp.getsize(file) > 0: 378 | with open(file, 'rb') as fp: 379 | data = pkl.load(fp) 380 | return data 381 | 382 | def pkdump(data, file): 383 | dirname = osp.dirname(file) 384 | if not osp.exists(dirname): 385 | os.makedirs(dirname) 386 | with open(file, 'wb') as fp: 387 | pkl.dump(data, fp) -------------------------------------------------------------------------------- /dataset/load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SharedArray as sa 3 | import h5py 4 | from tqdm import tqdm 5 | 6 | bert_file = './data/text_feature/text_seq.h5' 7 | with h5py.File(bert_file, 'r') as fp: 8 | for key in tqdm(fp.keys()): 9 | tmp = sa.create("shm://{}".format(key), fp[key].shape, 'float32') 10 | tmp[:] = fp[key][...] 11 | -------------------------------------------------------------------------------- /dataset/release.py: -------------------------------------------------------------------------------- 1 | import os 2 | import SharedArray as sa 3 | import h5py 4 | from tqdm import tqdm 5 | 6 | bert_file = './data/text_feature/text_seq.h5' 7 | with h5py.File(bert_file, 'r') as fp: 8 | for key in tqdm(fp.keys()): 9 | sa.delete("shm://{}".format(key)) -------------------------------------------------------------------------------- /dataset/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import pickle as pkl 6 | import pandas as pd 7 | 8 | def load_file(file_name): 9 | annos = None 10 | if osp.splitext(file_name)[-1] == '.csv': 11 | return pd.read_csv(file_name) 12 | with open(file_name, 'r') as fp: 13 | if osp.splitext(file_name)[1]== '.txt': 14 | annos = fp.readlines() 15 | annos = [line.rstrip() for line in annos] 16 | if osp.splitext(file_name)[1] == '.json': 17 | annos = json.load(fp) 18 | 19 | return annos 20 | 21 | def save_file(obj, filename): 22 | """ 23 | save obj to filename 24 | :param obj: 25 | :param filename: 26 | :return: 27 | """ 28 | filepath = osp.dirname(filename) 29 | if filepath != '' and not osp.exists(filepath): 30 | os.makedirs(filepath) 31 | else: 32 | with open(filename, 'w') as fp: 33 | json.dump(obj, fp, indent=4) 34 | 35 | def pkload(file): 36 | data = None 37 | if osp.exists(file) and osp.getsize(file) > 0: 38 | with open(file, 'rb') as fp: 39 | data = pkl.load(fp) 40 | # print('{} does not exist'.format(file)) 41 | return data 42 | 43 | 44 | def pkdump(data, file): 45 | dirname = osp.dirname(file) 46 | if not osp.exists(dirname): 47 | os.makedirs(dirname) 48 | with open(file, 'wb') as fp: 49 | pkl.dump(data, fp) 50 | 51 | 52 | -------------------------------------------------------------------------------- /eval_mc.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from utils import load_file 3 | 4 | map_name = {'D': 'Description', 'E': 'Explaination', 'PA': 'Predictive-Answer', 'CA': 'Counterfactual-Answer', 'PR': 'Predictive-Reason', 'CR': 'Counterfactual-Reason', 'P':'Predictive', 'C': 'Counterfactual'} 5 | 6 | def accuracy_metric(result_file, qtype): 7 | if qtype == -1: 8 | accuracy_metric_all(result_file) 9 | if qtype == 0: 10 | accuracy_metric_q0(result_file) 11 | if qtype == 1: 12 | accuracy_metric_q1(result_file) 13 | if qtype == 2: 14 | accuracy_metric_q2(result_file) 15 | if qtype == 3: 16 | accuracy_metric_q3(result_file) 17 | 18 | def accuracy_metric_q0(result_file): 19 | preds = list(load_file(result_file).items()) 20 | group_acc = {'D': 0} 21 | group_cnt = {'D': 0} 22 | all_acc = 0 23 | all_cnt = 0 24 | for idx in range(len(preds)): 25 | id_qtypes = preds[idx] 26 | answer = id_qtypes[1]['answer'] 27 | pred = id_qtypes[1]['prediction'] 28 | group_cnt['D'] += 1 29 | all_cnt += 1 30 | if answer == pred: 31 | group_acc['D'] += 1 32 | all_acc += 1 33 | for qtype, acc in group_acc.items(): # 34 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype])) 35 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt)) 36 | 37 | def accuracy_metric_q1(result_file): 38 | preds = list(load_file(result_file).items()) 39 | group_acc = {'E': 0} 40 | group_cnt = {'E': 0} 41 | all_acc = 0 42 | all_cnt = 0 43 | for idx in range(len(preds)): 44 | id_qtypes = preds[idx] 45 | answer = id_qtypes[1]['answer'] 46 | pred = id_qtypes[1]['prediction'] 47 | group_cnt['E'] += 1 48 | all_cnt += 1 49 | if answer == pred: 50 | group_acc['E'] += 1 51 | all_acc += 1 52 | for qtype, acc in group_acc.items(): # 53 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype])) 54 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt)) 55 | 56 | def accuracy_metric_q2(result_file): 57 | preds = list(load_file(result_file).items()) 58 | qtype2short = ['PA', 'PR', 'P'] 59 | group_acc = {'PA': 0, 'PR': 0, 'P': 0} 60 | group_cnt = {'PA': 0, 'PR': 0, 'P': 0} 61 | all_acc = 0 62 | all_cnt = 0 63 | for idx in range(len(preds)//2): 64 | id_qtypes = preds[idx*2:(idx+1)*2] 65 | qtypes = [0, 1] 66 | answer = [ans_pre[1]['answer'] for ans_pre in id_qtypes] 67 | pred = [ans_pre[1]['prediction'] for ans_pre in id_qtypes] 68 | for i in range(2): 69 | group_cnt[qtype2short[qtypes[i]]] += 1 70 | if answer[i] == pred[i]: 71 | group_acc[qtype2short[qtypes[i]]] += 1 72 | group_cnt['P'] += 1 73 | all_cnt += 1 74 | if answer[0] == pred[0] and answer[1] == pred[1]: 75 | group_acc['P'] += 1 76 | all_acc += 1 77 | for qtype, acc in group_acc.items(): # 78 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype])) 79 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt)) 80 | 81 | def accuracy_metric_q3(result_file): 82 | preds = list(load_file(result_file).items()) 83 | qtype2short = ['CA', 'CR', 'C'] 84 | group_acc = {'CA': 0, 'CR': 0, 'C': 0} 85 | group_cnt = {'CA': 0, 'CR': 0, 'C': 0} 86 | all_acc = 0 87 | all_cnt = 0 88 | for idx in range(len(preds)//2): 89 | id_qtypes = preds[idx*2:(idx+1)*2] 90 | qtypes = [0, 1] 91 | answer = [ans_pre[1]['answer'] for ans_pre in id_qtypes] 92 | pred = [ans_pre[1]['prediction'] for ans_pre in id_qtypes] 93 | for i in range(2): 94 | group_cnt[qtype2short[qtypes[i]]] += 1 95 | if answer[i] == pred[i]: 96 | group_acc[qtype2short[qtypes[i]]] += 1 97 | group_cnt['C'] += 1 98 | all_cnt += 1 99 | if answer[0] == pred[0] and answer[1] == pred[1]: 100 | group_acc['C'] += 1 101 | all_acc += 1 102 | for qtype, acc in group_acc.items(): # 103 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype])) 104 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt)) 105 | 106 | def accuracy_metric_all(result_file): 107 | preds = list(load_file(result_file).items()) 108 | qtype2short = ['D', 'E', 'PA', 'PR', 'CA', 'CR', 'P', 'C'] 109 | group_acc = {'D': 0, 'E': 0, 'PA': 0, 'PR': 0, 'CA': 0, 'CR': 0, 'P': 0, 'C': 0} 110 | group_cnt = {'D': 0, 'E': 0, 'PA': 0, 'PR': 0, 'CA': 0, 'CR': 0, 'P': 0, 'C': 0} 111 | all_acc = 0 112 | all_cnt = 0 113 | for idx in range(len(preds)//6): 114 | id_qtypes = preds[idx*6:(idx+1)*6] 115 | qtypes = [int(id_qtype[0].split('_')[-1]) for id_qtype in id_qtypes] 116 | answer = [ans_pre[1]['answer'] for ans_pre in id_qtypes] 117 | pred = [ans_pre[1]['prediction'] for ans_pre in id_qtypes] 118 | for i in range(6): 119 | group_cnt[qtype2short[qtypes[i]]] += 1 120 | if answer[i] == pred[i]: 121 | group_acc[qtype2short[qtypes[i]]] += 1 122 | group_cnt['C'] += 1 123 | group_cnt['P'] += 1 124 | all_cnt += 4 125 | if answer[0] == pred[0]: 126 | all_acc += 1 127 | if answer[1] == pred[1]: 128 | all_acc += 1 129 | if answer[2] == pred[2] and answer[3] == pred[3]: 130 | group_acc['P'] += 1 131 | all_acc += 1 132 | if answer[4] == pred[4] and answer[5] == pred[5]: 133 | group_acc['C'] += 1 134 | all_acc += 1 135 | for qtype, acc in group_acc.items(): # 136 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype])) 137 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt)) 138 | 139 | def main(result_file, qtype=-1): 140 | print('Evaluating {}'.format(result_file)) 141 | 142 | accuracy_metric(result_file, qtype) 143 | 144 | 145 | if __name__ == "__main__": 146 | result_file = 'path to results json' 147 | main(result_file, -1) 148 | -------------------------------------------------------------------------------- /fig/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bcmi/Causal-VidQA/d2c05a1bd306dc9e17f619a5c47f4731b81a3031/fig/example.png -------------------------------------------------------------------------------- /main_qa.py: -------------------------------------------------------------------------------- 1 | from videoqa import * 2 | from dataset import VidQADataset, Vocabulary 3 | from torch.utils.data import Dataset, DataLoader 4 | from utils import * 5 | import argparse 6 | import eval_mc 7 | torch.multiprocessing.set_sharing_strategy('file_system') 8 | 9 | def main(args): 10 | 11 | mode = args.mode 12 | if mode == 'train': 13 | batch_size = args.batch_size 14 | num_worker = 8 15 | else: 16 | batch_size = 32 17 | num_worker = 8 18 | 19 | feature_path = args.feature_path 20 | text_feature_path = args.text_feature_path 21 | data_path = args.data_path 22 | train_split_path = osp.join(args.split_path, 'train.pkl') 23 | valid_split_path = osp.join(args.split_path, 'valid.pkl') 24 | test_split_path = osp.join(args.split_path, 'test.pkl') 25 | qtype=args.qtype 26 | max_qa_len = args.max_qa_len 27 | 28 | vocab = pkload(osp.join(text_feature_path, 'qa_vocab.pkl')) 29 | 30 | glove_embed = osp.join(text_feature_path, 'glove.840B.300d.npy') 31 | use_bert = args.use_bert 32 | checkpoint_path = args.checkpoint_path 33 | model_type = args.model_type 34 | model_prefix= args.model_prefix 35 | 36 | vis_step = args.vis_step 37 | lr_rate = args.lr_rate 38 | epoch_num = args.epoch_num 39 | 40 | if not osp.exists(osp.join(checkpoint_path, model_type, model_prefix)): 41 | os.makedirs(osp.join(checkpoint_path, model_type, model_prefix)) 42 | if not osp.exists(osp.join(checkpoint_path, model_type, model_prefix, 'model')): 43 | os.makedirs(osp.join(checkpoint_path, model_type, model_prefix, 'model')) 44 | logger = make_logger(osp.join(checkpoint_path, model_type, model_prefix, 'log')) 45 | 46 | train_set = VidQADataset(feature_path=feature_path, text_feature_path=text_feature_path, split_path=train_split_path, data_path=data_path, use_bert=use_bert, vocab=vocab, qtype=qtype, max_length=max_qa_len) 47 | valid_set = VidQADataset(feature_path=feature_path, text_feature_path=text_feature_path, split_path=valid_split_path, data_path=data_path, use_bert=use_bert, vocab=vocab, qtype=qtype, max_length=max_qa_len) 48 | test_set = VidQADataset(feature_path=feature_path, text_feature_path=text_feature_path, split_path=test_split_path, data_path=data_path, use_bert=use_bert, vocab=vocab, qtype=qtype, max_length=max_qa_len) 49 | 50 | train_loader = DataLoader( 51 | dataset=train_set, 52 | batch_size=batch_size, 53 | shuffle=True, 54 | num_workers=num_worker) 55 | 56 | valid_loader = DataLoader( 57 | dataset=valid_set, 58 | batch_size=batch_size, 59 | shuffle=False, 60 | num_workers=num_worker) 61 | 62 | test_loader = DataLoader( 63 | dataset=test_set, 64 | batch_size=batch_size, 65 | shuffle=False, 66 | num_workers=num_worker) 67 | 68 | vqa = VideoQA(vocab, train_loader, valid_loader, test_loader, glove_embed, use_bert, checkpoint_path, model_type, model_prefix, 69 | vis_step, lr_rate, batch_size, epoch_num, logger, args) 70 | 71 | if mode != 'train': 72 | model_file = osp.join(args.checkpoint_path, model_type, model_prefix, 'model', 'best.ckpt') 73 | result_file1 = args.result_file.format(model_type, model_prefix, 'valid') 74 | result_file2 = args.result_file.format(model_type, model_prefix, 'test') 75 | vqa.predict(model_file, result_file1, vqa.val_loader) 76 | vqa.predict(model_file, result_file2, vqa.test_loader) 77 | print('Validation set') 78 | eval_mc.main(result_file1, qtype=args.qtype) 79 | print('Test set') 80 | eval_mc.main(result_file2, qtype=args.qtype) 81 | else: 82 | model_file = osp.join(model_type, model_prefix, 'model', '0-00.00.ckpt') 83 | vqa.run(model_file, pre_trained=False) 84 | model_file = osp.join(args.checkpoint_path, model_type, model_prefix, 'model', 'best.ckpt') 85 | result_file1 = args.result_file.format(model_type, model_prefix, 'valid') 86 | result_file2 = args.result_file.format(model_type, model_prefix, 'test') 87 | vqa.predict(model_file, result_file1, vqa.val_loader) 88 | vqa.predict(model_file, result_file2, vqa.test_loader) 89 | print('Validation set') 90 | eval_mc.main(result_file1, qtype=args.qtype) 91 | print('Test set') 92 | eval_mc.main(result_file2, qtype=args.qtype) 93 | 94 | if __name__ == "__main__": 95 | torch.backends.cudnn.enabled = False 96 | torch.manual_seed(666) 97 | torch.cuda.manual_seed(666) 98 | torch.backends.cudnn.benchmark = True 99 | 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('--gpu', type=int, default=0, 102 | help='gpu device id') 103 | parser.add_argument('--mode', type=str, default='train', 104 | help='train or val') 105 | parser.add_argument('--feature_path', type=str, default='', 106 | help='path to load visual feature') 107 | parser.add_argument('--text_feature_path', type=str, default='', 108 | help='path to load text feature') 109 | parser.add_argument('--data_path', type=str, default='', 110 | help='path to load original data') 111 | parser.add_argument('--split_path', type=str, default='', 112 | help='path for train/valid/test split') 113 | parser.add_argument('--use_bert', action='store_true', 114 | help='whether use bert embedding') 115 | parser.add_argument('--checkpoint_path', type=str, default='', 116 | help='path to save training model and log') 117 | parser.add_argument('--model_type', type=str, default='HGA', 118 | help='(B2A, EVQA, CoMem, HME, HGA, HCRN)') 119 | parser.add_argument('--model_prefix', type=str, default='debug', 120 | help='detail model info') 121 | parser.add_argument('--result_file', type=str, default='', 122 | help='where to save processed results') 123 | 124 | parser.add_argument('--vid_dim', type=int, default=4096, 125 | help='number of dim for video features') 126 | parser.add_argument('--hidden_dim', type=int, default=256, 127 | help='number of dim for hidden feature') 128 | parser.add_argument('--word_dim', type=int, default=300, 129 | help='number of dim for word feature') 130 | parser.add_argument('--max_vid_len', type=int, default=8, 131 | help='number of max length for video clips') 132 | parser.add_argument('--max_vid_frame_len', type=int, default=16, 133 | help='number of max length for frames in each video clip') 134 | parser.add_argument('--max_qa_len', type=int, default=40, 135 | help='number of max length for question and answer') 136 | parser.add_argument('--vis_step', type=int, default=100, 137 | help='number of step to print the training info') 138 | parser.add_argument('--epoch_num', type=int, default=30, 139 | help='number of epoch to train model') 140 | parser.add_argument('--lr_rate', type=float, default=1e-4, 141 | help='learning rate') 142 | parser.add_argument('--qtype', type=int, default=-1, 143 | help='question type in VVCR dataset') 144 | parser.add_argument('--batch_size', type=int, default=128, 145 | help='batch size') 146 | parser.add_argument('--gcn_layer', type=int, default=1, 147 | help='gcn layer') 148 | parser.add_argument('--spl_resolution', type=int, default=16, 149 | help='spl_resolution') 150 | args = parser.parse_args() 151 | 152 | main(args) -------------------------------------------------------------------------------- /networks/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TempAttention(nn.Module): 7 | """ 8 | Applies an attention mechanism on the output features from the decoder. 9 | """ 10 | 11 | def __init__(self, text_dim, visual_dim, hidden_dim): 12 | super(TempAttention, self).__init__() 13 | self.hidden_dim = hidden_dim 14 | self.linear_text = nn.Linear(text_dim, hidden_dim) 15 | self.linear_visual = nn.Linear(visual_dim, hidden_dim) 16 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False) 17 | self._init_weight() 18 | 19 | def _init_weight(self): 20 | nn.init.xavier_normal_(self.linear_text.weight) 21 | nn.init.xavier_normal_(self.linear_visual.weight) 22 | nn.init.xavier_normal_(self.linear_att.weight) 23 | 24 | def forward(self, qns_embed, vid_outputs): 25 | """ 26 | Arguments: 27 | qns_embed {Variable} -- batch_size x dim 28 | vid_outputs {Variable} -- batch_size x seq_len x dim 29 | 30 | Returns: 31 | context -- context vector of size batch_size x dim 32 | """ 33 | qns_embed_trans = self.linear_text(qns_embed) 34 | 35 | batch_size, seq_len, visual_dim = vid_outputs.size() 36 | vid_outputs_temp = vid_outputs.contiguous().view(batch_size*seq_len, visual_dim) 37 | vid_outputs_trans = self.linear_visual(vid_outputs_temp) 38 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.hidden_dim) 39 | 40 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1) 41 | 42 | 43 | o = self.linear_att(torch.tanh(qns_embed_trans+vid_outputs_trans)) 44 | 45 | e = o.view(batch_size, seq_len) 46 | beta = F.softmax(e, dim=1) 47 | context = torch.bmm(beta.unsqueeze(1), vid_outputs).squeeze(1) 48 | 49 | return context, beta 50 | 51 | 52 | class SpatialAttention(nn.Module): 53 | """ 54 | Apply spatial attention on vid feature before being fed into LSTM 55 | """ 56 | 57 | def __init__(self, text_dim=1024, vid_dim=3072, hidden_dim=512, input_dropout_p=0.2): 58 | super(SpatialAttention, self).__init__() 59 | 60 | self.linear_v = nn.Linear(vid_dim, hidden_dim) 61 | self.linear_q = nn.Linear(text_dim, hidden_dim) 62 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False) 63 | 64 | self.softmax = nn.Softmax(dim=1) 65 | self.dropout = nn.Dropout(input_dropout_p) 66 | self._init_weight() 67 | 68 | def _init_weight(self): 69 | nn.init.xavier_normal_(self.linear_v.weight) 70 | nn.init.xavier_normal_(self.linear_q.weight) 71 | nn.init.xavier_normal_(self.linear_att.weight) 72 | 73 | def forward(self, qns_feat, vid_feats): 74 | """ 75 | Apply question feature as semantic clue to guide feature aggregation at each frame 76 | :param vid_feats: fnum x feat_dim x 7 x 7 77 | :param qns_feat: dim_hidden*2 78 | :return: 79 | """ 80 | # print(qns_feat.size(), vid_feats.size()) 81 | # permute to fnum x 7 x 7 x feat_dim 82 | vid_feats = vid_feats.permute(0, 2, 3, 1) 83 | fnum, width, height, feat_dim = vid_feats.size() 84 | vid_feats = vid_feats.contiguous().view(-1, feat_dim) 85 | vid_feats_trans = self.linear_v(vid_feats) 86 | 87 | vid_feats_trans = vid_feats_trans.view(fnum, width*height, -1) 88 | region_num = vid_feats_trans.shape[1] 89 | 90 | qns_feat_trans = self.linear_q(qns_feat) 91 | 92 | qns_feat_trans = qns_feat_trans.repeat(fnum, region_num, 1) 93 | # print(vid_feats_trans.shape, qns_feat_trans.shape) 94 | 95 | vid_qns = self.linear_att(torch.tanh(vid_feats_trans + qns_feat_trans)) 96 | 97 | vid_qns_o = vid_qns.view(fnum, region_num) 98 | alpha = self.softmax(vid_qns_o) 99 | alpha = alpha.unsqueeze(1) 100 | vid_feats = vid_feats.view(fnum, region_num, -1) 101 | feature = torch.bmm(alpha, vid_feats).squeeze(1) 102 | feature = self.dropout(feature) 103 | # print(feature.size()) 104 | return feature, alpha 105 | 106 | 107 | class TempAttentionHis(nn.Module): 108 | """ 109 | Applies an attention mechanism on the output features from the decoder. 110 | """ 111 | 112 | def __init__(self, visual_dim, text_dim, his_dim, mem_dim): 113 | super(TempAttentionHis, self).__init__() 114 | # self.dim = dim 115 | self.mem_dim = mem_dim 116 | self.linear_v = nn.Linear(visual_dim, self.mem_dim, bias=False) 117 | self.linear_q = nn.Linear(text_dim, self.mem_dim, bias=False) 118 | self.linear_his1 = nn.Linear(his_dim, self.mem_dim, bias=False) 119 | self.linear_his2 = nn.Linear(his_dim, self.mem_dim, bias=False) 120 | self.linear_att = nn.Linear(self.mem_dim, 1, bias=False) 121 | self._init_weight() 122 | 123 | 124 | def _init_weight(self): 125 | nn.init.xavier_normal_(self.linear_v.weight) 126 | nn.init.xavier_normal_(self.linear_q.weight) 127 | nn.init.xavier_normal_(self.linear_his1.weight) 128 | nn.init.xavier_normal_(self.linear_his2.weight) 129 | nn.init.xavier_normal_(self.linear_att.weight) 130 | 131 | 132 | def forward(self, qns_embed, vid_outputs, his): 133 | """ 134 | :param qns_embed: batch_size x 1024 135 | :param vid_outputs: batch_size x seq_num x feat_dim 136 | :param his: batch_size x 512 137 | :return: 138 | """ 139 | 140 | batch_size, seq_len, feat_dim = vid_outputs.size() 141 | vid_outputs_trans = self.linear_v(vid_outputs.contiguous().view(batch_size * seq_len, feat_dim)) 142 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.mem_dim) 143 | 144 | qns_embed_trans = self.linear_q(qns_embed) 145 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1) 146 | 147 | 148 | his_trans = self.linear_his1(his) 149 | his_trans = his_trans.unsqueeze(1).repeat(1, seq_len, 1) 150 | 151 | o = self.linear_att(torch.tanh(qns_embed_trans + vid_outputs_trans + his_trans)) 152 | 153 | e = o.view(batch_size, seq_len) 154 | beta = F.softmax(e, dim=1) 155 | context = torch.bmm(beta.unsqueeze(1), vid_outputs_trans).squeeze(1) 156 | 157 | his_acc = torch.tanh(self.linear_his2(his)) 158 | 159 | context += his_acc 160 | 161 | return context, beta 162 | 163 | 164 | class MultiModalAttentionModule(nn.Module): 165 | 166 | def __init__(self, hidden_size=512, simple=False): 167 | """Set the hyper-parameters and build the layers.""" 168 | super(MultiModalAttentionModule, self).__init__() 169 | 170 | self.hidden_size = hidden_size 171 | self.simple = simple 172 | 173 | # alignment model 174 | # see appendices A.1.2 of neural machine translation 175 | 176 | self.Wav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 177 | self.Wat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 178 | self.Uav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 179 | self.Uat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 180 | self.Vav = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 181 | self.Vat = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 182 | self.bav = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True) 183 | self.bat = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True) 184 | 185 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 186 | self.Wvh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 187 | self.Wth = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 188 | self.bh = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True) 189 | 190 | self.video_sum_encoder = nn.Linear(hidden_size, hidden_size) 191 | self.question_sum_encoder = nn.Linear(hidden_size, hidden_size) 192 | 193 | self.Wb = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 194 | self.Vbv = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 195 | self.Vbt = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 196 | self.bbv = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 197 | self.bbt = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 198 | self.wb = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 199 | self.init_weights() 200 | 201 | def init_weights(self): 202 | self.Wav.data.normal_(0.0, 0.1) 203 | self.Wat.data.normal_(0.0, 0.1) 204 | self.Uav.data.normal_(0.0, 0.1) 205 | self.Uat.data.normal_(0.0, 0.1) 206 | self.Vav.data.normal_(0.0, 0.1) 207 | self.Vat.data.normal_(0.0, 0.1) 208 | self.bav.data.fill_(0) 209 | self.bat.data.fill_(0) 210 | 211 | self.Whh.data.normal_(0.0, 0.1) 212 | self.Wvh.data.normal_(0.0, 0.1) 213 | self.Wth.data.normal_(0.0, 0.1) 214 | self.bh.data.fill_(0) 215 | 216 | self.Wb.data.normal_(0.0, 0.01) 217 | self.Vbv.data.normal_(0.0, 0.01) 218 | self.Vbt.data.normal_(0.0, 0.01) 219 | self.wb.data.normal_(0.0, 0.01) 220 | 221 | self.bbv.data.fill_(0) 222 | self.bbt.data.fill_(0) 223 | 224 | def forward(self, h, hidden_frames, hidden_text, inv_attention=False): 225 | # print self.Uav 226 | # hidden_text: 1 x T1 x 1024 (looks like a two layer one-directional LSTM, combining each layer's hidden) 227 | # hidden_frame: 1 x T2 x 1024 (from video encoder output, 1024 is similar from above) 228 | 229 | # print hidden_frames.size(),hidden_text.size() 230 | Uhv = torch.matmul(h, self.Uav) # (1,512) 231 | Uhv = Uhv.view(Uhv.size(0), 1, Uhv.size(1)) # (1,1,512) 232 | 233 | Uht = torch.matmul(h, self.Uat) # (1,512) 234 | Uht = Uht.view(Uht.size(0), 1, Uht.size(1)) # (1,1,512) 235 | 236 | # print Uhv.size(),Uht.size() 237 | 238 | Wsv = torch.matmul(hidden_frames, self.Wav) # (1,T,512) 239 | # print Wsv.size() 240 | att_vec_v = torch.matmul(torch.tanh(Wsv + Uhv + self.bav), self.Vav) 241 | 242 | Wst = torch.matmul(hidden_text, self.Wat) # (1,T,512) 243 | att_vec_t = torch.matmul(torch.tanh(Wst + Uht + self.bat), self.Vat) 244 | 245 | if inv_attention == True: 246 | att_vec_v = -att_vec_v 247 | att_vec_t = -att_vec_t 248 | 249 | att_vec_v = torch.softmax(att_vec_v, dim=1) 250 | att_vec_t = torch.softmax(att_vec_t, dim=1) 251 | 252 | att_vec_v = att_vec_v.view(att_vec_v.size(0), att_vec_v.size(1), 1) # expand att_vec from 1xT to 1xTx1 253 | att_vec_t = att_vec_t.view(att_vec_t.size(0), att_vec_t.size(1), 1) # expand att_vec from 1xT to 1xTx1 254 | 255 | hv_weighted = att_vec_v * hidden_frames 256 | hv_sum = torch.sum(hv_weighted, dim=1) 257 | hv_sum2 = self.video_sum_encoder(hv_sum) 258 | 259 | ht_weighted = att_vec_t * hidden_text 260 | ht_sum = torch.sum(ht_weighted, dim=1) 261 | ht_sum2 = self.question_sum_encoder(ht_sum) 262 | 263 | Wbs = torch.matmul(h, self.Wb) 264 | mt1 = torch.matmul(ht_sum, self.Vbt) + self.bbt + Wbs 265 | mv1 = torch.matmul(hv_sum, self.Vbv) + self.bbv + Wbs 266 | mtv = torch.tanh(torch.cat([mv1, mt1], dim=0)) 267 | mtv2 = torch.matmul(mtv, self.wb) 268 | beta = torch.softmax(mtv2, dim=0) 269 | 270 | output = torch.tanh(torch.matmul(h, self.Whh) + beta[0] * hv_sum2 + 271 | beta[1] * ht_sum2 + self.bh) 272 | output = output.view(output.size(1), output.size(2)) 273 | 274 | return output -------------------------------------------------------------------------------- /networks/CRN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.modules.module import Module 7 | 8 | class EncoderVidCRN(nn.Module): 9 | def __init__(self, k_max_frame_level, k_max_clip_level, spl_resolution, vision_dim, dim_hidden=512): 10 | super(EncoderVidCRN, self).__init__() 11 | 12 | self.clip_level_motion_cond = CRN(dim_hidden, k_max_frame_level, k_max_frame_level, gating=False, spl_resolution=spl_resolution) 13 | self.clip_level_question_cond = CRN(dim_hidden, k_max_frame_level-2, k_max_frame_level-2, gating=True, spl_resolution=spl_resolution) 14 | self.video_level_motion_cond = CRN(dim_hidden, k_max_clip_level, k_max_clip_level, gating=False, spl_resolution=spl_resolution) 15 | self.video_level_question_cond = CRN(dim_hidden, k_max_clip_level-2, k_max_clip_level-2, gating=True, spl_resolution=spl_resolution) 16 | 17 | self.sequence_encoder = nn.LSTM(vision_dim, dim_hidden, batch_first=True, bidirectional=False) 18 | self.clip_level_motion_proj = nn.Linear(vision_dim, dim_hidden) 19 | self.video_level_motion_proj = nn.Linear(dim_hidden, dim_hidden) 20 | self.appearance_feat_proj = nn.Linear(vision_dim, dim_hidden) 21 | 22 | self.question_embedding_proj = nn.Linear(dim_hidden, dim_hidden) 23 | 24 | self.dim_hidden = dim_hidden 25 | self.activation = nn.ELU() 26 | 27 | def forward(self, appearance_video_feat, motion_video_feat, question_embedding): 28 | """ 29 | Args: 30 | appearance_video_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 31 | motion_video_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 32 | question_embedding: [Tensor] (batch_size, dim_hidden) 33 | return: 34 | encoded video feature: [Tensor] (batch_size, N, dim_hidden) 35 | """ 36 | batch_size = appearance_video_feat.size(0) 37 | clip_level_crn_outputs = [] 38 | question_embedding_proj = self.question_embedding_proj(question_embedding) 39 | for i in range(appearance_video_feat.size(1)): 40 | clip_level_motion = motion_video_feat[:, i, :] # (bz, 2048) 41 | clip_level_motion_proj = self.clip_level_motion_proj(clip_level_motion) 42 | 43 | clip_level_appearance = appearance_video_feat[:, i, :, :] # (bz, 16, 2048) 44 | clip_level_appearance_proj = self.appearance_feat_proj(clip_level_appearance) # (bz, 16, 512) 45 | # clip level CRNs 46 | clip_level_crn_motion = self.clip_level_motion_cond(torch.unbind(clip_level_appearance_proj, dim=1), 47 | clip_level_motion_proj) 48 | clip_level_crn_question = self.clip_level_question_cond(clip_level_crn_motion, question_embedding_proj) 49 | 50 | clip_level_crn_output = torch.cat( 51 | [frame_relation.unsqueeze(1) for frame_relation in clip_level_crn_question], 52 | dim=1) 53 | clip_level_crn_output = clip_level_crn_output.view(batch_size, -1, self.dim_hidden) 54 | clip_level_crn_outputs.append(clip_level_crn_output) 55 | 56 | # Encode video level motion 57 | _, (video_level_motion, _) = self.sequence_encoder(motion_video_feat) 58 | video_level_motion = video_level_motion.transpose(0, 1) 59 | video_level_motion_feat_proj = self.video_level_motion_proj(video_level_motion) 60 | # video level CRNs 61 | video_level_crn_motion = self.video_level_motion_cond(clip_level_crn_outputs, video_level_motion_feat_proj) 62 | video_level_crn_question = self.video_level_question_cond(video_level_crn_motion, 63 | question_embedding_proj.unsqueeze(1)) 64 | 65 | video_level_crn_output = torch.cat([clip_relation.unsqueeze(1) for clip_relation in video_level_crn_question], 66 | dim=1) 67 | video_level_crn_output = video_level_crn_output.view(batch_size, -1, self.dim_hidden) 68 | 69 | return video_level_crn_output 70 | 71 | class CRN(Module): 72 | def __init__(self, dim_hidden, num_objects, max_subset_size, gating=False, spl_resolution=1): 73 | super(CRN, self).__init__() 74 | self.dim_hidden = dim_hidden 75 | self.gating = gating 76 | 77 | self.k_objects_fusion = nn.ModuleList() 78 | if self.gating: 79 | self.gate_k_objects_fusion = nn.ModuleList() 80 | for i in range(min(num_objects, max_subset_size + 1), 1, -1): 81 | self.k_objects_fusion.append(nn.Linear(2 * dim_hidden, dim_hidden)) 82 | if self.gating: 83 | self.gate_k_objects_fusion.append(nn.Linear(2 * dim_hidden, dim_hidden)) 84 | self.spl_resolution = spl_resolution 85 | self.activation = nn.ELU() 86 | self.max_subset_size = max_subset_size 87 | 88 | def forward(self, object_list, cond_feat): 89 | """ 90 | :param object_list: list of tensors or vectors 91 | :param cond_feat: conditioning feature 92 | :return: list of output objects 93 | """ 94 | scales = [i for i in range(len(object_list), 1, -1)] 95 | 96 | relations_scales = [] 97 | subsample_scales = [] 98 | for scale in scales: 99 | relations_scale = self.relationset(len(object_list), scale) 100 | relations_scales.append(relations_scale) 101 | subsample_scales.append(min(self.spl_resolution, len(relations_scale))) 102 | 103 | crn_feats = [] 104 | if len(scales) > 1 and self.max_subset_size == len(object_list): 105 | start_scale = 1 106 | else: 107 | start_scale = 0 108 | for scaleID in range(start_scale, min(len(scales), self.max_subset_size)): 109 | idx_relations_randomsample = np.random.choice(len(relations_scales[scaleID]), 110 | subsample_scales[scaleID], replace=False) 111 | mono_scale_features = 0 112 | for id_choice, idx in enumerate(idx_relations_randomsample): 113 | clipFeatList = [object_list[obj].unsqueeze(1) for obj in relations_scales[scaleID][idx]] 114 | # g_theta 115 | g_feat = torch.cat(clipFeatList, dim=1) 116 | g_feat = g_feat.mean(1) 117 | if len(g_feat.size()) == 2: 118 | h_feat = torch.cat((g_feat, cond_feat), dim=-1) 119 | elif len(g_feat.size()) == 3: 120 | cond_feat_repeat = cond_feat.repeat(1, g_feat.size(1), 1) 121 | h_feat = torch.cat((g_feat, cond_feat_repeat), dim=-1) 122 | if self.gating: 123 | h_feat = self.activation(self.k_objects_fusion[scaleID](h_feat)) * torch.sigmoid( 124 | self.gate_k_objects_fusion[scaleID](h_feat)) 125 | else: 126 | h_feat = self.activation(self.k_objects_fusion[scaleID](h_feat)) 127 | mono_scale_features += h_feat 128 | crn_feats.append(mono_scale_features / len(idx_relations_randomsample)) 129 | return crn_feats 130 | 131 | def relationset(self, num_objects, num_object_relation): 132 | return list(itertools.combinations([i for i in range(num_objects)], num_object_relation)) 133 | -------------------------------------------------------------------------------- /networks/Embed_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | # __all__ = ['MultipleChoiceLoss', 'CountLoss'] 6 | 7 | 8 | class MultipleChoiceLoss(nn.Module): 9 | 10 | def __init__(self, num_option=5, margin=1, size_average=True): 11 | super(MultipleChoiceLoss, self).__init__() 12 | self.margin = margin 13 | self.num_option = num_option 14 | self.size_average = size_average 15 | 16 | # score is N x C 17 | 18 | def forward(self, score, target): 19 | N = score.size(0) 20 | C = score.size(1) 21 | assert self.num_option == C 22 | 23 | loss = torch.tensor(0.0).cuda() 24 | zero = torch.tensor(0.0).cuda() 25 | 26 | cnt = 0 27 | #print(N,C) 28 | for b in range(N): 29 | # loop over incorrect answer, check if correct answer's score larger than a margin 30 | c0 = target[b] 31 | for c in range(C): 32 | if c == c0: 33 | continue 34 | 35 | # right class and wrong class should have score difference larger than a margin 36 | # see formula under paper Eq(4) 37 | loss += torch.max(zero, 1.0 + score[b, c] - score[b, c0]) 38 | cnt += 1 39 | 40 | if cnt == 0: 41 | return loss 42 | 43 | return loss / cnt if self.size_average else loss 44 | -------------------------------------------------------------------------------- /networks/EncoderRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | from torch.nn import init 5 | import numpy as np 6 | import os 7 | 8 | def init_modules(modules, w_init='kaiming_uniform'): 9 | if w_init == "normal": 10 | _init = init.normal_ 11 | elif w_init == "xavier_normal": 12 | _init = init.xavier_normal_ 13 | elif w_init == "xavier_uniform": 14 | _init = init.xavier_uniform_ 15 | elif w_init == "kaiming_normal": 16 | _init = init.kaiming_normal_ 17 | elif w_init == "kaiming_uniform": 18 | _init = init.kaiming_uniform_ 19 | elif w_init == "orthogonal": 20 | _init = init.orthogonal_ 21 | else: 22 | raise NotImplementedError 23 | for m in modules: 24 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 25 | _init(m.weight) 26 | if m.bias is not None: 27 | torch.nn.init.zeros_(m.bias) 28 | if isinstance(m, (nn.LSTM, nn.GRU)): 29 | for name, param in m.named_parameters(): 30 | if 'bias' in name: 31 | nn.init.zeros_(param) 32 | elif 'weight' in name: 33 | _init(param) 34 | 35 | class EncoderQns(nn.Module): 36 | def __init__(self, dim_embed, dim_hidden, vocab_size, glove_embed, use_bert=True, input_dropout_p=0.2, rnn_dropout_p=0.1, n_layers=1, bidirectional=False, rnn_cell='gru'): 37 | """ 38 | """ 39 | super(EncoderQns, self).__init__() 40 | self.dim_hidden = dim_hidden 41 | self.vocab_size = vocab_size 42 | self.glove_embed = glove_embed 43 | self.input_dropout_p = input_dropout_p 44 | self.rnn_dropout_p = rnn_dropout_p 45 | self.n_layers = n_layers 46 | self.bidirectional = bidirectional 47 | self.rnn_cell = rnn_cell 48 | 49 | self.input_dropout = nn.Dropout(input_dropout_p) 50 | self.rnn_dropout = nn.Dropout(rnn_dropout_p) 51 | 52 | if rnn_cell.lower() == 'lstm': 53 | self.rnn_cell = nn.LSTM 54 | elif rnn_cell.lower() == 'gru': 55 | self.rnn_cell = nn.GRU 56 | 57 | input_dim = dim_embed 58 | self.use_bert = use_bert 59 | if self.use_bert: 60 | self.embedding = nn.Linear(input_dim, dim_embed, bias=False) 61 | else: 62 | self.embedding = nn.Embedding(vocab_size, dim_embed) 63 | 64 | self.obj_embedding = nn.Linear(2048, dim_embed, bias=False) 65 | 66 | self.rnn = self.rnn_cell(dim_embed, dim_hidden, n_layers, batch_first=True, 67 | bidirectional=bidirectional) 68 | 69 | # init_modules(self.modules(), w_init="xavier_uniform") 70 | # nn.init.uniform_(self.embedding.weight, -1.0, 1.0) 71 | 72 | if not self.use_bert and os.path.exists(self.glove_embed): 73 | word_mat = torch.FloatTensor(np.load(self.glove_embed)) 74 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False) 75 | 76 | def forward(self, qns, qns_lengths, hidden=None, obj=None): 77 | """ 78 | encode question 79 | :param qns: 80 | :param qns_lengths: 81 | :return: 82 | """ 83 | qns_embed = self.embedding(qns) 84 | assert obj is not None 85 | obj_embed = self.obj_embedding(obj) 86 | qns_embed = qns_embed + obj_embed 87 | qns_embed = self.input_dropout(qns_embed) 88 | packed = pack_padded_sequence(qns_embed, qns_lengths, batch_first=True, enforce_sorted=False) 89 | packed_output, hidden = self.rnn(packed, hidden) 90 | output, _ = pad_packed_sequence(packed_output, batch_first=True) 91 | output = self.rnn_dropout(output) 92 | hidden = self.rnn_dropout(hidden).squeeze() 93 | return output, hidden 94 | 95 | 96 | class EncoderVid(nn.Module): 97 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 98 | n_layers=1, bidirectional=False, rnn_cell='gru'): 99 | """ 100 | """ 101 | super(EncoderVid, self).__init__() 102 | self.dim_vid = dim_vid 103 | self.dim_app = 2048 104 | self.dim_motion = 4096 105 | self.dim_hidden = dim_hidden 106 | self.input_dropout_p = input_dropout_p 107 | self.rnn_dropout_p = rnn_dropout_p 108 | self.n_layers = n_layers 109 | self.bidirectional = bidirectional 110 | self.rnn_cell = rnn_cell 111 | 112 | if rnn_cell.lower() == 'lstm': 113 | self.rnn_cell = nn.LSTM 114 | elif rnn_cell.lower() == 'gru': 115 | self.rnn_cell = nn.GRU 116 | 117 | self.rnn = self.rnn_cell(dim_vid, dim_hidden, n_layers, batch_first=True, 118 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 119 | 120 | 121 | def forward(self, vid_feats): 122 | 123 | self.rnn.flatten_parameters() 124 | foutput, fhidden = self.rnn(vid_feats) 125 | 126 | return foutput, fhidden 127 | 128 | 129 | class EncoderVidSTVQA(nn.Module): 130 | def __init__(self, input_dim, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 131 | n_layers=1, bidirectional=False, rnn_cell='gru'): 132 | """ 133 | """ 134 | super(EncoderVidSTVQA, self).__init__() 135 | self.input_dim = input_dim 136 | self.dim_hidden = dim_hidden 137 | self.input_dropout_p = input_dropout_p 138 | self.rnn_dropout_p = rnn_dropout_p 139 | self.n_layers = n_layers 140 | self.bidirectional = bidirectional 141 | self.rnn_cell = rnn_cell 142 | 143 | 144 | if rnn_cell.lower() == 'lstm': 145 | self.rnn_cell = nn.LSTM 146 | elif rnn_cell.lower() == 'gru': 147 | self.rnn_cell = nn.GRU 148 | 149 | self.rnn1 = self.rnn_cell(input_dim, dim_hidden, n_layers, batch_first=True, 150 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 151 | 152 | self.rnn2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 153 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 154 | 155 | 156 | def forward(self, vid_feats): 157 | """ 158 | Dual-layer LSTM 159 | """ 160 | 161 | self.rnn1.flatten_parameters() 162 | 163 | foutput_1, fhidden_1 = self.rnn1(vid_feats) 164 | self.rnn2.flatten_parameters() 165 | foutput_2, fhidden_2 = self.rnn2(foutput_1) 166 | 167 | foutput = torch.cat((foutput_1, foutput_2), dim=2) 168 | fhidden = (torch.cat((fhidden_1[0], fhidden_2[0]), dim=0), 169 | torch.cat((fhidden_1[1], fhidden_2[1]), dim=0)) 170 | 171 | return foutput, fhidden 172 | 173 | 174 | class EncoderVidCoMem(nn.Module): 175 | def __init__(self, dim_app, dim_motion, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 176 | n_layers=1, bidirectional=False, rnn_cell='gru'): 177 | """ 178 | """ 179 | super(EncoderVidCoMem, self).__init__() 180 | self.dim_app = dim_app 181 | self.dim_motion = dim_motion 182 | self.dim_hidden = dim_hidden 183 | self.input_dropout_p = input_dropout_p 184 | self.rnn_dropout_p = rnn_dropout_p 185 | self.n_layers = n_layers 186 | self.bidirectional = bidirectional 187 | self.rnn_cell = rnn_cell 188 | 189 | if rnn_cell.lower() == 'lstm': 190 | self.rnn_cell = nn.LSTM 191 | elif rnn_cell.lower() == 'gru': 192 | self.rnn_cell = nn.GRU 193 | 194 | self.rnn_app_l1 = self.rnn_cell(self.dim_app, dim_hidden, n_layers, batch_first=True, 195 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 196 | self.rnn_app_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 197 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 198 | 199 | self.rnn_motion_l1 = self.rnn_cell(self.dim_motion, dim_hidden, n_layers, batch_first=True, 200 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 201 | self.rnn_motion_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 202 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 203 | 204 | 205 | def forward(self, vid_feats): 206 | """ 207 | two separate LSTM to encode app and motion feature 208 | :param vid_feats: 209 | :return: 210 | """ 211 | vid_app = vid_feats[:, :, 0:self.dim_app] 212 | vid_motion = vid_feats[:, :, self.dim_app:] 213 | 214 | app_output_l1, app_hidden_l1 = self.rnn_app_l1(vid_app) 215 | app_output_l2, app_hidden_l2 = self.rnn_app_l2(app_output_l1) 216 | 217 | 218 | motion_output_l1, motion_hidden_l1 = self.rnn_motion_l1(vid_motion) 219 | motion_output_l2, motion_hidden_l2 = self.rnn_motion_l2(motion_output_l1) 220 | 221 | return app_output_l1, app_output_l2, motion_output_l1, motion_output_l2 222 | 223 | 224 | class EncoderVidHGA(nn.Module): 225 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 226 | n_layers=1, bidirectional=False, rnn_cell='gru'): 227 | """ 228 | """ 229 | super(EncoderVidHGA, self).__init__() 230 | self.dim_vid = dim_vid 231 | self.dim_hidden = dim_hidden 232 | self.input_dropout_p = input_dropout_p 233 | self.rnn_dropout_p = rnn_dropout_p 234 | self.n_layers = n_layers 235 | self.bidirectional = bidirectional 236 | self.rnn_cell = rnn_cell 237 | 238 | 239 | self.vid2hid = nn.Sequential(nn.Linear(self.dim_vid, dim_hidden), 240 | nn.ReLU(), 241 | nn.Dropout(input_dropout_p)) 242 | 243 | 244 | if rnn_cell.lower() == 'lstm': 245 | self.rnn_cell = nn.LSTM 246 | elif rnn_cell.lower() == 'gru': 247 | self.rnn_cell = nn.GRU 248 | 249 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 250 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 251 | 252 | self._init_weight() 253 | 254 | 255 | def _init_weight(self): 256 | nn.init.xavier_normal_(self.vid2hid[0].weight) 257 | 258 | 259 | def forward(self, vid_feats): 260 | """ 261 | """ 262 | batch_size, seq_len, dim_vid = vid_feats.size() 263 | vid_feats_trans = self.vid2hid(vid_feats.view(-1, self.dim_vid)) 264 | vid_feats = vid_feats_trans.view(batch_size, seq_len, -1) 265 | 266 | self.rnn.flatten_parameters() 267 | foutput, fhidden = self.rnn(vid_feats) 268 | 269 | return foutput, fhidden 270 | 271 | class EncoderVidB2A(nn.Module): 272 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 273 | n_layers=1, bidirectional=False, rnn_cell='gru'): 274 | """ 275 | """ 276 | super(EncoderVidB2A, self).__init__() 277 | self.dim_vid = dim_vid 278 | self.dim_hidden = dim_hidden 279 | self.input_dropout_p = input_dropout_p 280 | self.rnn_dropout_p = rnn_dropout_p 281 | self.n_layers = n_layers 282 | self.bidirectional = bidirectional 283 | self.rnn_cell = rnn_cell 284 | 285 | 286 | self.vid2hid = nn.Sequential(nn.Linear(self.dim_vid, dim_hidden), 287 | nn.ReLU(), 288 | nn.Dropout(input_dropout_p)) 289 | 290 | 291 | if rnn_cell.lower() == 'lstm': 292 | self.rnn_cell = nn.LSTM 293 | elif rnn_cell.lower() == 'gru': 294 | self.rnn_cell = nn.GRU 295 | 296 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 297 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 298 | 299 | self._init_weight() 300 | 301 | 302 | def _init_weight(self): 303 | nn.init.xavier_normal_(self.vid2hid[0].weight) 304 | 305 | 306 | def forward(self, app_feat, mot_feat): 307 | """ 308 | """ 309 | batch_size, seq_len, seq_len2, dim_vid = app_feat.size() 310 | 311 | app_feat_trans = self.vid2hid(app_feat.view(-1, self.dim_vid)) 312 | app_feat = app_feat_trans.view(batch_size, seq_len*seq_len2, -1) 313 | 314 | mot_feat_trans = self.vid2hid(mot_feat.view(-1, self.dim_vid)) 315 | mot_feat = mot_feat_trans.view(batch_size, seq_len, -1) 316 | 317 | self.rnn.flatten_parameters() 318 | app_output, _ = self.rnn(app_feat) 319 | mot_output, _ = self.rnn(mot_feat) 320 | 321 | return app_output, mot_output -------------------------------------------------------------------------------- /networks/GCN.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | 9 | 10 | def padding_mask_k(seq_q, seq_k): 11 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 12 | In batch 0: 13 | [[x x x 0] [[0 0 0 1] 14 | [x x x 0]-> [0 0 0 1] 15 | [x x x 0]] [0 0 0 1]] uint8 16 | """ 17 | fake_q = torch.ones_like(seq_q) 18 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2)) 19 | pad_mask = pad_mask.eq(0) 20 | # pad_mask = pad_mask.lt(1e-3) 21 | return pad_mask 22 | 23 | 24 | def padding_mask_q(seq_q, seq_k): 25 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 26 | In batch 0: 27 | [[x x x x] [[0 0 0 0] 28 | [x x x x] -> [0 0 0 0] 29 | [0 0 0 0]] [1 1 1 1]] uint8 30 | """ 31 | fake_k = torch.ones_like(seq_k) 32 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2)) 33 | pad_mask = pad_mask.eq(0) 34 | # pad_mask = pad_mask.lt(1e-3) 35 | return pad_mask 36 | 37 | 38 | class GraphConvolution(Module): 39 | """ 40 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 41 | """ 42 | 43 | def __init__(self, in_features, out_features): 44 | super(GraphConvolution, self).__init__() 45 | self.weight = nn.Linear(in_features, out_features, bias=False) 46 | self.layer_norm = nn.LayerNorm(out_features, elementwise_affine=False) 47 | 48 | def forward(self, input, adj): 49 | # self.weight of shape (hidden_size, hidden_size) 50 | support = self.weight(input) 51 | output = torch.bmm(adj, support) 52 | output = self.layer_norm(output) 53 | return output 54 | 55 | 56 | class GraphAttention(nn.Module): 57 | """ 58 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 59 | """ 60 | 61 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 62 | super(GraphAttention, self).__init__() 63 | self.dropout = dropout 64 | self.in_features = in_features 65 | self.out_features = out_features 66 | self.alpha = alpha 67 | self.concat = concat 68 | 69 | self.W = nn.Parameter( 70 | nn.init.xavier_normal_( 71 | torch.Tensor(in_features, out_features).type( 72 | torch.cuda.FloatTensor if torch.cuda.is_available( 73 | ) else torch.FloatTensor), 74 | gain=np.sqrt(2.0)), 75 | requires_grad=True) 76 | self.a1 = nn.Parameter( 77 | nn.init.xavier_normal_( 78 | torch.Tensor(out_features, 1).type( 79 | torch.cuda.FloatTensor if torch.cuda.is_available( 80 | ) else torch.FloatTensor), 81 | gain=np.sqrt(2.0)), 82 | requires_grad=True) 83 | self.a2 = nn.Parameter( 84 | nn.init.xavier_normal_( 85 | torch.Tensor(out_features, 1).type( 86 | torch.cuda.FloatTensor if torch.cuda.is_available( 87 | ) else torch.FloatTensor), 88 | gain=np.sqrt(2.0)), 89 | requires_grad=True) 90 | 91 | self.leakyrelu = nn.LeakyReLU(self.alpha) 92 | 93 | def forward(self, input, adj): 94 | h = torch.mm(input, self.W) 95 | N = h.size()[0] 96 | 97 | f_1 = torch.matmul(h, self.a1) 98 | f_2 = torch.matmul(h, self.a2) 99 | e = self.leakyrelu(f_1 + f_2.transpose(0, 1)) 100 | 101 | zero_vec = -9e15 * torch.ones_like(e) 102 | attention = torch.where(adj > 0, e, zero_vec) 103 | attention = F.softmax(attention, dim=1) 104 | attention = F.dropout(attention, self.dropout, training=self.training) 105 | h_prime = torch.matmul(attention, h) 106 | 107 | if self.concat: 108 | return F.elu(h_prime) 109 | else: 110 | return h_prime 111 | 112 | 113 | class GCN(nn.Module): 114 | 115 | def __init__( 116 | self, input_size, hidden_size, num_classes, num_layers=1, 117 | dropout=0.1): 118 | super(GCN, self).__init__() 119 | self.layers = nn.ModuleList() 120 | self.layers.append(GraphConvolution(input_size, hidden_size)) 121 | for i in range(num_layers - 1): 122 | self.layers.append(GraphConvolution(hidden_size, hidden_size)) 123 | self.layers.append(GraphConvolution(hidden_size, num_classes)) 124 | self.dropout = nn.Dropout(p=dropout) 125 | 126 | def forward(self, x, adj): 127 | for i, layer in enumerate(self.layers): 128 | x = self.dropout(F.relu(layer(x, adj))) 129 | 130 | # x of shape (bs, q_v_len, num_classes) 131 | return x 132 | 133 | 134 | class AdjLearner(Module): 135 | 136 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1): 137 | super().__init__() 138 | ''' 139 | ## Variables: 140 | - in_feature_dim: dimensionality of input features 141 | - hidden_size: dimensionality of the joint hidden embedding 142 | - K: number of graph nodes/objects on the image 143 | ''' 144 | 145 | # Embedding layers. Padded 0 => 0 146 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False) 147 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False) 148 | 149 | # Regularisation 150 | self.dropout = nn.Dropout(p=dropout) 151 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1) 152 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2) 153 | 154 | def forward(self, questions, videos): 155 | ''' 156 | ## Inputs: 157 | ## Returns: 158 | - adjacency matrix (batch_size, q_v_len, q_v_len) 159 | ''' 160 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features 161 | graph_nodes = torch.cat((questions, videos), dim=1) 162 | 163 | # layer 1 164 | h = self.edge_layer_1(graph_nodes) 165 | h = F.relu(h) 166 | 167 | # layer 2 168 | h = self.edge_layer_2(h) 169 | h = F.relu(h) 170 | # h * sigmoid(Wh) 171 | # h = F.tanh(h) 172 | 173 | # outer product 174 | adjacency_matrix = torch.bmm(h, h.transpose(1, 2)) 175 | 176 | return adjacency_matrix 177 | 178 | 179 | class EvoAdjLearner(Module): 180 | 181 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1): 182 | super().__init__() 183 | ''' 184 | ## Variables: 185 | - in_feature_dim: dimensionality of input features 186 | - hidden_size: dimensionality of the joint hidden embedding 187 | - K: number of graph nodes/objects on the image 188 | ''' 189 | 190 | # Embedding layers. Padded 0 => 0 191 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False) 192 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False) 193 | self.edge_layer_3 = nn.Linear(in_feature_dim, hidden_size, bias=False) 194 | self.edge_layer_4 = nn.Linear(hidden_size, hidden_size, bias=False) 195 | 196 | # Regularisation 197 | self.dropout = nn.Dropout(p=dropout) 198 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1) 199 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2) 200 | 201 | def forward(self, questions, videos): 202 | ''' 203 | ## Inputs: 204 | ## Returns: 205 | - adjacency matrix (batch_size, q_v_len, q_v_len) 206 | ''' 207 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features 208 | graph_nodes = torch.cat((questions, videos), dim=1) 209 | 210 | attn_mask = padding_mask_k(graph_nodes, graph_nodes) 211 | sf_mask = padding_mask_q(graph_nodes, graph_nodes) 212 | 213 | # layer 1 214 | h = self.edge_layer_1(graph_nodes) 215 | h = F.relu(h) 216 | # layer 2 217 | h = self.edge_layer_2(h) 218 | # h = F.relu(h) 219 | 220 | # layer 1 221 | h_ = self.edge_layer_3(graph_nodes) 222 | h_ = F.relu(h_) 223 | # layer 2 224 | h_ = self.edge_layer_4(h_) 225 | # h_ = F.relu(h_) 226 | 227 | # outer product 228 | adjacency_matrix = torch.bmm(h, h_.transpose(1, 2)) 229 | # adjacency_matrix = adjacency_matrix.masked_fill(attn_mask, -np.inf) 230 | 231 | # softmaxed_adj = F.softmax(adjacency_matrix, dim=-1) 232 | 233 | # softmaxed_adj = softmaxed_adj.masked_fill(sf_mask, 0.) 234 | 235 | return adjacency_matrix 236 | 237 | class AdjGenerator(Module): 238 | 239 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1): 240 | super().__init__() 241 | ''' 242 | ## Variables: 243 | - in_feature_dim: dimensionality of input features 244 | - hidden_size: dimensionality of the joint hidden embedding 245 | - K: number of graph nodes/objects on the image 246 | ''' 247 | 248 | # Embedding layers. Padded 0 => 0 249 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False) 250 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False) 251 | 252 | # Regularisation 253 | self.dropout = nn.Dropout(p=dropout) 254 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1) 255 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2) 256 | 257 | def forward(self, features, adjacency=None): 258 | ''' 259 | ## Inputs: 260 | ## Returns: 261 | - adjacency matrix (batch_size, q_v_len, q_v_len) 262 | ''' 263 | 264 | # layer 1 265 | h = self.edge_layer_1(features) 266 | h = F.relu(h) 267 | 268 | # layer 2 269 | h = self.edge_layer_2(h) 270 | h = F.relu(h) 271 | # h * sigmoid(Wh) 272 | # h = F.tanh(h) 273 | 274 | # outer product 275 | adjacency_matrix = torch.bmm(h, h.transpose(1, 2)) 276 | if adjacency is not None: 277 | adjacency_matrix = adjacency_matrix*adjacency[:, :adjacency_matrix.shape[1], :adjacency_matrix.shape[2]] 278 | 279 | return adjacency_matrix -------------------------------------------------------------------------------- /networks/Transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchnlp_nn as nlpnn 6 | 7 | 8 | def padding_mask(seq_q, seq_k): 9 | # seq_k of shape (batch, k_len) and seq_q (batch, q_len), not embedded. q and k are padded with 0. 10 | seq_q = torch.unsqueeze(seq_q, 2) 11 | seq_k = torch.unsqueeze(seq_k, 2) 12 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2)) 13 | pad_mask = pad_mask.eq(0) 14 | return pad_mask 15 | 16 | 17 | def padding_mask_transformer(seq_q, seq_k): 18 | # original padding_mask in transformer, for masking out the padding part of key sequence. 19 | len_q = seq_q.size(1) 20 | # `PAD` is 0 21 | pad_mask = seq_k.eq(0) 22 | pad_mask = pad_mask.unsqueeze(1).expand( 23 | -1, len_q, -1) # shape [B, L_q, L_k] 24 | return pad_mask 25 | 26 | 27 | def padding_mask_embedded(seq_q, seq_k): 28 | # seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len) 29 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2)) 30 | pad_mask = pad_mask.eq(0) 31 | return pad_mask 32 | 33 | 34 | def padding_mask_k(seq_q, seq_k): 35 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 36 | In batch 0: 37 | [[x x x 0] [[0 0 0 1] 38 | [x x x 0]-> [0 0 0 1] 39 | [x x x 0]] [0 0 0 1]] uint8 40 | """ 41 | fake_q = torch.ones_like(seq_q) 42 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2)) 43 | pad_mask = pad_mask.eq(0) 44 | # pad_mask = pad_mask.lt(1e-3) 45 | return pad_mask 46 | 47 | 48 | def padding_mask_q(seq_q, seq_k): 49 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 50 | In batch 0: 51 | [[x x x x] [[0 0 0 0] 52 | [x x x x] -> [0 0 0 0] 53 | [0 0 0 0]] [1 1 1 1]] uint8 54 | """ 55 | fake_k = torch.ones_like(seq_k) 56 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2)) 57 | pad_mask = pad_mask.eq(0) 58 | # pad_mask = pad_mask.lt(1e-3) 59 | return pad_mask 60 | 61 | 62 | class PositionalEncoding(nn.Module): 63 | 64 | def __init__(self, d_model, max_seq_len): 65 | super(PositionalEncoding, self).__init__() 66 | self.max_seq_len = max_seq_len 67 | 68 | position_encoding = np.array( 69 | [ 70 | [ 71 | pos / np.power(10000, 2.0 * (j // 2) / d_model) 72 | for j in range(d_model) 73 | ] 74 | for pos in range(max_seq_len) 75 | ]) 76 | position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2]) 77 | position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2]) 78 | 79 | pad_row = torch.zeros([1, d_model]) 80 | position_encoding = torch.cat( 81 | (pad_row, torch.from_numpy(position_encoding).float())) 82 | 83 | self.position_encoding = nn.Embedding(max_seq_len + 1, d_model) 84 | self.position_encoding.weight = nn.Parameter( 85 | position_encoding, requires_grad=False) 86 | 87 | def forward(self, input_len): 88 | # max_len = torch.max(input_len) 89 | max_len = self.max_seq_len 90 | tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor 91 | input_pos = [ 92 | list(range(1, l + 1)) + [0] * (max_len - l.item()) 93 | for l in input_len 94 | ] 95 | input_pos = tensor(input_pos) 96 | return self.position_encoding(input_pos) 97 | 98 | 99 | class PositionalWiseFeedForward(nn.Module): 100 | 101 | def __init__(self, model_dim=512, ffn_dim=512, dropout=0.0): 102 | super(PositionalWiseFeedForward, self).__init__() 103 | self.w1 = nn.Conv1d(model_dim, ffn_dim, 1) 104 | self.w2 = nn.Conv1d(model_dim, ffn_dim, 1) 105 | self.dropout = nn.Dropout(dropout) 106 | self.layer_norm = nn.LayerNorm(model_dim) 107 | 108 | def forward(self, x): 109 | # x of shape (bs, seq_len, hs) 110 | output = x.transpose(1, 2) 111 | output = self.w2(F.relu(self.w1(output))) 112 | output = self.dropout(output.transpose(1, 2)) 113 | 114 | # add residual and norm layer 115 | output = self.layer_norm(x + output) 116 | return output 117 | 118 | 119 | class MaskedPositionalWiseFeedForward(nn.Module): 120 | 121 | def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0): 122 | super().__init__() 123 | self.w1 = nn.Linear(model_dim, ffn_dim, bias=False) 124 | self.w2 = nn.Linear(ffn_dim, model_dim, bias=False) 125 | self.dropout = nn.Dropout(dropout) 126 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False) 127 | 128 | def forward(self, x): 129 | # x of shape (bs, seq_len, hs) 130 | output = self.w2(F.relu(self.w1(x))) 131 | output = self.dropout(output) 132 | 133 | # add residual and norm layer 134 | output = self.layer_norm(x + output) 135 | return output 136 | 137 | 138 | class ScaledDotProductAttention(nn.Module): 139 | """Scaled dot-product attention mechanism.""" 140 | 141 | def __init__(self, attention_dropout=0.0): 142 | super(ScaledDotProductAttention, self).__init__() 143 | self.dropout = nn.Dropout(attention_dropout) 144 | self.softmax = nn.Softmax(dim=-1) 145 | 146 | def forward(self, q, k, v, scale=None, attn_mask=None): 147 | """ 148 | Args: 149 | q: [B, L_q, D_q] 150 | k: [B, L_k, D_k] 151 | v: [B, L_v, D_v] 152 | """ 153 | attention = torch.matmul(q, k.transpose(1, 2)) 154 | if scale is not None: 155 | attention = attention * scale 156 | if attn_mask is not None: 157 | attention = attention.masked_fill(attn_mask, -np.inf) 158 | attention = self.softmax(attention) 159 | attention = self.dropout(attention) 160 | output = torch.matmul(attention, v) 161 | return output, attention 162 | 163 | 164 | class MaskedScaledDotProductAttention(nn.Module): 165 | """Scaled dot-product attention mechanism.""" 166 | 167 | def __init__(self, attention_dropout=0.0): 168 | super().__init__() 169 | self.dropout = nn.Dropout(attention_dropout) 170 | self.softmax = nn.Softmax(dim=-1) 171 | 172 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None): 173 | """ 174 | Args: 175 | q: [B, L_q, D_q] 176 | k: [B, L_k, D_k] 177 | v: [B, L_v, D_v] 178 | """ 179 | attention = torch.matmul(q, k.transpose(-2, -1)) 180 | if scale is not None: 181 | attention = attention * scale 182 | if attn_mask is not None: 183 | attention = attention.masked_fill(attn_mask, -np.inf) 184 | attention = self.softmax(attention) 185 | attention = attention.masked_fill(softmax_mask, 0.) 186 | attention = self.dropout(attention) 187 | output = torch.matmul(attention, v) 188 | return output, attention 189 | 190 | 191 | class MultiHeadAttention(nn.Module): 192 | 193 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0): 194 | super(MultiHeadAttention, self).__init__() 195 | 196 | self.dim_per_head = model_dim // num_heads 197 | self.num_heads = num_heads 198 | self.linear_k = nn.Linear( 199 | model_dim, self.dim_per_head * num_heads, bias=False) 200 | self.linear_v = nn.Linear( 201 | model_dim, self.dim_per_head * num_heads, bias=False) 202 | self.linear_q = nn.Linear( 203 | model_dim, self.dim_per_head * num_heads, bias=False) 204 | 205 | self.dot_product_attention = ScaledDotProductAttention(dropout) 206 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False) 207 | self.dropout = nn.Dropout(dropout) 208 | self.layer_norm = nn.LayerNorm(model_dim) 209 | 210 | def forward(self, query, key, value, attn_mask=None): 211 | residual = query 212 | 213 | dim_per_head = self.dim_per_head 214 | num_heads = self.num_heads 215 | batch_size = key.size(0) 216 | 217 | # linear projection 218 | key = self.linear_k(key) 219 | value = self.linear_v(value) 220 | query = self.linear_q(query) 221 | 222 | # split by heads 223 | key = key.view(batch_size * num_heads, -1, dim_per_head) 224 | value = value.view(batch_size * num_heads, -1, dim_per_head) 225 | query = query.view(batch_size * num_heads, -1, dim_per_head) 226 | 227 | if attn_mask is not None: 228 | attn_mask = attn_mask.repeat(num_heads, 1, 1) 229 | 230 | # scaled dot product attention 231 | scale = (key.size(-1) // num_heads)**-0.5 232 | context, attention = self.dot_product_attention( 233 | query, key, value, scale, attn_mask) 234 | 235 | # concat heads 236 | context = context.view(batch_size, -1, dim_per_head * num_heads) 237 | 238 | # final linear projection 239 | output = self.linear_final(context) 240 | 241 | # dropout 242 | output = self.dropout(output) 243 | 244 | # add residual and norm layer 245 | output = self.layer_norm(residual + output) 246 | 247 | return output, attention 248 | 249 | 250 | class MaskedMultiHeadAttention(nn.Module): 251 | 252 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0): 253 | super().__init__() 254 | 255 | self.dim_per_head = model_dim // num_heads 256 | self.num_heads = num_heads 257 | self.linear_k = nn.Linear( 258 | model_dim, self.dim_per_head * num_heads, bias=False) 259 | self.linear_v = nn.Linear( 260 | model_dim, self.dim_per_head * num_heads, bias=False) 261 | self.linear_q = nn.Linear( 262 | model_dim, self.dim_per_head * num_heads, bias=False) 263 | 264 | self.dot_product_attention = MaskedScaledDotProductAttention(dropout) 265 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False) 266 | self.dropout = nn.Dropout(dropout) 267 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False) 268 | 269 | def forward(self, query, key, value, attn_mask=None, softmax_mask=None): 270 | residual = query 271 | 272 | dim_per_head = self.dim_per_head 273 | num_heads = self.num_heads 274 | batch_size = key.size(0) 275 | 276 | # linear projection 277 | key = self.linear_k(key) 278 | value = self.linear_v(value) 279 | query = self.linear_q(query) 280 | 281 | # split by heads 282 | key = key.view(batch_size, -1, num_heads, dim_per_head).transpose(1, 2) 283 | value = value.view(batch_size, -1, num_heads, 284 | dim_per_head).transpose(1, 2) 285 | query = query.view(batch_size, -1, num_heads, 286 | dim_per_head).transpose(1, 2) 287 | 288 | if attn_mask is not None: 289 | attn_mask = attn_mask.unsqueeze(1).repeat(1, num_heads, 1, 1) 290 | if softmax_mask is not None: 291 | softmax_mask = softmax_mask.unsqueeze(1).repeat(1, num_heads, 1, 1) 292 | # scaled dot product attention 293 | # key.size(-1) is 64? 294 | scale = key.size(-1)**-0.5 295 | context, attention = self.dot_product_attention( 296 | query, key, value, scale, attn_mask, softmax_mask) 297 | 298 | # concat heads 299 | context = context.transpose(1, 2).contiguous().view( 300 | batch_size, -1, dim_per_head * num_heads) 301 | 302 | # final linear projection 303 | output = self.linear_final(context) 304 | 305 | # dropout 306 | output = self.dropout(output) 307 | 308 | # add residual and norm layer 309 | output = self.layer_norm(residual + output) 310 | 311 | return output, attention 312 | 313 | 314 | class SelfTransformerLayer(nn.Module): 315 | 316 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): 317 | super().__init__() 318 | 319 | self.transformer = MaskedMultiHeadAttention( 320 | model_dim, num_heads, dropout) 321 | self.feed_forward = MaskedPositionalWiseFeedForward( 322 | model_dim, ffn_dim, dropout) 323 | 324 | def forward(self, input, attn_mask=None, sf_mask=None): 325 | output, attention = self.transformer( 326 | input, input, input, attn_mask, sf_mask) 327 | # feed forward network 328 | output = self.feed_forward(output) 329 | 330 | return output, attention 331 | 332 | 333 | class SelfTransformer(nn.Module): 334 | 335 | def __init__( 336 | self, 337 | max_len=35, 338 | num_layers=2, 339 | model_dim=512, 340 | num_heads=8, 341 | ffn_dim=2048, 342 | dropout=0.0, 343 | position=False): 344 | super().__init__() 345 | 346 | self.position = position 347 | 348 | self.encoder_layers = nn.ModuleList( 349 | [ 350 | SelfTransformerLayer(model_dim, num_heads, ffn_dim, dropout) 351 | for _ in range(num_layers) 352 | ]) 353 | 354 | # max_seq_len is 35 or 80 355 | self.pos_embedding = PositionalEncoding(model_dim, max_len) 356 | 357 | def forward(self, input, input_length): 358 | # q_length of shape (batch, ), each item is the length of the seq 359 | if self.position: 360 | input += self.pos_embedding(input_length)[:, :input.size()[1], :] 361 | 362 | attention_mask = padding_mask_k(input, input) 363 | softmax_mask = padding_mask_q(input, input) 364 | 365 | attentions = [] 366 | for encoder in self.encoder_layers: 367 | input, attention = encoder(input, attention_mask, softmax_mask) 368 | attentions.append(attention) 369 | 370 | return input, attentions 371 | 372 | 373 | class SelfAttentionLayer(nn.Module): 374 | 375 | def __init__(self, hidden_size, dropout_p=0.0): 376 | super().__init__() 377 | self.dropout = nn.Dropout(dropout_p) 378 | self.softmax = nn.Softmax(dim=-1) 379 | 380 | self.linear_k = nlpnn.WeightDropLinear( 381 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 382 | self.linear_q = nlpnn.WeightDropLinear( 383 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 384 | self.linear_v = nlpnn.WeightDropLinear( 385 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 386 | 387 | self.linear_final = nlpnn.WeightDropLinear( 388 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 389 | 390 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) 391 | 392 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None): 393 | """ 394 | Args: 395 | q: [B, L_q, D_q] 396 | k: [B, L_k, D_k] 397 | v: [B, L_v, D_v] 398 | """ 399 | residual = q 400 | 401 | if attn_mask is None or softmax_mask is None: 402 | attn_mask = padding_mask_k(q, k) 403 | softmax_mask = padding_mask_q(q, k) 404 | 405 | # linear projection 406 | k = self.linear_k(k) 407 | v = self.linear_v(v) 408 | q = self.linear_q(q) 409 | 410 | scale = k.size(-1)**-0.5 411 | 412 | attention = torch.bmm(q, k.transpose(1, 2)) 413 | if scale is not None: 414 | attention = attention * scale 415 | if attn_mask is not None: 416 | attention = attention.masked_fill(attn_mask, -np.inf) 417 | attention = self.softmax(attention) 418 | attention = attention.masked_fill(softmax_mask, 0.) 419 | 420 | # attention = self.dropout(attention) 421 | output = torch.bmm(attention, v) 422 | output = self.linear_final(output) 423 | output = self.layer_norm(output + residual) 424 | return output, attention 425 | 426 | 427 | class SelfAttention(nn.Module): 428 | 429 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 430 | super().__init__() 431 | 432 | self.encoder_layers = nn.ModuleList( 433 | [ 434 | SelfAttentionLayer(hidden_size, dropout_p) 435 | for _ in range(n_layers) 436 | ]) 437 | 438 | def forward(self, input): 439 | 440 | # q_attention_mask of shape (bs, q_len, v_len) 441 | attn_mask = padding_mask_k(input, input) 442 | # v_attention_mask of shape (bs, v_len, q_len) 443 | softmax_mask = padding_mask_q(input, input) 444 | 445 | attentions = [] 446 | for encoder in self.encoder_layers: 447 | input, attention = encoder( 448 | input, 449 | input, 450 | input, 451 | attn_mask=attn_mask, 452 | softmax_mask=softmax_mask) 453 | attentions.append(attention) 454 | 455 | return input, attentions 456 | 457 | 458 | class CoAttentionLayer(nn.Module): 459 | 460 | def __init__(self, hidden_size, dropout_p=0.0): 461 | super().__init__() 462 | self.dropout = nn.Dropout(dropout_p) 463 | self.softmax = nn.Softmax(dim=-1) 464 | 465 | self.linear_question = nlpnn.WeightDropLinear( 466 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 467 | self.linear_video = nlpnn.WeightDropLinear( 468 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 469 | self.linear_v_question = nlpnn.WeightDropLinear( 470 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 471 | self.linear_v_video = nlpnn.WeightDropLinear( 472 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 473 | 474 | self.linear_final_qv = nlpnn.WeightDropLinear( 475 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 476 | self.linear_final_vq = nlpnn.WeightDropLinear( 477 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 478 | 479 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False) 480 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False) 481 | 482 | def forward( 483 | self, 484 | question, 485 | video, 486 | scale=None, 487 | attn_mask=None, 488 | softmax_mask=None, 489 | attn_mask_=None, 490 | softmax_mask_=None): 491 | """ 492 | Args: 493 | q: [B, L_q, D_q] 494 | k: [B, L_k, D_k] 495 | v: [B, L_v, D_v] 496 | """ 497 | q = question 498 | v = video 499 | 500 | if attn_mask is None or softmax_mask is None: 501 | attn_mask = padding_mask_k(question, video) 502 | softmax_mask = padding_mask_q(question, video) 503 | if attn_mask_ is None or softmax_mask_ is None: 504 | attn_mask_ = padding_mask_k(video, question) 505 | softmax_mask_ = padding_mask_q(video, question) 506 | 507 | # linear projection 508 | question_q = self.linear_question(question) 509 | video_k = self.linear_video(video) 510 | question = self.linear_v_question(question) 511 | video = self.linear_v_video(video) 512 | 513 | scale = video.size(-1)**-0.5 514 | 515 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2)) 516 | if scale is not None: 517 | attention_qv = attention_qv * scale 518 | if attn_mask is not None: 519 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf) 520 | attention_qv = self.softmax(attention_qv) 521 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.) 522 | 523 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2)) 524 | if scale is not None: 525 | attention_vq = attention_vq * scale 526 | if attn_mask_ is not None: 527 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf) 528 | attention_vq = self.softmax(attention_vq) 529 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.) 530 | 531 | # attention = self.dropout(attention) 532 | output_qv = torch.bmm(attention_qv, video) 533 | output_qv = self.linear_final_qv(output_qv) 534 | output_q = self.layer_norm_qv(output_qv + q) 535 | 536 | output_vq = torch.bmm(attention_vq, question) 537 | output_vq = self.linear_final_vq(output_vq) 538 | output_v = self.layer_norm_vq(output_vq + v) 539 | return output_q, output_v 540 | 541 | 542 | class CoAttention(nn.Module): 543 | 544 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 545 | super().__init__() 546 | 547 | self.encoder_layers = nn.ModuleList( 548 | [CoAttentionLayer(hidden_size, dropout_p) for _ in range(n_layers)]) 549 | 550 | def forward(self, question, video): 551 | attn_mask = padding_mask_k(question, video) 552 | softmax_mask = padding_mask_q(question, video) 553 | attn_mask_ = padding_mask_k(video, question) 554 | softmax_mask_ = padding_mask_q(video, question) 555 | 556 | for encoder in self.encoder_layers: 557 | question, video = encoder( 558 | question, 559 | video, 560 | attn_mask=attn_mask, 561 | softmax_mask=softmax_mask, 562 | attn_mask_=attn_mask_, 563 | softmax_mask_=softmax_mask_) 564 | 565 | return question, video 566 | 567 | 568 | class CoConcatAttentionLayer(nn.Module): 569 | 570 | def __init__(self, hidden_size, dropout_p=0.0): 571 | super().__init__() 572 | self.dropout = nn.Dropout(dropout_p) 573 | self.softmax = nn.Softmax(dim=-1) 574 | 575 | self.linear_question = nlpnn.WeightDropLinear( 576 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 577 | self.linear_video = nlpnn.WeightDropLinear( 578 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 579 | self.linear_v_question = nlpnn.WeightDropLinear( 580 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 581 | self.linear_v_video = nlpnn.WeightDropLinear( 582 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 583 | 584 | self.linear_final_qv = nn.Sequential( 585 | nlpnn.WeightDropLinear( 586 | 2 * hidden_size, 587 | hidden_size, 588 | weight_dropout=dropout_p, 589 | bias=False), nn.ReLU(), 590 | nlpnn.WeightDropLinear( 591 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)) 592 | self.linear_final_vq = nn.Sequential( 593 | nlpnn.WeightDropLinear( 594 | 2 * hidden_size, 595 | hidden_size, 596 | weight_dropout=dropout_p, 597 | bias=False), nn.ReLU(), 598 | nlpnn.WeightDropLinear( 599 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)) 600 | 601 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False) 602 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False) 603 | 604 | def forward( 605 | self, 606 | question, 607 | video, 608 | scale=None, 609 | attn_mask=None, 610 | softmax_mask=None, 611 | attn_mask_=None, 612 | softmax_mask_=None): 613 | """ 614 | Args: 615 | q: [B, L_q, D_q] 616 | k: [B, L_k, D_k] 617 | v: [B, L_v, D_v] 618 | """ 619 | q = question 620 | v = video 621 | 622 | if attn_mask is None or softmax_mask is None: 623 | attn_mask = padding_mask_k(question, video) 624 | softmax_mask = padding_mask_q(question, video) 625 | if attn_mask_ is None or softmax_mask_ is None: 626 | attn_mask_ = padding_mask_k(video, question) 627 | softmax_mask_ = padding_mask_q(video, question) 628 | 629 | # linear projection 630 | question_q = self.linear_question(question) 631 | video_k = self.linear_video(video) 632 | question = self.linear_v_question(question) 633 | video = self.linear_v_video(video) 634 | 635 | scale = video.size(-1)**-0.5 636 | 637 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2)) 638 | if scale is not None: 639 | attention_qv = attention_qv * scale 640 | if attn_mask is not None: 641 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf) 642 | attention_qv = self.softmax(attention_qv) 643 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.) 644 | 645 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2)) 646 | if scale is not None: 647 | attention_vq = attention_vq * scale 648 | if attn_mask_ is not None: 649 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf) 650 | attention_vq = self.softmax(attention_vq) 651 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.) 652 | 653 | # attention = self.dropout(attention) 654 | output_qv = torch.bmm(attention_qv, video) 655 | output_qv = self.linear_final_qv(torch.cat((output_qv, q), dim=-1)) 656 | # output_q = self.layer_norm_qv(output_qv + q) 657 | output_q = self.layer_norm_qv(output_qv) 658 | 659 | output_vq = torch.bmm(attention_vq, question) 660 | output_vq = self.linear_final_vq(torch.cat((output_vq, v), dim=-1)) 661 | # output_v = self.layer_norm_vq(output_vq + v) 662 | output_v = self.layer_norm_vq(output_vq) 663 | return output_q, output_v 664 | 665 | 666 | class CoConcatAttention(nn.Module): 667 | 668 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 669 | super().__init__() 670 | 671 | self.encoder_layers = nn.ModuleList( 672 | [ 673 | CoConcatAttentionLayer(hidden_size, dropout_p) 674 | for _ in range(n_layers) 675 | ]) 676 | 677 | def forward(self, question, video): 678 | attn_mask = padding_mask_k(question, video) 679 | softmax_mask = padding_mask_q(question, video) 680 | attn_mask_ = padding_mask_k(video, question) 681 | softmax_mask_ = padding_mask_q(video, question) 682 | 683 | for encoder in self.encoder_layers: 684 | question, video = encoder( 685 | question, 686 | video, 687 | attn_mask=attn_mask, 688 | softmax_mask=softmax_mask, 689 | attn_mask_=attn_mask_, 690 | softmax_mask_=softmax_mask_) 691 | 692 | return question, video 693 | 694 | 695 | class CoSiameseAttentionLayer(nn.Module): 696 | 697 | def __init__(self, hidden_size, dropout_p=0.0): 698 | super().__init__() 699 | self.dropout = nn.Dropout(dropout_p) 700 | self.softmax = nn.Softmax(dim=-1) 701 | 702 | self.linear_question = nlpnn.WeightDropLinear( 703 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 704 | self.linear_video = nlpnn.WeightDropLinear( 705 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 706 | self.linear_v_question = nlpnn.WeightDropLinear( 707 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 708 | self.linear_v_video = nlpnn.WeightDropLinear( 709 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 710 | 711 | self.linear_final = nn.Sequential( 712 | nlpnn.WeightDropLinear( 713 | 2 * hidden_size, 714 | hidden_size, 715 | weight_dropout=dropout_p, 716 | bias=False), nn.ReLU(), 717 | nlpnn.WeightDropLinear( 718 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)) 719 | 720 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False) 721 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False) 722 | 723 | def forward( 724 | self, 725 | question, 726 | video, 727 | scale=None, 728 | attn_mask=None, 729 | softmax_mask=None, 730 | attn_mask_=None, 731 | softmax_mask_=None): 732 | """ 733 | Args: 734 | q: [B, L_q, D_q] 735 | k: [B, L_k, D_k] 736 | v: [B, L_v, D_v] 737 | """ 738 | q = question 739 | v = video 740 | 741 | if attn_mask is None or softmax_mask is None: 742 | attn_mask = padding_mask_k(question, video) 743 | softmax_mask = padding_mask_q(question, video) 744 | if attn_mask_ is None or softmax_mask_ is None: 745 | attn_mask_ = padding_mask_k(video, question) 746 | softmax_mask_ = padding_mask_q(video, question) 747 | 748 | # linear projection 749 | question_q = self.linear_question(question) 750 | video_k = self.linear_video(video) 751 | question = self.linear_v_question(question) 752 | video = self.linear_v_video(video) 753 | 754 | scale = video.size(-1)**-0.5 755 | 756 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2)) 757 | if scale is not None: 758 | attention_qv = attention_qv * scale 759 | if attn_mask is not None: 760 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf) 761 | attention_qv = self.softmax(attention_qv) 762 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.) 763 | 764 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2)) 765 | if scale is not None: 766 | attention_vq = attention_vq * scale 767 | if attn_mask_ is not None: 768 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf) 769 | attention_vq = self.softmax(attention_vq) 770 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.) 771 | 772 | # attention = self.dropout(attention) 773 | output_qv = torch.bmm(attention_qv, video) 774 | output_qv = self.linear_final(torch.cat((output_qv, q), dim=-1)) 775 | # output_q = self.layer_norm_qv(output_qv + q) 776 | output_q = self.layer_norm_qv(output_qv) 777 | 778 | output_vq = torch.bmm(attention_vq, question) 779 | output_vq = self.linear_final(torch.cat((output_vq, v), dim=-1)) 780 | # output_v = self.layer_norm_vq(output_vq + v) 781 | output_v = self.layer_norm_vq(output_vq) 782 | return output_q, output_v 783 | 784 | 785 | class CoSiameseAttention(nn.Module): 786 | 787 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 788 | super().__init__() 789 | 790 | self.encoder_layers = nn.ModuleList( 791 | [ 792 | CoSiameseAttentionLayer(hidden_size, dropout_p) 793 | for _ in range(n_layers) 794 | ]) 795 | 796 | def forward(self, question, video): 797 | attn_mask = padding_mask_k(question, video) 798 | softmax_mask = padding_mask_q(question, video) 799 | attn_mask_ = padding_mask_k(video, question) 800 | softmax_mask_ = padding_mask_q(video, question) 801 | 802 | for encoder in self.encoder_layers: 803 | question, video = encoder( 804 | question, 805 | video, 806 | attn_mask=attn_mask, 807 | softmax_mask=softmax_mask, 808 | attn_mask_=attn_mask_, 809 | softmax_mask_=softmax_mask_) 810 | 811 | return question, video 812 | 813 | 814 | class SingleAttentionLayer(nn.Module): 815 | 816 | def __init__(self, hidden_size, dropout_p=0.0): 817 | super().__init__() 818 | self.dropout = nn.Dropout(dropout_p) 819 | self.softmax = nn.Softmax(dim=-1) 820 | 821 | self.linear_q = nlpnn.WeightDropLinear( 822 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 823 | self.linear_v = nlpnn.WeightDropLinear( 824 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 825 | self.linear_k = nlpnn.WeightDropLinear( 826 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 827 | 828 | self.linear_final = nlpnn.WeightDropLinear( 829 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 830 | 831 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) 832 | 833 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None): 834 | """ 835 | Args: 836 | q: [B, L_q, D_q] 837 | k: [B, L_k, D_k] 838 | v: [B, L_v, D_v] 839 | Return: Same shape to q, but in 'v' space, soft knn 840 | """ 841 | 842 | if attn_mask is None or softmax_mask is None: 843 | attn_mask = padding_mask_k(q, k) 844 | softmax_mask = padding_mask_q(q, k) 845 | 846 | # linear projection 847 | q = self.linear_q(q) 848 | k = self.linear_k(k) 849 | v = self.linear_v(v) 850 | 851 | scale = v.size(-1)**-0.5 852 | 853 | attention = torch.bmm(q, k.transpose(-2, -1)) 854 | if scale is not None: 855 | attention = attention * scale 856 | if attn_mask is not None: 857 | attention = attention.masked_fill(attn_mask, -np.inf) 858 | attention = self.softmax(attention) 859 | attention = attention.masked_fill(softmax_mask, 0.) 860 | 861 | # attention = self.dropout(attention) 862 | output = torch.bmm(attention, v) 863 | output = self.linear_final(output) 864 | output = self.layer_norm(output + q) 865 | 866 | return output 867 | 868 | 869 | class SingleAttention(nn.Module): 870 | 871 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 872 | super().__init__() 873 | 874 | self.encoder_layers = nn.ModuleList( 875 | [ 876 | SingleAttentionLayer(hidden_size, dropout_p) 877 | for _ in range(n_layers) 878 | ]) 879 | 880 | def forward(self, q, v): 881 | attn_mask = padding_mask_k(q, v) 882 | softmax_mask = padding_mask_q(q, v) 883 | 884 | for encoder in self.encoder_layers: 885 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask) 886 | 887 | return q 888 | 889 | class SingleAttention(nn.Module): 890 | 891 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 892 | super().__init__() 893 | 894 | self.encoder_layers = nn.ModuleList( 895 | [ 896 | SingleAttentionLayer(hidden_size, dropout_p) 897 | for _ in range(n_layers) 898 | ]) 899 | 900 | def forward(self, q, v): 901 | attn_mask = padding_mask_k(q, v) 902 | softmax_mask = padding_mask_q(q, v) 903 | 904 | for encoder in self.encoder_layers: 905 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask) 906 | 907 | return q 908 | 909 | class SingleSimpleAttentionLayer(nn.Module): 910 | 911 | def __init__(self, hidden_size, dropout_p=0.0): 912 | super().__init__() 913 | self.dropout = nn.Dropout(dropout_p) 914 | self.softmax = nn.Softmax(dim=-1) 915 | 916 | self.linear_final = nlpnn.WeightDropLinear( 917 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 918 | 919 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) 920 | 921 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None): 922 | """ 923 | Args: 924 | q: [B, L_q, D_q] 925 | k: [B, L_k, D_k] 926 | v: [B, L_v, D_v] 927 | Return: Same shape to q, but in 'v' space, soft knn 928 | """ 929 | 930 | if attn_mask is None or softmax_mask is None: 931 | attn_mask = padding_mask_k(q, k) 932 | softmax_mask = padding_mask_q(q, k) 933 | 934 | # linear projection 935 | 936 | scale = v.size(-1)**-0.5 937 | 938 | attention = torch.bmm(q, k.transpose(-2, -1)) 939 | if scale is not None: 940 | attention = attention * scale 941 | if attn_mask is not None: 942 | attention = attention.masked_fill(attn_mask, -np.inf) 943 | attention = self.softmax(attention) 944 | attention = attention.masked_fill(softmax_mask, 0.) 945 | 946 | # attention = self.dropout(attention) 947 | output = torch.bmm(attention, v) 948 | output = self.linear_final(output) 949 | output = self.layer_norm(output + q) 950 | 951 | return output 952 | 953 | 954 | class SingleSimpleAttention(nn.Module): 955 | 956 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 957 | super().__init__() 958 | 959 | self.encoder_layers = nn.ModuleList( 960 | [ 961 | SingleSimpleAttentionLayer(hidden_size, dropout_p) 962 | for _ in range(n_layers) 963 | ]) 964 | 965 | def forward(self, q, v): 966 | attn_mask = padding_mask_k(q, v) 967 | softmax_mask = padding_mask_q(q, v) 968 | 969 | for encoder in self.encoder_layers: 970 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask) 971 | 972 | return q 973 | 974 | class SoftKNN(nn.Module): 975 | 976 | def __init__(self, model_dim=512, num_heads=1, dropout=0.0): 977 | super().__init__() 978 | 979 | self.dim_per_head = model_dim // num_heads 980 | self.num_heads = num_heads 981 | self.linear_k = nn.Linear( 982 | model_dim, self.dim_per_head * num_heads, bias=False) 983 | self.linear_v = nn.Linear( 984 | model_dim, self.dim_per_head * num_heads, bias=False) 985 | self.linear_q = nn.Linear( 986 | model_dim, self.dim_per_head * num_heads, bias=False) 987 | 988 | self.dot_product_attention = ScaledDotProductAttention(dropout) 989 | 990 | def forward(self, query, key, value, attn_mask=None): 991 | 992 | dim_per_head = self.dim_per_head 993 | num_heads = self.num_heads 994 | batch_size = key.size(0) 995 | 996 | # linear projection 997 | key = self.linear_k(key) 998 | value = self.linear_v(value) 999 | query = self.linear_q(query) 1000 | 1001 | # split by heads 1002 | key = key.view(batch_size * num_heads, -1, dim_per_head) 1003 | value = value.view(batch_size * num_heads, -1, dim_per_head) 1004 | query = query.view(batch_size * num_heads, -1, dim_per_head) 1005 | 1006 | if attn_mask is not None: 1007 | attn_mask = attn_mask.repeat(num_heads, 1, 1) 1008 | # scaled dot product attention 1009 | scale = (key.size(-1) // num_heads)**-0.5 1010 | context, attention = self.dot_product_attention( 1011 | query, key, value, scale, attn_mask) 1012 | 1013 | # concat heads 1014 | output = context.view(batch_size, -1, dim_per_head * num_heads) 1015 | 1016 | return output, attention 1017 | 1018 | 1019 | class CrossoverTransformerLayer(nn.Module): 1020 | 1021 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): 1022 | super().__init__() 1023 | 1024 | self.v_transformer = MultiHeadAttention(model_dim, num_heads, dropout) 1025 | self.q_transformer = MultiHeadAttention(model_dim, num_heads, dropout) 1026 | self.v_feed_forward = PositionalWiseFeedForward( 1027 | model_dim, ffn_dim, dropout) 1028 | self.q_feed_forward = PositionalWiseFeedForward( 1029 | model_dim, ffn_dim, dropout) 1030 | 1031 | def forward(self, question, video, q_mask=None, v_mask=None): 1032 | # self attention, v_attention of shape (bs, v_len, q_len) 1033 | video_, v_attention = self.v_transformer( 1034 | video, question, question, v_mask) 1035 | # feed forward network 1036 | video_ = self.v_feed_forward(video_) 1037 | 1038 | # self attention, q_attention of shape (bs, q_len, v_len) 1039 | question_, q_attention = self.q_transformer( 1040 | question, video, video, q_mask) 1041 | # feed forward network 1042 | question_ = self.q_feed_forward(question_) 1043 | 1044 | return video_, question_, v_attention, q_attention 1045 | 1046 | 1047 | class CrossoverTransformer(nn.Module): 1048 | 1049 | def __init__( 1050 | self, 1051 | q_max_len=35, 1052 | v_max_len=80, 1053 | num_layers=2, 1054 | model_dim=512, 1055 | num_heads=8, 1056 | ffn_dim=2048, 1057 | dropout=0.0): 1058 | super().__init__() 1059 | 1060 | self.encoder_layers = nn.ModuleList( 1061 | [ 1062 | CrossoverTransformerLayer( 1063 | model_dim, num_heads, ffn_dim, dropout) 1064 | for _ in range(num_layers) 1065 | ]) 1066 | 1067 | # max_seq_len is 35 or 80 1068 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len) 1069 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len) 1070 | 1071 | def forward(self, question, video, q_length, v_length): 1072 | # q_length of shape (batch, ), each item is the length of the seq 1073 | question += self.q_pos_embedding(q_length)[:, :question.size()[1], :] 1074 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :] 1075 | 1076 | # q_attention_mask of shape (bs, q_len, v_len) 1077 | q_attention_mask = padding_mask_k(question, video) 1078 | # v_attention_mask of shape (bs, v_len, q_len) 1079 | v_attention_mask = padding_mask_k(video, question) 1080 | 1081 | q_attentions = [] 1082 | v_attentions = [] 1083 | for encoder in self.encoder_layers: 1084 | video, question, v_attention, q_attention = encoder( 1085 | question, video, q_attention_mask, v_attention_mask) 1086 | q_attentions.append(q_attention) 1087 | v_attentions.append(v_attention) 1088 | 1089 | return question, video, q_attentions, v_attentions 1090 | 1091 | 1092 | class MaskedCrossoverTransformerLayer(nn.Module): 1093 | 1094 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): 1095 | super().__init__() 1096 | 1097 | self.v_transformer = MaskedMultiHeadAttention( 1098 | model_dim, num_heads, dropout) 1099 | self.q_transformer = MaskedMultiHeadAttention( 1100 | model_dim, num_heads, dropout) 1101 | self.v_feed_forward = MaskedPositionalWiseFeedForward( 1102 | model_dim, ffn_dim, dropout) 1103 | self.q_feed_forward = MaskedPositionalWiseFeedForward( 1104 | model_dim, ffn_dim, dropout) 1105 | 1106 | def forward( 1107 | self, 1108 | question, 1109 | video, 1110 | q_mask=None, 1111 | v_mask=None, 1112 | q_sf_mask=None, 1113 | v_sf_mask=None): 1114 | # self attention, v_attention of shape (bs, v_len, q_len) 1115 | video_, v_attention = self.v_transformer( 1116 | video, question, question, v_mask, v_sf_mask) 1117 | # feed forward network 1118 | video_ = self.v_feed_forward(video_) 1119 | 1120 | # self attention, q_attention of shape (bs, q_len, v_len) 1121 | question_, q_attention = self.q_transformer( 1122 | question, video, video, q_mask, q_sf_mask) 1123 | # feed forward network 1124 | question_ = self.q_feed_forward(question_) 1125 | 1126 | return video_, question_, v_attention, q_attention 1127 | 1128 | 1129 | class MaskedCrossoverTransformer(nn.Module): 1130 | 1131 | def __init__( 1132 | self, 1133 | q_max_len=35, 1134 | v_max_len=80, 1135 | num_layers=2, 1136 | model_dim=512, 1137 | num_heads=8, 1138 | ffn_dim=2048, 1139 | dropout=0.0, 1140 | position=False): 1141 | super().__init__() 1142 | 1143 | self.position = position 1144 | 1145 | self.encoder_layers = nn.ModuleList( 1146 | [ 1147 | MaskedCrossoverTransformerLayer( 1148 | model_dim, num_heads, ffn_dim, dropout) 1149 | for _ in range(num_layers) 1150 | ]) 1151 | 1152 | # max_seq_len is 35 or 80 1153 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len) 1154 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len) 1155 | 1156 | def forward(self, question, video, q_length, v_length): 1157 | # q_length of shape (batch, ), each item is the length of the seq 1158 | if self.position: 1159 | question += self.q_pos_embedding( 1160 | q_length)[:, :question.size()[1], :] 1161 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :] 1162 | 1163 | q_attention_mask = padding_mask_k(question, video) 1164 | q_softmax_mask = padding_mask_q(question, video) 1165 | v_attention_mask = padding_mask_k(video, question) 1166 | v_softmax_mask = padding_mask_q(video, question) 1167 | 1168 | q_attentions = [] 1169 | v_attentions = [] 1170 | for encoder in self.encoder_layers: 1171 | video, question, v_attention, q_attention = encoder( 1172 | question, video, q_attention_mask, v_attention_mask, 1173 | q_softmax_mask, v_softmax_mask) 1174 | q_attentions.append(q_attention) 1175 | v_attentions.append(v_attention) 1176 | 1177 | return question, video, q_attentions, v_attentions 1178 | 1179 | 1180 | class SelfTransformerEncoder(nn.Module): 1181 | 1182 | def __init__( 1183 | self, 1184 | hidden_size, 1185 | n_layers, 1186 | dropout_p, 1187 | vocab_size, 1188 | q_max_len, 1189 | v_max_len, 1190 | embedding=None, 1191 | update_embedding=True, 1192 | position=True): 1193 | super().__init__() 1194 | self.dropout = nn.Dropout(p=dropout_p) 1195 | self.ln_q = nn.LayerNorm(hidden_size, elementwise_affine=False) 1196 | self.ln_v = nn.LayerNorm(hidden_size, elementwise_affine=False) 1197 | self.n_layers = n_layers 1198 | self.position = position 1199 | 1200 | embedding_dim = embedding.shape[ 1201 | 1] if embedding is not None else hidden_size 1202 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) 1203 | 1204 | # ! no embedding init 1205 | # if embedding is not None: 1206 | # # self.embedding.weight.data.copy_(torch.from_numpy(embedding)) 1207 | # self.embedding.weight = nn.Parameter( 1208 | # torch.from_numpy(embedding).float()) 1209 | self.upcompress_embedding = nlpnn.WeightDropLinear( 1210 | embedding_dim, hidden_size, weight_dropout=dropout_p, bias=False) 1211 | self.embedding.weight.requires_grad = update_embedding 1212 | 1213 | self.project_c3d = nlpnn.WeightDropLinear(4096, 2048, bias=False) 1214 | 1215 | self.project_resnet_and_c3d = nlpnn.WeightDropLinear( 1216 | 4096, hidden_size, weight_dropout=dropout_p, bias=False) 1217 | 1218 | # max_seq_len is 35 or 80 1219 | self.q_pos_embedding = PositionalEncoding(hidden_size, q_max_len) 1220 | self.v_pos_embedding = PositionalEncoding(hidden_size, v_max_len) 1221 | 1222 | def forward(self, question, resnet, c3d, q_length, v_length): 1223 | ### question 1224 | embedded = self.embedding(question) 1225 | embedded = self.dropout(embedded) 1226 | question = F.relu(self.upcompress_embedding(embedded)) 1227 | 1228 | ### video 1229 | # ! no relu 1230 | c3d = self.project_c3d(c3d) 1231 | video = F.relu( 1232 | self.project_resnet_and_c3d(torch.cat((resnet, c3d), dim=2))) 1233 | 1234 | ### position encoding 1235 | if self.position: 1236 | question += self.q_pos_embedding( 1237 | q_length)[:, :question.size()[1], :] 1238 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :] 1239 | 1240 | # question = self.ln_q(question) 1241 | # video = self.ln_v(video) 1242 | return question, video 1243 | -------------------------------------------------------------------------------- /networks/VQAModel/B2A.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.insert(0, 'networks') 5 | from networks.Transformer import SingleSimpleAttention 6 | from networks.GCN import AdjGenerator, GCN 7 | import EncoderRNN 8 | from block import fusions #pytorch >= 1.1.0 9 | 10 | 11 | class B2A(nn.Module): 12 | def __init__(self, vid_encoder, qns_encoder, device, gcn_layer=1): 13 | """ 14 | Bridge to Answer: Structure-aware Graph Interaction Network for Video Question Answering (CVPR 2021) 15 | """ 16 | super(B2A, self).__init__() 17 | self.vid_encoder = vid_encoder 18 | self.qns_encoder = qns_encoder 19 | self.device = device 20 | hidden_size = qns_encoder.dim_hidden 21 | input_dropout_p = vid_encoder.input_dropout_p 22 | 23 | self.q_input_ln = nn.LayerNorm(hidden_size*2, elementwise_affine=False) 24 | self.v_input_ln = nn.LayerNorm(hidden_size*2, elementwise_affine=False) 25 | 26 | self.co_attn_t2m_qv = SingleSimpleAttention( 27 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 28 | self.co_attn_t2a_qv = SingleSimpleAttention( 29 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 30 | self.co_attn_a2t_vv = SingleSimpleAttention( 31 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 32 | self.co_attn_t2m_vv = SingleSimpleAttention( 33 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 34 | self.co_attn_m2t_vv = SingleSimpleAttention( 35 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 36 | self.co_attn_t2a_vv = SingleSimpleAttention( 37 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 38 | 39 | self.adj_generator = AdjGenerator(hidden_size*2, hidden_size*2) 40 | 41 | self.gcn_v = GCN( 42 | hidden_size*2, 43 | hidden_size*2, 44 | hidden_size*2, 45 | num_layers=gcn_layer, 46 | dropout=input_dropout_p) 47 | 48 | self.gcn_t = GCN( 49 | hidden_size*2, 50 | hidden_size*2, 51 | hidden_size*2, 52 | num_layers=gcn_layer, 53 | dropout=input_dropout_p) 54 | 55 | self.output_layer = OutputUnitMultiChoices(hidden_size*2) 56 | 57 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q): 58 | """ 59 | Args: 60 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 61 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 62 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 63 | candidates_len: [Tensor] (batch_size, 5) 64 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim) 65 | dep_adj: [Tensor] (batch_size, max_length, max_length) 66 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 67 | question_len: [Tensor] (batch_size, 5) 68 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim) 69 | dep_adj_q: [Tensor] (batch_size, max_length, max_length) 70 | return: 71 | logits, predict_idx 72 | """ 73 | if self.qns_encoder.use_bert: 74 | candidates = candidates.permute(1, 0, 2, 3) 75 | else: 76 | candidates = candidates.permute(1, 0, 2) 77 | 78 | obj_feature = obj_feature.permute(1, 0, 2, 3) 79 | cand_len = candidates_len.permute(1, 0) 80 | 81 | app_output, mot_output = self.vid_encoder(video_appearance_feat, video_motion_feat) 82 | app_output = self.v_input_ln(app_output) 83 | mot_output = self.v_input_ln(mot_output) 84 | 85 | ques_output, ques_hidden = self.qns_encoder(question, question_len, obj=obj_fea_q) 86 | ques_output = ques_output.reshape(ques_output.shape[0], ques_output.shape[1], -1) 87 | ques_output = self.q_input_ln(ques_output) 88 | ques_hidden = ques_hidden.permute(1, 0, 2).reshape(ques_output.shape[0], -1) 89 | 90 | q_v_emb = self.q2v_v2v(app_output, mot_output, ques_output, dep_adj_q) 91 | 92 | 93 | out = [] 94 | for idx, qas in enumerate(candidates): 95 | qas_output, qas_hidden = self.qns_encoder(qas, cand_len[idx], obj=obj_feature[idx]) 96 | qas_output = qas_output.reshape(qas_output.shape[0], qas_output.shape[1], -1) 97 | qas_output = self.q_input_ln(qas_output) 98 | qas_hidden = qas_hidden.permute(1, 0, 2).reshape(qas_output.shape[0], -1) 99 | qa_v_emb = self.q2v_v2v(app_output, mot_output, qas_output, dep_adj[:, idx]) 100 | 101 | final_output = self.output_layer(q_v_emb, qa_v_emb, ques_hidden, qas_hidden) 102 | out.append(final_output) 103 | 104 | out = torch.stack(out, 0).transpose(1, 0).squeeze() 105 | _, predict_idx = torch.max(out, 1) 106 | 107 | return out, predict_idx 108 | 109 | def q2v_v2v(self, app_feat, mot_feat, txt_feat, txt_cont=None): 110 | app_adj = self.adj_generator(app_feat) 111 | mot_adj = self.adj_generator(mot_feat) 112 | txt_adj = self.adj_generator(txt_feat, adjacency=txt_cont) 113 | 114 | # question-to-visual 115 | app_hat = self.gcn_v(self.co_attn_t2a_qv(app_feat, txt_feat), app_adj) 116 | mot_hat = self.gcn_v(self.co_attn_t2m_qv(mot_feat, txt_feat), mot_adj) 117 | 118 | # visual-to-visual 119 | txt_a2t = self.gcn_t(self.co_attn_a2t_vv(txt_feat, app_hat), txt_adj) + txt_feat 120 | txt_m2t = self.gcn_t(self.co_attn_m2t_vv(txt_feat, mot_hat), txt_adj) + txt_feat 121 | app_v2v = self.co_attn_t2a_vv(app_hat, txt_m2t) 122 | mot_v2v = self.co_attn_t2a_vv(mot_hat, txt_a2t) 123 | 124 | return torch.cat([app_v2v.mean(dim=1), mot_v2v.mean(dim=1)], dim=-1) 125 | 126 | class OutputUnitMultiChoices(nn.Module): 127 | def __init__(self, module_dim=512): 128 | super(OutputUnitMultiChoices, self).__init__() 129 | 130 | self.question_proj = nn.Linear(module_dim, module_dim) 131 | 132 | self.ans_candidates_proj = nn.Linear(module_dim, module_dim) 133 | 134 | self.v_question_proj = nn.Linear(module_dim*2, module_dim) 135 | 136 | self.v_ans_candidates_proj = nn.Linear(module_dim*2, module_dim) 137 | 138 | self.classifier = nn.Sequential(nn.Dropout(0.15), 139 | nn.Linear(module_dim * 4, module_dim), 140 | nn.ELU(), 141 | nn.BatchNorm1d(module_dim), 142 | nn.Dropout(0.15), 143 | nn.Linear(module_dim, 1)) 144 | 145 | def forward(self, q_visual_embedding, a_visual_embedding, question_embedding, ans_candidates_embedding): 146 | q_visual_embedding = self.v_question_proj(q_visual_embedding) 147 | a_visual_embedding = self.v_ans_candidates_proj(a_visual_embedding) 148 | question_embedding = self.question_proj(question_embedding) 149 | ans_candidates_embedding = self.ans_candidates_proj(ans_candidates_embedding) 150 | out = torch.cat([q_visual_embedding, question_embedding, a_visual_embedding, ans_candidates_embedding], 1) 151 | out = self.classifier(out) 152 | 153 | return out -------------------------------------------------------------------------------- /networks/VQAModel/CoMem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.insert(0, 'networks') 5 | from memory_module import EpisodicMemory 6 | 7 | 8 | class CoMem(nn.Module): 9 | def __init__(self, vid_encoder, qns_encoder, max_len_v, max_len_q, device): 10 | """ 11 | motion-appearance co-memory networks for video question answering (CVPR18) 12 | """ 13 | super(CoMem, self).__init__() 14 | self.vid_encoder = vid_encoder 15 | self.qns_encoder = qns_encoder 16 | 17 | dim = qns_encoder.dim_hidden 18 | 19 | self.epm_app = EpisodicMemory(dim*2) 20 | self.epm_mot = EpisodicMemory(dim*2) 21 | 22 | self.linear_ma = nn.Linear(dim*2*4, dim*2) 23 | self.linear_mb = nn.Linear(dim*2*4, dim*2) 24 | 25 | self.vq2word = nn.Linear(dim*2*2, 1) 26 | 27 | self.device = device 28 | 29 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q): 30 | """ 31 | Args: 32 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 33 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 34 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 35 | candidates_len: [Tensor] (batch_size, 5) 36 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim) 37 | dep_adj: [Tensor] (batch_size, max_length, max_length) 38 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 39 | question_len: [Tensor] (batch_size, 5) 40 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim) 41 | dep_adj_q: [Tensor] (batch_size, max_length, max_length) 42 | return: 43 | logits, predict_idx 44 | """ 45 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1) 46 | if self.qns_encoder.use_bert: 47 | candidates = candidates.permute(1, 0, 2, 3) # for BERT 48 | else: 49 | candidates = candidates.permute(1, 0, 2) 50 | 51 | obj_feature = obj_feature.permute(1, 0, 2, 3) 52 | candidates_len = candidates_len.permute(1, 0) 53 | 54 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) 55 | vid_feats = (outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2) 56 | 57 | _, qns_hidden = self.qns_encoder(question, question_len, obj=obj_fea_q) 58 | qas_hidden = list() 59 | for idx, qas in enumerate(candidates): 60 | _, ah_tmp = self.qns_encoder(qas, candidates_len[idx], obj=obj_feature[idx]) 61 | qas_hidden.append(ah_tmp) 62 | 63 | out = [] 64 | for idx, qas in enumerate(qas_hidden): 65 | encoder_out = self.vq_encoder(vid_feats, qns_hidden, qas) 66 | out.append(encoder_out) 67 | 68 | out = torch.stack(out, 0).transpose(1, 0) 69 | 70 | _, predict_idx = torch.max(out, 1) 71 | 72 | 73 | return out, predict_idx 74 | 75 | def vq_encoder(self, vid_feats, ques, qas, iter_num=3): 76 | 77 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = vid_feats 78 | 79 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1) 80 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1) 81 | 82 | batch_size = qas.shape[1] 83 | 84 | qns_embed = ques.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 85 | qas_embed = qas.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 86 | 87 | m_app = outputs_app[:, -1, :] 88 | m_mot = outputs_motion[:, -1, :] 89 | ma, mb = m_app.detach(), m_mot.detach() 90 | m_app = m_app.unsqueeze(1) 91 | m_mot = m_mot.unsqueeze(1) 92 | for _ in range(iter_num): 93 | mm = ma + mb 94 | m_app = self.epm_app(outputs_app, mm, m_app) 95 | m_mot = self.epm_mot(outputs_motion, mm, m_mot) 96 | ma_q = torch.cat((ma, m_app.squeeze(1), qns_embed, qas_embed), dim=1) 97 | mb_q = torch.cat((mb, m_mot.squeeze(1), qns_embed, qas_embed), dim=1) 98 | ma = torch.tanh(self.linear_ma(ma_q)) 99 | mb = torch.tanh(self.linear_mb(mb_q)) 100 | 101 | mem = torch.cat((ma, mb), dim=1) 102 | outputs = self.vq2word(mem).squeeze() 103 | 104 | return outputs -------------------------------------------------------------------------------- /networks/VQAModel/EVQA.py: -------------------------------------------------------------------------------- 1 | from locale import AM_STR 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class EVQA(nn.Module): 7 | def __init__(self, vid_encoder, qns_encoder, device, blind=False): 8 | super(EVQA, self).__init__() 9 | self.vid_encoder = vid_encoder 10 | self.qns_encoder = qns_encoder 11 | self.device = device 12 | self.blind = blind 13 | self.FC = nn.Linear(qns_encoder.dim_hidden, 1) 14 | 15 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q): 16 | """ 17 | Args: 18 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 19 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 20 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 21 | candidates_len: [Tensor] (batch_size, 5) 22 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim) 23 | dep_adj: [Tensor] (batch_size, max_length, max_length) 24 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 25 | question_len: [Tensor] (batch_size, 5) 26 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim) 27 | dep_adj_q: [Tensor] (batch_size, max_length, max_length) 28 | return: 29 | logits, predict_idx 30 | """ 31 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1) 32 | if self.qns_encoder.use_bert: 33 | candidates = candidates.permute(1, 0, 2, 3) # for BERT 34 | else: 35 | candidates = candidates.permute(1, 0, 2) 36 | 37 | obj_feature = obj_feature.permute(1, 0, 2, 3) 38 | if self.blind: 39 | obj_feature[:] = 0 40 | cand_len = candidates_len.permute(1, 0) 41 | out = [] 42 | for idx, qnsans in enumerate(candidates): 43 | encoder_out = self.vq_encoder(vid_feats, qnsans, cand_len[idx], question_len, obj_feature[idx]) 44 | out.append(encoder_out) 45 | 46 | out = torch.stack(out, 0).transpose(1, 0) 47 | 48 | _, predict_idx = torch.max(out, 1) 49 | 50 | return out, predict_idx 51 | 52 | def vq_encoder(self, vid_feats, qnsans, qnsans_len, qns_len, obj_feature): 53 | 54 | qmask = torch.zeros(qnsans.shape[0], qnsans.shape[1], dtype=qnsans.dtype, device=qnsans.device) # bs, maxlen 55 | amask = torch.zeros(qnsans.shape[0], qnsans.shape[1], dtype=qnsans.dtype, device=qnsans.device) # bs, maxlen 56 | 57 | for idx in range(qmask.shape[0]): 58 | qmask[idx, :qns_len[idx]] = 1 59 | amask[idx, qns_len[idx]:qnsans_len[idx]] = 1 60 | 61 | if len(qnsans.shape) == 2: 62 | qns = qnsans*qmask 63 | ans = qnsans*amask 64 | elif len(qnsans.shape) == 3: 65 | qns = qnsans*qmask.unsqueeze(-1) 66 | ans = qnsans*amask.unsqueeze(-1) 67 | 68 | obj_feature_q = obj_feature*qmask.unsqueeze(-1) 69 | obj_feature_a = obj_feature*amask.unsqueeze(-1) 70 | 71 | _, vid_hidden = self.vid_encoder(vid_feats) 72 | _, qs_hidden = self.qns_encoder(qns, qns_len, obj=obj_feature_q) 73 | _, as_hidden = self.qns_encoder(ans, qnsans_len, obj=obj_feature_a) 74 | 75 | vid_embed = vid_hidden.squeeze() 76 | qs_embed = qs_hidden.squeeze() 77 | as_embed = as_hidden.squeeze() 78 | 79 | if self.blind: 80 | fuse = qs_embed + as_embed 81 | else: 82 | fuse = qs_embed + as_embed + vid_embed 83 | 84 | outputs = self.FC(fuse).squeeze() 85 | 86 | return outputs -------------------------------------------------------------------------------- /networks/VQAModel/HCRN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn import functional as F 3 | 4 | import itertools 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | from torch.nn.modules.module import Module 10 | 11 | class HCRN(nn.Module): 12 | def __init__(self, vid_encoder, qns_encoder, device): 13 | super(HCRN, self).__init__() 14 | """ 15 | Hierarchical Conditional Relation Networks for Video Question Answering (CVPR2020) 16 | """ 17 | self.qns_encoder = qns_encoder 18 | self.vid_encoder = vid_encoder 19 | hidden_size = vid_encoder.dim_hidden 20 | self.feature_aggregation = FeatureAggregation(hidden_size) 21 | 22 | self.output_unit = OutputUnitMultiChoices(module_dim=hidden_size) 23 | 24 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q): 25 | """ 26 | Args: 27 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 28 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 29 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 30 | candidates_len: [Tensor] (batch_size, 5) 31 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim) 32 | dep_adj: [Tensor] (batch_size, max_length, max_length) 33 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 34 | question_len: [Tensor] (batch_size, 5) 35 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim) 36 | dep_adj_q: [Tensor] (batch_size, max_length, max_length) 37 | return: 38 | logits, predict_idx 39 | """ 40 | batch_size = candidates.size(0) 41 | if self.qns_encoder.use_bert: 42 | cand = candidates.permute(1, 0, 2, 3) # for BERT 43 | else: 44 | cand = candidates.permute(1, 0, 2) 45 | cand_len = candidates_len.permute(1, 0) 46 | out = list() 47 | _, question_embedding = self.qns_encoder(question, question_len, obj=obj_fea_q) 48 | visual_embedding = self.vid_encoder(video_appearance_feat, video_motion_feat, question_embedding) 49 | q_visual_embedding = self.feature_aggregation(question_embedding, visual_embedding) 50 | for idx, qas in enumerate(cand): 51 | _, qas_embedding = self.qns_encoder(qas, cand_len[idx], obj=obj_feature[:, idx]) 52 | qa_visual_embedding = self.feature_aggregation(qas_embedding, visual_embedding) 53 | encoder_out = self.output_unit(q_visual_embedding, question_embedding, qa_visual_embedding, qas_embedding) 54 | out.append(encoder_out) 55 | out = torch.stack(out, 0).transpose(1, 0).squeeze() 56 | _, predict_idx = torch.max(out, 1) 57 | return out, predict_idx 58 | 59 | class FeatureAggregation(nn.Module): 60 | def __init__(self, module_dim=512): 61 | super(FeatureAggregation, self).__init__() 62 | self.module_dim = module_dim 63 | 64 | self.q_proj = nn.Linear(module_dim, module_dim, bias=False) 65 | self.v_proj = nn.Linear(module_dim, module_dim, bias=False) 66 | 67 | self.cat = nn.Linear(2 * module_dim, module_dim) 68 | self.attn = nn.Linear(module_dim, 1) 69 | 70 | self.activation = nn.ELU() 71 | self.dropout = nn.Dropout(0.15) 72 | 73 | def forward(self, question_rep, visual_feat): 74 | visual_feat = self.dropout(visual_feat) 75 | q_proj = self.q_proj(question_rep) 76 | v_proj = self.v_proj(visual_feat) 77 | 78 | v_q_cat = torch.cat((v_proj, q_proj.unsqueeze(1) * v_proj), dim=-1) 79 | v_q_cat = self.cat(v_q_cat) 80 | v_q_cat = self.activation(v_q_cat) 81 | 82 | attn = self.attn(v_q_cat) # (bz, k, 1) 83 | attn = F.softmax(attn, dim=1) # (bz, k, 1) 84 | 85 | v_distill = (attn * visual_feat).sum(1) 86 | 87 | return v_distill 88 | 89 | class OutputUnitMultiChoices(nn.Module): 90 | def __init__(self, module_dim=512): 91 | super(OutputUnitMultiChoices, self).__init__() 92 | 93 | self.question_proj = nn.Linear(module_dim, module_dim) 94 | 95 | self.ans_candidates_proj = nn.Linear(module_dim, module_dim) 96 | 97 | self.classifier = nn.Sequential(nn.Dropout(0.15), 98 | nn.Linear(module_dim * 4, module_dim), 99 | nn.ELU(), 100 | nn.BatchNorm1d(module_dim), 101 | nn.Dropout(0.15), 102 | nn.Linear(module_dim, 1)) 103 | 104 | def forward(self, question_embedding, q_visual_embedding, ans_candidates_embedding, 105 | a_visual_embedding): 106 | question_embedding = self.question_proj(question_embedding) 107 | ans_candidates_embedding = self.ans_candidates_proj(ans_candidates_embedding) 108 | out = torch.cat([q_visual_embedding, question_embedding, a_visual_embedding, 109 | ans_candidates_embedding], 1) 110 | out = self.classifier(out) 111 | 112 | return out -------------------------------------------------------------------------------- /networks/VQAModel/HGA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.insert(0, 'networks') 5 | from networks.Transformer import CoAttention 6 | from networks.GCN import AdjLearner, GCN 7 | from block import fusions #pytorch >= 1.1.0 8 | 9 | 10 | class HGA(nn.Module): 11 | def __init__(self, vid_encoder, qns_encoder, device): 12 | """ 13 | Reasoning with Heterogeneous Graph Alignment for Video Question Answering (AAAI2020) 14 | """ 15 | super(HGA, self).__init__() 16 | self.vid_encoder = vid_encoder 17 | self.qns_encoder = qns_encoder 18 | self.device = device 19 | hidden_size = vid_encoder.dim_hidden 20 | input_dropout_p = vid_encoder.input_dropout_p 21 | 22 | self.co_attn = CoAttention( 23 | hidden_size, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 24 | 25 | self.adj_learner = AdjLearner( 26 | hidden_size, hidden_size, dropout=input_dropout_p) 27 | 28 | self.gcn = GCN( 29 | hidden_size, 30 | hidden_size, 31 | hidden_size, 32 | num_layers=1, 33 | dropout=input_dropout_p) 34 | 35 | self.gcn_atten_pool = nn.Sequential( 36 | nn.Linear(hidden_size, hidden_size // 2), 37 | nn.Tanh(), 38 | nn.Linear(hidden_size // 2, 1), 39 | nn.Softmax(dim=-1)) #change to dim=-2 for attention-pooling otherwise sum-pooling 40 | 41 | self.global_fusion = fusions.Block( 42 | [hidden_size, hidden_size], hidden_size, dropout_input=input_dropout_p) 43 | 44 | self.fusion = fusions.Block([hidden_size, hidden_size], 1) 45 | 46 | 47 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q): 48 | """ 49 | Args: 50 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 51 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 52 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 53 | candidates_len: [Tensor] (batch_size, 5) 54 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim) 55 | dep_adj: [Tensor] (batch_size, max_length, max_length) 56 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 57 | question_len: [Tensor] (batch_size, 5) 58 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim) 59 | dep_adj_q: [Tensor] (batch_size, max_length, max_length) 60 | return: 61 | logits, predict_idx 62 | """ 63 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1) 64 | if self.qns_encoder.use_bert: 65 | candidates = candidates.permute(1, 0, 2, 3) # for BERT 66 | else: 67 | candidates = candidates.permute(1, 0, 2) 68 | 69 | obj_feature = obj_feature.permute(1, 0, 2, 3) 70 | cand_len = candidates_len.permute(1, 0) 71 | 72 | v_output, v_hidden = self.vid_encoder(vid_feats) 73 | v_last_hidden = torch.squeeze(v_hidden) 74 | 75 | 76 | out = [] 77 | for idx, qas in enumerate(candidates): 78 | encoder_out = self.vq_encoder(v_output, v_last_hidden, qas, cand_len[idx], obj_feature[idx]) 79 | out.append(encoder_out) 80 | 81 | out = torch.stack(out, 0).transpose(1, 0) 82 | _, predict_idx = torch.max(out, 1) 83 | 84 | return out, predict_idx 85 | 86 | 87 | def vq_encoder(self, v_output, v_last_hidden, qas, qas_len, obj_feature): 88 | q_output, s_hidden = self.qns_encoder(qas, qas_len, obj=obj_feature) 89 | qns_last_hidden = torch.squeeze(s_hidden) 90 | 91 | q_output, v_output = self.co_attn(q_output, v_output) 92 | 93 | adj = self.adj_learner(q_output, v_output) 94 | q_v_inputs = torch.cat((q_output, v_output), dim=1) 95 | q_v_output = self.gcn(q_v_inputs, adj) 96 | 97 | local_attn = self.gcn_atten_pool(q_v_output) 98 | local_out = torch.sum(q_v_output * local_attn, dim=1) 99 | 100 | global_out = self.global_fusion((qns_last_hidden, v_last_hidden)) 101 | 102 | 103 | out = self.fusion((global_out, local_out)).squeeze() 104 | 105 | return out 106 | -------------------------------------------------------------------------------- /networks/VQAModel/HME.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import sys 4 | sys.path.insert(0, 'networks') 5 | from Attention import TempAttention, SpatialAttention 6 | from memory_rand import MemoryRamTwoStreamModule2, MemoryRamModule2, MMModule2 7 | 8 | 9 | class HME(nn.Module): 10 | def __init__(self, vid_encoder, qns_encoder, max_len_v, max_len_q, device, input_drop_p=0.2): 11 | """ 12 | Heterogeneous memory enhanced multimodal attention model for video question answering (CVPR19) 13 | """ 14 | super(HME, self).__init__() 15 | self.vid_encoder = vid_encoder 16 | self.qns_encoder = qns_encoder 17 | 18 | 19 | dim = qns_encoder.dim_hidden 20 | 21 | self.temp_att_a = TempAttention(dim * 2, dim * 2, hidden_dim=256) 22 | self.temp_att_m = TempAttention(dim * 2, dim * 2, hidden_dim=256) 23 | self.mrm_vid = MemoryRamTwoStreamModule2(dim, dim, max_len_v, device) 24 | self.mrm_txt = MemoryRamModule2(dim, dim, max_len_q, device) 25 | 26 | self.mm_module_v1 = MMModule2(dim, input_drop_p, device) 27 | 28 | self.linear_vid = nn.Linear(dim*2, dim) 29 | self.linear_qns = nn.Linear(dim*2, dim) 30 | self.linear_mem = nn.Linear(dim*2, dim) 31 | self.vq2word_hme = nn.Linear(dim*3, 1) 32 | self.device = device 33 | 34 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q): 35 | """ 36 | Args: 37 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 38 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 39 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 40 | candidates_len: [Tensor] (batch_size, 5) 41 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim) 42 | dep_adj: [Tensor] (batch_size, max_length, max_length) 43 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)]) 44 | question_len: [Tensor] (batch_size, 5) 45 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim) 46 | dep_adj_q: [Tensor] (batch_size, max_length, max_length) 47 | return: 48 | logits, predict_idx 49 | """ 50 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1) 51 | if self.qns_encoder.use_bert: 52 | candidates = candidates.permute(1, 0, 2, 3) # for BERT 53 | else: 54 | candidates = candidates.permute(1, 0, 2) 55 | 56 | obj_feature = obj_feature.permute(1, 0, 2, 3) 57 | candidates_len = candidates_len.permute(1, 0) 58 | 59 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) 60 | vid_feats = (outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2) 61 | 62 | qas_seq, qas_hidden = list(), list() 63 | for idx, qas in enumerate(candidates): 64 | q_output, s_hidden = self.qns_encoder(qas, candidates_len[idx], obj=obj_feature[idx]) 65 | qas_seq.append(q_output) 66 | qas_hidden.append(s_hidden) 67 | 68 | out = [] 69 | for idx, (qa_seq, qa_hidden) in enumerate(zip(qas_seq, qas_hidden)): 70 | encoder_out = self.vq_encoder(vid_feats, qa_seq, qa_hidden) 71 | out.append(encoder_out) 72 | 73 | out = torch.stack(out, 0).transpose(1, 0) 74 | 75 | _, predict_idx = torch.max(out, 1) 76 | 77 | return out, predict_idx 78 | 79 | def vq_encoder(self, vid_feats, qns_seq, qns_hidden, iter_num=3): 80 | 81 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = vid_feats 82 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1) 83 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1) 84 | 85 | batch_size, fnum, vid_feat_dim = outputs_app.size() 86 | 87 | batch_size, seq_len, qns_feat_dim = qns_seq.size() 88 | 89 | qns_embed = qns_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1) 90 | 91 | # Apply temporal attention 92 | att_app, beta_app = self.temp_att_a(qns_embed, outputs_app) 93 | att_motion, beta_motion = self.temp_att_m(qns_embed, outputs_motion) 94 | tmp_app_motion = torch.cat((outputs_app_l2[:, -1, :], outputs_motion_l2[:, -1, :]), dim=-1) 95 | 96 | mem_output = torch.zeros(batch_size, vid_feat_dim).to(self.device) 97 | 98 | mem_ram_vid = self.mrm_vid(outputs_app_l2, outputs_motion_l2, fnum) 99 | mem_ram_txt = self.mrm_txt(qns_seq, qns_seq.shape[1]) 100 | mem_output[:] = self.mm_module_v1(tmp_app_motion, mem_ram_vid, mem_ram_txt, iter_num) 101 | 102 | app_trans = torch.tanh(self.linear_vid(att_app)) 103 | motion_trans = torch.tanh(self.linear_vid(att_motion)) 104 | mem_trans = torch.tanh(self.linear_mem(mem_output)) 105 | 106 | encoder_outputs = torch.cat((app_trans, motion_trans, mem_trans), dim=1) 107 | outputs = self.vq2word_hme(encoder_outputs).squeeze() 108 | 109 | return outputs -------------------------------------------------------------------------------- /networks/memory_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | 8 | class AttentionGRUCell(nn.Module): 9 | ''' 10 | Eq (1)~(4), then modify by Eq (11) 11 | When forwarding, we feed attention gate g into GRU 12 | ''' 13 | def __init__(self, input_size, hidden_size): 14 | super(AttentionGRUCell, self).__init__() 15 | self.hidden_size = hidden_size 16 | self.Wr = nn.Linear(input_size, hidden_size) 17 | self.Ur = nn.Linear(hidden_size, hidden_size) 18 | self.W = nn.Linear(input_size, hidden_size) 19 | self.U = nn.Linear(hidden_size, hidden_size) 20 | 21 | def init_weights(self): 22 | self.Wr.weight.data.normal_(0.0, 0.02) 23 | self.Wr.bias.data.fill_(0) 24 | self.Ur.weight.data.normal_(0.0, 0.02) 25 | self.Ur.bias.data.fill_(0) 26 | self.W.weight.data.normal_(0.0, 0.02) 27 | self.W.bias.data.fill_(0) 28 | self.U.weight.data.normal_(0.0, 0.02) 29 | self.U.bias.data.fill_(0) 30 | 31 | def forward(self, fact, C, g): 32 | ''' 33 | fact.size() -> (#batch, #hidden = #embedding) 34 | c.size() -> (#hidden, ) -> (#batch, #hidden = #embedding) 35 | r.size() -> (#batch, #hidden = #embedding) 36 | h_tilda.size() -> (#batch, #hidden = #embedding) 37 | g.size() -> (#batch, ) 38 | ''' 39 | 40 | r = torch.sigmoid(self.Wr(fact) + self.Ur(C)) 41 | h_tilda = torch.tanh(self.W(fact) + r * self.U(C)) 42 | g = g.unsqueeze(1).expand_as(h_tilda) 43 | h = g * h_tilda + (1 - g) * C 44 | return h 45 | 46 | class AttentionGRU(nn.Module): 47 | ''' 48 | Section 3.3 49 | continuously run AttnGRU to get contextual vector c at each time t 50 | ''' 51 | def __init__(self, input_size, hidden_size): 52 | super(AttentionGRU, self).__init__() 53 | self.hidden_size = hidden_size 54 | self.AGRUCell = AttentionGRUCell(input_size, hidden_size) 55 | 56 | def init_weights(self): 57 | self.AGRUCell.init_weights() 58 | 59 | def forward(self, facts, G): 60 | ''' 61 | facts.size() -> (#batch, #sentence, #hidden = #embedding) 62 | fact.size() -> (#batch, #hidden = #embedding) 63 | G.size() -> (#batch, #sentence) 64 | g.size() -> (#batch, ) 65 | C.size() -> (#batch, #hidden) 66 | ''' 67 | batch_num, sen_num, embedding_size = facts.size() 68 | C = Variable(torch.zeros(self.hidden_size)).cuda() 69 | for sid in range(sen_num): 70 | fact = facts[:, sid, :] 71 | g = G[:, sid] 72 | if sid == 0: 73 | C = C.unsqueeze(0).expand_as(fact) 74 | C = self.AGRUCell(fact, C, g) 75 | return C 76 | 77 | class EpisodicMemory(nn.Module): 78 | ''' 79 | Section 3.3 80 | ''' 81 | 82 | def __init__(self, hidden_size): 83 | super(EpisodicMemory, self).__init__() 84 | self.AGRU = AttentionGRU(hidden_size, hidden_size) 85 | self.z1 = nn.Linear(4 * hidden_size, hidden_size) 86 | self.z2 = nn.Linear(hidden_size, 1) 87 | self.next_mem = nn.Linear(3 * hidden_size, hidden_size) 88 | 89 | 90 | def init_weights(self): 91 | self.z1.weight.data.normal_(0.0, 0.02) 92 | self.z1.bias.data.fill_(0) 93 | self.z2.weight.data.normal_(0.0, 0.02) 94 | self.z2.bias.data.fill_(0) 95 | self.next_mem.weight.data.normal_(0.0, 0.02) 96 | self.next_mem.bias.data.fill_(0) 97 | self.AGRU.init_weights() 98 | 99 | 100 | def make_interaction(self, frames, questions, prevM): 101 | ''' 102 | frames.size() -> (#batch, T, #hidden = #embedding) 103 | questions.size() -> (#batch, 1, #hidden) 104 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding) 105 | z.size() -> (#batch, T, 4 x #embedding) 106 | G.size() -> (#batch, T) 107 | ''' 108 | batch_num, T, embedding_size = frames.size() 109 | questions = questions.view(questions.size(0),1,questions.size(1)) 110 | 111 | 112 | #questions = questions.expand_as(frames) 113 | #prevM = prevM.expand_as(frames) 114 | 115 | #print(questions.size(),prevM.size()) 116 | 117 | # Eq (8)~(10) 118 | z = torch.cat([ 119 | frames * questions, 120 | frames * prevM, 121 | torch.abs(frames - questions), 122 | torch.abs(frames - prevM) 123 | ], dim=2) 124 | 125 | z = z.view(-1, 4 * embedding_size) 126 | 127 | G = torch.tanh(self.z1(z)) 128 | G = self.z2(G) 129 | G = G.view(batch_num, -1) 130 | G = F.softmax(G,dim=1) 131 | #print('G size',G.size()) 132 | return G 133 | 134 | def forward(self, frames, questions, prevM): 135 | ''' 136 | frames.size() -> (#batch, #sentence, #hidden = #embedding) 137 | questions.size() -> (#batch, #sentence = 1, #hidden) 138 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding) 139 | G.size() -> (#batch, #sentence) 140 | C.size() -> (#batch, #hidden) 141 | concat.size() -> (#batch, 3 x #embedding) 142 | ''' 143 | 144 | ''' 145 | section 3.3 - Attention based GRU 146 | input: F and q, as frames and questions 147 | then get gates g 148 | then (c,m,g) feed into memory update module Eq(13) 149 | output new memory state 150 | ''' 151 | # print(frames.shape, questions.shape, prevM.shape) 152 | 153 | G = self.make_interaction(frames, questions, prevM) 154 | C = self.AGRU(frames, G) 155 | concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1) 156 | next_mem = F.relu(self.next_mem(concat)) 157 | #print(next_mem.size()) 158 | next_mem = next_mem.unsqueeze(1) 159 | return next_mem 160 | -------------------------------------------------------------------------------- /networks/memory_rand.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from Attention import MultiModalAttentionModule 5 | 6 | class MemoryRamModule(nn.Module): 7 | 8 | def __init__(self, input_size=1024, hidden_size=512, memory_bank_size=100, device=None): 9 | """Set the hyper-parameters and build the layers.""" 10 | super(MemoryRamModule, self).__init__() 11 | 12 | self.input_size = input_size 13 | self.hidden_size = hidden_size 14 | self.memory_bank_size = memory_bank_size 15 | self.device = device 16 | 17 | self.hidden_to_content = nn.Linear(hidden_size+input_size, hidden_size) 18 | #self.read_to_hidden = nn.Linear(hidden_size+input_size, 1) 19 | self.write_gate = nn.Linear(hidden_size+input_size, 1) 20 | self.write_prob = nn.Linear(hidden_size+input_size, memory_bank_size) 21 | 22 | self.read_gate = nn.Linear(hidden_size+input_size, 1) 23 | self.read_prob = nn.Linear(hidden_size+input_size, memory_bank_size) 24 | 25 | 26 | self.Wxh = nn.Parameter(torch.FloatTensor(input_size, hidden_size),requires_grad=True) 27 | self.Wrh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True) 28 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True) 29 | self.bh = nn.Parameter(torch.FloatTensor(hidden_size),requires_grad=True) 30 | 31 | self.init_weights() 32 | 33 | 34 | def init_weights(self): 35 | self.Wxh.data.normal_(0.0, 0.1) 36 | self.Wrh.data.normal_(0.0, 0.1) 37 | self.Whh.data.normal_(0.0, 0.1) 38 | self.bh.data.fill_(0) 39 | 40 | 41 | def forward(self, hidden_frames, nImg): 42 | 43 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device) 44 | memory_ram.fill_(0) 45 | 46 | h_t = torch.zeros(1, self.hidden_size).to(self.device) 47 | 48 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device) 49 | 50 | for t in range(nImg): 51 | x_t = hidden_frames[t:t+1,:] 52 | 53 | x_h_t = torch.cat([x_t,h_t],dim=1) 54 | 55 | ############# read ############ 56 | ar = torch.softmax(self.read_prob( x_h_t ),dim=1) # read prob from memories 57 | go = torch.sigmoid(self.read_gate( x_h_t )) # read gate 58 | r = go * torch.matmul(ar,memory_ram) # read vector 59 | 60 | ######### h_t ######### 61 | # Eq (17) 62 | m1 = torch.matmul(x_t, self.Wxh) 63 | m2 = torch.matmul(r, self.Wrh) 64 | m3 = torch.matmul(h_t, self.Whh) 65 | h_t_p1 = F.relu(m1 + m2 + m3 + self.bh) # Eq(17) 66 | 67 | 68 | ############# write ############ 69 | c_t = F.relu( self.hidden_to_content(x_h_t) ) # Eq(15), content vector 70 | aw = torch.softmax(self.write_prob( x_h_t ),dim=1) # write prob to memories 71 | aw = aw.view(self.memory_bank_size,1) 72 | gw = torch.sigmoid(self.write_gate( x_h_t )) # write gate 73 | #print gw.size(),aw.size(),c_t.size(),memory_ram.size() 74 | memory_ram = gw * aw * c_t + (1.0-aw) * memory_ram # Eq(16) 75 | 76 | h_t = h_t_p1 77 | hiddens[t,:] = h_t 78 | 79 | #return memory_ram 80 | return hiddens 81 | 82 | 83 | class MemoryRamTwoStreamModule(nn.Module): 84 | 85 | def __init__(self, input_size, hidden_size=512, memory_bank_size=100, device=None): 86 | """Set the hyper-parameters and build the layers.""" 87 | super(MemoryRamTwoStreamModule, self).__init__() 88 | 89 | self.input_size = input_size 90 | self.hidden_size = hidden_size 91 | self.memory_bank_size = memory_bank_size 92 | self.device = device 93 | 94 | self.hidden_to_content_a = nn.Linear(hidden_size+input_size, hidden_size) 95 | self.hidden_to_content_m = nn.Linear(hidden_size+input_size, hidden_size) 96 | 97 | self.write_prob = nn.Linear(hidden_size*3, 3) 98 | self.write_prob_a = nn.Linear(hidden_size+input_size, memory_bank_size) 99 | self.write_prob_m = nn.Linear(hidden_size+input_size, memory_bank_size) 100 | 101 | self.read_prob = nn.Linear(hidden_size*3, memory_bank_size) 102 | 103 | self.read_to_hidden = nn.Linear(hidden_size*2, hidden_size) 104 | self.read_to_hidden_a = nn.Linear(hidden_size*2+input_size, hidden_size) 105 | self.read_to_hidden_m = nn.Linear(hidden_size*2+input_size, hidden_size) 106 | self.init_weights() 107 | 108 | def init_weights(self): 109 | pass 110 | 111 | 112 | def forward(self, hidden_out_a, hidden_out_m, nImg): 113 | 114 | 115 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device) 116 | memory_ram.fill_(0) 117 | 118 | h_t_a = torch.zeros(1, self.hidden_size).to(self.device) 119 | h_t_m = torch.zeros(1, self.hidden_size).to(self.device) 120 | h_t = torch.zeros(1, self.hidden_size).to(self.device) 121 | 122 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device) 123 | 124 | for t in range(nImg): 125 | x_t_a = hidden_out_a[t:t+1,:] 126 | x_t_m = hidden_out_m[t:t+1,:] 127 | 128 | 129 | ############# read ############ 130 | x_h_t_am = torch.cat([h_t_a,h_t_m,h_t],dim=1) 131 | ar = torch.softmax(self.read_prob( x_h_t_am ),dim=1) # read prob from memories 132 | r = torch.matmul(ar,memory_ram) # read vector 133 | 134 | 135 | ######### h_t ######### 136 | # Eq (17) 137 | f_0 = torch.cat([r, h_t],dim=1) 138 | f_a = torch.cat([x_t_a, r, h_t_a],dim=1) 139 | f_m = torch.cat([x_t_m, r, h_t_m],dim=1) 140 | 141 | h_t_1 = F.relu(self.read_to_hidden(f_0)) 142 | h_t_a1 = F.relu(self.read_to_hidden_a(f_a)) 143 | h_t_m1 = F.relu(self.read_to_hidden_m(f_m)) 144 | 145 | 146 | ############# write ############ 147 | 148 | # write probability of [keep, write appearance, write motion] 149 | aw = torch.softmax(self.write_prob( x_h_t_am ),dim=1) # write prob to memories 150 | x_h_ta = torch.cat([h_t_a,x_t_a],dim=1) 151 | x_h_tm = torch.cat([h_t_m,x_t_m],dim=1) 152 | 153 | 154 | # write content 155 | c_t_a = F.relu( self.hidden_to_content_a(x_h_ta) ) # Eq(15), content vector 156 | c_t_m = F.relu( self.hidden_to_content_m(x_h_tm) ) # Eq(15), content vector 157 | 158 | aw_a = torch.softmax(self.write_prob_a( x_h_ta ),dim=1) # write prob to memories 159 | aw_m = torch.softmax(self.write_prob_m( x_h_tm ),dim=1) # write prob to memories 160 | 161 | 162 | aw_a = aw_a.view(self.memory_bank_size,1) 163 | aw_m = aw_m.view(self.memory_bank_size,1) 164 | 165 | memory_ram = aw[0,0] * memory_ram + aw[0,1] * aw_a * c_t_a + aw[0,2] * aw_m * c_t_m 166 | 167 | 168 | h_t = h_t_1 169 | h_t_a = h_t_a1 170 | h_t_m = h_t_m1 171 | 172 | hiddens[t,:] = h_t 173 | 174 | 175 | return hiddens 176 | 177 | class MMModule(nn.Module): 178 | def __init__(self, dim, input_drop_p, device): 179 | """Set the hyper-parameters and build the layers.""" 180 | super(MMModule, self).__init__() 181 | self.hidden_size = dim 182 | self.lstm_mm_1 = nn.LSTMCell(dim, dim) 183 | self.lstm_mm_2 = nn.LSTMCell(dim, dim) 184 | self.hidden_encoder_1 = nn.Linear(dim * 2, dim) 185 | self.hidden_encoder_2 = nn.Linear(dim * 2, dim) 186 | self.dropout = nn.Dropout(input_drop_p) 187 | self.mm_att = MultiModalAttentionModule(dim) 188 | self.device = device 189 | self.init_weights() 190 | 191 | 192 | def init_weights(self): 193 | nn.init.xavier_normal_(self.hidden_encoder_1.weight) 194 | nn.init.xavier_normal_(self.hidden_encoder_2.weight) 195 | self.init_hiddens() 196 | 197 | def init_hiddens(self): 198 | s_t = torch.zeros(1, self.hidden_size).to(self.device) 199 | s_t2 = torch.zeros(1, self.hidden_size).to(self.device) 200 | c_t = torch.zeros(1, self.hidden_size).to(self.device) 201 | c_t2 = torch.zeros(1, self.hidden_size).to(self.device) 202 | return s_t, s_t2, c_t, c_t2 203 | 204 | def forward(self, svt_tmp, memory_ram_vid, memory_ram_txt, loop=3): 205 | """ 206 | 207 | :param svt_tmp: 208 | :param memory_ram_vid: 209 | :param memory_ram_txt: 210 | :param loop: 211 | :return: 212 | """ 213 | 214 | sm_q1, sm_q2, cm_q1, cm_q2 = self.init_hiddens() 215 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_1(svt_tmp))) 216 | 217 | for _ in range(loop): 218 | sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1)) 219 | sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2)) 220 | 221 | mm_o1 = self.mm_att(sm_q2, memory_ram_vid, memory_ram_txt) 222 | mm_o2 = torch.cat((sm_q2, mm_o1), dim=1) 223 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_2(mm_o2))) 224 | 225 | smq = torch.cat((sm_q1, sm_q2), dim=1) 226 | 227 | return smq 228 | 229 | class MemoryRamTwoStreamModule2(nn.Module): 230 | 231 | def __init__(self, input_size, hidden_size=512, memory_bank_size=100, device=None): 232 | """Set the hyper-parameters and build the layers.""" 233 | super(MemoryRamTwoStreamModule2, self).__init__() 234 | 235 | self.input_size = input_size 236 | self.hidden_size = hidden_size 237 | self.memory_bank_size = memory_bank_size 238 | self.device = device 239 | 240 | self.hidden_to_content_a = nn.Linear(hidden_size+input_size, hidden_size) 241 | self.hidden_to_content_m = nn.Linear(hidden_size+input_size, hidden_size) 242 | 243 | self.write_prob = nn.Linear(hidden_size*3, 3) 244 | self.write_prob_a = nn.Linear(hidden_size+input_size, memory_bank_size) 245 | self.write_prob_m = nn.Linear(hidden_size+input_size, memory_bank_size) 246 | 247 | self.read_prob = nn.Linear(hidden_size*3, memory_bank_size) 248 | 249 | self.read_to_hidden = nn.Linear(hidden_size*2, hidden_size) 250 | self.read_to_hidden_a = nn.Linear(hidden_size*2+input_size, hidden_size) 251 | self.read_to_hidden_m = nn.Linear(hidden_size*2+input_size, hidden_size) 252 | self.init_weights() 253 | 254 | def init_weights(self): 255 | pass 256 | 257 | 258 | def forward(self, hidden_out_a, hidden_out_m, nImg): 259 | 260 | 261 | memory_ram = torch.FloatTensor(hidden_out_a.shape[0], self.memory_bank_size, self.hidden_size).to(self.device) 262 | memory_ram.fill_(0) 263 | 264 | h_t_a = torch.zeros(hidden_out_a.shape[0], 1, self.hidden_size).to(self.device) 265 | h_t_m = torch.zeros(hidden_out_a.shape[0], 1, self.hidden_size).to(self.device) 266 | h_t = torch.zeros(hidden_out_a.shape[0], 1, self.hidden_size).to(self.device) 267 | 268 | hiddens = torch.FloatTensor(hidden_out_a.shape[0], nImg, self.hidden_size).to(self.device) 269 | 270 | for t in range(nImg): 271 | x_t_a = hidden_out_a[:, t:t+1,:] 272 | x_t_m = hidden_out_m[:, t:t+1,:] 273 | 274 | 275 | ############# read ############ 276 | x_h_t_am = torch.cat([h_t_a,h_t_m,h_t],dim=2) 277 | ar = torch.softmax(self.read_prob( x_h_t_am ),dim=2) # read prob from memories 278 | r = torch.matmul(ar,memory_ram) # read vector 279 | 280 | 281 | ######### h_t ######### 282 | # Eq (17) 283 | f_0 = torch.cat([r, h_t],dim=2) 284 | f_a = torch.cat([x_t_a, r, h_t_a],dim=2) 285 | f_m = torch.cat([x_t_m, r, h_t_m],dim=2) 286 | 287 | h_t_1 = F.relu(self.read_to_hidden(f_0)) 288 | h_t_a1 = F.relu(self.read_to_hidden_a(f_a)) 289 | h_t_m1 = F.relu(self.read_to_hidden_m(f_m)) 290 | 291 | 292 | ############# write ############ 293 | 294 | # write probability of [keep, write appearance, write motion] 295 | aw = torch.softmax(self.write_prob( x_h_t_am ),dim=2) # write prob to memories 296 | x_h_ta = torch.cat([h_t_a,x_t_a],dim=2) 297 | x_h_tm = torch.cat([h_t_m,x_t_m],dim=2) 298 | 299 | 300 | # write content 301 | c_t_a = F.relu( self.hidden_to_content_a(x_h_ta) ) # Eq(15), content vector 302 | c_t_m = F.relu( self.hidden_to_content_m(x_h_tm) ) # Eq(15), content vector 303 | 304 | aw_a = torch.softmax(self.write_prob_a( x_h_ta ),dim=2) # write prob to memories 305 | aw_m = torch.softmax(self.write_prob_m( x_h_tm ),dim=2) # write prob to memories 306 | 307 | 308 | aw_a = aw_a.view(hidden_out_a.shape[0], self.memory_bank_size,1) 309 | aw_m = aw_m.view(hidden_out_a.shape[0], self.memory_bank_size,1) 310 | 311 | memory_ram = aw[:, 0,0].unsqueeze(1).unsqueeze(2) * memory_ram + aw[:, 0,1].unsqueeze(1).unsqueeze(2) * aw_a * c_t_a + aw[:, 0,2].unsqueeze(1).unsqueeze(2) * aw_m * c_t_m 312 | 313 | 314 | h_t = h_t_1 315 | h_t_a = h_t_a1 316 | h_t_m = h_t_m1 317 | 318 | hiddens[:, t,:] = h_t.squeeze() 319 | 320 | 321 | return hiddens 322 | 323 | class MemoryRamModule2(nn.Module): 324 | 325 | def __init__(self, input_size=1024, hidden_size=512, memory_bank_size=100, device=None): 326 | """Set the hyper-parameters and build the layers.""" 327 | super(MemoryRamModule2, self).__init__() 328 | 329 | self.input_size = input_size 330 | self.hidden_size = hidden_size 331 | self.memory_bank_size = memory_bank_size 332 | self.device = device 333 | 334 | self.hidden_to_content = nn.Linear(hidden_size+input_size, hidden_size) 335 | #self.read_to_hidden = nn.Linear(hidden_size+input_size, 1) 336 | self.write_gate = nn.Linear(hidden_size+input_size, 1) 337 | self.write_prob = nn.Linear(hidden_size+input_size, memory_bank_size) 338 | 339 | self.read_gate = nn.Linear(hidden_size+input_size, 1) 340 | self.read_prob = nn.Linear(hidden_size+input_size, memory_bank_size) 341 | 342 | 343 | self.Wxh = nn.Parameter(torch.FloatTensor(input_size, hidden_size),requires_grad=True) 344 | self.Wrh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True) 345 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True) 346 | self.bh = nn.Parameter(torch.FloatTensor(hidden_size),requires_grad=True) 347 | 348 | self.init_weights() 349 | 350 | 351 | def init_weights(self): 352 | self.Wxh.data.normal_(0.0, 0.1) 353 | self.Wrh.data.normal_(0.0, 0.1) 354 | self.Whh.data.normal_(0.0, 0.1) 355 | self.bh.data.fill_(0) 356 | 357 | 358 | def forward(self, hidden_frames, nImg): 359 | 360 | memory_ram = torch.FloatTensor(hidden_frames.shape[0], self.memory_bank_size, self.hidden_size).to(self.device) 361 | memory_ram.fill_(0) 362 | 363 | h_t = torch.zeros(hidden_frames.shape[0], 1, self.hidden_size).to(self.device) 364 | 365 | hiddens = torch.FloatTensor(hidden_frames.shape[0], nImg, self.hidden_size).to(self.device) 366 | 367 | for t in range(nImg): 368 | x_t = hidden_frames[:, t:t+1,:] 369 | 370 | x_h_t = torch.cat([x_t,h_t],dim=2) 371 | 372 | ############# read ############ 373 | ar = torch.softmax(self.read_prob( x_h_t ),dim=2) # read prob from memories 374 | go = torch.sigmoid(self.read_gate( x_h_t )) # read gate 375 | r = go * torch.matmul(ar,memory_ram) # read vector 376 | 377 | ######### h_t ######### 378 | # Eq (17) 379 | m1 = torch.matmul(x_t, self.Wxh) 380 | m2 = torch.matmul(r, self.Wrh) 381 | m3 = torch.matmul(h_t, self.Whh) 382 | h_t_p1 = F.relu(m1 + m2 + m3 + self.bh) # Eq(17) 383 | 384 | 385 | ############# write ############ 386 | c_t = F.relu( self.hidden_to_content(x_h_t) ) # Eq(15), content vector 387 | aw = torch.softmax(self.write_prob( x_h_t ),dim=2) # write prob to memories 388 | aw = aw.view(-1, self.memory_bank_size,1) 389 | gw = torch.sigmoid(self.write_gate( x_h_t )) # write gate 390 | #print gw.size(),aw.size(),c_t.size(),memory_ram.size() 391 | memory_ram = gw * aw * c_t + (1.0-aw) * memory_ram # Eq(16) 392 | 393 | h_t = h_t_p1 394 | hiddens[:, t,:] = h_t.squeeze() 395 | 396 | #return memory_ram 397 | return hiddens 398 | 399 | class MMModule2(nn.Module): 400 | def __init__(self, dim, input_drop_p, device): 401 | """Set the hyper-parameters and build the layers.""" 402 | super(MMModule2, self).__init__() 403 | self.hidden_size = dim 404 | self.lstm_mm_1 = nn.LSTMCell(dim, dim) 405 | self.lstm_mm_2 = nn.LSTMCell(dim, dim) 406 | self.hidden_encoder_1 = nn.Linear(dim * 2, dim) 407 | self.hidden_encoder_2 = nn.Linear(dim * 2, dim) 408 | self.dropout = nn.Dropout(input_drop_p) 409 | self.mm_att = MultiModalAttentionModule(dim) 410 | self.device = device 411 | self.init_weights() 412 | 413 | 414 | def init_weights(self): 415 | nn.init.xavier_normal_(self.hidden_encoder_1.weight) 416 | nn.init.xavier_normal_(self.hidden_encoder_2.weight) 417 | 418 | def init_hiddens(self, bs): 419 | s_t = torch.zeros(bs, self.hidden_size).to(self.device) 420 | s_t2 = torch.zeros(bs, self.hidden_size).to(self.device) 421 | c_t = torch.zeros(bs, self.hidden_size).to(self.device) 422 | c_t2 = torch.zeros(bs, self.hidden_size).to(self.device) 423 | return s_t, s_t2, c_t, c_t2 424 | 425 | def forward(self, svt_tmp, memory_ram_vid, memory_ram_txt, loop=3): 426 | """ 427 | 428 | :param svt_tmp: 429 | :param memory_ram_vid: 430 | :param memory_ram_txt: 431 | :param loop: 432 | :return: 433 | """ 434 | bs = svt_tmp.shape[0] 435 | sm_q1, sm_q2, cm_q1, cm_q2 = self.init_hiddens(bs) 436 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_1(svt_tmp))) 437 | 438 | for _ in range(loop): 439 | sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1)) 440 | sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2)) 441 | 442 | mm_o1 = self.mm_att(sm_q2, memory_ram_vid, memory_ram_txt) 443 | mm_o2 = torch.cat((sm_q2, mm_o1), dim=1) 444 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_2(mm_o2))) 445 | 446 | smq = torch.cat((sm_q1, sm_q2), dim=1) 447 | 448 | return smq -------------------------------------------------------------------------------- /networks/torchnlp_nn.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Parameter 2 | 3 | import torch 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class weight_drop(): 9 | 10 | def __init__(self, module, weights, dropout): 11 | for name_w in weights: 12 | w = getattr(module, name_w) 13 | del module._parameters[name_w] 14 | module.register_parameter(name_w + '_raw', Parameter(w)) 15 | 16 | self.original_module_forward = module.forward 17 | 18 | self.weights = weights 19 | self.module = module 20 | self.dropout = dropout 21 | 22 | def __call__(self, *args, **kwargs): 23 | for name_w in self.weights: 24 | raw_w = getattr(self.module, name_w + '_raw') 25 | w = torch.nn.functional.dropout( 26 | raw_w, p=self.dropout, training=self.module.training) 27 | # module.register_parameter(name_w, Parameter(w)) 28 | setattr(self.module, name_w, Parameter(w)) 29 | 30 | return self.original_module_forward(*args, **kwargs) 31 | 32 | 33 | def _weight_drop(module, weights, dropout): 34 | setattr(module, 'forward', weight_drop(module, weights, dropout)) 35 | 36 | 37 | # def _weight_drop(module, weights, dropout): 38 | # """ 39 | # Helper for `WeightDrop`. 40 | # """ 41 | 42 | # for name_w in weights: 43 | # w = getattr(module, name_w) 44 | # del module._parameters[name_w] 45 | # module.register_parameter(name_w + '_raw', Parameter(w)) 46 | 47 | # original_module_forward = module.forward 48 | 49 | # def forward(*args, **kwargs): 50 | # for name_w in weights: 51 | # raw_w = getattr(module, name_w + '_raw') 52 | # w = torch.nn.functional.dropout( 53 | # raw_w, p=dropout, training=module.training) 54 | # # module.register_parameter(name_w, Parameter(w)) 55 | # setattr(module, name_w, Parameter(w)) 56 | 57 | # return original_module_forward(*args, **kwargs) 58 | 59 | # setattr(module, 'forward', forward) 60 | 61 | 62 | class WeightDrop(torch.nn.Module): 63 | """ 64 | The weight-dropped module applies recurrent regularization through a DropConnect mask on the 65 | hidden-to-hidden recurrent weights. 66 | **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. Here is 67 | their `License 68 | `__. 69 | Args: 70 | module (:class:`torch.nn.Module`): Containing module. 71 | weights (:class:`list` of :class:`str`): Names of the module weight parameters to apply a 72 | dropout too. 73 | dropout (float): The probability a weight will be dropped. 74 | Example: 75 | >>> from torchnlp.nn import WeightDrop 76 | >>> import torch 77 | >>> 78 | >>> torch.manual_seed(123) 79 | >> 81 | >>> gru = torch.nn.GRUCell(2, 2) 82 | >>> weights = ['weight_hh'] 83 | >>> weight_drop_gru = WeightDrop(gru, weights, dropout=0.9) 84 | >>> 85 | >>> input_ = torch.randn(3, 2) 86 | >>> hidden_state = torch.randn(3, 2) 87 | >>> weight_drop_gru(input_, hidden_state) 88 | tensor(... grad_fn=) 89 | """ 90 | 91 | def __init__(self, module, weights, dropout=0.0): 92 | super(WeightDrop, self).__init__() 93 | _weight_drop(module, weights, dropout) 94 | self.forward = module.forward 95 | 96 | 97 | class WeightDropLSTM(torch.nn.LSTM): 98 | """ 99 | Wrapper around :class:`torch.nn.LSTM` that adds ``weight_dropout`` named argument. 100 | Args: 101 | weight_dropout (float): The probability a weight will be dropped. 102 | """ 103 | 104 | def __init__(self, *args, weight_dropout=0.0, **kwargs): 105 | super().__init__(*args, **kwargs) 106 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)] 107 | _weight_drop(self, weights, weight_dropout) 108 | 109 | 110 | class WeightDropGRU(torch.nn.GRU): 111 | """ 112 | Wrapper around :class:`torch.nn.GRU` that adds ``weight_dropout`` named argument. 113 | Args: 114 | weight_dropout (float): The probability a weight will be dropped. 115 | """ 116 | 117 | def __init__(self, *args, weight_dropout=0.0, **kwargs): 118 | super().__init__(*args, **kwargs) 119 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)] 120 | _weight_drop(self, weights, weight_dropout) 121 | 122 | 123 | class WeightDropLinear(torch.nn.Linear): 124 | """ 125 | Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. 126 | Args: 127 | weight_dropout (float): The probability a weight will be dropped. 128 | """ 129 | 130 | def __init__(self, *args, weight_dropout=0.0, **kwargs): 131 | super().__init__(*args, **kwargs) 132 | weights = ['weight'] 133 | _weight_drop(self, weights, weight_dropout) 134 | -------------------------------------------------------------------------------- /requirement.txt: -------------------------------------------------------------------------------- 1 | addict==2.4.0 2 | antlr4-python3-runtime==4.8 3 | appdirs==1.4.4 4 | astor==0.8.1 5 | backcall==0.2.0 6 | black==21.4b2 7 | blis==0.7.4 8 | block.bootstrap.pytorch==0.1.6 9 | bootstrap.pytorch==0.0.13 10 | brotlipy==0.7.0 11 | catalogue==2.0.6 12 | certifi==2021.5.30 13 | cloudpickle==1.6.0 14 | colorama==0.4.4 15 | contextvars==2.4 16 | cycler==0.10.0 17 | cymem==2.0.5 18 | Cython==0.29.23 19 | dataclasses==0.7 20 | decorator==4.4.2 21 | filelock==3.0.10 22 | future==0.18.2 23 | fvcore==0.1.5.post20210515 24 | google-pasta==0.2.0 25 | hydra-core==1.1.0rc1 26 | imageio==2.15.0 27 | immutables==0.16 28 | importlib-resources==5.1.3 29 | instaboostfast==0.1.2 30 | iopath==0.1.8 31 | ipdb==0.13.9 32 | ipython==7.12.0 33 | Jinja2==3.0.2 34 | Keras-Preprocessing==1.1.0 35 | kiwisolver==1.3.1 36 | language-tool-python==2.5.5 37 | MarkupSafe==2.0.1 38 | matplotlib==3.3.4 39 | mmcv-full==1.4.8 40 | model-index==0.1.11 41 | munch==2.5.0 42 | murmurhash==1.0.5 43 | mypy-extensions==0.4.3 44 | networkx==2.5.1 45 | olefile==0.46 46 | omegaconf==2.1.0rc1 47 | opencv-python==4.5.1.48 48 | openmim==0.1.5 49 | ordered-set==4.0.2 50 | pathspec==0.8.1 51 | pathy==0.6.0 52 | Pillow==8.4.0 53 | plotly==5.3.1 54 | portalocker==2.3.0 55 | preshed==3.0.5 56 | pretrainedmodels==0.7.4 57 | protobuf==3.13.0 58 | pycocotools==2.0.2 59 | pydantic==1.8.2 60 | pydot==1.4.2 61 | pyhocon==0.3.58 62 | python-dateutil==2.8.1 63 | PyWavelets==1.1.1 64 | PyYAML==5.4.1 65 | pyzmq==19.0.2 66 | scikit-image==0.17.2 67 | scikit-learn==0.24.2 68 | seaborn==0.11.2 69 | sk-video==1.1.10 70 | skipthoughts==0.0.1 71 | smart-open==5.2.1 72 | spacy==3.1.3 73 | spacy-legacy==3.0.8 74 | srsly==2.4.1 75 | stanfordnlp==0.2.0 76 | tabulate==0.8.9 77 | tenacity==8.0.1 78 | tensorboard==1.14.0 79 | tensorflow-estimator==1.14.0 80 | termcolor==1.1.0 81 | terminaltables==3.1.10 82 | thinc==8.0.10 83 | threadpoolctl==3.0.0 84 | tifffile==2020.9.3 85 | tokenizers==0.10.3 86 | toml==0.10.2 87 | torch==1.7.1 88 | torchvision==0.8.2 89 | traitlets==4.3.3 90 | transformers==4.11.3 91 | typed-args==0.4.2 92 | typed-ast==1.4.3 93 | typer==0.4.0 94 | urllib3==1.26.7 95 | wasabi==0.8.2 96 | Werkzeug==1.0.1 97 | wrapt==1.12.1 98 | yacs==0.1.8 99 | yapf==0.32.0 100 | numpy==1.19.2 101 | nltk 102 | h5py 103 | SharedArray -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import pickle as pkl 5 | import pandas as pd 6 | import logging 7 | 8 | def make_logger(log_file): 9 | logger = logging.getLogger() 10 | logger.setLevel(logging.INFO) 11 | 12 | logfile = log_file 13 | fh = logging.FileHandler(logfile) 14 | fh.setLevel(logging.DEBUG) 15 | 16 | ch = logging.StreamHandler() 17 | ch.setLevel(logging.INFO) 18 | 19 | formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s") 20 | fh.setFormatter(formatter) 21 | ch.setFormatter(formatter) 22 | 23 | logger.addHandler(fh) 24 | logger.addHandler(ch) 25 | logger.info('logfile = {}'.format(logfile)) 26 | return logger 27 | 28 | def set_gpu_devices(gpu_id): 29 | gpu = '' 30 | if gpu_id != -1: 31 | gpu = str(gpu_id) 32 | os.environ['CUDA_VOSIBLE_DEVICES'] = gpu 33 | 34 | 35 | def load_file(filename): 36 | """ 37 | load obj from filename 38 | :param filename: 39 | :return: 40 | """ 41 | cont = None 42 | if not osp.exists(filename): 43 | print('{} not exist'.format(filename)) 44 | return cont 45 | if osp.splitext(filename)[-1] == '.csv': 46 | # return pd.read_csv(filename, delimiter= '\t', index_col=0) 47 | return pd.read_csv(filename, delimiter=',') 48 | with open(filename, 'r') as fp: 49 | if osp.splitext(filename)[1] == '.txt': 50 | cont = fp.readlines() 51 | cont = [c.rstrip('\n') for c in cont] 52 | elif osp.splitext(filename)[1] == '.json': 53 | cont = json.load(fp) 54 | return cont 55 | 56 | 57 | def save_file(obj, filename): 58 | """ 59 | save obj to filename 60 | :param obj: 61 | :param filename: 62 | :return: 63 | """ 64 | filepath = osp.dirname(filename) 65 | if filepath != '' and not osp.exists(filepath): 66 | os.makedirs(filepath) 67 | with open(filename, 'w') as fp: 68 | json.dump(obj, fp, indent=4) 69 | 70 | 71 | def pkload(file): 72 | data = None 73 | if osp.exists(file) and osp.getsize(file) > 0: 74 | with open(file, 'rb') as fp: 75 | data = pkl.load(fp) 76 | # print('{} does not exist'.format(file)) 77 | return data 78 | 79 | 80 | def pkdump(data, file): 81 | dirname = osp.dirname(file) 82 | if not osp.exists(dirname): 83 | os.makedirs(dirname) 84 | with open(file, 'wb') as fp: 85 | pkl.dump(data, fp) 86 | -------------------------------------------------------------------------------- /videoqa.py: -------------------------------------------------------------------------------- 1 | from networks import Embed_loss, EncoderRNN, CRN 2 | from networks.VQAModel import EVQA, HCRN, CoMem, HME, HGA, B2A 3 | from utils import * 4 | from torch.optim.lr_scheduler import ReduceLROnPlateau 5 | import torch 6 | import torch.nn as nn 7 | import time 8 | 9 | 10 | class VideoQA(): 11 | def __init__(self, vocab, train_loader, val_loader, test_loader, glove_embed, use_bert, checkpoint_path, model_type, 12 | model_prefix, vis_step, lr_rate, batch_size, epoch_num, logger, args): 13 | self.vocab = vocab 14 | self.train_loader = train_loader 15 | self.val_loader = val_loader 16 | self.test_loader = test_loader 17 | self.glove_embed = glove_embed 18 | self.use_bert = use_bert 19 | self.model_dir = checkpoint_path 20 | self.model_type = model_type 21 | self.model_prefix = model_prefix 22 | self.vis_step = vis_step 23 | self.lr_rate = lr_rate 24 | self.batch_size = batch_size 25 | self.epoch_num = epoch_num 26 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 27 | self.model = None 28 | self.logger = logger 29 | self.args = args 30 | 31 | def build_model(self): 32 | 33 | vid_dim = self.args.vid_dim 34 | hidden_dim = self.args.hidden_dim 35 | word_dim = self.args.word_dim 36 | vocab_size = len(self.vocab) 37 | max_vid_len = self.args.max_vid_len 38 | max_vid_frame_len = self.args.max_vid_frame_len 39 | max_qa_len = self.args.max_qa_len 40 | spl_resolution = self.args.spl_resolution 41 | 42 | if self.model_type == 'EVQA' or self.model_type == 'BlindQA': 43 | #ICCV15, AAAI17 44 | vid_encoder = EncoderRNN.EncoderVid(vid_dim, hidden_dim, input_dropout_p=0.2, n_layers=1, rnn_dropout_p=0, bidirectional=False, rnn_cell='gru') 45 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, input_dropout_p=0.2, rnn_dropout_p=0, bidirectional=False, rnn_cell='gru') 46 | 47 | self.model = EVQA.EVQA(vid_encoder, qns_encoder, self.device, self.model_type == 'BlindQA') 48 | 49 | elif self.model_type == 'CoMem': 50 | #CVPR18 51 | app_dim = 2048 52 | motion_dim = 2048 53 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru') 54 | 55 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=2, rnn_dropout_p=0.5, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru') 56 | 57 | self.model = CoMem.CoMem(vid_encoder, qns_encoder, max_vid_len, max_qa_len, self.device) 58 | 59 | elif self.model_type == 'HME': 60 | #CVPR19 61 | app_dim = 2048 62 | motion_dim = 2048 63 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru') 64 | 65 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=2, rnn_dropout_p=0.5, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru') 66 | 67 | self.model = HME.HME(vid_encoder, qns_encoder, max_vid_len, max_qa_len*2, self.device) 68 | 69 | elif self.model_type == 'HGA': 70 | #AAAI20 71 | vid_encoder = EncoderRNN.EncoderVidHGA(vid_dim, hidden_dim, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru') 72 | 73 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, rnn_dropout_p=0, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru') 74 | 75 | self.model = HGA.HGA(vid_encoder, qns_encoder, self.device) 76 | 77 | elif self.model_type == 'HCRN': 78 | #CVPR20 79 | vid_dim = vid_dim//2 80 | vid_encoder = CRN.EncoderVidCRN(max_vid_frame_len, max_vid_len, spl_resolution, vid_dim, hidden_dim) 81 | 82 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, rnn_dropout_p=0.2, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru') 83 | 84 | self.model = HCRN.HCRN(vid_encoder, qns_encoder, self.device) 85 | 86 | elif self.model_type == 'B2A' or self.model_type == 'B2A2': 87 | #CVPR21 88 | vid_dim = vid_dim // 2 89 | vid_encoder = EncoderRNN.EncoderVidB2A(vid_dim, hidden_dim*2, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru') 90 | 91 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, rnn_dropout_p=0, input_dropout_p=0.3, bidirectional=True, rnn_cell='gru') 92 | 93 | self.model = B2A.B2A(vid_encoder, qns_encoder, self.device) 94 | 95 | 96 | 97 | params = [{'params':self.model.parameters()}] 98 | 99 | self.optimizer = torch.optim.Adam(params = params, lr=self.lr_rate) 100 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'max', factor=0.5, patience=5, verbose=True) 101 | 102 | self.model.to(self.device) 103 | self.criterion = Embed_loss.MultipleChoiceLoss().to(self.device) 104 | 105 | 106 | def save_model(self, epoch, acc, is_best=False): 107 | if not is_best: 108 | torch.save(self.model.state_dict(), osp.join(self.model_dir, self.model_type, self.model_prefix, 'model', '{}-{:.2f}.ckpt' 109 | .format(epoch, acc))) 110 | else: 111 | torch.save(self.model.state_dict(), osp.join(self.model_dir, self.model_type, self.model_prefix, 'model', 'best.ckpt')) 112 | 113 | def resume(self, model_file): 114 | """ 115 | initialize model with pretrained weights 116 | :return: 117 | """ 118 | self.logger.info('Warm-start (or test) with model: {}'.format(model_file)) 119 | model_dict = torch.load(model_file) 120 | new_model_dict = {} 121 | for k, v in self.model.state_dict().items(): 122 | if k in model_dict: 123 | v = model_dict[k] 124 | else: 125 | pass 126 | # print(k) 127 | new_model_dict[k] = v 128 | self.model.load_state_dict(new_model_dict) 129 | 130 | 131 | def run(self, model_file, pre_trained=False): 132 | self.build_model() 133 | self.logger.info(self.model) 134 | best_eval_score = 0.0 135 | if pre_trained: 136 | self.resume(model_file) 137 | best_eval_score = self.eval(0) 138 | self.logger.info('Initial Acc {:.2f}'.format(best_eval_score)) 139 | 140 | for epoch in range(0, self.epoch_num): 141 | train_loss, train_acc = self.train(epoch) 142 | eval_score = self.eval(epoch) 143 | eval_score_test = self.eval_t(epoch) 144 | self.logger.info("==>Epoch:[{}/{}][Train Loss: {:.4f}; Train acc: {:.2f}; Val acc: {:.2f}; Test acc: {:.2f}]". 145 | format(epoch, self.epoch_num, train_loss, train_acc, eval_score, eval_score_test)) 146 | self.scheduler.step(eval_score) 147 | self.save_model(epoch, eval_score) 148 | if eval_score > best_eval_score: 149 | best_eval_score = eval_score 150 | self.save_model(epoch, best_eval_score, True) 151 | 152 | def train(self, epoch): 153 | self.logger.info('==>Epoch:[{}/{}][lr_rate: {}]'.format(epoch, self.epoch_num, self.optimizer.param_groups[0]['lr'])) 154 | self.model.train() 155 | total_step = len(self.train_loader) 156 | epoch_loss = 0.0 157 | prediction_list = [] 158 | answer_list = [] 159 | for iter, inputs in enumerate(self.train_loader): 160 | visual, can, ques, ans_id, qns_key = inputs 161 | app_inputs = visual[0].to(self.device) 162 | mot_inputs = visual[1].to(self.device) 163 | candidate = can[0].to(self.device) 164 | candidate_lengths = can[1] 165 | obj_fea_can = can[2].to(self.device) 166 | dep_adj_can = can[3].to(self.device) 167 | question = ques[0].to(self.device) 168 | ques_lengths = ques[1] 169 | obj_fea_q = ques[2].to(self.device) 170 | dep_adj_q = ques[3].to(self.device) 171 | ans_targets = ans_id.to(self.device) 172 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q) 173 | 174 | self.model.zero_grad() 175 | loss = self.criterion(out, ans_targets) 176 | if not torch.isnan(loss): 177 | loss.backward() 178 | else: 179 | print(out) 180 | print(ans_targets) 181 | self.optimizer.step() 182 | epoch_loss += loss.item() 183 | if iter % self.vis_step == 0: 184 | self.logger.info('\t[{}/{}] Training loss: {:.4f}'.format(iter, total_step, epoch_loss/(iter+1))) 185 | 186 | prediction_list.append(prediction) 187 | answer_list.append(ans_id) 188 | 189 | predict_answers = torch.cat(prediction_list, dim=0).long().cpu() 190 | ref_answers = torch.cat(answer_list, dim=0).long() 191 | acc_num = torch.sum(predict_answers==ref_answers).numpy() 192 | print(len(ref_answers)) 193 | 194 | return epoch_loss / total_step, acc_num*100.0 / len(ref_answers) 195 | 196 | 197 | def eval(self, epoch): 198 | self.logger.info('==>Epoch:[{}/{}][validation stage]'.format(epoch, self.epoch_num)) 199 | self.model.eval() 200 | total_step = len(self.val_loader) 201 | acc_count = 0 202 | prediction_list = [] 203 | answer_list = [] 204 | with torch.no_grad(): 205 | for iter, inputs in enumerate(self.val_loader): 206 | visual, can, ques, ans_id, qns_key = inputs 207 | app_inputs = visual[0].to(self.device) 208 | mot_inputs = visual[1].to(self.device) 209 | candidate = can[0].to(self.device) 210 | candidate_lengths = can[1] 211 | obj_fea_can = can[2].to(self.device) 212 | dep_adj_can = can[3].to(self.device) 213 | question = ques[0].to(self.device) 214 | ques_lengths = ques[1] 215 | obj_fea_q = ques[2].to(self.device) 216 | dep_adj_q = ques[3].to(self.device) 217 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q) 218 | 219 | prediction_list.append(prediction) 220 | answer_list.append(ans_id) 221 | 222 | predict_answers = torch.cat(prediction_list, dim=0).long().cpu() 223 | ref_answers = torch.cat(answer_list, dim=0).long() 224 | acc_num = torch.sum(predict_answers == ref_answers).numpy() 225 | print(len(ref_answers)) 226 | 227 | return acc_num*100.0 / len(ref_answers) 228 | 229 | def eval_t(self, epoch): 230 | self.logger.info('==>Epoch:[{}/{}][test stage]'.format(epoch, self.epoch_num)) 231 | self.model.eval() 232 | total_step = len(self.test_loader) 233 | acc_count = 0 234 | prediction_list = [] 235 | answer_list = [] 236 | with torch.no_grad(): 237 | for iter, inputs in enumerate(self.test_loader): 238 | visual, can, ques, ans_id, qns_key = inputs 239 | app_inputs = visual[0].to(self.device) 240 | mot_inputs = visual[1].to(self.device) 241 | candidate = can[0].to(self.device) 242 | candidate_lengths = can[1] 243 | obj_fea_can = can[2].to(self.device) 244 | dep_adj_can = can[3].to(self.device) 245 | question = ques[0].to(self.device) 246 | ques_lengths = ques[1] 247 | obj_fea_q = ques[2].to(self.device) 248 | dep_adj_q = ques[3].to(self.device) 249 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q) 250 | 251 | prediction_list.append(prediction) 252 | answer_list.append(ans_id) 253 | 254 | predict_answers = torch.cat(prediction_list, dim=0).long().cpu() 255 | ref_answers = torch.cat(answer_list, dim=0).long() 256 | acc_num = torch.sum(predict_answers == ref_answers).numpy() 257 | print(len(ref_answers)) 258 | 259 | return acc_num*100.0 / len(ref_answers) 260 | 261 | 262 | def predict(self, model_file, result_file, loader): 263 | """ 264 | predict the answer with the trained model 265 | :param model_file: 266 | :return: 267 | """ 268 | self.build_model() 269 | self.resume(model_file) 270 | 271 | self.model.eval() 272 | results = {} 273 | with torch.no_grad(): 274 | for iter, inputs in enumerate(loader): 275 | visual, can, ques, ans_id, qns_key = inputs 276 | app_inputs = visual[0].to(self.device) 277 | mot_inputs = visual[1].to(self.device) 278 | candidate = can[0].to(self.device) 279 | candidate_lengths = can[1] 280 | obj_fea_can = can[2].to(self.device) 281 | dep_adj_can = can[3].to(self.device) 282 | question = ques[0].to(self.device) 283 | ques_lengths = ques[1] 284 | obj_fea_q = ques[2].to(self.device) 285 | dep_adj_q = ques[3].to(self.device) 286 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q) 287 | 288 | prediction = prediction.data.cpu().numpy() 289 | ans_id = ans_id.numpy() 290 | for qid, pred, ans in zip(qns_key, prediction, ans_id): 291 | results[qid] = {'prediction': int(pred), 'answer': int(ans)} 292 | 293 | print(len(results)) 294 | print(result_file) 295 | save_file(results, result_file) 296 | --------------------------------------------------------------------------------