├── .gitignore
├── LICENSE
├── README.md
├── bash
├── eval.sh
├── train_bert.sh
└── train_glove.sh
├── dataset
├── __init__.py
├── dataset.py
├── load.py
├── release.py
└── util.py
├── eval_mc.py
├── fig
└── example.png
├── main_qa.py
├── networks
├── Attention.py
├── CRN.py
├── Embed_loss.py
├── EncoderRNN.py
├── GCN.py
├── Transformer.py
├── VQAModel
│ ├── B2A.py
│ ├── CoMem.py
│ ├── EVQA.py
│ ├── HCRN.py
│ ├── HGA.py
│ └── HME.py
├── memory_module.py
├── memory_rand.py
└── torchnlp_nn.py
├── requirement.txt
├── utils.py
└── videoqa.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | __ignore__
3 | data
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 BCMI
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Causal-VidQA
2 |
3 | ## News
4 |
5 | * [2024.07.11] We release the answer for the [test set](https://cloud.bcmi.sjtu.edu.cn/sharing/aU4Skr9EJ). You can download them and put them into the ```['data/QA']``` to use them.
6 |
7 |
8 | ## Introduction
9 |
10 | The Causal-VidQA dataset contains 107,600 QA pairs from the [Causal-VidQA dataset](https://arxiv.org/pdf/2205.14895.pdf). The dataset aims to facilitate deeper video understanding towards video reasoning. In detail, we present the task of Causal-VidQA, which includes four types of questions ranging from scene description (description) to evidence reasoning (explanation) and commonsense reasoning (prediction and counterfactual). For commonsense reasoning, we set up a two-step solution by answering the question and providing a proper reason.
11 |
12 | Here is an example from our dataset and the comparison between our dataset and other VisualQA datasets.
13 |
14 |

15 | Example from our Causal-VidQA Dataset
16 |
17 | | Dataset | Visual Type | Visual Source | Annotation | Description | Explanation | Prediction | Counterfactual | \#Video/Image | \#QA | Video Length (s) |
18 | |:--------------:|:-----------:|:-------------:|:----------:|:-----------:|:-----------:|:----------:|:--------------:|:-------------:|:-------:|:----------------:|
19 | | Motivation | Image | MS COCO | Man | ✔ | ✔ | ✔ | $\times$ | 10,191 | - | - |
20 | | VCR | Image | Movie Clip | Man | ✔ | ✔ | ✔ | $\times$ | 110,000 | 290,000 | - |
21 | | MovieQA | Video | Movie Stories | Auto | ✔ | ✔ | $\times$ | $\times$ | 548 | 21,406 | 200 |
22 | | TVQA | Video | TV Show | Man | ✔ | ✔ | $\times$ | $\times$ | 21,793 | 152,545 | 76 |
23 | | TGIF-QA | Video | TGIF | Auto | ✔ | $\times$ | $\times$ | $\times$ | 71,741 | 165,165 | 3 |
24 | | ActivityNet-QA | Video | ActivityNet | Man | ✔ | ✔ | $\times$ | $\times$ | 5,800 | 58,000 | 180 |
25 | | Social-IQ | Video | YouTube | Man | ✔ | ✔ | $\times$ | $\times$ | 1,250 | 7,500 | 60 |
26 | | CLEVRER | Video | Game Engine | Man | ✔ | ✔ | ✔ | ✔ | 20,000 | 305,280 | 5 |
27 | | V2C | Video | MSR-VTT | Man | ✔ | ✔ | $\times$ | $\times$ | 10,000 | 115,312 | 30 |
28 | | NExT-QA | Video | YFCC-100M | Man | ✔ | ✔ | $\times$ | $\times$ | 5,440 | 52,044 | 44 |
29 | | Causal-VidQA | Video | Kinetics-700 | Man | ✔ | ✔ | ✔ | ✔ | 26,900 | 107,600 | 9 |
30 |
31 | Comparison between our dataset and other VisualQA datasets
32 |
33 | In this page, you can find the code of some SOTA VideoQA methods and the dataset for our **CVPR** conference paper.
34 |
35 | * Jiangtong Li, Li Niu and Liqing Zhang. *From Representation to Reasoning: Towards both Evidence and Commonsense Reasoning for Video Question-Answering*. *CVPR*, 2022. [[paper link]](https://arxiv.org/pdf/2205.14895.pdf)
36 |
37 | ## Download
38 | 1. [Visual Feature](https://cloud.bcmi.sjtu.edu.cn/sharing/ZI1F0Hfd0)
39 | 2. [Text Feature](https://cloud.bcmi.sjtu.edu.cn/sharing/NeiJfafJq)
40 | 3. [Dataset Split](https://cloud.bcmi.sjtu.edu.cn/sharing/6kEtHMarE)
41 | 4. [Text annotation](https://cloud.bcmi.sjtu.edu.cn/sharing/aszEJs8VX)
42 | 5. [Original Data](https://cloud.bcmi.sjtu.edu.cn/sharing/FYDmyDwff)
43 |
44 | ## Install
45 | Please create an env for this project using miniconda (should install [miniconda](https://docs.conda.io/en/latest/miniconda.html) first)
46 | ```
47 | >conda create -n causal-vidqa python==3.6.12
48 | >conda activate causal-vidqa
49 | >git clone https://github.com/bcmi/Causal-VidQA
50 | >pip install -r requirement.txt
51 | ```
52 |
53 | ## Data Preparation
54 | Please download the pre-computed features and QA annotations from [Download 1-4](##Download).
55 | And place them in ```['data/visual_feature']```, ```['data/text_feature']```, ```['data/split']``` and ```['data/QA']```. Note that the ```Text annotation``` is package as QA.tar, you need to unpack it first before place it to ```['data/QA']```.
56 |
57 | If you want to extract different video features and text features from our Causal-VidQA dataset, you can download the original data from [Download 5](##Download) and do whatever your want to extract features.
58 |
59 | ## Usage
60 | Once the data is ready, you can easily run the code. First, to run these models with GloVe feature, you can directly train the B2A by:
61 | ```
62 | >sh bash/train_glove.sh
63 | ```
64 | Note that if you want to train the model with BERT feature, we suggest your to first load the BERT feature to sharedarray by:
65 | ```
66 | >python dataset/load.py
67 | ```
68 | and then train the B2A with BERT feature by:
69 | ```
70 | >sh bash/train_bert.sh.
71 | ```
72 | After the train shell file is conducted, you can find the the prediction file under ```['results/model_name/model_prefix.json']``` and you can evaluate the prediction results by:
73 | ```
74 | >python eval_mc.py
75 | ```
76 | You can also obtain the prediction by running:
77 | ```
78 | >sh bash/eval.sh
79 | ```
80 | The command above will load the model from ```['experiment/model_name/model_prefix/model/best.pkl']``` and generate the prediction file.
81 |
82 | Hint: we have release a trained [model](https://cloud.bcmi.sjtu.edu.cn/sharing/c5IKQVMrM) for ```B2A``` method, please place this the trained weight in ```['experiment/B2A/B2A/model/best.pkl']``` and then make prediction by running:
83 | ```
84 | >sh bash/eval.sh
85 | ```
86 |
87 | (*The results may be slightly different depending on the environments and random seeds.*)
88 |
89 | (*For comparison, please refer to the results in our paper.*)
90 |
91 | ## Citation
92 | ```
93 | @InProceedings{li2022from,
94 | author = {Li, Jiangtong and Niu, Li and Zhang, Liqing},
95 | title = {From Representation to Reasoning: Towards both Evidence and Commonsense Reasoning for Video Question-Answering},
96 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
97 | month = {June},
98 | year = {2022}
99 | }
100 | ```
101 | ## Acknowledgement
102 | Our reproduction of the methods is mainly based on the [Next-QA](https://github.com/doc-doc/NExT-QA) and other respective official repositories, we thank the authors to release their code. If you use the related part, please cite the corresponding paper commented in the code.
103 |
--------------------------------------------------------------------------------
/bash/eval.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_qa.py \
2 | --mode test \
3 | --feature_path ./data/visual_feature/ \
4 | --text_feature_path ./data/text_feature/ \
5 | --data_path ./data/QA/ \
6 | --split_path ./data/split/ \
7 | --checkpoint_path ./experiment \
8 | --model_type B2A \
9 | --model_prefix B2A \
10 | --result_file ./result/{}/{}_{}.json \
11 | --vid_dim 4096 \
12 | --hidden_dim 256 \
13 | --word_dim 300 \
14 | --max_vid_len 16 \
15 | --max_qa_len 40 \
16 | --epoch_num 30 \
17 | --lr_rate 2e-4 \
18 | --batch_size 32
--------------------------------------------------------------------------------
/bash/train_bert.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_qa.py \
2 | --mode train \
3 | --feature_path ./data/visual_feature/ \
4 | --text_feature_path ./data/text_feature/ \
5 | --data_path ./data/QA/ \
6 | --split_path ./data/split/ \
7 | --checkpoint_path ./experiment \
8 | --model_type B2A \
9 | --model_prefix B2A_bert \
10 | --result_file ./result/{}/{}_{}.json \
11 | --vid_dim 4096 \
12 | --hidden_dim 128 \
13 | --word_dim 300 \
14 | --max_vid_len 16 \
15 | --max_qa_len 40 \
16 | --epoch_num 30 \
17 | --lr_rate 2e-4 \
18 | --batch_size 128 \
19 | --use_bert
--------------------------------------------------------------------------------
/bash/train_glove.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0 python main_qa.py \
2 | --mode train \
3 | --feature_path ./data/visual_feature/ \
4 | --text_feature_path ./data/text_feature/ \
5 | --data_path ./data/QA/ \
6 | --split_path ./data/split/ \
7 | --checkpoint_path ./experiment \
8 | --model_type B2A \
9 | --model_prefix B2A_glove \
10 | --result_file ./result/{}/{}_{}.json \
11 | --vid_dim 4096 \
12 | --hidden_dim 128 \
13 | --word_dim 300 \
14 | --max_vid_len 16 \
15 | --max_qa_len 40 \
16 | --epoch_num 30 \
17 | --lr_rate 2e-4 \
18 | --batch_size 128
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | from .dataset import VidQADataset, Vocabulary
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import enum
2 | from numpy import random
3 | import torch
4 | from torch.utils.data import Dataset
5 | import os.path as osp
6 | import numpy as np
7 | import nltk
8 | import h5py
9 | import os
10 | import json
11 | import pickle as pkl
12 | import pandas as pd
13 | from tqdm import tqdm
14 | import re
15 | import stanfordnlp
16 | import SharedArray as sa
17 |
18 | class Vocabulary(object):
19 | """Simple vocabulary wrapper."""
20 | def __init__(self, word2idx, idx2word):
21 | self.word2idx = word2idx
22 | self.idx2word = idx2word
23 | self.idx = len(idx2word)
24 |
25 | def add_word(self, word):
26 | if not word in self.word2idx:
27 | self.word2idx[word] = self.idx
28 | self.idx2word[self.idx] = word
29 | self.idx += 1
30 |
31 | def __call__(self, word):
32 | if not word in self.word2idx:
33 | return self.word2idx['']
34 | return self.word2idx[word]
35 |
36 | def __len__(self):
37 | return len(self.word2idx)
38 |
39 | class VidQADataset(Dataset):
40 | """load the dataset in dataloader"""
41 |
42 | def __init__(self, feature_path, text_feature_path, split_path, data_path, use_bert, vocab, qtype=-1, max_length=40):
43 | self.feature_path = feature_path
44 | self.text_feature_path = text_feature_path
45 | self.split_path = split_path
46 | self.data_path = data_path
47 | self.qtype = qtype
48 |
49 | self.vocab = vocab
50 | self.vids = pkload(self.split_path)
51 |
52 | self.max_length = max_length
53 | self.use_bert = use_bert
54 | if self.use_bert:
55 | self.bert_file = osp.join(text_feature_path, 'text_seq.h5')
56 | self.bert_length = osp.join(text_feature_path, 'text_seq_length.pkl')
57 | self.bert_token = osp.join(text_feature_path, 'token_org.pkl')
58 | with open(self.bert_token, 'rb') as fbt:
59 | self.token_dict = pkl.load(fbt)
60 | with open(self.bert_length, 'rb') as fl:
61 | self.length_dict = pkl.load(fl)
62 |
63 | if self.use_bert:
64 | self.adj_path = osp.join(text_feature_path, 'bert_adj_dict.pkl')
65 | else:
66 | self.adj_path = osp.join(text_feature_path, 'glove_adj_dict.pkl')
67 | with open(self.adj_path, 'rb') as fbt:
68 | self.token_adj = pkl.load(fbt)
69 |
70 | vf_info = pkload(osp.join(feature_path, 'idx2vid.pkl'))
71 | self.vf_info = dict()
72 | for idx, vid in enumerate(vf_info):
73 | if vid in self.vids:
74 | self.vf_info[vid] = idx
75 | app_file = osp.join(feature_path, 'appearance_feat.h5')
76 | mot_file = osp.join(feature_path, 'motion_feat.h5')
77 | print('Load {}...'.format(app_file))
78 | self.app_feats = dict()
79 | with h5py.File(app_file, 'r') as fp:
80 | feats = fp['resnet_features']
81 | for vid, idx in self.vf_info.items():
82 | self.app_feats[vid] = feats[idx][...]
83 | print('Load {}...'.format(mot_file))
84 | self.mot_feats = dict()
85 | with h5py.File(mot_file, 'r') as fp:
86 | feats = fp['resnet_features']
87 | for vid, idx in self.vf_info.items():
88 | self.mot_feats[vid] = feats[idx][...]
89 |
90 | self.txt_obj = dict()
91 | print('Load {}...'.format(osp.join(self.feature_path, 'ROI_text.h5')))
92 | with h5py.File(osp.join(self.feature_path, 'ROI_text.h5'), 'r') as f:
93 | keys = [item for item in self.vids if item in f.keys()]
94 | for key in keys:
95 | tmp = dict()
96 | labels = f[key].keys()
97 | for label in labels:
98 | new_label = '[' + label + ']'
99 | tmp[new_label] = f[key][label][...]
100 | self.txt_obj[key] = tmp
101 |
102 | def __len__(self):
103 | if self.qtype == -1:
104 | return len(self.vids)*6
105 | elif self.qtype == 0 or self.qtype == 1:
106 | return len(self.vids)
107 | elif self.qtype == 2 or self.qtype == 3:
108 | return len(self.vids)*2
109 |
110 | def get_video_feature(self, video_name):
111 | """
112 | :param video_name:
113 | :return:
114 | """
115 | app_feat = self.app_feats[video_name]
116 | mot_feat = self.mot_feats[video_name]
117 |
118 | return torch.from_numpy(app_feat).type(torch.float32), torch.from_numpy(mot_feat).type(torch.float32)
119 |
120 | def get_word_idx(self, text):
121 | """
122 | """
123 | tokens = nltk.tokenize.word_tokenize(str(text).lower())
124 | token_ids = [self.vocab(token) for i, token in enumerate(tokens) if i < (self.max_length - 2)]
125 |
126 | return token_ids
127 |
128 | def get_token_seq(self, text):
129 | """
130 | """
131 | tokens = nltk.tokenize.word_tokenize(str(text).lower())
132 | return tokens
133 |
134 | def get_adj(self, vidx, qtype):
135 | adj_vidx = self.token_adj[vidx]
136 | qas_adj = adj_vidx[6+qtype*5:11+qtype*5]
137 | ques_adj = adj_vidx[qtype]
138 | qas_adj_new = np.zeros((len(qas_adj), self.max_length, self.max_length))
139 | ques_adj_new = np.zeros((self.max_length, self.max_length))
140 | for idx, item in enumerate(qas_adj):
141 | if item.shape[0] > self.max_length:
142 | qas_adj_new[idx] = item[:self.max_length, :self.max_length]
143 | else:
144 | qas_adj_new[idx, :item.shape[0], :item.shape[1]] = item
145 | if ques_adj.shape[0] > self.max_length:
146 | ques_adj_new = ques_adj[:self.max_length, :self.max_length]
147 | else:
148 | ques_adj_new[:ques_adj.shape[0], :ques_adj.shape[1]] = ques_adj
149 | return qas_adj_new, ques_adj_new
150 |
151 | def get_trans_matrix(self, candidates):
152 |
153 | qa_lengths = [len(qa) for qa in candidates]
154 | candidates_matrix = torch.zeros([5, self.max_length]).long()
155 | for k in range(5):
156 | sentence = candidates[k]
157 | length = qa_lengths[k]
158 | if length > self.max_length:
159 | length = self.max_length
160 | candidates_matrix[k] = torch.Tensor(sentence[:length])
161 | else:
162 | candidates_matrix[k, :length] = torch.Tensor(sentence)
163 |
164 | return candidates_matrix, qa_lengths
165 |
166 | def get_ques_matrix(self, ques):
167 |
168 | q_lengths = len(ques)
169 | ques_matrix = torch.zeros([self.max_length]).long()
170 | ques_matrix[:q_lengths] = torch.Tensor(ques)
171 |
172 | return ques_matrix, q_lengths
173 |
174 | def get_tagname(self, line):
175 | tag = set()
176 | tmp_tag = re.findall(r"\[(.+?)\]", line)
177 | for item in tmp_tag:
178 | tag.add('['+item+']')
179 | return list(tag)
180 |
181 | def match_tok_tag(self, labels, tags, tok):
182 | tok_tag = [None for _ in range(len(tok))]
183 | if labels == list():
184 | return tok_tag
185 | for tag in tags:
186 | for idx in range(len(tok)):
187 | if tag.startswith(tok[idx]):
188 | new_idx = idx
189 | while not tag.endswith(tok[new_idx]):
190 | new_idx += 1
191 | new_tag = ''.join(tok[idx:new_idx+1])
192 | if new_tag == tag:
193 | for i in range(idx, new_idx+1):
194 | tok_tag[i] = tag
195 | if tag not in labels:
196 | label = random.choice(labels)
197 | for index, item in enumerate(tok_tag):
198 | if item == tag:
199 | tok_tag[index] = label
200 | else:
201 | pass
202 | return tok_tag
203 |
204 | def load_txt_obj(self, vid, tok, org):
205 | if vid in self.txt_obj:
206 | labels = list(self.txt_obj[vid].keys())
207 | else:
208 | labels = list()
209 | fea = list()
210 | for idx in range(len(tok)):
211 | tags = self.get_tagname(org[idx])
212 | tok_tag = self.match_tok_tag(labels, tags, tok[idx])
213 | fea_each = list()
214 | for item in tok_tag:
215 | if item is None:
216 | fea_each.append(np.zeros((2048,)))
217 | else:
218 | fea_each.append(self.txt_obj[vid][item])
219 | fea_each = np.stack(fea_each, axis=0)
220 | new_fea_each = np.zeros((self.max_length, 2048))
221 | if fea_each.shape[0] > self.max_length:
222 | new_fea_each = fea_each[:self.max_length]
223 | else:
224 | new_fea_each[:fea_each.shape[0]] = fea_each
225 |
226 | fea.append(new_fea_each)
227 | return fea
228 |
229 | def load_text(self, vid, qtype):
230 | text_file = os.path.join(self.data_path, vid, 'text.json')
231 | answer_file = os.path.join(self.data_path, vid, 'answer.json')
232 | with open(text_file, 'r') as fin:
233 | text = json.load(fin)
234 | with open(answer_file, 'r') as fin:
235 | answer = json.load(fin)
236 | if qtype == 0:
237 | qns = text['descriptive']['question']
238 | cand_ans = text['descriptive']['answer']
239 | ans_id = answer['descriptive']['answer']
240 | if qtype == 1:
241 | qns = text['explanatory']['question']
242 | cand_ans = text['explanatory']['answer']
243 | ans_id = answer['explanatory']['answer']
244 | if qtype == 2:
245 | qns = text['predictive']['question']
246 | cand_ans = text['predictive']['answer']
247 | ans_id = answer['predictive']['answer']
248 | if qtype == 3:
249 | qns = text['predictive']['question']
250 | cand_ans = text['predictive']['reason']
251 | ans_id = answer['predictive']['reason']
252 | if qtype == 4:
253 | qns = text['counterfactual']['question']
254 | cand_ans = text['counterfactual']['answer']
255 | ans_id = answer['counterfactual']['answer']
256 | if qtype == 5:
257 | qns = text['counterfactual']['question']
258 | cand_ans = text['counterfactual']['reason']
259 | ans_id = answer['counterfactual']['reason']
260 | return qns, cand_ans, ans_id
261 |
262 | def load_text_bert(self, vid, qtype):
263 | with h5py.File(self.bert_file, 'r') as fp:
264 | feature = sa.attach("shm://{}".format(vid))
265 | token_org = self.token_dict[vid]
266 | length = self.length_dict[vid]
267 | cand = feature[6+qtype*5:11+qtype*5]
268 | tok = token_org[0][6+qtype*5:11+qtype*5]
269 | org = token_org[1][6+qtype*5:11+qtype*5]
270 | cand_l = length[6+qtype*5:11+qtype*5]
271 | question = feature[qtype]
272 | tok_q = [token_org[0][qtype], ]
273 | org_q = [token_org[1][qtype], ]
274 | qns_len = length[qtype]
275 | dim = cand.shape[2]
276 | new_candidate = np.zeros((5, self.max_length, dim))
277 | new_question = np.zeros((self.max_length, dim))
278 | for idx, qa_l in enumerate(cand_l):
279 | if qa_l > self.max_length:
280 | new_candidate[idx] = cand[idx, :self.max_length]
281 | else:
282 | new_candidate[idx, :qa_l] = cand[idx, :qa_l]
283 | if qns_len > self.max_length:
284 | new_question = question[:self.max_length]
285 | else:
286 | new_question[:qns_len] = question[:qns_len]
287 | return torch.from_numpy(new_candidate).type(torch.float32), tok, org, cand_l, torch.from_numpy(new_question).type(torch.float32), tok_q, org_q, qns_len
288 |
289 | def __getitem__(self, idx):
290 | """
291 | """
292 | if self.qtype == -1:
293 | qtype = idx % 6
294 | idx = idx // 6
295 | elif self.qtype == 0 or self.qtype == 1:
296 | qtype = self.qtype
297 | elif self.qtype == 2:
298 | qtype = 2 + (idx % 2)
299 | idx = idx // 2
300 | elif self.qtype == 3:
301 | qtype = 4 + (idx % 2)
302 | idx = idx // 2
303 | vidx = self.vids[idx]
304 | # load text
305 | qns, cand_ans, ans_id = self.load_text(vidx, qtype)
306 | if self.use_bert:
307 | candidate, tok, org, can_lengths, question, tok_q, org_q, qns_len = self.load_text_bert(vidx, qtype)
308 | else:
309 | tok_q = [['',] + self.get_token_seq(qns) + ['', ], ]
310 | org_q = [qns, ]
311 | question, qns_len = self.get_ques_matrix([self.vocab(''), ] + self.get_word_idx(qns) + [self.vocab(''), ])
312 |
313 | tok = []
314 | org = []
315 | candidate = []
316 | qnstok = ['',] + self.get_token_seq(qns) + ['', ]
317 | qnsids = [self.vocab(''), ] + self.get_word_idx(qns) + [self.vocab(''), ]
318 | for ans in cand_ans:
319 | anstok = ['', ] + self.get_token_seq(ans) + ['', ]
320 | ansids = [self.vocab(''), ] + self.get_word_idx(ans) + [self.vocab(''), ]
321 | tok.append(qnstok+anstok)
322 | org.append(qns+ans)
323 | candidate.append(qnsids + ansids)
324 | candidate, can_lengths = self.get_trans_matrix(candidate)
325 | can_lengths = torch.tensor(can_lengths).clamp(max=self.max_length)
326 | qns_len = torch.tensor(qns_len).clamp(max=self.max_length)
327 | # load object feature
328 | obj_feature = torch.from_numpy(np.stack(self.load_txt_obj(vidx, tok, org), axis=0)).type(torch.float32)
329 | obj_feature_q = torch.from_numpy(self.load_txt_obj(vidx, tok_q, org_q)[0]).type(torch.float32)
330 | # load dependency relation
331 | adj_qas, adj_ques = self.get_adj(vidx, qtype)
332 | adj_ques = torch.from_numpy(adj_ques).type(torch.float32)
333 | adj_qas = torch.from_numpy(np.stack(adj_qas, axis=0)).type(torch.float32)
334 | # load video feature
335 | app_feature, mot_feature = self.get_video_feature(vidx)
336 | qns_key = vidx + '_' + str(qtype)
337 |
338 | return [app_feature, mot_feature], [candidate, can_lengths, obj_feature, adj_qas], [question, qns_len, obj_feature_q, adj_ques], torch.tensor(ans_id), qns_key
339 |
340 | def nozero_row(A):
341 | i = 0
342 | for row in A:
343 | if row.sum()==0:
344 | break
345 | i += 1
346 |
347 | return i
348 |
349 | def load_file(file_name):
350 | annos = None
351 | if osp.splitext(file_name)[-1] == '.csv':
352 | return pd.read_csv(file_name)
353 | with open(file_name, 'r') as fp:
354 | if osp.splitext(file_name)[1]== '.txt':
355 | annos = fp.readlines()
356 | annos = [line.rstrip() for line in annos]
357 | if osp.splitext(file_name)[1] == '.json':
358 | annos = json.load(fp)
359 | return annos
360 |
361 | def save_file(obj, filename):
362 | """
363 | save obj to filename
364 | :param obj:
365 | :param filename:
366 | :return:
367 | """
368 | filepath = osp.dirname(filename)
369 | if filepath != '' and not osp.exists(filepath):
370 | os.makedirs(filepath)
371 | else:
372 | with open(filename, 'w') as fp:
373 | json.dump(obj, fp, indent=4)
374 |
375 | def pkload(file):
376 | data = None
377 | if osp.exists(file) and osp.getsize(file) > 0:
378 | with open(file, 'rb') as fp:
379 | data = pkl.load(fp)
380 | return data
381 |
382 | def pkdump(data, file):
383 | dirname = osp.dirname(file)
384 | if not osp.exists(dirname):
385 | os.makedirs(dirname)
386 | with open(file, 'wb') as fp:
387 | pkl.dump(data, fp)
--------------------------------------------------------------------------------
/dataset/load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import SharedArray as sa
3 | import h5py
4 | from tqdm import tqdm
5 |
6 | bert_file = './data/text_feature/text_seq.h5'
7 | with h5py.File(bert_file, 'r') as fp:
8 | for key in tqdm(fp.keys()):
9 | tmp = sa.create("shm://{}".format(key), fp[key].shape, 'float32')
10 | tmp[:] = fp[key][...]
11 |
--------------------------------------------------------------------------------
/dataset/release.py:
--------------------------------------------------------------------------------
1 | import os
2 | import SharedArray as sa
3 | import h5py
4 | from tqdm import tqdm
5 |
6 | bert_file = './data/text_feature/text_seq.h5'
7 | with h5py.File(bert_file, 'r') as fp:
8 | for key in tqdm(fp.keys()):
9 | sa.delete("shm://{}".format(key))
--------------------------------------------------------------------------------
/dataset/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import pickle as pkl
6 | import pandas as pd
7 |
8 | def load_file(file_name):
9 | annos = None
10 | if osp.splitext(file_name)[-1] == '.csv':
11 | return pd.read_csv(file_name)
12 | with open(file_name, 'r') as fp:
13 | if osp.splitext(file_name)[1]== '.txt':
14 | annos = fp.readlines()
15 | annos = [line.rstrip() for line in annos]
16 | if osp.splitext(file_name)[1] == '.json':
17 | annos = json.load(fp)
18 |
19 | return annos
20 |
21 | def save_file(obj, filename):
22 | """
23 | save obj to filename
24 | :param obj:
25 | :param filename:
26 | :return:
27 | """
28 | filepath = osp.dirname(filename)
29 | if filepath != '' and not osp.exists(filepath):
30 | os.makedirs(filepath)
31 | else:
32 | with open(filename, 'w') as fp:
33 | json.dump(obj, fp, indent=4)
34 |
35 | def pkload(file):
36 | data = None
37 | if osp.exists(file) and osp.getsize(file) > 0:
38 | with open(file, 'rb') as fp:
39 | data = pkl.load(fp)
40 | # print('{} does not exist'.format(file))
41 | return data
42 |
43 |
44 | def pkdump(data, file):
45 | dirname = osp.dirname(file)
46 | if not osp.exists(dirname):
47 | os.makedirs(dirname)
48 | with open(file, 'wb') as fp:
49 | pkl.dump(data, fp)
50 |
51 |
52 |
--------------------------------------------------------------------------------
/eval_mc.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from utils import load_file
3 |
4 | map_name = {'D': 'Description', 'E': 'Explaination', 'PA': 'Predictive-Answer', 'CA': 'Counterfactual-Answer', 'PR': 'Predictive-Reason', 'CR': 'Counterfactual-Reason', 'P':'Predictive', 'C': 'Counterfactual'}
5 |
6 | def accuracy_metric(result_file, qtype):
7 | if qtype == -1:
8 | accuracy_metric_all(result_file)
9 | if qtype == 0:
10 | accuracy_metric_q0(result_file)
11 | if qtype == 1:
12 | accuracy_metric_q1(result_file)
13 | if qtype == 2:
14 | accuracy_metric_q2(result_file)
15 | if qtype == 3:
16 | accuracy_metric_q3(result_file)
17 |
18 | def accuracy_metric_q0(result_file):
19 | preds = list(load_file(result_file).items())
20 | group_acc = {'D': 0}
21 | group_cnt = {'D': 0}
22 | all_acc = 0
23 | all_cnt = 0
24 | for idx in range(len(preds)):
25 | id_qtypes = preds[idx]
26 | answer = id_qtypes[1]['answer']
27 | pred = id_qtypes[1]['prediction']
28 | group_cnt['D'] += 1
29 | all_cnt += 1
30 | if answer == pred:
31 | group_acc['D'] += 1
32 | all_acc += 1
33 | for qtype, acc in group_acc.items(): #
34 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype]))
35 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt))
36 |
37 | def accuracy_metric_q1(result_file):
38 | preds = list(load_file(result_file).items())
39 | group_acc = {'E': 0}
40 | group_cnt = {'E': 0}
41 | all_acc = 0
42 | all_cnt = 0
43 | for idx in range(len(preds)):
44 | id_qtypes = preds[idx]
45 | answer = id_qtypes[1]['answer']
46 | pred = id_qtypes[1]['prediction']
47 | group_cnt['E'] += 1
48 | all_cnt += 1
49 | if answer == pred:
50 | group_acc['E'] += 1
51 | all_acc += 1
52 | for qtype, acc in group_acc.items(): #
53 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype]))
54 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt))
55 |
56 | def accuracy_metric_q2(result_file):
57 | preds = list(load_file(result_file).items())
58 | qtype2short = ['PA', 'PR', 'P']
59 | group_acc = {'PA': 0, 'PR': 0, 'P': 0}
60 | group_cnt = {'PA': 0, 'PR': 0, 'P': 0}
61 | all_acc = 0
62 | all_cnt = 0
63 | for idx in range(len(preds)//2):
64 | id_qtypes = preds[idx*2:(idx+1)*2]
65 | qtypes = [0, 1]
66 | answer = [ans_pre[1]['answer'] for ans_pre in id_qtypes]
67 | pred = [ans_pre[1]['prediction'] for ans_pre in id_qtypes]
68 | for i in range(2):
69 | group_cnt[qtype2short[qtypes[i]]] += 1
70 | if answer[i] == pred[i]:
71 | group_acc[qtype2short[qtypes[i]]] += 1
72 | group_cnt['P'] += 1
73 | all_cnt += 1
74 | if answer[0] == pred[0] and answer[1] == pred[1]:
75 | group_acc['P'] += 1
76 | all_acc += 1
77 | for qtype, acc in group_acc.items(): #
78 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype]))
79 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt))
80 |
81 | def accuracy_metric_q3(result_file):
82 | preds = list(load_file(result_file).items())
83 | qtype2short = ['CA', 'CR', 'C']
84 | group_acc = {'CA': 0, 'CR': 0, 'C': 0}
85 | group_cnt = {'CA': 0, 'CR': 0, 'C': 0}
86 | all_acc = 0
87 | all_cnt = 0
88 | for idx in range(len(preds)//2):
89 | id_qtypes = preds[idx*2:(idx+1)*2]
90 | qtypes = [0, 1]
91 | answer = [ans_pre[1]['answer'] for ans_pre in id_qtypes]
92 | pred = [ans_pre[1]['prediction'] for ans_pre in id_qtypes]
93 | for i in range(2):
94 | group_cnt[qtype2short[qtypes[i]]] += 1
95 | if answer[i] == pred[i]:
96 | group_acc[qtype2short[qtypes[i]]] += 1
97 | group_cnt['C'] += 1
98 | all_cnt += 1
99 | if answer[0] == pred[0] and answer[1] == pred[1]:
100 | group_acc['C'] += 1
101 | all_acc += 1
102 | for qtype, acc in group_acc.items(): #
103 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype]))
104 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt))
105 |
106 | def accuracy_metric_all(result_file):
107 | preds = list(load_file(result_file).items())
108 | qtype2short = ['D', 'E', 'PA', 'PR', 'CA', 'CR', 'P', 'C']
109 | group_acc = {'D': 0, 'E': 0, 'PA': 0, 'PR': 0, 'CA': 0, 'CR': 0, 'P': 0, 'C': 0}
110 | group_cnt = {'D': 0, 'E': 0, 'PA': 0, 'PR': 0, 'CA': 0, 'CR': 0, 'P': 0, 'C': 0}
111 | all_acc = 0
112 | all_cnt = 0
113 | for idx in range(len(preds)//6):
114 | id_qtypes = preds[idx*6:(idx+1)*6]
115 | qtypes = [int(id_qtype[0].split('_')[-1]) for id_qtype in id_qtypes]
116 | answer = [ans_pre[1]['answer'] for ans_pre in id_qtypes]
117 | pred = [ans_pre[1]['prediction'] for ans_pre in id_qtypes]
118 | for i in range(6):
119 | group_cnt[qtype2short[qtypes[i]]] += 1
120 | if answer[i] == pred[i]:
121 | group_acc[qtype2short[qtypes[i]]] += 1
122 | group_cnt['C'] += 1
123 | group_cnt['P'] += 1
124 | all_cnt += 4
125 | if answer[0] == pred[0]:
126 | all_acc += 1
127 | if answer[1] == pred[1]:
128 | all_acc += 1
129 | if answer[2] == pred[2] and answer[3] == pred[3]:
130 | group_acc['P'] += 1
131 | all_acc += 1
132 | if answer[4] == pred[4] and answer[5] == pred[5]:
133 | group_acc['C'] += 1
134 | all_acc += 1
135 | for qtype, acc in group_acc.items(): #
136 | print('{0:21} ==> {1:6.2f}%'.format(map_name[qtype], acc*100.0/group_cnt[qtype]))
137 | print('{0:21} ==> {1:6.2f}%'.format('Acc', all_acc*100.0/all_cnt))
138 |
139 | def main(result_file, qtype=-1):
140 | print('Evaluating {}'.format(result_file))
141 |
142 | accuracy_metric(result_file, qtype)
143 |
144 |
145 | if __name__ == "__main__":
146 | result_file = 'path to results json'
147 | main(result_file, -1)
148 |
--------------------------------------------------------------------------------
/fig/example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/bcmi/Causal-VidQA/d2c05a1bd306dc9e17f619a5c47f4731b81a3031/fig/example.png
--------------------------------------------------------------------------------
/main_qa.py:
--------------------------------------------------------------------------------
1 | from videoqa import *
2 | from dataset import VidQADataset, Vocabulary
3 | from torch.utils.data import Dataset, DataLoader
4 | from utils import *
5 | import argparse
6 | import eval_mc
7 | torch.multiprocessing.set_sharing_strategy('file_system')
8 |
9 | def main(args):
10 |
11 | mode = args.mode
12 | if mode == 'train':
13 | batch_size = args.batch_size
14 | num_worker = 8
15 | else:
16 | batch_size = 32
17 | num_worker = 8
18 |
19 | feature_path = args.feature_path
20 | text_feature_path = args.text_feature_path
21 | data_path = args.data_path
22 | train_split_path = osp.join(args.split_path, 'train.pkl')
23 | valid_split_path = osp.join(args.split_path, 'valid.pkl')
24 | test_split_path = osp.join(args.split_path, 'test.pkl')
25 | qtype=args.qtype
26 | max_qa_len = args.max_qa_len
27 |
28 | vocab = pkload(osp.join(text_feature_path, 'qa_vocab.pkl'))
29 |
30 | glove_embed = osp.join(text_feature_path, 'glove.840B.300d.npy')
31 | use_bert = args.use_bert
32 | checkpoint_path = args.checkpoint_path
33 | model_type = args.model_type
34 | model_prefix= args.model_prefix
35 |
36 | vis_step = args.vis_step
37 | lr_rate = args.lr_rate
38 | epoch_num = args.epoch_num
39 |
40 | if not osp.exists(osp.join(checkpoint_path, model_type, model_prefix)):
41 | os.makedirs(osp.join(checkpoint_path, model_type, model_prefix))
42 | if not osp.exists(osp.join(checkpoint_path, model_type, model_prefix, 'model')):
43 | os.makedirs(osp.join(checkpoint_path, model_type, model_prefix, 'model'))
44 | logger = make_logger(osp.join(checkpoint_path, model_type, model_prefix, 'log'))
45 |
46 | train_set = VidQADataset(feature_path=feature_path, text_feature_path=text_feature_path, split_path=train_split_path, data_path=data_path, use_bert=use_bert, vocab=vocab, qtype=qtype, max_length=max_qa_len)
47 | valid_set = VidQADataset(feature_path=feature_path, text_feature_path=text_feature_path, split_path=valid_split_path, data_path=data_path, use_bert=use_bert, vocab=vocab, qtype=qtype, max_length=max_qa_len)
48 | test_set = VidQADataset(feature_path=feature_path, text_feature_path=text_feature_path, split_path=test_split_path, data_path=data_path, use_bert=use_bert, vocab=vocab, qtype=qtype, max_length=max_qa_len)
49 |
50 | train_loader = DataLoader(
51 | dataset=train_set,
52 | batch_size=batch_size,
53 | shuffle=True,
54 | num_workers=num_worker)
55 |
56 | valid_loader = DataLoader(
57 | dataset=valid_set,
58 | batch_size=batch_size,
59 | shuffle=False,
60 | num_workers=num_worker)
61 |
62 | test_loader = DataLoader(
63 | dataset=test_set,
64 | batch_size=batch_size,
65 | shuffle=False,
66 | num_workers=num_worker)
67 |
68 | vqa = VideoQA(vocab, train_loader, valid_loader, test_loader, glove_embed, use_bert, checkpoint_path, model_type, model_prefix,
69 | vis_step, lr_rate, batch_size, epoch_num, logger, args)
70 |
71 | if mode != 'train':
72 | model_file = osp.join(args.checkpoint_path, model_type, model_prefix, 'model', 'best.ckpt')
73 | result_file1 = args.result_file.format(model_type, model_prefix, 'valid')
74 | result_file2 = args.result_file.format(model_type, model_prefix, 'test')
75 | vqa.predict(model_file, result_file1, vqa.val_loader)
76 | vqa.predict(model_file, result_file2, vqa.test_loader)
77 | print('Validation set')
78 | eval_mc.main(result_file1, qtype=args.qtype)
79 | print('Test set')
80 | eval_mc.main(result_file2, qtype=args.qtype)
81 | else:
82 | model_file = osp.join(model_type, model_prefix, 'model', '0-00.00.ckpt')
83 | vqa.run(model_file, pre_trained=False)
84 | model_file = osp.join(args.checkpoint_path, model_type, model_prefix, 'model', 'best.ckpt')
85 | result_file1 = args.result_file.format(model_type, model_prefix, 'valid')
86 | result_file2 = args.result_file.format(model_type, model_prefix, 'test')
87 | vqa.predict(model_file, result_file1, vqa.val_loader)
88 | vqa.predict(model_file, result_file2, vqa.test_loader)
89 | print('Validation set')
90 | eval_mc.main(result_file1, qtype=args.qtype)
91 | print('Test set')
92 | eval_mc.main(result_file2, qtype=args.qtype)
93 |
94 | if __name__ == "__main__":
95 | torch.backends.cudnn.enabled = False
96 | torch.manual_seed(666)
97 | torch.cuda.manual_seed(666)
98 | torch.backends.cudnn.benchmark = True
99 |
100 | parser = argparse.ArgumentParser()
101 | parser.add_argument('--gpu', type=int, default=0,
102 | help='gpu device id')
103 | parser.add_argument('--mode', type=str, default='train',
104 | help='train or val')
105 | parser.add_argument('--feature_path', type=str, default='',
106 | help='path to load visual feature')
107 | parser.add_argument('--text_feature_path', type=str, default='',
108 | help='path to load text feature')
109 | parser.add_argument('--data_path', type=str, default='',
110 | help='path to load original data')
111 | parser.add_argument('--split_path', type=str, default='',
112 | help='path for train/valid/test split')
113 | parser.add_argument('--use_bert', action='store_true',
114 | help='whether use bert embedding')
115 | parser.add_argument('--checkpoint_path', type=str, default='',
116 | help='path to save training model and log')
117 | parser.add_argument('--model_type', type=str, default='HGA',
118 | help='(B2A, EVQA, CoMem, HME, HGA, HCRN)')
119 | parser.add_argument('--model_prefix', type=str, default='debug',
120 | help='detail model info')
121 | parser.add_argument('--result_file', type=str, default='',
122 | help='where to save processed results')
123 |
124 | parser.add_argument('--vid_dim', type=int, default=4096,
125 | help='number of dim for video features')
126 | parser.add_argument('--hidden_dim', type=int, default=256,
127 | help='number of dim for hidden feature')
128 | parser.add_argument('--word_dim', type=int, default=300,
129 | help='number of dim for word feature')
130 | parser.add_argument('--max_vid_len', type=int, default=8,
131 | help='number of max length for video clips')
132 | parser.add_argument('--max_vid_frame_len', type=int, default=16,
133 | help='number of max length for frames in each video clip')
134 | parser.add_argument('--max_qa_len', type=int, default=40,
135 | help='number of max length for question and answer')
136 | parser.add_argument('--vis_step', type=int, default=100,
137 | help='number of step to print the training info')
138 | parser.add_argument('--epoch_num', type=int, default=30,
139 | help='number of epoch to train model')
140 | parser.add_argument('--lr_rate', type=float, default=1e-4,
141 | help='learning rate')
142 | parser.add_argument('--qtype', type=int, default=-1,
143 | help='question type in VVCR dataset')
144 | parser.add_argument('--batch_size', type=int, default=128,
145 | help='batch size')
146 | parser.add_argument('--gcn_layer', type=int, default=1,
147 | help='gcn layer')
148 | parser.add_argument('--spl_resolution', type=int, default=16,
149 | help='spl_resolution')
150 | args = parser.parse_args()
151 |
152 | main(args)
--------------------------------------------------------------------------------
/networks/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class TempAttention(nn.Module):
7 | """
8 | Applies an attention mechanism on the output features from the decoder.
9 | """
10 |
11 | def __init__(self, text_dim, visual_dim, hidden_dim):
12 | super(TempAttention, self).__init__()
13 | self.hidden_dim = hidden_dim
14 | self.linear_text = nn.Linear(text_dim, hidden_dim)
15 | self.linear_visual = nn.Linear(visual_dim, hidden_dim)
16 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False)
17 | self._init_weight()
18 |
19 | def _init_weight(self):
20 | nn.init.xavier_normal_(self.linear_text.weight)
21 | nn.init.xavier_normal_(self.linear_visual.weight)
22 | nn.init.xavier_normal_(self.linear_att.weight)
23 |
24 | def forward(self, qns_embed, vid_outputs):
25 | """
26 | Arguments:
27 | qns_embed {Variable} -- batch_size x dim
28 | vid_outputs {Variable} -- batch_size x seq_len x dim
29 |
30 | Returns:
31 | context -- context vector of size batch_size x dim
32 | """
33 | qns_embed_trans = self.linear_text(qns_embed)
34 |
35 | batch_size, seq_len, visual_dim = vid_outputs.size()
36 | vid_outputs_temp = vid_outputs.contiguous().view(batch_size*seq_len, visual_dim)
37 | vid_outputs_trans = self.linear_visual(vid_outputs_temp)
38 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.hidden_dim)
39 |
40 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1)
41 |
42 |
43 | o = self.linear_att(torch.tanh(qns_embed_trans+vid_outputs_trans))
44 |
45 | e = o.view(batch_size, seq_len)
46 | beta = F.softmax(e, dim=1)
47 | context = torch.bmm(beta.unsqueeze(1), vid_outputs).squeeze(1)
48 |
49 | return context, beta
50 |
51 |
52 | class SpatialAttention(nn.Module):
53 | """
54 | Apply spatial attention on vid feature before being fed into LSTM
55 | """
56 |
57 | def __init__(self, text_dim=1024, vid_dim=3072, hidden_dim=512, input_dropout_p=0.2):
58 | super(SpatialAttention, self).__init__()
59 |
60 | self.linear_v = nn.Linear(vid_dim, hidden_dim)
61 | self.linear_q = nn.Linear(text_dim, hidden_dim)
62 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False)
63 |
64 | self.softmax = nn.Softmax(dim=1)
65 | self.dropout = nn.Dropout(input_dropout_p)
66 | self._init_weight()
67 |
68 | def _init_weight(self):
69 | nn.init.xavier_normal_(self.linear_v.weight)
70 | nn.init.xavier_normal_(self.linear_q.weight)
71 | nn.init.xavier_normal_(self.linear_att.weight)
72 |
73 | def forward(self, qns_feat, vid_feats):
74 | """
75 | Apply question feature as semantic clue to guide feature aggregation at each frame
76 | :param vid_feats: fnum x feat_dim x 7 x 7
77 | :param qns_feat: dim_hidden*2
78 | :return:
79 | """
80 | # print(qns_feat.size(), vid_feats.size())
81 | # permute to fnum x 7 x 7 x feat_dim
82 | vid_feats = vid_feats.permute(0, 2, 3, 1)
83 | fnum, width, height, feat_dim = vid_feats.size()
84 | vid_feats = vid_feats.contiguous().view(-1, feat_dim)
85 | vid_feats_trans = self.linear_v(vid_feats)
86 |
87 | vid_feats_trans = vid_feats_trans.view(fnum, width*height, -1)
88 | region_num = vid_feats_trans.shape[1]
89 |
90 | qns_feat_trans = self.linear_q(qns_feat)
91 |
92 | qns_feat_trans = qns_feat_trans.repeat(fnum, region_num, 1)
93 | # print(vid_feats_trans.shape, qns_feat_trans.shape)
94 |
95 | vid_qns = self.linear_att(torch.tanh(vid_feats_trans + qns_feat_trans))
96 |
97 | vid_qns_o = vid_qns.view(fnum, region_num)
98 | alpha = self.softmax(vid_qns_o)
99 | alpha = alpha.unsqueeze(1)
100 | vid_feats = vid_feats.view(fnum, region_num, -1)
101 | feature = torch.bmm(alpha, vid_feats).squeeze(1)
102 | feature = self.dropout(feature)
103 | # print(feature.size())
104 | return feature, alpha
105 |
106 |
107 | class TempAttentionHis(nn.Module):
108 | """
109 | Applies an attention mechanism on the output features from the decoder.
110 | """
111 |
112 | def __init__(self, visual_dim, text_dim, his_dim, mem_dim):
113 | super(TempAttentionHis, self).__init__()
114 | # self.dim = dim
115 | self.mem_dim = mem_dim
116 | self.linear_v = nn.Linear(visual_dim, self.mem_dim, bias=False)
117 | self.linear_q = nn.Linear(text_dim, self.mem_dim, bias=False)
118 | self.linear_his1 = nn.Linear(his_dim, self.mem_dim, bias=False)
119 | self.linear_his2 = nn.Linear(his_dim, self.mem_dim, bias=False)
120 | self.linear_att = nn.Linear(self.mem_dim, 1, bias=False)
121 | self._init_weight()
122 |
123 |
124 | def _init_weight(self):
125 | nn.init.xavier_normal_(self.linear_v.weight)
126 | nn.init.xavier_normal_(self.linear_q.weight)
127 | nn.init.xavier_normal_(self.linear_his1.weight)
128 | nn.init.xavier_normal_(self.linear_his2.weight)
129 | nn.init.xavier_normal_(self.linear_att.weight)
130 |
131 |
132 | def forward(self, qns_embed, vid_outputs, his):
133 | """
134 | :param qns_embed: batch_size x 1024
135 | :param vid_outputs: batch_size x seq_num x feat_dim
136 | :param his: batch_size x 512
137 | :return:
138 | """
139 |
140 | batch_size, seq_len, feat_dim = vid_outputs.size()
141 | vid_outputs_trans = self.linear_v(vid_outputs.contiguous().view(batch_size * seq_len, feat_dim))
142 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.mem_dim)
143 |
144 | qns_embed_trans = self.linear_q(qns_embed)
145 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1)
146 |
147 |
148 | his_trans = self.linear_his1(his)
149 | his_trans = his_trans.unsqueeze(1).repeat(1, seq_len, 1)
150 |
151 | o = self.linear_att(torch.tanh(qns_embed_trans + vid_outputs_trans + his_trans))
152 |
153 | e = o.view(batch_size, seq_len)
154 | beta = F.softmax(e, dim=1)
155 | context = torch.bmm(beta.unsqueeze(1), vid_outputs_trans).squeeze(1)
156 |
157 | his_acc = torch.tanh(self.linear_his2(his))
158 |
159 | context += his_acc
160 |
161 | return context, beta
162 |
163 |
164 | class MultiModalAttentionModule(nn.Module):
165 |
166 | def __init__(self, hidden_size=512, simple=False):
167 | """Set the hyper-parameters and build the layers."""
168 | super(MultiModalAttentionModule, self).__init__()
169 |
170 | self.hidden_size = hidden_size
171 | self.simple = simple
172 |
173 | # alignment model
174 | # see appendices A.1.2 of neural machine translation
175 |
176 | self.Wav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
177 | self.Wat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
178 | self.Uav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
179 | self.Uat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
180 | self.Vav = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
181 | self.Vat = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
182 | self.bav = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True)
183 | self.bat = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True)
184 |
185 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
186 | self.Wvh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
187 | self.Wth = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
188 | self.bh = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True)
189 |
190 | self.video_sum_encoder = nn.Linear(hidden_size, hidden_size)
191 | self.question_sum_encoder = nn.Linear(hidden_size, hidden_size)
192 |
193 | self.Wb = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
194 | self.Vbv = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
195 | self.Vbt = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
196 | self.bbv = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
197 | self.bbt = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
198 | self.wb = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
199 | self.init_weights()
200 |
201 | def init_weights(self):
202 | self.Wav.data.normal_(0.0, 0.1)
203 | self.Wat.data.normal_(0.0, 0.1)
204 | self.Uav.data.normal_(0.0, 0.1)
205 | self.Uat.data.normal_(0.0, 0.1)
206 | self.Vav.data.normal_(0.0, 0.1)
207 | self.Vat.data.normal_(0.0, 0.1)
208 | self.bav.data.fill_(0)
209 | self.bat.data.fill_(0)
210 |
211 | self.Whh.data.normal_(0.0, 0.1)
212 | self.Wvh.data.normal_(0.0, 0.1)
213 | self.Wth.data.normal_(0.0, 0.1)
214 | self.bh.data.fill_(0)
215 |
216 | self.Wb.data.normal_(0.0, 0.01)
217 | self.Vbv.data.normal_(0.0, 0.01)
218 | self.Vbt.data.normal_(0.0, 0.01)
219 | self.wb.data.normal_(0.0, 0.01)
220 |
221 | self.bbv.data.fill_(0)
222 | self.bbt.data.fill_(0)
223 |
224 | def forward(self, h, hidden_frames, hidden_text, inv_attention=False):
225 | # print self.Uav
226 | # hidden_text: 1 x T1 x 1024 (looks like a two layer one-directional LSTM, combining each layer's hidden)
227 | # hidden_frame: 1 x T2 x 1024 (from video encoder output, 1024 is similar from above)
228 |
229 | # print hidden_frames.size(),hidden_text.size()
230 | Uhv = torch.matmul(h, self.Uav) # (1,512)
231 | Uhv = Uhv.view(Uhv.size(0), 1, Uhv.size(1)) # (1,1,512)
232 |
233 | Uht = torch.matmul(h, self.Uat) # (1,512)
234 | Uht = Uht.view(Uht.size(0), 1, Uht.size(1)) # (1,1,512)
235 |
236 | # print Uhv.size(),Uht.size()
237 |
238 | Wsv = torch.matmul(hidden_frames, self.Wav) # (1,T,512)
239 | # print Wsv.size()
240 | att_vec_v = torch.matmul(torch.tanh(Wsv + Uhv + self.bav), self.Vav)
241 |
242 | Wst = torch.matmul(hidden_text, self.Wat) # (1,T,512)
243 | att_vec_t = torch.matmul(torch.tanh(Wst + Uht + self.bat), self.Vat)
244 |
245 | if inv_attention == True:
246 | att_vec_v = -att_vec_v
247 | att_vec_t = -att_vec_t
248 |
249 | att_vec_v = torch.softmax(att_vec_v, dim=1)
250 | att_vec_t = torch.softmax(att_vec_t, dim=1)
251 |
252 | att_vec_v = att_vec_v.view(att_vec_v.size(0), att_vec_v.size(1), 1) # expand att_vec from 1xT to 1xTx1
253 | att_vec_t = att_vec_t.view(att_vec_t.size(0), att_vec_t.size(1), 1) # expand att_vec from 1xT to 1xTx1
254 |
255 | hv_weighted = att_vec_v * hidden_frames
256 | hv_sum = torch.sum(hv_weighted, dim=1)
257 | hv_sum2 = self.video_sum_encoder(hv_sum)
258 |
259 | ht_weighted = att_vec_t * hidden_text
260 | ht_sum = torch.sum(ht_weighted, dim=1)
261 | ht_sum2 = self.question_sum_encoder(ht_sum)
262 |
263 | Wbs = torch.matmul(h, self.Wb)
264 | mt1 = torch.matmul(ht_sum, self.Vbt) + self.bbt + Wbs
265 | mv1 = torch.matmul(hv_sum, self.Vbv) + self.bbv + Wbs
266 | mtv = torch.tanh(torch.cat([mv1, mt1], dim=0))
267 | mtv2 = torch.matmul(mtv, self.wb)
268 | beta = torch.softmax(mtv2, dim=0)
269 |
270 | output = torch.tanh(torch.matmul(h, self.Whh) + beta[0] * hv_sum2 +
271 | beta[1] * ht_sum2 + self.bh)
272 | output = output.view(output.size(1), output.size(2))
273 |
274 | return output
--------------------------------------------------------------------------------
/networks/CRN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import itertools
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn.modules.module import Module
7 |
8 | class EncoderVidCRN(nn.Module):
9 | def __init__(self, k_max_frame_level, k_max_clip_level, spl_resolution, vision_dim, dim_hidden=512):
10 | super(EncoderVidCRN, self).__init__()
11 |
12 | self.clip_level_motion_cond = CRN(dim_hidden, k_max_frame_level, k_max_frame_level, gating=False, spl_resolution=spl_resolution)
13 | self.clip_level_question_cond = CRN(dim_hidden, k_max_frame_level-2, k_max_frame_level-2, gating=True, spl_resolution=spl_resolution)
14 | self.video_level_motion_cond = CRN(dim_hidden, k_max_clip_level, k_max_clip_level, gating=False, spl_resolution=spl_resolution)
15 | self.video_level_question_cond = CRN(dim_hidden, k_max_clip_level-2, k_max_clip_level-2, gating=True, spl_resolution=spl_resolution)
16 |
17 | self.sequence_encoder = nn.LSTM(vision_dim, dim_hidden, batch_first=True, bidirectional=False)
18 | self.clip_level_motion_proj = nn.Linear(vision_dim, dim_hidden)
19 | self.video_level_motion_proj = nn.Linear(dim_hidden, dim_hidden)
20 | self.appearance_feat_proj = nn.Linear(vision_dim, dim_hidden)
21 |
22 | self.question_embedding_proj = nn.Linear(dim_hidden, dim_hidden)
23 |
24 | self.dim_hidden = dim_hidden
25 | self.activation = nn.ELU()
26 |
27 | def forward(self, appearance_video_feat, motion_video_feat, question_embedding):
28 | """
29 | Args:
30 | appearance_video_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim)
31 | motion_video_feat: [Tensor] (batch_size, num_clips, visual_inp_dim)
32 | question_embedding: [Tensor] (batch_size, dim_hidden)
33 | return:
34 | encoded video feature: [Tensor] (batch_size, N, dim_hidden)
35 | """
36 | batch_size = appearance_video_feat.size(0)
37 | clip_level_crn_outputs = []
38 | question_embedding_proj = self.question_embedding_proj(question_embedding)
39 | for i in range(appearance_video_feat.size(1)):
40 | clip_level_motion = motion_video_feat[:, i, :] # (bz, 2048)
41 | clip_level_motion_proj = self.clip_level_motion_proj(clip_level_motion)
42 |
43 | clip_level_appearance = appearance_video_feat[:, i, :, :] # (bz, 16, 2048)
44 | clip_level_appearance_proj = self.appearance_feat_proj(clip_level_appearance) # (bz, 16, 512)
45 | # clip level CRNs
46 | clip_level_crn_motion = self.clip_level_motion_cond(torch.unbind(clip_level_appearance_proj, dim=1),
47 | clip_level_motion_proj)
48 | clip_level_crn_question = self.clip_level_question_cond(clip_level_crn_motion, question_embedding_proj)
49 |
50 | clip_level_crn_output = torch.cat(
51 | [frame_relation.unsqueeze(1) for frame_relation in clip_level_crn_question],
52 | dim=1)
53 | clip_level_crn_output = clip_level_crn_output.view(batch_size, -1, self.dim_hidden)
54 | clip_level_crn_outputs.append(clip_level_crn_output)
55 |
56 | # Encode video level motion
57 | _, (video_level_motion, _) = self.sequence_encoder(motion_video_feat)
58 | video_level_motion = video_level_motion.transpose(0, 1)
59 | video_level_motion_feat_proj = self.video_level_motion_proj(video_level_motion)
60 | # video level CRNs
61 | video_level_crn_motion = self.video_level_motion_cond(clip_level_crn_outputs, video_level_motion_feat_proj)
62 | video_level_crn_question = self.video_level_question_cond(video_level_crn_motion,
63 | question_embedding_proj.unsqueeze(1))
64 |
65 | video_level_crn_output = torch.cat([clip_relation.unsqueeze(1) for clip_relation in video_level_crn_question],
66 | dim=1)
67 | video_level_crn_output = video_level_crn_output.view(batch_size, -1, self.dim_hidden)
68 |
69 | return video_level_crn_output
70 |
71 | class CRN(Module):
72 | def __init__(self, dim_hidden, num_objects, max_subset_size, gating=False, spl_resolution=1):
73 | super(CRN, self).__init__()
74 | self.dim_hidden = dim_hidden
75 | self.gating = gating
76 |
77 | self.k_objects_fusion = nn.ModuleList()
78 | if self.gating:
79 | self.gate_k_objects_fusion = nn.ModuleList()
80 | for i in range(min(num_objects, max_subset_size + 1), 1, -1):
81 | self.k_objects_fusion.append(nn.Linear(2 * dim_hidden, dim_hidden))
82 | if self.gating:
83 | self.gate_k_objects_fusion.append(nn.Linear(2 * dim_hidden, dim_hidden))
84 | self.spl_resolution = spl_resolution
85 | self.activation = nn.ELU()
86 | self.max_subset_size = max_subset_size
87 |
88 | def forward(self, object_list, cond_feat):
89 | """
90 | :param object_list: list of tensors or vectors
91 | :param cond_feat: conditioning feature
92 | :return: list of output objects
93 | """
94 | scales = [i for i in range(len(object_list), 1, -1)]
95 |
96 | relations_scales = []
97 | subsample_scales = []
98 | for scale in scales:
99 | relations_scale = self.relationset(len(object_list), scale)
100 | relations_scales.append(relations_scale)
101 | subsample_scales.append(min(self.spl_resolution, len(relations_scale)))
102 |
103 | crn_feats = []
104 | if len(scales) > 1 and self.max_subset_size == len(object_list):
105 | start_scale = 1
106 | else:
107 | start_scale = 0
108 | for scaleID in range(start_scale, min(len(scales), self.max_subset_size)):
109 | idx_relations_randomsample = np.random.choice(len(relations_scales[scaleID]),
110 | subsample_scales[scaleID], replace=False)
111 | mono_scale_features = 0
112 | for id_choice, idx in enumerate(idx_relations_randomsample):
113 | clipFeatList = [object_list[obj].unsqueeze(1) for obj in relations_scales[scaleID][idx]]
114 | # g_theta
115 | g_feat = torch.cat(clipFeatList, dim=1)
116 | g_feat = g_feat.mean(1)
117 | if len(g_feat.size()) == 2:
118 | h_feat = torch.cat((g_feat, cond_feat), dim=-1)
119 | elif len(g_feat.size()) == 3:
120 | cond_feat_repeat = cond_feat.repeat(1, g_feat.size(1), 1)
121 | h_feat = torch.cat((g_feat, cond_feat_repeat), dim=-1)
122 | if self.gating:
123 | h_feat = self.activation(self.k_objects_fusion[scaleID](h_feat)) * torch.sigmoid(
124 | self.gate_k_objects_fusion[scaleID](h_feat))
125 | else:
126 | h_feat = self.activation(self.k_objects_fusion[scaleID](h_feat))
127 | mono_scale_features += h_feat
128 | crn_feats.append(mono_scale_features / len(idx_relations_randomsample))
129 | return crn_feats
130 |
131 | def relationset(self, num_objects, num_object_relation):
132 | return list(itertools.combinations([i for i in range(num_objects)], num_object_relation))
133 |
--------------------------------------------------------------------------------
/networks/Embed_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | # __all__ = ['MultipleChoiceLoss', 'CountLoss']
6 |
7 |
8 | class MultipleChoiceLoss(nn.Module):
9 |
10 | def __init__(self, num_option=5, margin=1, size_average=True):
11 | super(MultipleChoiceLoss, self).__init__()
12 | self.margin = margin
13 | self.num_option = num_option
14 | self.size_average = size_average
15 |
16 | # score is N x C
17 |
18 | def forward(self, score, target):
19 | N = score.size(0)
20 | C = score.size(1)
21 | assert self.num_option == C
22 |
23 | loss = torch.tensor(0.0).cuda()
24 | zero = torch.tensor(0.0).cuda()
25 |
26 | cnt = 0
27 | #print(N,C)
28 | for b in range(N):
29 | # loop over incorrect answer, check if correct answer's score larger than a margin
30 | c0 = target[b]
31 | for c in range(C):
32 | if c == c0:
33 | continue
34 |
35 | # right class and wrong class should have score difference larger than a margin
36 | # see formula under paper Eq(4)
37 | loss += torch.max(zero, 1.0 + score[b, c] - score[b, c0])
38 | cnt += 1
39 |
40 | if cnt == 0:
41 | return loss
42 |
43 | return loss / cnt if self.size_average else loss
44 |
--------------------------------------------------------------------------------
/networks/EncoderRNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4 | from torch.nn import init
5 | import numpy as np
6 | import os
7 |
8 | def init_modules(modules, w_init='kaiming_uniform'):
9 | if w_init == "normal":
10 | _init = init.normal_
11 | elif w_init == "xavier_normal":
12 | _init = init.xavier_normal_
13 | elif w_init == "xavier_uniform":
14 | _init = init.xavier_uniform_
15 | elif w_init == "kaiming_normal":
16 | _init = init.kaiming_normal_
17 | elif w_init == "kaiming_uniform":
18 | _init = init.kaiming_uniform_
19 | elif w_init == "orthogonal":
20 | _init = init.orthogonal_
21 | else:
22 | raise NotImplementedError
23 | for m in modules:
24 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
25 | _init(m.weight)
26 | if m.bias is not None:
27 | torch.nn.init.zeros_(m.bias)
28 | if isinstance(m, (nn.LSTM, nn.GRU)):
29 | for name, param in m.named_parameters():
30 | if 'bias' in name:
31 | nn.init.zeros_(param)
32 | elif 'weight' in name:
33 | _init(param)
34 |
35 | class EncoderQns(nn.Module):
36 | def __init__(self, dim_embed, dim_hidden, vocab_size, glove_embed, use_bert=True, input_dropout_p=0.2, rnn_dropout_p=0.1, n_layers=1, bidirectional=False, rnn_cell='gru'):
37 | """
38 | """
39 | super(EncoderQns, self).__init__()
40 | self.dim_hidden = dim_hidden
41 | self.vocab_size = vocab_size
42 | self.glove_embed = glove_embed
43 | self.input_dropout_p = input_dropout_p
44 | self.rnn_dropout_p = rnn_dropout_p
45 | self.n_layers = n_layers
46 | self.bidirectional = bidirectional
47 | self.rnn_cell = rnn_cell
48 |
49 | self.input_dropout = nn.Dropout(input_dropout_p)
50 | self.rnn_dropout = nn.Dropout(rnn_dropout_p)
51 |
52 | if rnn_cell.lower() == 'lstm':
53 | self.rnn_cell = nn.LSTM
54 | elif rnn_cell.lower() == 'gru':
55 | self.rnn_cell = nn.GRU
56 |
57 | input_dim = dim_embed
58 | self.use_bert = use_bert
59 | if self.use_bert:
60 | self.embedding = nn.Linear(input_dim, dim_embed, bias=False)
61 | else:
62 | self.embedding = nn.Embedding(vocab_size, dim_embed)
63 |
64 | self.obj_embedding = nn.Linear(2048, dim_embed, bias=False)
65 |
66 | self.rnn = self.rnn_cell(dim_embed, dim_hidden, n_layers, batch_first=True,
67 | bidirectional=bidirectional)
68 |
69 | # init_modules(self.modules(), w_init="xavier_uniform")
70 | # nn.init.uniform_(self.embedding.weight, -1.0, 1.0)
71 |
72 | if not self.use_bert and os.path.exists(self.glove_embed):
73 | word_mat = torch.FloatTensor(np.load(self.glove_embed))
74 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False)
75 |
76 | def forward(self, qns, qns_lengths, hidden=None, obj=None):
77 | """
78 | encode question
79 | :param qns:
80 | :param qns_lengths:
81 | :return:
82 | """
83 | qns_embed = self.embedding(qns)
84 | assert obj is not None
85 | obj_embed = self.obj_embedding(obj)
86 | qns_embed = qns_embed + obj_embed
87 | qns_embed = self.input_dropout(qns_embed)
88 | packed = pack_padded_sequence(qns_embed, qns_lengths, batch_first=True, enforce_sorted=False)
89 | packed_output, hidden = self.rnn(packed, hidden)
90 | output, _ = pad_packed_sequence(packed_output, batch_first=True)
91 | output = self.rnn_dropout(output)
92 | hidden = self.rnn_dropout(hidden).squeeze()
93 | return output, hidden
94 |
95 |
96 | class EncoderVid(nn.Module):
97 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
98 | n_layers=1, bidirectional=False, rnn_cell='gru'):
99 | """
100 | """
101 | super(EncoderVid, self).__init__()
102 | self.dim_vid = dim_vid
103 | self.dim_app = 2048
104 | self.dim_motion = 4096
105 | self.dim_hidden = dim_hidden
106 | self.input_dropout_p = input_dropout_p
107 | self.rnn_dropout_p = rnn_dropout_p
108 | self.n_layers = n_layers
109 | self.bidirectional = bidirectional
110 | self.rnn_cell = rnn_cell
111 |
112 | if rnn_cell.lower() == 'lstm':
113 | self.rnn_cell = nn.LSTM
114 | elif rnn_cell.lower() == 'gru':
115 | self.rnn_cell = nn.GRU
116 |
117 | self.rnn = self.rnn_cell(dim_vid, dim_hidden, n_layers, batch_first=True,
118 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
119 |
120 |
121 | def forward(self, vid_feats):
122 |
123 | self.rnn.flatten_parameters()
124 | foutput, fhidden = self.rnn(vid_feats)
125 |
126 | return foutput, fhidden
127 |
128 |
129 | class EncoderVidSTVQA(nn.Module):
130 | def __init__(self, input_dim, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
131 | n_layers=1, bidirectional=False, rnn_cell='gru'):
132 | """
133 | """
134 | super(EncoderVidSTVQA, self).__init__()
135 | self.input_dim = input_dim
136 | self.dim_hidden = dim_hidden
137 | self.input_dropout_p = input_dropout_p
138 | self.rnn_dropout_p = rnn_dropout_p
139 | self.n_layers = n_layers
140 | self.bidirectional = bidirectional
141 | self.rnn_cell = rnn_cell
142 |
143 |
144 | if rnn_cell.lower() == 'lstm':
145 | self.rnn_cell = nn.LSTM
146 | elif rnn_cell.lower() == 'gru':
147 | self.rnn_cell = nn.GRU
148 |
149 | self.rnn1 = self.rnn_cell(input_dim, dim_hidden, n_layers, batch_first=True,
150 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
151 |
152 | self.rnn2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
153 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
154 |
155 |
156 | def forward(self, vid_feats):
157 | """
158 | Dual-layer LSTM
159 | """
160 |
161 | self.rnn1.flatten_parameters()
162 |
163 | foutput_1, fhidden_1 = self.rnn1(vid_feats)
164 | self.rnn2.flatten_parameters()
165 | foutput_2, fhidden_2 = self.rnn2(foutput_1)
166 |
167 | foutput = torch.cat((foutput_1, foutput_2), dim=2)
168 | fhidden = (torch.cat((fhidden_1[0], fhidden_2[0]), dim=0),
169 | torch.cat((fhidden_1[1], fhidden_2[1]), dim=0))
170 |
171 | return foutput, fhidden
172 |
173 |
174 | class EncoderVidCoMem(nn.Module):
175 | def __init__(self, dim_app, dim_motion, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
176 | n_layers=1, bidirectional=False, rnn_cell='gru'):
177 | """
178 | """
179 | super(EncoderVidCoMem, self).__init__()
180 | self.dim_app = dim_app
181 | self.dim_motion = dim_motion
182 | self.dim_hidden = dim_hidden
183 | self.input_dropout_p = input_dropout_p
184 | self.rnn_dropout_p = rnn_dropout_p
185 | self.n_layers = n_layers
186 | self.bidirectional = bidirectional
187 | self.rnn_cell = rnn_cell
188 |
189 | if rnn_cell.lower() == 'lstm':
190 | self.rnn_cell = nn.LSTM
191 | elif rnn_cell.lower() == 'gru':
192 | self.rnn_cell = nn.GRU
193 |
194 | self.rnn_app_l1 = self.rnn_cell(self.dim_app, dim_hidden, n_layers, batch_first=True,
195 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
196 | self.rnn_app_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
197 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
198 |
199 | self.rnn_motion_l1 = self.rnn_cell(self.dim_motion, dim_hidden, n_layers, batch_first=True,
200 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
201 | self.rnn_motion_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
202 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
203 |
204 |
205 | def forward(self, vid_feats):
206 | """
207 | two separate LSTM to encode app and motion feature
208 | :param vid_feats:
209 | :return:
210 | """
211 | vid_app = vid_feats[:, :, 0:self.dim_app]
212 | vid_motion = vid_feats[:, :, self.dim_app:]
213 |
214 | app_output_l1, app_hidden_l1 = self.rnn_app_l1(vid_app)
215 | app_output_l2, app_hidden_l2 = self.rnn_app_l2(app_output_l1)
216 |
217 |
218 | motion_output_l1, motion_hidden_l1 = self.rnn_motion_l1(vid_motion)
219 | motion_output_l2, motion_hidden_l2 = self.rnn_motion_l2(motion_output_l1)
220 |
221 | return app_output_l1, app_output_l2, motion_output_l1, motion_output_l2
222 |
223 |
224 | class EncoderVidHGA(nn.Module):
225 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
226 | n_layers=1, bidirectional=False, rnn_cell='gru'):
227 | """
228 | """
229 | super(EncoderVidHGA, self).__init__()
230 | self.dim_vid = dim_vid
231 | self.dim_hidden = dim_hidden
232 | self.input_dropout_p = input_dropout_p
233 | self.rnn_dropout_p = rnn_dropout_p
234 | self.n_layers = n_layers
235 | self.bidirectional = bidirectional
236 | self.rnn_cell = rnn_cell
237 |
238 |
239 | self.vid2hid = nn.Sequential(nn.Linear(self.dim_vid, dim_hidden),
240 | nn.ReLU(),
241 | nn.Dropout(input_dropout_p))
242 |
243 |
244 | if rnn_cell.lower() == 'lstm':
245 | self.rnn_cell = nn.LSTM
246 | elif rnn_cell.lower() == 'gru':
247 | self.rnn_cell = nn.GRU
248 |
249 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
250 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
251 |
252 | self._init_weight()
253 |
254 |
255 | def _init_weight(self):
256 | nn.init.xavier_normal_(self.vid2hid[0].weight)
257 |
258 |
259 | def forward(self, vid_feats):
260 | """
261 | """
262 | batch_size, seq_len, dim_vid = vid_feats.size()
263 | vid_feats_trans = self.vid2hid(vid_feats.view(-1, self.dim_vid))
264 | vid_feats = vid_feats_trans.view(batch_size, seq_len, -1)
265 |
266 | self.rnn.flatten_parameters()
267 | foutput, fhidden = self.rnn(vid_feats)
268 |
269 | return foutput, fhidden
270 |
271 | class EncoderVidB2A(nn.Module):
272 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
273 | n_layers=1, bidirectional=False, rnn_cell='gru'):
274 | """
275 | """
276 | super(EncoderVidB2A, self).__init__()
277 | self.dim_vid = dim_vid
278 | self.dim_hidden = dim_hidden
279 | self.input_dropout_p = input_dropout_p
280 | self.rnn_dropout_p = rnn_dropout_p
281 | self.n_layers = n_layers
282 | self.bidirectional = bidirectional
283 | self.rnn_cell = rnn_cell
284 |
285 |
286 | self.vid2hid = nn.Sequential(nn.Linear(self.dim_vid, dim_hidden),
287 | nn.ReLU(),
288 | nn.Dropout(input_dropout_p))
289 |
290 |
291 | if rnn_cell.lower() == 'lstm':
292 | self.rnn_cell = nn.LSTM
293 | elif rnn_cell.lower() == 'gru':
294 | self.rnn_cell = nn.GRU
295 |
296 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
297 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
298 |
299 | self._init_weight()
300 |
301 |
302 | def _init_weight(self):
303 | nn.init.xavier_normal_(self.vid2hid[0].weight)
304 |
305 |
306 | def forward(self, app_feat, mot_feat):
307 | """
308 | """
309 | batch_size, seq_len, seq_len2, dim_vid = app_feat.size()
310 |
311 | app_feat_trans = self.vid2hid(app_feat.view(-1, self.dim_vid))
312 | app_feat = app_feat_trans.view(batch_size, seq_len*seq_len2, -1)
313 |
314 | mot_feat_trans = self.vid2hid(mot_feat.view(-1, self.dim_vid))
315 | mot_feat = mot_feat_trans.view(batch_size, seq_len, -1)
316 |
317 | self.rnn.flatten_parameters()
318 | app_output, _ = self.rnn(app_feat)
319 | mot_output, _ = self.rnn(mot_feat)
320 |
321 | return app_output, mot_output
--------------------------------------------------------------------------------
/networks/GCN.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.parameter import Parameter
7 | from torch.nn.modules.module import Module
8 |
9 |
10 | def padding_mask_k(seq_q, seq_k):
11 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
12 | In batch 0:
13 | [[x x x 0] [[0 0 0 1]
14 | [x x x 0]-> [0 0 0 1]
15 | [x x x 0]] [0 0 0 1]] uint8
16 | """
17 | fake_q = torch.ones_like(seq_q)
18 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2))
19 | pad_mask = pad_mask.eq(0)
20 | # pad_mask = pad_mask.lt(1e-3)
21 | return pad_mask
22 |
23 |
24 | def padding_mask_q(seq_q, seq_k):
25 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
26 | In batch 0:
27 | [[x x x x] [[0 0 0 0]
28 | [x x x x] -> [0 0 0 0]
29 | [0 0 0 0]] [1 1 1 1]] uint8
30 | """
31 | fake_k = torch.ones_like(seq_k)
32 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2))
33 | pad_mask = pad_mask.eq(0)
34 | # pad_mask = pad_mask.lt(1e-3)
35 | return pad_mask
36 |
37 |
38 | class GraphConvolution(Module):
39 | """
40 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
41 | """
42 |
43 | def __init__(self, in_features, out_features):
44 | super(GraphConvolution, self).__init__()
45 | self.weight = nn.Linear(in_features, out_features, bias=False)
46 | self.layer_norm = nn.LayerNorm(out_features, elementwise_affine=False)
47 |
48 | def forward(self, input, adj):
49 | # self.weight of shape (hidden_size, hidden_size)
50 | support = self.weight(input)
51 | output = torch.bmm(adj, support)
52 | output = self.layer_norm(output)
53 | return output
54 |
55 |
56 | class GraphAttention(nn.Module):
57 | """
58 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
59 | """
60 |
61 | def __init__(self, in_features, out_features, dropout, alpha, concat=True):
62 | super(GraphAttention, self).__init__()
63 | self.dropout = dropout
64 | self.in_features = in_features
65 | self.out_features = out_features
66 | self.alpha = alpha
67 | self.concat = concat
68 |
69 | self.W = nn.Parameter(
70 | nn.init.xavier_normal_(
71 | torch.Tensor(in_features, out_features).type(
72 | torch.cuda.FloatTensor if torch.cuda.is_available(
73 | ) else torch.FloatTensor),
74 | gain=np.sqrt(2.0)),
75 | requires_grad=True)
76 | self.a1 = nn.Parameter(
77 | nn.init.xavier_normal_(
78 | torch.Tensor(out_features, 1).type(
79 | torch.cuda.FloatTensor if torch.cuda.is_available(
80 | ) else torch.FloatTensor),
81 | gain=np.sqrt(2.0)),
82 | requires_grad=True)
83 | self.a2 = nn.Parameter(
84 | nn.init.xavier_normal_(
85 | torch.Tensor(out_features, 1).type(
86 | torch.cuda.FloatTensor if torch.cuda.is_available(
87 | ) else torch.FloatTensor),
88 | gain=np.sqrt(2.0)),
89 | requires_grad=True)
90 |
91 | self.leakyrelu = nn.LeakyReLU(self.alpha)
92 |
93 | def forward(self, input, adj):
94 | h = torch.mm(input, self.W)
95 | N = h.size()[0]
96 |
97 | f_1 = torch.matmul(h, self.a1)
98 | f_2 = torch.matmul(h, self.a2)
99 | e = self.leakyrelu(f_1 + f_2.transpose(0, 1))
100 |
101 | zero_vec = -9e15 * torch.ones_like(e)
102 | attention = torch.where(adj > 0, e, zero_vec)
103 | attention = F.softmax(attention, dim=1)
104 | attention = F.dropout(attention, self.dropout, training=self.training)
105 | h_prime = torch.matmul(attention, h)
106 |
107 | if self.concat:
108 | return F.elu(h_prime)
109 | else:
110 | return h_prime
111 |
112 |
113 | class GCN(nn.Module):
114 |
115 | def __init__(
116 | self, input_size, hidden_size, num_classes, num_layers=1,
117 | dropout=0.1):
118 | super(GCN, self).__init__()
119 | self.layers = nn.ModuleList()
120 | self.layers.append(GraphConvolution(input_size, hidden_size))
121 | for i in range(num_layers - 1):
122 | self.layers.append(GraphConvolution(hidden_size, hidden_size))
123 | self.layers.append(GraphConvolution(hidden_size, num_classes))
124 | self.dropout = nn.Dropout(p=dropout)
125 |
126 | def forward(self, x, adj):
127 | for i, layer in enumerate(self.layers):
128 | x = self.dropout(F.relu(layer(x, adj)))
129 |
130 | # x of shape (bs, q_v_len, num_classes)
131 | return x
132 |
133 |
134 | class AdjLearner(Module):
135 |
136 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1):
137 | super().__init__()
138 | '''
139 | ## Variables:
140 | - in_feature_dim: dimensionality of input features
141 | - hidden_size: dimensionality of the joint hidden embedding
142 | - K: number of graph nodes/objects on the image
143 | '''
144 |
145 | # Embedding layers. Padded 0 => 0
146 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False)
147 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False)
148 |
149 | # Regularisation
150 | self.dropout = nn.Dropout(p=dropout)
151 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1)
152 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2)
153 |
154 | def forward(self, questions, videos):
155 | '''
156 | ## Inputs:
157 | ## Returns:
158 | - adjacency matrix (batch_size, q_v_len, q_v_len)
159 | '''
160 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features
161 | graph_nodes = torch.cat((questions, videos), dim=1)
162 |
163 | # layer 1
164 | h = self.edge_layer_1(graph_nodes)
165 | h = F.relu(h)
166 |
167 | # layer 2
168 | h = self.edge_layer_2(h)
169 | h = F.relu(h)
170 | # h * sigmoid(Wh)
171 | # h = F.tanh(h)
172 |
173 | # outer product
174 | adjacency_matrix = torch.bmm(h, h.transpose(1, 2))
175 |
176 | return adjacency_matrix
177 |
178 |
179 | class EvoAdjLearner(Module):
180 |
181 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1):
182 | super().__init__()
183 | '''
184 | ## Variables:
185 | - in_feature_dim: dimensionality of input features
186 | - hidden_size: dimensionality of the joint hidden embedding
187 | - K: number of graph nodes/objects on the image
188 | '''
189 |
190 | # Embedding layers. Padded 0 => 0
191 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False)
192 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False)
193 | self.edge_layer_3 = nn.Linear(in_feature_dim, hidden_size, bias=False)
194 | self.edge_layer_4 = nn.Linear(hidden_size, hidden_size, bias=False)
195 |
196 | # Regularisation
197 | self.dropout = nn.Dropout(p=dropout)
198 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1)
199 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2)
200 |
201 | def forward(self, questions, videos):
202 | '''
203 | ## Inputs:
204 | ## Returns:
205 | - adjacency matrix (batch_size, q_v_len, q_v_len)
206 | '''
207 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features
208 | graph_nodes = torch.cat((questions, videos), dim=1)
209 |
210 | attn_mask = padding_mask_k(graph_nodes, graph_nodes)
211 | sf_mask = padding_mask_q(graph_nodes, graph_nodes)
212 |
213 | # layer 1
214 | h = self.edge_layer_1(graph_nodes)
215 | h = F.relu(h)
216 | # layer 2
217 | h = self.edge_layer_2(h)
218 | # h = F.relu(h)
219 |
220 | # layer 1
221 | h_ = self.edge_layer_3(graph_nodes)
222 | h_ = F.relu(h_)
223 | # layer 2
224 | h_ = self.edge_layer_4(h_)
225 | # h_ = F.relu(h_)
226 |
227 | # outer product
228 | adjacency_matrix = torch.bmm(h, h_.transpose(1, 2))
229 | # adjacency_matrix = adjacency_matrix.masked_fill(attn_mask, -np.inf)
230 |
231 | # softmaxed_adj = F.softmax(adjacency_matrix, dim=-1)
232 |
233 | # softmaxed_adj = softmaxed_adj.masked_fill(sf_mask, 0.)
234 |
235 | return adjacency_matrix
236 |
237 | class AdjGenerator(Module):
238 |
239 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1):
240 | super().__init__()
241 | '''
242 | ## Variables:
243 | - in_feature_dim: dimensionality of input features
244 | - hidden_size: dimensionality of the joint hidden embedding
245 | - K: number of graph nodes/objects on the image
246 | '''
247 |
248 | # Embedding layers. Padded 0 => 0
249 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False)
250 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False)
251 |
252 | # Regularisation
253 | self.dropout = nn.Dropout(p=dropout)
254 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1)
255 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2)
256 |
257 | def forward(self, features, adjacency=None):
258 | '''
259 | ## Inputs:
260 | ## Returns:
261 | - adjacency matrix (batch_size, q_v_len, q_v_len)
262 | '''
263 |
264 | # layer 1
265 | h = self.edge_layer_1(features)
266 | h = F.relu(h)
267 |
268 | # layer 2
269 | h = self.edge_layer_2(h)
270 | h = F.relu(h)
271 | # h * sigmoid(Wh)
272 | # h = F.tanh(h)
273 |
274 | # outer product
275 | adjacency_matrix = torch.bmm(h, h.transpose(1, 2))
276 | if adjacency is not None:
277 | adjacency_matrix = adjacency_matrix*adjacency[:, :adjacency_matrix.shape[1], :adjacency_matrix.shape[2]]
278 |
279 | return adjacency_matrix
--------------------------------------------------------------------------------
/networks/Transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import torchnlp_nn as nlpnn
6 |
7 |
8 | def padding_mask(seq_q, seq_k):
9 | # seq_k of shape (batch, k_len) and seq_q (batch, q_len), not embedded. q and k are padded with 0.
10 | seq_q = torch.unsqueeze(seq_q, 2)
11 | seq_k = torch.unsqueeze(seq_k, 2)
12 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2))
13 | pad_mask = pad_mask.eq(0)
14 | return pad_mask
15 |
16 |
17 | def padding_mask_transformer(seq_q, seq_k):
18 | # original padding_mask in transformer, for masking out the padding part of key sequence.
19 | len_q = seq_q.size(1)
20 | # `PAD` is 0
21 | pad_mask = seq_k.eq(0)
22 | pad_mask = pad_mask.unsqueeze(1).expand(
23 | -1, len_q, -1) # shape [B, L_q, L_k]
24 | return pad_mask
25 |
26 |
27 | def padding_mask_embedded(seq_q, seq_k):
28 | # seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len)
29 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2))
30 | pad_mask = pad_mask.eq(0)
31 | return pad_mask
32 |
33 |
34 | def padding_mask_k(seq_q, seq_k):
35 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
36 | In batch 0:
37 | [[x x x 0] [[0 0 0 1]
38 | [x x x 0]-> [0 0 0 1]
39 | [x x x 0]] [0 0 0 1]] uint8
40 | """
41 | fake_q = torch.ones_like(seq_q)
42 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2))
43 | pad_mask = pad_mask.eq(0)
44 | # pad_mask = pad_mask.lt(1e-3)
45 | return pad_mask
46 |
47 |
48 | def padding_mask_q(seq_q, seq_k):
49 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
50 | In batch 0:
51 | [[x x x x] [[0 0 0 0]
52 | [x x x x] -> [0 0 0 0]
53 | [0 0 0 0]] [1 1 1 1]] uint8
54 | """
55 | fake_k = torch.ones_like(seq_k)
56 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2))
57 | pad_mask = pad_mask.eq(0)
58 | # pad_mask = pad_mask.lt(1e-3)
59 | return pad_mask
60 |
61 |
62 | class PositionalEncoding(nn.Module):
63 |
64 | def __init__(self, d_model, max_seq_len):
65 | super(PositionalEncoding, self).__init__()
66 | self.max_seq_len = max_seq_len
67 |
68 | position_encoding = np.array(
69 | [
70 | [
71 | pos / np.power(10000, 2.0 * (j // 2) / d_model)
72 | for j in range(d_model)
73 | ]
74 | for pos in range(max_seq_len)
75 | ])
76 | position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
77 | position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])
78 |
79 | pad_row = torch.zeros([1, d_model])
80 | position_encoding = torch.cat(
81 | (pad_row, torch.from_numpy(position_encoding).float()))
82 |
83 | self.position_encoding = nn.Embedding(max_seq_len + 1, d_model)
84 | self.position_encoding.weight = nn.Parameter(
85 | position_encoding, requires_grad=False)
86 |
87 | def forward(self, input_len):
88 | # max_len = torch.max(input_len)
89 | max_len = self.max_seq_len
90 | tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor
91 | input_pos = [
92 | list(range(1, l + 1)) + [0] * (max_len - l.item())
93 | for l in input_len
94 | ]
95 | input_pos = tensor(input_pos)
96 | return self.position_encoding(input_pos)
97 |
98 |
99 | class PositionalWiseFeedForward(nn.Module):
100 |
101 | def __init__(self, model_dim=512, ffn_dim=512, dropout=0.0):
102 | super(PositionalWiseFeedForward, self).__init__()
103 | self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
104 | self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)
105 | self.dropout = nn.Dropout(dropout)
106 | self.layer_norm = nn.LayerNorm(model_dim)
107 |
108 | def forward(self, x):
109 | # x of shape (bs, seq_len, hs)
110 | output = x.transpose(1, 2)
111 | output = self.w2(F.relu(self.w1(output)))
112 | output = self.dropout(output.transpose(1, 2))
113 |
114 | # add residual and norm layer
115 | output = self.layer_norm(x + output)
116 | return output
117 |
118 |
119 | class MaskedPositionalWiseFeedForward(nn.Module):
120 |
121 | def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
122 | super().__init__()
123 | self.w1 = nn.Linear(model_dim, ffn_dim, bias=False)
124 | self.w2 = nn.Linear(ffn_dim, model_dim, bias=False)
125 | self.dropout = nn.Dropout(dropout)
126 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
127 |
128 | def forward(self, x):
129 | # x of shape (bs, seq_len, hs)
130 | output = self.w2(F.relu(self.w1(x)))
131 | output = self.dropout(output)
132 |
133 | # add residual and norm layer
134 | output = self.layer_norm(x + output)
135 | return output
136 |
137 |
138 | class ScaledDotProductAttention(nn.Module):
139 | """Scaled dot-product attention mechanism."""
140 |
141 | def __init__(self, attention_dropout=0.0):
142 | super(ScaledDotProductAttention, self).__init__()
143 | self.dropout = nn.Dropout(attention_dropout)
144 | self.softmax = nn.Softmax(dim=-1)
145 |
146 | def forward(self, q, k, v, scale=None, attn_mask=None):
147 | """
148 | Args:
149 | q: [B, L_q, D_q]
150 | k: [B, L_k, D_k]
151 | v: [B, L_v, D_v]
152 | """
153 | attention = torch.matmul(q, k.transpose(1, 2))
154 | if scale is not None:
155 | attention = attention * scale
156 | if attn_mask is not None:
157 | attention = attention.masked_fill(attn_mask, -np.inf)
158 | attention = self.softmax(attention)
159 | attention = self.dropout(attention)
160 | output = torch.matmul(attention, v)
161 | return output, attention
162 |
163 |
164 | class MaskedScaledDotProductAttention(nn.Module):
165 | """Scaled dot-product attention mechanism."""
166 |
167 | def __init__(self, attention_dropout=0.0):
168 | super().__init__()
169 | self.dropout = nn.Dropout(attention_dropout)
170 | self.softmax = nn.Softmax(dim=-1)
171 |
172 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None):
173 | """
174 | Args:
175 | q: [B, L_q, D_q]
176 | k: [B, L_k, D_k]
177 | v: [B, L_v, D_v]
178 | """
179 | attention = torch.matmul(q, k.transpose(-2, -1))
180 | if scale is not None:
181 | attention = attention * scale
182 | if attn_mask is not None:
183 | attention = attention.masked_fill(attn_mask, -np.inf)
184 | attention = self.softmax(attention)
185 | attention = attention.masked_fill(softmax_mask, 0.)
186 | attention = self.dropout(attention)
187 | output = torch.matmul(attention, v)
188 | return output, attention
189 |
190 |
191 | class MultiHeadAttention(nn.Module):
192 |
193 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
194 | super(MultiHeadAttention, self).__init__()
195 |
196 | self.dim_per_head = model_dim // num_heads
197 | self.num_heads = num_heads
198 | self.linear_k = nn.Linear(
199 | model_dim, self.dim_per_head * num_heads, bias=False)
200 | self.linear_v = nn.Linear(
201 | model_dim, self.dim_per_head * num_heads, bias=False)
202 | self.linear_q = nn.Linear(
203 | model_dim, self.dim_per_head * num_heads, bias=False)
204 |
205 | self.dot_product_attention = ScaledDotProductAttention(dropout)
206 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False)
207 | self.dropout = nn.Dropout(dropout)
208 | self.layer_norm = nn.LayerNorm(model_dim)
209 |
210 | def forward(self, query, key, value, attn_mask=None):
211 | residual = query
212 |
213 | dim_per_head = self.dim_per_head
214 | num_heads = self.num_heads
215 | batch_size = key.size(0)
216 |
217 | # linear projection
218 | key = self.linear_k(key)
219 | value = self.linear_v(value)
220 | query = self.linear_q(query)
221 |
222 | # split by heads
223 | key = key.view(batch_size * num_heads, -1, dim_per_head)
224 | value = value.view(batch_size * num_heads, -1, dim_per_head)
225 | query = query.view(batch_size * num_heads, -1, dim_per_head)
226 |
227 | if attn_mask is not None:
228 | attn_mask = attn_mask.repeat(num_heads, 1, 1)
229 |
230 | # scaled dot product attention
231 | scale = (key.size(-1) // num_heads)**-0.5
232 | context, attention = self.dot_product_attention(
233 | query, key, value, scale, attn_mask)
234 |
235 | # concat heads
236 | context = context.view(batch_size, -1, dim_per_head * num_heads)
237 |
238 | # final linear projection
239 | output = self.linear_final(context)
240 |
241 | # dropout
242 | output = self.dropout(output)
243 |
244 | # add residual and norm layer
245 | output = self.layer_norm(residual + output)
246 |
247 | return output, attention
248 |
249 |
250 | class MaskedMultiHeadAttention(nn.Module):
251 |
252 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
253 | super().__init__()
254 |
255 | self.dim_per_head = model_dim // num_heads
256 | self.num_heads = num_heads
257 | self.linear_k = nn.Linear(
258 | model_dim, self.dim_per_head * num_heads, bias=False)
259 | self.linear_v = nn.Linear(
260 | model_dim, self.dim_per_head * num_heads, bias=False)
261 | self.linear_q = nn.Linear(
262 | model_dim, self.dim_per_head * num_heads, bias=False)
263 |
264 | self.dot_product_attention = MaskedScaledDotProductAttention(dropout)
265 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False)
266 | self.dropout = nn.Dropout(dropout)
267 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
268 |
269 | def forward(self, query, key, value, attn_mask=None, softmax_mask=None):
270 | residual = query
271 |
272 | dim_per_head = self.dim_per_head
273 | num_heads = self.num_heads
274 | batch_size = key.size(0)
275 |
276 | # linear projection
277 | key = self.linear_k(key)
278 | value = self.linear_v(value)
279 | query = self.linear_q(query)
280 |
281 | # split by heads
282 | key = key.view(batch_size, -1, num_heads, dim_per_head).transpose(1, 2)
283 | value = value.view(batch_size, -1, num_heads,
284 | dim_per_head).transpose(1, 2)
285 | query = query.view(batch_size, -1, num_heads,
286 | dim_per_head).transpose(1, 2)
287 |
288 | if attn_mask is not None:
289 | attn_mask = attn_mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
290 | if softmax_mask is not None:
291 | softmax_mask = softmax_mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
292 | # scaled dot product attention
293 | # key.size(-1) is 64?
294 | scale = key.size(-1)**-0.5
295 | context, attention = self.dot_product_attention(
296 | query, key, value, scale, attn_mask, softmax_mask)
297 |
298 | # concat heads
299 | context = context.transpose(1, 2).contiguous().view(
300 | batch_size, -1, dim_per_head * num_heads)
301 |
302 | # final linear projection
303 | output = self.linear_final(context)
304 |
305 | # dropout
306 | output = self.dropout(output)
307 |
308 | # add residual and norm layer
309 | output = self.layer_norm(residual + output)
310 |
311 | return output, attention
312 |
313 |
314 | class SelfTransformerLayer(nn.Module):
315 |
316 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
317 | super().__init__()
318 |
319 | self.transformer = MaskedMultiHeadAttention(
320 | model_dim, num_heads, dropout)
321 | self.feed_forward = MaskedPositionalWiseFeedForward(
322 | model_dim, ffn_dim, dropout)
323 |
324 | def forward(self, input, attn_mask=None, sf_mask=None):
325 | output, attention = self.transformer(
326 | input, input, input, attn_mask, sf_mask)
327 | # feed forward network
328 | output = self.feed_forward(output)
329 |
330 | return output, attention
331 |
332 |
333 | class SelfTransformer(nn.Module):
334 |
335 | def __init__(
336 | self,
337 | max_len=35,
338 | num_layers=2,
339 | model_dim=512,
340 | num_heads=8,
341 | ffn_dim=2048,
342 | dropout=0.0,
343 | position=False):
344 | super().__init__()
345 |
346 | self.position = position
347 |
348 | self.encoder_layers = nn.ModuleList(
349 | [
350 | SelfTransformerLayer(model_dim, num_heads, ffn_dim, dropout)
351 | for _ in range(num_layers)
352 | ])
353 |
354 | # max_seq_len is 35 or 80
355 | self.pos_embedding = PositionalEncoding(model_dim, max_len)
356 |
357 | def forward(self, input, input_length):
358 | # q_length of shape (batch, ), each item is the length of the seq
359 | if self.position:
360 | input += self.pos_embedding(input_length)[:, :input.size()[1], :]
361 |
362 | attention_mask = padding_mask_k(input, input)
363 | softmax_mask = padding_mask_q(input, input)
364 |
365 | attentions = []
366 | for encoder in self.encoder_layers:
367 | input, attention = encoder(input, attention_mask, softmax_mask)
368 | attentions.append(attention)
369 |
370 | return input, attentions
371 |
372 |
373 | class SelfAttentionLayer(nn.Module):
374 |
375 | def __init__(self, hidden_size, dropout_p=0.0):
376 | super().__init__()
377 | self.dropout = nn.Dropout(dropout_p)
378 | self.softmax = nn.Softmax(dim=-1)
379 |
380 | self.linear_k = nlpnn.WeightDropLinear(
381 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
382 | self.linear_q = nlpnn.WeightDropLinear(
383 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
384 | self.linear_v = nlpnn.WeightDropLinear(
385 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
386 |
387 | self.linear_final = nlpnn.WeightDropLinear(
388 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
389 |
390 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
391 |
392 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None):
393 | """
394 | Args:
395 | q: [B, L_q, D_q]
396 | k: [B, L_k, D_k]
397 | v: [B, L_v, D_v]
398 | """
399 | residual = q
400 |
401 | if attn_mask is None or softmax_mask is None:
402 | attn_mask = padding_mask_k(q, k)
403 | softmax_mask = padding_mask_q(q, k)
404 |
405 | # linear projection
406 | k = self.linear_k(k)
407 | v = self.linear_v(v)
408 | q = self.linear_q(q)
409 |
410 | scale = k.size(-1)**-0.5
411 |
412 | attention = torch.bmm(q, k.transpose(1, 2))
413 | if scale is not None:
414 | attention = attention * scale
415 | if attn_mask is not None:
416 | attention = attention.masked_fill(attn_mask, -np.inf)
417 | attention = self.softmax(attention)
418 | attention = attention.masked_fill(softmax_mask, 0.)
419 |
420 | # attention = self.dropout(attention)
421 | output = torch.bmm(attention, v)
422 | output = self.linear_final(output)
423 | output = self.layer_norm(output + residual)
424 | return output, attention
425 |
426 |
427 | class SelfAttention(nn.Module):
428 |
429 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
430 | super().__init__()
431 |
432 | self.encoder_layers = nn.ModuleList(
433 | [
434 | SelfAttentionLayer(hidden_size, dropout_p)
435 | for _ in range(n_layers)
436 | ])
437 |
438 | def forward(self, input):
439 |
440 | # q_attention_mask of shape (bs, q_len, v_len)
441 | attn_mask = padding_mask_k(input, input)
442 | # v_attention_mask of shape (bs, v_len, q_len)
443 | softmax_mask = padding_mask_q(input, input)
444 |
445 | attentions = []
446 | for encoder in self.encoder_layers:
447 | input, attention = encoder(
448 | input,
449 | input,
450 | input,
451 | attn_mask=attn_mask,
452 | softmax_mask=softmax_mask)
453 | attentions.append(attention)
454 |
455 | return input, attentions
456 |
457 |
458 | class CoAttentionLayer(nn.Module):
459 |
460 | def __init__(self, hidden_size, dropout_p=0.0):
461 | super().__init__()
462 | self.dropout = nn.Dropout(dropout_p)
463 | self.softmax = nn.Softmax(dim=-1)
464 |
465 | self.linear_question = nlpnn.WeightDropLinear(
466 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
467 | self.linear_video = nlpnn.WeightDropLinear(
468 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
469 | self.linear_v_question = nlpnn.WeightDropLinear(
470 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
471 | self.linear_v_video = nlpnn.WeightDropLinear(
472 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
473 |
474 | self.linear_final_qv = nlpnn.WeightDropLinear(
475 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
476 | self.linear_final_vq = nlpnn.WeightDropLinear(
477 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
478 |
479 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False)
480 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False)
481 |
482 | def forward(
483 | self,
484 | question,
485 | video,
486 | scale=None,
487 | attn_mask=None,
488 | softmax_mask=None,
489 | attn_mask_=None,
490 | softmax_mask_=None):
491 | """
492 | Args:
493 | q: [B, L_q, D_q]
494 | k: [B, L_k, D_k]
495 | v: [B, L_v, D_v]
496 | """
497 | q = question
498 | v = video
499 |
500 | if attn_mask is None or softmax_mask is None:
501 | attn_mask = padding_mask_k(question, video)
502 | softmax_mask = padding_mask_q(question, video)
503 | if attn_mask_ is None or softmax_mask_ is None:
504 | attn_mask_ = padding_mask_k(video, question)
505 | softmax_mask_ = padding_mask_q(video, question)
506 |
507 | # linear projection
508 | question_q = self.linear_question(question)
509 | video_k = self.linear_video(video)
510 | question = self.linear_v_question(question)
511 | video = self.linear_v_video(video)
512 |
513 | scale = video.size(-1)**-0.5
514 |
515 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2))
516 | if scale is not None:
517 | attention_qv = attention_qv * scale
518 | if attn_mask is not None:
519 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf)
520 | attention_qv = self.softmax(attention_qv)
521 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.)
522 |
523 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2))
524 | if scale is not None:
525 | attention_vq = attention_vq * scale
526 | if attn_mask_ is not None:
527 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf)
528 | attention_vq = self.softmax(attention_vq)
529 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.)
530 |
531 | # attention = self.dropout(attention)
532 | output_qv = torch.bmm(attention_qv, video)
533 | output_qv = self.linear_final_qv(output_qv)
534 | output_q = self.layer_norm_qv(output_qv + q)
535 |
536 | output_vq = torch.bmm(attention_vq, question)
537 | output_vq = self.linear_final_vq(output_vq)
538 | output_v = self.layer_norm_vq(output_vq + v)
539 | return output_q, output_v
540 |
541 |
542 | class CoAttention(nn.Module):
543 |
544 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
545 | super().__init__()
546 |
547 | self.encoder_layers = nn.ModuleList(
548 | [CoAttentionLayer(hidden_size, dropout_p) for _ in range(n_layers)])
549 |
550 | def forward(self, question, video):
551 | attn_mask = padding_mask_k(question, video)
552 | softmax_mask = padding_mask_q(question, video)
553 | attn_mask_ = padding_mask_k(video, question)
554 | softmax_mask_ = padding_mask_q(video, question)
555 |
556 | for encoder in self.encoder_layers:
557 | question, video = encoder(
558 | question,
559 | video,
560 | attn_mask=attn_mask,
561 | softmax_mask=softmax_mask,
562 | attn_mask_=attn_mask_,
563 | softmax_mask_=softmax_mask_)
564 |
565 | return question, video
566 |
567 |
568 | class CoConcatAttentionLayer(nn.Module):
569 |
570 | def __init__(self, hidden_size, dropout_p=0.0):
571 | super().__init__()
572 | self.dropout = nn.Dropout(dropout_p)
573 | self.softmax = nn.Softmax(dim=-1)
574 |
575 | self.linear_question = nlpnn.WeightDropLinear(
576 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
577 | self.linear_video = nlpnn.WeightDropLinear(
578 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
579 | self.linear_v_question = nlpnn.WeightDropLinear(
580 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
581 | self.linear_v_video = nlpnn.WeightDropLinear(
582 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
583 |
584 | self.linear_final_qv = nn.Sequential(
585 | nlpnn.WeightDropLinear(
586 | 2 * hidden_size,
587 | hidden_size,
588 | weight_dropout=dropout_p,
589 | bias=False), nn.ReLU(),
590 | nlpnn.WeightDropLinear(
591 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False))
592 | self.linear_final_vq = nn.Sequential(
593 | nlpnn.WeightDropLinear(
594 | 2 * hidden_size,
595 | hidden_size,
596 | weight_dropout=dropout_p,
597 | bias=False), nn.ReLU(),
598 | nlpnn.WeightDropLinear(
599 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False))
600 |
601 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False)
602 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False)
603 |
604 | def forward(
605 | self,
606 | question,
607 | video,
608 | scale=None,
609 | attn_mask=None,
610 | softmax_mask=None,
611 | attn_mask_=None,
612 | softmax_mask_=None):
613 | """
614 | Args:
615 | q: [B, L_q, D_q]
616 | k: [B, L_k, D_k]
617 | v: [B, L_v, D_v]
618 | """
619 | q = question
620 | v = video
621 |
622 | if attn_mask is None or softmax_mask is None:
623 | attn_mask = padding_mask_k(question, video)
624 | softmax_mask = padding_mask_q(question, video)
625 | if attn_mask_ is None or softmax_mask_ is None:
626 | attn_mask_ = padding_mask_k(video, question)
627 | softmax_mask_ = padding_mask_q(video, question)
628 |
629 | # linear projection
630 | question_q = self.linear_question(question)
631 | video_k = self.linear_video(video)
632 | question = self.linear_v_question(question)
633 | video = self.linear_v_video(video)
634 |
635 | scale = video.size(-1)**-0.5
636 |
637 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2))
638 | if scale is not None:
639 | attention_qv = attention_qv * scale
640 | if attn_mask is not None:
641 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf)
642 | attention_qv = self.softmax(attention_qv)
643 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.)
644 |
645 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2))
646 | if scale is not None:
647 | attention_vq = attention_vq * scale
648 | if attn_mask_ is not None:
649 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf)
650 | attention_vq = self.softmax(attention_vq)
651 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.)
652 |
653 | # attention = self.dropout(attention)
654 | output_qv = torch.bmm(attention_qv, video)
655 | output_qv = self.linear_final_qv(torch.cat((output_qv, q), dim=-1))
656 | # output_q = self.layer_norm_qv(output_qv + q)
657 | output_q = self.layer_norm_qv(output_qv)
658 |
659 | output_vq = torch.bmm(attention_vq, question)
660 | output_vq = self.linear_final_vq(torch.cat((output_vq, v), dim=-1))
661 | # output_v = self.layer_norm_vq(output_vq + v)
662 | output_v = self.layer_norm_vq(output_vq)
663 | return output_q, output_v
664 |
665 |
666 | class CoConcatAttention(nn.Module):
667 |
668 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
669 | super().__init__()
670 |
671 | self.encoder_layers = nn.ModuleList(
672 | [
673 | CoConcatAttentionLayer(hidden_size, dropout_p)
674 | for _ in range(n_layers)
675 | ])
676 |
677 | def forward(self, question, video):
678 | attn_mask = padding_mask_k(question, video)
679 | softmax_mask = padding_mask_q(question, video)
680 | attn_mask_ = padding_mask_k(video, question)
681 | softmax_mask_ = padding_mask_q(video, question)
682 |
683 | for encoder in self.encoder_layers:
684 | question, video = encoder(
685 | question,
686 | video,
687 | attn_mask=attn_mask,
688 | softmax_mask=softmax_mask,
689 | attn_mask_=attn_mask_,
690 | softmax_mask_=softmax_mask_)
691 |
692 | return question, video
693 |
694 |
695 | class CoSiameseAttentionLayer(nn.Module):
696 |
697 | def __init__(self, hidden_size, dropout_p=0.0):
698 | super().__init__()
699 | self.dropout = nn.Dropout(dropout_p)
700 | self.softmax = nn.Softmax(dim=-1)
701 |
702 | self.linear_question = nlpnn.WeightDropLinear(
703 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
704 | self.linear_video = nlpnn.WeightDropLinear(
705 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
706 | self.linear_v_question = nlpnn.WeightDropLinear(
707 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
708 | self.linear_v_video = nlpnn.WeightDropLinear(
709 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
710 |
711 | self.linear_final = nn.Sequential(
712 | nlpnn.WeightDropLinear(
713 | 2 * hidden_size,
714 | hidden_size,
715 | weight_dropout=dropout_p,
716 | bias=False), nn.ReLU(),
717 | nlpnn.WeightDropLinear(
718 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False))
719 |
720 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False)
721 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False)
722 |
723 | def forward(
724 | self,
725 | question,
726 | video,
727 | scale=None,
728 | attn_mask=None,
729 | softmax_mask=None,
730 | attn_mask_=None,
731 | softmax_mask_=None):
732 | """
733 | Args:
734 | q: [B, L_q, D_q]
735 | k: [B, L_k, D_k]
736 | v: [B, L_v, D_v]
737 | """
738 | q = question
739 | v = video
740 |
741 | if attn_mask is None or softmax_mask is None:
742 | attn_mask = padding_mask_k(question, video)
743 | softmax_mask = padding_mask_q(question, video)
744 | if attn_mask_ is None or softmax_mask_ is None:
745 | attn_mask_ = padding_mask_k(video, question)
746 | softmax_mask_ = padding_mask_q(video, question)
747 |
748 | # linear projection
749 | question_q = self.linear_question(question)
750 | video_k = self.linear_video(video)
751 | question = self.linear_v_question(question)
752 | video = self.linear_v_video(video)
753 |
754 | scale = video.size(-1)**-0.5
755 |
756 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2))
757 | if scale is not None:
758 | attention_qv = attention_qv * scale
759 | if attn_mask is not None:
760 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf)
761 | attention_qv = self.softmax(attention_qv)
762 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.)
763 |
764 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2))
765 | if scale is not None:
766 | attention_vq = attention_vq * scale
767 | if attn_mask_ is not None:
768 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf)
769 | attention_vq = self.softmax(attention_vq)
770 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.)
771 |
772 | # attention = self.dropout(attention)
773 | output_qv = torch.bmm(attention_qv, video)
774 | output_qv = self.linear_final(torch.cat((output_qv, q), dim=-1))
775 | # output_q = self.layer_norm_qv(output_qv + q)
776 | output_q = self.layer_norm_qv(output_qv)
777 |
778 | output_vq = torch.bmm(attention_vq, question)
779 | output_vq = self.linear_final(torch.cat((output_vq, v), dim=-1))
780 | # output_v = self.layer_norm_vq(output_vq + v)
781 | output_v = self.layer_norm_vq(output_vq)
782 | return output_q, output_v
783 |
784 |
785 | class CoSiameseAttention(nn.Module):
786 |
787 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
788 | super().__init__()
789 |
790 | self.encoder_layers = nn.ModuleList(
791 | [
792 | CoSiameseAttentionLayer(hidden_size, dropout_p)
793 | for _ in range(n_layers)
794 | ])
795 |
796 | def forward(self, question, video):
797 | attn_mask = padding_mask_k(question, video)
798 | softmax_mask = padding_mask_q(question, video)
799 | attn_mask_ = padding_mask_k(video, question)
800 | softmax_mask_ = padding_mask_q(video, question)
801 |
802 | for encoder in self.encoder_layers:
803 | question, video = encoder(
804 | question,
805 | video,
806 | attn_mask=attn_mask,
807 | softmax_mask=softmax_mask,
808 | attn_mask_=attn_mask_,
809 | softmax_mask_=softmax_mask_)
810 |
811 | return question, video
812 |
813 |
814 | class SingleAttentionLayer(nn.Module):
815 |
816 | def __init__(self, hidden_size, dropout_p=0.0):
817 | super().__init__()
818 | self.dropout = nn.Dropout(dropout_p)
819 | self.softmax = nn.Softmax(dim=-1)
820 |
821 | self.linear_q = nlpnn.WeightDropLinear(
822 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
823 | self.linear_v = nlpnn.WeightDropLinear(
824 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
825 | self.linear_k = nlpnn.WeightDropLinear(
826 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
827 |
828 | self.linear_final = nlpnn.WeightDropLinear(
829 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
830 |
831 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
832 |
833 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None):
834 | """
835 | Args:
836 | q: [B, L_q, D_q]
837 | k: [B, L_k, D_k]
838 | v: [B, L_v, D_v]
839 | Return: Same shape to q, but in 'v' space, soft knn
840 | """
841 |
842 | if attn_mask is None or softmax_mask is None:
843 | attn_mask = padding_mask_k(q, k)
844 | softmax_mask = padding_mask_q(q, k)
845 |
846 | # linear projection
847 | q = self.linear_q(q)
848 | k = self.linear_k(k)
849 | v = self.linear_v(v)
850 |
851 | scale = v.size(-1)**-0.5
852 |
853 | attention = torch.bmm(q, k.transpose(-2, -1))
854 | if scale is not None:
855 | attention = attention * scale
856 | if attn_mask is not None:
857 | attention = attention.masked_fill(attn_mask, -np.inf)
858 | attention = self.softmax(attention)
859 | attention = attention.masked_fill(softmax_mask, 0.)
860 |
861 | # attention = self.dropout(attention)
862 | output = torch.bmm(attention, v)
863 | output = self.linear_final(output)
864 | output = self.layer_norm(output + q)
865 |
866 | return output
867 |
868 |
869 | class SingleAttention(nn.Module):
870 |
871 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
872 | super().__init__()
873 |
874 | self.encoder_layers = nn.ModuleList(
875 | [
876 | SingleAttentionLayer(hidden_size, dropout_p)
877 | for _ in range(n_layers)
878 | ])
879 |
880 | def forward(self, q, v):
881 | attn_mask = padding_mask_k(q, v)
882 | softmax_mask = padding_mask_q(q, v)
883 |
884 | for encoder in self.encoder_layers:
885 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask)
886 |
887 | return q
888 |
889 | class SingleAttention(nn.Module):
890 |
891 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
892 | super().__init__()
893 |
894 | self.encoder_layers = nn.ModuleList(
895 | [
896 | SingleAttentionLayer(hidden_size, dropout_p)
897 | for _ in range(n_layers)
898 | ])
899 |
900 | def forward(self, q, v):
901 | attn_mask = padding_mask_k(q, v)
902 | softmax_mask = padding_mask_q(q, v)
903 |
904 | for encoder in self.encoder_layers:
905 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask)
906 |
907 | return q
908 |
909 | class SingleSimpleAttentionLayer(nn.Module):
910 |
911 | def __init__(self, hidden_size, dropout_p=0.0):
912 | super().__init__()
913 | self.dropout = nn.Dropout(dropout_p)
914 | self.softmax = nn.Softmax(dim=-1)
915 |
916 | self.linear_final = nlpnn.WeightDropLinear(
917 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
918 |
919 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
920 |
921 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None):
922 | """
923 | Args:
924 | q: [B, L_q, D_q]
925 | k: [B, L_k, D_k]
926 | v: [B, L_v, D_v]
927 | Return: Same shape to q, but in 'v' space, soft knn
928 | """
929 |
930 | if attn_mask is None or softmax_mask is None:
931 | attn_mask = padding_mask_k(q, k)
932 | softmax_mask = padding_mask_q(q, k)
933 |
934 | # linear projection
935 |
936 | scale = v.size(-1)**-0.5
937 |
938 | attention = torch.bmm(q, k.transpose(-2, -1))
939 | if scale is not None:
940 | attention = attention * scale
941 | if attn_mask is not None:
942 | attention = attention.masked_fill(attn_mask, -np.inf)
943 | attention = self.softmax(attention)
944 | attention = attention.masked_fill(softmax_mask, 0.)
945 |
946 | # attention = self.dropout(attention)
947 | output = torch.bmm(attention, v)
948 | output = self.linear_final(output)
949 | output = self.layer_norm(output + q)
950 |
951 | return output
952 |
953 |
954 | class SingleSimpleAttention(nn.Module):
955 |
956 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
957 | super().__init__()
958 |
959 | self.encoder_layers = nn.ModuleList(
960 | [
961 | SingleSimpleAttentionLayer(hidden_size, dropout_p)
962 | for _ in range(n_layers)
963 | ])
964 |
965 | def forward(self, q, v):
966 | attn_mask = padding_mask_k(q, v)
967 | softmax_mask = padding_mask_q(q, v)
968 |
969 | for encoder in self.encoder_layers:
970 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask)
971 |
972 | return q
973 |
974 | class SoftKNN(nn.Module):
975 |
976 | def __init__(self, model_dim=512, num_heads=1, dropout=0.0):
977 | super().__init__()
978 |
979 | self.dim_per_head = model_dim // num_heads
980 | self.num_heads = num_heads
981 | self.linear_k = nn.Linear(
982 | model_dim, self.dim_per_head * num_heads, bias=False)
983 | self.linear_v = nn.Linear(
984 | model_dim, self.dim_per_head * num_heads, bias=False)
985 | self.linear_q = nn.Linear(
986 | model_dim, self.dim_per_head * num_heads, bias=False)
987 |
988 | self.dot_product_attention = ScaledDotProductAttention(dropout)
989 |
990 | def forward(self, query, key, value, attn_mask=None):
991 |
992 | dim_per_head = self.dim_per_head
993 | num_heads = self.num_heads
994 | batch_size = key.size(0)
995 |
996 | # linear projection
997 | key = self.linear_k(key)
998 | value = self.linear_v(value)
999 | query = self.linear_q(query)
1000 |
1001 | # split by heads
1002 | key = key.view(batch_size * num_heads, -1, dim_per_head)
1003 | value = value.view(batch_size * num_heads, -1, dim_per_head)
1004 | query = query.view(batch_size * num_heads, -1, dim_per_head)
1005 |
1006 | if attn_mask is not None:
1007 | attn_mask = attn_mask.repeat(num_heads, 1, 1)
1008 | # scaled dot product attention
1009 | scale = (key.size(-1) // num_heads)**-0.5
1010 | context, attention = self.dot_product_attention(
1011 | query, key, value, scale, attn_mask)
1012 |
1013 | # concat heads
1014 | output = context.view(batch_size, -1, dim_per_head * num_heads)
1015 |
1016 | return output, attention
1017 |
1018 |
1019 | class CrossoverTransformerLayer(nn.Module):
1020 |
1021 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
1022 | super().__init__()
1023 |
1024 | self.v_transformer = MultiHeadAttention(model_dim, num_heads, dropout)
1025 | self.q_transformer = MultiHeadAttention(model_dim, num_heads, dropout)
1026 | self.v_feed_forward = PositionalWiseFeedForward(
1027 | model_dim, ffn_dim, dropout)
1028 | self.q_feed_forward = PositionalWiseFeedForward(
1029 | model_dim, ffn_dim, dropout)
1030 |
1031 | def forward(self, question, video, q_mask=None, v_mask=None):
1032 | # self attention, v_attention of shape (bs, v_len, q_len)
1033 | video_, v_attention = self.v_transformer(
1034 | video, question, question, v_mask)
1035 | # feed forward network
1036 | video_ = self.v_feed_forward(video_)
1037 |
1038 | # self attention, q_attention of shape (bs, q_len, v_len)
1039 | question_, q_attention = self.q_transformer(
1040 | question, video, video, q_mask)
1041 | # feed forward network
1042 | question_ = self.q_feed_forward(question_)
1043 |
1044 | return video_, question_, v_attention, q_attention
1045 |
1046 |
1047 | class CrossoverTransformer(nn.Module):
1048 |
1049 | def __init__(
1050 | self,
1051 | q_max_len=35,
1052 | v_max_len=80,
1053 | num_layers=2,
1054 | model_dim=512,
1055 | num_heads=8,
1056 | ffn_dim=2048,
1057 | dropout=0.0):
1058 | super().__init__()
1059 |
1060 | self.encoder_layers = nn.ModuleList(
1061 | [
1062 | CrossoverTransformerLayer(
1063 | model_dim, num_heads, ffn_dim, dropout)
1064 | for _ in range(num_layers)
1065 | ])
1066 |
1067 | # max_seq_len is 35 or 80
1068 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len)
1069 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len)
1070 |
1071 | def forward(self, question, video, q_length, v_length):
1072 | # q_length of shape (batch, ), each item is the length of the seq
1073 | question += self.q_pos_embedding(q_length)[:, :question.size()[1], :]
1074 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :]
1075 |
1076 | # q_attention_mask of shape (bs, q_len, v_len)
1077 | q_attention_mask = padding_mask_k(question, video)
1078 | # v_attention_mask of shape (bs, v_len, q_len)
1079 | v_attention_mask = padding_mask_k(video, question)
1080 |
1081 | q_attentions = []
1082 | v_attentions = []
1083 | for encoder in self.encoder_layers:
1084 | video, question, v_attention, q_attention = encoder(
1085 | question, video, q_attention_mask, v_attention_mask)
1086 | q_attentions.append(q_attention)
1087 | v_attentions.append(v_attention)
1088 |
1089 | return question, video, q_attentions, v_attentions
1090 |
1091 |
1092 | class MaskedCrossoverTransformerLayer(nn.Module):
1093 |
1094 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
1095 | super().__init__()
1096 |
1097 | self.v_transformer = MaskedMultiHeadAttention(
1098 | model_dim, num_heads, dropout)
1099 | self.q_transformer = MaskedMultiHeadAttention(
1100 | model_dim, num_heads, dropout)
1101 | self.v_feed_forward = MaskedPositionalWiseFeedForward(
1102 | model_dim, ffn_dim, dropout)
1103 | self.q_feed_forward = MaskedPositionalWiseFeedForward(
1104 | model_dim, ffn_dim, dropout)
1105 |
1106 | def forward(
1107 | self,
1108 | question,
1109 | video,
1110 | q_mask=None,
1111 | v_mask=None,
1112 | q_sf_mask=None,
1113 | v_sf_mask=None):
1114 | # self attention, v_attention of shape (bs, v_len, q_len)
1115 | video_, v_attention = self.v_transformer(
1116 | video, question, question, v_mask, v_sf_mask)
1117 | # feed forward network
1118 | video_ = self.v_feed_forward(video_)
1119 |
1120 | # self attention, q_attention of shape (bs, q_len, v_len)
1121 | question_, q_attention = self.q_transformer(
1122 | question, video, video, q_mask, q_sf_mask)
1123 | # feed forward network
1124 | question_ = self.q_feed_forward(question_)
1125 |
1126 | return video_, question_, v_attention, q_attention
1127 |
1128 |
1129 | class MaskedCrossoverTransformer(nn.Module):
1130 |
1131 | def __init__(
1132 | self,
1133 | q_max_len=35,
1134 | v_max_len=80,
1135 | num_layers=2,
1136 | model_dim=512,
1137 | num_heads=8,
1138 | ffn_dim=2048,
1139 | dropout=0.0,
1140 | position=False):
1141 | super().__init__()
1142 |
1143 | self.position = position
1144 |
1145 | self.encoder_layers = nn.ModuleList(
1146 | [
1147 | MaskedCrossoverTransformerLayer(
1148 | model_dim, num_heads, ffn_dim, dropout)
1149 | for _ in range(num_layers)
1150 | ])
1151 |
1152 | # max_seq_len is 35 or 80
1153 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len)
1154 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len)
1155 |
1156 | def forward(self, question, video, q_length, v_length):
1157 | # q_length of shape (batch, ), each item is the length of the seq
1158 | if self.position:
1159 | question += self.q_pos_embedding(
1160 | q_length)[:, :question.size()[1], :]
1161 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :]
1162 |
1163 | q_attention_mask = padding_mask_k(question, video)
1164 | q_softmax_mask = padding_mask_q(question, video)
1165 | v_attention_mask = padding_mask_k(video, question)
1166 | v_softmax_mask = padding_mask_q(video, question)
1167 |
1168 | q_attentions = []
1169 | v_attentions = []
1170 | for encoder in self.encoder_layers:
1171 | video, question, v_attention, q_attention = encoder(
1172 | question, video, q_attention_mask, v_attention_mask,
1173 | q_softmax_mask, v_softmax_mask)
1174 | q_attentions.append(q_attention)
1175 | v_attentions.append(v_attention)
1176 |
1177 | return question, video, q_attentions, v_attentions
1178 |
1179 |
1180 | class SelfTransformerEncoder(nn.Module):
1181 |
1182 | def __init__(
1183 | self,
1184 | hidden_size,
1185 | n_layers,
1186 | dropout_p,
1187 | vocab_size,
1188 | q_max_len,
1189 | v_max_len,
1190 | embedding=None,
1191 | update_embedding=True,
1192 | position=True):
1193 | super().__init__()
1194 | self.dropout = nn.Dropout(p=dropout_p)
1195 | self.ln_q = nn.LayerNorm(hidden_size, elementwise_affine=False)
1196 | self.ln_v = nn.LayerNorm(hidden_size, elementwise_affine=False)
1197 | self.n_layers = n_layers
1198 | self.position = position
1199 |
1200 | embedding_dim = embedding.shape[
1201 | 1] if embedding is not None else hidden_size
1202 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
1203 |
1204 | # ! no embedding init
1205 | # if embedding is not None:
1206 | # # self.embedding.weight.data.copy_(torch.from_numpy(embedding))
1207 | # self.embedding.weight = nn.Parameter(
1208 | # torch.from_numpy(embedding).float())
1209 | self.upcompress_embedding = nlpnn.WeightDropLinear(
1210 | embedding_dim, hidden_size, weight_dropout=dropout_p, bias=False)
1211 | self.embedding.weight.requires_grad = update_embedding
1212 |
1213 | self.project_c3d = nlpnn.WeightDropLinear(4096, 2048, bias=False)
1214 |
1215 | self.project_resnet_and_c3d = nlpnn.WeightDropLinear(
1216 | 4096, hidden_size, weight_dropout=dropout_p, bias=False)
1217 |
1218 | # max_seq_len is 35 or 80
1219 | self.q_pos_embedding = PositionalEncoding(hidden_size, q_max_len)
1220 | self.v_pos_embedding = PositionalEncoding(hidden_size, v_max_len)
1221 |
1222 | def forward(self, question, resnet, c3d, q_length, v_length):
1223 | ### question
1224 | embedded = self.embedding(question)
1225 | embedded = self.dropout(embedded)
1226 | question = F.relu(self.upcompress_embedding(embedded))
1227 |
1228 | ### video
1229 | # ! no relu
1230 | c3d = self.project_c3d(c3d)
1231 | video = F.relu(
1232 | self.project_resnet_and_c3d(torch.cat((resnet, c3d), dim=2)))
1233 |
1234 | ### position encoding
1235 | if self.position:
1236 | question += self.q_pos_embedding(
1237 | q_length)[:, :question.size()[1], :]
1238 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :]
1239 |
1240 | # question = self.ln_q(question)
1241 | # video = self.ln_v(video)
1242 | return question, video
1243 |
--------------------------------------------------------------------------------
/networks/VQAModel/B2A.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.insert(0, 'networks')
5 | from networks.Transformer import SingleSimpleAttention
6 | from networks.GCN import AdjGenerator, GCN
7 | import EncoderRNN
8 | from block import fusions #pytorch >= 1.1.0
9 |
10 |
11 | class B2A(nn.Module):
12 | def __init__(self, vid_encoder, qns_encoder, device, gcn_layer=1):
13 | """
14 | Bridge to Answer: Structure-aware Graph Interaction Network for Video Question Answering (CVPR 2021)
15 | """
16 | super(B2A, self).__init__()
17 | self.vid_encoder = vid_encoder
18 | self.qns_encoder = qns_encoder
19 | self.device = device
20 | hidden_size = qns_encoder.dim_hidden
21 | input_dropout_p = vid_encoder.input_dropout_p
22 |
23 | self.q_input_ln = nn.LayerNorm(hidden_size*2, elementwise_affine=False)
24 | self.v_input_ln = nn.LayerNorm(hidden_size*2, elementwise_affine=False)
25 |
26 | self.co_attn_t2m_qv = SingleSimpleAttention(
27 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
28 | self.co_attn_t2a_qv = SingleSimpleAttention(
29 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
30 | self.co_attn_a2t_vv = SingleSimpleAttention(
31 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
32 | self.co_attn_t2m_vv = SingleSimpleAttention(
33 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
34 | self.co_attn_m2t_vv = SingleSimpleAttention(
35 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
36 | self.co_attn_t2a_vv = SingleSimpleAttention(
37 | hidden_size*2, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
38 |
39 | self.adj_generator = AdjGenerator(hidden_size*2, hidden_size*2)
40 |
41 | self.gcn_v = GCN(
42 | hidden_size*2,
43 | hidden_size*2,
44 | hidden_size*2,
45 | num_layers=gcn_layer,
46 | dropout=input_dropout_p)
47 |
48 | self.gcn_t = GCN(
49 | hidden_size*2,
50 | hidden_size*2,
51 | hidden_size*2,
52 | num_layers=gcn_layer,
53 | dropout=input_dropout_p)
54 |
55 | self.output_layer = OutputUnitMultiChoices(hidden_size*2)
56 |
57 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q):
58 | """
59 | Args:
60 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim)
61 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim)
62 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
63 | candidates_len: [Tensor] (batch_size, 5)
64 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim)
65 | dep_adj: [Tensor] (batch_size, max_length, max_length)
66 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
67 | question_len: [Tensor] (batch_size, 5)
68 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim)
69 | dep_adj_q: [Tensor] (batch_size, max_length, max_length)
70 | return:
71 | logits, predict_idx
72 | """
73 | if self.qns_encoder.use_bert:
74 | candidates = candidates.permute(1, 0, 2, 3)
75 | else:
76 | candidates = candidates.permute(1, 0, 2)
77 |
78 | obj_feature = obj_feature.permute(1, 0, 2, 3)
79 | cand_len = candidates_len.permute(1, 0)
80 |
81 | app_output, mot_output = self.vid_encoder(video_appearance_feat, video_motion_feat)
82 | app_output = self.v_input_ln(app_output)
83 | mot_output = self.v_input_ln(mot_output)
84 |
85 | ques_output, ques_hidden = self.qns_encoder(question, question_len, obj=obj_fea_q)
86 | ques_output = ques_output.reshape(ques_output.shape[0], ques_output.shape[1], -1)
87 | ques_output = self.q_input_ln(ques_output)
88 | ques_hidden = ques_hidden.permute(1, 0, 2).reshape(ques_output.shape[0], -1)
89 |
90 | q_v_emb = self.q2v_v2v(app_output, mot_output, ques_output, dep_adj_q)
91 |
92 |
93 | out = []
94 | for idx, qas in enumerate(candidates):
95 | qas_output, qas_hidden = self.qns_encoder(qas, cand_len[idx], obj=obj_feature[idx])
96 | qas_output = qas_output.reshape(qas_output.shape[0], qas_output.shape[1], -1)
97 | qas_output = self.q_input_ln(qas_output)
98 | qas_hidden = qas_hidden.permute(1, 0, 2).reshape(qas_output.shape[0], -1)
99 | qa_v_emb = self.q2v_v2v(app_output, mot_output, qas_output, dep_adj[:, idx])
100 |
101 | final_output = self.output_layer(q_v_emb, qa_v_emb, ques_hidden, qas_hidden)
102 | out.append(final_output)
103 |
104 | out = torch.stack(out, 0).transpose(1, 0).squeeze()
105 | _, predict_idx = torch.max(out, 1)
106 |
107 | return out, predict_idx
108 |
109 | def q2v_v2v(self, app_feat, mot_feat, txt_feat, txt_cont=None):
110 | app_adj = self.adj_generator(app_feat)
111 | mot_adj = self.adj_generator(mot_feat)
112 | txt_adj = self.adj_generator(txt_feat, adjacency=txt_cont)
113 |
114 | # question-to-visual
115 | app_hat = self.gcn_v(self.co_attn_t2a_qv(app_feat, txt_feat), app_adj)
116 | mot_hat = self.gcn_v(self.co_attn_t2m_qv(mot_feat, txt_feat), mot_adj)
117 |
118 | # visual-to-visual
119 | txt_a2t = self.gcn_t(self.co_attn_a2t_vv(txt_feat, app_hat), txt_adj) + txt_feat
120 | txt_m2t = self.gcn_t(self.co_attn_m2t_vv(txt_feat, mot_hat), txt_adj) + txt_feat
121 | app_v2v = self.co_attn_t2a_vv(app_hat, txt_m2t)
122 | mot_v2v = self.co_attn_t2a_vv(mot_hat, txt_a2t)
123 |
124 | return torch.cat([app_v2v.mean(dim=1), mot_v2v.mean(dim=1)], dim=-1)
125 |
126 | class OutputUnitMultiChoices(nn.Module):
127 | def __init__(self, module_dim=512):
128 | super(OutputUnitMultiChoices, self).__init__()
129 |
130 | self.question_proj = nn.Linear(module_dim, module_dim)
131 |
132 | self.ans_candidates_proj = nn.Linear(module_dim, module_dim)
133 |
134 | self.v_question_proj = nn.Linear(module_dim*2, module_dim)
135 |
136 | self.v_ans_candidates_proj = nn.Linear(module_dim*2, module_dim)
137 |
138 | self.classifier = nn.Sequential(nn.Dropout(0.15),
139 | nn.Linear(module_dim * 4, module_dim),
140 | nn.ELU(),
141 | nn.BatchNorm1d(module_dim),
142 | nn.Dropout(0.15),
143 | nn.Linear(module_dim, 1))
144 |
145 | def forward(self, q_visual_embedding, a_visual_embedding, question_embedding, ans_candidates_embedding):
146 | q_visual_embedding = self.v_question_proj(q_visual_embedding)
147 | a_visual_embedding = self.v_ans_candidates_proj(a_visual_embedding)
148 | question_embedding = self.question_proj(question_embedding)
149 | ans_candidates_embedding = self.ans_candidates_proj(ans_candidates_embedding)
150 | out = torch.cat([q_visual_embedding, question_embedding, a_visual_embedding, ans_candidates_embedding], 1)
151 | out = self.classifier(out)
152 |
153 | return out
--------------------------------------------------------------------------------
/networks/VQAModel/CoMem.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.insert(0, 'networks')
5 | from memory_module import EpisodicMemory
6 |
7 |
8 | class CoMem(nn.Module):
9 | def __init__(self, vid_encoder, qns_encoder, max_len_v, max_len_q, device):
10 | """
11 | motion-appearance co-memory networks for video question answering (CVPR18)
12 | """
13 | super(CoMem, self).__init__()
14 | self.vid_encoder = vid_encoder
15 | self.qns_encoder = qns_encoder
16 |
17 | dim = qns_encoder.dim_hidden
18 |
19 | self.epm_app = EpisodicMemory(dim*2)
20 | self.epm_mot = EpisodicMemory(dim*2)
21 |
22 | self.linear_ma = nn.Linear(dim*2*4, dim*2)
23 | self.linear_mb = nn.Linear(dim*2*4, dim*2)
24 |
25 | self.vq2word = nn.Linear(dim*2*2, 1)
26 |
27 | self.device = device
28 |
29 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q):
30 | """
31 | Args:
32 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim)
33 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim)
34 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
35 | candidates_len: [Tensor] (batch_size, 5)
36 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim)
37 | dep_adj: [Tensor] (batch_size, max_length, max_length)
38 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
39 | question_len: [Tensor] (batch_size, 5)
40 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim)
41 | dep_adj_q: [Tensor] (batch_size, max_length, max_length)
42 | return:
43 | logits, predict_idx
44 | """
45 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1)
46 | if self.qns_encoder.use_bert:
47 | candidates = candidates.permute(1, 0, 2, 3) # for BERT
48 | else:
49 | candidates = candidates.permute(1, 0, 2)
50 |
51 | obj_feature = obj_feature.permute(1, 0, 2, 3)
52 | candidates_len = candidates_len.permute(1, 0)
53 |
54 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats)
55 | vid_feats = (outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2)
56 |
57 | _, qns_hidden = self.qns_encoder(question, question_len, obj=obj_fea_q)
58 | qas_hidden = list()
59 | for idx, qas in enumerate(candidates):
60 | _, ah_tmp = self.qns_encoder(qas, candidates_len[idx], obj=obj_feature[idx])
61 | qas_hidden.append(ah_tmp)
62 |
63 | out = []
64 | for idx, qas in enumerate(qas_hidden):
65 | encoder_out = self.vq_encoder(vid_feats, qns_hidden, qas)
66 | out.append(encoder_out)
67 |
68 | out = torch.stack(out, 0).transpose(1, 0)
69 |
70 | _, predict_idx = torch.max(out, 1)
71 |
72 |
73 | return out, predict_idx
74 |
75 | def vq_encoder(self, vid_feats, ques, qas, iter_num=3):
76 |
77 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = vid_feats
78 |
79 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1)
80 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1)
81 |
82 | batch_size = qas.shape[1]
83 |
84 | qns_embed = ques.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
85 | qas_embed = qas.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
86 |
87 | m_app = outputs_app[:, -1, :]
88 | m_mot = outputs_motion[:, -1, :]
89 | ma, mb = m_app.detach(), m_mot.detach()
90 | m_app = m_app.unsqueeze(1)
91 | m_mot = m_mot.unsqueeze(1)
92 | for _ in range(iter_num):
93 | mm = ma + mb
94 | m_app = self.epm_app(outputs_app, mm, m_app)
95 | m_mot = self.epm_mot(outputs_motion, mm, m_mot)
96 | ma_q = torch.cat((ma, m_app.squeeze(1), qns_embed, qas_embed), dim=1)
97 | mb_q = torch.cat((mb, m_mot.squeeze(1), qns_embed, qas_embed), dim=1)
98 | ma = torch.tanh(self.linear_ma(ma_q))
99 | mb = torch.tanh(self.linear_mb(mb_q))
100 |
101 | mem = torch.cat((ma, mb), dim=1)
102 | outputs = self.vq2word(mem).squeeze()
103 |
104 | return outputs
--------------------------------------------------------------------------------
/networks/VQAModel/EVQA.py:
--------------------------------------------------------------------------------
1 | from locale import AM_STR
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class EVQA(nn.Module):
7 | def __init__(self, vid_encoder, qns_encoder, device, blind=False):
8 | super(EVQA, self).__init__()
9 | self.vid_encoder = vid_encoder
10 | self.qns_encoder = qns_encoder
11 | self.device = device
12 | self.blind = blind
13 | self.FC = nn.Linear(qns_encoder.dim_hidden, 1)
14 |
15 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q):
16 | """
17 | Args:
18 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim)
19 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim)
20 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
21 | candidates_len: [Tensor] (batch_size, 5)
22 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim)
23 | dep_adj: [Tensor] (batch_size, max_length, max_length)
24 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
25 | question_len: [Tensor] (batch_size, 5)
26 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim)
27 | dep_adj_q: [Tensor] (batch_size, max_length, max_length)
28 | return:
29 | logits, predict_idx
30 | """
31 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1)
32 | if self.qns_encoder.use_bert:
33 | candidates = candidates.permute(1, 0, 2, 3) # for BERT
34 | else:
35 | candidates = candidates.permute(1, 0, 2)
36 |
37 | obj_feature = obj_feature.permute(1, 0, 2, 3)
38 | if self.blind:
39 | obj_feature[:] = 0
40 | cand_len = candidates_len.permute(1, 0)
41 | out = []
42 | for idx, qnsans in enumerate(candidates):
43 | encoder_out = self.vq_encoder(vid_feats, qnsans, cand_len[idx], question_len, obj_feature[idx])
44 | out.append(encoder_out)
45 |
46 | out = torch.stack(out, 0).transpose(1, 0)
47 |
48 | _, predict_idx = torch.max(out, 1)
49 |
50 | return out, predict_idx
51 |
52 | def vq_encoder(self, vid_feats, qnsans, qnsans_len, qns_len, obj_feature):
53 |
54 | qmask = torch.zeros(qnsans.shape[0], qnsans.shape[1], dtype=qnsans.dtype, device=qnsans.device) # bs, maxlen
55 | amask = torch.zeros(qnsans.shape[0], qnsans.shape[1], dtype=qnsans.dtype, device=qnsans.device) # bs, maxlen
56 |
57 | for idx in range(qmask.shape[0]):
58 | qmask[idx, :qns_len[idx]] = 1
59 | amask[idx, qns_len[idx]:qnsans_len[idx]] = 1
60 |
61 | if len(qnsans.shape) == 2:
62 | qns = qnsans*qmask
63 | ans = qnsans*amask
64 | elif len(qnsans.shape) == 3:
65 | qns = qnsans*qmask.unsqueeze(-1)
66 | ans = qnsans*amask.unsqueeze(-1)
67 |
68 | obj_feature_q = obj_feature*qmask.unsqueeze(-1)
69 | obj_feature_a = obj_feature*amask.unsqueeze(-1)
70 |
71 | _, vid_hidden = self.vid_encoder(vid_feats)
72 | _, qs_hidden = self.qns_encoder(qns, qns_len, obj=obj_feature_q)
73 | _, as_hidden = self.qns_encoder(ans, qnsans_len, obj=obj_feature_a)
74 |
75 | vid_embed = vid_hidden.squeeze()
76 | qs_embed = qs_hidden.squeeze()
77 | as_embed = as_hidden.squeeze()
78 |
79 | if self.blind:
80 | fuse = qs_embed + as_embed
81 | else:
82 | fuse = qs_embed + as_embed + vid_embed
83 |
84 | outputs = self.FC(fuse).squeeze()
85 |
86 | return outputs
--------------------------------------------------------------------------------
/networks/VQAModel/HCRN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.nn import functional as F
3 |
4 | import itertools
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import init
9 | from torch.nn.modules.module import Module
10 |
11 | class HCRN(nn.Module):
12 | def __init__(self, vid_encoder, qns_encoder, device):
13 | super(HCRN, self).__init__()
14 | """
15 | Hierarchical Conditional Relation Networks for Video Question Answering (CVPR2020)
16 | """
17 | self.qns_encoder = qns_encoder
18 | self.vid_encoder = vid_encoder
19 | hidden_size = vid_encoder.dim_hidden
20 | self.feature_aggregation = FeatureAggregation(hidden_size)
21 |
22 | self.output_unit = OutputUnitMultiChoices(module_dim=hidden_size)
23 |
24 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q):
25 | """
26 | Args:
27 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim)
28 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim)
29 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
30 | candidates_len: [Tensor] (batch_size, 5)
31 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim)
32 | dep_adj: [Tensor] (batch_size, max_length, max_length)
33 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
34 | question_len: [Tensor] (batch_size, 5)
35 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim)
36 | dep_adj_q: [Tensor] (batch_size, max_length, max_length)
37 | return:
38 | logits, predict_idx
39 | """
40 | batch_size = candidates.size(0)
41 | if self.qns_encoder.use_bert:
42 | cand = candidates.permute(1, 0, 2, 3) # for BERT
43 | else:
44 | cand = candidates.permute(1, 0, 2)
45 | cand_len = candidates_len.permute(1, 0)
46 | out = list()
47 | _, question_embedding = self.qns_encoder(question, question_len, obj=obj_fea_q)
48 | visual_embedding = self.vid_encoder(video_appearance_feat, video_motion_feat, question_embedding)
49 | q_visual_embedding = self.feature_aggregation(question_embedding, visual_embedding)
50 | for idx, qas in enumerate(cand):
51 | _, qas_embedding = self.qns_encoder(qas, cand_len[idx], obj=obj_feature[:, idx])
52 | qa_visual_embedding = self.feature_aggregation(qas_embedding, visual_embedding)
53 | encoder_out = self.output_unit(q_visual_embedding, question_embedding, qa_visual_embedding, qas_embedding)
54 | out.append(encoder_out)
55 | out = torch.stack(out, 0).transpose(1, 0).squeeze()
56 | _, predict_idx = torch.max(out, 1)
57 | return out, predict_idx
58 |
59 | class FeatureAggregation(nn.Module):
60 | def __init__(self, module_dim=512):
61 | super(FeatureAggregation, self).__init__()
62 | self.module_dim = module_dim
63 |
64 | self.q_proj = nn.Linear(module_dim, module_dim, bias=False)
65 | self.v_proj = nn.Linear(module_dim, module_dim, bias=False)
66 |
67 | self.cat = nn.Linear(2 * module_dim, module_dim)
68 | self.attn = nn.Linear(module_dim, 1)
69 |
70 | self.activation = nn.ELU()
71 | self.dropout = nn.Dropout(0.15)
72 |
73 | def forward(self, question_rep, visual_feat):
74 | visual_feat = self.dropout(visual_feat)
75 | q_proj = self.q_proj(question_rep)
76 | v_proj = self.v_proj(visual_feat)
77 |
78 | v_q_cat = torch.cat((v_proj, q_proj.unsqueeze(1) * v_proj), dim=-1)
79 | v_q_cat = self.cat(v_q_cat)
80 | v_q_cat = self.activation(v_q_cat)
81 |
82 | attn = self.attn(v_q_cat) # (bz, k, 1)
83 | attn = F.softmax(attn, dim=1) # (bz, k, 1)
84 |
85 | v_distill = (attn * visual_feat).sum(1)
86 |
87 | return v_distill
88 |
89 | class OutputUnitMultiChoices(nn.Module):
90 | def __init__(self, module_dim=512):
91 | super(OutputUnitMultiChoices, self).__init__()
92 |
93 | self.question_proj = nn.Linear(module_dim, module_dim)
94 |
95 | self.ans_candidates_proj = nn.Linear(module_dim, module_dim)
96 |
97 | self.classifier = nn.Sequential(nn.Dropout(0.15),
98 | nn.Linear(module_dim * 4, module_dim),
99 | nn.ELU(),
100 | nn.BatchNorm1d(module_dim),
101 | nn.Dropout(0.15),
102 | nn.Linear(module_dim, 1))
103 |
104 | def forward(self, question_embedding, q_visual_embedding, ans_candidates_embedding,
105 | a_visual_embedding):
106 | question_embedding = self.question_proj(question_embedding)
107 | ans_candidates_embedding = self.ans_candidates_proj(ans_candidates_embedding)
108 | out = torch.cat([q_visual_embedding, question_embedding, a_visual_embedding,
109 | ans_candidates_embedding], 1)
110 | out = self.classifier(out)
111 |
112 | return out
--------------------------------------------------------------------------------
/networks/VQAModel/HGA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.insert(0, 'networks')
5 | from networks.Transformer import CoAttention
6 | from networks.GCN import AdjLearner, GCN
7 | from block import fusions #pytorch >= 1.1.0
8 |
9 |
10 | class HGA(nn.Module):
11 | def __init__(self, vid_encoder, qns_encoder, device):
12 | """
13 | Reasoning with Heterogeneous Graph Alignment for Video Question Answering (AAAI2020)
14 | """
15 | super(HGA, self).__init__()
16 | self.vid_encoder = vid_encoder
17 | self.qns_encoder = qns_encoder
18 | self.device = device
19 | hidden_size = vid_encoder.dim_hidden
20 | input_dropout_p = vid_encoder.input_dropout_p
21 |
22 | self.co_attn = CoAttention(
23 | hidden_size, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
24 |
25 | self.adj_learner = AdjLearner(
26 | hidden_size, hidden_size, dropout=input_dropout_p)
27 |
28 | self.gcn = GCN(
29 | hidden_size,
30 | hidden_size,
31 | hidden_size,
32 | num_layers=1,
33 | dropout=input_dropout_p)
34 |
35 | self.gcn_atten_pool = nn.Sequential(
36 | nn.Linear(hidden_size, hidden_size // 2),
37 | nn.Tanh(),
38 | nn.Linear(hidden_size // 2, 1),
39 | nn.Softmax(dim=-1)) #change to dim=-2 for attention-pooling otherwise sum-pooling
40 |
41 | self.global_fusion = fusions.Block(
42 | [hidden_size, hidden_size], hidden_size, dropout_input=input_dropout_p)
43 |
44 | self.fusion = fusions.Block([hidden_size, hidden_size], 1)
45 |
46 |
47 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q):
48 | """
49 | Args:
50 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim)
51 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim)
52 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
53 | candidates_len: [Tensor] (batch_size, 5)
54 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim)
55 | dep_adj: [Tensor] (batch_size, max_length, max_length)
56 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
57 | question_len: [Tensor] (batch_size, 5)
58 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim)
59 | dep_adj_q: [Tensor] (batch_size, max_length, max_length)
60 | return:
61 | logits, predict_idx
62 | """
63 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1)
64 | if self.qns_encoder.use_bert:
65 | candidates = candidates.permute(1, 0, 2, 3) # for BERT
66 | else:
67 | candidates = candidates.permute(1, 0, 2)
68 |
69 | obj_feature = obj_feature.permute(1, 0, 2, 3)
70 | cand_len = candidates_len.permute(1, 0)
71 |
72 | v_output, v_hidden = self.vid_encoder(vid_feats)
73 | v_last_hidden = torch.squeeze(v_hidden)
74 |
75 |
76 | out = []
77 | for idx, qas in enumerate(candidates):
78 | encoder_out = self.vq_encoder(v_output, v_last_hidden, qas, cand_len[idx], obj_feature[idx])
79 | out.append(encoder_out)
80 |
81 | out = torch.stack(out, 0).transpose(1, 0)
82 | _, predict_idx = torch.max(out, 1)
83 |
84 | return out, predict_idx
85 |
86 |
87 | def vq_encoder(self, v_output, v_last_hidden, qas, qas_len, obj_feature):
88 | q_output, s_hidden = self.qns_encoder(qas, qas_len, obj=obj_feature)
89 | qns_last_hidden = torch.squeeze(s_hidden)
90 |
91 | q_output, v_output = self.co_attn(q_output, v_output)
92 |
93 | adj = self.adj_learner(q_output, v_output)
94 | q_v_inputs = torch.cat((q_output, v_output), dim=1)
95 | q_v_output = self.gcn(q_v_inputs, adj)
96 |
97 | local_attn = self.gcn_atten_pool(q_v_output)
98 | local_out = torch.sum(q_v_output * local_attn, dim=1)
99 |
100 | global_out = self.global_fusion((qns_last_hidden, v_last_hidden))
101 |
102 |
103 | out = self.fusion((global_out, local_out)).squeeze()
104 |
105 | return out
106 |
--------------------------------------------------------------------------------
/networks/VQAModel/HME.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import sys
4 | sys.path.insert(0, 'networks')
5 | from Attention import TempAttention, SpatialAttention
6 | from memory_rand import MemoryRamTwoStreamModule2, MemoryRamModule2, MMModule2
7 |
8 |
9 | class HME(nn.Module):
10 | def __init__(self, vid_encoder, qns_encoder, max_len_v, max_len_q, device, input_drop_p=0.2):
11 | """
12 | Heterogeneous memory enhanced multimodal attention model for video question answering (CVPR19)
13 | """
14 | super(HME, self).__init__()
15 | self.vid_encoder = vid_encoder
16 | self.qns_encoder = qns_encoder
17 |
18 |
19 | dim = qns_encoder.dim_hidden
20 |
21 | self.temp_att_a = TempAttention(dim * 2, dim * 2, hidden_dim=256)
22 | self.temp_att_m = TempAttention(dim * 2, dim * 2, hidden_dim=256)
23 | self.mrm_vid = MemoryRamTwoStreamModule2(dim, dim, max_len_v, device)
24 | self.mrm_txt = MemoryRamModule2(dim, dim, max_len_q, device)
25 |
26 | self.mm_module_v1 = MMModule2(dim, input_drop_p, device)
27 |
28 | self.linear_vid = nn.Linear(dim*2, dim)
29 | self.linear_qns = nn.Linear(dim*2, dim)
30 | self.linear_mem = nn.Linear(dim*2, dim)
31 | self.vq2word_hme = nn.Linear(dim*3, 1)
32 | self.device = device
33 |
34 | def forward(self, video_appearance_feat, video_motion_feat, candidates, candidates_len, obj_feature, dep_adj, question, question_len, obj_fea_q, dep_adj_q):
35 | """
36 | Args:
37 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim)
38 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim)
39 | candidates: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
40 | candidates_len: [Tensor] (batch_size, 5)
41 | obj_feature: [Tensor] (batch_size, 5, max_length, emb_dim)
42 | dep_adj: [Tensor] (batch_size, max_length, max_length)
43 | question: [Tensor] (batch_size, 5, max_length, [emb_dim(for bert)])
44 | question_len: [Tensor] (batch_size, 5)
45 | obj_fea_q: [Tensor] (batch_size, 5, max_length, emb_dim)
46 | dep_adj_q: [Tensor] (batch_size, max_length, max_length)
47 | return:
48 | logits, predict_idx
49 | """
50 | vid_feats = torch.cat([video_appearance_feat.mean(2), video_motion_feat], dim=-1)
51 | if self.qns_encoder.use_bert:
52 | candidates = candidates.permute(1, 0, 2, 3) # for BERT
53 | else:
54 | candidates = candidates.permute(1, 0, 2)
55 |
56 | obj_feature = obj_feature.permute(1, 0, 2, 3)
57 | candidates_len = candidates_len.permute(1, 0)
58 |
59 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats)
60 | vid_feats = (outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2)
61 |
62 | qas_seq, qas_hidden = list(), list()
63 | for idx, qas in enumerate(candidates):
64 | q_output, s_hidden = self.qns_encoder(qas, candidates_len[idx], obj=obj_feature[idx])
65 | qas_seq.append(q_output)
66 | qas_hidden.append(s_hidden)
67 |
68 | out = []
69 | for idx, (qa_seq, qa_hidden) in enumerate(zip(qas_seq, qas_hidden)):
70 | encoder_out = self.vq_encoder(vid_feats, qa_seq, qa_hidden)
71 | out.append(encoder_out)
72 |
73 | out = torch.stack(out, 0).transpose(1, 0)
74 |
75 | _, predict_idx = torch.max(out, 1)
76 |
77 | return out, predict_idx
78 |
79 | def vq_encoder(self, vid_feats, qns_seq, qns_hidden, iter_num=3):
80 |
81 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = vid_feats
82 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1)
83 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1)
84 |
85 | batch_size, fnum, vid_feat_dim = outputs_app.size()
86 |
87 | batch_size, seq_len, qns_feat_dim = qns_seq.size()
88 |
89 | qns_embed = qns_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1)
90 |
91 | # Apply temporal attention
92 | att_app, beta_app = self.temp_att_a(qns_embed, outputs_app)
93 | att_motion, beta_motion = self.temp_att_m(qns_embed, outputs_motion)
94 | tmp_app_motion = torch.cat((outputs_app_l2[:, -1, :], outputs_motion_l2[:, -1, :]), dim=-1)
95 |
96 | mem_output = torch.zeros(batch_size, vid_feat_dim).to(self.device)
97 |
98 | mem_ram_vid = self.mrm_vid(outputs_app_l2, outputs_motion_l2, fnum)
99 | mem_ram_txt = self.mrm_txt(qns_seq, qns_seq.shape[1])
100 | mem_output[:] = self.mm_module_v1(tmp_app_motion, mem_ram_vid, mem_ram_txt, iter_num)
101 |
102 | app_trans = torch.tanh(self.linear_vid(att_app))
103 | motion_trans = torch.tanh(self.linear_vid(att_motion))
104 | mem_trans = torch.tanh(self.linear_mem(mem_output))
105 |
106 | encoder_outputs = torch.cat((app_trans, motion_trans, mem_trans), dim=1)
107 | outputs = self.vq2word_hme(encoder_outputs).squeeze()
108 |
109 | return outputs
--------------------------------------------------------------------------------
/networks/memory_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 |
7 |
8 | class AttentionGRUCell(nn.Module):
9 | '''
10 | Eq (1)~(4), then modify by Eq (11)
11 | When forwarding, we feed attention gate g into GRU
12 | '''
13 | def __init__(self, input_size, hidden_size):
14 | super(AttentionGRUCell, self).__init__()
15 | self.hidden_size = hidden_size
16 | self.Wr = nn.Linear(input_size, hidden_size)
17 | self.Ur = nn.Linear(hidden_size, hidden_size)
18 | self.W = nn.Linear(input_size, hidden_size)
19 | self.U = nn.Linear(hidden_size, hidden_size)
20 |
21 | def init_weights(self):
22 | self.Wr.weight.data.normal_(0.0, 0.02)
23 | self.Wr.bias.data.fill_(0)
24 | self.Ur.weight.data.normal_(0.0, 0.02)
25 | self.Ur.bias.data.fill_(0)
26 | self.W.weight.data.normal_(0.0, 0.02)
27 | self.W.bias.data.fill_(0)
28 | self.U.weight.data.normal_(0.0, 0.02)
29 | self.U.bias.data.fill_(0)
30 |
31 | def forward(self, fact, C, g):
32 | '''
33 | fact.size() -> (#batch, #hidden = #embedding)
34 | c.size() -> (#hidden, ) -> (#batch, #hidden = #embedding)
35 | r.size() -> (#batch, #hidden = #embedding)
36 | h_tilda.size() -> (#batch, #hidden = #embedding)
37 | g.size() -> (#batch, )
38 | '''
39 |
40 | r = torch.sigmoid(self.Wr(fact) + self.Ur(C))
41 | h_tilda = torch.tanh(self.W(fact) + r * self.U(C))
42 | g = g.unsqueeze(1).expand_as(h_tilda)
43 | h = g * h_tilda + (1 - g) * C
44 | return h
45 |
46 | class AttentionGRU(nn.Module):
47 | '''
48 | Section 3.3
49 | continuously run AttnGRU to get contextual vector c at each time t
50 | '''
51 | def __init__(self, input_size, hidden_size):
52 | super(AttentionGRU, self).__init__()
53 | self.hidden_size = hidden_size
54 | self.AGRUCell = AttentionGRUCell(input_size, hidden_size)
55 |
56 | def init_weights(self):
57 | self.AGRUCell.init_weights()
58 |
59 | def forward(self, facts, G):
60 | '''
61 | facts.size() -> (#batch, #sentence, #hidden = #embedding)
62 | fact.size() -> (#batch, #hidden = #embedding)
63 | G.size() -> (#batch, #sentence)
64 | g.size() -> (#batch, )
65 | C.size() -> (#batch, #hidden)
66 | '''
67 | batch_num, sen_num, embedding_size = facts.size()
68 | C = Variable(torch.zeros(self.hidden_size)).cuda()
69 | for sid in range(sen_num):
70 | fact = facts[:, sid, :]
71 | g = G[:, sid]
72 | if sid == 0:
73 | C = C.unsqueeze(0).expand_as(fact)
74 | C = self.AGRUCell(fact, C, g)
75 | return C
76 |
77 | class EpisodicMemory(nn.Module):
78 | '''
79 | Section 3.3
80 | '''
81 |
82 | def __init__(self, hidden_size):
83 | super(EpisodicMemory, self).__init__()
84 | self.AGRU = AttentionGRU(hidden_size, hidden_size)
85 | self.z1 = nn.Linear(4 * hidden_size, hidden_size)
86 | self.z2 = nn.Linear(hidden_size, 1)
87 | self.next_mem = nn.Linear(3 * hidden_size, hidden_size)
88 |
89 |
90 | def init_weights(self):
91 | self.z1.weight.data.normal_(0.0, 0.02)
92 | self.z1.bias.data.fill_(0)
93 | self.z2.weight.data.normal_(0.0, 0.02)
94 | self.z2.bias.data.fill_(0)
95 | self.next_mem.weight.data.normal_(0.0, 0.02)
96 | self.next_mem.bias.data.fill_(0)
97 | self.AGRU.init_weights()
98 |
99 |
100 | def make_interaction(self, frames, questions, prevM):
101 | '''
102 | frames.size() -> (#batch, T, #hidden = #embedding)
103 | questions.size() -> (#batch, 1, #hidden)
104 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
105 | z.size() -> (#batch, T, 4 x #embedding)
106 | G.size() -> (#batch, T)
107 | '''
108 | batch_num, T, embedding_size = frames.size()
109 | questions = questions.view(questions.size(0),1,questions.size(1))
110 |
111 |
112 | #questions = questions.expand_as(frames)
113 | #prevM = prevM.expand_as(frames)
114 |
115 | #print(questions.size(),prevM.size())
116 |
117 | # Eq (8)~(10)
118 | z = torch.cat([
119 | frames * questions,
120 | frames * prevM,
121 | torch.abs(frames - questions),
122 | torch.abs(frames - prevM)
123 | ], dim=2)
124 |
125 | z = z.view(-1, 4 * embedding_size)
126 |
127 | G = torch.tanh(self.z1(z))
128 | G = self.z2(G)
129 | G = G.view(batch_num, -1)
130 | G = F.softmax(G,dim=1)
131 | #print('G size',G.size())
132 | return G
133 |
134 | def forward(self, frames, questions, prevM):
135 | '''
136 | frames.size() -> (#batch, #sentence, #hidden = #embedding)
137 | questions.size() -> (#batch, #sentence = 1, #hidden)
138 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
139 | G.size() -> (#batch, #sentence)
140 | C.size() -> (#batch, #hidden)
141 | concat.size() -> (#batch, 3 x #embedding)
142 | '''
143 |
144 | '''
145 | section 3.3 - Attention based GRU
146 | input: F and q, as frames and questions
147 | then get gates g
148 | then (c,m,g) feed into memory update module Eq(13)
149 | output new memory state
150 | '''
151 | # print(frames.shape, questions.shape, prevM.shape)
152 |
153 | G = self.make_interaction(frames, questions, prevM)
154 | C = self.AGRU(frames, G)
155 | concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1)
156 | next_mem = F.relu(self.next_mem(concat))
157 | #print(next_mem.size())
158 | next_mem = next_mem.unsqueeze(1)
159 | return next_mem
160 |
--------------------------------------------------------------------------------
/networks/memory_rand.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from Attention import MultiModalAttentionModule
5 |
6 | class MemoryRamModule(nn.Module):
7 |
8 | def __init__(self, input_size=1024, hidden_size=512, memory_bank_size=100, device=None):
9 | """Set the hyper-parameters and build the layers."""
10 | super(MemoryRamModule, self).__init__()
11 |
12 | self.input_size = input_size
13 | self.hidden_size = hidden_size
14 | self.memory_bank_size = memory_bank_size
15 | self.device = device
16 |
17 | self.hidden_to_content = nn.Linear(hidden_size+input_size, hidden_size)
18 | #self.read_to_hidden = nn.Linear(hidden_size+input_size, 1)
19 | self.write_gate = nn.Linear(hidden_size+input_size, 1)
20 | self.write_prob = nn.Linear(hidden_size+input_size, memory_bank_size)
21 |
22 | self.read_gate = nn.Linear(hidden_size+input_size, 1)
23 | self.read_prob = nn.Linear(hidden_size+input_size, memory_bank_size)
24 |
25 |
26 | self.Wxh = nn.Parameter(torch.FloatTensor(input_size, hidden_size),requires_grad=True)
27 | self.Wrh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True)
28 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True)
29 | self.bh = nn.Parameter(torch.FloatTensor(hidden_size),requires_grad=True)
30 |
31 | self.init_weights()
32 |
33 |
34 | def init_weights(self):
35 | self.Wxh.data.normal_(0.0, 0.1)
36 | self.Wrh.data.normal_(0.0, 0.1)
37 | self.Whh.data.normal_(0.0, 0.1)
38 | self.bh.data.fill_(0)
39 |
40 |
41 | def forward(self, hidden_frames, nImg):
42 |
43 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device)
44 | memory_ram.fill_(0)
45 |
46 | h_t = torch.zeros(1, self.hidden_size).to(self.device)
47 |
48 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device)
49 |
50 | for t in range(nImg):
51 | x_t = hidden_frames[t:t+1,:]
52 |
53 | x_h_t = torch.cat([x_t,h_t],dim=1)
54 |
55 | ############# read ############
56 | ar = torch.softmax(self.read_prob( x_h_t ),dim=1) # read prob from memories
57 | go = torch.sigmoid(self.read_gate( x_h_t )) # read gate
58 | r = go * torch.matmul(ar,memory_ram) # read vector
59 |
60 | ######### h_t #########
61 | # Eq (17)
62 | m1 = torch.matmul(x_t, self.Wxh)
63 | m2 = torch.matmul(r, self.Wrh)
64 | m3 = torch.matmul(h_t, self.Whh)
65 | h_t_p1 = F.relu(m1 + m2 + m3 + self.bh) # Eq(17)
66 |
67 |
68 | ############# write ############
69 | c_t = F.relu( self.hidden_to_content(x_h_t) ) # Eq(15), content vector
70 | aw = torch.softmax(self.write_prob( x_h_t ),dim=1) # write prob to memories
71 | aw = aw.view(self.memory_bank_size,1)
72 | gw = torch.sigmoid(self.write_gate( x_h_t )) # write gate
73 | #print gw.size(),aw.size(),c_t.size(),memory_ram.size()
74 | memory_ram = gw * aw * c_t + (1.0-aw) * memory_ram # Eq(16)
75 |
76 | h_t = h_t_p1
77 | hiddens[t,:] = h_t
78 |
79 | #return memory_ram
80 | return hiddens
81 |
82 |
83 | class MemoryRamTwoStreamModule(nn.Module):
84 |
85 | def __init__(self, input_size, hidden_size=512, memory_bank_size=100, device=None):
86 | """Set the hyper-parameters and build the layers."""
87 | super(MemoryRamTwoStreamModule, self).__init__()
88 |
89 | self.input_size = input_size
90 | self.hidden_size = hidden_size
91 | self.memory_bank_size = memory_bank_size
92 | self.device = device
93 |
94 | self.hidden_to_content_a = nn.Linear(hidden_size+input_size, hidden_size)
95 | self.hidden_to_content_m = nn.Linear(hidden_size+input_size, hidden_size)
96 |
97 | self.write_prob = nn.Linear(hidden_size*3, 3)
98 | self.write_prob_a = nn.Linear(hidden_size+input_size, memory_bank_size)
99 | self.write_prob_m = nn.Linear(hidden_size+input_size, memory_bank_size)
100 |
101 | self.read_prob = nn.Linear(hidden_size*3, memory_bank_size)
102 |
103 | self.read_to_hidden = nn.Linear(hidden_size*2, hidden_size)
104 | self.read_to_hidden_a = nn.Linear(hidden_size*2+input_size, hidden_size)
105 | self.read_to_hidden_m = nn.Linear(hidden_size*2+input_size, hidden_size)
106 | self.init_weights()
107 |
108 | def init_weights(self):
109 | pass
110 |
111 |
112 | def forward(self, hidden_out_a, hidden_out_m, nImg):
113 |
114 |
115 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device)
116 | memory_ram.fill_(0)
117 |
118 | h_t_a = torch.zeros(1, self.hidden_size).to(self.device)
119 | h_t_m = torch.zeros(1, self.hidden_size).to(self.device)
120 | h_t = torch.zeros(1, self.hidden_size).to(self.device)
121 |
122 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device)
123 |
124 | for t in range(nImg):
125 | x_t_a = hidden_out_a[t:t+1,:]
126 | x_t_m = hidden_out_m[t:t+1,:]
127 |
128 |
129 | ############# read ############
130 | x_h_t_am = torch.cat([h_t_a,h_t_m,h_t],dim=1)
131 | ar = torch.softmax(self.read_prob( x_h_t_am ),dim=1) # read prob from memories
132 | r = torch.matmul(ar,memory_ram) # read vector
133 |
134 |
135 | ######### h_t #########
136 | # Eq (17)
137 | f_0 = torch.cat([r, h_t],dim=1)
138 | f_a = torch.cat([x_t_a, r, h_t_a],dim=1)
139 | f_m = torch.cat([x_t_m, r, h_t_m],dim=1)
140 |
141 | h_t_1 = F.relu(self.read_to_hidden(f_0))
142 | h_t_a1 = F.relu(self.read_to_hidden_a(f_a))
143 | h_t_m1 = F.relu(self.read_to_hidden_m(f_m))
144 |
145 |
146 | ############# write ############
147 |
148 | # write probability of [keep, write appearance, write motion]
149 | aw = torch.softmax(self.write_prob( x_h_t_am ),dim=1) # write prob to memories
150 | x_h_ta = torch.cat([h_t_a,x_t_a],dim=1)
151 | x_h_tm = torch.cat([h_t_m,x_t_m],dim=1)
152 |
153 |
154 | # write content
155 | c_t_a = F.relu( self.hidden_to_content_a(x_h_ta) ) # Eq(15), content vector
156 | c_t_m = F.relu( self.hidden_to_content_m(x_h_tm) ) # Eq(15), content vector
157 |
158 | aw_a = torch.softmax(self.write_prob_a( x_h_ta ),dim=1) # write prob to memories
159 | aw_m = torch.softmax(self.write_prob_m( x_h_tm ),dim=1) # write prob to memories
160 |
161 |
162 | aw_a = aw_a.view(self.memory_bank_size,1)
163 | aw_m = aw_m.view(self.memory_bank_size,1)
164 |
165 | memory_ram = aw[0,0] * memory_ram + aw[0,1] * aw_a * c_t_a + aw[0,2] * aw_m * c_t_m
166 |
167 |
168 | h_t = h_t_1
169 | h_t_a = h_t_a1
170 | h_t_m = h_t_m1
171 |
172 | hiddens[t,:] = h_t
173 |
174 |
175 | return hiddens
176 |
177 | class MMModule(nn.Module):
178 | def __init__(self, dim, input_drop_p, device):
179 | """Set the hyper-parameters and build the layers."""
180 | super(MMModule, self).__init__()
181 | self.hidden_size = dim
182 | self.lstm_mm_1 = nn.LSTMCell(dim, dim)
183 | self.lstm_mm_2 = nn.LSTMCell(dim, dim)
184 | self.hidden_encoder_1 = nn.Linear(dim * 2, dim)
185 | self.hidden_encoder_2 = nn.Linear(dim * 2, dim)
186 | self.dropout = nn.Dropout(input_drop_p)
187 | self.mm_att = MultiModalAttentionModule(dim)
188 | self.device = device
189 | self.init_weights()
190 |
191 |
192 | def init_weights(self):
193 | nn.init.xavier_normal_(self.hidden_encoder_1.weight)
194 | nn.init.xavier_normal_(self.hidden_encoder_2.weight)
195 | self.init_hiddens()
196 |
197 | def init_hiddens(self):
198 | s_t = torch.zeros(1, self.hidden_size).to(self.device)
199 | s_t2 = torch.zeros(1, self.hidden_size).to(self.device)
200 | c_t = torch.zeros(1, self.hidden_size).to(self.device)
201 | c_t2 = torch.zeros(1, self.hidden_size).to(self.device)
202 | return s_t, s_t2, c_t, c_t2
203 |
204 | def forward(self, svt_tmp, memory_ram_vid, memory_ram_txt, loop=3):
205 | """
206 |
207 | :param svt_tmp:
208 | :param memory_ram_vid:
209 | :param memory_ram_txt:
210 | :param loop:
211 | :return:
212 | """
213 |
214 | sm_q1, sm_q2, cm_q1, cm_q2 = self.init_hiddens()
215 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_1(svt_tmp)))
216 |
217 | for _ in range(loop):
218 | sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1))
219 | sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2))
220 |
221 | mm_o1 = self.mm_att(sm_q2, memory_ram_vid, memory_ram_txt)
222 | mm_o2 = torch.cat((sm_q2, mm_o1), dim=1)
223 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_2(mm_o2)))
224 |
225 | smq = torch.cat((sm_q1, sm_q2), dim=1)
226 |
227 | return smq
228 |
229 | class MemoryRamTwoStreamModule2(nn.Module):
230 |
231 | def __init__(self, input_size, hidden_size=512, memory_bank_size=100, device=None):
232 | """Set the hyper-parameters and build the layers."""
233 | super(MemoryRamTwoStreamModule2, self).__init__()
234 |
235 | self.input_size = input_size
236 | self.hidden_size = hidden_size
237 | self.memory_bank_size = memory_bank_size
238 | self.device = device
239 |
240 | self.hidden_to_content_a = nn.Linear(hidden_size+input_size, hidden_size)
241 | self.hidden_to_content_m = nn.Linear(hidden_size+input_size, hidden_size)
242 |
243 | self.write_prob = nn.Linear(hidden_size*3, 3)
244 | self.write_prob_a = nn.Linear(hidden_size+input_size, memory_bank_size)
245 | self.write_prob_m = nn.Linear(hidden_size+input_size, memory_bank_size)
246 |
247 | self.read_prob = nn.Linear(hidden_size*3, memory_bank_size)
248 |
249 | self.read_to_hidden = nn.Linear(hidden_size*2, hidden_size)
250 | self.read_to_hidden_a = nn.Linear(hidden_size*2+input_size, hidden_size)
251 | self.read_to_hidden_m = nn.Linear(hidden_size*2+input_size, hidden_size)
252 | self.init_weights()
253 |
254 | def init_weights(self):
255 | pass
256 |
257 |
258 | def forward(self, hidden_out_a, hidden_out_m, nImg):
259 |
260 |
261 | memory_ram = torch.FloatTensor(hidden_out_a.shape[0], self.memory_bank_size, self.hidden_size).to(self.device)
262 | memory_ram.fill_(0)
263 |
264 | h_t_a = torch.zeros(hidden_out_a.shape[0], 1, self.hidden_size).to(self.device)
265 | h_t_m = torch.zeros(hidden_out_a.shape[0], 1, self.hidden_size).to(self.device)
266 | h_t = torch.zeros(hidden_out_a.shape[0], 1, self.hidden_size).to(self.device)
267 |
268 | hiddens = torch.FloatTensor(hidden_out_a.shape[0], nImg, self.hidden_size).to(self.device)
269 |
270 | for t in range(nImg):
271 | x_t_a = hidden_out_a[:, t:t+1,:]
272 | x_t_m = hidden_out_m[:, t:t+1,:]
273 |
274 |
275 | ############# read ############
276 | x_h_t_am = torch.cat([h_t_a,h_t_m,h_t],dim=2)
277 | ar = torch.softmax(self.read_prob( x_h_t_am ),dim=2) # read prob from memories
278 | r = torch.matmul(ar,memory_ram) # read vector
279 |
280 |
281 | ######### h_t #########
282 | # Eq (17)
283 | f_0 = torch.cat([r, h_t],dim=2)
284 | f_a = torch.cat([x_t_a, r, h_t_a],dim=2)
285 | f_m = torch.cat([x_t_m, r, h_t_m],dim=2)
286 |
287 | h_t_1 = F.relu(self.read_to_hidden(f_0))
288 | h_t_a1 = F.relu(self.read_to_hidden_a(f_a))
289 | h_t_m1 = F.relu(self.read_to_hidden_m(f_m))
290 |
291 |
292 | ############# write ############
293 |
294 | # write probability of [keep, write appearance, write motion]
295 | aw = torch.softmax(self.write_prob( x_h_t_am ),dim=2) # write prob to memories
296 | x_h_ta = torch.cat([h_t_a,x_t_a],dim=2)
297 | x_h_tm = torch.cat([h_t_m,x_t_m],dim=2)
298 |
299 |
300 | # write content
301 | c_t_a = F.relu( self.hidden_to_content_a(x_h_ta) ) # Eq(15), content vector
302 | c_t_m = F.relu( self.hidden_to_content_m(x_h_tm) ) # Eq(15), content vector
303 |
304 | aw_a = torch.softmax(self.write_prob_a( x_h_ta ),dim=2) # write prob to memories
305 | aw_m = torch.softmax(self.write_prob_m( x_h_tm ),dim=2) # write prob to memories
306 |
307 |
308 | aw_a = aw_a.view(hidden_out_a.shape[0], self.memory_bank_size,1)
309 | aw_m = aw_m.view(hidden_out_a.shape[0], self.memory_bank_size,1)
310 |
311 | memory_ram = aw[:, 0,0].unsqueeze(1).unsqueeze(2) * memory_ram + aw[:, 0,1].unsqueeze(1).unsqueeze(2) * aw_a * c_t_a + aw[:, 0,2].unsqueeze(1).unsqueeze(2) * aw_m * c_t_m
312 |
313 |
314 | h_t = h_t_1
315 | h_t_a = h_t_a1
316 | h_t_m = h_t_m1
317 |
318 | hiddens[:, t,:] = h_t.squeeze()
319 |
320 |
321 | return hiddens
322 |
323 | class MemoryRamModule2(nn.Module):
324 |
325 | def __init__(self, input_size=1024, hidden_size=512, memory_bank_size=100, device=None):
326 | """Set the hyper-parameters and build the layers."""
327 | super(MemoryRamModule2, self).__init__()
328 |
329 | self.input_size = input_size
330 | self.hidden_size = hidden_size
331 | self.memory_bank_size = memory_bank_size
332 | self.device = device
333 |
334 | self.hidden_to_content = nn.Linear(hidden_size+input_size, hidden_size)
335 | #self.read_to_hidden = nn.Linear(hidden_size+input_size, 1)
336 | self.write_gate = nn.Linear(hidden_size+input_size, 1)
337 | self.write_prob = nn.Linear(hidden_size+input_size, memory_bank_size)
338 |
339 | self.read_gate = nn.Linear(hidden_size+input_size, 1)
340 | self.read_prob = nn.Linear(hidden_size+input_size, memory_bank_size)
341 |
342 |
343 | self.Wxh = nn.Parameter(torch.FloatTensor(input_size, hidden_size),requires_grad=True)
344 | self.Wrh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True)
345 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True)
346 | self.bh = nn.Parameter(torch.FloatTensor(hidden_size),requires_grad=True)
347 |
348 | self.init_weights()
349 |
350 |
351 | def init_weights(self):
352 | self.Wxh.data.normal_(0.0, 0.1)
353 | self.Wrh.data.normal_(0.0, 0.1)
354 | self.Whh.data.normal_(0.0, 0.1)
355 | self.bh.data.fill_(0)
356 |
357 |
358 | def forward(self, hidden_frames, nImg):
359 |
360 | memory_ram = torch.FloatTensor(hidden_frames.shape[0], self.memory_bank_size, self.hidden_size).to(self.device)
361 | memory_ram.fill_(0)
362 |
363 | h_t = torch.zeros(hidden_frames.shape[0], 1, self.hidden_size).to(self.device)
364 |
365 | hiddens = torch.FloatTensor(hidden_frames.shape[0], nImg, self.hidden_size).to(self.device)
366 |
367 | for t in range(nImg):
368 | x_t = hidden_frames[:, t:t+1,:]
369 |
370 | x_h_t = torch.cat([x_t,h_t],dim=2)
371 |
372 | ############# read ############
373 | ar = torch.softmax(self.read_prob( x_h_t ),dim=2) # read prob from memories
374 | go = torch.sigmoid(self.read_gate( x_h_t )) # read gate
375 | r = go * torch.matmul(ar,memory_ram) # read vector
376 |
377 | ######### h_t #########
378 | # Eq (17)
379 | m1 = torch.matmul(x_t, self.Wxh)
380 | m2 = torch.matmul(r, self.Wrh)
381 | m3 = torch.matmul(h_t, self.Whh)
382 | h_t_p1 = F.relu(m1 + m2 + m3 + self.bh) # Eq(17)
383 |
384 |
385 | ############# write ############
386 | c_t = F.relu( self.hidden_to_content(x_h_t) ) # Eq(15), content vector
387 | aw = torch.softmax(self.write_prob( x_h_t ),dim=2) # write prob to memories
388 | aw = aw.view(-1, self.memory_bank_size,1)
389 | gw = torch.sigmoid(self.write_gate( x_h_t )) # write gate
390 | #print gw.size(),aw.size(),c_t.size(),memory_ram.size()
391 | memory_ram = gw * aw * c_t + (1.0-aw) * memory_ram # Eq(16)
392 |
393 | h_t = h_t_p1
394 | hiddens[:, t,:] = h_t.squeeze()
395 |
396 | #return memory_ram
397 | return hiddens
398 |
399 | class MMModule2(nn.Module):
400 | def __init__(self, dim, input_drop_p, device):
401 | """Set the hyper-parameters and build the layers."""
402 | super(MMModule2, self).__init__()
403 | self.hidden_size = dim
404 | self.lstm_mm_1 = nn.LSTMCell(dim, dim)
405 | self.lstm_mm_2 = nn.LSTMCell(dim, dim)
406 | self.hidden_encoder_1 = nn.Linear(dim * 2, dim)
407 | self.hidden_encoder_2 = nn.Linear(dim * 2, dim)
408 | self.dropout = nn.Dropout(input_drop_p)
409 | self.mm_att = MultiModalAttentionModule(dim)
410 | self.device = device
411 | self.init_weights()
412 |
413 |
414 | def init_weights(self):
415 | nn.init.xavier_normal_(self.hidden_encoder_1.weight)
416 | nn.init.xavier_normal_(self.hidden_encoder_2.weight)
417 |
418 | def init_hiddens(self, bs):
419 | s_t = torch.zeros(bs, self.hidden_size).to(self.device)
420 | s_t2 = torch.zeros(bs, self.hidden_size).to(self.device)
421 | c_t = torch.zeros(bs, self.hidden_size).to(self.device)
422 | c_t2 = torch.zeros(bs, self.hidden_size).to(self.device)
423 | return s_t, s_t2, c_t, c_t2
424 |
425 | def forward(self, svt_tmp, memory_ram_vid, memory_ram_txt, loop=3):
426 | """
427 |
428 | :param svt_tmp:
429 | :param memory_ram_vid:
430 | :param memory_ram_txt:
431 | :param loop:
432 | :return:
433 | """
434 | bs = svt_tmp.shape[0]
435 | sm_q1, sm_q2, cm_q1, cm_q2 = self.init_hiddens(bs)
436 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_1(svt_tmp)))
437 |
438 | for _ in range(loop):
439 | sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1))
440 | sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2))
441 |
442 | mm_o1 = self.mm_att(sm_q2, memory_ram_vid, memory_ram_txt)
443 | mm_o2 = torch.cat((sm_q2, mm_o1), dim=1)
444 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_2(mm_o2)))
445 |
446 | smq = torch.cat((sm_q1, sm_q2), dim=1)
447 |
448 | return smq
--------------------------------------------------------------------------------
/networks/torchnlp_nn.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Parameter
2 |
3 | import torch
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | class weight_drop():
9 |
10 | def __init__(self, module, weights, dropout):
11 | for name_w in weights:
12 | w = getattr(module, name_w)
13 | del module._parameters[name_w]
14 | module.register_parameter(name_w + '_raw', Parameter(w))
15 |
16 | self.original_module_forward = module.forward
17 |
18 | self.weights = weights
19 | self.module = module
20 | self.dropout = dropout
21 |
22 | def __call__(self, *args, **kwargs):
23 | for name_w in self.weights:
24 | raw_w = getattr(self.module, name_w + '_raw')
25 | w = torch.nn.functional.dropout(
26 | raw_w, p=self.dropout, training=self.module.training)
27 | # module.register_parameter(name_w, Parameter(w))
28 | setattr(self.module, name_w, Parameter(w))
29 |
30 | return self.original_module_forward(*args, **kwargs)
31 |
32 |
33 | def _weight_drop(module, weights, dropout):
34 | setattr(module, 'forward', weight_drop(module, weights, dropout))
35 |
36 |
37 | # def _weight_drop(module, weights, dropout):
38 | # """
39 | # Helper for `WeightDrop`.
40 | # """
41 |
42 | # for name_w in weights:
43 | # w = getattr(module, name_w)
44 | # del module._parameters[name_w]
45 | # module.register_parameter(name_w + '_raw', Parameter(w))
46 |
47 | # original_module_forward = module.forward
48 |
49 | # def forward(*args, **kwargs):
50 | # for name_w in weights:
51 | # raw_w = getattr(module, name_w + '_raw')
52 | # w = torch.nn.functional.dropout(
53 | # raw_w, p=dropout, training=module.training)
54 | # # module.register_parameter(name_w, Parameter(w))
55 | # setattr(module, name_w, Parameter(w))
56 |
57 | # return original_module_forward(*args, **kwargs)
58 |
59 | # setattr(module, 'forward', forward)
60 |
61 |
62 | class WeightDrop(torch.nn.Module):
63 | """
64 | The weight-dropped module applies recurrent regularization through a DropConnect mask on the
65 | hidden-to-hidden recurrent weights.
66 | **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. Here is
67 | their `License
68 | `__.
69 | Args:
70 | module (:class:`torch.nn.Module`): Containing module.
71 | weights (:class:`list` of :class:`str`): Names of the module weight parameters to apply a
72 | dropout too.
73 | dropout (float): The probability a weight will be dropped.
74 | Example:
75 | >>> from torchnlp.nn import WeightDrop
76 | >>> import torch
77 | >>>
78 | >>> torch.manual_seed(123)
79 | >>
81 | >>> gru = torch.nn.GRUCell(2, 2)
82 | >>> weights = ['weight_hh']
83 | >>> weight_drop_gru = WeightDrop(gru, weights, dropout=0.9)
84 | >>>
85 | >>> input_ = torch.randn(3, 2)
86 | >>> hidden_state = torch.randn(3, 2)
87 | >>> weight_drop_gru(input_, hidden_state)
88 | tensor(... grad_fn=)
89 | """
90 |
91 | def __init__(self, module, weights, dropout=0.0):
92 | super(WeightDrop, self).__init__()
93 | _weight_drop(module, weights, dropout)
94 | self.forward = module.forward
95 |
96 |
97 | class WeightDropLSTM(torch.nn.LSTM):
98 | """
99 | Wrapper around :class:`torch.nn.LSTM` that adds ``weight_dropout`` named argument.
100 | Args:
101 | weight_dropout (float): The probability a weight will be dropped.
102 | """
103 |
104 | def __init__(self, *args, weight_dropout=0.0, **kwargs):
105 | super().__init__(*args, **kwargs)
106 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)]
107 | _weight_drop(self, weights, weight_dropout)
108 |
109 |
110 | class WeightDropGRU(torch.nn.GRU):
111 | """
112 | Wrapper around :class:`torch.nn.GRU` that adds ``weight_dropout`` named argument.
113 | Args:
114 | weight_dropout (float): The probability a weight will be dropped.
115 | """
116 |
117 | def __init__(self, *args, weight_dropout=0.0, **kwargs):
118 | super().__init__(*args, **kwargs)
119 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)]
120 | _weight_drop(self, weights, weight_dropout)
121 |
122 |
123 | class WeightDropLinear(torch.nn.Linear):
124 | """
125 | Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument.
126 | Args:
127 | weight_dropout (float): The probability a weight will be dropped.
128 | """
129 |
130 | def __init__(self, *args, weight_dropout=0.0, **kwargs):
131 | super().__init__(*args, **kwargs)
132 | weights = ['weight']
133 | _weight_drop(self, weights, weight_dropout)
134 |
--------------------------------------------------------------------------------
/requirement.txt:
--------------------------------------------------------------------------------
1 | addict==2.4.0
2 | antlr4-python3-runtime==4.8
3 | appdirs==1.4.4
4 | astor==0.8.1
5 | backcall==0.2.0
6 | black==21.4b2
7 | blis==0.7.4
8 | block.bootstrap.pytorch==0.1.6
9 | bootstrap.pytorch==0.0.13
10 | brotlipy==0.7.0
11 | catalogue==2.0.6
12 | certifi==2021.5.30
13 | cloudpickle==1.6.0
14 | colorama==0.4.4
15 | contextvars==2.4
16 | cycler==0.10.0
17 | cymem==2.0.5
18 | Cython==0.29.23
19 | dataclasses==0.7
20 | decorator==4.4.2
21 | filelock==3.0.10
22 | future==0.18.2
23 | fvcore==0.1.5.post20210515
24 | google-pasta==0.2.0
25 | hydra-core==1.1.0rc1
26 | imageio==2.15.0
27 | immutables==0.16
28 | importlib-resources==5.1.3
29 | instaboostfast==0.1.2
30 | iopath==0.1.8
31 | ipdb==0.13.9
32 | ipython==7.12.0
33 | Jinja2==3.0.2
34 | Keras-Preprocessing==1.1.0
35 | kiwisolver==1.3.1
36 | language-tool-python==2.5.5
37 | MarkupSafe==2.0.1
38 | matplotlib==3.3.4
39 | mmcv-full==1.4.8
40 | model-index==0.1.11
41 | munch==2.5.0
42 | murmurhash==1.0.5
43 | mypy-extensions==0.4.3
44 | networkx==2.5.1
45 | olefile==0.46
46 | omegaconf==2.1.0rc1
47 | opencv-python==4.5.1.48
48 | openmim==0.1.5
49 | ordered-set==4.0.2
50 | pathspec==0.8.1
51 | pathy==0.6.0
52 | Pillow==8.4.0
53 | plotly==5.3.1
54 | portalocker==2.3.0
55 | preshed==3.0.5
56 | pretrainedmodels==0.7.4
57 | protobuf==3.13.0
58 | pycocotools==2.0.2
59 | pydantic==1.8.2
60 | pydot==1.4.2
61 | pyhocon==0.3.58
62 | python-dateutil==2.8.1
63 | PyWavelets==1.1.1
64 | PyYAML==5.4.1
65 | pyzmq==19.0.2
66 | scikit-image==0.17.2
67 | scikit-learn==0.24.2
68 | seaborn==0.11.2
69 | sk-video==1.1.10
70 | skipthoughts==0.0.1
71 | smart-open==5.2.1
72 | spacy==3.1.3
73 | spacy-legacy==3.0.8
74 | srsly==2.4.1
75 | stanfordnlp==0.2.0
76 | tabulate==0.8.9
77 | tenacity==8.0.1
78 | tensorboard==1.14.0
79 | tensorflow-estimator==1.14.0
80 | termcolor==1.1.0
81 | terminaltables==3.1.10
82 | thinc==8.0.10
83 | threadpoolctl==3.0.0
84 | tifffile==2020.9.3
85 | tokenizers==0.10.3
86 | toml==0.10.2
87 | torch==1.7.1
88 | torchvision==0.8.2
89 | traitlets==4.3.3
90 | transformers==4.11.3
91 | typed-args==0.4.2
92 | typed-ast==1.4.3
93 | typer==0.4.0
94 | urllib3==1.26.7
95 | wasabi==0.8.2
96 | Werkzeug==1.0.1
97 | wrapt==1.12.1
98 | yacs==0.1.8
99 | yapf==0.32.0
100 | numpy==1.19.2
101 | nltk
102 | h5py
103 | SharedArray
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | import pickle as pkl
5 | import pandas as pd
6 | import logging
7 |
8 | def make_logger(log_file):
9 | logger = logging.getLogger()
10 | logger.setLevel(logging.INFO)
11 |
12 | logfile = log_file
13 | fh = logging.FileHandler(logfile)
14 | fh.setLevel(logging.DEBUG)
15 |
16 | ch = logging.StreamHandler()
17 | ch.setLevel(logging.INFO)
18 |
19 | formatter = logging.Formatter("%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s")
20 | fh.setFormatter(formatter)
21 | ch.setFormatter(formatter)
22 |
23 | logger.addHandler(fh)
24 | logger.addHandler(ch)
25 | logger.info('logfile = {}'.format(logfile))
26 | return logger
27 |
28 | def set_gpu_devices(gpu_id):
29 | gpu = ''
30 | if gpu_id != -1:
31 | gpu = str(gpu_id)
32 | os.environ['CUDA_VOSIBLE_DEVICES'] = gpu
33 |
34 |
35 | def load_file(filename):
36 | """
37 | load obj from filename
38 | :param filename:
39 | :return:
40 | """
41 | cont = None
42 | if not osp.exists(filename):
43 | print('{} not exist'.format(filename))
44 | return cont
45 | if osp.splitext(filename)[-1] == '.csv':
46 | # return pd.read_csv(filename, delimiter= '\t', index_col=0)
47 | return pd.read_csv(filename, delimiter=',')
48 | with open(filename, 'r') as fp:
49 | if osp.splitext(filename)[1] == '.txt':
50 | cont = fp.readlines()
51 | cont = [c.rstrip('\n') for c in cont]
52 | elif osp.splitext(filename)[1] == '.json':
53 | cont = json.load(fp)
54 | return cont
55 |
56 |
57 | def save_file(obj, filename):
58 | """
59 | save obj to filename
60 | :param obj:
61 | :param filename:
62 | :return:
63 | """
64 | filepath = osp.dirname(filename)
65 | if filepath != '' and not osp.exists(filepath):
66 | os.makedirs(filepath)
67 | with open(filename, 'w') as fp:
68 | json.dump(obj, fp, indent=4)
69 |
70 |
71 | def pkload(file):
72 | data = None
73 | if osp.exists(file) and osp.getsize(file) > 0:
74 | with open(file, 'rb') as fp:
75 | data = pkl.load(fp)
76 | # print('{} does not exist'.format(file))
77 | return data
78 |
79 |
80 | def pkdump(data, file):
81 | dirname = osp.dirname(file)
82 | if not osp.exists(dirname):
83 | os.makedirs(dirname)
84 | with open(file, 'wb') as fp:
85 | pkl.dump(data, fp)
86 |
--------------------------------------------------------------------------------
/videoqa.py:
--------------------------------------------------------------------------------
1 | from networks import Embed_loss, EncoderRNN, CRN
2 | from networks.VQAModel import EVQA, HCRN, CoMem, HME, HGA, B2A
3 | from utils import *
4 | from torch.optim.lr_scheduler import ReduceLROnPlateau
5 | import torch
6 | import torch.nn as nn
7 | import time
8 |
9 |
10 | class VideoQA():
11 | def __init__(self, vocab, train_loader, val_loader, test_loader, glove_embed, use_bert, checkpoint_path, model_type,
12 | model_prefix, vis_step, lr_rate, batch_size, epoch_num, logger, args):
13 | self.vocab = vocab
14 | self.train_loader = train_loader
15 | self.val_loader = val_loader
16 | self.test_loader = test_loader
17 | self.glove_embed = glove_embed
18 | self.use_bert = use_bert
19 | self.model_dir = checkpoint_path
20 | self.model_type = model_type
21 | self.model_prefix = model_prefix
22 | self.vis_step = vis_step
23 | self.lr_rate = lr_rate
24 | self.batch_size = batch_size
25 | self.epoch_num = epoch_num
26 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27 | self.model = None
28 | self.logger = logger
29 | self.args = args
30 |
31 | def build_model(self):
32 |
33 | vid_dim = self.args.vid_dim
34 | hidden_dim = self.args.hidden_dim
35 | word_dim = self.args.word_dim
36 | vocab_size = len(self.vocab)
37 | max_vid_len = self.args.max_vid_len
38 | max_vid_frame_len = self.args.max_vid_frame_len
39 | max_qa_len = self.args.max_qa_len
40 | spl_resolution = self.args.spl_resolution
41 |
42 | if self.model_type == 'EVQA' or self.model_type == 'BlindQA':
43 | #ICCV15, AAAI17
44 | vid_encoder = EncoderRNN.EncoderVid(vid_dim, hidden_dim, input_dropout_p=0.2, n_layers=1, rnn_dropout_p=0, bidirectional=False, rnn_cell='gru')
45 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, input_dropout_p=0.2, rnn_dropout_p=0, bidirectional=False, rnn_cell='gru')
46 |
47 | self.model = EVQA.EVQA(vid_encoder, qns_encoder, self.device, self.model_type == 'BlindQA')
48 |
49 | elif self.model_type == 'CoMem':
50 | #CVPR18
51 | app_dim = 2048
52 | motion_dim = 2048
53 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru')
54 |
55 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=2, rnn_dropout_p=0.5, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru')
56 |
57 | self.model = CoMem.CoMem(vid_encoder, qns_encoder, max_vid_len, max_qa_len, self.device)
58 |
59 | elif self.model_type == 'HME':
60 | #CVPR19
61 | app_dim = 2048
62 | motion_dim = 2048
63 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru')
64 |
65 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=2, rnn_dropout_p=0.5, input_dropout_p=0.2, bidirectional=False, rnn_cell='gru')
66 |
67 | self.model = HME.HME(vid_encoder, qns_encoder, max_vid_len, max_qa_len*2, self.device)
68 |
69 | elif self.model_type == 'HGA':
70 | #AAAI20
71 | vid_encoder = EncoderRNN.EncoderVidHGA(vid_dim, hidden_dim, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru')
72 |
73 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, rnn_dropout_p=0, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru')
74 |
75 | self.model = HGA.HGA(vid_encoder, qns_encoder, self.device)
76 |
77 | elif self.model_type == 'HCRN':
78 | #CVPR20
79 | vid_dim = vid_dim//2
80 | vid_encoder = CRN.EncoderVidCRN(max_vid_frame_len, max_vid_len, spl_resolution, vid_dim, hidden_dim)
81 |
82 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, rnn_dropout_p=0.2, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru')
83 |
84 | self.model = HCRN.HCRN(vid_encoder, qns_encoder, self.device)
85 |
86 | elif self.model_type == 'B2A' or self.model_type == 'B2A2':
87 | #CVPR21
88 | vid_dim = vid_dim // 2
89 | vid_encoder = EncoderRNN.EncoderVidB2A(vid_dim, hidden_dim*2, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru')
90 |
91 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, vocab_size, self.glove_embed, self.use_bert, n_layers=1, rnn_dropout_p=0, input_dropout_p=0.3, bidirectional=True, rnn_cell='gru')
92 |
93 | self.model = B2A.B2A(vid_encoder, qns_encoder, self.device)
94 |
95 |
96 |
97 | params = [{'params':self.model.parameters()}]
98 |
99 | self.optimizer = torch.optim.Adam(params = params, lr=self.lr_rate)
100 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'max', factor=0.5, patience=5, verbose=True)
101 |
102 | self.model.to(self.device)
103 | self.criterion = Embed_loss.MultipleChoiceLoss().to(self.device)
104 |
105 |
106 | def save_model(self, epoch, acc, is_best=False):
107 | if not is_best:
108 | torch.save(self.model.state_dict(), osp.join(self.model_dir, self.model_type, self.model_prefix, 'model', '{}-{:.2f}.ckpt'
109 | .format(epoch, acc)))
110 | else:
111 | torch.save(self.model.state_dict(), osp.join(self.model_dir, self.model_type, self.model_prefix, 'model', 'best.ckpt'))
112 |
113 | def resume(self, model_file):
114 | """
115 | initialize model with pretrained weights
116 | :return:
117 | """
118 | self.logger.info('Warm-start (or test) with model: {}'.format(model_file))
119 | model_dict = torch.load(model_file)
120 | new_model_dict = {}
121 | for k, v in self.model.state_dict().items():
122 | if k in model_dict:
123 | v = model_dict[k]
124 | else:
125 | pass
126 | # print(k)
127 | new_model_dict[k] = v
128 | self.model.load_state_dict(new_model_dict)
129 |
130 |
131 | def run(self, model_file, pre_trained=False):
132 | self.build_model()
133 | self.logger.info(self.model)
134 | best_eval_score = 0.0
135 | if pre_trained:
136 | self.resume(model_file)
137 | best_eval_score = self.eval(0)
138 | self.logger.info('Initial Acc {:.2f}'.format(best_eval_score))
139 |
140 | for epoch in range(0, self.epoch_num):
141 | train_loss, train_acc = self.train(epoch)
142 | eval_score = self.eval(epoch)
143 | eval_score_test = self.eval_t(epoch)
144 | self.logger.info("==>Epoch:[{}/{}][Train Loss: {:.4f}; Train acc: {:.2f}; Val acc: {:.2f}; Test acc: {:.2f}]".
145 | format(epoch, self.epoch_num, train_loss, train_acc, eval_score, eval_score_test))
146 | self.scheduler.step(eval_score)
147 | self.save_model(epoch, eval_score)
148 | if eval_score > best_eval_score:
149 | best_eval_score = eval_score
150 | self.save_model(epoch, best_eval_score, True)
151 |
152 | def train(self, epoch):
153 | self.logger.info('==>Epoch:[{}/{}][lr_rate: {}]'.format(epoch, self.epoch_num, self.optimizer.param_groups[0]['lr']))
154 | self.model.train()
155 | total_step = len(self.train_loader)
156 | epoch_loss = 0.0
157 | prediction_list = []
158 | answer_list = []
159 | for iter, inputs in enumerate(self.train_loader):
160 | visual, can, ques, ans_id, qns_key = inputs
161 | app_inputs = visual[0].to(self.device)
162 | mot_inputs = visual[1].to(self.device)
163 | candidate = can[0].to(self.device)
164 | candidate_lengths = can[1]
165 | obj_fea_can = can[2].to(self.device)
166 | dep_adj_can = can[3].to(self.device)
167 | question = ques[0].to(self.device)
168 | ques_lengths = ques[1]
169 | obj_fea_q = ques[2].to(self.device)
170 | dep_adj_q = ques[3].to(self.device)
171 | ans_targets = ans_id.to(self.device)
172 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q)
173 |
174 | self.model.zero_grad()
175 | loss = self.criterion(out, ans_targets)
176 | if not torch.isnan(loss):
177 | loss.backward()
178 | else:
179 | print(out)
180 | print(ans_targets)
181 | self.optimizer.step()
182 | epoch_loss += loss.item()
183 | if iter % self.vis_step == 0:
184 | self.logger.info('\t[{}/{}] Training loss: {:.4f}'.format(iter, total_step, epoch_loss/(iter+1)))
185 |
186 | prediction_list.append(prediction)
187 | answer_list.append(ans_id)
188 |
189 | predict_answers = torch.cat(prediction_list, dim=0).long().cpu()
190 | ref_answers = torch.cat(answer_list, dim=0).long()
191 | acc_num = torch.sum(predict_answers==ref_answers).numpy()
192 | print(len(ref_answers))
193 |
194 | return epoch_loss / total_step, acc_num*100.0 / len(ref_answers)
195 |
196 |
197 | def eval(self, epoch):
198 | self.logger.info('==>Epoch:[{}/{}][validation stage]'.format(epoch, self.epoch_num))
199 | self.model.eval()
200 | total_step = len(self.val_loader)
201 | acc_count = 0
202 | prediction_list = []
203 | answer_list = []
204 | with torch.no_grad():
205 | for iter, inputs in enumerate(self.val_loader):
206 | visual, can, ques, ans_id, qns_key = inputs
207 | app_inputs = visual[0].to(self.device)
208 | mot_inputs = visual[1].to(self.device)
209 | candidate = can[0].to(self.device)
210 | candidate_lengths = can[1]
211 | obj_fea_can = can[2].to(self.device)
212 | dep_adj_can = can[3].to(self.device)
213 | question = ques[0].to(self.device)
214 | ques_lengths = ques[1]
215 | obj_fea_q = ques[2].to(self.device)
216 | dep_adj_q = ques[3].to(self.device)
217 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q)
218 |
219 | prediction_list.append(prediction)
220 | answer_list.append(ans_id)
221 |
222 | predict_answers = torch.cat(prediction_list, dim=0).long().cpu()
223 | ref_answers = torch.cat(answer_list, dim=0).long()
224 | acc_num = torch.sum(predict_answers == ref_answers).numpy()
225 | print(len(ref_answers))
226 |
227 | return acc_num*100.0 / len(ref_answers)
228 |
229 | def eval_t(self, epoch):
230 | self.logger.info('==>Epoch:[{}/{}][test stage]'.format(epoch, self.epoch_num))
231 | self.model.eval()
232 | total_step = len(self.test_loader)
233 | acc_count = 0
234 | prediction_list = []
235 | answer_list = []
236 | with torch.no_grad():
237 | for iter, inputs in enumerate(self.test_loader):
238 | visual, can, ques, ans_id, qns_key = inputs
239 | app_inputs = visual[0].to(self.device)
240 | mot_inputs = visual[1].to(self.device)
241 | candidate = can[0].to(self.device)
242 | candidate_lengths = can[1]
243 | obj_fea_can = can[2].to(self.device)
244 | dep_adj_can = can[3].to(self.device)
245 | question = ques[0].to(self.device)
246 | ques_lengths = ques[1]
247 | obj_fea_q = ques[2].to(self.device)
248 | dep_adj_q = ques[3].to(self.device)
249 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q)
250 |
251 | prediction_list.append(prediction)
252 | answer_list.append(ans_id)
253 |
254 | predict_answers = torch.cat(prediction_list, dim=0).long().cpu()
255 | ref_answers = torch.cat(answer_list, dim=0).long()
256 | acc_num = torch.sum(predict_answers == ref_answers).numpy()
257 | print(len(ref_answers))
258 |
259 | return acc_num*100.0 / len(ref_answers)
260 |
261 |
262 | def predict(self, model_file, result_file, loader):
263 | """
264 | predict the answer with the trained model
265 | :param model_file:
266 | :return:
267 | """
268 | self.build_model()
269 | self.resume(model_file)
270 |
271 | self.model.eval()
272 | results = {}
273 | with torch.no_grad():
274 | for iter, inputs in enumerate(loader):
275 | visual, can, ques, ans_id, qns_key = inputs
276 | app_inputs = visual[0].to(self.device)
277 | mot_inputs = visual[1].to(self.device)
278 | candidate = can[0].to(self.device)
279 | candidate_lengths = can[1]
280 | obj_fea_can = can[2].to(self.device)
281 | dep_adj_can = can[3].to(self.device)
282 | question = ques[0].to(self.device)
283 | ques_lengths = ques[1]
284 | obj_fea_q = ques[2].to(self.device)
285 | dep_adj_q = ques[3].to(self.device)
286 | out, prediction = self.model(app_inputs, mot_inputs, candidate, candidate_lengths, obj_fea_can, dep_adj_can, question, ques_lengths, obj_fea_q, dep_adj_q)
287 |
288 | prediction = prediction.data.cpu().numpy()
289 | ans_id = ans_id.numpy()
290 | for qid, pred, ans in zip(qns_key, prediction, ans_id):
291 | results[qid] = {'prediction': int(pred), 'answer': int(ans)}
292 |
293 | print(len(results))
294 | print(result_file)
295 | save_file(results, result_file)
296 |
--------------------------------------------------------------------------------