├── .gitignore
├── LICENSE
├── README.md
├── build_vocab.py
├── dataloader
├── __init__.py
├── sample_loader.py
└── util.py
├── dataset
└── nextqa
│ ├── .gitignore
│ ├── add_reference_answer_test.json
│ ├── test.csv
│ ├── train.csv
│ └── val.csv
├── eval_oe.py
├── images
├── logo.png
└── res-mc-oe.png
├── main.sh
├── main_qa.py
├── metrics.py
├── models
└── .gitignore
├── networks
├── .gitignore
├── Attention.py
├── DecoderRNN.py
├── EncoderRNN.py
├── VQAModel
│ ├── CoMem.py
│ ├── EVQA.py
│ ├── HGA.py
│ ├── HME.py
│ ├── STVQA.py
│ └── UATT.py
├── VQAModel_bak.py
├── __init__.py
├── gcn.py
├── memory_module.py
├── memory_rand.py
├── q_v_transformer.py
└── torchnlp_nn.py
├── requirements.txt
├── results
├── HGA-same-att-qns23ans7-test.json
├── HGA-same-att-qns23ans7-val-example.json
└── HGA-same-att-qns23ans7-val.json
├── stopwords.txt
├── tools
└── .gitignore
├── utils.py
├── videoqa.py
└── word2vec.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | __pycache__
3 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Junbin
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # [NExT-QA](https://arxiv.org/pdf/2105.08276.pdf)
2 |
3 | We reproduce some SOTA VideoQA methods to provide benchmark results for our NExT-QA dataset accepted to CVPR2021.
4 |
5 | NExT-QA is a VideoQA benchmark targeting the explanation of video contents. It challenges QA models to reason about the causal and temporal actions and understand the rich object interactions in daily activities. We set up both multi-choice and open-ended QA tasks on the dataset. This repo. provides resources for open-ended QA; multi-choice QA is found in [NExT-QA](https://github.com/doc-doc/NExT-QA). For more details, please refer to our [dataset](https://doc-doc.github.io/docs/nextqa.html) page.
6 |
7 | ## Todo
8 | 1. [x] Raw Videos are the same with [NExT-QA(MC)](https://github.com/doc-doc/NExT-QA).
9 | 2. [ ] Open online evaluation server and release [test data](https://drive.google.com/file/d/1bXBFN61PaTSHTnJqz3R79mpIgEQPFGIU/view?usp=sharing).
10 | 3. [x] RoI features are the same with [NExT-QA(MC)](https://github.com/doc-doc/NExT-QA).
11 | ## Environment
12 |
13 | Anaconda 4.8.4, python 3.6.8, pytorch 1.6 and cuda 10.2. For other libs, please refer to the file requirements.txt.
14 |
15 | ## Install
16 | Please create an env for this project using anaconda (should install [anaconda](https://docs.anaconda.com/anaconda/install/linux/) first)
17 | ```
18 | >conda create -n videoqa python==3.6.8
19 | >conda activate videoqa
20 | >git clone https://github.com/doc-doc/NExT-OE.git
21 | >pip install -r requirements.txt
22 | ```
23 | ## Data Preparation
24 | Please download the pre-computed features and QA annotations from [here](https://drive.google.com/drive/folders/14jSt4sGFQaZxBu4AGL2Svj34fUhcK2u0?usp=sharing). There are 3 zip files:
25 | - ```['vid_feat.zip']```: Appearance and motion feature for video representation (same as multi-choice QA).
26 | - ```['nextqa.zip']```: Annotations of QAs and GloVe Embeddings (open-ended version).
27 | - ```['models.zip']```: HGA model (open-ended version).
28 |
29 | After downloading the data, please create a folder ```['data/feats']``` at the same directory as ```['NExT-OE']```, then unzip the video features into it. You will have directories like ```['data/feats/vid_feat/', and 'NExT-OE/']``` in your workspace. Please unzip the files in ```['nextqa.zip']``` into ```['NExT-OE/dataset/nextqa']``` and ```['models.zip']``` into ```['NExT-OE/models/']```.
30 |
31 |
32 | ## Usage
33 | Once the data is ready, you can easily run the code. First, to test the environment and code, we provide the prediction and model of the SOTA approach (i.e., HGA) on NExT-QA.
34 | You can get the results reported in the paper by running:
35 | ```
36 | >python eval_oe.py
37 | ```
38 | The command above will load the prediction file under ['results/'] and evaluate it.
39 | You can also obtain the prediction by running:
40 | ```
41 | >./main.sh 0 val #Test the model with GPU id 0
42 | ```
43 | The command above will load the model under ['models/'] and generate the prediction file.
44 | If you want to train the model, please run
45 | ```
46 | >./main.sh 0 train # Train the model with GPU id 0
47 | ```
48 | It will train the model and save to ['models']. (*The results may be slightly different depending on the environments*)
49 | ## Results on Val
50 | | Methods | Text Rep. | WUPS_C | WUPS_T | WUPS_D | WUPS |
51 | | -------------------------| --------: | ----: | ----: | ----: | ---:|
52 | | BlindQA | GloVe | 12.14 | 14.85 | 40.41 | 18.88 |
53 | | [STVQA](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/STVQA.py) ([CVPR17](https://openaccess.thecvf.com/content_cvpr_2017/papers/Jang_TGIF-QA_Toward_Spatio-Temporal_CVPR_2017_paper.pdf)) | GloVe | 12.52 | 14.57 | 45.64 | 20.08 |
54 | | [UATT](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/UATT.py) ([TIP17](https://ieeexplore.ieee.org/document/8017608)) | GloVe | 13.62 | **16.23** | 43.41 | 20.65 |
55 | | [HME](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/HME.py) ([CVPR19](https://openaccess.thecvf.com/content_CVPR_2019/papers/Fan_Heterogeneous_Memory_Enhanced_Multimodal_Attention_Model_for_Video_Question_Answering_CVPR_2019_paper.pdf)) | GloVe | 12.83 | 14.76 | 45.13 | 20.18 |
56 | | [HCRN](https://github.com/thaolmk54/hcrn-videoqa) ([CVPR20](https://openaccess.thecvf.com/content_CVPR_2020/papers/Le_Hierarchical_Conditional_Relation_Networks_for_Video_Question_Answering_CVPR_2020_paper.pdf)) | GloVe | 12.53 | 15.37 | 45.29 | 20.25 |
57 | | [HGA](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/HGA.py) ([AAAI20](https://ojs.aaai.org//index.php/AAAI/article/view/6767)) | GloVe | **14.76** | 14.90 | **46.60** | **21.48** |
58 |
59 | Please refer to our paper for results on the test set.
60 | ## Multi-choice QA *vs.* Open-ended QA
61 | 
62 |
63 | ## Some Latest Results
64 | | Methods | Publication | Highlight | Val (WUPS@All) | Test (WUPS@All) |
65 | | -------------------------| --------: |--------: | ----: | ----:|
66 | |[Emu(0-shot)](https://arxiv.org/pdf/2307.05222v1.pdf) by BAAI | arXiv'23 | VL foundation model | - | 23.4 |
67 | | [Flamingo(0-shot)](https://arxiv.org/pdf/2204.14198.pdf) by DeepMind | NeurIPS'22 | VL foundation model | - | 26.7|
68 | | [KcGA](https://ojs.aaai.org/index.php/AAAI/article/view/25983) by Baidu | AAAI'23 | Knowledge base, GPT-2 | - | 28.2 |
69 | | [Flamingo(32-shot)](https://arxiv.org/pdf/2204.14198.pdf) by DeepMind | NeurIPS'22 | VL foundation model | - | 33.5|
70 | | [PaLI-X](https://arxiv.org/pdf/2305.18565.pdf) by Google Research | arXiv'23 | VL foundation model | - | 38.3 |
71 |
72 | ## Citation
73 | ```
74 | @InProceedings{xiao2021next,
75 | author = {Xiao, Junbin and Shang, Xindi and Yao, Angela and Chua, Tat-Seng},
76 | title = {NExT-QA: Next Phase of Question-Answering to Explaining Temporal Actions},
77 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
78 | month = {June},
79 | year = {2021},
80 | pages = {9777-9786}
81 | }
82 | ```
83 | ## Acknowledgement
84 | Our reproduction of the methods is based on the respective official repositories, we thank the authors to release their code. If you use the related part, please cite the corresponding paper commented in the code.
85 |
--------------------------------------------------------------------------------
/build_vocab.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | # nltk.download('punkt')
3 | import pickle
4 | import argparse
5 | from utils import load_file, save_file
6 | from collections import Counter
7 |
8 |
9 |
10 | class Vocabulary(object):
11 | """Simple vocabulary wrapper."""
12 | def __init__(self):
13 | self.word2idx = {}
14 | self.idx2word = {}
15 | self.idx = 0
16 |
17 | def add_word(self, word):
18 | if not word in self.word2idx:
19 | self.word2idx[word] = self.idx
20 | self.idx2word[self.idx] = word
21 | self.idx += 1
22 |
23 | def __call__(self, word):
24 | if not word in self.word2idx:
25 | return self.word2idx['']
26 | return self.word2idx[word]
27 |
28 | def __len__(self):
29 | return len(self.word2idx)
30 |
31 |
32 |
33 | def build_vocab(anno_file, threshold):
34 | """Build a simple vocabulary wrapper."""
35 |
36 | annos = load_file(anno_file)
37 | print('total QA pairs', len(annos))
38 | counter = Counter()
39 |
40 | for rid, (qns, ans) in enumerate(zip(annos['question'], annos['answer'])):
41 | # qns, ans = vqa['question'], vqa['answer']
42 | text = qns +' ' +ans
43 | tokens = nltk.tokenize.word_tokenize(text.lower())
44 | counter.update(tokens)
45 |
46 | counter = sorted(counter.items(), key=lambda item:item[1], reverse=True)
47 |
48 | # If the word frequency is less than 'threshold', then the word is discarded.
49 | words = [item[0] for item in counter if item[1] >= threshold]
50 |
51 | # Create a vocab wrapper and add some special tokens.
52 | vocab = Vocabulary()
53 | vocab.add_word('')
54 | vocab.add_word('')
55 | vocab.add_word('')
56 | vocab.add_word('')
57 |
58 | # Add the words to the vocabulary.
59 | for i, word in enumerate(words):
60 | vocab.add_word(word)
61 |
62 | return vocab
63 |
64 |
65 | def main(args):
66 | vocab = build_vocab(args.caption_path, args.threshold)
67 | vocab_path = args.vocab_path
68 | with open(vocab_path, 'wb') as f:
69 | pickle.dump(vocab, f)
70 | print("Total vocabulary size: {}".format(len(vocab)))
71 | print("Saved the vocabulary wrapper to '{}'".format(vocab_path))
72 |
73 |
74 | if __name__ == '__main__':
75 | parser = argparse.ArgumentParser()
76 | parser.add_argument('--anno_path', type=str,
77 | default='dataset/nextqa/all.csv',
78 | help='path for train annotation file')
79 | parser.add_argument('--vocab_path', type=str, default='dataset/nextqa/vocab.pkl',
80 | help='path for saving vocabulary wrapper')
81 | parser.add_argument('--threshold', type=int, default=5,
82 | help='minimum word count threshold')
83 | args = parser.parse_args()
84 | main(args)
85 |
--------------------------------------------------------------------------------
/dataloader/__init__.py:
--------------------------------------------------------------------------------
1 | # ====================================================
2 | # @Time : 15/5/20 3:48 PM
3 | # @Author : Xiao Junbin
4 | # @Email : junbin@comp.nus.edu.sg
5 | # @File : __init__.py
6 | # ====================================================
7 | from .sample_loader import *
--------------------------------------------------------------------------------
/dataloader/sample_loader.py:
--------------------------------------------------------------------------------
1 | # ====================================================
2 | # @Time : 19/5/20 10:42 PM
3 | # @Author : Xiao Junbin
4 | # @Email : junbin@comp.nus.edu.sg
5 | # @File : sample_loader.py
6 | # ====================================================
7 | import torch
8 | from torch.utils.data import Dataset, DataLoader
9 | from .util import load_file, pkdump, pkload
10 | import os.path as osp
11 | import numpy as np
12 | import nltk
13 | import h5py
14 |
15 | class VidQADataset(Dataset):
16 | """load the dataset in dataloader"""
17 |
18 | def __init__(self, video_feature_path, video_feature_cache, sample_list_path, vocab_qns, vocab_ans, mode):
19 | self.video_feature_path = video_feature_path
20 | self.vocab_qns = vocab_qns
21 | self.vocab_ans = vocab_ans
22 | sample_list_file = osp.join(sample_list_path, '{}.csv'.format(mode))
23 | self.sample_list = load_file(sample_list_file)
24 | self.video_feature_cache = video_feature_cache
25 | self.use_frame = True
26 | self.use_mot = True
27 | self.frame_feats = {}
28 | self.mot_feats = {}
29 | vid_feat_file = osp.join(video_feature_path, 'vid_feat/app_mot_{}.h5'.format(mode))
30 | with h5py.File(vid_feat_file, 'r') as fp:
31 | vids = fp['ids']
32 | feats = fp['feat']
33 | for id, (vid, feat) in enumerate(zip(vids, feats)):
34 | if self.use_frame:
35 | self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048)
36 | if self.use_mot:
37 | self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048)
38 |
39 |
40 | def __len__(self):
41 | return len(self.sample_list)
42 |
43 |
44 | def get_video_feature(self, video_name):
45 | """
46 |
47 | """
48 | if self.use_frame:
49 | app_feat = self.frame_feats[video_name]
50 | video_feature = app_feat # (16, 2048)
51 | if self.use_mot:
52 | mot_feat = self.mot_feats[video_name]
53 | video_feature = np.concatenate((video_feature, mot_feat), axis=1) #(16, 4096)
54 |
55 | return torch.from_numpy(video_feature).type(torch.float32)
56 |
57 |
58 | def get_word_idx(self, text, src='qns'):
59 | """
60 | convert relation to index sequence
61 | :param relation:
62 | :return:
63 | """
64 | if src=='qns': vocab = self.vocab_qns
65 | elif src=='ans': vocab = self.vocab_ans
66 | tokens = nltk.tokenize.word_tokenize(str(text).lower())
67 | text = []
68 | text.append(vocab(''))
69 | text.extend([vocab(token) for i,token in enumerate(tokens) if i < 23])
70 | #text.append(vocab(''))
71 | target = torch.Tensor(text)
72 |
73 | return target
74 |
75 |
76 | def __getitem__(self, idx):
77 | """
78 |
79 | """
80 |
81 | sample = self.sample_list.loc[idx]
82 | video_name, qns, ans = sample['video'], sample['question'], sample['answer']
83 | qid, qtype = sample['qid'], sample['type']
84 | video_name = str(video_name)
85 | qns, ans, qid, qtype = str(qns), str(ans), str(qid), str(qtype)
86 |
87 |
88 | #video_feature = torch.tensor([0])
89 | video_feature = self.get_video_feature(video_name)
90 |
91 | qns2idx = self.get_word_idx(qns, 'qns')
92 | ans2idx = self.get_word_idx(ans, 'ans')
93 |
94 | return video_feature, qns2idx, ans2idx, video_name, qid, qtype
95 |
96 |
97 | class QALoader():
98 | def __init__(self, batch_size, num_worker, video_feature_path, video_feature_cache,
99 | sample_list_path, vocab_qns, vocab_ans, train_shuffle=True, val_shuffle=False):
100 | self.batch_size = batch_size
101 | self.num_worker = num_worker
102 | self.video_feature_path = video_feature_path
103 | self.video_feature_cache = video_feature_cache
104 | self.sample_list_path = sample_list_path
105 | self.vocab_qns = vocab_qns
106 | self.vocab_ans = vocab_ans
107 | self.train_shuffle = train_shuffle
108 | self.val_shuffle = val_shuffle
109 |
110 |
111 | def run(self, mode=''):
112 | if mode != 'train':
113 | train_loader = ''
114 | val_loader = self.validate(mode)
115 | else:
116 | train_loader = self.train('train')
117 | val_loader = self.validate('val')
118 | return train_loader, val_loader
119 |
120 |
121 | def train(self, mode):
122 |
123 | training_set = VidQADataset(self.video_feature_path, self.video_feature_cache, self.sample_list_path,
124 | self.vocab_qns, self.vocab_ans, mode)
125 |
126 | print('Eligible QA pairs for training : {}'.format(len(training_set)))
127 | train_loader = DataLoader(
128 | dataset=training_set,
129 | batch_size=self.batch_size,
130 | shuffle=self.train_shuffle,
131 | num_workers=self.num_worker,
132 | collate_fn=collate_fn)
133 |
134 | return train_loader
135 |
136 | def validate(self, mode):
137 |
138 | validation_set = VidQADataset(self.video_feature_path, self.video_feature_cache, self.sample_list_path,
139 | self.vocab_qns, self.vocab_ans, mode)
140 |
141 | print('Eligible QA pairs for validation : {}'.format(len(validation_set)))
142 | val_loader = DataLoader(
143 | dataset=validation_set,
144 | batch_size=self.batch_size,
145 | shuffle=self.val_shuffle,
146 | num_workers=self.num_worker,
147 | collate_fn=collate_fn)
148 |
149 | return val_loader
150 |
151 |
152 | def collate_fn (data):
153 | """
154 | """
155 | data.sort(key=lambda x : len(x[1]), reverse=True)
156 | videos, qns2idx, ans2idx, video_names, qids, qtypes = zip(*data)
157 |
158 | #merge videos
159 | videos = torch.stack(videos, 0)
160 |
161 | #merge relations
162 | qns_lengths = [len(qns) for qns in qns2idx]
163 | targets_qns = torch.zeros(len(qns2idx), max(qns_lengths)).long()
164 | for i, qns in enumerate(qns2idx):
165 | end = qns_lengths[i]
166 | targets_qns[i, :end] = qns[:end]
167 |
168 | ans_lengths = [len(ans) for ans in ans2idx]
169 | targets_ans = torch.zeros(len(ans2idx), max(ans_lengths)).long()
170 | for i, ans in enumerate(ans2idx):
171 | end = ans_lengths[i]
172 | targets_ans[i, :end] = ans[:end]
173 |
174 | return videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes
175 |
--------------------------------------------------------------------------------
/dataloader/util.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | import numpy as np
5 | import pickle as pkl
6 | import pandas as pd
7 |
8 | def load_file(file_name):
9 | annos = None
10 | if osp.splitext(file_name)[-1] == '.csv':
11 | return pd.read_csv(file_name, delimiter=',')
12 | with open(file_name, 'r') as fp:
13 | if osp.splitext(file_name)[1]== '.txt':
14 | annos = fp.readlines()
15 | annos = [line.rstrip() for line in annos]
16 | if osp.splitext(file_name)[1] == '.json':
17 | annos = json.load(fp)
18 |
19 | return annos
20 |
21 | def save_file(obj, filename):
22 | """
23 | save obj to filename
24 | :param obj:
25 | :param filename:
26 | :return:
27 | """
28 | filepath = osp.dirname(filename)
29 | if filepath != '' and not osp.exists(filepath):
30 | os.makedirs(filepath)
31 | else:
32 | with open(filename, 'w') as fp:
33 | json.dump(obj, fp, indent=4)
34 |
35 | def pkload(file):
36 | data = None
37 | if osp.exists(file) and osp.getsize(file) > 0:
38 | with open(file, 'rb') as fp:
39 | data = pkl.load(fp)
40 | # print('{} does not exist'.format(file))
41 | return data
42 |
43 |
44 | def pkdump(data, file):
45 | dirname = osp.dirname(file)
46 | if not osp.exists(dirname):
47 | os.makedirs(dirname)
48 | with open(file, 'wb') as fp:
49 | pkl.dump(data, fp)
50 |
51 |
52 |
53 |
--------------------------------------------------------------------------------
/dataset/nextqa/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore all except .gitignore file
2 | *
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/eval_oe.py:
--------------------------------------------------------------------------------
1 | from utils import *
2 | from metrics import *
3 | import pandas as pd
4 | from pywsd.utils import lemmatize_sentence
5 |
6 | #use stopwords tailored for NExT-QA
7 | stopwords = load_file('stopwords.txt')
8 | def remove_stop(sentence):
9 |
10 | words = lemmatize_sentence(sentence)
11 | words = [w for w in words if not w in stopwords]
12 | return ' '.join(words)
13 |
14 |
15 | def evaluate(res_file, ref_file, ref_file_add):
16 | """
17 | :param res_file:
18 | :param ref_file:
19 | :return:
20 | """
21 | res = load_file(res_file)
22 |
23 | multi_ref_ans = False
24 | if osp.exists(ref_file_add):
25 | add_ref = load_file(ref_file_add)
26 | multi_ref_ans = True
27 | refer = pd.read_csv(ref_file)
28 | ref_num = len(refer)
29 | group_dict = {'CW': [], 'CH': [], 'TN': [], 'TC': [], 'DL': [], 'DB':[], 'DC': [], 'DO': []}
30 | for idx in range(ref_num):
31 | sample = refer.loc[idx]
32 | qtype = sample['type']
33 | if qtype in ['TN', 'TP']: qtype = 'TN'
34 | group_dict[qtype].append(idx)
35 | wups0 = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DB': 0, 'DC': 0, 'DL': 0, 'DO': 0}
36 | wups9 = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DB': 0, 'DC': 0, 'DL': 0, 'DO': 0}
37 | wups0_e, wups0_t, wups0_c = 0, 0, 0
38 | wups0_all, wups9_all = 0, 0
39 |
40 | num = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DB': 0, 'DC': 0, 'DL': 0, 'DO': 0}
41 | over_num = {'C':0, 'T':0, 'D':0}
42 | ref_num = 0
43 | for qtype, ids in group_dict.items():
44 | for id in ids:
45 | sample = refer.loc[id]
46 | video, qid, ans, qns = str(sample['video']), str(sample['qid']), str(sample['answer']), str(sample['question'])
47 | num[qtype] += 1
48 | over_num[qtype[0]] += 1
49 | ref_num += 1
50 |
51 | pred_ans_src = res[video][qid]
52 |
53 | gt_ans = remove_stop(ans)
54 | pred_ans = remove_stop(pred_ans_src)
55 | if multi_ref_ans and (video in add_ref):
56 | gt_ans_add = remove_stop(add_ref[video][qid])
57 | if qtype == 'DC' or qtype == 'DB':
58 | cur_0 = 1 if pred_ans == gt_ans_add or pred_ans == gt_ans else 0
59 | cur_9 = cur_0
60 | else:
61 | cur_0 = max(get_wups(pred_ans, gt_ans, 0), get_wups(pred_ans, gt_ans_add, 0))
62 | cur_9 = max(get_wups(pred_ans, gt_ans, 0.9), get_wups(pred_ans, gt_ans_add, 0.9))
63 | else:
64 | if qtype == 'DC' or qtype == 'DB':
65 | cur_0 = 1 if pred_ans == gt_ans else 0
66 | cur_9 = cur_0
67 | else:
68 | cur_0 = get_wups(pred_ans, gt_ans, 0)
69 | cur_9 = get_wups(pred_ans, gt_ans, 0.9)
70 | wups0[qtype] += cur_0
71 | wups9[qtype] += cur_9
72 |
73 |
74 | wups0_all += wups0[qtype]
75 | wups9_all += wups9[qtype]
76 | if qtype[0] == 'C':
77 | wups0_e += wups0[qtype]
78 | if qtype[0] == 'T':
79 | wups0_t += wups0[qtype]
80 | if qtype[0] == 'D':
81 | wups0_c += wups0[qtype]
82 |
83 | wups0[qtype] = wups0[qtype]/num[qtype]
84 | wups9[qtype] = wups9[qtype]/num[qtype]
85 |
86 | num_e = over_num['C']
87 | num_t = over_num['T']
88 | num_c = over_num['D']
89 |
90 | wups0_e /= num_e
91 | wups0_t /= num_t
92 | wups0_c /= num_c
93 |
94 | wups0_all /= ref_num
95 | wups9_all /= ref_num
96 |
97 | for k in wups0:
98 | wups0[k] = wups0[k] * 100
99 | wups9[k] = wups9[k] * 100
100 |
101 | wups0_e *= 100
102 | wups0_t *= 100
103 | wups0_c *= 100
104 | wups0_all *= 100
105 |
106 | print('CW\tCH\tWUPS_C\tTPN\tTC\tWUPS_T\tDB\tDC\tDL\tDO\tWUPS_D\tWUPS')
107 | print('{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}'
108 | .format(wups0['CW'], wups0['CH'], wups0_e, wups0['TN'], wups0['TC'],wups0_t,
109 | wups0['DB'],wups0['DC'], wups0['DL'], wups0['DO'], wups0_c, wups0_all))
110 |
111 |
112 |
113 | def main(filename, mode):
114 | res_dir = 'results'
115 | res_file = osp.join(res_dir, filename)
116 | print(f'Evaluate on {res_file}')
117 | ref_file = 'dataset/nextqa/{}.csv'.format(mode)
118 | ref_file_add = 'dataset/nextqa/add_reference_answer_{}.json'.format(mode)
119 | evaluate(res_file, ref_file, ref_file_add)
120 |
121 |
122 | if __name__ == "__main__":
123 |
124 | mode = 'val'
125 | model = 'HGA'
126 | result_file = '{}-same-att-qns23ans7-{}-example.json'.format(model, mode)
127 | main(result_file, mode)
128 |
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/doc-doc/NExT-OE/a83f8f581191da07675e0fc83074e0dfcf907273/images/logo.png
--------------------------------------------------------------------------------
/images/res-mc-oe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/doc-doc/NExT-OE/a83f8f581191da07675e0fc83074e0dfcf907273/images/res-mc-oe.png
--------------------------------------------------------------------------------
/main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | GPU=$1
3 | MODE=$2
4 | CUDA_VISIBLE_DEVICES=$GPU python main_qa.py \
5 | --mode $MODE
6 |
--------------------------------------------------------------------------------
/main_qa.py:
--------------------------------------------------------------------------------
1 | from videoqa import *
2 | import dataloader
3 | from build_vocab import Vocabulary
4 | from utils import *
5 | import argparse
6 | import eval_oe
7 |
8 |
9 | def main(args):
10 |
11 | mode = args.mode
12 | if mode == 'train':
13 | batch_size = 64
14 | num_worker = 8
15 | else:
16 | batch_size = 64
17 | num_worker = 8
18 | spatial = False
19 | if spatial:
20 | #for STVQA
21 | video_feature_path = '../data/feats/spatial/'
22 | video_feature_cache = '../data/feats/cache_spatial/'
23 | else:
24 | video_feature_path = '../data/feats/'
25 | video_feature_cache = '../data/feats/cache/'
26 |
27 | dataset = 'nextqa'
28 | sample_list_path = 'dataset/{}/'.format(dataset)
29 |
30 | #We separate the dicts for qns and ans, in case one wants to use different word-dicts for them.
31 | vocab_qns = pkload('dataset/{}/vocab.pkl'.format(dataset))
32 | vocab_ans = pkload('dataset/{}/vocab.pkl'.format(dataset))
33 |
34 | word_type = 'glove'
35 | glove_embed_qns = 'dataset/{}/{}_embed.npy'.format(dataset, word_type)
36 | glove_embed_ans = 'dataset/{}/{}_embed.npy'.format(dataset, word_type)
37 | checkpoint_path = 'models'
38 | model_type = 'HGA'
39 |
40 | model_prefix = 'same-att-qns23ans7'
41 | vis_step = 116
42 | lr_rate = 5e-5
43 | epoch_num = 100
44 |
45 | data_loader = dataloader.QALoader(batch_size, num_worker, video_feature_path, video_feature_cache,
46 | sample_list_path, vocab_qns, vocab_ans, True, False)
47 |
48 | train_loader, val_loader = data_loader.run(mode=mode)
49 |
50 | vqa = VideoQA(vocab_qns, vocab_ans, train_loader, val_loader, glove_embed_qns, glove_embed_ans,
51 | checkpoint_path, model_type, model_prefix, vis_step,lr_rate, batch_size, epoch_num)
52 |
53 | ep = 36
54 | acc = 0.2163
55 | model_file = f'{model_type}-{model_prefix}-{ep}-{acc:.4f}.ckpt'
56 |
57 | if mode != 'train':
58 | result_file = f'{model_type}-{model_prefix}-{mode}.json'
59 | vqa.predict(model_file, result_file)
60 | eval_oe.main(result_file, mode)
61 | else:
62 | model_file = f'{model_type}-{model_prefix}-44-0.2140.ckpt'
63 | vqa.run(model_file, pre_trained=False)
64 |
65 |
66 | if __name__ == "__main__":
67 | torch.backends.cudnn.enabled = False
68 | torch.manual_seed(666)
69 | torch.cuda.manual_seed(666)
70 | torch.backends.cudnn.benchmark = True
71 |
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument('--gpu', dest='gpu', type=int,
74 | default=0, help='gpu device id')
75 | parser.add_argument('--mode', dest='mode', type=str,
76 | default='train', help='train or val')
77 | args = parser.parse_args()
78 | set_gpu_devices(args.gpu)
79 | main(args)
80 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | # ====================================================
2 | # @Time : 13/9/20 4:19 PM
3 | # @Author : Xiao Junbin
4 | # @Email : junbin@comp.nus.edu.sg
5 | # @File : metrics.py
6 | # ====================================================
7 | from nltk.tokenize import word_tokenize
8 | from nltk.corpus import wordnet
9 | import numpy as np
10 |
11 | def wup(word1, word2, alpha):
12 | """
13 | calculate the wup similarity
14 | :param word1:
15 | :param word2:
16 | :param alpha:
17 | :return:
18 | """
19 | # print(word1, word2)
20 | if word1 == word2:
21 | return 1.0
22 |
23 | w1 = wordnet.synsets(word1)
24 | w1_len = len(w1)
25 | if w1_len == 0: return 0.0
26 | w2 = wordnet.synsets(word2)
27 | w2_len = len(w2)
28 | if w2_len == 0: return 0.0
29 |
30 | #match the first
31 | word_sim = w1[0].wup_similarity(w2[0])
32 | if word_sim is None:
33 | word_sim = 0.0
34 |
35 | if word_sim < alpha:
36 | word_sim = 0.1*word_sim
37 | return word_sim
38 |
39 |
40 | def wups(words1, words2, alpha):
41 | """
42 |
43 | :param pred:
44 | :param truth:
45 | :param alpha:
46 | :return:
47 | """
48 | sim = 1.0
49 | flag = False
50 | for w1 in words1:
51 | max_sim = 0
52 | for w2 in words2:
53 | word_sim = wup(w1, w2, alpha)
54 | if word_sim > max_sim:
55 | max_sim = word_sim
56 | if max_sim == 0: continue
57 | sim *= max_sim
58 | flag = True
59 | if not flag:
60 | sim = 0.0
61 | return sim
62 |
63 |
64 | def get_wups(pred, truth, alpha):
65 | """
66 | calculate the wups score
67 | :param pred:
68 | :param truth:
69 | :return:
70 | """
71 | pred = word_tokenize(pred)
72 | truth = word_tokenize(truth)
73 | item1 = wups(pred, truth, alpha)
74 | item2 = wups(truth, pred, alpha)
75 | value = min(item1, item2)
76 | return value
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore all except .gitignore file
2 | *
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/networks/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
--------------------------------------------------------------------------------
/networks/Attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class TempAttention(nn.Module):
7 | """
8 | Applies an attention mechanism on the output features from the decoder.
9 | """
10 |
11 | def __init__(self, text_dim, visual_dim, hidden_dim):
12 | super(TempAttention, self).__init__()
13 | self.hidden_dim = hidden_dim
14 | self.linear_text = nn.Linear(text_dim, hidden_dim)
15 | self.linear_visual = nn.Linear(visual_dim, hidden_dim)
16 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False)
17 | self._init_weight()
18 |
19 | def _init_weight(self):
20 | nn.init.xavier_normal_(self.linear_text.weight)
21 | nn.init.xavier_normal_(self.linear_visual.weight)
22 | nn.init.xavier_normal_(self.linear_att.weight)
23 |
24 | def forward(self, qns_embed, vid_outputs):
25 | """
26 | Arguments:
27 | qns_embed {Variable} -- batch_size x dim
28 | vid_outputs {Variable} -- batch_size x seq_len x dim
29 |
30 | Returns:
31 | context -- context vector of size batch_size x dim
32 | """
33 | qns_embed_trans = self.linear_text(qns_embed)
34 |
35 | batch_size, seq_len, visual_dim = vid_outputs.size()
36 | vid_outputs_temp = vid_outputs.contiguous().view(batch_size*seq_len, visual_dim)
37 | vid_outputs_trans = self.linear_visual(vid_outputs_temp)
38 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.hidden_dim)
39 |
40 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1)
41 |
42 |
43 | o = self.linear_att(torch.tanh(qns_embed_trans+vid_outputs_trans))
44 |
45 | e = o.view(batch_size, seq_len)
46 | beta = F.softmax(e, dim=1)
47 | context = torch.bmm(beta.unsqueeze(1), vid_outputs).squeeze(1)
48 |
49 | return context, beta
50 |
51 |
52 | class SpatialAttention(nn.Module):
53 | """
54 | Apply spatial attention on vid feature before being fed into LSTM
55 | """
56 |
57 | def __init__(self, text_dim=1024, vid_dim=3072, hidden_dim=512, input_dropout_p=0.2):
58 | super(SpatialAttention, self).__init__()
59 |
60 | self.linear_v = nn.Linear(vid_dim, hidden_dim)
61 | self.linear_q = nn.Linear(text_dim, hidden_dim)
62 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False)
63 |
64 | self.softmax = nn.Softmax(dim=1)
65 | self.dropout = nn.Dropout(input_dropout_p)
66 | self._init_weight()
67 |
68 | def _init_weight(self):
69 | nn.init.xavier_normal_(self.linear_v.weight)
70 | nn.init.xavier_normal_(self.linear_q.weight)
71 | nn.init.xavier_normal_(self.linear_att.weight)
72 |
73 | def forward(self, qns_feat, vid_feats):
74 | """
75 | Apply question feature as semantic clue to guide feature aggregation at each frame
76 | :param vid_feats: fnum x feat_dim x 7 x 7
77 | :param qns_feat: dim_hidden*2
78 | :return:
79 | """
80 | # print(qns_feat.size(), vid_feats.size())
81 | # permute to fnum x 7 x 7 x feat_dim
82 | vid_feats = vid_feats.permute(0, 2, 3, 1)
83 | fnum, width, height, feat_dim = vid_feats.size()
84 | vid_feats = vid_feats.contiguous().view(-1, feat_dim)
85 | vid_feats_trans = self.linear_v(vid_feats)
86 |
87 | vid_feats_trans = vid_feats_trans.view(fnum, width*height, -1)
88 | region_num = vid_feats_trans.shape[1]
89 |
90 | qns_feat_trans = self.linear_q(qns_feat)
91 |
92 | qns_feat_trans = qns_feat_trans.repeat(fnum, region_num, 1)
93 | # print(vid_feats_trans.shape, qns_feat_trans.shape)
94 |
95 | vid_qns = self.linear_att(torch.tanh(vid_feats_trans + qns_feat_trans))
96 |
97 | vid_qns_o = vid_qns.view(fnum, region_num)
98 | alpha = self.softmax(vid_qns_o)
99 | alpha = alpha.unsqueeze(1)
100 | vid_feats = vid_feats.view(fnum, region_num, -1)
101 | feature = torch.bmm(alpha, vid_feats).squeeze(1)
102 | feature = self.dropout(feature)
103 | # print(feature.size())
104 | return feature, alpha
105 |
106 |
107 | class TempAttentionHis(nn.Module):
108 | """
109 | Applies an attention mechanism on the output features from the decoder.
110 | """
111 |
112 | def __init__(self, visual_dim, text_dim, his_dim, mem_dim):
113 | super(TempAttentionHis, self).__init__()
114 | # self.dim = dim
115 | self.mem_dim = mem_dim
116 | self.linear_v = nn.Linear(visual_dim, self.mem_dim, bias=False)
117 | self.linear_q = nn.Linear(text_dim, self.mem_dim, bias=False)
118 | self.linear_his1 = nn.Linear(his_dim, self.mem_dim, bias=False)
119 | self.linear_his2 = nn.Linear(his_dim, self.mem_dim, bias=False)
120 | self.linear_att = nn.Linear(self.mem_dim, 1, bias=False)
121 | self._init_weight()
122 |
123 |
124 | def _init_weight(self):
125 | nn.init.xavier_normal_(self.linear_v.weight)
126 | nn.init.xavier_normal_(self.linear_q.weight)
127 | nn.init.xavier_normal_(self.linear_his1.weight)
128 | nn.init.xavier_normal_(self.linear_his2.weight)
129 | nn.init.xavier_normal_(self.linear_att.weight)
130 |
131 |
132 | def forward(self, qns_embed, vid_outputs, his):
133 | """
134 | :param qns_embed: batch_size x 1024
135 | :param vid_outputs: batch_size x seq_num x feat_dim
136 | :param his: batch_size x 512
137 | :return:
138 | """
139 |
140 | batch_size, seq_len, feat_dim = vid_outputs.size()
141 | vid_outputs_trans = self.linear_v(vid_outputs.contiguous().view(batch_size * seq_len, feat_dim))
142 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.mem_dim)
143 |
144 | qns_embed_trans = self.linear_q(qns_embed)
145 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1)
146 |
147 |
148 | his_trans = self.linear_his1(his)
149 | his_trans = his_trans.unsqueeze(1).repeat(1, seq_len, 1)
150 |
151 | o = self.linear_att(torch.tanh(qns_embed_trans + vid_outputs_trans + his_trans))
152 |
153 | e = o.view(batch_size, seq_len)
154 | beta = F.softmax(e, dim=1)
155 | context = torch.bmm(beta.unsqueeze(1), vid_outputs_trans).squeeze(1)
156 |
157 | his_acc = torch.tanh(self.linear_his2(his))
158 |
159 | context += his_acc
160 |
161 | return context, beta
162 |
163 |
164 | class MultiModalAttentionModule(nn.Module):
165 |
166 | def __init__(self, hidden_size=512, simple=False):
167 | """Set the hyper-parameters and build the layers."""
168 | super(MultiModalAttentionModule, self).__init__()
169 |
170 | self.hidden_size = hidden_size
171 | self.simple = simple
172 |
173 | # alignment model
174 | # see appendices A.1.2 of neural machine translation
175 |
176 | self.Wav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
177 | self.Wat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
178 | self.Uav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
179 | self.Uat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
180 | self.Vav = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
181 | self.Vat = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
182 | self.bav = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True)
183 | self.bat = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True)
184 |
185 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
186 | self.Wvh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
187 | self.Wth = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
188 | self.bh = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True)
189 |
190 | self.video_sum_encoder = nn.Linear(hidden_size, hidden_size)
191 | self.question_sum_encoder = nn.Linear(hidden_size, hidden_size)
192 |
193 | self.Wb = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
194 | self.Vbv = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
195 | self.Vbt = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True)
196 | self.bbv = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
197 | self.bbt = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
198 | self.wb = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True)
199 | self.init_weights()
200 |
201 | def init_weights(self):
202 | self.Wav.data.normal_(0.0, 0.1)
203 | self.Wat.data.normal_(0.0, 0.1)
204 | self.Uav.data.normal_(0.0, 0.1)
205 | self.Uat.data.normal_(0.0, 0.1)
206 | self.Vav.data.normal_(0.0, 0.1)
207 | self.Vat.data.normal_(0.0, 0.1)
208 | self.bav.data.fill_(0)
209 | self.bat.data.fill_(0)
210 |
211 | self.Whh.data.normal_(0.0, 0.1)
212 | self.Wvh.data.normal_(0.0, 0.1)
213 | self.Wth.data.normal_(0.0, 0.1)
214 | self.bh.data.fill_(0)
215 |
216 | self.Wb.data.normal_(0.0, 0.01)
217 | self.Vbv.data.normal_(0.0, 0.01)
218 | self.Vbt.data.normal_(0.0, 0.01)
219 | self.wb.data.normal_(0.0, 0.01)
220 |
221 | self.bbv.data.fill_(0)
222 | self.bbt.data.fill_(0)
223 |
224 | def forward(self, h, hidden_frames, hidden_text, inv_attention=False):
225 | # print self.Uav
226 | # hidden_text: 1 x T1 x 1024 (looks like a two layer one-directional LSTM, combining each layer's hidden)
227 | # hidden_frame: 1 x T2 x 1024 (from video encoder output, 1024 is similar from above)
228 |
229 | # print hidden_frames.size(),hidden_text.size()
230 | Uhv = torch.matmul(h, self.Uav) # (1,512)
231 | Uhv = Uhv.view(Uhv.size(0), 1, Uhv.size(1)) # (1,1,512)
232 |
233 | Uht = torch.matmul(h, self.Uat) # (1,512)
234 | Uht = Uht.view(Uht.size(0), 1, Uht.size(1)) # (1,1,512)
235 |
236 | # print Uhv.size(),Uht.size()
237 |
238 | Wsv = torch.matmul(hidden_frames, self.Wav) # (1,T,512)
239 | # print Wsv.size()
240 | att_vec_v = torch.matmul(torch.tanh(Wsv + Uhv + self.bav), self.Vav)
241 |
242 | Wst = torch.matmul(hidden_text, self.Wat) # (1,T,512)
243 | att_vec_t = torch.matmul(torch.tanh(Wst + Uht + self.bat), self.Vat)
244 |
245 | if inv_attention == True:
246 | att_vec_v = -att_vec_v
247 | att_vec_t = -att_vec_t
248 |
249 | att_vec_v = torch.softmax(att_vec_v, dim=1)
250 | att_vec_t = torch.softmax(att_vec_t, dim=1)
251 |
252 | att_vec_v = att_vec_v.view(att_vec_v.size(0), att_vec_v.size(1), 1) # expand att_vec from 1xT to 1xTx1
253 | att_vec_t = att_vec_t.view(att_vec_t.size(0), att_vec_t.size(1), 1) # expand att_vec from 1xT to 1xTx1
254 |
255 | hv_weighted = att_vec_v * hidden_frames
256 | hv_sum = torch.sum(hv_weighted, dim=1)
257 | hv_sum2 = self.video_sum_encoder(hv_sum)
258 |
259 | ht_weighted = att_vec_t * hidden_text
260 | ht_sum = torch.sum(ht_weighted, dim=1)
261 | ht_sum2 = self.question_sum_encoder(ht_sum)
262 |
263 | Wbs = torch.matmul(h, self.Wb)
264 | mt1 = torch.matmul(ht_sum, self.Vbt) + self.bbt + Wbs
265 | mv1 = torch.matmul(hv_sum, self.Vbv) + self.bbv + Wbs
266 | mtv = torch.tanh(torch.cat([mv1, mt1], dim=0))
267 | mtv2 = torch.matmul(mtv, self.wb)
268 | beta = torch.softmax(mtv2, dim=0)
269 |
270 | output = torch.tanh(torch.matmul(h, self.Whh) + beta[0] * hv_sum2 +
271 | beta[1] * ht_sum2 + self.bh)
272 | output = output.view(output.size(1), output.size(2))
273 |
274 | return output
--------------------------------------------------------------------------------
/networks/DecoderRNN.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | from .Attention import TempAttention
5 |
6 |
7 | class AnsUATT(nn.Module):
8 | """
9 | """
10 | def __init__(self,
11 | vocab_size,
12 | max_len,
13 | dim_hidden,
14 | dim_word,
15 | glove_embed,
16 | n_layers=1,
17 | rnn_cell='gru',
18 | bidirectional=False,
19 | input_dropout_p=0.2,
20 | rnn_dropout_p=0):
21 | super(AnsUATT, self).__init__()
22 |
23 | self.bidirectional_encoder = bidirectional
24 | self.vocab_size = vocab_size
25 |
26 |
27 | self.dim_hidden = dim_hidden
28 | self.dim_word = dim_word
29 | self.max_length = max_len
30 | self.glove_embed = glove_embed
31 |
32 | self.input_dropout = nn.Dropout(input_dropout_p)
33 | self.embedding = nn.Embedding(vocab_size, dim_word)
34 |
35 | if rnn_cell.lower() == 'lstm':
36 | self.rnn_cell = nn.LSTM
37 | elif rnn_cell.lower() == 'gru':
38 | self.rnn_cell = nn.GRU
39 | self.rnn = self.rnn_cell(self.dim_word, self.dim_hidden, n_layers,
40 | batch_first=True, dropout=rnn_dropout_p)
41 |
42 | self.out = nn.Linear(self.dim_hidden, vocab_size)
43 |
44 | self._init_weights()
45 |
46 |
47 | def forward(self, hidden, ans):
48 | """
49 | decode answer
50 | :param encoder_outputs:
51 | :param encoder_hidden:
52 | :param ans:
53 | :param t:
54 | :return:
55 | """
56 | ans_embed = self.embedding(ans)
57 | ans_embed = self.input_dropout(ans_embed)
58 | # ans_embed = self.transform(ans_embed)
59 | ans_embed = ans_embed.unsqueeze(1)
60 | decoder_input = ans_embed
61 | outputs, hidden = self.rnn(decoder_input, hidden)
62 | final_outputs = self.out(outputs.squeeze(1))
63 | return final_outputs, hidden
64 |
65 |
66 | def sample(self, hidden, start=None):
67 | """
68 |
69 | :param encoder_outputs:
70 | :param encoder_hidden:
71 | :return:
72 | """
73 | sample_ids = []
74 | start_embed = self.embedding(start)
75 | start_embed = self.input_dropout(start_embed)
76 | # start_embed = self.transform(start_embed)
77 | start_embed = start_embed.unsqueeze(1)
78 |
79 | inputs = start_embed
80 |
81 | for i in range(self.max_length):
82 | outputs, hidden = self.rnn(inputs, hidden)
83 | outputs = self.out(outputs.squeeze(1))
84 | _, predict = outputs.max(1)
85 | sample_ids.append(predict)
86 | ans_embed = self.embedding(predict)
87 | ans_embed = self.input_dropout(ans_embed)
88 | # ans_embed = self.transform(ans_embed)
89 |
90 | inputs = ans_embed.unsqueeze(1)
91 |
92 | sample_ids = torch.stack(sample_ids, 1)
93 | return sample_ids
94 |
95 |
96 | def _init_weights(self):
97 | """ init the weight of some layers
98 | """
99 | nn.init.xavier_normal_(self.out.weight)
100 | glove_embed = np.load(self.glove_embed)
101 | self.embedding.weight = nn.Parameter(torch.FloatTensor(glove_embed))
102 |
103 |
104 |
105 | class AnsHME(nn.Module):
106 | """
107 | """
108 |
109 | def __init__(self,
110 | vocab_size,
111 | max_len,
112 | dim_hidden,
113 | dim_word,
114 | glove_embed,
115 | n_layers=1,
116 | rnn_cell='gru',
117 | bidirectional=False,
118 | input_dropout_p=0.2,
119 | rnn_dropout_p=0):
120 | super(AnsHME, self).__init__()
121 |
122 | self.bidirectional_encoder = bidirectional
123 | self.vocab_size = vocab_size
124 |
125 |
126 | self.dim_hidden = dim_hidden
127 | self.dim_word = dim_word
128 | self.max_length = max_len
129 | self.glove_embed = glove_embed
130 |
131 | self.input_dropout = nn.Dropout(input_dropout_p)
132 | self.embedding = nn.Embedding(vocab_size, dim_word)
133 | word_mat = torch.FloatTensor(np.load(self.glove_embed))
134 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False)
135 |
136 | self.transform = nn.Sequential(nn.Linear(dim_hidden*2+dim_word, dim_hidden),
137 | nn.Dropout(input_dropout_p))
138 |
139 | if rnn_cell.lower() == 'lstm':
140 | self.rnn_cell = nn.LSTM
141 | elif rnn_cell.lower() == 'gru':
142 | self.rnn_cell = nn.GRU
143 | self.rnn = self.rnn_cell(self.dim_hidden, self.dim_hidden, n_layers,
144 | batch_first=True, dropout=rnn_dropout_p)
145 |
146 | self.out = nn.Linear(self.dim_hidden, vocab_size)
147 |
148 | self._init_weights()
149 |
150 |
151 | def forward(self, encoder_outputs, hidden, ans_idx):
152 | """
153 | decode answer
154 | :param encoder_outputs:
155 | :param encoder_hidden:
156 | :param ans:
157 | :param t:
158 | :return:
159 | """
160 | ans_embed = self.embedding(ans_idx)
161 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1)
162 | ans_embed = self.transform(ans_embed)
163 | ans_embed = ans_embed.unsqueeze(1)
164 | decoder_input = ans_embed
165 | outputs, hidden = self.rnn(decoder_input, hidden)
166 | final_outputs = self.out(outputs.squeeze(1))
167 | return final_outputs, hidden
168 |
169 |
170 | def sample(self, encoder_outputs, hidden, start=None):
171 | """
172 |
173 | :param encoder_outputs:
174 | :param encoder_hidden:
175 | :return:
176 | """
177 | sample_ids = []
178 | start_embed = self.embedding(start)
179 | start_embed = torch.cat((encoder_outputs, start_embed), dim=-1)
180 | start_embed = self.transform(start_embed)
181 | start_embed = start_embed.unsqueeze(1)
182 | inputs = start_embed
183 |
184 | for i in range(self.max_length):
185 | outputs, hidden = self.rnn(inputs, hidden)
186 | outputs = self.out(outputs.squeeze(1))
187 | _, predict = outputs.max(1)
188 | sample_ids.append(predict)
189 | ans_embed = self.embedding(predict)
190 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1)
191 | ans_embed = self.transform(ans_embed)
192 | inputs = ans_embed.unsqueeze(1)
193 |
194 | sample_ids = torch.stack(sample_ids, 1)
195 | return sample_ids
196 |
197 |
198 | def _init_weights(self):
199 | """ init the weight of some layers
200 | """
201 | nn.init.xavier_normal_(self.out.weight)
202 | nn.init.xavier_normal_(self.transform[0].weight)
203 |
204 |
205 | class AnsQnsAns(nn.Module):
206 | """
207 | """
208 |
209 | def __init__(self,
210 | vocab_size,
211 | max_len,
212 | dim_hidden,
213 | dim_word,
214 | glove_embed,
215 | n_layers=1,
216 | rnn_cell='gru',
217 | bidirectional=False,
218 | input_dropout_p=0.2,
219 | rnn_dropout_p=0):
220 | super(AnsQnsAns, self).__init__()
221 |
222 | self.bidirectional_encoder = bidirectional
223 | self.vocab_size = vocab_size
224 |
225 |
226 | self.dim_hidden = dim_hidden
227 | self.dim_word = dim_word
228 | self.max_length = max_len
229 | self.glove_embed = glove_embed
230 |
231 | self.input_dropout = nn.Dropout(input_dropout_p)
232 | self.embedding = nn.Embedding(vocab_size, dim_word)
233 | word_mat = torch.FloatTensor(np.load(self.glove_embed))
234 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False)
235 |
236 | self.transform = nn.Sequential(nn.Linear(dim_hidden+dim_word, dim_hidden),
237 | nn.Dropout(input_dropout_p))
238 |
239 | if rnn_cell.lower() == 'lstm':
240 | self.rnn_cell = nn.LSTM
241 | elif rnn_cell.lower() == 'gru':
242 | self.rnn_cell = nn.GRU
243 | self.rnn = self.rnn_cell(self.dim_hidden, self.dim_hidden, n_layers,
244 | batch_first=True, dropout=rnn_dropout_p)
245 |
246 | self.out = nn.Linear(self.dim_hidden, vocab_size)
247 |
248 | self._init_weights()
249 |
250 |
251 | def forward(self, encoder_outputs, hidden, ans_idx):
252 | """
253 | :param encoder_outputs:
254 | :param encoder_hidden:
255 | :param ans:
256 | :param t:
257 | :return:
258 | """
259 | ans_embed = self.embedding(ans_idx)
260 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1)
261 | ans_embed = self.transform(ans_embed)
262 | ans_embed = ans_embed.unsqueeze(1)
263 | decoder_input = ans_embed
264 | outputs, hidden = self.rnn(decoder_input, hidden)
265 | final_outputs = self.out(outputs.squeeze(1))
266 | return final_outputs, hidden
267 |
268 |
269 | def sample(self, encoder_outputs, hidden, start=None):
270 | """
271 |
272 | :param encoder_outputs:
273 | :param encoder_hidden:
274 | :return:
275 | """
276 | sample_ids = []
277 | start_embed = self.embedding(start)
278 | start_embed = torch.cat((encoder_outputs, start_embed), dim=-1)
279 | start_embed = self.transform(start_embed)
280 | start_embed = start_embed.unsqueeze(1)
281 | inputs = start_embed
282 |
283 | for i in range(self.max_length):
284 | outputs, hidden = self.rnn(inputs, hidden)
285 | outputs = self.out(outputs.squeeze(1))
286 | _, predict = outputs.max(1)
287 | sample_ids.append(predict)
288 | ans_embed = self.embedding(predict)
289 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1)
290 | ans_embed = self.transform(ans_embed)
291 | inputs = ans_embed.unsqueeze(1)
292 |
293 | sample_ids = torch.stack(sample_ids, 1)
294 | return sample_ids
295 |
296 |
297 | def _init_weights(self):
298 | """ init the weight of some layers
299 | """
300 | nn.init.xavier_normal_(self.out.weight)
301 | nn.init.xavier_normal_(self.transform[0].weight)
302 |
303 |
304 | class AnsAttSeq(nn.Module):
305 | """
306 |
307 | """
308 |
309 | def __init__(self,
310 | vocab_size,
311 | max_len,
312 | dim_hidden,
313 | dim_word,
314 | glove_embed,
315 | n_layers=1,
316 | rnn_cell='gru',
317 | bidirectional=False,
318 | input_dropout_p=0.2,
319 | rnn_dropout_p=0):
320 | super(AnsAttSeq, self).__init__()
321 |
322 | self.bidirectional_encoder = bidirectional
323 | self.vocab_size = vocab_size
324 |
325 |
326 | self.dim_hidden = dim_hidden
327 | self.dim_word = dim_word
328 | self.max_length = max_len
329 | self.glove_embed = glove_embed
330 |
331 | self.input_dropout = nn.Dropout(input_dropout_p)
332 | self.embedding = nn.Embedding(vocab_size, dim_word)
333 | word_mat = torch.FloatTensor(np.load(self.glove_embed))
334 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False)
335 |
336 | self.temp = TempAttention(dim_word, dim_hidden, dim_hidden//2)
337 |
338 | self.transform = nn.Sequential(nn.Linear(dim_word+dim_hidden, dim_hidden),
339 | nn.Dropout(input_dropout_p))
340 |
341 | if rnn_cell.lower() == 'lstm':
342 | self.rnn_cell = nn.LSTM
343 | elif rnn_cell.lower() == 'gru':
344 | self.rnn_cell = nn.GRU
345 | self.rnn = self.rnn_cell(self.dim_hidden, self.dim_hidden, n_layers,
346 | batch_first=True, dropout=rnn_dropout_p)
347 |
348 | self.out = nn.Linear(self.dim_hidden, vocab_size)
349 |
350 | self._init_weights()
351 |
352 |
353 | def forward(self, seq_outs, hidden, ans_idx):
354 | """
355 | decode answer
356 | :param encoder_outputs:
357 | :param encoder_hidden:
358 | :param ans:
359 | :param t:
360 | :return:
361 | """
362 | ans_embed = self.embedding(ans_idx)
363 | ans_embed_att, _ = self.temp(ans_embed, seq_outs)
364 | ans_embed = torch.cat((ans_embed, ans_embed_att), dim=-1)
365 |
366 | ans_embed = self.transform(ans_embed)
367 | ans_embed = ans_embed.unsqueeze(1)
368 | decoder_input = ans_embed
369 | outputs, hidden = self.rnn(decoder_input, hidden)
370 | final_outputs = self.out(outputs.squeeze(1))
371 | return final_outputs, hidden
372 |
373 |
374 | def sample(self, seq_outs, hidden, start=None):
375 | """
376 |
377 | :param encoder_outputs:
378 | :param encoder_hidden:
379 | :return:
380 | """
381 | sample_ids = []
382 | start_embed = self.embedding(start)
383 | start_embed_att, _ = self.temp(start_embed, seq_outs)
384 | start_embed = torch.cat((start_embed, start_embed_att), dim=-1)
385 | start_embed = self.transform(start_embed)
386 | start_embed = start_embed.unsqueeze(1)
387 | inputs = start_embed
388 |
389 | for i in range(self.max_length):
390 | outputs, hidden = self.rnn(inputs, hidden)
391 | outputs = self.out(outputs.squeeze(1))
392 | _, predict = outputs.max(1)
393 | sample_ids.append(predict)
394 | ans_embed = self.embedding(predict)
395 | ans_embed_att, _ = self.temp(ans_embed, seq_outs)
396 | ans_embed = torch.cat((ans_embed, ans_embed_att), dim=-1)
397 | ans_embed = self.transform(ans_embed)
398 | inputs = ans_embed.unsqueeze(1)
399 |
400 | sample_ids = torch.stack(sample_ids, 1)
401 | return sample_ids
402 |
403 |
404 | def _init_weights(self):
405 | """ init the weight of some layers
406 | """
407 | nn.init.xavier_normal_(self.out.weight)
408 | nn.init.xavier_normal_(self.transform[0].weight)
409 |
410 |
411 | class AnsNavieTrans(nn.Module):
412 | """
413 | """
414 | def __init__(self,
415 | vocab_size,
416 | max_len,
417 | dim_hidden,
418 | dim_word,
419 | glove_embed,
420 | n_layers=1,
421 | rnn_cell='gru',
422 | bidirectional=False,
423 | input_dropout_p=0.2,
424 | rnn_dropout_p=0):
425 | super(AnsNavieTrans, self).__init__()
426 |
427 | self.bidirectional_encoder = bidirectional
428 | self.vocab_size = vocab_size
429 |
430 |
431 | self.dim_hidden = dim_hidden
432 | self.dim_word = dim_word
433 | self.max_length = max_len
434 | self.glove_embed = glove_embed
435 |
436 | self.input_dropout = nn.Dropout(input_dropout_p)
437 | self.embedding = nn.Embedding(vocab_size, dim_word)
438 | word_mat = torch.FloatTensor(np.load(self.glove_embed))
439 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False)
440 |
441 | self.transform = nn.Sequential(nn.Linear(dim_word, dim_hidden),
442 | nn.Dropout(input_dropout_p))
443 |
444 | if rnn_cell.lower() == 'lstm':
445 | self.rnn_cell = nn.LSTM
446 | elif rnn_cell.lower() == 'gru':
447 | self.rnn_cell = nn.GRU
448 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers,
449 | batch_first=True, dropout=rnn_dropout_p)
450 |
451 | self.out = nn.Linear(self.dim_hidden, vocab_size)
452 |
453 | self._init_weights()
454 |
455 |
456 | def forward(self, hidden, ans_idx):
457 | """
458 | decode answer
459 | :param encoder_outputs:
460 | :param encoder_hidden:
461 | :param ans:
462 | :param t:
463 | :return:
464 | """
465 | ans_embed = self.embedding(ans_idx)
466 | ans_embed = self.transform(ans_embed)
467 | ans_embed = ans_embed.unsqueeze(1)
468 | decoder_input = ans_embed
469 | outputs, hidden = self.rnn(decoder_input, hidden)
470 | final_outputs = self.out(outputs.squeeze(1))
471 | return final_outputs, hidden
472 |
473 |
474 | def sample(self, hidden, start=None):
475 | """
476 | :param encoder_outputs:
477 | :param encoder_hidden:
478 | :return:
479 | """
480 | sample_ids = []
481 | start_embed = self.embedding(start)
482 | start_embed = self.transform(start_embed)
483 | start_embed = start_embed.unsqueeze(1)
484 | inputs = start_embed
485 |
486 | for i in range(self.max_length):
487 | outputs, hidden = self.rnn(inputs, hidden)
488 | outputs = self.out(outputs.squeeze(1))
489 | _, predict = outputs.max(1)
490 | sample_ids.append(predict)
491 | ans_embed = self.embedding(predict)
492 | ans_embed = self.transform(ans_embed)
493 | inputs = ans_embed.unsqueeze(1)
494 |
495 | sample_ids = torch.stack(sample_ids, 1)
496 | return sample_ids
497 |
498 |
499 | def _init_weights(self):
500 | """ init the weight of some layers
501 | """
502 | nn.init.xavier_normal_(self.out.weight)
503 | nn.init.xavier_normal_(self.transform[0].weight)
504 |
505 |
506 |
--------------------------------------------------------------------------------
/networks/EncoderRNN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
4 | import numpy as np
5 |
6 | class EncoderQns(nn.Module):
7 | def __init__(self, dim_embed, dim_hidden, vocab_size, glove_embed, input_dropout_p=0.2, rnn_dropout_p=0,
8 | n_layers=1, bidirectional=False, rnn_cell='gru'):
9 | """
10 | """
11 | super(EncoderQns, self).__init__()
12 | self.dim_hidden = dim_hidden
13 | self.vocab_size = vocab_size
14 | self.glove_embed = glove_embed
15 | self.input_dropout_p = input_dropout_p
16 | self.rnn_dropout_p = rnn_dropout_p
17 | self.n_layers = n_layers
18 | self.bidirectional = bidirectional
19 | self.rnn_cell = rnn_cell
20 |
21 | self.input_dropout = nn.Dropout(input_dropout_p)
22 |
23 | if rnn_cell.lower() == 'lstm':
24 | self.rnn_cell = nn.LSTM
25 | elif rnn_cell.lower() == 'gru':
26 | self.rnn_cell = nn.GRU
27 |
28 | self.embedding = nn.Embedding(vocab_size, dim_embed)
29 | word_mat = torch.FloatTensor(np.load(self.glove_embed))
30 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False)
31 | self.rnn = self.rnn_cell(dim_embed, dim_hidden, n_layers, batch_first=True,
32 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
33 |
34 |
35 | def forward(self, qns, qns_lengths, hidden=None):
36 | """
37 | """
38 | qns_embed = self.embedding(qns)
39 | qns_embed = self.input_dropout(qns_embed)
40 | packed = pack_padded_sequence(qns_embed, qns_lengths, batch_first=True)
41 | packed_output, hidden = self.rnn(packed, hidden)
42 | output, _ = pad_packed_sequence(packed_output, batch_first=True)
43 | return output, hidden
44 |
45 |
46 | class EncoderVid(nn.Module):
47 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
48 | n_layers=1, bidirectional=False, rnn_cell='gru'):
49 | """
50 | """
51 | super(EncoderVid, self).__init__()
52 | self.dim_vid = dim_vid
53 | self.dim_app = 2048
54 | self.dim_motion = 4096
55 | self.dim_hidden = dim_hidden
56 | self.input_dropout_p = input_dropout_p
57 | self.rnn_dropout_p = rnn_dropout_p
58 | self.n_layers = n_layers
59 | self.bidirectional = bidirectional
60 | self.rnn_cell = rnn_cell
61 |
62 | if rnn_cell.lower() == 'lstm':
63 | self.rnn_cell = nn.LSTM
64 | elif rnn_cell.lower() == 'gru':
65 | self.rnn_cell = nn.GRU
66 |
67 | self.rnn = self.rnn_cell(dim_vid, dim_hidden, n_layers, batch_first=True,
68 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
69 |
70 |
71 | def forward(self, vid_feats):
72 | """
73 | """
74 | self.rnn.flatten_parameters()
75 | foutput, fhidden = self.rnn(vid_feats)
76 |
77 | return foutput, fhidden
78 |
79 |
80 | class EncoderVidSTVQA(nn.Module):
81 | def __init__(self, input_dim, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
82 | n_layers=1, bidirectional=False, rnn_cell='gru'):
83 | """
84 | """
85 | super(EncoderVidSTVQA, self).__init__()
86 | self.input_dim = input_dim
87 | self.dim_hidden = dim_hidden
88 | self.input_dropout_p = input_dropout_p
89 | self.rnn_dropout_p = rnn_dropout_p
90 | self.n_layers = n_layers
91 | self.bidirectional = bidirectional
92 | self.rnn_cell = rnn_cell
93 |
94 |
95 | if rnn_cell.lower() == 'lstm':
96 | self.rnn_cell = nn.LSTM
97 | elif rnn_cell.lower() == 'gru':
98 | self.rnn_cell = nn.GRU
99 |
100 | self.rnn1 = self.rnn_cell(input_dim, dim_hidden, n_layers, batch_first=True,
101 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
102 |
103 | self.rnn2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
104 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
105 |
106 |
107 | def forward(self, vid_feats):
108 | """
109 | Dual-layer LSTM
110 | """
111 |
112 | self.rnn1.flatten_parameters()
113 |
114 | foutput_1, fhidden_1 = self.rnn1(vid_feats)
115 | self.rnn2.flatten_parameters()
116 | foutput_2, fhidden_2 = self.rnn2(foutput_1)
117 |
118 | foutput = torch.cat((foutput_1, foutput_2), dim=2)
119 | fhidden = (torch.cat((fhidden_1[0], fhidden_2[0]), dim=0),
120 | torch.cat((fhidden_1[1], fhidden_2[1]), dim=0))
121 |
122 | return foutput, fhidden
123 |
124 |
125 | class EncoderVidCoMem(nn.Module):
126 | def __init__(self, dim_app, dim_motion, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
127 | n_layers=1, bidirectional=False, rnn_cell='gru'):
128 | """
129 | """
130 | super(EncoderVidCoMem, self).__init__()
131 | self.dim_app = dim_app
132 | self.dim_motion = dim_motion
133 | self.dim_hidden = dim_hidden
134 | self.input_dropout_p = input_dropout_p
135 | self.rnn_dropout_p = rnn_dropout_p
136 | self.n_layers = n_layers
137 | self.bidirectional = bidirectional
138 | self.rnn_cell = rnn_cell
139 |
140 | if rnn_cell.lower() == 'lstm':
141 | self.rnn_cell = nn.LSTM
142 | elif rnn_cell.lower() == 'gru':
143 | self.rnn_cell = nn.GRU
144 |
145 | self.rnn_app_l1 = self.rnn_cell(self.dim_app, dim_hidden, n_layers, batch_first=True,
146 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
147 | self.rnn_app_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
148 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
149 |
150 | self.rnn_motion_l1 = self.rnn_cell(self.dim_motion, dim_hidden, n_layers, batch_first=True,
151 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
152 | self.rnn_motion_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
153 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
154 |
155 |
156 | def forward(self, vid_feats):
157 | """
158 | two separate LSTM to encode app and motion feature
159 | :param vid_feats:
160 | :return:
161 | """
162 | vid_app = vid_feats[:, :, 0:self.dim_app]
163 | vid_motion = vid_feats[:, :, self.dim_app:]
164 |
165 | app_output_l1, app_hidden_l1 = self.rnn_app_l1(vid_app)
166 | app_output_l2, app_hidden_l2 = self.rnn_app_l2(app_output_l1)
167 |
168 |
169 | motion_output_l1, motion_hidden_l1 = self.rnn_motion_l1(vid_motion)
170 | motion_output_l2, motion_hidden_l2 = self.rnn_motion_l2(motion_output_l1)
171 |
172 |
173 | return app_output_l1, app_output_l2, motion_output_l1, motion_output_l2
174 |
175 |
176 |
177 | class EncoderQnsHGA(nn.Module):
178 | def __init__(self, dim_embed, dim_hidden, vocab_size, glove_embed, input_dropout_p=0.2, rnn_dropout_p=0,
179 | n_layers=1, bidirectional=False, rnn_cell='gru'):
180 | """
181 |
182 | """
183 | super(EncoderQnsHGA, self).__init__()
184 | self.dim_hidden = dim_hidden
185 | self.vocab_size = vocab_size
186 | self.glove_embed = glove_embed
187 | self.input_dropout_p = input_dropout_p
188 | self.rnn_dropout_p = rnn_dropout_p
189 | self.n_layers = n_layers
190 | self.bidirectional = bidirectional
191 | self.rnn_cell = rnn_cell
192 |
193 | self.input_dropout = nn.Dropout(input_dropout_p)
194 |
195 | if rnn_cell.lower() == 'lstm':
196 | self.rnn_cell = nn.LSTM
197 | elif rnn_cell.lower() == 'gru':
198 | self.rnn_cell = nn.GRU
199 |
200 | self.embedding = nn.Embedding(vocab_size, dim_embed)
201 | word_mat = torch.FloatTensor(np.load(self.glove_embed))
202 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False)
203 | self.FC = nn.Sequential(nn.Linear(dim_embed, dim_hidden, bias=False),
204 | nn.ReLU(),
205 | )
206 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True,
207 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
208 |
209 |
210 | def forward(self, qns, qns_lengths, hidden=None):
211 | """
212 | """
213 | qns_embed = self.embedding(qns)
214 | qns_embed = self.input_dropout(qns_embed)
215 | qns_embed = self.FC(qns_embed)
216 | packed = pack_padded_sequence(qns_embed, qns_lengths, batch_first=True)
217 | packed_output, hidden = self.rnn(packed, hidden)
218 | output, _ = pad_packed_sequence(packed_output, batch_first=True)
219 | return output, hidden
220 |
221 |
222 | class EncoderVidHGA(nn.Module):
223 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0,
224 | n_layers=1, bidirectional=False, rnn_cell='gru'):
225 | """
226 | """
227 | super(EncoderVidHGA, self).__init__()
228 | self.dim_vid = dim_vid
229 | self.dim_app = 2048
230 | self.dim_mot = 2048
231 | self.dim_hidden = dim_hidden
232 | self.input_dropout_p = input_dropout_p
233 | self.rnn_dropout_p = rnn_dropout_p
234 | self.n_layers = n_layers
235 | self.bidirectional = bidirectional
236 | self.rnn_cell = rnn_cell
237 |
238 |
239 | self.mot2hid = nn.Sequential(nn.Linear(self.dim_mot, self.dim_app),
240 | nn.Dropout(input_dropout_p),
241 | nn.ReLU())
242 |
243 | self.appmot2hid = nn.Sequential(nn.Linear(self.dim_app*2, self.dim_app),
244 | nn.Dropout(input_dropout_p),
245 | nn.ReLU())
246 |
247 | if rnn_cell.lower() == 'lstm':
248 | self.rnn_cell = nn.LSTM
249 | elif rnn_cell.lower() == 'gru':
250 | self.rnn_cell = nn.GRU
251 |
252 | self.rnn = self.rnn_cell(self.dim_app, dim_hidden, n_layers, batch_first=True,
253 | bidirectional=bidirectional, dropout=self.rnn_dropout_p)
254 |
255 |
256 | def forward(self, vid_feats):
257 | """
258 | """
259 | vid_app = vid_feats[:, :, 0:self.dim_app]
260 | vid_motion = vid_feats[:, :, self.dim_app:]
261 |
262 | vid_motion_redu = self.mot2hid(vid_motion)
263 | vid_feats = self.appmot2hid(torch.cat((vid_app, vid_motion_redu), dim=2))
264 |
265 | self.rnn.flatten_parameters()
266 | foutput, fhidden = self.rnn(vid_feats)
267 |
268 | return foutput, fhidden
--------------------------------------------------------------------------------
/networks/VQAModel/CoMem.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random as rd
4 | import sys
5 | sys.path.insert(0, 'networks')
6 | from memory_module import EpisodicMemory
7 |
8 | class CoMem(nn.Module):
9 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2):
10 | """
11 | motion-appearance co-memory networks for video question answering (CVPR18)
12 | """
13 | super(CoMem, self).__init__()
14 | self.vid_encoder = vid_encoder
15 | self.qns_encoder = qns_encoder
16 | self.ans_decoder = ans_decoder
17 |
18 | dim = qns_encoder.dim_hidden
19 |
20 | self.epm_app = EpisodicMemory(dim*2)
21 | self.epm_mot = EpisodicMemory(dim*2)
22 |
23 | self.linear_ma = nn.Linear(dim*2*3, dim*2)
24 | self.linear_mb = nn.Linear(dim*2*3, dim*2)
25 |
26 | self.vq2word = nn.Linear(dim*2*2, dim)
27 | self._init_weights()
28 | self.device = device
29 |
30 | def _init_weights(self):
31 | """
32 | initialize the linear weights
33 | :return:
34 | """
35 | nn.init.xavier_normal_(self.linear_ma.weight)
36 | nn.init.xavier_normal_(self.linear_mb.weight)
37 | nn.init.xavier_normal_(self.vq2word.weight)
38 |
39 |
40 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'):
41 | """
42 | Co-memory network
43 | """
44 |
45 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim)
46 |
47 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1)
48 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1)
49 |
50 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths)
51 |
52 | # qns_output = qns_output.permute(1, 0, 2)
53 | batch_size, seq_len, qns_feat_dim = qns_output.size()
54 |
55 |
56 | qns_embed = qns_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
57 |
58 | m_app = outputs_app[:, -1, :]
59 | m_mot = outputs_motion[:, -1, :]
60 | ma, mb = m_app.detach(), m_mot.detach()
61 | m_app = m_app.unsqueeze(1)
62 | m_mot = m_mot.unsqueeze(1)
63 | for _ in range(iter_num):
64 | mm = ma + mb
65 | m_app = self.epm_app(outputs_app, mm, m_app)
66 | m_mot = self.epm_mot(outputs_motion, mm, m_mot)
67 | ma_q = torch.cat((ma, m_app.squeeze(1), qns_embed), dim=1)
68 | mb_q = torch.cat((mb, m_mot.squeeze(1), qns_embed), dim=1)
69 | # print(ma_q.shape)
70 | ma = torch.tanh(self.linear_ma(ma_q))
71 | mb = torch.tanh(self.linear_mb(mb_q))
72 |
73 | mem = torch.cat((ma, mb), dim=1)
74 | encoder_outputs = self.vq2word(mem)
75 | # hidden = qns_hidden
76 | hidden = encoder_outputs.unsqueeze(0)
77 |
78 | # decoder_inputs = encoder_outputs
79 |
80 | if mode == 'train':
81 | vocab_size = self.ans_decoder.vocab_size
82 | ans_len = ans.shape[1]
83 | input = ans[:, 0]
84 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
85 |
86 | for t in range(0, ans_len):
87 | output, hidden = self.ans_decoder(qns_output, hidden, input)
88 | outputs[:, t] = output
89 | teacher_force = rd.random() < teacher_force_ratio
90 | top1 = output.argmax(1)
91 | input = ans[:, t] if teacher_force else top1
92 | else:
93 | start = torch.LongTensor([1] * batch_size).to(self.device)
94 | outputs = self.ans_decoder.sample(qns_output, hidden, start)
95 |
96 | return outputs
--------------------------------------------------------------------------------
/networks/VQAModel/EVQA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random as rd
4 |
5 | class EVQA(nn.Module):
6 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device):
7 | """
8 | :param vid_encoder:
9 | :param qns_encoder:
10 | :param ans_decoder:
11 | :param device:
12 | """
13 | super(EVQA, self).__init__()
14 | self.vid_encoder = vid_encoder
15 | self.qns_encoder = qns_encoder
16 | self.ans_decoder = ans_decoder
17 | self.device = device
18 |
19 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
20 |
21 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats)
22 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths)
23 |
24 |
25 | hidden = qns_hidden[0] +vid_hidden[0]
26 | batch_size = qns.shape[0]
27 |
28 | if mode == 'train':
29 | vocab_size = self.ans_decoder.vocab_size
30 | ans_len = ans.shape[1]
31 | input = ans[:, 0]
32 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
33 | for t in range(0, ans_len):
34 | output, hidden = self.ans_decoder(qns_outputs, hidden, input)
35 | outputs[:,t] = output
36 | teacher_force = rd.random() < teacher_force_ratio
37 | top1 = output.argmax(1)
38 | input = ans[:, t] if teacher_force else top1
39 | else:
40 | start = torch.LongTensor([1] * batch_size).to(self.device)
41 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start)
42 |
43 | return outputs
44 |
--------------------------------------------------------------------------------
/networks/VQAModel/HGA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random as rd
4 | import sys
5 | sys.path.insert(0, 'networks')
6 | from q_v_transformer import CoAttention
7 | from gcn import AdjLearner, GCN
8 | from block import fusions #pytorch >= 1.1.0
9 |
10 | class HGA(nn.Module):
11 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device):
12 | """
13 | Reasoning with Heterogeneous Graph Alignment for Video Question Answering (AAAI20)
14 | """
15 | super(HGA, self).__init__()
16 | self.vid_encoder = vid_encoder
17 | self.qns_encoder = qns_encoder
18 | self.ans_decoder = ans_decoder
19 | self.max_len_v = max_len_v
20 | self.max_len_q = max_len_q
21 | self.device = device
22 | hidden_size = vid_encoder.dim_hidden
23 | input_dropout_p = vid_encoder.input_dropout_p
24 |
25 | self.q_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
26 | self.v_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
27 |
28 | self.co_attn = CoAttention(
29 | hidden_size, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
30 |
31 | self.adj_learner = AdjLearner(
32 | hidden_size, hidden_size, dropout=input_dropout_p)
33 |
34 | self.gcn = GCN(
35 | hidden_size,
36 | hidden_size,
37 | hidden_size,
38 | num_layers=2,
39 | dropout=input_dropout_p)
40 |
41 | self.gcn_atten_pool = nn.Sequential(
42 | nn.Linear(hidden_size, hidden_size // 2),
43 | nn.Tanh(),
44 | nn.Linear(hidden_size // 2, 1),
45 | nn.Softmax(dim=-1)) #dim=-2 for attention-pooling otherwise sum-pooling
46 |
47 | self.global_fusion = fusions.Block(
48 | [hidden_size, hidden_size], hidden_size, dropout_input=input_dropout_p)
49 |
50 | self.fusion = fusions.Block([hidden_size, hidden_size], hidden_size)
51 |
52 |
53 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
54 | """
55 |
56 | """
57 | encoder_out, qns_hidden, qns_out, vid_out = self.vq_encoder(vid_feats, qns, qns_lengths)
58 |
59 | batch_size = encoder_out.shape[0]
60 |
61 | hidden = encoder_out.unsqueeze(0)
62 | if mode == 'train':
63 | vocab_size = self.ans_decoder.vocab_size
64 | ans_len = ans.shape[1]
65 | input = ans[:, 0]
66 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
67 | for t in range(0, ans_len):
68 |
69 | output, hidden = self.ans_decoder(qns_out, hidden, input) #attqns, attvid
70 | outputs[:, t] = output
71 | teacher_force = rd.random() < teacher_force_ratio
72 | top1 = output.argmax(1)
73 | input = ans[:, t] if teacher_force else top1
74 | else:
75 | start = torch.LongTensor([1] * batch_size).to(self.device)
76 |
77 | outputs = self.ans_decoder.sample(qns_out, hidden, start) #vidatt, qns_att
78 |
79 | return outputs
80 |
81 |
82 | def vq_encoder(self, vid_feats, qns, qns_lengths):
83 | """
84 |
85 | :param vid_feats:
86 | :param qns:
87 | :param qns_lengths:
88 | :return:
89 | """
90 | q_output, s_hidden = self.qns_encoder(qns, qns_lengths)
91 | qns_last_hidden = torch.squeeze(s_hidden)
92 |
93 |
94 | v_output, v_hidden = self.vid_encoder(vid_feats)
95 | vid_last_hidden = torch.squeeze(v_hidden)
96 |
97 | q_output = self.q_input_ln(q_output)
98 | v_output = self.v_input_ln(v_output)
99 |
100 | q_output, v_output = self.co_attn(q_output, v_output)
101 |
102 | ### GCN
103 | adj = self.adj_learner(q_output, v_output)
104 | # q_v_inputs of shape (batch_size, q_v_len, hidden_size)
105 | q_v_inputs = torch.cat((q_output, v_output), dim=1)
106 | # q_v_output of shape (batch_size, q_v_len, hidden_size)
107 | q_v_output = self.gcn(q_v_inputs, adj)
108 |
109 | ## attention pool
110 | local_attn = self.gcn_atten_pool(q_v_output)
111 | local_out = torch.sum(q_v_output * local_attn, dim=1)
112 |
113 | # print(qns_last_hidden.shape, vid_last_hidden.shape)
114 | global_out = self.global_fusion((qns_last_hidden, vid_last_hidden))
115 |
116 |
117 | out = self.fusion((global_out, local_out)).squeeze() #4 x 512
118 |
119 | return out, s_hidden, q_output, v_output,
120 |
121 |
--------------------------------------------------------------------------------
/networks/VQAModel/HME.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random as rd
4 | import sys
5 | sys.path.insert(0, 'networks')
6 | from Attention import TempAttention
7 | from memory_rand import MemoryRamTwoStreamModule, MemoryRamModule, MMModule
8 |
9 |
10 | class HME(nn.Module):
11 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2):
12 | """
13 | Heterogeneous memory enhanced multimodal attention model for video question answering (CVPR19)
14 |
15 | """
16 | super(HME, self).__init__()
17 | self.vid_encoder = vid_encoder
18 | self.qns_encoder = qns_encoder
19 | self.ans_decoder = ans_decoder
20 |
21 | dim = qns_encoder.dim_hidden
22 |
23 | self.temp_att_a = TempAttention(dim * 2, dim * 2, hidden_dim=256)
24 | self.temp_att_m = TempAttention(dim * 2, dim * 2, hidden_dim=256)
25 | self.mrm_vid = MemoryRamTwoStreamModule(dim, dim, max_len_v, device)
26 | self.mrm_txt = MemoryRamModule(dim, dim, max_len_q, device)
27 |
28 | self.mm_module_v1 = MMModule(dim, input_drop_p, device)
29 |
30 | self.linear_vid = nn.Linear(dim*2, dim)
31 | self.linear_qns = nn.Linear(dim*2, dim)
32 | self.linear_mem = nn.Linear(dim*2, dim)
33 | self.vq2word_hme = nn.Linear(dim*3, dim*2)
34 | self._init_weights()
35 | self.device = device
36 |
37 | def _init_weights(self):
38 | """
39 | initialize the linear weights
40 | :return:
41 | """
42 | nn.init.xavier_normal_(self.linear_vid.weight)
43 | nn.init.xavier_normal_(self.linear_qns.weight)
44 | nn.init.xavier_normal_(self.linear_mem.weight)
45 | nn.init.xavier_normal_(self.vq2word_hme.weight)
46 |
47 |
48 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'):
49 | """
50 | """
51 |
52 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim)
53 |
54 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1)
55 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1)
56 |
57 | batch_size, fnum, vid_feat_dim = outputs_app.size()
58 |
59 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths)
60 | # print(qns_output.shape, qns_hidden[0].shape) #torch.Size([10, 23, 256]) torch.Size([2, 10, 256])
61 |
62 | # qns_output = qns_output.permute(1, 0, 2)
63 | batch_size, seq_len, qns_feat_dim = qns_output.size()
64 |
65 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
66 |
67 | # Apply temporal attention
68 | att_app, beta_app = self.temp_att_a(qns_embed, outputs_app)
69 | att_motion, beta_motion = self.temp_att_m(qns_embed, outputs_motion)
70 | tmp_app_motion = torch.cat((outputs_app_l2[:, -1, :], outputs_motion_l2[:, -1, :]), dim=-1)
71 |
72 | mem_output = torch.zeros(batch_size, vid_feat_dim).to(self.device)
73 |
74 | for bs in range(batch_size):
75 | mem_ram_vid = self.mrm_vid(outputs_app_l2[bs], outputs_motion_l2[bs], fnum)
76 | cur_qns = qns_output[bs][:qns_lengths[bs]]
77 | mem_ram_txt = self.mrm_txt(cur_qns, qns_lengths[bs]) #should remove padded zeros
78 | mem_output[bs] = self.mm_module_v1(tmp_app_motion[bs].unsqueeze(0), mem_ram_vid, mem_ram_txt, iter_num)
79 | """
80 | (64, 256) (22, 256) (1, 512)
81 | """
82 | app_trans = torch.tanh(self.linear_vid(att_app))
83 | motion_trans = torch.tanh(self.linear_vid(att_motion))
84 | mem_trans = torch.tanh(self.linear_mem(mem_output))
85 |
86 | encoder_outputs = torch.cat((app_trans, motion_trans, mem_trans), dim=1)
87 | decoder_inputs = self.vq2word_hme(encoder_outputs)
88 | hidden = qns_hidden
89 | if mode == 'train':
90 | vocab_size = self.ans_decoder.vocab_size
91 | ans_len = ans.shape[1]
92 | input = ans[:, 0]
93 |
94 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
95 |
96 | for t in range(0, ans_len):
97 | output, hidden = self.ans_decoder(decoder_inputs, hidden, input)
98 | outputs[:, t] = output
99 | teacher_force = rd.random() < teacher_force_ratio
100 | top1 = output.argmax(1)
101 | input = ans[:, t] if teacher_force else top1
102 | else:
103 | start = torch.LongTensor([1] * batch_size).to(self.device)
104 | outputs = self.ans_decoder.sample(decoder_inputs, hidden, start)
105 |
106 | return outputs
--------------------------------------------------------------------------------
/networks/VQAModel/STVQA.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random as rd
4 | import sys
5 | sys.path.insert(0, 'networks')
6 | from Attention import TempAttention, SpatialAttention
7 |
8 |
9 | class STVQA(nn.Module):
10 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, att_dim, device):
11 | """
12 | TGIF-QA: Toward Spatio-Temporal Reasoning in Visual Question Answering (CVPR17)
13 | """
14 | super(STVQA, self).__init__()
15 | self.vid_encoder = vid_encoder
16 | self.qns_encoder = qns_encoder
17 | self.ans_decoder = ans_decoder
18 | self.att_dim = att_dim
19 |
20 | self.spatial_att = SpatialAttention(qns_encoder.dim_hidden*2, vid_encoder.input_dim, hidden_dim=self.att_dim)
21 | self.temp_att = TempAttention(qns_encoder.dim_hidden*2, vid_encoder.dim_hidden*2, hidden_dim=self.att_dim)
22 | self.device = device
23 | self.FC = nn.Linear(att_dim*2, att_dim)
24 |
25 |
26 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
27 | """
28 | """
29 | qns_output_1, qns_hidden_1 = self.qns_encoder(qns, qns_lengths)
30 | n_layers, batch_size, qns_dim = qns_hidden_1[0].size()
31 |
32 | # Concatenate the dual-layer hidden as qns embedding
33 | qns_embed = qns_hidden_1[0].permute(1, 0, 2) # batch first
34 | qns_embed = qns_embed.reshape(batch_size, -1) #(batch_size, feat_dim*2)
35 | batch_size, fnum, vid_dim, w, h = vid_feats.size()
36 |
37 | # Apply spatial attention
38 | vid_att_feats = torch.zeros(batch_size, fnum, vid_dim).to(self.device)
39 | for bs in range(batch_size):
40 | vid_att_feats[bs], alpha = self.spatial_att(qns_embed[bs], vid_feats[bs])
41 |
42 | vid_outputs, vid_hidden = self.vid_encoder(vid_att_feats)
43 |
44 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths, vid_hidden)
45 |
46 | """
47 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
48 | torch.Size([16, 3, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
49 | """
50 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
51 |
52 | # Apply temporal attention
53 | temp_att_outputs, beta = self.temp_att(qns_embed, vid_outputs)
54 | encoder_outputs = self.FC(qns_embed + temp_att_outputs)
55 | # hidden = qns_hidden
56 | hidden = encoder_outputs.unsqueeze(0)
57 | # print(hidden.size())
58 |
59 | if mode == 'train':
60 | vocab_size = self.ans_decoder.vocab_size
61 | ans_len = ans.shape[1]
62 | input = ans[:, 0]
63 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
64 | for t in range(0, ans_len):
65 | # output, hidden = self.ans_decoder(encoder_outputs, hidden, input)
66 | output, hidden = self.ans_decoder(qns_outputs, hidden, input)
67 | outputs[:, t] = output
68 | teacher_force = rd.random() < teacher_force_ratio
69 | top1 = output.argmax(1)
70 | input = ans[:, t] if teacher_force else top1
71 | else:
72 | start = torch.LongTensor([1] * batch_size).to(self.device)
73 | # outputs = self.ans_decoder.sample(encoder_outputs, hidden, start)
74 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start)
75 |
76 | return outputs
77 |
--------------------------------------------------------------------------------
/networks/VQAModel/UATT.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random as rd
4 | import sys
5 | sys.path.insert(0, 'networks')
6 | from Attention import TempAttentionHis
7 |
8 | class UATT(nn.Module):
9 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device):
10 | """
11 | Unifying the Video and Question Attentions for Open-Ended Video Question Answering (TIP17)
12 | """
13 | super(UATT, self).__init__()
14 | self.vid_encoder = vid_encoder
15 | self.qns_encoder = qns_encoder
16 | self.ans_decoder = ans_decoder
17 | mem_dim = 512
18 | self.att_q2v = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim)
19 | self.att_v2q = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim)
20 |
21 | self.device = device
22 |
23 |
24 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
25 |
26 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats)
27 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths)
28 | qns_outputs = qns_outputs.permute(1, 0, 2)
29 | # print(vid_outputs.size(), vid_hidden[0].size(), vid_hidden[1].size())
30 | # print(qns_outputs.size(), qns_hidden[0].size(), qns_hidden[1].size())
31 | """
32 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
33 | torch.Size([3, 16, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
34 | """
35 |
36 | word_num, batch_size, feat_dim = qns_outputs.size()
37 | r = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device)
38 |
39 | for word in qns_outputs:
40 | r, beta_r = self.att_q2v(word, vid_outputs, r)
41 |
42 | vid_outputs = vid_outputs.permute(1, 0, 2) # change to fnum, batch_size, feat_dim
43 | qns_outputs = qns_outputs.permute(1, 0, 2) # change to batch_size, word_num, feat_dim
44 | w = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device)
45 | for frame in vid_outputs:
46 | w, beta_w = self.att_v2q(frame, qns_outputs, w)
47 |
48 | hidden = (torch.cat((r.unsqueeze(0), w.unsqueeze(0)), dim=0),
49 | torch.cat((vid_hidden[1][0].unsqueeze(0), qns_hidden[1][0].unsqueeze(0)), dim=0))
50 |
51 | if mode == 'train':
52 | vocab_size = self.ans_decoder.vocab_size
53 | batch_size = ans.shape[0]
54 | ans_len = ans.shape[1]
55 | input = ans[:, 0]
56 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
57 | for t in range(0, ans_len):
58 | output, hidden = self.ans_decoder(hidden, input)
59 | outputs[:, t] = output
60 | teacher_force = rd.random() < teacher_force_ratio
61 | top1 = output.argmax(1)
62 | input = ans[:, t] if teacher_force else top1
63 | else:
64 | start = torch.LongTensor([1] * batch_size).to(self.device)
65 | outputs = self.ans_decoder.sample(hidden, start)
66 | return outputs
67 |
--------------------------------------------------------------------------------
/networks/VQAModel_bak.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random as rd
4 | from .Attention import TempAttentionHis, TempAttention, SpatialAttention
5 | from .memory_rand import MemoryRamTwoStreamModule, MemoryRamModule, MMModule
6 | from .memory_module import EpisodicMemory
7 | from .q_v_transformer import CoAttention
8 | from .gcn import AdjLearner, GCN
9 | from block import fusions #pytorch >= 1.1.0
10 |
11 |
12 | class EVQA(nn.Module):
13 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device):
14 | """
15 |
16 | :param vid_encoder:
17 | :param qns_encoder:
18 | :param ans_decoder:
19 | :param device:
20 | """
21 | super(EVQA, self).__init__()
22 | self.vid_encoder = vid_encoder
23 | self.qns_encoder = qns_encoder
24 | self.ans_decoder = ans_decoder
25 | self.device = device
26 |
27 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
28 |
29 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats)
30 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths)
31 |
32 | # print(vid_outputs.size(), vid_hidden[0].size(), vid_hidden[1].size())
33 | # print(qns_outputs.size(), qns_hidden[0].size(), qns_hidden[1].size())
34 | """
35 | torch.Size([64, 128, 512]) torch.Size([1, 64 512]) torch.Size([1, 64, 512])
36 | torch.Size([16, 64, 512]) torch.Size([2, 64, 512]) torch.Size([2, 64, 512])
37 | """
38 |
39 |
40 | hidden = qns_hidden[0] +vid_hidden[0]
41 | batch_size = qns.shape[0]
42 |
43 | if mode == 'train':
44 | vocab_size = self.ans_decoder.vocab_size
45 | ans_len = ans.shape[1]
46 | input = ans[:, 0]
47 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
48 | for t in range(0, ans_len):
49 | output, hidden = self.ans_decoder(qns_outputs, hidden, input)
50 | outputs[:,t] = output
51 | teacher_force = rd.random() < teacher_force_ratio
52 | top1 = output.argmax(1)
53 | input = ans[:, t] if teacher_force else top1
54 | else:
55 | start = torch.LongTensor([1] * batch_size).to(self.device)
56 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start)
57 |
58 | return outputs
59 |
60 |
61 | class UATT(nn.Module):
62 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device):
63 | """
64 | Unifying the Video and Question Attentions for Open-Ended Video Question Answering (TIP17)
65 | """
66 | super(UATT, self).__init__()
67 | self.vid_encoder = vid_encoder
68 | self.qns_encoder = qns_encoder
69 | self.ans_decoder = ans_decoder
70 | mem_dim = 512
71 | self.att_q2v = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim)
72 | self.att_v2q = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim)
73 |
74 | self.device = device
75 |
76 |
77 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
78 |
79 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats)
80 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths)
81 | qns_outputs = qns_outputs.permute(1, 0, 2)
82 | # print(vid_outputs.size(), vid_hidden[0].size(), vid_hidden[1].size())
83 | # print(qns_outputs.size(), qns_hidden[0].size(), qns_hidden[1].size())
84 | """
85 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
86 | torch.Size([3, 16, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
87 | """
88 |
89 | word_num, batch_size, feat_dim = qns_outputs.size()
90 | r = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device)
91 |
92 | for word in qns_outputs:
93 | r, beta_r = self.att_q2v(word, vid_outputs, r)
94 |
95 | vid_outputs = vid_outputs.permute(1, 0, 2) # change to fnum, batch_size, feat_dim
96 | qns_outputs = qns_outputs.permute(1, 0, 2) # change to batch_size, word_num, feat_dim
97 | w = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device)
98 | for frame in vid_outputs:
99 | w, beta_w = self.att_v2q(frame, qns_outputs, w)
100 |
101 | hidden = (torch.cat((r.unsqueeze(0), w.unsqueeze(0)), dim=0),
102 | torch.cat((vid_hidden[1][0].unsqueeze(0), qns_hidden[1][0].unsqueeze(0)), dim=0))
103 |
104 | if mode == 'train':
105 | vocab_size = self.ans_decoder.vocab_size
106 | batch_size = ans.shape[0]
107 | ans_len = ans.shape[1]
108 | input = ans[:, 0]
109 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
110 | for t in range(0, ans_len):
111 | output, hidden = self.ans_decoder(hidden, input)
112 | outputs[:, t] = output
113 | teacher_force = rd.random() < teacher_force_ratio
114 | top1 = output.argmax(1)
115 | input = ans[:, t] if teacher_force else top1
116 | else:
117 | start = torch.LongTensor([1] * batch_size).to(self.device)
118 | outputs = self.ans_decoder.sample(hidden, start)
119 | return outputs
120 |
121 |
122 | class STVQA(nn.Module):
123 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, att_dim, device):
124 | """
125 | TGIF-QA: Toward Spatio-Temporal Reasoning in Visual Question Answering (CVPR17)
126 | """
127 | super(STVQA, self).__init__()
128 | self.vid_encoder = vid_encoder
129 | self.qns_encoder = qns_encoder
130 | self.ans_decoder = ans_decoder
131 | self.att_dim = att_dim
132 |
133 | self.spatial_att = SpatialAttention(qns_encoder.dim_hidden*2, vid_encoder.input_dim, hidden_dim=self.att_dim)
134 | self.temp_att = TempAttention(qns_encoder.dim_hidden*2, vid_encoder.dim_hidden*2, hidden_dim=self.att_dim)
135 | self.device = device
136 | self.FC = nn.Linear(att_dim*2, att_dim)
137 |
138 |
139 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
140 | """
141 | """
142 | qns_output_1, qns_hidden_1 = self.qns_encoder(qns, qns_lengths)
143 | n_layers, batch_size, qns_dim = qns_hidden_1[0].size()
144 |
145 | # Concatenate the dual-layer hidden as qns embedding
146 | qns_embed = qns_hidden_1[0].permute(1, 0, 2) # batch first
147 | qns_embed = qns_embed.reshape(batch_size, -1) #(batch_size, feat_dim*2)
148 | batch_size, fnum, vid_dim, w, h = vid_feats.size()
149 |
150 | # Apply spatial attention
151 | vid_att_feats = torch.zeros(batch_size, fnum, vid_dim).to(self.device)
152 | for bs in range(batch_size):
153 | vid_att_feats[bs], alpha = self.spatial_att(qns_embed[bs], vid_feats[bs])
154 |
155 | vid_outputs, vid_hidden = self.vid_encoder(vid_att_feats)
156 |
157 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths, vid_hidden)
158 |
159 | """
160 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
161 | torch.Size([16, 3, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512])
162 | """
163 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
164 |
165 | # Apply temporal attention
166 | temp_att_outputs, beta = self.temp_att(qns_embed, vid_outputs)
167 | encoder_outputs = self.FC(qns_embed + temp_att_outputs)
168 | # hidden = qns_hidden
169 | hidden = encoder_outputs.unsqueeze(0)
170 | # print(hidden.size())
171 |
172 | if mode == 'train':
173 | vocab_size = self.ans_decoder.vocab_size
174 | ans_len = ans.shape[1]
175 | input = ans[:, 0]
176 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
177 | for t in range(0, ans_len):
178 | # output, hidden = self.ans_decoder(encoder_outputs, hidden, input)
179 | output, hidden = self.ans_decoder(qns_outputs, hidden, input)
180 | outputs[:, t] = output
181 | teacher_force = rd.random() < teacher_force_ratio
182 | top1 = output.argmax(1)
183 | input = ans[:, t] if teacher_force else top1
184 | else:
185 | start = torch.LongTensor([1] * batch_size).to(self.device)
186 | # outputs = self.ans_decoder.sample(encoder_outputs, hidden, start)
187 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start)
188 |
189 | return outputs
190 |
191 |
192 | class CoMem(nn.Module):
193 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2):
194 | """
195 | motion-appearance co-memory networks for video question answering (CVPR18)
196 | """
197 | super(CoMem, self).__init__()
198 | self.vid_encoder = vid_encoder
199 | self.qns_encoder = qns_encoder
200 | self.ans_decoder = ans_decoder
201 |
202 | dim = qns_encoder.dim_hidden
203 |
204 | self.epm_app = EpisodicMemory(dim*2)
205 | self.epm_mot = EpisodicMemory(dim*2)
206 |
207 | self.linear_ma = nn.Linear(dim*2*3, dim*2)
208 | self.linear_mb = nn.Linear(dim*2*3, dim*2)
209 |
210 | self.vq2word = nn.Linear(dim*2*2, dim)
211 | self._init_weights()
212 | self.device = device
213 |
214 | def _init_weights(self):
215 | """
216 | initialize the linear weights
217 | :return:
218 | """
219 | nn.init.xavier_normal_(self.linear_ma.weight)
220 | nn.init.xavier_normal_(self.linear_mb.weight)
221 | nn.init.xavier_normal_(self.vq2word.weight)
222 |
223 |
224 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'):
225 | """
226 | Co-memory network
227 | """
228 |
229 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim)
230 |
231 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1)
232 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1)
233 |
234 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths)
235 |
236 | # qns_output = qns_output.permute(1, 0, 2)
237 | batch_size, seq_len, qns_feat_dim = qns_output.size()
238 |
239 |
240 | qns_embed = qns_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
241 |
242 | m_app = outputs_app[:, -1, :]
243 | m_mot = outputs_motion[:, -1, :]
244 | ma, mb = m_app.detach(), m_mot.detach()
245 | m_app = m_app.unsqueeze(1)
246 | m_mot = m_mot.unsqueeze(1)
247 | for _ in range(iter_num):
248 | mm = ma + mb
249 | m_app = self.epm_app(outputs_app, mm, m_app)
250 | m_mot = self.epm_mot(outputs_motion, mm, m_mot)
251 | ma_q = torch.cat((ma, m_app.squeeze(1), qns_embed), dim=1)
252 | mb_q = torch.cat((mb, m_mot.squeeze(1), qns_embed), dim=1)
253 | # print(ma_q.shape)
254 | ma = torch.tanh(self.linear_ma(ma_q))
255 | mb = torch.tanh(self.linear_mb(mb_q))
256 |
257 | mem = torch.cat((ma, mb), dim=1)
258 | encoder_outputs = self.vq2word(mem)
259 | # hidden = qns_hidden
260 | hidden = encoder_outputs.unsqueeze(0)
261 |
262 | # decoder_inputs = encoder_outputs
263 |
264 | if mode == 'train':
265 | vocab_size = self.ans_decoder.vocab_size
266 | ans_len = ans.shape[1]
267 | input = ans[:, 0]
268 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
269 |
270 | for t in range(0, ans_len):
271 | output, hidden = self.ans_decoder(qns_output, hidden, input)
272 | outputs[:, t] = output
273 | teacher_force = rd.random() < teacher_force_ratio
274 | top1 = output.argmax(1)
275 | input = ans[:, t] if teacher_force else top1
276 | else:
277 | start = torch.LongTensor([1] * batch_size).to(self.device)
278 | outputs = self.ans_decoder.sample(qns_output, hidden, start)
279 |
280 | return outputs
281 |
282 |
283 | class HME(nn.Module):
284 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2):
285 | """
286 | Heterogeneous memory enhanced multimodal attention model for video question answering (CVPR19)
287 |
288 | """
289 | super(HME, self).__init__()
290 | self.vid_encoder = vid_encoder
291 | self.qns_encoder = qns_encoder
292 | self.ans_decoder = ans_decoder
293 |
294 | dim = qns_encoder.dim_hidden
295 |
296 | self.temp_att_a = TempAttention(dim * 2, dim * 2, hidden_dim=256)
297 | self.temp_att_m = TempAttention(dim * 2, dim * 2, hidden_dim=256)
298 | self.mrm_vid = MemoryRamTwoStreamModule(dim, dim, max_len_v, device)
299 | self.mrm_txt = MemoryRamModule(dim, dim, max_len_q, device)
300 |
301 | self.mm_module_v1 = MMModule(dim, input_drop_p, device)
302 |
303 | self.linear_vid = nn.Linear(dim*2, dim)
304 | self.linear_qns = nn.Linear(dim*2, dim)
305 | self.linear_mem = nn.Linear(dim*2, dim)
306 | self.vq2word_hme = nn.Linear(dim*3, dim*2)
307 | self._init_weights()
308 | self.device = device
309 |
310 | def _init_weights(self):
311 | """
312 | initialize the linear weights
313 | :return:
314 | """
315 | nn.init.xavier_normal_(self.linear_vid.weight)
316 | nn.init.xavier_normal_(self.linear_qns.weight)
317 | nn.init.xavier_normal_(self.linear_mem.weight)
318 | nn.init.xavier_normal_(self.vq2word_hme.weight)
319 |
320 |
321 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'):
322 | """
323 | """
324 |
325 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim)
326 |
327 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1)
328 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1)
329 |
330 | batch_size, fnum, vid_feat_dim = outputs_app.size()
331 |
332 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths)
333 | # print(qns_output.shape, qns_hidden[0].shape) #torch.Size([10, 23, 256]) torch.Size([2, 10, 256])
334 |
335 | # qns_output = qns_output.permute(1, 0, 2)
336 | batch_size, seq_len, qns_feat_dim = qns_output.size()
337 |
338 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim)
339 |
340 | # Apply temporal attention
341 | att_app, beta_app = self.temp_att_a(qns_embed, outputs_app)
342 | att_motion, beta_motion = self.temp_att_m(qns_embed, outputs_motion)
343 | tmp_app_motion = torch.cat((outputs_app_l2[:, -1, :], outputs_motion_l2[:, -1, :]), dim=-1)
344 |
345 | mem_output = torch.zeros(batch_size, vid_feat_dim).to(self.device)
346 |
347 | for bs in range(batch_size):
348 | mem_ram_vid = self.mrm_vid(outputs_app_l2[bs], outputs_motion_l2[bs], fnum)
349 | cur_qns = qns_output[bs][:qns_lengths[bs]]
350 | mem_ram_txt = self.mrm_txt(cur_qns, qns_lengths[bs]) #should remove padded zeros
351 | mem_output[bs] = self.mm_module_v1(tmp_app_motion[bs].unsqueeze(0), mem_ram_vid, mem_ram_txt, iter_num)
352 | """
353 | (64, 256) (22, 256) (1, 512)
354 | """
355 | app_trans = torch.tanh(self.linear_vid(att_app))
356 | motion_trans = torch.tanh(self.linear_vid(att_motion))
357 | mem_trans = torch.tanh(self.linear_mem(mem_output))
358 |
359 | encoder_outputs = torch.cat((app_trans, motion_trans, mem_trans), dim=1)
360 | decoder_inputs = self.vq2word_hme(encoder_outputs)
361 | hidden = qns_hidden
362 | if mode == 'train':
363 | vocab_size = self.ans_decoder.vocab_size
364 | ans_len = ans.shape[1]
365 | input = ans[:, 0]
366 |
367 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
368 |
369 | for t in range(0, ans_len):
370 | output, hidden = self.ans_decoder(decoder_inputs, hidden, input)
371 | outputs[:, t] = output
372 | teacher_force = rd.random() < teacher_force_ratio
373 | top1 = output.argmax(1)
374 | input = ans[:, t] if teacher_force else top1
375 | else:
376 | start = torch.LongTensor([1] * batch_size).to(self.device)
377 | outputs = self.ans_decoder.sample(decoder_inputs, hidden, start)
378 |
379 | return outputs
380 |
381 |
382 | class HGA(nn.Module):
383 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device):
384 | """
385 | Reasoning with Heterogeneous Graph Alignment for Video Question Answering (AAAI20)
386 | """
387 | super(HGA, self).__init__()
388 | self.vid_encoder = vid_encoder
389 | self.qns_encoder = qns_encoder
390 | self.ans_decoder = ans_decoder
391 | self.max_len_v = max_len_v
392 | self.max_len_q = max_len_q
393 | self.device = device
394 | hidden_size = vid_encoder.dim_hidden
395 | input_dropout_p = vid_encoder.input_dropout_p
396 |
397 | self.q_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
398 | self.v_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False)
399 |
400 | self.co_attn = CoAttention(
401 | hidden_size, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p)
402 |
403 | self.adj_learner = AdjLearner(
404 | hidden_size, hidden_size, dropout=input_dropout_p)
405 |
406 | self.gcn = GCN(
407 | hidden_size,
408 | hidden_size,
409 | hidden_size,
410 | num_layers=2,
411 | dropout=input_dropout_p)
412 |
413 | self.gcn_atten_pool = nn.Sequential(
414 | nn.Linear(hidden_size, hidden_size // 2),
415 | nn.Tanh(),
416 | nn.Linear(hidden_size // 2, 1),
417 | nn.Softmax(dim=-1))
418 |
419 | self.global_fusion = fusions.Block(
420 | [hidden_size, hidden_size], hidden_size, dropout_input=input_dropout_p)
421 |
422 | self.fusion = fusions.Block([hidden_size, hidden_size], hidden_size)
423 |
424 |
425 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'):
426 | """
427 |
428 | """
429 | encoder_out, qns_hidden, qns_out, vid_out = self.vq_encoder(vid_feats, qns, qns_lengths)
430 |
431 | batch_size = encoder_out.shape[0]
432 |
433 | hidden = encoder_out.unsqueeze(0)
434 | if mode == 'train':
435 | vocab_size = self.ans_decoder.vocab_size
436 | ans_len = ans.shape[1]
437 | input = ans[:, 0]
438 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device)
439 | for t in range(0, ans_len):
440 |
441 | output, hidden = self.ans_decoder(qns_out, hidden, input) #attqns, attvid
442 | outputs[:, t] = output
443 | teacher_force = rd.random() < teacher_force_ratio
444 | top1 = output.argmax(1)
445 | input = ans[:, t] if teacher_force else top1
446 | else:
447 | start = torch.LongTensor([1] * batch_size).to(self.device)
448 |
449 | outputs = self.ans_decoder.sample(qns_out, hidden, start) #vidatt, qns_att
450 |
451 | return outputs
452 |
453 |
454 | def vq_encoder(self, vid_feats, qns, qns_lengths):
455 | """
456 |
457 | :param vid_feats:
458 | :param qns:
459 | :param qns_lengths:
460 | :return:
461 | """
462 | q_output, s_hidden = self.qns_encoder(qns, qns_lengths)
463 | qns_last_hidden = torch.squeeze(s_hidden)
464 |
465 |
466 | v_output, v_hidden = self.vid_encoder(vid_feats)
467 | vid_last_hidden = torch.squeeze(v_hidden)
468 |
469 | q_output = self.q_input_ln(q_output)
470 | v_output = self.v_input_ln(v_output)
471 |
472 | q_output, v_output = self.co_attn(q_output, v_output)
473 |
474 | ### GCN
475 | adj = self.adj_learner(q_output, v_output)
476 | # q_v_inputs of shape (batch_size, q_v_len, hidden_size)
477 | q_v_inputs = torch.cat((q_output, v_output), dim=1)
478 | # q_v_output of shape (batch_size, q_v_len, hidden_size)
479 | q_v_output = self.gcn(q_v_inputs, adj)
480 |
481 | ## attention pool
482 | local_attn = self.gcn_atten_pool(q_v_output)
483 | local_out = torch.sum(q_v_output * local_attn, dim=1)
484 |
485 | # print(qns_last_hidden.shape, vid_last_hidden.shape)
486 | global_out = self.global_fusion((qns_last_hidden, vid_last_hidden))
487 |
488 |
489 | out = self.fusion((global_out, local_out)).squeeze() #4 x 512
490 |
491 | return out, s_hidden, q_output, v_output,
492 |
493 |
--------------------------------------------------------------------------------
/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .EncoderRNN import EncoderVid, EncoderQns
2 | from .DecoderRNN import AnsUATT
3 | from .VQAModel import EVQA
4 |
5 |
--------------------------------------------------------------------------------
/networks/gcn.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from torch.nn.parameter import Parameter
7 | from torch.nn.modules.module import Module
8 |
9 |
10 | def padding_mask_k(seq_q, seq_k):
11 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
12 | In batch 0:
13 | [[x x x 0] [[0 0 0 1]
14 | [x x x 0]-> [0 0 0 1]
15 | [x x x 0]] [0 0 0 1]] uint8
16 | """
17 | fake_q = torch.ones_like(seq_q)
18 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2))
19 | pad_mask = pad_mask.eq(0)
20 | # pad_mask = pad_mask.lt(1e-3)
21 | return pad_mask
22 |
23 |
24 | def padding_mask_q(seq_q, seq_k):
25 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
26 | In batch 0:
27 | [[x x x x] [[0 0 0 0]
28 | [x x x x] -> [0 0 0 0]
29 | [0 0 0 0]] [1 1 1 1]] uint8
30 | """
31 | fake_k = torch.ones_like(seq_k)
32 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2))
33 | pad_mask = pad_mask.eq(0)
34 | # pad_mask = pad_mask.lt(1e-3)
35 | return pad_mask
36 |
37 |
38 | class GraphConvolution(Module):
39 | """
40 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
41 | """
42 |
43 | def __init__(self, in_features, out_features):
44 | super(GraphConvolution, self).__init__()
45 | self.weight = nn.Linear(in_features, out_features, bias=False)
46 | self.layer_norm = nn.LayerNorm(out_features, elementwise_affine=False)
47 |
48 | def forward(self, input, adj):
49 | # self.weight of shape (hidden_size, hidden_size)
50 | support = self.weight(input)
51 | output = torch.bmm(adj, support)
52 | output = self.layer_norm(output)
53 | return output
54 |
55 |
56 | class GraphAttention(nn.Module):
57 | """
58 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
59 | """
60 |
61 | def __init__(self, in_features, out_features, dropout, alpha, concat=True):
62 | super(GraphAttention, self).__init__()
63 | self.dropout = dropout
64 | self.in_features = in_features
65 | self.out_features = out_features
66 | self.alpha = alpha
67 | self.concat = concat
68 |
69 | self.W = nn.Parameter(
70 | nn.init.xavier_normal_(
71 | torch.Tensor(in_features, out_features).type(
72 | torch.cuda.FloatTensor if torch.cuda.is_available(
73 | ) else torch.FloatTensor),
74 | gain=np.sqrt(2.0)),
75 | requires_grad=True)
76 | self.a1 = nn.Parameter(
77 | nn.init.xavier_normal_(
78 | torch.Tensor(out_features, 1).type(
79 | torch.cuda.FloatTensor if torch.cuda.is_available(
80 | ) else torch.FloatTensor),
81 | gain=np.sqrt(2.0)),
82 | requires_grad=True)
83 | self.a2 = nn.Parameter(
84 | nn.init.xavier_normal_(
85 | torch.Tensor(out_features, 1).type(
86 | torch.cuda.FloatTensor if torch.cuda.is_available(
87 | ) else torch.FloatTensor),
88 | gain=np.sqrt(2.0)),
89 | requires_grad=True)
90 |
91 | self.leakyrelu = nn.LeakyReLU(self.alpha)
92 |
93 | def forward(self, input, adj):
94 | h = torch.mm(input, self.W)
95 | N = h.size()[0]
96 |
97 | f_1 = torch.matmul(h, self.a1)
98 | f_2 = torch.matmul(h, self.a2)
99 | e = self.leakyrelu(f_1 + f_2.transpose(0, 1))
100 |
101 | zero_vec = -9e15 * torch.ones_like(e)
102 | attention = torch.where(adj > 0, e, zero_vec)
103 | attention = F.softmax(attention, dim=1)
104 | attention = F.dropout(attention, self.dropout, training=self.training)
105 | h_prime = torch.matmul(attention, h)
106 |
107 | if self.concat:
108 | return F.elu(h_prime)
109 | else:
110 | return h_prime
111 |
112 |
113 | class GCN(nn.Module):
114 |
115 | def __init__(
116 | self, input_size, hidden_size, num_classes, num_layers=1,
117 | dropout=0.1):
118 | super(GCN, self).__init__()
119 | self.layers = nn.ModuleList()
120 | self.layers.append(GraphConvolution(input_size, hidden_size))
121 | for i in range(num_layers - 1):
122 | self.layers.append(GraphConvolution(hidden_size, hidden_size))
123 | self.layers.append(GraphConvolution(hidden_size, num_classes))
124 | self.dropout = nn.Dropout(p=dropout)
125 |
126 | def forward(self, x, adj):
127 | for i, layer in enumerate(self.layers):
128 | x = self.dropout(F.relu(layer(x, adj)))
129 |
130 | # x of shape (bs, q_v_len, num_classes)
131 | return x
132 |
133 |
134 | class AdjLearner(Module):
135 |
136 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1):
137 | super().__init__()
138 | '''
139 | ## Variables:
140 | - in_feature_dim: dimensionality of input features
141 | - hidden_size: dimensionality of the joint hidden embedding
142 | - K: number of graph nodes/objects on the image
143 | '''
144 |
145 | # Embedding layers. Padded 0 => 0
146 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False)
147 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False)
148 |
149 | # Regularisation
150 | self.dropout = nn.Dropout(p=dropout)
151 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1)
152 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2)
153 |
154 | def forward(self, questions, videos):
155 | '''
156 | ## Inputs:
157 | ## Returns:
158 | - adjacency matrix (batch_size, q_v_len, q_v_len)
159 | '''
160 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features
161 | graph_nodes = torch.cat((questions, videos), dim=1)
162 |
163 | # layer 1
164 | h = self.edge_layer_1(graph_nodes)
165 | h = F.relu(h)
166 |
167 | # layer 2
168 | h = self.edge_layer_2(h)
169 | h = F.relu(h)
170 | # h * sigmoid(Wh)
171 | # h = F.tanh(h)
172 |
173 | # outer product
174 | adjacency_matrix = torch.bmm(h, h.transpose(1, 2))
175 |
176 | return adjacency_matrix
177 |
178 |
179 | class EvoAdjLearner(Module):
180 |
181 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1):
182 | super().__init__()
183 | '''
184 | ## Variables:
185 | - in_feature_dim: dimensionality of input features
186 | - hidden_size: dimensionality of the joint hidden embedding
187 | - K: number of graph nodes/objects on the image
188 | '''
189 |
190 | # Embedding layers. Padded 0 => 0
191 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False)
192 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False)
193 | self.edge_layer_3 = nn.Linear(in_feature_dim, hidden_size, bias=False)
194 | self.edge_layer_4 = nn.Linear(hidden_size, hidden_size, bias=False)
195 |
196 | # Regularisation
197 | self.dropout = nn.Dropout(p=dropout)
198 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1)
199 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2)
200 |
201 | def forward(self, questions, videos):
202 | '''
203 | ## Inputs:
204 | ## Returns:
205 | - adjacency matrix (batch_size, q_v_len, q_v_len)
206 | '''
207 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features
208 | graph_nodes = torch.cat((questions, videos), dim=1)
209 |
210 | attn_mask = padding_mask_k(graph_nodes, graph_nodes)
211 | sf_mask = padding_mask_q(graph_nodes, graph_nodes)
212 |
213 | # layer 1
214 | h = self.edge_layer_1(graph_nodes)
215 | h = F.relu(h)
216 | # layer 2
217 | h = self.edge_layer_2(h)
218 | # h = F.relu(h)
219 |
220 | # layer 1
221 | h_ = self.edge_layer_3(graph_nodes)
222 | h_ = F.relu(h_)
223 | # layer 2
224 | h_ = self.edge_layer_4(h_)
225 | # h_ = F.relu(h_)
226 |
227 | # outer product
228 | adjacency_matrix = torch.bmm(h, h_.transpose(1, 2))
229 | # adjacency_matrix = adjacency_matrix.masked_fill(attn_mask, -np.inf)
230 |
231 | # softmaxed_adj = F.softmax(adjacency_matrix, dim=-1)
232 |
233 | # softmaxed_adj = softmaxed_adj.masked_fill(sf_mask, 0.)
234 |
235 | return adjacency_matrix
--------------------------------------------------------------------------------
/networks/memory_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 | import torch.nn.functional as F
5 | import torch.nn.init as init
6 |
7 |
8 | class AttentionGRUCell(nn.Module):
9 | '''
10 | Eq (1)~(4), then modify by Eq (11)
11 | When forwarding, we feed attention gate g into GRU
12 | '''
13 | def __init__(self, input_size, hidden_size):
14 | super(AttentionGRUCell, self).__init__()
15 | self.hidden_size = hidden_size
16 | self.Wr = nn.Linear(input_size, hidden_size)
17 | self.Ur = nn.Linear(hidden_size, hidden_size)
18 | self.W = nn.Linear(input_size, hidden_size)
19 | self.U = nn.Linear(hidden_size, hidden_size)
20 |
21 | def init_weights(self):
22 | self.Wr.weight.data.normal_(0.0, 0.02)
23 | self.Wr.bias.data.fill_(0)
24 | self.Ur.weight.data.normal_(0.0, 0.02)
25 | self.Ur.bias.data.fill_(0)
26 | self.W.weight.data.normal_(0.0, 0.02)
27 | self.W.bias.data.fill_(0)
28 | self.U.weight.data.normal_(0.0, 0.02)
29 | self.U.bias.data.fill_(0)
30 |
31 | def forward(self, fact, C, g):
32 | '''
33 | fact.size() -> (#batch, #hidden = #embedding)
34 | c.size() -> (#hidden, ) -> (#batch, #hidden = #embedding)
35 | r.size() -> (#batch, #hidden = #embedding)
36 | h_tilda.size() -> (#batch, #hidden = #embedding)
37 | g.size() -> (#batch, )
38 | '''
39 |
40 | r = torch.sigmoid(self.Wr(fact) + self.Ur(C))
41 | h_tilda = torch.tanh(self.W(fact) + r * self.U(C))
42 | g = g.unsqueeze(1).expand_as(h_tilda)
43 | h = g * h_tilda + (1 - g) * C
44 | return h
45 |
46 | class AttentionGRU(nn.Module):
47 | '''
48 | Section 3.3
49 | continuously run AttnGRU to get contextual vector c at each time t
50 | '''
51 | def __init__(self, input_size, hidden_size):
52 | super(AttentionGRU, self).__init__()
53 | self.hidden_size = hidden_size
54 | self.AGRUCell = AttentionGRUCell(input_size, hidden_size)
55 |
56 | def init_weights(self):
57 | self.AGRUCell.init_weights()
58 |
59 | def forward(self, facts, G):
60 | '''
61 | facts.size() -> (#batch, #sentence, #hidden = #embedding)
62 | fact.size() -> (#batch, #hidden = #embedding)
63 | G.size() -> (#batch, #sentence)
64 | g.size() -> (#batch, )
65 | C.size() -> (#batch, #hidden)
66 | '''
67 | batch_num, sen_num, embedding_size = facts.size()
68 | C = Variable(torch.zeros(self.hidden_size)).cuda()
69 | for sid in range(sen_num):
70 | fact = facts[:, sid, :]
71 | g = G[:, sid]
72 | if sid == 0:
73 | C = C.unsqueeze(0).expand_as(fact)
74 | C = self.AGRUCell(fact, C, g)
75 | return C
76 |
77 | class EpisodicMemory(nn.Module):
78 | '''
79 | Section 3.3
80 | '''
81 |
82 | def __init__(self, hidden_size):
83 | super(EpisodicMemory, self).__init__()
84 | self.AGRU = AttentionGRU(hidden_size, hidden_size)
85 | self.z1 = nn.Linear(4 * hidden_size, hidden_size)
86 | self.z2 = nn.Linear(hidden_size, 1)
87 | self.next_mem = nn.Linear(3 * hidden_size, hidden_size)
88 |
89 |
90 | def init_weights(self):
91 | self.z1.weight.data.normal_(0.0, 0.02)
92 | self.z1.bias.data.fill_(0)
93 | self.z2.weight.data.normal_(0.0, 0.02)
94 | self.z2.bias.data.fill_(0)
95 | self.next_mem.weight.data.normal_(0.0, 0.02)
96 | self.next_mem.bias.data.fill_(0)
97 | self.AGRU.init_weights()
98 |
99 |
100 | def make_interaction(self, frames, questions, prevM):
101 | '''
102 | frames.size() -> (#batch, T, #hidden = #embedding)
103 | questions.size() -> (#batch, 1, #hidden)
104 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
105 | z.size() -> (#batch, T, 4 x #embedding)
106 | G.size() -> (#batch, T)
107 | '''
108 | batch_num, T, embedding_size = frames.size()
109 | questions = questions.view(questions.size(0),1,questions.size(1))
110 |
111 |
112 | #questions = questions.expand_as(frames)
113 | #prevM = prevM.expand_as(frames)
114 |
115 | #print(questions.size(),prevM.size())
116 |
117 | # Eq (8)~(10)
118 | z = torch.cat([
119 | frames * questions,
120 | frames * prevM,
121 | torch.abs(frames - questions),
122 | torch.abs(frames - prevM)
123 | ], dim=2)
124 |
125 | z = z.view(-1, 4 * embedding_size)
126 |
127 | G = torch.tanh(self.z1(z))
128 | G = self.z2(G)
129 | G = G.view(batch_num, -1)
130 | G = F.softmax(G,dim=1)
131 | #print('G size',G.size())
132 | return G
133 |
134 | def forward(self, frames, questions, prevM):
135 | '''
136 | frames.size() -> (#batch, #sentence, #hidden = #embedding)
137 | questions.size() -> (#batch, #sentence = 1, #hidden)
138 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding)
139 | G.size() -> (#batch, #sentence)
140 | C.size() -> (#batch, #hidden)
141 | concat.size() -> (#batch, 3 x #embedding)
142 | '''
143 |
144 | '''
145 | section 3.3 - Attention based GRU
146 | input: F and q, as frames and questions
147 | then get gates g
148 | then (c,m,g) feed into memory update module Eq(13)
149 | output new memory state
150 | '''
151 | # print(frames.shape, questions.shape, prevM.shape)
152 |
153 | G = self.make_interaction(frames, questions, prevM)
154 | C = self.AGRU(frames, G)
155 | concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1)
156 | next_mem = F.relu(self.next_mem(concat))
157 | #print(next_mem.size())
158 | next_mem = next_mem.unsqueeze(1)
159 | return next_mem
160 |
--------------------------------------------------------------------------------
/networks/memory_rand.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from Attention import MultiModalAttentionModule
5 |
6 | class MemoryRamModule(nn.Module):
7 |
8 | def __init__(self, input_size=1024, hidden_size=512, memory_bank_size=100, device=None):
9 | """Set the hyper-parameters and build the layers."""
10 | super(MemoryRamModule, self).__init__()
11 |
12 | self.input_size = input_size
13 | self.hidden_size = hidden_size
14 | self.memory_bank_size = memory_bank_size
15 | self.device = device
16 |
17 | self.hidden_to_content = nn.Linear(hidden_size+input_size, hidden_size)
18 | #self.read_to_hidden = nn.Linear(hidden_size+input_size, 1)
19 | self.write_gate = nn.Linear(hidden_size+input_size, 1)
20 | self.write_prob = nn.Linear(hidden_size+input_size, memory_bank_size)
21 |
22 | self.read_gate = nn.Linear(hidden_size+input_size, 1)
23 | self.read_prob = nn.Linear(hidden_size+input_size, memory_bank_size)
24 |
25 |
26 | self.Wxh = nn.Parameter(torch.FloatTensor(input_size, hidden_size),requires_grad=True)
27 | self.Wrh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True)
28 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True)
29 | self.bh = nn.Parameter(torch.FloatTensor(hidden_size),requires_grad=True)
30 |
31 | self.init_weights()
32 |
33 |
34 | def init_weights(self):
35 | self.Wxh.data.normal_(0.0, 0.1)
36 | self.Wrh.data.normal_(0.0, 0.1)
37 | self.Whh.data.normal_(0.0, 0.1)
38 | self.bh.data.fill_(0)
39 |
40 |
41 | def forward(self, hidden_frames, nImg):
42 |
43 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device)
44 | memory_ram.fill_(0)
45 |
46 | h_t = torch.zeros(1, self.hidden_size).to(self.device)
47 |
48 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device)
49 |
50 | for t in range(nImg):
51 | x_t = hidden_frames[t:t+1,:]
52 |
53 | x_h_t = torch.cat([x_t,h_t],dim=1)
54 |
55 | ############# read ############
56 | ar = torch.softmax(self.read_prob( x_h_t ),dim=1) # read prob from memories
57 | go = torch.sigmoid(self.read_gate( x_h_t )) # read gate
58 | r = go * torch.matmul(ar,memory_ram) # read vector
59 |
60 | ######### h_t #########
61 | # Eq (17)
62 | m1 = torch.matmul(x_t, self.Wxh)
63 | m2 = torch.matmul(r, self.Wrh)
64 | m3 = torch.matmul(h_t, self.Whh)
65 | h_t_p1 = F.relu(m1 + m2 + m3 + self.bh) # Eq(17)
66 |
67 |
68 | ############# write ############
69 | c_t = F.relu( self.hidden_to_content(x_h_t) ) # Eq(15), content vector
70 | aw = torch.softmax(self.write_prob( x_h_t ),dim=1) # write prob to memories
71 | aw = aw.view(self.memory_bank_size,1)
72 | gw = torch.sigmoid(self.write_gate( x_h_t )) # write gate
73 | #print gw.size(),aw.size(),c_t.size(),memory_ram.size()
74 | memory_ram = gw * aw * c_t + (1.0-aw) * memory_ram # Eq(16)
75 |
76 | h_t = h_t_p1
77 | hiddens[t,:] = h_t
78 |
79 | #return memory_ram
80 | return hiddens
81 |
82 |
83 | class MemoryRamTwoStreamModule(nn.Module):
84 |
85 | def __init__(self, input_size, hidden_size=512, memory_bank_size=100, device=None):
86 | """Set the hyper-parameters and build the layers."""
87 | super(MemoryRamTwoStreamModule, self).__init__()
88 |
89 | self.input_size = input_size
90 | self.hidden_size = hidden_size
91 | self.memory_bank_size = memory_bank_size
92 | self.device = device
93 |
94 | self.hidden_to_content_a = nn.Linear(hidden_size+input_size, hidden_size)
95 | self.hidden_to_content_m = nn.Linear(hidden_size+input_size, hidden_size)
96 |
97 | self.write_prob = nn.Linear(hidden_size*3, 3)
98 | self.write_prob_a = nn.Linear(hidden_size+input_size, memory_bank_size)
99 | self.write_prob_m = nn.Linear(hidden_size+input_size, memory_bank_size)
100 |
101 | self.read_prob = nn.Linear(hidden_size*3, memory_bank_size)
102 |
103 | self.read_to_hidden = nn.Linear(hidden_size*2, hidden_size)
104 | self.read_to_hidden_a = nn.Linear(hidden_size*2+input_size, hidden_size)
105 | self.read_to_hidden_m = nn.Linear(hidden_size*2+input_size, hidden_size)
106 | self.init_weights()
107 |
108 | def init_weights(self):
109 | pass
110 |
111 |
112 | def forward(self, hidden_out_a, hidden_out_m, nImg):
113 |
114 |
115 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device)
116 | memory_ram.fill_(0)
117 |
118 | h_t_a = torch.zeros(1, self.hidden_size).to(self.device)
119 | h_t_m = torch.zeros(1, self.hidden_size).to(self.device)
120 | h_t = torch.zeros(1, self.hidden_size).to(self.device)
121 |
122 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device)
123 |
124 | for t in range(nImg):
125 | x_t_a = hidden_out_a[t:t+1,:]
126 | x_t_m = hidden_out_m[t:t+1,:]
127 |
128 |
129 | ############# read ############
130 | x_h_t_am = torch.cat([h_t_a,h_t_m,h_t],dim=1)
131 | ar = torch.softmax(self.read_prob( x_h_t_am ),dim=1) # read prob from memories
132 | r = torch.matmul(ar,memory_ram) # read vector
133 |
134 |
135 | ######### h_t #########
136 | # Eq (17)
137 | f_0 = torch.cat([r, h_t],dim=1)
138 | f_a = torch.cat([x_t_a, r, h_t_a],dim=1)
139 | f_m = torch.cat([x_t_m, r, h_t_m],dim=1)
140 |
141 | h_t_1 = F.relu(self.read_to_hidden(f_0))
142 | h_t_a1 = F.relu(self.read_to_hidden_a(f_a))
143 | h_t_m1 = F.relu(self.read_to_hidden_m(f_m))
144 |
145 |
146 | ############# write ############
147 |
148 | # write probability of [keep, write appearance, write motion]
149 | aw = torch.softmax(self.write_prob( x_h_t_am ),dim=1) # write prob to memories
150 | x_h_ta = torch.cat([h_t_a,x_t_a],dim=1)
151 | x_h_tm = torch.cat([h_t_m,x_t_m],dim=1)
152 |
153 |
154 | # write content
155 | c_t_a = F.relu( self.hidden_to_content_a(x_h_ta) ) # Eq(15), content vector
156 | c_t_m = F.relu( self.hidden_to_content_m(x_h_tm) ) # Eq(15), content vector
157 |
158 | aw_a = torch.softmax(self.write_prob_a( x_h_ta ),dim=1) # write prob to memories
159 | aw_m = torch.softmax(self.write_prob_m( x_h_tm ),dim=1) # write prob to memories
160 |
161 |
162 | aw_a = aw_a.view(self.memory_bank_size,1)
163 | aw_m = aw_m.view(self.memory_bank_size,1)
164 |
165 | memory_ram = aw[0,0] * memory_ram + aw[0,1] * aw_a * c_t_a + aw[0,2] * aw_m * c_t_m
166 |
167 |
168 | h_t = h_t_1
169 | h_t_a = h_t_a1
170 | h_t_m = h_t_m1
171 |
172 | hiddens[t,:] = h_t
173 |
174 |
175 | return hiddens
176 |
177 | class MMModule(nn.Module):
178 | def __init__(self, dim, input_drop_p, device):
179 | """Set the hyper-parameters and build the layers."""
180 | super(MMModule, self).__init__()
181 | self.hidden_size = dim
182 | self.lstm_mm_1 = nn.LSTMCell(dim, dim)
183 | self.lstm_mm_2 = nn.LSTMCell(dim, dim)
184 | self.hidden_encoder_1 = nn.Linear(dim * 2, dim)
185 | self.hidden_encoder_2 = nn.Linear(dim * 2, dim)
186 | self.dropout = nn.Dropout(input_drop_p)
187 | self.mm_att = MultiModalAttentionModule(dim)
188 | self.device = device
189 | self.init_weights()
190 |
191 |
192 | def init_weights(self):
193 | nn.init.xavier_normal_(self.hidden_encoder_1.weight)
194 | nn.init.xavier_normal_(self.hidden_encoder_2.weight)
195 | self.init_hiddens()
196 |
197 | def init_hiddens(self):
198 | s_t = torch.zeros(1, self.hidden_size).to(self.device)
199 | s_t2 = torch.zeros(1, self.hidden_size).to(self.device)
200 | c_t = torch.zeros(1, self.hidden_size).to(self.device)
201 | c_t2 = torch.zeros(1, self.hidden_size).to(self.device)
202 | return s_t, s_t2, c_t, c_t2
203 |
204 | def forward(self, svt_tmp, memory_ram_vid, memory_ram_txt, loop=3):
205 | """
206 |
207 | :param svt_tmp:
208 | :param memory_ram_vid:
209 | :param memory_ram_txt:
210 | :param loop:
211 | :return:
212 | """
213 |
214 | sm_q1, sm_q2, cm_q1, cm_q2 = self.init_hiddens()
215 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_1(svt_tmp)))
216 |
217 | for _ in range(loop):
218 | sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1))
219 | sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2))
220 |
221 | mm_o1 = self.mm_att(sm_q2, memory_ram_vid, memory_ram_txt)
222 | mm_o2 = torch.cat((sm_q2, mm_o1), dim=1)
223 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_2(mm_o2)))
224 |
225 | smq = torch.cat((sm_q1, sm_q2), dim=1)
226 |
227 | return smq
--------------------------------------------------------------------------------
/networks/q_v_transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import torchnlp_nn as nlpnn
6 |
7 |
8 | def padding_mask(seq_q, seq_k):
9 | # seq_k of shape (batch, k_len) and seq_q (batch, q_len), not embedded. q and k are padded with 0.
10 | seq_q = torch.unsqueeze(seq_q, 2)
11 | seq_k = torch.unsqueeze(seq_k, 2)
12 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2))
13 | pad_mask = pad_mask.eq(0)
14 | return pad_mask
15 |
16 |
17 | def padding_mask_transformer(seq_q, seq_k):
18 | # original padding_mask in transformer, for masking out the padding part of key sequence.
19 | len_q = seq_q.size(1)
20 | # `PAD` is 0
21 | pad_mask = seq_k.eq(0)
22 | pad_mask = pad_mask.unsqueeze(1).expand(
23 | -1, len_q, -1) # shape [B, L_q, L_k]
24 | return pad_mask
25 |
26 |
27 | def padding_mask_embedded(seq_q, seq_k):
28 | # seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len)
29 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2))
30 | pad_mask = pad_mask.eq(0)
31 | return pad_mask
32 |
33 |
34 | def padding_mask_k(seq_q, seq_k):
35 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
36 | In batch 0:
37 | [[x x x 0] [[0 0 0 1]
38 | [x x x 0]-> [0 0 0 1]
39 | [x x x 0]] [0 0 0 1]] uint8
40 | """
41 | fake_q = torch.ones_like(seq_q)
42 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2))
43 | pad_mask = pad_mask.eq(0)
44 | # pad_mask = pad_mask.lt(1e-3)
45 | return pad_mask
46 |
47 |
48 | def padding_mask_q(seq_q, seq_k):
49 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len).
50 | In batch 0:
51 | [[x x x x] [[0 0 0 0]
52 | [x x x x] -> [0 0 0 0]
53 | [0 0 0 0]] [1 1 1 1]] uint8
54 | """
55 | fake_k = torch.ones_like(seq_k)
56 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2))
57 | pad_mask = pad_mask.eq(0)
58 | # pad_mask = pad_mask.lt(1e-3)
59 | return pad_mask
60 |
61 |
62 | class PositionalEncoding(nn.Module):
63 |
64 | def __init__(self, d_model, max_seq_len):
65 | super(PositionalEncoding, self).__init__()
66 | self.max_seq_len = max_seq_len
67 |
68 | position_encoding = np.array(
69 | [
70 | [
71 | pos / np.power(10000, 2.0 * (j // 2) / d_model)
72 | for j in range(d_model)
73 | ]
74 | for pos in range(max_seq_len)
75 | ])
76 | position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2])
77 | position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2])
78 |
79 | pad_row = torch.zeros([1, d_model])
80 | position_encoding = torch.cat(
81 | (pad_row, torch.from_numpy(position_encoding).float()))
82 |
83 | self.position_encoding = nn.Embedding(max_seq_len + 1, d_model)
84 | self.position_encoding.weight = nn.Parameter(
85 | position_encoding, requires_grad=False)
86 |
87 | def forward(self, input_len):
88 | # max_len = torch.max(input_len)
89 | max_len = self.max_seq_len
90 | tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor
91 | input_pos = [
92 | list(range(1, l + 1)) + [0] * (max_len - l.item())
93 | for l in input_len
94 | ]
95 | input_pos = tensor(input_pos)
96 | return self.position_encoding(input_pos)
97 |
98 |
99 | class PositionalWiseFeedForward(nn.Module):
100 |
101 | def __init__(self, model_dim=512, ffn_dim=512, dropout=0.0):
102 | super(PositionalWiseFeedForward, self).__init__()
103 | self.w1 = nn.Conv1d(model_dim, ffn_dim, 1)
104 | self.w2 = nn.Conv1d(model_dim, ffn_dim, 1)
105 | self.dropout = nn.Dropout(dropout)
106 | self.layer_norm = nn.LayerNorm(model_dim)
107 |
108 | def forward(self, x):
109 | # x of shape (bs, seq_len, hs)
110 | output = x.transpose(1, 2)
111 | output = self.w2(F.relu(self.w1(output)))
112 | output = self.dropout(output.transpose(1, 2))
113 |
114 | # add residual and norm layer
115 | output = self.layer_norm(x + output)
116 | return output
117 |
118 |
119 | class MaskedPositionalWiseFeedForward(nn.Module):
120 |
121 | def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0):
122 | super().__init__()
123 | self.w1 = nn.Linear(model_dim, ffn_dim, bias=False)
124 | self.w2 = nn.Linear(ffn_dim, model_dim, bias=False)
125 | self.dropout = nn.Dropout(dropout)
126 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
127 |
128 | def forward(self, x):
129 | # x of shape (bs, seq_len, hs)
130 | output = self.w2(F.relu(self.w1(x)))
131 | output = self.dropout(output)
132 |
133 | # add residual and norm layer
134 | output = self.layer_norm(x + output)
135 | return output
136 |
137 |
138 | class ScaledDotProductAttention(nn.Module):
139 | """Scaled dot-product attention mechanism."""
140 |
141 | def __init__(self, attention_dropout=0.0):
142 | super(ScaledDotProductAttention, self).__init__()
143 | self.dropout = nn.Dropout(attention_dropout)
144 | self.softmax = nn.Softmax(dim=-1)
145 |
146 | def forward(self, q, k, v, scale=None, attn_mask=None):
147 | """
148 | Args:
149 | q: [B, L_q, D_q]
150 | k: [B, L_k, D_k]
151 | v: [B, L_v, D_v]
152 | """
153 | attention = torch.matmul(q, k.transpose(1, 2))
154 | if scale is not None:
155 | attention = attention * scale
156 | if attn_mask is not None:
157 | attention = attention.masked_fill(attn_mask, -np.inf)
158 | attention = self.softmax(attention)
159 | attention = self.dropout(attention)
160 | output = torch.matmul(attention, v)
161 | return output, attention
162 |
163 |
164 | class MaskedScaledDotProductAttention(nn.Module):
165 | """Scaled dot-product attention mechanism."""
166 |
167 | def __init__(self, attention_dropout=0.0):
168 | super().__init__()
169 | self.dropout = nn.Dropout(attention_dropout)
170 | self.softmax = nn.Softmax(dim=-1)
171 |
172 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None):
173 | """
174 | Args:
175 | q: [B, L_q, D_q]
176 | k: [B, L_k, D_k]
177 | v: [B, L_v, D_v]
178 | """
179 | attention = torch.matmul(q, k.transpose(-2, -1))
180 | if scale is not None:
181 | attention = attention * scale
182 | if attn_mask is not None:
183 | attention = attention.masked_fill(attn_mask, -np.inf)
184 | attention = self.softmax(attention)
185 | attention = attention.masked_fill(softmax_mask, 0.)
186 | attention = self.dropout(attention)
187 | output = torch.matmul(attention, v)
188 | return output, attention
189 |
190 |
191 | class MultiHeadAttention(nn.Module):
192 |
193 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
194 | super(MultiHeadAttention, self).__init__()
195 |
196 | self.dim_per_head = model_dim // num_heads
197 | self.num_heads = num_heads
198 | self.linear_k = nn.Linear(
199 | model_dim, self.dim_per_head * num_heads, bias=False)
200 | self.linear_v = nn.Linear(
201 | model_dim, self.dim_per_head * num_heads, bias=False)
202 | self.linear_q = nn.Linear(
203 | model_dim, self.dim_per_head * num_heads, bias=False)
204 |
205 | self.dot_product_attention = ScaledDotProductAttention(dropout)
206 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False)
207 | self.dropout = nn.Dropout(dropout)
208 | self.layer_norm = nn.LayerNorm(model_dim)
209 |
210 | def forward(self, query, key, value, attn_mask=None):
211 | residual = query
212 |
213 | dim_per_head = self.dim_per_head
214 | num_heads = self.num_heads
215 | batch_size = key.size(0)
216 |
217 | # linear projection
218 | key = self.linear_k(key)
219 | value = self.linear_v(value)
220 | query = self.linear_q(query)
221 |
222 | # split by heads
223 | key = key.view(batch_size * num_heads, -1, dim_per_head)
224 | value = value.view(batch_size * num_heads, -1, dim_per_head)
225 | query = query.view(batch_size * num_heads, -1, dim_per_head)
226 |
227 | if attn_mask is not None:
228 | attn_mask = attn_mask.repeat(num_heads, 1, 1)
229 |
230 | # scaled dot product attention
231 | scale = (key.size(-1) // num_heads)**-0.5
232 | context, attention = self.dot_product_attention(
233 | query, key, value, scale, attn_mask)
234 |
235 | # concat heads
236 | context = context.view(batch_size, -1, dim_per_head * num_heads)
237 |
238 | # final linear projection
239 | output = self.linear_final(context)
240 |
241 | # dropout
242 | output = self.dropout(output)
243 |
244 | # add residual and norm layer
245 | output = self.layer_norm(residual + output)
246 |
247 | return output, attention
248 |
249 |
250 | class MaskedMultiHeadAttention(nn.Module):
251 |
252 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0):
253 | super().__init__()
254 |
255 | self.dim_per_head = model_dim // num_heads
256 | self.num_heads = num_heads
257 | self.linear_k = nn.Linear(
258 | model_dim, self.dim_per_head * num_heads, bias=False)
259 | self.linear_v = nn.Linear(
260 | model_dim, self.dim_per_head * num_heads, bias=False)
261 | self.linear_q = nn.Linear(
262 | model_dim, self.dim_per_head * num_heads, bias=False)
263 |
264 | self.dot_product_attention = MaskedScaledDotProductAttention(dropout)
265 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False)
266 | self.dropout = nn.Dropout(dropout)
267 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False)
268 |
269 | def forward(self, query, key, value, attn_mask=None, softmax_mask=None):
270 | residual = query
271 |
272 | dim_per_head = self.dim_per_head
273 | num_heads = self.num_heads
274 | batch_size = key.size(0)
275 |
276 | # linear projection
277 | key = self.linear_k(key)
278 | value = self.linear_v(value)
279 | query = self.linear_q(query)
280 |
281 | # split by heads
282 | key = key.view(batch_size, -1, num_heads, dim_per_head).transpose(1, 2)
283 | value = value.view(batch_size, -1, num_heads,
284 | dim_per_head).transpose(1, 2)
285 | query = query.view(batch_size, -1, num_heads,
286 | dim_per_head).transpose(1, 2)
287 |
288 | if attn_mask is not None:
289 | attn_mask = attn_mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
290 | if softmax_mask is not None:
291 | softmax_mask = softmax_mask.unsqueeze(1).repeat(1, num_heads, 1, 1)
292 | # scaled dot product attention
293 | # key.size(-1) is 64?
294 | scale = key.size(-1)**-0.5
295 | context, attention = self.dot_product_attention(
296 | query, key, value, scale, attn_mask, softmax_mask)
297 |
298 | # concat heads
299 | context = context.transpose(1, 2).contiguous().view(
300 | batch_size, -1, dim_per_head * num_heads)
301 |
302 | # final linear projection
303 | output = self.linear_final(context)
304 |
305 | # dropout
306 | output = self.dropout(output)
307 |
308 | # add residual and norm layer
309 | output = self.layer_norm(residual + output)
310 |
311 | return output, attention
312 |
313 |
314 | class SelfTransformerLayer(nn.Module):
315 |
316 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
317 | super().__init__()
318 |
319 | self.transformer = MaskedMultiHeadAttention(
320 | model_dim, num_heads, dropout)
321 | self.feed_forward = MaskedPositionalWiseFeedForward(
322 | model_dim, ffn_dim, dropout)
323 |
324 | def forward(self, input, attn_mask=None, sf_mask=None):
325 | output, attention = self.transformer(
326 | input, input, input, attn_mask, sf_mask)
327 | # feed forward network
328 | output = self.feed_forward(output)
329 |
330 | return output, attention
331 |
332 |
333 | class SelfTransformer(nn.Module):
334 |
335 | def __init__(
336 | self,
337 | max_len=35,
338 | num_layers=2,
339 | model_dim=512,
340 | num_heads=8,
341 | ffn_dim=2048,
342 | dropout=0.0,
343 | position=False):
344 | super().__init__()
345 |
346 | self.position = position
347 |
348 | self.encoder_layers = nn.ModuleList(
349 | [
350 | SelfTransformerLayer(model_dim, num_heads, ffn_dim, dropout)
351 | for _ in range(num_layers)
352 | ])
353 |
354 | # max_seq_len is 35 or 80
355 | self.pos_embedding = PositionalEncoding(model_dim, max_len)
356 |
357 | def forward(self, input, input_length):
358 | # q_length of shape (batch, ), each item is the length of the seq
359 | if self.position:
360 | input += self.pos_embedding(input_length)[:, :input.size()[1], :]
361 |
362 | attention_mask = padding_mask_k(input, input)
363 | softmax_mask = padding_mask_q(input, input)
364 |
365 | attentions = []
366 | for encoder in self.encoder_layers:
367 | input, attention = encoder(input, attention_mask, softmax_mask)
368 | attentions.append(attention)
369 |
370 | return input, attentions
371 |
372 |
373 | class SelfAttentionLayer(nn.Module):
374 |
375 | def __init__(self, hidden_size, dropout_p=0.0):
376 | super().__init__()
377 | self.dropout = nn.Dropout(dropout_p)
378 | self.softmax = nn.Softmax(dim=-1)
379 |
380 | self.linear_k = nlpnn.WeightDropLinear(
381 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
382 | self.linear_q = nlpnn.WeightDropLinear(
383 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
384 | self.linear_v = nlpnn.WeightDropLinear(
385 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
386 |
387 | self.linear_final = nlpnn.WeightDropLinear(
388 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
389 |
390 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
391 |
392 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None):
393 | """
394 | Args:
395 | q: [B, L_q, D_q]
396 | k: [B, L_k, D_k]
397 | v: [B, L_v, D_v]
398 | """
399 | residual = q
400 |
401 | if attn_mask is None or softmax_mask is None:
402 | attn_mask = padding_mask_k(q, k)
403 | softmax_mask = padding_mask_q(q, k)
404 |
405 | # linear projection
406 | k = self.linear_k(k)
407 | v = self.linear_v(v)
408 | q = self.linear_q(q)
409 |
410 | scale = k.size(-1)**-0.5
411 |
412 | attention = torch.bmm(q, k.transpose(1, 2))
413 | if scale is not None:
414 | attention = attention * scale
415 | if attn_mask is not None:
416 | attention = attention.masked_fill(attn_mask, -np.inf)
417 | attention = self.softmax(attention)
418 | attention = attention.masked_fill(softmax_mask, 0.)
419 |
420 | # attention = self.dropout(attention)
421 | output = torch.bmm(attention, v)
422 | output = self.linear_final(output)
423 | output = self.layer_norm(output + residual)
424 | return output, attention
425 |
426 |
427 | class SelfAttention(nn.Module):
428 |
429 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
430 | super().__init__()
431 |
432 | self.encoder_layers = nn.ModuleList(
433 | [
434 | SelfAttentionLayer(hidden_size, dropout_p)
435 | for _ in range(n_layers)
436 | ])
437 |
438 | def forward(self, input):
439 |
440 | # q_attention_mask of shape (bs, q_len, v_len)
441 | attn_mask = padding_mask_k(input, input)
442 | # v_attention_mask of shape (bs, v_len, q_len)
443 | softmax_mask = padding_mask_q(input, input)
444 |
445 | attentions = []
446 | for encoder in self.encoder_layers:
447 | input, attention = encoder(
448 | input,
449 | input,
450 | input,
451 | attn_mask=attn_mask,
452 | softmax_mask=softmax_mask)
453 | attentions.append(attention)
454 |
455 | return input, attentions
456 |
457 |
458 | class CoAttentionLayer(nn.Module):
459 |
460 | def __init__(self, hidden_size, dropout_p=0.0):
461 | super().__init__()
462 | self.dropout = nn.Dropout(dropout_p)
463 | self.softmax = nn.Softmax(dim=-1)
464 |
465 | self.linear_question = nlpnn.WeightDropLinear(
466 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
467 | self.linear_video = nlpnn.WeightDropLinear(
468 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
469 | self.linear_v_question = nlpnn.WeightDropLinear(
470 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
471 | self.linear_v_video = nlpnn.WeightDropLinear(
472 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
473 |
474 | self.linear_final_qv = nlpnn.WeightDropLinear(
475 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
476 | self.linear_final_vq = nlpnn.WeightDropLinear(
477 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
478 |
479 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False)
480 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False)
481 |
482 | def forward(
483 | self,
484 | question,
485 | video,
486 | scale=None,
487 | attn_mask=None,
488 | softmax_mask=None,
489 | attn_mask_=None,
490 | softmax_mask_=None):
491 | """
492 | Args:
493 | q: [B, L_q, D_q]
494 | k: [B, L_k, D_k]
495 | v: [B, L_v, D_v]
496 | """
497 | q = question
498 | v = video
499 |
500 | if attn_mask is None or softmax_mask is None:
501 | attn_mask = padding_mask_k(question, video)
502 | softmax_mask = padding_mask_q(question, video)
503 | if attn_mask_ is None or softmax_mask_ is None:
504 | attn_mask_ = padding_mask_k(video, question)
505 | softmax_mask_ = padding_mask_q(video, question)
506 |
507 | # linear projection
508 | question_q = self.linear_question(question)
509 | video_k = self.linear_video(video)
510 | question = self.linear_v_question(question)
511 | video = self.linear_v_video(video)
512 |
513 | scale = video.size(-1)**-0.5
514 |
515 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2))
516 | if scale is not None:
517 | attention_qv = attention_qv * scale
518 | if attn_mask is not None:
519 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf)
520 | attention_qv = self.softmax(attention_qv)
521 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.)
522 |
523 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2))
524 | if scale is not None:
525 | attention_vq = attention_vq * scale
526 | if attn_mask_ is not None:
527 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf)
528 | attention_vq = self.softmax(attention_vq)
529 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.)
530 |
531 | # attention = self.dropout(attention)
532 | output_qv = torch.bmm(attention_qv, video)
533 | output_qv = self.linear_final_qv(output_qv)
534 | output_q = self.layer_norm_qv(output_qv + q)
535 |
536 | output_vq = torch.bmm(attention_vq, question)
537 | output_vq = self.linear_final_vq(output_vq)
538 | output_v = self.layer_norm_vq(output_vq + v)
539 | return output_q, output_v
540 |
541 |
542 | class CoAttention(nn.Module):
543 |
544 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
545 | super().__init__()
546 |
547 | self.encoder_layers = nn.ModuleList(
548 | [CoAttentionLayer(hidden_size, dropout_p) for _ in range(n_layers)])
549 |
550 | def forward(self, question, video):
551 | attn_mask = padding_mask_k(question, video)
552 | softmax_mask = padding_mask_q(question, video)
553 | attn_mask_ = padding_mask_k(video, question)
554 | softmax_mask_ = padding_mask_q(video, question)
555 |
556 | for encoder in self.encoder_layers:
557 | question, video = encoder(
558 | question,
559 | video,
560 | attn_mask=attn_mask,
561 | softmax_mask=softmax_mask,
562 | attn_mask_=attn_mask_,
563 | softmax_mask_=softmax_mask_)
564 |
565 | return question, video
566 |
567 |
568 | class CoConcatAttentionLayer(nn.Module):
569 |
570 | def __init__(self, hidden_size, dropout_p=0.0):
571 | super().__init__()
572 | self.dropout = nn.Dropout(dropout_p)
573 | self.softmax = nn.Softmax(dim=-1)
574 |
575 | self.linear_question = nlpnn.WeightDropLinear(
576 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
577 | self.linear_video = nlpnn.WeightDropLinear(
578 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
579 | self.linear_v_question = nlpnn.WeightDropLinear(
580 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
581 | self.linear_v_video = nlpnn.WeightDropLinear(
582 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
583 |
584 | self.linear_final_qv = nn.Sequential(
585 | nlpnn.WeightDropLinear(
586 | 2 * hidden_size,
587 | hidden_size,
588 | weight_dropout=dropout_p,
589 | bias=False), nn.ReLU(),
590 | nlpnn.WeightDropLinear(
591 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False))
592 | self.linear_final_vq = nn.Sequential(
593 | nlpnn.WeightDropLinear(
594 | 2 * hidden_size,
595 | hidden_size,
596 | weight_dropout=dropout_p,
597 | bias=False), nn.ReLU(),
598 | nlpnn.WeightDropLinear(
599 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False))
600 |
601 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False)
602 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False)
603 |
604 | def forward(
605 | self,
606 | question,
607 | video,
608 | scale=None,
609 | attn_mask=None,
610 | softmax_mask=None,
611 | attn_mask_=None,
612 | softmax_mask_=None):
613 | """
614 | Args:
615 | q: [B, L_q, D_q]
616 | k: [B, L_k, D_k]
617 | v: [B, L_v, D_v]
618 | """
619 | q = question
620 | v = video
621 |
622 | if attn_mask is None or softmax_mask is None:
623 | attn_mask = padding_mask_k(question, video)
624 | softmax_mask = padding_mask_q(question, video)
625 | if attn_mask_ is None or softmax_mask_ is None:
626 | attn_mask_ = padding_mask_k(video, question)
627 | softmax_mask_ = padding_mask_q(video, question)
628 |
629 | # linear projection
630 | question_q = self.linear_question(question)
631 | video_k = self.linear_video(video)
632 | question = self.linear_v_question(question)
633 | video = self.linear_v_video(video)
634 |
635 | scale = video.size(-1)**-0.5
636 |
637 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2))
638 | if scale is not None:
639 | attention_qv = attention_qv * scale
640 | if attn_mask is not None:
641 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf)
642 | attention_qv = self.softmax(attention_qv)
643 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.)
644 |
645 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2))
646 | if scale is not None:
647 | attention_vq = attention_vq * scale
648 | if attn_mask_ is not None:
649 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf)
650 | attention_vq = self.softmax(attention_vq)
651 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.)
652 |
653 | # attention = self.dropout(attention)
654 | output_qv = torch.bmm(attention_qv, video)
655 | output_qv = self.linear_final_qv(torch.cat((output_qv, q), dim=-1))
656 | # output_q = self.layer_norm_qv(output_qv + q)
657 | output_q = self.layer_norm_qv(output_qv)
658 |
659 | output_vq = torch.bmm(attention_vq, question)
660 | output_vq = self.linear_final_vq(torch.cat((output_vq, v), dim=-1))
661 | # output_v = self.layer_norm_vq(output_vq + v)
662 | output_v = self.layer_norm_vq(output_vq)
663 | return output_q, output_v
664 |
665 |
666 | class CoConcatAttention(nn.Module):
667 |
668 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
669 | super().__init__()
670 |
671 | self.encoder_layers = nn.ModuleList(
672 | [
673 | CoConcatAttentionLayer(hidden_size, dropout_p)
674 | for _ in range(n_layers)
675 | ])
676 |
677 | def forward(self, question, video):
678 | attn_mask = padding_mask_k(question, video)
679 | softmax_mask = padding_mask_q(question, video)
680 | attn_mask_ = padding_mask_k(video, question)
681 | softmax_mask_ = padding_mask_q(video, question)
682 |
683 | for encoder in self.encoder_layers:
684 | question, video = encoder(
685 | question,
686 | video,
687 | attn_mask=attn_mask,
688 | softmax_mask=softmax_mask,
689 | attn_mask_=attn_mask_,
690 | softmax_mask_=softmax_mask_)
691 |
692 | return question, video
693 |
694 |
695 | class CoSiameseAttentionLayer(nn.Module):
696 |
697 | def __init__(self, hidden_size, dropout_p=0.0):
698 | super().__init__()
699 | self.dropout = nn.Dropout(dropout_p)
700 | self.softmax = nn.Softmax(dim=-1)
701 |
702 | self.linear_question = nlpnn.WeightDropLinear(
703 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
704 | self.linear_video = nlpnn.WeightDropLinear(
705 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
706 | self.linear_v_question = nlpnn.WeightDropLinear(
707 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
708 | self.linear_v_video = nlpnn.WeightDropLinear(
709 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
710 |
711 | self.linear_final = nn.Sequential(
712 | nlpnn.WeightDropLinear(
713 | 2 * hidden_size,
714 | hidden_size,
715 | weight_dropout=dropout_p,
716 | bias=False), nn.ReLU(),
717 | nlpnn.WeightDropLinear(
718 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False))
719 |
720 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False)
721 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False)
722 |
723 | def forward(
724 | self,
725 | question,
726 | video,
727 | scale=None,
728 | attn_mask=None,
729 | softmax_mask=None,
730 | attn_mask_=None,
731 | softmax_mask_=None):
732 | """
733 | Args:
734 | q: [B, L_q, D_q]
735 | k: [B, L_k, D_k]
736 | v: [B, L_v, D_v]
737 | """
738 | q = question
739 | v = video
740 |
741 | if attn_mask is None or softmax_mask is None:
742 | attn_mask = padding_mask_k(question, video)
743 | softmax_mask = padding_mask_q(question, video)
744 | if attn_mask_ is None or softmax_mask_ is None:
745 | attn_mask_ = padding_mask_k(video, question)
746 | softmax_mask_ = padding_mask_q(video, question)
747 |
748 | # linear projection
749 | question_q = self.linear_question(question)
750 | video_k = self.linear_video(video)
751 | question = self.linear_v_question(question)
752 | video = self.linear_v_video(video)
753 |
754 | scale = video.size(-1)**-0.5
755 |
756 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2))
757 | if scale is not None:
758 | attention_qv = attention_qv * scale
759 | if attn_mask is not None:
760 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf)
761 | attention_qv = self.softmax(attention_qv)
762 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.)
763 |
764 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2))
765 | if scale is not None:
766 | attention_vq = attention_vq * scale
767 | if attn_mask_ is not None:
768 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf)
769 | attention_vq = self.softmax(attention_vq)
770 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.)
771 |
772 | # attention = self.dropout(attention)
773 | output_qv = torch.bmm(attention_qv, video)
774 | output_qv = self.linear_final(torch.cat((output_qv, q), dim=-1))
775 | # output_q = self.layer_norm_qv(output_qv + q)
776 | output_q = self.layer_norm_qv(output_qv)
777 |
778 | output_vq = torch.bmm(attention_vq, question)
779 | output_vq = self.linear_final(torch.cat((output_vq, v), dim=-1))
780 | # output_v = self.layer_norm_vq(output_vq + v)
781 | output_v = self.layer_norm_vq(output_vq)
782 | return output_q, output_v
783 |
784 |
785 | class CoSiameseAttention(nn.Module):
786 |
787 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
788 | super().__init__()
789 |
790 | self.encoder_layers = nn.ModuleList(
791 | [
792 | CoSiameseAttentionLayer(hidden_size, dropout_p)
793 | for _ in range(n_layers)
794 | ])
795 |
796 | def forward(self, question, video):
797 | attn_mask = padding_mask_k(question, video)
798 | softmax_mask = padding_mask_q(question, video)
799 | attn_mask_ = padding_mask_k(video, question)
800 | softmax_mask_ = padding_mask_q(video, question)
801 |
802 | for encoder in self.encoder_layers:
803 | question, video = encoder(
804 | question,
805 | video,
806 | attn_mask=attn_mask,
807 | softmax_mask=softmax_mask,
808 | attn_mask_=attn_mask_,
809 | softmax_mask_=softmax_mask_)
810 |
811 | return question, video
812 |
813 |
814 | class SingleAttentionLayer(nn.Module):
815 |
816 | def __init__(self, hidden_size, dropout_p=0.0):
817 | super().__init__()
818 | self.dropout = nn.Dropout(dropout_p)
819 | self.softmax = nn.Softmax(dim=-1)
820 |
821 | self.linear_q = nlpnn.WeightDropLinear(
822 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
823 | self.linear_v = nlpnn.WeightDropLinear(
824 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
825 | self.linear_k = nlpnn.WeightDropLinear(
826 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
827 |
828 | self.linear_final = nlpnn.WeightDropLinear(
829 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)
830 |
831 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
832 |
833 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None):
834 | """
835 | Args:
836 | q: [B, L_q, D_q]
837 | k: [B, L_k, D_k]
838 | v: [B, L_v, D_v]
839 | Return: Same shape to q, but in 'v' space, soft knn
840 | """
841 |
842 | if attn_mask is None or softmax_mask is None:
843 | attn_mask = padding_mask_k(q, k)
844 | softmax_mask = padding_mask_q(q, k)
845 |
846 | # linear projection
847 | q = self.linear_q(q)
848 | k = self.linear_k(k)
849 | v = self.linear_v(v)
850 |
851 | scale = v.size(-1)**-0.5
852 |
853 | attention = torch.bmm(q, k.transpose(-2, -1))
854 | if scale is not None:
855 | attention = attention * scale
856 | if attn_mask is not None:
857 | attention = attention.masked_fill(attn_mask, -np.inf)
858 | attention = self.softmax(attention)
859 | attention = attention.masked_fill(softmax_mask, 0.)
860 |
861 | # attention = self.dropout(attention)
862 | output = torch.bmm(attention, v)
863 | output = self.linear_final(output)
864 | output = self.layer_norm(output + q)
865 |
866 | return output
867 |
868 |
869 | class SingleAttention(nn.Module):
870 |
871 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0):
872 | super().__init__()
873 |
874 | self.encoder_layers = nn.ModuleList(
875 | [
876 | SingleAttentionLayer(hidden_size, dropout_p)
877 | for _ in range(n_layers)
878 | ])
879 |
880 | def forward(self, q, v):
881 | attn_mask = padding_mask_k(q, v)
882 | softmax_mask = padding_mask_q(q, v)
883 |
884 | for encoder in self.encoder_layers:
885 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask)
886 |
887 | return q
888 |
889 |
890 | class SoftKNN(nn.Module):
891 |
892 | def __init__(self, model_dim=512, num_heads=1, dropout=0.0):
893 | super().__init__()
894 |
895 | self.dim_per_head = model_dim // num_heads
896 | self.num_heads = num_heads
897 | self.linear_k = nn.Linear(
898 | model_dim, self.dim_per_head * num_heads, bias=False)
899 | self.linear_v = nn.Linear(
900 | model_dim, self.dim_per_head * num_heads, bias=False)
901 | self.linear_q = nn.Linear(
902 | model_dim, self.dim_per_head * num_heads, bias=False)
903 |
904 | self.dot_product_attention = ScaledDotProductAttention(dropout)
905 |
906 | def forward(self, query, key, value, attn_mask=None):
907 |
908 | dim_per_head = self.dim_per_head
909 | num_heads = self.num_heads
910 | batch_size = key.size(0)
911 |
912 | # linear projection
913 | key = self.linear_k(key)
914 | value = self.linear_v(value)
915 | query = self.linear_q(query)
916 |
917 | # split by heads
918 | key = key.view(batch_size * num_heads, -1, dim_per_head)
919 | value = value.view(batch_size * num_heads, -1, dim_per_head)
920 | query = query.view(batch_size * num_heads, -1, dim_per_head)
921 |
922 | if attn_mask is not None:
923 | attn_mask = attn_mask.repeat(num_heads, 1, 1)
924 | # scaled dot product attention
925 | scale = (key.size(-1) // num_heads)**-0.5
926 | context, attention = self.dot_product_attention(
927 | query, key, value, scale, attn_mask)
928 |
929 | # concat heads
930 | output = context.view(batch_size, -1, dim_per_head * num_heads)
931 |
932 | return output, attention
933 |
934 |
935 | class CrossoverTransformerLayer(nn.Module):
936 |
937 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
938 | super().__init__()
939 |
940 | self.v_transformer = MultiHeadAttention(model_dim, num_heads, dropout)
941 | self.q_transformer = MultiHeadAttention(model_dim, num_heads, dropout)
942 | self.v_feed_forward = PositionalWiseFeedForward(
943 | model_dim, ffn_dim, dropout)
944 | self.q_feed_forward = PositionalWiseFeedForward(
945 | model_dim, ffn_dim, dropout)
946 |
947 | def forward(self, question, video, q_mask=None, v_mask=None):
948 | # self attention, v_attention of shape (bs, v_len, q_len)
949 | video_, v_attention = self.v_transformer(
950 | video, question, question, v_mask)
951 | # feed forward network
952 | video_ = self.v_feed_forward(video_)
953 |
954 | # self attention, q_attention of shape (bs, q_len, v_len)
955 | question_, q_attention = self.q_transformer(
956 | question, video, video, q_mask)
957 | # feed forward network
958 | question_ = self.q_feed_forward(question_)
959 |
960 | return video_, question_, v_attention, q_attention
961 |
962 |
963 | class CrossoverTransformer(nn.Module):
964 |
965 | def __init__(
966 | self,
967 | q_max_len=35,
968 | v_max_len=80,
969 | num_layers=2,
970 | model_dim=512,
971 | num_heads=8,
972 | ffn_dim=2048,
973 | dropout=0.0):
974 | super().__init__()
975 |
976 | self.encoder_layers = nn.ModuleList(
977 | [
978 | CrossoverTransformerLayer(
979 | model_dim, num_heads, ffn_dim, dropout)
980 | for _ in range(num_layers)
981 | ])
982 |
983 | # max_seq_len is 35 or 80
984 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len)
985 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len)
986 |
987 | def forward(self, question, video, q_length, v_length):
988 | # q_length of shape (batch, ), each item is the length of the seq
989 | question += self.q_pos_embedding(q_length)[:, :question.size()[1], :]
990 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :]
991 |
992 | # q_attention_mask of shape (bs, q_len, v_len)
993 | q_attention_mask = padding_mask_k(question, video)
994 | # v_attention_mask of shape (bs, v_len, q_len)
995 | v_attention_mask = padding_mask_k(video, question)
996 |
997 | q_attentions = []
998 | v_attentions = []
999 | for encoder in self.encoder_layers:
1000 | video, question, v_attention, q_attention = encoder(
1001 | question, video, q_attention_mask, v_attention_mask)
1002 | q_attentions.append(q_attention)
1003 | v_attentions.append(v_attention)
1004 |
1005 | return question, video, q_attentions, v_attentions
1006 |
1007 |
1008 | class MaskedCrossoverTransformerLayer(nn.Module):
1009 |
1010 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0):
1011 | super().__init__()
1012 |
1013 | self.v_transformer = MaskedMultiHeadAttention(
1014 | model_dim, num_heads, dropout)
1015 | self.q_transformer = MaskedMultiHeadAttention(
1016 | model_dim, num_heads, dropout)
1017 | self.v_feed_forward = MaskedPositionalWiseFeedForward(
1018 | model_dim, ffn_dim, dropout)
1019 | self.q_feed_forward = MaskedPositionalWiseFeedForward(
1020 | model_dim, ffn_dim, dropout)
1021 |
1022 | def forward(
1023 | self,
1024 | question,
1025 | video,
1026 | q_mask=None,
1027 | v_mask=None,
1028 | q_sf_mask=None,
1029 | v_sf_mask=None):
1030 | # self attention, v_attention of shape (bs, v_len, q_len)
1031 | video_, v_attention = self.v_transformer(
1032 | video, question, question, v_mask, v_sf_mask)
1033 | # feed forward network
1034 | video_ = self.v_feed_forward(video_)
1035 |
1036 | # self attention, q_attention of shape (bs, q_len, v_len)
1037 | question_, q_attention = self.q_transformer(
1038 | question, video, video, q_mask, q_sf_mask)
1039 | # feed forward network
1040 | question_ = self.q_feed_forward(question_)
1041 |
1042 | return video_, question_, v_attention, q_attention
1043 |
1044 |
1045 | class MaskedCrossoverTransformer(nn.Module):
1046 |
1047 | def __init__(
1048 | self,
1049 | q_max_len=35,
1050 | v_max_len=80,
1051 | num_layers=2,
1052 | model_dim=512,
1053 | num_heads=8,
1054 | ffn_dim=2048,
1055 | dropout=0.0,
1056 | position=False):
1057 | super().__init__()
1058 |
1059 | self.position = position
1060 |
1061 | self.encoder_layers = nn.ModuleList(
1062 | [
1063 | MaskedCrossoverTransformerLayer(
1064 | model_dim, num_heads, ffn_dim, dropout)
1065 | for _ in range(num_layers)
1066 | ])
1067 |
1068 | # max_seq_len is 35 or 80
1069 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len)
1070 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len)
1071 |
1072 | def forward(self, question, video, q_length, v_length):
1073 | # q_length of shape (batch, ), each item is the length of the seq
1074 | if self.position:
1075 | question += self.q_pos_embedding(
1076 | q_length)[:, :question.size()[1], :]
1077 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :]
1078 |
1079 | q_attention_mask = padding_mask_k(question, video)
1080 | q_softmax_mask = padding_mask_q(question, video)
1081 | v_attention_mask = padding_mask_k(video, question)
1082 | v_softmax_mask = padding_mask_q(video, question)
1083 |
1084 | q_attentions = []
1085 | v_attentions = []
1086 | for encoder in self.encoder_layers:
1087 | video, question, v_attention, q_attention = encoder(
1088 | question, video, q_attention_mask, v_attention_mask,
1089 | q_softmax_mask, v_softmax_mask)
1090 | q_attentions.append(q_attention)
1091 | v_attentions.append(v_attention)
1092 |
1093 | return question, video, q_attentions, v_attentions
1094 |
1095 |
1096 | class SelfTransformerEncoder(nn.Module):
1097 |
1098 | def __init__(
1099 | self,
1100 | hidden_size,
1101 | n_layers,
1102 | dropout_p,
1103 | vocab_size,
1104 | q_max_len,
1105 | v_max_len,
1106 | embedding=None,
1107 | update_embedding=True,
1108 | position=True):
1109 | super().__init__()
1110 | self.dropout = nn.Dropout(p=dropout_p)
1111 | self.ln_q = nn.LayerNorm(hidden_size, elementwise_affine=False)
1112 | self.ln_v = nn.LayerNorm(hidden_size, elementwise_affine=False)
1113 | self.n_layers = n_layers
1114 | self.position = position
1115 |
1116 | embedding_dim = embedding.shape[
1117 | 1] if embedding is not None else hidden_size
1118 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
1119 |
1120 | # ! no embedding init
1121 | # if embedding is not None:
1122 | # # self.embedding.weight.data.copy_(torch.from_numpy(embedding))
1123 | # self.embedding.weight = nn.Parameter(
1124 | # torch.from_numpy(embedding).float())
1125 | self.upcompress_embedding = nlpnn.WeightDropLinear(
1126 | embedding_dim, hidden_size, weight_dropout=dropout_p, bias=False)
1127 | self.embedding.weight.requires_grad = update_embedding
1128 |
1129 | self.project_c3d = nlpnn.WeightDropLinear(4096, 2048, bias=False)
1130 |
1131 | self.project_resnet_and_c3d = nlpnn.WeightDropLinear(
1132 | 4096, hidden_size, weight_dropout=dropout_p, bias=False)
1133 |
1134 | # max_seq_len is 35 or 80
1135 | self.q_pos_embedding = PositionalEncoding(hidden_size, q_max_len)
1136 | self.v_pos_embedding = PositionalEncoding(hidden_size, v_max_len)
1137 |
1138 | def forward(self, question, resnet, c3d, q_length, v_length):
1139 | ### question
1140 | embedded = self.embedding(question)
1141 | embedded = self.dropout(embedded)
1142 | question = F.relu(self.upcompress_embedding(embedded))
1143 |
1144 | ### video
1145 | # ! no relu
1146 | c3d = self.project_c3d(c3d)
1147 | video = F.relu(
1148 | self.project_resnet_and_c3d(torch.cat((resnet, c3d), dim=2)))
1149 |
1150 | ### position encoding
1151 | if self.position:
1152 | question += self.q_pos_embedding(
1153 | q_length)[:, :question.size()[1], :]
1154 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :]
1155 |
1156 | # question = self.ln_q(question)
1157 | # video = self.ln_v(video)
1158 | return question, video
1159 |
--------------------------------------------------------------------------------
/networks/torchnlp_nn.py:
--------------------------------------------------------------------------------
1 | from torch.nn import Parameter
2 |
3 | import torch
4 |
5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6 |
7 |
8 | class weight_drop():
9 |
10 | def __init__(self, module, weights, dropout):
11 | for name_w in weights:
12 | w = getattr(module, name_w)
13 | del module._parameters[name_w]
14 | module.register_parameter(name_w + '_raw', Parameter(w))
15 |
16 | self.original_module_forward = module.forward
17 |
18 | self.weights = weights
19 | self.module = module
20 | self.dropout = dropout
21 |
22 | def __call__(self, *args, **kwargs):
23 | for name_w in self.weights:
24 | raw_w = getattr(self.module, name_w + '_raw')
25 | w = torch.nn.functional.dropout(
26 | raw_w, p=self.dropout, training=self.module.training)
27 | # module.register_parameter(name_w, Parameter(w))
28 | setattr(self.module, name_w, Parameter(w))
29 |
30 | return self.original_module_forward(*args, **kwargs)
31 |
32 |
33 | def _weight_drop(module, weights, dropout):
34 | setattr(module, 'forward', weight_drop(module, weights, dropout))
35 |
36 |
37 | # def _weight_drop(module, weights, dropout):
38 | # """
39 | # Helper for `WeightDrop`.
40 | # """
41 |
42 | # for name_w in weights:
43 | # w = getattr(module, name_w)
44 | # del module._parameters[name_w]
45 | # module.register_parameter(name_w + '_raw', Parameter(w))
46 |
47 | # original_module_forward = module.forward
48 |
49 | # def forward(*args, **kwargs):
50 | # for name_w in weights:
51 | # raw_w = getattr(module, name_w + '_raw')
52 | # w = torch.nn.functional.dropout(
53 | # raw_w, p=dropout, training=module.training)
54 | # # module.register_parameter(name_w, Parameter(w))
55 | # setattr(module, name_w, Parameter(w))
56 |
57 | # return original_module_forward(*args, **kwargs)
58 |
59 | # setattr(module, 'forward', forward)
60 |
61 |
62 | class WeightDrop(torch.nn.Module):
63 | """
64 | The weight-dropped module applies recurrent regularization through a DropConnect mask on the
65 | hidden-to-hidden recurrent weights.
66 | **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. Here is
67 | their `License
68 | `__.
69 | Args:
70 | module (:class:`torch.nn.Module`): Containing module.
71 | weights (:class:`list` of :class:`str`): Names of the module weight parameters to apply a
72 | dropout too.
73 | dropout (float): The probability a weight will be dropped.
74 | Example:
75 | >>> from torchnlp.nn import WeightDrop
76 | >>> import torch
77 | >>>
78 | >>> torch.manual_seed(123)
79 | >>
81 | >>> gru = torch.nn.GRUCell(2, 2)
82 | >>> weights = ['weight_hh']
83 | >>> weight_drop_gru = WeightDrop(gru, weights, dropout=0.9)
84 | >>>
85 | >>> input_ = torch.randn(3, 2)
86 | >>> hidden_state = torch.randn(3, 2)
87 | >>> weight_drop_gru(input_, hidden_state)
88 | tensor(... grad_fn=)
89 | """
90 |
91 | def __init__(self, module, weights, dropout=0.0):
92 | super(WeightDrop, self).__init__()
93 | _weight_drop(module, weights, dropout)
94 | self.forward = module.forward
95 |
96 |
97 | class WeightDropLSTM(torch.nn.LSTM):
98 | """
99 | Wrapper around :class:`torch.nn.LSTM` that adds ``weight_dropout`` named argument.
100 | Args:
101 | weight_dropout (float): The probability a weight will be dropped.
102 | """
103 |
104 | def __init__(self, *args, weight_dropout=0.0, **kwargs):
105 | super().__init__(*args, **kwargs)
106 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)]
107 | _weight_drop(self, weights, weight_dropout)
108 |
109 |
110 | class WeightDropGRU(torch.nn.GRU):
111 | """
112 | Wrapper around :class:`torch.nn.GRU` that adds ``weight_dropout`` named argument.
113 | Args:
114 | weight_dropout (float): The probability a weight will be dropped.
115 | """
116 |
117 | def __init__(self, *args, weight_dropout=0.0, **kwargs):
118 | super().__init__(*args, **kwargs)
119 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)]
120 | _weight_drop(self, weights, weight_dropout)
121 |
122 |
123 | class WeightDropLinear(torch.nn.Linear):
124 | """
125 | Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument.
126 | Args:
127 | weight_dropout (float): The probability a weight will be dropped.
128 | """
129 |
130 | def __init__(self, *args, weight_dropout=0.0, **kwargs):
131 | super().__init__(*args, **kwargs)
132 | weights = ['weight']
133 | _weight_drop(self, weights, weight_dropout)
134 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | backcall==0.2.0
2 | bert-score==0.3.6
3 | block==0.0.5
4 | block.bootstrap.pytorch==0.1.6
5 | bootstrap.pytorch==0.0.13
6 | certifi==2020.12.5
7 | chardet==4.0.0
8 | click==8.0.1
9 | cycler==0.10.0
10 | decorator==5.0.9
11 | future==0.18.2
12 | h5py==2.7.1
13 | idna==2.10
14 | importlib-metadata==4.0.1
15 | importlib-resources==5.1.4
16 | ipdb==0.13.7
17 | ipython==7.16.1
18 | ipython-genutils==0.2.0
19 | jedi==0.18.0
20 | kiwisolver==1.3.1
21 | matplotlib==3.3.4
22 | munch==2.5.0
23 | nltk==3.3
24 | numpy==1.19.5
25 | opencv-python==4.5.2.52
26 | pandas==1.1.2
27 | parso==0.8.2
28 | pexpect==4.8.0
29 | pickleshare==0.7.5
30 | Pillow==8.2.0
31 | plotly==4.14.3
32 | pretrainedmodels==0.7.4
33 | prompt-toolkit==3.0.18
34 | protobuf==3.17.0
35 | ptyprocess==0.7.0
36 | Pygments==2.9.0
37 | pyparsing==2.4.7
38 | python-dateutil==2.8.1
39 | pytz==2021.1
40 | pywsd==1.2.4
41 | PyYAML==5.4.1
42 | requests==2.25.1
43 | retrying==1.3.3
44 | scipy==1.5.4
45 | seaborn==0.11.1
46 | six==1.16.0
47 | skipthoughts==0.0.1
48 | tabulate==0.8.9
49 | tensorboardX==2.2
50 | toml==0.10.2
51 | torch==1.6.0
52 | torchvision==0.7.0
53 | tqdm==4.60.0
54 | traitlets==4.3.3
55 | typing-extensions==3.10.0.0
56 | urllib3==1.26.4
57 | wcwidth==0.2.5
58 | wn==0.0.23
59 | zipp==3.4.1
60 |
--------------------------------------------------------------------------------
/stopwords.txt:
--------------------------------------------------------------------------------
1 | i
2 | me
3 | my
4 | myself
5 | we
6 | our
7 | ours
8 | ourselves
9 | you
10 | you're
11 | you've
12 | you'll
13 | you'd
14 | your
15 | yours
16 | yourself
17 | yourselves
18 | he
19 | him
20 | his
21 | himself
22 | she
23 | she's
24 | her
25 | hers
26 | herself
27 | it
28 | it's
29 | its
30 | itself
31 | they
32 | them
33 | their
34 | theirs
35 | themselves
36 | what
37 | which
38 | who
39 | whom
40 | this
41 | that
42 | that'll
43 | these
44 | those
45 | am
46 | is
47 | are
48 | was
49 | were
50 | be
51 | been
52 | being
53 | have
54 | has
55 | had
56 | having
57 | do
58 | does
59 | did
60 | doing
61 | a
62 | an
63 | the
64 | and
65 | but
66 | if
67 | or
68 | because
69 | as
70 | until
71 | while
72 | to
73 | from
74 | of
75 | at
76 | for
77 | with
78 | about
79 | into
80 | through
81 | during
82 | again
83 | further
84 | then
85 | here
86 | there
87 | when
88 | where
89 | why
90 | how
91 | all
92 | any
93 | each
94 | most
95 | other
96 | some
97 | such
98 | only
99 | own
100 | so
101 | than
102 | too
103 | very
104 | s
105 | t
106 | can
107 | will
108 | just
109 | don
110 | don't
111 | should
112 | should've
113 | now
114 | d
115 | ll
116 | m
117 | o
118 | re
119 | ve
120 | y
121 | ain
122 | aren
123 | aren't
124 | couldn
125 | couldn't
126 | didn
127 | didn't
128 | doesn
129 | doesn't
130 | hadn
131 | hadn't
132 | hasn
133 | hasn't
134 | haven
135 | haven't
136 | isn
137 | isn't
138 | ma
139 | mightn
140 | mightn't
141 | mustn
142 | mustn't
143 | needn
144 | needn't
145 | shan
146 | shan't
147 | shouldn
148 | shouldn't
149 | wasn
150 | wasn't
151 | weren
152 | weren't
153 | won
154 | won't
155 | wouldn
156 | wouldn't
157 |
--------------------------------------------------------------------------------
/tools/.gitignore:
--------------------------------------------------------------------------------
1 | # ignore all except .gitignore file
2 | *
3 | !.gitignore
4 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import os.path as osp
4 | import pickle as pkl
5 | import pandas as pd
6 |
7 | def set_gpu_devices(gpu_id):
8 | gpu = ''
9 | if gpu_id != -1:
10 | gpu = str(gpu_id)
11 | os.environ['CUDA_VOSIBLE_DEVICES'] = gpu
12 |
13 |
14 | def load_file(filename):
15 | """
16 | load obj from filename
17 | :param filename:
18 | :return:
19 | """
20 | cont = None
21 | if not osp.exists(filename):
22 | print('{} not exist'.format(filename))
23 | return cont
24 | if osp.splitext(filename)[-1] == '.csv':
25 | return pd.read_csv(filename, delimiter=',')
26 | with open(filename, 'r') as fp:
27 | if osp.splitext(filename)[1] == '.txt':
28 | cont = fp.readlines()
29 | cont = [c.rstrip('\n') for c in cont]
30 | elif osp.splitext(filename)[1] == '.json':
31 | cont = json.load(fp)
32 | return cont
33 |
34 | def save_file(obj, filename):
35 | """
36 | save obj to filename
37 | :param obj:
38 | :param filename:
39 | :return:
40 | """
41 | filepath = osp.dirname(filename)
42 | if filepath != '' and not osp.exists(filepath):
43 | os.makedirs(filepath)
44 | else:
45 | with open(filename, 'w') as fp:
46 | json.dump(obj, fp, indent=4)
47 |
48 | def pkload(file):
49 | data = None
50 | if osp.exists(file) and osp.getsize(file) > 0:
51 | with open(file, 'rb') as fp:
52 | data = pkl.load(fp)
53 | # print('{} does not exist'.format(file))
54 | return data
55 |
56 |
57 | def pkdump(data, file):
58 | dirname = osp.dirname(file)
59 | if not osp.exists(dirname):
60 | os.makedirs(dirname)
61 | with open(file, 'wb') as fp:
62 | pkl.dump(data, fp)
63 |
64 |
--------------------------------------------------------------------------------
/videoqa.py:
--------------------------------------------------------------------------------
1 | from networks import EncoderRNN, DecoderRNN
2 | from networks.VQAModel import EVQA, UATT, STVQA, CoMem, HME, HGA
3 | from utils import *
4 | import torch
5 | from torch.optim.lr_scheduler import ReduceLROnPlateau
6 | import torch.nn as nn
7 | import time
8 | from metrics import get_wups
9 | from eval_oe import remove_stop
10 |
11 | class VideoQA():
12 | def __init__(self, vocab_qns, vocab_ans, train_loader, val_loader, glove_embed_qns, glove_embed_ans,
13 | checkpoint_path, model_type, model_prefix, vis_step,
14 | lr_rate, batch_size, epoch_num):
15 | self.vocab_qns = vocab_qns
16 | self.vocab_ans = vocab_ans
17 | self.train_loader = train_loader
18 | self.val_loader = val_loader
19 | self.glove_embed_qns = glove_embed_qns
20 | self.glove_embed_ans = glove_embed_ans
21 | self.model_dir = checkpoint_path
22 | self.model_type = model_type
23 | self.model_prefix = model_prefix
24 | self.vis_step = vis_step
25 | self.lr_rate = lr_rate
26 | self.batch_size = batch_size
27 | self.epoch_num = epoch_num
28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29 | self.model = None
30 |
31 |
32 | def build_model(self):
33 |
34 | vid_dim = 2048+2048
35 | hidden_dim = 512
36 | word_dim = 300
37 | qns_vocab_size = len(self.vocab_qns)
38 | ans_vocab_size = len(self.vocab_ans)
39 | max_ans_len = 7
40 | max_vid_len = 16
41 | max_qns_len = 23
42 |
43 |
44 | if self.model_type == 'EVQA' or self.model_type == 'BlindQA':
45 | #ICCV15, AAAI17
46 | vid_encoder = EncoderRNN.EncoderVid(vid_dim, hidden_dim, input_dropout_p=0.3, n_layers=1, rnn_dropout_p=0,
47 | bidirectional=False, rnn_cell='lstm')
48 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=1,
49 | input_dropout_p=0.3, rnn_dropout_p=0, bidirectional=False, rnn_cell='lstm')
50 |
51 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans,
52 | n_layers=1, input_dropout_p=0.3, rnn_dropout_p=0, rnn_cell='gru')
53 | self.model = EVQA.EVQA(vid_encoder, qns_encoder, ans_decoder, self.device)
54 |
55 | elif self.model_type == 'UATT':
56 | #TIP17
57 | # hidden_dim = 512
58 | vid_encoder = EncoderRNN.EncoderVid(vid_dim, hidden_dim, input_dropout_p=0.3, bidirectional=True,
59 | rnn_cell='lstm')
60 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns,
61 | input_dropout_p=0.3, bidirectional=True, rnn_cell='lstm')
62 |
63 | ans_decoder = DecoderRNN.AnsUATT(ans_vocab_size, max_ans_len, hidden_dim, word_dim,
64 | self.glove_embed_ans, n_layers=2, input_dropout_p=0.3,
65 | rnn_dropout_p=0.5, rnn_cell='lstm')
66 | self.model = UATT.UATT(vid_encoder, qns_encoder, ans_decoder, self.device)
67 |
68 | elif self.model_type == 'STVQA':
69 | #CVPR17
70 | vid_dim = 2048 + 2048 # (64, 1024+2048, 7, 7)
71 | att_dim = 256
72 | hidden_dim = 256
73 | vid_encoder = EncoderRNN.EncoderVidSTVQA(vid_dim, hidden_dim, input_dropout_p=0.3, rnn_dropout_p=0,
74 | n_layers=1, rnn_cell='lstm')
75 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns,
76 | input_dropout_p=0.3, rnn_dropout_p=0.5, n_layers=2, rnn_cell='lstm')
77 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans,
78 | input_dropout_p=0.3, rnn_dropout_p=0, n_layers=1, rnn_cell='gru')
79 | self.model = STVQA.STVQA(vid_encoder, qns_encoder, ans_decoder, att_dim, self.device)
80 |
81 |
82 | elif self.model_type == 'CoMem':
83 | #CVPR18
84 | app_dim = 2048
85 | motion_dim = 2048
86 | hidden_dim = 256
87 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.3,
88 | bidirectional=False, rnn_cell='gru')
89 |
90 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=2,
91 | rnn_dropout_p=0.5, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru')
92 |
93 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans,
94 | n_layers=1, input_dropout_p=0.3, rnn_dropout_p=0, rnn_cell='gru')
95 |
96 | self.model = CoMem.CoMem(vid_encoder, qns_encoder, ans_decoder, max_vid_len, max_qns_len, self.device)
97 |
98 |
99 | elif self.model_type == 'HME':
100 | #CVPR19
101 | app_dim = 2048
102 | motion_dim = 2048
103 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.3,
104 | bidirectional=False, rnn_cell='lstm')
105 |
106 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=2,
107 | rnn_dropout_p=0.5, input_dropout_p=0.3, bidirectional=False, rnn_cell='lstm')
108 |
109 | ans_decoder = DecoderRNN.AnsHME(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans,
110 | n_layers=2, input_dropout_p=0.3, rnn_dropout_p=0.5, rnn_cell='lstm')
111 |
112 | self.model = HME.HME(vid_encoder, qns_encoder, ans_decoder, max_vid_len, max_qns_len, self.device)
113 |
114 |
115 | elif self.model_type == 'HGA':
116 | #AAAI20
117 | vid_encoder = EncoderRNN.EncoderVidHGA(vid_dim, hidden_dim, input_dropout_p=0.3,
118 | bidirectional=False, rnn_cell='gru')
119 |
120 | qns_encoder = EncoderRNN.EncoderQnsHGA(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=1,
121 | rnn_dropout_p=0, input_dropout_p=0.3, bidirectional=False,
122 | rnn_cell='gru')
123 |
124 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans,
125 | n_layers=1, input_dropout_p=0.3, rnn_dropout_p=0, rnn_cell='gru')
126 |
127 | self.model = HGA.HGA(vid_encoder, qns_encoder, ans_decoder, max_vid_len, max_qns_len, self.device)
128 |
129 |
130 | params = [{'params':self.model.parameters()}]
131 | # params = [{'params': vid_encoder.parameters()}, {'params': qns_encoder.parameters()},
132 | # {'params': ans_decoder.parameters(), 'lr': self.lr_rate}]
133 | self.optimizer = torch.optim.Adam(params = params, lr=self.lr_rate)
134 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'max', factor=0.5, patience=5, verbose=True)
135 | # if torch.cuda.device_count() > 1:
136 | # print("Let's use", torch.cuda.device_count(), "GPUs!")
137 | # self.model = nn.DataParallel(self.model)
138 |
139 | self.model.to(self.device)
140 | self.criterion = nn.CrossEntropyLoss().to(self.device)
141 |
142 |
143 | def save_model(self, epoch, loss):
144 | torch.save(self.model.state_dict(), osp.join(self.model_dir, '{}-{}-{}-{:.4f}.ckpt'
145 | .format(self.model_type, self.model_prefix, epoch, loss)))
146 |
147 | def resume(self, model_file):
148 | """
149 | initialize model with pretrained weights
150 | :return:
151 | """
152 | model_path = osp.join(self.model_dir, model_file)
153 | print(f'Warm-starting from model {model_path}')
154 | model_dict = torch.load(model_path)
155 | new_model_dict = {}
156 | for k, v in self.model.state_dict().items():
157 | if k in model_dict:
158 | v = model_dict[k]
159 |
160 | new_model_dict[k] = v
161 | self.model.load_state_dict(new_model_dict)
162 |
163 |
164 | def run(self, model_file, pre_trained=False):
165 | self.build_model()
166 | best_eval_score = 0.0
167 | if pre_trained:
168 | self.resume(model_file)
169 | best_eval_score = self.eval(0)
170 | print('Initial Acc {:.4f}'.format(best_eval_score))
171 |
172 | for epoch in range(1, self.epoch_num):
173 | train_loss = self.train(epoch)
174 | eval_score = self.eval(epoch)
175 | print("==>Epoch:[{}/{}][Train Loss: {:.4f} Val acc: {:.4f}]".
176 | format(epoch, self.epoch_num, train_loss, eval_score))
177 | self.scheduler.step(eval_score)
178 | if eval_score > best_eval_score or pre_trained:
179 | best_eval_score = eval_score
180 | if epoch > 10 or pre_trained:
181 | self.save_model(epoch, best_eval_score)
182 |
183 |
184 | def train(self, epoch):
185 | print('==>Epoch:[{}/{}][lr_rate: {}]'.format(epoch, self.epoch_num, self.optimizer.param_groups[0]['lr']))
186 | self.model.train()
187 | total_step = len(self.train_loader)
188 | epoch_loss = 0.0
189 | for iter, inputs in enumerate(self.train_loader):
190 | videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes = inputs
191 | video_inputs = videos.to(self.device)
192 | qns_inputs = targets_qns.to(self.device)
193 | ans_inputs = targets_ans.to(self.device)
194 | prediction = self.model(video_inputs, qns_inputs, qns_lengths, ans_inputs, ans_lengths, 0.5)
195 |
196 | out_dim = prediction.shape[-1]
197 | prediction = prediction.view(-1, out_dim)
198 | ans_targets = ans_inputs.view(-1)
199 |
200 | loss = self.criterion(prediction, ans_targets)
201 | self.model.zero_grad()
202 | loss.backward()
203 | self.optimizer.step()
204 | cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
205 | if iter % self.vis_step == 0:
206 | print('\t[{}/{}]-{}-{:.4f}'.format(iter, total_step,cur_time, loss.item()))
207 | epoch_loss += loss.item()
208 |
209 | return epoch_loss / total_step
210 |
211 |
212 | def eval(self, epoch):
213 | print('==>Epoch:[{}/{}][validation stage]'.format(epoch, self.epoch_num))
214 | self.model.eval()
215 | total_step = len(self.val_loader)
216 | acc_count = 0
217 | with torch.no_grad():
218 | for iter, inputs in enumerate(self.val_loader):
219 | videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes = inputs
220 | video_inputs = videos.to(self.device)
221 | qns_inputs = targets_qns.to(self.device)
222 | ans_inputs = targets_ans.to(self.device)
223 | prediction = self.model(video_inputs, qns_inputs, qns_lengths, ans_inputs, ans_lengths, mode='val')
224 | acc_count += get_acc_count(prediction, targets_ans, self.vocab_ans, qtypes)
225 |
226 | return acc_count*1.0 / ((total_step-1)*self.batch_size)
227 |
228 |
229 | def predict(self, model_file, res_file):
230 | """
231 | predict the answer with the trained model
232 | :param model_file:
233 | :return:
234 | """
235 | model_path = osp.join(self.model_dir, model_file)
236 | self.build_model()
237 | if self.model_type == 'HGA':
238 | self.resume(model_file)
239 | else:
240 | old_state_dict = torch.load(model_path)
241 | self.model.load_state_dict(old_state_dict)
242 | #self.resume()
243 | self.model.eval()
244 | total = len(self.val_loader)
245 | acc = 0
246 | results = {}
247 | with torch.no_grad():
248 | for iter, inputs in enumerate(self.val_loader):
249 | videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes = inputs
250 | video_inputs = videos.to(self.device)
251 | qns_inputs = targets_qns.to(self.device)
252 | ans_inputs = targets_ans.to(self.device)
253 | # predict_ans_idxs = self.model.predict(video_inputs, qns_inputs, qns_lengths)
254 | predict_ans_idxs = self.model(video_inputs, qns_inputs, qns_lengths, ans_inputs, ans_lengths, mode='val')
255 | ans_idxs = predict_ans_idxs.cpu().numpy()
256 | targets_ans = targets_ans.numpy()
257 | targets_qns = targets_qns.numpy()
258 | for vname in video_names:
259 | if vname not in results:
260 | results[vname] = {}
261 | for bs, idx in enumerate(ans_idxs):
262 | ans_pred = [self.vocab_ans.idx2word[ans_id] for ans_id in idx[1:] if ans_id >3] #the first 4 ids are reserved for special token
263 | ans_pred = ' '.join(ans_pred)
264 | groundtruth = [self.vocab_ans.idx2word[ans_id] for ans_id in targets_ans[bs][1:] if ans_id > 3]
265 | groundtruth = ' '.join(groundtruth)
266 | qns_text = [self.vocab_qns.idx2word[qns_id] for qns_id in targets_qns[bs][1:] if qns_id > 3]
267 | # if qids[bs] not in results[video_names[bs]]:
268 | qns_text = ' '.join(qns_text)
269 | results[video_names[bs]][qids[bs]] = ans_pred
270 | if ans_pred==groundtruth and ans_pred != '':
271 | acc += 1
272 | # print(f'[{iter}/{total}]{qns_text}? P:{ans_pred} G:{groundtruth}')
273 |
274 | save_file(results, f'results/{res_file}')
275 |
276 |
277 | def get_acc_count(prediction, labels, vocab_ans, qtypes):
278 | """
279 |
280 | :param prediction:
281 | :param labels:
282 | :return:
283 | """
284 | preds = prediction.data.cpu().numpy()
285 | labels = np.asarray(labels)
286 | batch_size = labels.shape[0]
287 | score = 0
288 | for i in range(batch_size):
289 | pred = [j for j in preds[i] if j > 3]
290 | ans = [j for j in labels[i] if j > 3]
291 | pred_ans = ' '.join([vocab_ans.idx2word[id] for id in pred])
292 | gt_ans = ' '.join([vocab_ans.idx2word[id] for id in ans])
293 | pred_ans = remove_stop(pred_ans)
294 | gt_ans = remove_stop(gt_ans)
295 | cur_s = 0
296 | if qtypes[i] in ['CC', 'CB']:
297 | if gt_ans == pred_ans:
298 | cur_s = 1
299 | else:
300 | cur_s = get_wups(pred_ans, gt_ans, 0)
301 | score += cur_s
302 |
303 |
304 | return score
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
--------------------------------------------------------------------------------
/word2vec.py:
--------------------------------------------------------------------------------
1 | from build_vocab import Vocabulary
2 | from utils import *
3 | import numpy as np
4 | import random as rd
5 | rd.seed(0)
6 |
7 | def word2vec(vocab, glove_file, save_filename):
8 | glove = load_file(glove_file)
9 | word2vec = {}
10 | for i, line in enumerate(glove):
11 | if i == 0: continue # for FastText
12 | line = line.split(' ')
13 | word2vec[line[0]] = np.array(line[1:]).astype(np.float)
14 |
15 | temp = []
16 | for word, vec in word2vec.items():
17 | temp.append(vec)
18 | temp = np.asarray(temp)
19 | print(temp.shape)
20 | row, col = temp.shape
21 |
22 | pad = np.mean(temp, axis=0)
23 | start = np.mean(temp[:int(row//2), :], axis=0)
24 | end = np.mean(temp[int(row//2):, :], axis=0)
25 | special_tokens = [pad, start, end]
26 | count = 0
27 | bad_words = []
28 | sort_idx_word = sorted(vocab.idx2word.items(), key=lambda k:k[0])
29 | glove_embed = np.zeros((len(vocab), 300))
30 | for row, item in enumerate(sort_idx_word):
31 | idx, word = item[0], item[1]
32 | if word in word2vec:
33 | glove_embed[row] = word2vec[word]
34 | else:
35 | if row < 3:
36 | glove_embed[row] = special_tokens[row]
37 | else:
38 | glove_embed[row] = np.random.randn(300)*0.4
39 | print(word)
40 | bad_words.append(word)
41 | count += 1
42 | print(glove_embed.shape)
43 | save_file(bad_words, 'bad_words_qns.json')
44 | np.save(save_filename, glove_embed)
45 | print(count)
46 |
47 |
48 | def main():
49 | data_dir = 'dataset/nextqa/'
50 | vocab_file = osp.join(data_dir, 'vocab.pkl')
51 | vocab = pkload(vocab_file)
52 | glove_file = '../data/Vocabulary/glove.840B.300d.txt'
53 | save_filename = 'dataset/nextqa/glove_embed.npy'
54 | word2vec(vocab, glove_file, save_filename)
55 |
56 | if __name__ == "__main__":
57 | main()
58 |
--------------------------------------------------------------------------------