├── .gitignore ├── LICENSE ├── README.md ├── build_vocab.py ├── dataloader ├── __init__.py ├── sample_loader.py └── util.py ├── dataset └── nextqa │ ├── .gitignore │ ├── add_reference_answer_test.json │ ├── test.csv │ ├── train.csv │ └── val.csv ├── eval_oe.py ├── images ├── logo.png └── res-mc-oe.png ├── main.sh ├── main_qa.py ├── metrics.py ├── models └── .gitignore ├── networks ├── .gitignore ├── Attention.py ├── DecoderRNN.py ├── EncoderRNN.py ├── VQAModel │ ├── CoMem.py │ ├── EVQA.py │ ├── HGA.py │ ├── HME.py │ ├── STVQA.py │ └── UATT.py ├── VQAModel_bak.py ├── __init__.py ├── gcn.py ├── memory_module.py ├── memory_rand.py ├── q_v_transformer.py └── torchnlp_nn.py ├── requirements.txt ├── results ├── HGA-same-att-qns23ans7-test.json ├── HGA-same-att-qns23ans7-val-example.json └── HGA-same-att-qns23ans7-val.json ├── stopwords.txt ├── tools └── .gitignore ├── utils.py ├── videoqa.py └── word2vec.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Junbin 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [NExT-QA](https://arxiv.org/pdf/2105.08276.pdf) 2 | 3 | We reproduce some SOTA VideoQA methods to provide benchmark results for our NExT-QA dataset accepted to CVPR2021. 4 | 5 | NExT-QA is a VideoQA benchmark targeting the explanation of video contents. It challenges QA models to reason about the causal and temporal actions and understand the rich object interactions in daily activities. We set up both multi-choice and open-ended QA tasks on the dataset. This repo. provides resources for open-ended QA; multi-choice QA is found in [NExT-QA](https://github.com/doc-doc/NExT-QA). For more details, please refer to our [dataset](https://doc-doc.github.io/docs/nextqa.html) page. 6 | 7 | ## Todo 8 | 1. [x] Raw Videos are the same with [NExT-QA(MC)](https://github.com/doc-doc/NExT-QA). 9 | 2. [ ] Open online evaluation server and release [test data](https://drive.google.com/file/d/1bXBFN61PaTSHTnJqz3R79mpIgEQPFGIU/view?usp=sharing). 10 | 3. [x] RoI features are the same with [NExT-QA(MC)](https://github.com/doc-doc/NExT-QA). 11 | ## Environment 12 | 13 | Anaconda 4.8.4, python 3.6.8, pytorch 1.6 and cuda 10.2. For other libs, please refer to the file requirements.txt. 14 | 15 | ## Install 16 | Please create an env for this project using anaconda (should install [anaconda](https://docs.anaconda.com/anaconda/install/linux/) first) 17 | ``` 18 | >conda create -n videoqa python==3.6.8 19 | >conda activate videoqa 20 | >git clone https://github.com/doc-doc/NExT-OE.git 21 | >pip install -r requirements.txt 22 | ``` 23 | ## Data Preparation 24 | Please download the pre-computed features and QA annotations from [here](https://drive.google.com/drive/folders/14jSt4sGFQaZxBu4AGL2Svj34fUhcK2u0?usp=sharing). There are 3 zip files: 25 | - ```['vid_feat.zip']```: Appearance and motion feature for video representation (same as multi-choice QA). 26 | - ```['nextqa.zip']```: Annotations of QAs and GloVe Embeddings (open-ended version). 27 | - ```['models.zip']```: HGA model (open-ended version). 28 | 29 | After downloading the data, please create a folder ```['data/feats']``` at the same directory as ```['NExT-OE']```, then unzip the video features into it. You will have directories like ```['data/feats/vid_feat/', and 'NExT-OE/']``` in your workspace. Please unzip the files in ```['nextqa.zip']``` into ```['NExT-OE/dataset/nextqa']``` and ```['models.zip']``` into ```['NExT-OE/models/']```. 30 | 31 | 32 | ## Usage 33 | Once the data is ready, you can easily run the code. First, to test the environment and code, we provide the prediction and model of the SOTA approach (i.e., HGA) on NExT-QA. 34 | You can get the results reported in the paper by running: 35 | ``` 36 | >python eval_oe.py 37 | ``` 38 | The command above will load the prediction file under ['results/'] and evaluate it. 39 | You can also obtain the prediction by running: 40 | ``` 41 | >./main.sh 0 val #Test the model with GPU id 0 42 | ``` 43 | The command above will load the model under ['models/'] and generate the prediction file. 44 | If you want to train the model, please run 45 | ``` 46 | >./main.sh 0 train # Train the model with GPU id 0 47 | ``` 48 | It will train the model and save to ['models']. (*The results may be slightly different depending on the environments*) 49 | ## Results on Val 50 | | Methods | Text Rep. | WUPS_C | WUPS_T | WUPS_D | WUPS | 51 | | -------------------------| --------: | ----: | ----: | ----: | ---:| 52 | | BlindQA | GloVe | 12.14 | 14.85 | 40.41 | 18.88 | 53 | | [STVQA](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/STVQA.py) ([CVPR17](https://openaccess.thecvf.com/content_cvpr_2017/papers/Jang_TGIF-QA_Toward_Spatio-Temporal_CVPR_2017_paper.pdf)) | GloVe | 12.52 | 14.57 | 45.64 | 20.08 | 54 | | [UATT](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/UATT.py) ([TIP17](https://ieeexplore.ieee.org/document/8017608)) | GloVe | 13.62 | **16.23** | 43.41 | 20.65 | 55 | | [HME](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/HME.py) ([CVPR19](https://openaccess.thecvf.com/content_CVPR_2019/papers/Fan_Heterogeneous_Memory_Enhanced_Multimodal_Attention_Model_for_Video_Question_Answering_CVPR_2019_paper.pdf)) | GloVe | 12.83 | 14.76 | 45.13 | 20.18 | 56 | | [HCRN](https://github.com/thaolmk54/hcrn-videoqa) ([CVPR20](https://openaccess.thecvf.com/content_CVPR_2020/papers/Le_Hierarchical_Conditional_Relation_Networks_for_Video_Question_Answering_CVPR_2020_paper.pdf)) | GloVe | 12.53 | 15.37 | 45.29 | 20.25 | 57 | | [HGA](https://github.com/doc-doc/NExT-OE/blob/main/networks/VQAModel/HGA.py) ([AAAI20](https://ojs.aaai.org//index.php/AAAI/article/view/6767)) | GloVe | **14.76** | 14.90 | **46.60** | **21.48** | 58 | 59 | Please refer to our paper for results on the test set. 60 | ## Multi-choice QA *vs.* Open-ended QA 61 | ![vis mc_oe](./images/res-mc-oe.png) 62 | 63 | ## Some Latest Results 64 | | Methods | Publication | Highlight | Val (WUPS@All) | Test (WUPS@All) | 65 | | -------------------------| --------: |--------: | ----: | ----:| 66 | |[Emu(0-shot)](https://arxiv.org/pdf/2307.05222v1.pdf) by BAAI | arXiv'23 | VL foundation model | - | 23.4 | 67 | | [Flamingo(0-shot)](https://arxiv.org/pdf/2204.14198.pdf) by DeepMind | NeurIPS'22 | VL foundation model | - | 26.7| 68 | | [KcGA](https://ojs.aaai.org/index.php/AAAI/article/view/25983) by Baidu | AAAI'23 | Knowledge base, GPT-2 | - | 28.2 | 69 | | [Flamingo(32-shot)](https://arxiv.org/pdf/2204.14198.pdf) by DeepMind | NeurIPS'22 | VL foundation model | - | 33.5| 70 | | [PaLI-X](https://arxiv.org/pdf/2305.18565.pdf) by Google Research | arXiv'23 | VL foundation model | - | 38.3 | 71 | 72 | ## Citation 73 | ``` 74 | @InProceedings{xiao2021next, 75 | author = {Xiao, Junbin and Shang, Xindi and Yao, Angela and Chua, Tat-Seng}, 76 | title = {NExT-QA: Next Phase of Question-Answering to Explaining Temporal Actions}, 77 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 78 | month = {June}, 79 | year = {2021}, 80 | pages = {9777-9786} 81 | } 82 | ``` 83 | ## Acknowledgement 84 | Our reproduction of the methods is based on the respective official repositories, we thank the authors to release their code. If you use the related part, please cite the corresponding paper commented in the code. 85 | -------------------------------------------------------------------------------- /build_vocab.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | # nltk.download('punkt') 3 | import pickle 4 | import argparse 5 | from utils import load_file, save_file 6 | from collections import Counter 7 | 8 | 9 | 10 | class Vocabulary(object): 11 | """Simple vocabulary wrapper.""" 12 | def __init__(self): 13 | self.word2idx = {} 14 | self.idx2word = {} 15 | self.idx = 0 16 | 17 | def add_word(self, word): 18 | if not word in self.word2idx: 19 | self.word2idx[word] = self.idx 20 | self.idx2word[self.idx] = word 21 | self.idx += 1 22 | 23 | def __call__(self, word): 24 | if not word in self.word2idx: 25 | return self.word2idx[''] 26 | return self.word2idx[word] 27 | 28 | def __len__(self): 29 | return len(self.word2idx) 30 | 31 | 32 | 33 | def build_vocab(anno_file, threshold): 34 | """Build a simple vocabulary wrapper.""" 35 | 36 | annos = load_file(anno_file) 37 | print('total QA pairs', len(annos)) 38 | counter = Counter() 39 | 40 | for rid, (qns, ans) in enumerate(zip(annos['question'], annos['answer'])): 41 | # qns, ans = vqa['question'], vqa['answer'] 42 | text = qns +' ' +ans 43 | tokens = nltk.tokenize.word_tokenize(text.lower()) 44 | counter.update(tokens) 45 | 46 | counter = sorted(counter.items(), key=lambda item:item[1], reverse=True) 47 | 48 | # If the word frequency is less than 'threshold', then the word is discarded. 49 | words = [item[0] for item in counter if item[1] >= threshold] 50 | 51 | # Create a vocab wrapper and add some special tokens. 52 | vocab = Vocabulary() 53 | vocab.add_word('') 54 | vocab.add_word('') 55 | vocab.add_word('') 56 | vocab.add_word('') 57 | 58 | # Add the words to the vocabulary. 59 | for i, word in enumerate(words): 60 | vocab.add_word(word) 61 | 62 | return vocab 63 | 64 | 65 | def main(args): 66 | vocab = build_vocab(args.caption_path, args.threshold) 67 | vocab_path = args.vocab_path 68 | with open(vocab_path, 'wb') as f: 69 | pickle.dump(vocab, f) 70 | print("Total vocabulary size: {}".format(len(vocab))) 71 | print("Saved the vocabulary wrapper to '{}'".format(vocab_path)) 72 | 73 | 74 | if __name__ == '__main__': 75 | parser = argparse.ArgumentParser() 76 | parser.add_argument('--anno_path', type=str, 77 | default='dataset/nextqa/all.csv', 78 | help='path for train annotation file') 79 | parser.add_argument('--vocab_path', type=str, default='dataset/nextqa/vocab.pkl', 80 | help='path for saving vocabulary wrapper') 81 | parser.add_argument('--threshold', type=int, default=5, 82 | help='minimum word count threshold') 83 | args = parser.parse_args() 84 | main(args) 85 | -------------------------------------------------------------------------------- /dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | # ==================================================== 2 | # @Time : 15/5/20 3:48 PM 3 | # @Author : Xiao Junbin 4 | # @Email : junbin@comp.nus.edu.sg 5 | # @File : __init__.py 6 | # ==================================================== 7 | from .sample_loader import * -------------------------------------------------------------------------------- /dataloader/sample_loader.py: -------------------------------------------------------------------------------- 1 | # ==================================================== 2 | # @Time : 19/5/20 10:42 PM 3 | # @Author : Xiao Junbin 4 | # @Email : junbin@comp.nus.edu.sg 5 | # @File : sample_loader.py 6 | # ==================================================== 7 | import torch 8 | from torch.utils.data import Dataset, DataLoader 9 | from .util import load_file, pkdump, pkload 10 | import os.path as osp 11 | import numpy as np 12 | import nltk 13 | import h5py 14 | 15 | class VidQADataset(Dataset): 16 | """load the dataset in dataloader""" 17 | 18 | def __init__(self, video_feature_path, video_feature_cache, sample_list_path, vocab_qns, vocab_ans, mode): 19 | self.video_feature_path = video_feature_path 20 | self.vocab_qns = vocab_qns 21 | self.vocab_ans = vocab_ans 22 | sample_list_file = osp.join(sample_list_path, '{}.csv'.format(mode)) 23 | self.sample_list = load_file(sample_list_file) 24 | self.video_feature_cache = video_feature_cache 25 | self.use_frame = True 26 | self.use_mot = True 27 | self.frame_feats = {} 28 | self.mot_feats = {} 29 | vid_feat_file = osp.join(video_feature_path, 'vid_feat/app_mot_{}.h5'.format(mode)) 30 | with h5py.File(vid_feat_file, 'r') as fp: 31 | vids = fp['ids'] 32 | feats = fp['feat'] 33 | for id, (vid, feat) in enumerate(zip(vids, feats)): 34 | if self.use_frame: 35 | self.frame_feats[str(vid)] = feat[:, :2048] # (16, 2048) 36 | if self.use_mot: 37 | self.mot_feats[str(vid)] = feat[:, 2048:] # (16, 2048) 38 | 39 | 40 | def __len__(self): 41 | return len(self.sample_list) 42 | 43 | 44 | def get_video_feature(self, video_name): 45 | """ 46 | 47 | """ 48 | if self.use_frame: 49 | app_feat = self.frame_feats[video_name] 50 | video_feature = app_feat # (16, 2048) 51 | if self.use_mot: 52 | mot_feat = self.mot_feats[video_name] 53 | video_feature = np.concatenate((video_feature, mot_feat), axis=1) #(16, 4096) 54 | 55 | return torch.from_numpy(video_feature).type(torch.float32) 56 | 57 | 58 | def get_word_idx(self, text, src='qns'): 59 | """ 60 | convert relation to index sequence 61 | :param relation: 62 | :return: 63 | """ 64 | if src=='qns': vocab = self.vocab_qns 65 | elif src=='ans': vocab = self.vocab_ans 66 | tokens = nltk.tokenize.word_tokenize(str(text).lower()) 67 | text = [] 68 | text.append(vocab('')) 69 | text.extend([vocab(token) for i,token in enumerate(tokens) if i < 23]) 70 | #text.append(vocab('')) 71 | target = torch.Tensor(text) 72 | 73 | return target 74 | 75 | 76 | def __getitem__(self, idx): 77 | """ 78 | 79 | """ 80 | 81 | sample = self.sample_list.loc[idx] 82 | video_name, qns, ans = sample['video'], sample['question'], sample['answer'] 83 | qid, qtype = sample['qid'], sample['type'] 84 | video_name = str(video_name) 85 | qns, ans, qid, qtype = str(qns), str(ans), str(qid), str(qtype) 86 | 87 | 88 | #video_feature = torch.tensor([0]) 89 | video_feature = self.get_video_feature(video_name) 90 | 91 | qns2idx = self.get_word_idx(qns, 'qns') 92 | ans2idx = self.get_word_idx(ans, 'ans') 93 | 94 | return video_feature, qns2idx, ans2idx, video_name, qid, qtype 95 | 96 | 97 | class QALoader(): 98 | def __init__(self, batch_size, num_worker, video_feature_path, video_feature_cache, 99 | sample_list_path, vocab_qns, vocab_ans, train_shuffle=True, val_shuffle=False): 100 | self.batch_size = batch_size 101 | self.num_worker = num_worker 102 | self.video_feature_path = video_feature_path 103 | self.video_feature_cache = video_feature_cache 104 | self.sample_list_path = sample_list_path 105 | self.vocab_qns = vocab_qns 106 | self.vocab_ans = vocab_ans 107 | self.train_shuffle = train_shuffle 108 | self.val_shuffle = val_shuffle 109 | 110 | 111 | def run(self, mode=''): 112 | if mode != 'train': 113 | train_loader = '' 114 | val_loader = self.validate(mode) 115 | else: 116 | train_loader = self.train('train') 117 | val_loader = self.validate('val') 118 | return train_loader, val_loader 119 | 120 | 121 | def train(self, mode): 122 | 123 | training_set = VidQADataset(self.video_feature_path, self.video_feature_cache, self.sample_list_path, 124 | self.vocab_qns, self.vocab_ans, mode) 125 | 126 | print('Eligible QA pairs for training : {}'.format(len(training_set))) 127 | train_loader = DataLoader( 128 | dataset=training_set, 129 | batch_size=self.batch_size, 130 | shuffle=self.train_shuffle, 131 | num_workers=self.num_worker, 132 | collate_fn=collate_fn) 133 | 134 | return train_loader 135 | 136 | def validate(self, mode): 137 | 138 | validation_set = VidQADataset(self.video_feature_path, self.video_feature_cache, self.sample_list_path, 139 | self.vocab_qns, self.vocab_ans, mode) 140 | 141 | print('Eligible QA pairs for validation : {}'.format(len(validation_set))) 142 | val_loader = DataLoader( 143 | dataset=validation_set, 144 | batch_size=self.batch_size, 145 | shuffle=self.val_shuffle, 146 | num_workers=self.num_worker, 147 | collate_fn=collate_fn) 148 | 149 | return val_loader 150 | 151 | 152 | def collate_fn (data): 153 | """ 154 | """ 155 | data.sort(key=lambda x : len(x[1]), reverse=True) 156 | videos, qns2idx, ans2idx, video_names, qids, qtypes = zip(*data) 157 | 158 | #merge videos 159 | videos = torch.stack(videos, 0) 160 | 161 | #merge relations 162 | qns_lengths = [len(qns) for qns in qns2idx] 163 | targets_qns = torch.zeros(len(qns2idx), max(qns_lengths)).long() 164 | for i, qns in enumerate(qns2idx): 165 | end = qns_lengths[i] 166 | targets_qns[i, :end] = qns[:end] 167 | 168 | ans_lengths = [len(ans) for ans in ans2idx] 169 | targets_ans = torch.zeros(len(ans2idx), max(ans_lengths)).long() 170 | for i, ans in enumerate(ans2idx): 171 | end = ans_lengths[i] 172 | targets_ans[i, :end] = ans[:end] 173 | 174 | return videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes 175 | -------------------------------------------------------------------------------- /dataloader/util.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import numpy as np 5 | import pickle as pkl 6 | import pandas as pd 7 | 8 | def load_file(file_name): 9 | annos = None 10 | if osp.splitext(file_name)[-1] == '.csv': 11 | return pd.read_csv(file_name, delimiter=',') 12 | with open(file_name, 'r') as fp: 13 | if osp.splitext(file_name)[1]== '.txt': 14 | annos = fp.readlines() 15 | annos = [line.rstrip() for line in annos] 16 | if osp.splitext(file_name)[1] == '.json': 17 | annos = json.load(fp) 18 | 19 | return annos 20 | 21 | def save_file(obj, filename): 22 | """ 23 | save obj to filename 24 | :param obj: 25 | :param filename: 26 | :return: 27 | """ 28 | filepath = osp.dirname(filename) 29 | if filepath != '' and not osp.exists(filepath): 30 | os.makedirs(filepath) 31 | else: 32 | with open(filename, 'w') as fp: 33 | json.dump(obj, fp, indent=4) 34 | 35 | def pkload(file): 36 | data = None 37 | if osp.exists(file) and osp.getsize(file) > 0: 38 | with open(file, 'rb') as fp: 39 | data = pkl.load(fp) 40 | # print('{} does not exist'.format(file)) 41 | return data 42 | 43 | 44 | def pkdump(data, file): 45 | dirname = osp.dirname(file) 46 | if not osp.exists(dirname): 47 | os.makedirs(dirname) 48 | with open(file, 'wb') as fp: 49 | pkl.dump(data, fp) 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /dataset/nextqa/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all except .gitignore file 2 | * 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /eval_oe.py: -------------------------------------------------------------------------------- 1 | from utils import * 2 | from metrics import * 3 | import pandas as pd 4 | from pywsd.utils import lemmatize_sentence 5 | 6 | #use stopwords tailored for NExT-QA 7 | stopwords = load_file('stopwords.txt') 8 | def remove_stop(sentence): 9 | 10 | words = lemmatize_sentence(sentence) 11 | words = [w for w in words if not w in stopwords] 12 | return ' '.join(words) 13 | 14 | 15 | def evaluate(res_file, ref_file, ref_file_add): 16 | """ 17 | :param res_file: 18 | :param ref_file: 19 | :return: 20 | """ 21 | res = load_file(res_file) 22 | 23 | multi_ref_ans = False 24 | if osp.exists(ref_file_add): 25 | add_ref = load_file(ref_file_add) 26 | multi_ref_ans = True 27 | refer = pd.read_csv(ref_file) 28 | ref_num = len(refer) 29 | group_dict = {'CW': [], 'CH': [], 'TN': [], 'TC': [], 'DL': [], 'DB':[], 'DC': [], 'DO': []} 30 | for idx in range(ref_num): 31 | sample = refer.loc[idx] 32 | qtype = sample['type'] 33 | if qtype in ['TN', 'TP']: qtype = 'TN' 34 | group_dict[qtype].append(idx) 35 | wups0 = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DB': 0, 'DC': 0, 'DL': 0, 'DO': 0} 36 | wups9 = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DB': 0, 'DC': 0, 'DL': 0, 'DO': 0} 37 | wups0_e, wups0_t, wups0_c = 0, 0, 0 38 | wups0_all, wups9_all = 0, 0 39 | 40 | num = {'CW': 0, 'CH': 0, 'TN': 0, 'TC': 0, 'DB': 0, 'DC': 0, 'DL': 0, 'DO': 0} 41 | over_num = {'C':0, 'T':0, 'D':0} 42 | ref_num = 0 43 | for qtype, ids in group_dict.items(): 44 | for id in ids: 45 | sample = refer.loc[id] 46 | video, qid, ans, qns = str(sample['video']), str(sample['qid']), str(sample['answer']), str(sample['question']) 47 | num[qtype] += 1 48 | over_num[qtype[0]] += 1 49 | ref_num += 1 50 | 51 | pred_ans_src = res[video][qid] 52 | 53 | gt_ans = remove_stop(ans) 54 | pred_ans = remove_stop(pred_ans_src) 55 | if multi_ref_ans and (video in add_ref): 56 | gt_ans_add = remove_stop(add_ref[video][qid]) 57 | if qtype == 'DC' or qtype == 'DB': 58 | cur_0 = 1 if pred_ans == gt_ans_add or pred_ans == gt_ans else 0 59 | cur_9 = cur_0 60 | else: 61 | cur_0 = max(get_wups(pred_ans, gt_ans, 0), get_wups(pred_ans, gt_ans_add, 0)) 62 | cur_9 = max(get_wups(pred_ans, gt_ans, 0.9), get_wups(pred_ans, gt_ans_add, 0.9)) 63 | else: 64 | if qtype == 'DC' or qtype == 'DB': 65 | cur_0 = 1 if pred_ans == gt_ans else 0 66 | cur_9 = cur_0 67 | else: 68 | cur_0 = get_wups(pred_ans, gt_ans, 0) 69 | cur_9 = get_wups(pred_ans, gt_ans, 0.9) 70 | wups0[qtype] += cur_0 71 | wups9[qtype] += cur_9 72 | 73 | 74 | wups0_all += wups0[qtype] 75 | wups9_all += wups9[qtype] 76 | if qtype[0] == 'C': 77 | wups0_e += wups0[qtype] 78 | if qtype[0] == 'T': 79 | wups0_t += wups0[qtype] 80 | if qtype[0] == 'D': 81 | wups0_c += wups0[qtype] 82 | 83 | wups0[qtype] = wups0[qtype]/num[qtype] 84 | wups9[qtype] = wups9[qtype]/num[qtype] 85 | 86 | num_e = over_num['C'] 87 | num_t = over_num['T'] 88 | num_c = over_num['D'] 89 | 90 | wups0_e /= num_e 91 | wups0_t /= num_t 92 | wups0_c /= num_c 93 | 94 | wups0_all /= ref_num 95 | wups9_all /= ref_num 96 | 97 | for k in wups0: 98 | wups0[k] = wups0[k] * 100 99 | wups9[k] = wups9[k] * 100 100 | 101 | wups0_e *= 100 102 | wups0_t *= 100 103 | wups0_c *= 100 104 | wups0_all *= 100 105 | 106 | print('CW\tCH\tWUPS_C\tTPN\tTC\tWUPS_T\tDB\tDC\tDL\tDO\tWUPS_D\tWUPS') 107 | print('{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}\t{:.2f}' 108 | .format(wups0['CW'], wups0['CH'], wups0_e, wups0['TN'], wups0['TC'],wups0_t, 109 | wups0['DB'],wups0['DC'], wups0['DL'], wups0['DO'], wups0_c, wups0_all)) 110 | 111 | 112 | 113 | def main(filename, mode): 114 | res_dir = 'results' 115 | res_file = osp.join(res_dir, filename) 116 | print(f'Evaluate on {res_file}') 117 | ref_file = 'dataset/nextqa/{}.csv'.format(mode) 118 | ref_file_add = 'dataset/nextqa/add_reference_answer_{}.json'.format(mode) 119 | evaluate(res_file, ref_file, ref_file_add) 120 | 121 | 122 | if __name__ == "__main__": 123 | 124 | mode = 'val' 125 | model = 'HGA' 126 | result_file = '{}-same-att-qns23ans7-{}-example.json'.format(model, mode) 127 | main(result_file, mode) 128 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doc-doc/NExT-OE/a83f8f581191da07675e0fc83074e0dfcf907273/images/logo.png -------------------------------------------------------------------------------- /images/res-mc-oe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/doc-doc/NExT-OE/a83f8f581191da07675e0fc83074e0dfcf907273/images/res-mc-oe.png -------------------------------------------------------------------------------- /main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | GPU=$1 3 | MODE=$2 4 | CUDA_VISIBLE_DEVICES=$GPU python main_qa.py \ 5 | --mode $MODE 6 | -------------------------------------------------------------------------------- /main_qa.py: -------------------------------------------------------------------------------- 1 | from videoqa import * 2 | import dataloader 3 | from build_vocab import Vocabulary 4 | from utils import * 5 | import argparse 6 | import eval_oe 7 | 8 | 9 | def main(args): 10 | 11 | mode = args.mode 12 | if mode == 'train': 13 | batch_size = 64 14 | num_worker = 8 15 | else: 16 | batch_size = 64 17 | num_worker = 8 18 | spatial = False 19 | if spatial: 20 | #for STVQA 21 | video_feature_path = '../data/feats/spatial/' 22 | video_feature_cache = '../data/feats/cache_spatial/' 23 | else: 24 | video_feature_path = '../data/feats/' 25 | video_feature_cache = '../data/feats/cache/' 26 | 27 | dataset = 'nextqa' 28 | sample_list_path = 'dataset/{}/'.format(dataset) 29 | 30 | #We separate the dicts for qns and ans, in case one wants to use different word-dicts for them. 31 | vocab_qns = pkload('dataset/{}/vocab.pkl'.format(dataset)) 32 | vocab_ans = pkload('dataset/{}/vocab.pkl'.format(dataset)) 33 | 34 | word_type = 'glove' 35 | glove_embed_qns = 'dataset/{}/{}_embed.npy'.format(dataset, word_type) 36 | glove_embed_ans = 'dataset/{}/{}_embed.npy'.format(dataset, word_type) 37 | checkpoint_path = 'models' 38 | model_type = 'HGA' 39 | 40 | model_prefix = 'same-att-qns23ans7' 41 | vis_step = 116 42 | lr_rate = 5e-5 43 | epoch_num = 100 44 | 45 | data_loader = dataloader.QALoader(batch_size, num_worker, video_feature_path, video_feature_cache, 46 | sample_list_path, vocab_qns, vocab_ans, True, False) 47 | 48 | train_loader, val_loader = data_loader.run(mode=mode) 49 | 50 | vqa = VideoQA(vocab_qns, vocab_ans, train_loader, val_loader, glove_embed_qns, glove_embed_ans, 51 | checkpoint_path, model_type, model_prefix, vis_step,lr_rate, batch_size, epoch_num) 52 | 53 | ep = 36 54 | acc = 0.2163 55 | model_file = f'{model_type}-{model_prefix}-{ep}-{acc:.4f}.ckpt' 56 | 57 | if mode != 'train': 58 | result_file = f'{model_type}-{model_prefix}-{mode}.json' 59 | vqa.predict(model_file, result_file) 60 | eval_oe.main(result_file, mode) 61 | else: 62 | model_file = f'{model_type}-{model_prefix}-44-0.2140.ckpt' 63 | vqa.run(model_file, pre_trained=False) 64 | 65 | 66 | if __name__ == "__main__": 67 | torch.backends.cudnn.enabled = False 68 | torch.manual_seed(666) 69 | torch.cuda.manual_seed(666) 70 | torch.backends.cudnn.benchmark = True 71 | 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument('--gpu', dest='gpu', type=int, 74 | default=0, help='gpu device id') 75 | parser.add_argument('--mode', dest='mode', type=str, 76 | default='train', help='train or val') 77 | args = parser.parse_args() 78 | set_gpu_devices(args.gpu) 79 | main(args) 80 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | # ==================================================== 2 | # @Time : 13/9/20 4:19 PM 3 | # @Author : Xiao Junbin 4 | # @Email : junbin@comp.nus.edu.sg 5 | # @File : metrics.py 6 | # ==================================================== 7 | from nltk.tokenize import word_tokenize 8 | from nltk.corpus import wordnet 9 | import numpy as np 10 | 11 | def wup(word1, word2, alpha): 12 | """ 13 | calculate the wup similarity 14 | :param word1: 15 | :param word2: 16 | :param alpha: 17 | :return: 18 | """ 19 | # print(word1, word2) 20 | if word1 == word2: 21 | return 1.0 22 | 23 | w1 = wordnet.synsets(word1) 24 | w1_len = len(w1) 25 | if w1_len == 0: return 0.0 26 | w2 = wordnet.synsets(word2) 27 | w2_len = len(w2) 28 | if w2_len == 0: return 0.0 29 | 30 | #match the first 31 | word_sim = w1[0].wup_similarity(w2[0]) 32 | if word_sim is None: 33 | word_sim = 0.0 34 | 35 | if word_sim < alpha: 36 | word_sim = 0.1*word_sim 37 | return word_sim 38 | 39 | 40 | def wups(words1, words2, alpha): 41 | """ 42 | 43 | :param pred: 44 | :param truth: 45 | :param alpha: 46 | :return: 47 | """ 48 | sim = 1.0 49 | flag = False 50 | for w1 in words1: 51 | max_sim = 0 52 | for w2 in words2: 53 | word_sim = wup(w1, w2, alpha) 54 | if word_sim > max_sim: 55 | max_sim = word_sim 56 | if max_sim == 0: continue 57 | sim *= max_sim 58 | flag = True 59 | if not flag: 60 | sim = 0.0 61 | return sim 62 | 63 | 64 | def get_wups(pred, truth, alpha): 65 | """ 66 | calculate the wups score 67 | :param pred: 68 | :param truth: 69 | :return: 70 | """ 71 | pred = word_tokenize(pred) 72 | truth = word_tokenize(truth) 73 | item1 = wups(pred, truth, alpha) 74 | item2 = wups(truth, pred, alpha) 75 | value = min(item1, item2) 76 | return value -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all except .gitignore file 2 | * 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /networks/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | -------------------------------------------------------------------------------- /networks/Attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class TempAttention(nn.Module): 7 | """ 8 | Applies an attention mechanism on the output features from the decoder. 9 | """ 10 | 11 | def __init__(self, text_dim, visual_dim, hidden_dim): 12 | super(TempAttention, self).__init__() 13 | self.hidden_dim = hidden_dim 14 | self.linear_text = nn.Linear(text_dim, hidden_dim) 15 | self.linear_visual = nn.Linear(visual_dim, hidden_dim) 16 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False) 17 | self._init_weight() 18 | 19 | def _init_weight(self): 20 | nn.init.xavier_normal_(self.linear_text.weight) 21 | nn.init.xavier_normal_(self.linear_visual.weight) 22 | nn.init.xavier_normal_(self.linear_att.weight) 23 | 24 | def forward(self, qns_embed, vid_outputs): 25 | """ 26 | Arguments: 27 | qns_embed {Variable} -- batch_size x dim 28 | vid_outputs {Variable} -- batch_size x seq_len x dim 29 | 30 | Returns: 31 | context -- context vector of size batch_size x dim 32 | """ 33 | qns_embed_trans = self.linear_text(qns_embed) 34 | 35 | batch_size, seq_len, visual_dim = vid_outputs.size() 36 | vid_outputs_temp = vid_outputs.contiguous().view(batch_size*seq_len, visual_dim) 37 | vid_outputs_trans = self.linear_visual(vid_outputs_temp) 38 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.hidden_dim) 39 | 40 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1) 41 | 42 | 43 | o = self.linear_att(torch.tanh(qns_embed_trans+vid_outputs_trans)) 44 | 45 | e = o.view(batch_size, seq_len) 46 | beta = F.softmax(e, dim=1) 47 | context = torch.bmm(beta.unsqueeze(1), vid_outputs).squeeze(1) 48 | 49 | return context, beta 50 | 51 | 52 | class SpatialAttention(nn.Module): 53 | """ 54 | Apply spatial attention on vid feature before being fed into LSTM 55 | """ 56 | 57 | def __init__(self, text_dim=1024, vid_dim=3072, hidden_dim=512, input_dropout_p=0.2): 58 | super(SpatialAttention, self).__init__() 59 | 60 | self.linear_v = nn.Linear(vid_dim, hidden_dim) 61 | self.linear_q = nn.Linear(text_dim, hidden_dim) 62 | self.linear_att = nn.Linear(hidden_dim, 1, bias=False) 63 | 64 | self.softmax = nn.Softmax(dim=1) 65 | self.dropout = nn.Dropout(input_dropout_p) 66 | self._init_weight() 67 | 68 | def _init_weight(self): 69 | nn.init.xavier_normal_(self.linear_v.weight) 70 | nn.init.xavier_normal_(self.linear_q.weight) 71 | nn.init.xavier_normal_(self.linear_att.weight) 72 | 73 | def forward(self, qns_feat, vid_feats): 74 | """ 75 | Apply question feature as semantic clue to guide feature aggregation at each frame 76 | :param vid_feats: fnum x feat_dim x 7 x 7 77 | :param qns_feat: dim_hidden*2 78 | :return: 79 | """ 80 | # print(qns_feat.size(), vid_feats.size()) 81 | # permute to fnum x 7 x 7 x feat_dim 82 | vid_feats = vid_feats.permute(0, 2, 3, 1) 83 | fnum, width, height, feat_dim = vid_feats.size() 84 | vid_feats = vid_feats.contiguous().view(-1, feat_dim) 85 | vid_feats_trans = self.linear_v(vid_feats) 86 | 87 | vid_feats_trans = vid_feats_trans.view(fnum, width*height, -1) 88 | region_num = vid_feats_trans.shape[1] 89 | 90 | qns_feat_trans = self.linear_q(qns_feat) 91 | 92 | qns_feat_trans = qns_feat_trans.repeat(fnum, region_num, 1) 93 | # print(vid_feats_trans.shape, qns_feat_trans.shape) 94 | 95 | vid_qns = self.linear_att(torch.tanh(vid_feats_trans + qns_feat_trans)) 96 | 97 | vid_qns_o = vid_qns.view(fnum, region_num) 98 | alpha = self.softmax(vid_qns_o) 99 | alpha = alpha.unsqueeze(1) 100 | vid_feats = vid_feats.view(fnum, region_num, -1) 101 | feature = torch.bmm(alpha, vid_feats).squeeze(1) 102 | feature = self.dropout(feature) 103 | # print(feature.size()) 104 | return feature, alpha 105 | 106 | 107 | class TempAttentionHis(nn.Module): 108 | """ 109 | Applies an attention mechanism on the output features from the decoder. 110 | """ 111 | 112 | def __init__(self, visual_dim, text_dim, his_dim, mem_dim): 113 | super(TempAttentionHis, self).__init__() 114 | # self.dim = dim 115 | self.mem_dim = mem_dim 116 | self.linear_v = nn.Linear(visual_dim, self.mem_dim, bias=False) 117 | self.linear_q = nn.Linear(text_dim, self.mem_dim, bias=False) 118 | self.linear_his1 = nn.Linear(his_dim, self.mem_dim, bias=False) 119 | self.linear_his2 = nn.Linear(his_dim, self.mem_dim, bias=False) 120 | self.linear_att = nn.Linear(self.mem_dim, 1, bias=False) 121 | self._init_weight() 122 | 123 | 124 | def _init_weight(self): 125 | nn.init.xavier_normal_(self.linear_v.weight) 126 | nn.init.xavier_normal_(self.linear_q.weight) 127 | nn.init.xavier_normal_(self.linear_his1.weight) 128 | nn.init.xavier_normal_(self.linear_his2.weight) 129 | nn.init.xavier_normal_(self.linear_att.weight) 130 | 131 | 132 | def forward(self, qns_embed, vid_outputs, his): 133 | """ 134 | :param qns_embed: batch_size x 1024 135 | :param vid_outputs: batch_size x seq_num x feat_dim 136 | :param his: batch_size x 512 137 | :return: 138 | """ 139 | 140 | batch_size, seq_len, feat_dim = vid_outputs.size() 141 | vid_outputs_trans = self.linear_v(vid_outputs.contiguous().view(batch_size * seq_len, feat_dim)) 142 | vid_outputs_trans = vid_outputs_trans.view(batch_size, seq_len, self.mem_dim) 143 | 144 | qns_embed_trans = self.linear_q(qns_embed) 145 | qns_embed_trans = qns_embed_trans.unsqueeze(1).repeat(1, seq_len, 1) 146 | 147 | 148 | his_trans = self.linear_his1(his) 149 | his_trans = his_trans.unsqueeze(1).repeat(1, seq_len, 1) 150 | 151 | o = self.linear_att(torch.tanh(qns_embed_trans + vid_outputs_trans + his_trans)) 152 | 153 | e = o.view(batch_size, seq_len) 154 | beta = F.softmax(e, dim=1) 155 | context = torch.bmm(beta.unsqueeze(1), vid_outputs_trans).squeeze(1) 156 | 157 | his_acc = torch.tanh(self.linear_his2(his)) 158 | 159 | context += his_acc 160 | 161 | return context, beta 162 | 163 | 164 | class MultiModalAttentionModule(nn.Module): 165 | 166 | def __init__(self, hidden_size=512, simple=False): 167 | """Set the hyper-parameters and build the layers.""" 168 | super(MultiModalAttentionModule, self).__init__() 169 | 170 | self.hidden_size = hidden_size 171 | self.simple = simple 172 | 173 | # alignment model 174 | # see appendices A.1.2 of neural machine translation 175 | 176 | self.Wav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 177 | self.Wat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 178 | self.Uav = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 179 | self.Uat = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 180 | self.Vav = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 181 | self.Vat = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 182 | self.bav = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True) 183 | self.bat = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True) 184 | 185 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 186 | self.Wvh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 187 | self.Wth = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 188 | self.bh = nn.Parameter(torch.FloatTensor(1, 1, hidden_size), requires_grad=True) 189 | 190 | self.video_sum_encoder = nn.Linear(hidden_size, hidden_size) 191 | self.question_sum_encoder = nn.Linear(hidden_size, hidden_size) 192 | 193 | self.Wb = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 194 | self.Vbv = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 195 | self.Vbt = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size), requires_grad=True) 196 | self.bbv = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 197 | self.bbt = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 198 | self.wb = nn.Parameter(torch.FloatTensor(hidden_size), requires_grad=True) 199 | self.init_weights() 200 | 201 | def init_weights(self): 202 | self.Wav.data.normal_(0.0, 0.1) 203 | self.Wat.data.normal_(0.0, 0.1) 204 | self.Uav.data.normal_(0.0, 0.1) 205 | self.Uat.data.normal_(0.0, 0.1) 206 | self.Vav.data.normal_(0.0, 0.1) 207 | self.Vat.data.normal_(0.0, 0.1) 208 | self.bav.data.fill_(0) 209 | self.bat.data.fill_(0) 210 | 211 | self.Whh.data.normal_(0.0, 0.1) 212 | self.Wvh.data.normal_(0.0, 0.1) 213 | self.Wth.data.normal_(0.0, 0.1) 214 | self.bh.data.fill_(0) 215 | 216 | self.Wb.data.normal_(0.0, 0.01) 217 | self.Vbv.data.normal_(0.0, 0.01) 218 | self.Vbt.data.normal_(0.0, 0.01) 219 | self.wb.data.normal_(0.0, 0.01) 220 | 221 | self.bbv.data.fill_(0) 222 | self.bbt.data.fill_(0) 223 | 224 | def forward(self, h, hidden_frames, hidden_text, inv_attention=False): 225 | # print self.Uav 226 | # hidden_text: 1 x T1 x 1024 (looks like a two layer one-directional LSTM, combining each layer's hidden) 227 | # hidden_frame: 1 x T2 x 1024 (from video encoder output, 1024 is similar from above) 228 | 229 | # print hidden_frames.size(),hidden_text.size() 230 | Uhv = torch.matmul(h, self.Uav) # (1,512) 231 | Uhv = Uhv.view(Uhv.size(0), 1, Uhv.size(1)) # (1,1,512) 232 | 233 | Uht = torch.matmul(h, self.Uat) # (1,512) 234 | Uht = Uht.view(Uht.size(0), 1, Uht.size(1)) # (1,1,512) 235 | 236 | # print Uhv.size(),Uht.size() 237 | 238 | Wsv = torch.matmul(hidden_frames, self.Wav) # (1,T,512) 239 | # print Wsv.size() 240 | att_vec_v = torch.matmul(torch.tanh(Wsv + Uhv + self.bav), self.Vav) 241 | 242 | Wst = torch.matmul(hidden_text, self.Wat) # (1,T,512) 243 | att_vec_t = torch.matmul(torch.tanh(Wst + Uht + self.bat), self.Vat) 244 | 245 | if inv_attention == True: 246 | att_vec_v = -att_vec_v 247 | att_vec_t = -att_vec_t 248 | 249 | att_vec_v = torch.softmax(att_vec_v, dim=1) 250 | att_vec_t = torch.softmax(att_vec_t, dim=1) 251 | 252 | att_vec_v = att_vec_v.view(att_vec_v.size(0), att_vec_v.size(1), 1) # expand att_vec from 1xT to 1xTx1 253 | att_vec_t = att_vec_t.view(att_vec_t.size(0), att_vec_t.size(1), 1) # expand att_vec from 1xT to 1xTx1 254 | 255 | hv_weighted = att_vec_v * hidden_frames 256 | hv_sum = torch.sum(hv_weighted, dim=1) 257 | hv_sum2 = self.video_sum_encoder(hv_sum) 258 | 259 | ht_weighted = att_vec_t * hidden_text 260 | ht_sum = torch.sum(ht_weighted, dim=1) 261 | ht_sum2 = self.question_sum_encoder(ht_sum) 262 | 263 | Wbs = torch.matmul(h, self.Wb) 264 | mt1 = torch.matmul(ht_sum, self.Vbt) + self.bbt + Wbs 265 | mv1 = torch.matmul(hv_sum, self.Vbv) + self.bbv + Wbs 266 | mtv = torch.tanh(torch.cat([mv1, mt1], dim=0)) 267 | mtv2 = torch.matmul(mtv, self.wb) 268 | beta = torch.softmax(mtv2, dim=0) 269 | 270 | output = torch.tanh(torch.matmul(h, self.Whh) + beta[0] * hv_sum2 + 271 | beta[1] * ht_sum2 + self.bh) 272 | output = output.view(output.size(1), output.size(2)) 273 | 274 | return output -------------------------------------------------------------------------------- /networks/DecoderRNN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from .Attention import TempAttention 5 | 6 | 7 | class AnsUATT(nn.Module): 8 | """ 9 | """ 10 | def __init__(self, 11 | vocab_size, 12 | max_len, 13 | dim_hidden, 14 | dim_word, 15 | glove_embed, 16 | n_layers=1, 17 | rnn_cell='gru', 18 | bidirectional=False, 19 | input_dropout_p=0.2, 20 | rnn_dropout_p=0): 21 | super(AnsUATT, self).__init__() 22 | 23 | self.bidirectional_encoder = bidirectional 24 | self.vocab_size = vocab_size 25 | 26 | 27 | self.dim_hidden = dim_hidden 28 | self.dim_word = dim_word 29 | self.max_length = max_len 30 | self.glove_embed = glove_embed 31 | 32 | self.input_dropout = nn.Dropout(input_dropout_p) 33 | self.embedding = nn.Embedding(vocab_size, dim_word) 34 | 35 | if rnn_cell.lower() == 'lstm': 36 | self.rnn_cell = nn.LSTM 37 | elif rnn_cell.lower() == 'gru': 38 | self.rnn_cell = nn.GRU 39 | self.rnn = self.rnn_cell(self.dim_word, self.dim_hidden, n_layers, 40 | batch_first=True, dropout=rnn_dropout_p) 41 | 42 | self.out = nn.Linear(self.dim_hidden, vocab_size) 43 | 44 | self._init_weights() 45 | 46 | 47 | def forward(self, hidden, ans): 48 | """ 49 | decode answer 50 | :param encoder_outputs: 51 | :param encoder_hidden: 52 | :param ans: 53 | :param t: 54 | :return: 55 | """ 56 | ans_embed = self.embedding(ans) 57 | ans_embed = self.input_dropout(ans_embed) 58 | # ans_embed = self.transform(ans_embed) 59 | ans_embed = ans_embed.unsqueeze(1) 60 | decoder_input = ans_embed 61 | outputs, hidden = self.rnn(decoder_input, hidden) 62 | final_outputs = self.out(outputs.squeeze(1)) 63 | return final_outputs, hidden 64 | 65 | 66 | def sample(self, hidden, start=None): 67 | """ 68 | 69 | :param encoder_outputs: 70 | :param encoder_hidden: 71 | :return: 72 | """ 73 | sample_ids = [] 74 | start_embed = self.embedding(start) 75 | start_embed = self.input_dropout(start_embed) 76 | # start_embed = self.transform(start_embed) 77 | start_embed = start_embed.unsqueeze(1) 78 | 79 | inputs = start_embed 80 | 81 | for i in range(self.max_length): 82 | outputs, hidden = self.rnn(inputs, hidden) 83 | outputs = self.out(outputs.squeeze(1)) 84 | _, predict = outputs.max(1) 85 | sample_ids.append(predict) 86 | ans_embed = self.embedding(predict) 87 | ans_embed = self.input_dropout(ans_embed) 88 | # ans_embed = self.transform(ans_embed) 89 | 90 | inputs = ans_embed.unsqueeze(1) 91 | 92 | sample_ids = torch.stack(sample_ids, 1) 93 | return sample_ids 94 | 95 | 96 | def _init_weights(self): 97 | """ init the weight of some layers 98 | """ 99 | nn.init.xavier_normal_(self.out.weight) 100 | glove_embed = np.load(self.glove_embed) 101 | self.embedding.weight = nn.Parameter(torch.FloatTensor(glove_embed)) 102 | 103 | 104 | 105 | class AnsHME(nn.Module): 106 | """ 107 | """ 108 | 109 | def __init__(self, 110 | vocab_size, 111 | max_len, 112 | dim_hidden, 113 | dim_word, 114 | glove_embed, 115 | n_layers=1, 116 | rnn_cell='gru', 117 | bidirectional=False, 118 | input_dropout_p=0.2, 119 | rnn_dropout_p=0): 120 | super(AnsHME, self).__init__() 121 | 122 | self.bidirectional_encoder = bidirectional 123 | self.vocab_size = vocab_size 124 | 125 | 126 | self.dim_hidden = dim_hidden 127 | self.dim_word = dim_word 128 | self.max_length = max_len 129 | self.glove_embed = glove_embed 130 | 131 | self.input_dropout = nn.Dropout(input_dropout_p) 132 | self.embedding = nn.Embedding(vocab_size, dim_word) 133 | word_mat = torch.FloatTensor(np.load(self.glove_embed)) 134 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False) 135 | 136 | self.transform = nn.Sequential(nn.Linear(dim_hidden*2+dim_word, dim_hidden), 137 | nn.Dropout(input_dropout_p)) 138 | 139 | if rnn_cell.lower() == 'lstm': 140 | self.rnn_cell = nn.LSTM 141 | elif rnn_cell.lower() == 'gru': 142 | self.rnn_cell = nn.GRU 143 | self.rnn = self.rnn_cell(self.dim_hidden, self.dim_hidden, n_layers, 144 | batch_first=True, dropout=rnn_dropout_p) 145 | 146 | self.out = nn.Linear(self.dim_hidden, vocab_size) 147 | 148 | self._init_weights() 149 | 150 | 151 | def forward(self, encoder_outputs, hidden, ans_idx): 152 | """ 153 | decode answer 154 | :param encoder_outputs: 155 | :param encoder_hidden: 156 | :param ans: 157 | :param t: 158 | :return: 159 | """ 160 | ans_embed = self.embedding(ans_idx) 161 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1) 162 | ans_embed = self.transform(ans_embed) 163 | ans_embed = ans_embed.unsqueeze(1) 164 | decoder_input = ans_embed 165 | outputs, hidden = self.rnn(decoder_input, hidden) 166 | final_outputs = self.out(outputs.squeeze(1)) 167 | return final_outputs, hidden 168 | 169 | 170 | def sample(self, encoder_outputs, hidden, start=None): 171 | """ 172 | 173 | :param encoder_outputs: 174 | :param encoder_hidden: 175 | :return: 176 | """ 177 | sample_ids = [] 178 | start_embed = self.embedding(start) 179 | start_embed = torch.cat((encoder_outputs, start_embed), dim=-1) 180 | start_embed = self.transform(start_embed) 181 | start_embed = start_embed.unsqueeze(1) 182 | inputs = start_embed 183 | 184 | for i in range(self.max_length): 185 | outputs, hidden = self.rnn(inputs, hidden) 186 | outputs = self.out(outputs.squeeze(1)) 187 | _, predict = outputs.max(1) 188 | sample_ids.append(predict) 189 | ans_embed = self.embedding(predict) 190 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1) 191 | ans_embed = self.transform(ans_embed) 192 | inputs = ans_embed.unsqueeze(1) 193 | 194 | sample_ids = torch.stack(sample_ids, 1) 195 | return sample_ids 196 | 197 | 198 | def _init_weights(self): 199 | """ init the weight of some layers 200 | """ 201 | nn.init.xavier_normal_(self.out.weight) 202 | nn.init.xavier_normal_(self.transform[0].weight) 203 | 204 | 205 | class AnsQnsAns(nn.Module): 206 | """ 207 | """ 208 | 209 | def __init__(self, 210 | vocab_size, 211 | max_len, 212 | dim_hidden, 213 | dim_word, 214 | glove_embed, 215 | n_layers=1, 216 | rnn_cell='gru', 217 | bidirectional=False, 218 | input_dropout_p=0.2, 219 | rnn_dropout_p=0): 220 | super(AnsQnsAns, self).__init__() 221 | 222 | self.bidirectional_encoder = bidirectional 223 | self.vocab_size = vocab_size 224 | 225 | 226 | self.dim_hidden = dim_hidden 227 | self.dim_word = dim_word 228 | self.max_length = max_len 229 | self.glove_embed = glove_embed 230 | 231 | self.input_dropout = nn.Dropout(input_dropout_p) 232 | self.embedding = nn.Embedding(vocab_size, dim_word) 233 | word_mat = torch.FloatTensor(np.load(self.glove_embed)) 234 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False) 235 | 236 | self.transform = nn.Sequential(nn.Linear(dim_hidden+dim_word, dim_hidden), 237 | nn.Dropout(input_dropout_p)) 238 | 239 | if rnn_cell.lower() == 'lstm': 240 | self.rnn_cell = nn.LSTM 241 | elif rnn_cell.lower() == 'gru': 242 | self.rnn_cell = nn.GRU 243 | self.rnn = self.rnn_cell(self.dim_hidden, self.dim_hidden, n_layers, 244 | batch_first=True, dropout=rnn_dropout_p) 245 | 246 | self.out = nn.Linear(self.dim_hidden, vocab_size) 247 | 248 | self._init_weights() 249 | 250 | 251 | def forward(self, encoder_outputs, hidden, ans_idx): 252 | """ 253 | :param encoder_outputs: 254 | :param encoder_hidden: 255 | :param ans: 256 | :param t: 257 | :return: 258 | """ 259 | ans_embed = self.embedding(ans_idx) 260 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1) 261 | ans_embed = self.transform(ans_embed) 262 | ans_embed = ans_embed.unsqueeze(1) 263 | decoder_input = ans_embed 264 | outputs, hidden = self.rnn(decoder_input, hidden) 265 | final_outputs = self.out(outputs.squeeze(1)) 266 | return final_outputs, hidden 267 | 268 | 269 | def sample(self, encoder_outputs, hidden, start=None): 270 | """ 271 | 272 | :param encoder_outputs: 273 | :param encoder_hidden: 274 | :return: 275 | """ 276 | sample_ids = [] 277 | start_embed = self.embedding(start) 278 | start_embed = torch.cat((encoder_outputs, start_embed), dim=-1) 279 | start_embed = self.transform(start_embed) 280 | start_embed = start_embed.unsqueeze(1) 281 | inputs = start_embed 282 | 283 | for i in range(self.max_length): 284 | outputs, hidden = self.rnn(inputs, hidden) 285 | outputs = self.out(outputs.squeeze(1)) 286 | _, predict = outputs.max(1) 287 | sample_ids.append(predict) 288 | ans_embed = self.embedding(predict) 289 | ans_embed = torch.cat((encoder_outputs, ans_embed), dim=-1) 290 | ans_embed = self.transform(ans_embed) 291 | inputs = ans_embed.unsqueeze(1) 292 | 293 | sample_ids = torch.stack(sample_ids, 1) 294 | return sample_ids 295 | 296 | 297 | def _init_weights(self): 298 | """ init the weight of some layers 299 | """ 300 | nn.init.xavier_normal_(self.out.weight) 301 | nn.init.xavier_normal_(self.transform[0].weight) 302 | 303 | 304 | class AnsAttSeq(nn.Module): 305 | """ 306 | 307 | """ 308 | 309 | def __init__(self, 310 | vocab_size, 311 | max_len, 312 | dim_hidden, 313 | dim_word, 314 | glove_embed, 315 | n_layers=1, 316 | rnn_cell='gru', 317 | bidirectional=False, 318 | input_dropout_p=0.2, 319 | rnn_dropout_p=0): 320 | super(AnsAttSeq, self).__init__() 321 | 322 | self.bidirectional_encoder = bidirectional 323 | self.vocab_size = vocab_size 324 | 325 | 326 | self.dim_hidden = dim_hidden 327 | self.dim_word = dim_word 328 | self.max_length = max_len 329 | self.glove_embed = glove_embed 330 | 331 | self.input_dropout = nn.Dropout(input_dropout_p) 332 | self.embedding = nn.Embedding(vocab_size, dim_word) 333 | word_mat = torch.FloatTensor(np.load(self.glove_embed)) 334 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False) 335 | 336 | self.temp = TempAttention(dim_word, dim_hidden, dim_hidden//2) 337 | 338 | self.transform = nn.Sequential(nn.Linear(dim_word+dim_hidden, dim_hidden), 339 | nn.Dropout(input_dropout_p)) 340 | 341 | if rnn_cell.lower() == 'lstm': 342 | self.rnn_cell = nn.LSTM 343 | elif rnn_cell.lower() == 'gru': 344 | self.rnn_cell = nn.GRU 345 | self.rnn = self.rnn_cell(self.dim_hidden, self.dim_hidden, n_layers, 346 | batch_first=True, dropout=rnn_dropout_p) 347 | 348 | self.out = nn.Linear(self.dim_hidden, vocab_size) 349 | 350 | self._init_weights() 351 | 352 | 353 | def forward(self, seq_outs, hidden, ans_idx): 354 | """ 355 | decode answer 356 | :param encoder_outputs: 357 | :param encoder_hidden: 358 | :param ans: 359 | :param t: 360 | :return: 361 | """ 362 | ans_embed = self.embedding(ans_idx) 363 | ans_embed_att, _ = self.temp(ans_embed, seq_outs) 364 | ans_embed = torch.cat((ans_embed, ans_embed_att), dim=-1) 365 | 366 | ans_embed = self.transform(ans_embed) 367 | ans_embed = ans_embed.unsqueeze(1) 368 | decoder_input = ans_embed 369 | outputs, hidden = self.rnn(decoder_input, hidden) 370 | final_outputs = self.out(outputs.squeeze(1)) 371 | return final_outputs, hidden 372 | 373 | 374 | def sample(self, seq_outs, hidden, start=None): 375 | """ 376 | 377 | :param encoder_outputs: 378 | :param encoder_hidden: 379 | :return: 380 | """ 381 | sample_ids = [] 382 | start_embed = self.embedding(start) 383 | start_embed_att, _ = self.temp(start_embed, seq_outs) 384 | start_embed = torch.cat((start_embed, start_embed_att), dim=-1) 385 | start_embed = self.transform(start_embed) 386 | start_embed = start_embed.unsqueeze(1) 387 | inputs = start_embed 388 | 389 | for i in range(self.max_length): 390 | outputs, hidden = self.rnn(inputs, hidden) 391 | outputs = self.out(outputs.squeeze(1)) 392 | _, predict = outputs.max(1) 393 | sample_ids.append(predict) 394 | ans_embed = self.embedding(predict) 395 | ans_embed_att, _ = self.temp(ans_embed, seq_outs) 396 | ans_embed = torch.cat((ans_embed, ans_embed_att), dim=-1) 397 | ans_embed = self.transform(ans_embed) 398 | inputs = ans_embed.unsqueeze(1) 399 | 400 | sample_ids = torch.stack(sample_ids, 1) 401 | return sample_ids 402 | 403 | 404 | def _init_weights(self): 405 | """ init the weight of some layers 406 | """ 407 | nn.init.xavier_normal_(self.out.weight) 408 | nn.init.xavier_normal_(self.transform[0].weight) 409 | 410 | 411 | class AnsNavieTrans(nn.Module): 412 | """ 413 | """ 414 | def __init__(self, 415 | vocab_size, 416 | max_len, 417 | dim_hidden, 418 | dim_word, 419 | glove_embed, 420 | n_layers=1, 421 | rnn_cell='gru', 422 | bidirectional=False, 423 | input_dropout_p=0.2, 424 | rnn_dropout_p=0): 425 | super(AnsNavieTrans, self).__init__() 426 | 427 | self.bidirectional_encoder = bidirectional 428 | self.vocab_size = vocab_size 429 | 430 | 431 | self.dim_hidden = dim_hidden 432 | self.dim_word = dim_word 433 | self.max_length = max_len 434 | self.glove_embed = glove_embed 435 | 436 | self.input_dropout = nn.Dropout(input_dropout_p) 437 | self.embedding = nn.Embedding(vocab_size, dim_word) 438 | word_mat = torch.FloatTensor(np.load(self.glove_embed)) 439 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False) 440 | 441 | self.transform = nn.Sequential(nn.Linear(dim_word, dim_hidden), 442 | nn.Dropout(input_dropout_p)) 443 | 444 | if rnn_cell.lower() == 'lstm': 445 | self.rnn_cell = nn.LSTM 446 | elif rnn_cell.lower() == 'gru': 447 | self.rnn_cell = nn.GRU 448 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, 449 | batch_first=True, dropout=rnn_dropout_p) 450 | 451 | self.out = nn.Linear(self.dim_hidden, vocab_size) 452 | 453 | self._init_weights() 454 | 455 | 456 | def forward(self, hidden, ans_idx): 457 | """ 458 | decode answer 459 | :param encoder_outputs: 460 | :param encoder_hidden: 461 | :param ans: 462 | :param t: 463 | :return: 464 | """ 465 | ans_embed = self.embedding(ans_idx) 466 | ans_embed = self.transform(ans_embed) 467 | ans_embed = ans_embed.unsqueeze(1) 468 | decoder_input = ans_embed 469 | outputs, hidden = self.rnn(decoder_input, hidden) 470 | final_outputs = self.out(outputs.squeeze(1)) 471 | return final_outputs, hidden 472 | 473 | 474 | def sample(self, hidden, start=None): 475 | """ 476 | :param encoder_outputs: 477 | :param encoder_hidden: 478 | :return: 479 | """ 480 | sample_ids = [] 481 | start_embed = self.embedding(start) 482 | start_embed = self.transform(start_embed) 483 | start_embed = start_embed.unsqueeze(1) 484 | inputs = start_embed 485 | 486 | for i in range(self.max_length): 487 | outputs, hidden = self.rnn(inputs, hidden) 488 | outputs = self.out(outputs.squeeze(1)) 489 | _, predict = outputs.max(1) 490 | sample_ids.append(predict) 491 | ans_embed = self.embedding(predict) 492 | ans_embed = self.transform(ans_embed) 493 | inputs = ans_embed.unsqueeze(1) 494 | 495 | sample_ids = torch.stack(sample_ids, 1) 496 | return sample_ids 497 | 498 | 499 | def _init_weights(self): 500 | """ init the weight of some layers 501 | """ 502 | nn.init.xavier_normal_(self.out.weight) 503 | nn.init.xavier_normal_(self.transform[0].weight) 504 | 505 | 506 | -------------------------------------------------------------------------------- /networks/EncoderRNN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence 4 | import numpy as np 5 | 6 | class EncoderQns(nn.Module): 7 | def __init__(self, dim_embed, dim_hidden, vocab_size, glove_embed, input_dropout_p=0.2, rnn_dropout_p=0, 8 | n_layers=1, bidirectional=False, rnn_cell='gru'): 9 | """ 10 | """ 11 | super(EncoderQns, self).__init__() 12 | self.dim_hidden = dim_hidden 13 | self.vocab_size = vocab_size 14 | self.glove_embed = glove_embed 15 | self.input_dropout_p = input_dropout_p 16 | self.rnn_dropout_p = rnn_dropout_p 17 | self.n_layers = n_layers 18 | self.bidirectional = bidirectional 19 | self.rnn_cell = rnn_cell 20 | 21 | self.input_dropout = nn.Dropout(input_dropout_p) 22 | 23 | if rnn_cell.lower() == 'lstm': 24 | self.rnn_cell = nn.LSTM 25 | elif rnn_cell.lower() == 'gru': 26 | self.rnn_cell = nn.GRU 27 | 28 | self.embedding = nn.Embedding(vocab_size, dim_embed) 29 | word_mat = torch.FloatTensor(np.load(self.glove_embed)) 30 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False) 31 | self.rnn = self.rnn_cell(dim_embed, dim_hidden, n_layers, batch_first=True, 32 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 33 | 34 | 35 | def forward(self, qns, qns_lengths, hidden=None): 36 | """ 37 | """ 38 | qns_embed = self.embedding(qns) 39 | qns_embed = self.input_dropout(qns_embed) 40 | packed = pack_padded_sequence(qns_embed, qns_lengths, batch_first=True) 41 | packed_output, hidden = self.rnn(packed, hidden) 42 | output, _ = pad_packed_sequence(packed_output, batch_first=True) 43 | return output, hidden 44 | 45 | 46 | class EncoderVid(nn.Module): 47 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 48 | n_layers=1, bidirectional=False, rnn_cell='gru'): 49 | """ 50 | """ 51 | super(EncoderVid, self).__init__() 52 | self.dim_vid = dim_vid 53 | self.dim_app = 2048 54 | self.dim_motion = 4096 55 | self.dim_hidden = dim_hidden 56 | self.input_dropout_p = input_dropout_p 57 | self.rnn_dropout_p = rnn_dropout_p 58 | self.n_layers = n_layers 59 | self.bidirectional = bidirectional 60 | self.rnn_cell = rnn_cell 61 | 62 | if rnn_cell.lower() == 'lstm': 63 | self.rnn_cell = nn.LSTM 64 | elif rnn_cell.lower() == 'gru': 65 | self.rnn_cell = nn.GRU 66 | 67 | self.rnn = self.rnn_cell(dim_vid, dim_hidden, n_layers, batch_first=True, 68 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 69 | 70 | 71 | def forward(self, vid_feats): 72 | """ 73 | """ 74 | self.rnn.flatten_parameters() 75 | foutput, fhidden = self.rnn(vid_feats) 76 | 77 | return foutput, fhidden 78 | 79 | 80 | class EncoderVidSTVQA(nn.Module): 81 | def __init__(self, input_dim, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 82 | n_layers=1, bidirectional=False, rnn_cell='gru'): 83 | """ 84 | """ 85 | super(EncoderVidSTVQA, self).__init__() 86 | self.input_dim = input_dim 87 | self.dim_hidden = dim_hidden 88 | self.input_dropout_p = input_dropout_p 89 | self.rnn_dropout_p = rnn_dropout_p 90 | self.n_layers = n_layers 91 | self.bidirectional = bidirectional 92 | self.rnn_cell = rnn_cell 93 | 94 | 95 | if rnn_cell.lower() == 'lstm': 96 | self.rnn_cell = nn.LSTM 97 | elif rnn_cell.lower() == 'gru': 98 | self.rnn_cell = nn.GRU 99 | 100 | self.rnn1 = self.rnn_cell(input_dim, dim_hidden, n_layers, batch_first=True, 101 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 102 | 103 | self.rnn2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 104 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 105 | 106 | 107 | def forward(self, vid_feats): 108 | """ 109 | Dual-layer LSTM 110 | """ 111 | 112 | self.rnn1.flatten_parameters() 113 | 114 | foutput_1, fhidden_1 = self.rnn1(vid_feats) 115 | self.rnn2.flatten_parameters() 116 | foutput_2, fhidden_2 = self.rnn2(foutput_1) 117 | 118 | foutput = torch.cat((foutput_1, foutput_2), dim=2) 119 | fhidden = (torch.cat((fhidden_1[0], fhidden_2[0]), dim=0), 120 | torch.cat((fhidden_1[1], fhidden_2[1]), dim=0)) 121 | 122 | return foutput, fhidden 123 | 124 | 125 | class EncoderVidCoMem(nn.Module): 126 | def __init__(self, dim_app, dim_motion, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 127 | n_layers=1, bidirectional=False, rnn_cell='gru'): 128 | """ 129 | """ 130 | super(EncoderVidCoMem, self).__init__() 131 | self.dim_app = dim_app 132 | self.dim_motion = dim_motion 133 | self.dim_hidden = dim_hidden 134 | self.input_dropout_p = input_dropout_p 135 | self.rnn_dropout_p = rnn_dropout_p 136 | self.n_layers = n_layers 137 | self.bidirectional = bidirectional 138 | self.rnn_cell = rnn_cell 139 | 140 | if rnn_cell.lower() == 'lstm': 141 | self.rnn_cell = nn.LSTM 142 | elif rnn_cell.lower() == 'gru': 143 | self.rnn_cell = nn.GRU 144 | 145 | self.rnn_app_l1 = self.rnn_cell(self.dim_app, dim_hidden, n_layers, batch_first=True, 146 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 147 | self.rnn_app_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 148 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 149 | 150 | self.rnn_motion_l1 = self.rnn_cell(self.dim_motion, dim_hidden, n_layers, batch_first=True, 151 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 152 | self.rnn_motion_l2 = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 153 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 154 | 155 | 156 | def forward(self, vid_feats): 157 | """ 158 | two separate LSTM to encode app and motion feature 159 | :param vid_feats: 160 | :return: 161 | """ 162 | vid_app = vid_feats[:, :, 0:self.dim_app] 163 | vid_motion = vid_feats[:, :, self.dim_app:] 164 | 165 | app_output_l1, app_hidden_l1 = self.rnn_app_l1(vid_app) 166 | app_output_l2, app_hidden_l2 = self.rnn_app_l2(app_output_l1) 167 | 168 | 169 | motion_output_l1, motion_hidden_l1 = self.rnn_motion_l1(vid_motion) 170 | motion_output_l2, motion_hidden_l2 = self.rnn_motion_l2(motion_output_l1) 171 | 172 | 173 | return app_output_l1, app_output_l2, motion_output_l1, motion_output_l2 174 | 175 | 176 | 177 | class EncoderQnsHGA(nn.Module): 178 | def __init__(self, dim_embed, dim_hidden, vocab_size, glove_embed, input_dropout_p=0.2, rnn_dropout_p=0, 179 | n_layers=1, bidirectional=False, rnn_cell='gru'): 180 | """ 181 | 182 | """ 183 | super(EncoderQnsHGA, self).__init__() 184 | self.dim_hidden = dim_hidden 185 | self.vocab_size = vocab_size 186 | self.glove_embed = glove_embed 187 | self.input_dropout_p = input_dropout_p 188 | self.rnn_dropout_p = rnn_dropout_p 189 | self.n_layers = n_layers 190 | self.bidirectional = bidirectional 191 | self.rnn_cell = rnn_cell 192 | 193 | self.input_dropout = nn.Dropout(input_dropout_p) 194 | 195 | if rnn_cell.lower() == 'lstm': 196 | self.rnn_cell = nn.LSTM 197 | elif rnn_cell.lower() == 'gru': 198 | self.rnn_cell = nn.GRU 199 | 200 | self.embedding = nn.Embedding(vocab_size, dim_embed) 201 | word_mat = torch.FloatTensor(np.load(self.glove_embed)) 202 | self.embedding = nn.Embedding.from_pretrained(word_mat, freeze=False) 203 | self.FC = nn.Sequential(nn.Linear(dim_embed, dim_hidden, bias=False), 204 | nn.ReLU(), 205 | ) 206 | self.rnn = self.rnn_cell(dim_hidden, dim_hidden, n_layers, batch_first=True, 207 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 208 | 209 | 210 | def forward(self, qns, qns_lengths, hidden=None): 211 | """ 212 | """ 213 | qns_embed = self.embedding(qns) 214 | qns_embed = self.input_dropout(qns_embed) 215 | qns_embed = self.FC(qns_embed) 216 | packed = pack_padded_sequence(qns_embed, qns_lengths, batch_first=True) 217 | packed_output, hidden = self.rnn(packed, hidden) 218 | output, _ = pad_packed_sequence(packed_output, batch_first=True) 219 | return output, hidden 220 | 221 | 222 | class EncoderVidHGA(nn.Module): 223 | def __init__(self, dim_vid, dim_hidden, input_dropout_p=0.2, rnn_dropout_p=0, 224 | n_layers=1, bidirectional=False, rnn_cell='gru'): 225 | """ 226 | """ 227 | super(EncoderVidHGA, self).__init__() 228 | self.dim_vid = dim_vid 229 | self.dim_app = 2048 230 | self.dim_mot = 2048 231 | self.dim_hidden = dim_hidden 232 | self.input_dropout_p = input_dropout_p 233 | self.rnn_dropout_p = rnn_dropout_p 234 | self.n_layers = n_layers 235 | self.bidirectional = bidirectional 236 | self.rnn_cell = rnn_cell 237 | 238 | 239 | self.mot2hid = nn.Sequential(nn.Linear(self.dim_mot, self.dim_app), 240 | nn.Dropout(input_dropout_p), 241 | nn.ReLU()) 242 | 243 | self.appmot2hid = nn.Sequential(nn.Linear(self.dim_app*2, self.dim_app), 244 | nn.Dropout(input_dropout_p), 245 | nn.ReLU()) 246 | 247 | if rnn_cell.lower() == 'lstm': 248 | self.rnn_cell = nn.LSTM 249 | elif rnn_cell.lower() == 'gru': 250 | self.rnn_cell = nn.GRU 251 | 252 | self.rnn = self.rnn_cell(self.dim_app, dim_hidden, n_layers, batch_first=True, 253 | bidirectional=bidirectional, dropout=self.rnn_dropout_p) 254 | 255 | 256 | def forward(self, vid_feats): 257 | """ 258 | """ 259 | vid_app = vid_feats[:, :, 0:self.dim_app] 260 | vid_motion = vid_feats[:, :, self.dim_app:] 261 | 262 | vid_motion_redu = self.mot2hid(vid_motion) 263 | vid_feats = self.appmot2hid(torch.cat((vid_app, vid_motion_redu), dim=2)) 264 | 265 | self.rnn.flatten_parameters() 266 | foutput, fhidden = self.rnn(vid_feats) 267 | 268 | return foutput, fhidden -------------------------------------------------------------------------------- /networks/VQAModel/CoMem.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random as rd 4 | import sys 5 | sys.path.insert(0, 'networks') 6 | from memory_module import EpisodicMemory 7 | 8 | class CoMem(nn.Module): 9 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2): 10 | """ 11 | motion-appearance co-memory networks for video question answering (CVPR18) 12 | """ 13 | super(CoMem, self).__init__() 14 | self.vid_encoder = vid_encoder 15 | self.qns_encoder = qns_encoder 16 | self.ans_decoder = ans_decoder 17 | 18 | dim = qns_encoder.dim_hidden 19 | 20 | self.epm_app = EpisodicMemory(dim*2) 21 | self.epm_mot = EpisodicMemory(dim*2) 22 | 23 | self.linear_ma = nn.Linear(dim*2*3, dim*2) 24 | self.linear_mb = nn.Linear(dim*2*3, dim*2) 25 | 26 | self.vq2word = nn.Linear(dim*2*2, dim) 27 | self._init_weights() 28 | self.device = device 29 | 30 | def _init_weights(self): 31 | """ 32 | initialize the linear weights 33 | :return: 34 | """ 35 | nn.init.xavier_normal_(self.linear_ma.weight) 36 | nn.init.xavier_normal_(self.linear_mb.weight) 37 | nn.init.xavier_normal_(self.vq2word.weight) 38 | 39 | 40 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'): 41 | """ 42 | Co-memory network 43 | """ 44 | 45 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim) 46 | 47 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1) 48 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1) 49 | 50 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths) 51 | 52 | # qns_output = qns_output.permute(1, 0, 2) 53 | batch_size, seq_len, qns_feat_dim = qns_output.size() 54 | 55 | 56 | qns_embed = qns_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 57 | 58 | m_app = outputs_app[:, -1, :] 59 | m_mot = outputs_motion[:, -1, :] 60 | ma, mb = m_app.detach(), m_mot.detach() 61 | m_app = m_app.unsqueeze(1) 62 | m_mot = m_mot.unsqueeze(1) 63 | for _ in range(iter_num): 64 | mm = ma + mb 65 | m_app = self.epm_app(outputs_app, mm, m_app) 66 | m_mot = self.epm_mot(outputs_motion, mm, m_mot) 67 | ma_q = torch.cat((ma, m_app.squeeze(1), qns_embed), dim=1) 68 | mb_q = torch.cat((mb, m_mot.squeeze(1), qns_embed), dim=1) 69 | # print(ma_q.shape) 70 | ma = torch.tanh(self.linear_ma(ma_q)) 71 | mb = torch.tanh(self.linear_mb(mb_q)) 72 | 73 | mem = torch.cat((ma, mb), dim=1) 74 | encoder_outputs = self.vq2word(mem) 75 | # hidden = qns_hidden 76 | hidden = encoder_outputs.unsqueeze(0) 77 | 78 | # decoder_inputs = encoder_outputs 79 | 80 | if mode == 'train': 81 | vocab_size = self.ans_decoder.vocab_size 82 | ans_len = ans.shape[1] 83 | input = ans[:, 0] 84 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 85 | 86 | for t in range(0, ans_len): 87 | output, hidden = self.ans_decoder(qns_output, hidden, input) 88 | outputs[:, t] = output 89 | teacher_force = rd.random() < teacher_force_ratio 90 | top1 = output.argmax(1) 91 | input = ans[:, t] if teacher_force else top1 92 | else: 93 | start = torch.LongTensor([1] * batch_size).to(self.device) 94 | outputs = self.ans_decoder.sample(qns_output, hidden, start) 95 | 96 | return outputs -------------------------------------------------------------------------------- /networks/VQAModel/EVQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random as rd 4 | 5 | class EVQA(nn.Module): 6 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device): 7 | """ 8 | :param vid_encoder: 9 | :param qns_encoder: 10 | :param ans_decoder: 11 | :param device: 12 | """ 13 | super(EVQA, self).__init__() 14 | self.vid_encoder = vid_encoder 15 | self.qns_encoder = qns_encoder 16 | self.ans_decoder = ans_decoder 17 | self.device = device 18 | 19 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 20 | 21 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats) 22 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths) 23 | 24 | 25 | hidden = qns_hidden[0] +vid_hidden[0] 26 | batch_size = qns.shape[0] 27 | 28 | if mode == 'train': 29 | vocab_size = self.ans_decoder.vocab_size 30 | ans_len = ans.shape[1] 31 | input = ans[:, 0] 32 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 33 | for t in range(0, ans_len): 34 | output, hidden = self.ans_decoder(qns_outputs, hidden, input) 35 | outputs[:,t] = output 36 | teacher_force = rd.random() < teacher_force_ratio 37 | top1 = output.argmax(1) 38 | input = ans[:, t] if teacher_force else top1 39 | else: 40 | start = torch.LongTensor([1] * batch_size).to(self.device) 41 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start) 42 | 43 | return outputs 44 | -------------------------------------------------------------------------------- /networks/VQAModel/HGA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random as rd 4 | import sys 5 | sys.path.insert(0, 'networks') 6 | from q_v_transformer import CoAttention 7 | from gcn import AdjLearner, GCN 8 | from block import fusions #pytorch >= 1.1.0 9 | 10 | class HGA(nn.Module): 11 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device): 12 | """ 13 | Reasoning with Heterogeneous Graph Alignment for Video Question Answering (AAAI20) 14 | """ 15 | super(HGA, self).__init__() 16 | self.vid_encoder = vid_encoder 17 | self.qns_encoder = qns_encoder 18 | self.ans_decoder = ans_decoder 19 | self.max_len_v = max_len_v 20 | self.max_len_q = max_len_q 21 | self.device = device 22 | hidden_size = vid_encoder.dim_hidden 23 | input_dropout_p = vid_encoder.input_dropout_p 24 | 25 | self.q_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False) 26 | self.v_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False) 27 | 28 | self.co_attn = CoAttention( 29 | hidden_size, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 30 | 31 | self.adj_learner = AdjLearner( 32 | hidden_size, hidden_size, dropout=input_dropout_p) 33 | 34 | self.gcn = GCN( 35 | hidden_size, 36 | hidden_size, 37 | hidden_size, 38 | num_layers=2, 39 | dropout=input_dropout_p) 40 | 41 | self.gcn_atten_pool = nn.Sequential( 42 | nn.Linear(hidden_size, hidden_size // 2), 43 | nn.Tanh(), 44 | nn.Linear(hidden_size // 2, 1), 45 | nn.Softmax(dim=-1)) #dim=-2 for attention-pooling otherwise sum-pooling 46 | 47 | self.global_fusion = fusions.Block( 48 | [hidden_size, hidden_size], hidden_size, dropout_input=input_dropout_p) 49 | 50 | self.fusion = fusions.Block([hidden_size, hidden_size], hidden_size) 51 | 52 | 53 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 54 | """ 55 | 56 | """ 57 | encoder_out, qns_hidden, qns_out, vid_out = self.vq_encoder(vid_feats, qns, qns_lengths) 58 | 59 | batch_size = encoder_out.shape[0] 60 | 61 | hidden = encoder_out.unsqueeze(0) 62 | if mode == 'train': 63 | vocab_size = self.ans_decoder.vocab_size 64 | ans_len = ans.shape[1] 65 | input = ans[:, 0] 66 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 67 | for t in range(0, ans_len): 68 | 69 | output, hidden = self.ans_decoder(qns_out, hidden, input) #attqns, attvid 70 | outputs[:, t] = output 71 | teacher_force = rd.random() < teacher_force_ratio 72 | top1 = output.argmax(1) 73 | input = ans[:, t] if teacher_force else top1 74 | else: 75 | start = torch.LongTensor([1] * batch_size).to(self.device) 76 | 77 | outputs = self.ans_decoder.sample(qns_out, hidden, start) #vidatt, qns_att 78 | 79 | return outputs 80 | 81 | 82 | def vq_encoder(self, vid_feats, qns, qns_lengths): 83 | """ 84 | 85 | :param vid_feats: 86 | :param qns: 87 | :param qns_lengths: 88 | :return: 89 | """ 90 | q_output, s_hidden = self.qns_encoder(qns, qns_lengths) 91 | qns_last_hidden = torch.squeeze(s_hidden) 92 | 93 | 94 | v_output, v_hidden = self.vid_encoder(vid_feats) 95 | vid_last_hidden = torch.squeeze(v_hidden) 96 | 97 | q_output = self.q_input_ln(q_output) 98 | v_output = self.v_input_ln(v_output) 99 | 100 | q_output, v_output = self.co_attn(q_output, v_output) 101 | 102 | ### GCN 103 | adj = self.adj_learner(q_output, v_output) 104 | # q_v_inputs of shape (batch_size, q_v_len, hidden_size) 105 | q_v_inputs = torch.cat((q_output, v_output), dim=1) 106 | # q_v_output of shape (batch_size, q_v_len, hidden_size) 107 | q_v_output = self.gcn(q_v_inputs, adj) 108 | 109 | ## attention pool 110 | local_attn = self.gcn_atten_pool(q_v_output) 111 | local_out = torch.sum(q_v_output * local_attn, dim=1) 112 | 113 | # print(qns_last_hidden.shape, vid_last_hidden.shape) 114 | global_out = self.global_fusion((qns_last_hidden, vid_last_hidden)) 115 | 116 | 117 | out = self.fusion((global_out, local_out)).squeeze() #4 x 512 118 | 119 | return out, s_hidden, q_output, v_output, 120 | 121 | -------------------------------------------------------------------------------- /networks/VQAModel/HME.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random as rd 4 | import sys 5 | sys.path.insert(0, 'networks') 6 | from Attention import TempAttention 7 | from memory_rand import MemoryRamTwoStreamModule, MemoryRamModule, MMModule 8 | 9 | 10 | class HME(nn.Module): 11 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2): 12 | """ 13 | Heterogeneous memory enhanced multimodal attention model for video question answering (CVPR19) 14 | 15 | """ 16 | super(HME, self).__init__() 17 | self.vid_encoder = vid_encoder 18 | self.qns_encoder = qns_encoder 19 | self.ans_decoder = ans_decoder 20 | 21 | dim = qns_encoder.dim_hidden 22 | 23 | self.temp_att_a = TempAttention(dim * 2, dim * 2, hidden_dim=256) 24 | self.temp_att_m = TempAttention(dim * 2, dim * 2, hidden_dim=256) 25 | self.mrm_vid = MemoryRamTwoStreamModule(dim, dim, max_len_v, device) 26 | self.mrm_txt = MemoryRamModule(dim, dim, max_len_q, device) 27 | 28 | self.mm_module_v1 = MMModule(dim, input_drop_p, device) 29 | 30 | self.linear_vid = nn.Linear(dim*2, dim) 31 | self.linear_qns = nn.Linear(dim*2, dim) 32 | self.linear_mem = nn.Linear(dim*2, dim) 33 | self.vq2word_hme = nn.Linear(dim*3, dim*2) 34 | self._init_weights() 35 | self.device = device 36 | 37 | def _init_weights(self): 38 | """ 39 | initialize the linear weights 40 | :return: 41 | """ 42 | nn.init.xavier_normal_(self.linear_vid.weight) 43 | nn.init.xavier_normal_(self.linear_qns.weight) 44 | nn.init.xavier_normal_(self.linear_mem.weight) 45 | nn.init.xavier_normal_(self.vq2word_hme.weight) 46 | 47 | 48 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'): 49 | """ 50 | """ 51 | 52 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim) 53 | 54 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1) 55 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1) 56 | 57 | batch_size, fnum, vid_feat_dim = outputs_app.size() 58 | 59 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths) 60 | # print(qns_output.shape, qns_hidden[0].shape) #torch.Size([10, 23, 256]) torch.Size([2, 10, 256]) 61 | 62 | # qns_output = qns_output.permute(1, 0, 2) 63 | batch_size, seq_len, qns_feat_dim = qns_output.size() 64 | 65 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 66 | 67 | # Apply temporal attention 68 | att_app, beta_app = self.temp_att_a(qns_embed, outputs_app) 69 | att_motion, beta_motion = self.temp_att_m(qns_embed, outputs_motion) 70 | tmp_app_motion = torch.cat((outputs_app_l2[:, -1, :], outputs_motion_l2[:, -1, :]), dim=-1) 71 | 72 | mem_output = torch.zeros(batch_size, vid_feat_dim).to(self.device) 73 | 74 | for bs in range(batch_size): 75 | mem_ram_vid = self.mrm_vid(outputs_app_l2[bs], outputs_motion_l2[bs], fnum) 76 | cur_qns = qns_output[bs][:qns_lengths[bs]] 77 | mem_ram_txt = self.mrm_txt(cur_qns, qns_lengths[bs]) #should remove padded zeros 78 | mem_output[bs] = self.mm_module_v1(tmp_app_motion[bs].unsqueeze(0), mem_ram_vid, mem_ram_txt, iter_num) 79 | """ 80 | (64, 256) (22, 256) (1, 512) 81 | """ 82 | app_trans = torch.tanh(self.linear_vid(att_app)) 83 | motion_trans = torch.tanh(self.linear_vid(att_motion)) 84 | mem_trans = torch.tanh(self.linear_mem(mem_output)) 85 | 86 | encoder_outputs = torch.cat((app_trans, motion_trans, mem_trans), dim=1) 87 | decoder_inputs = self.vq2word_hme(encoder_outputs) 88 | hidden = qns_hidden 89 | if mode == 'train': 90 | vocab_size = self.ans_decoder.vocab_size 91 | ans_len = ans.shape[1] 92 | input = ans[:, 0] 93 | 94 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 95 | 96 | for t in range(0, ans_len): 97 | output, hidden = self.ans_decoder(decoder_inputs, hidden, input) 98 | outputs[:, t] = output 99 | teacher_force = rd.random() < teacher_force_ratio 100 | top1 = output.argmax(1) 101 | input = ans[:, t] if teacher_force else top1 102 | else: 103 | start = torch.LongTensor([1] * batch_size).to(self.device) 104 | outputs = self.ans_decoder.sample(decoder_inputs, hidden, start) 105 | 106 | return outputs -------------------------------------------------------------------------------- /networks/VQAModel/STVQA.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random as rd 4 | import sys 5 | sys.path.insert(0, 'networks') 6 | from Attention import TempAttention, SpatialAttention 7 | 8 | 9 | class STVQA(nn.Module): 10 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, att_dim, device): 11 | """ 12 | TGIF-QA: Toward Spatio-Temporal Reasoning in Visual Question Answering (CVPR17) 13 | """ 14 | super(STVQA, self).__init__() 15 | self.vid_encoder = vid_encoder 16 | self.qns_encoder = qns_encoder 17 | self.ans_decoder = ans_decoder 18 | self.att_dim = att_dim 19 | 20 | self.spatial_att = SpatialAttention(qns_encoder.dim_hidden*2, vid_encoder.input_dim, hidden_dim=self.att_dim) 21 | self.temp_att = TempAttention(qns_encoder.dim_hidden*2, vid_encoder.dim_hidden*2, hidden_dim=self.att_dim) 22 | self.device = device 23 | self.FC = nn.Linear(att_dim*2, att_dim) 24 | 25 | 26 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 27 | """ 28 | """ 29 | qns_output_1, qns_hidden_1 = self.qns_encoder(qns, qns_lengths) 30 | n_layers, batch_size, qns_dim = qns_hidden_1[0].size() 31 | 32 | # Concatenate the dual-layer hidden as qns embedding 33 | qns_embed = qns_hidden_1[0].permute(1, 0, 2) # batch first 34 | qns_embed = qns_embed.reshape(batch_size, -1) #(batch_size, feat_dim*2) 35 | batch_size, fnum, vid_dim, w, h = vid_feats.size() 36 | 37 | # Apply spatial attention 38 | vid_att_feats = torch.zeros(batch_size, fnum, vid_dim).to(self.device) 39 | for bs in range(batch_size): 40 | vid_att_feats[bs], alpha = self.spatial_att(qns_embed[bs], vid_feats[bs]) 41 | 42 | vid_outputs, vid_hidden = self.vid_encoder(vid_att_feats) 43 | 44 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths, vid_hidden) 45 | 46 | """ 47 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 48 | torch.Size([16, 3, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 49 | """ 50 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 51 | 52 | # Apply temporal attention 53 | temp_att_outputs, beta = self.temp_att(qns_embed, vid_outputs) 54 | encoder_outputs = self.FC(qns_embed + temp_att_outputs) 55 | # hidden = qns_hidden 56 | hidden = encoder_outputs.unsqueeze(0) 57 | # print(hidden.size()) 58 | 59 | if mode == 'train': 60 | vocab_size = self.ans_decoder.vocab_size 61 | ans_len = ans.shape[1] 62 | input = ans[:, 0] 63 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 64 | for t in range(0, ans_len): 65 | # output, hidden = self.ans_decoder(encoder_outputs, hidden, input) 66 | output, hidden = self.ans_decoder(qns_outputs, hidden, input) 67 | outputs[:, t] = output 68 | teacher_force = rd.random() < teacher_force_ratio 69 | top1 = output.argmax(1) 70 | input = ans[:, t] if teacher_force else top1 71 | else: 72 | start = torch.LongTensor([1] * batch_size).to(self.device) 73 | # outputs = self.ans_decoder.sample(encoder_outputs, hidden, start) 74 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start) 75 | 76 | return outputs 77 | -------------------------------------------------------------------------------- /networks/VQAModel/UATT.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random as rd 4 | import sys 5 | sys.path.insert(0, 'networks') 6 | from Attention import TempAttentionHis 7 | 8 | class UATT(nn.Module): 9 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device): 10 | """ 11 | Unifying the Video and Question Attentions for Open-Ended Video Question Answering (TIP17) 12 | """ 13 | super(UATT, self).__init__() 14 | self.vid_encoder = vid_encoder 15 | self.qns_encoder = qns_encoder 16 | self.ans_decoder = ans_decoder 17 | mem_dim = 512 18 | self.att_q2v = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim) 19 | self.att_v2q = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim) 20 | 21 | self.device = device 22 | 23 | 24 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 25 | 26 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats) 27 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths) 28 | qns_outputs = qns_outputs.permute(1, 0, 2) 29 | # print(vid_outputs.size(), vid_hidden[0].size(), vid_hidden[1].size()) 30 | # print(qns_outputs.size(), qns_hidden[0].size(), qns_hidden[1].size()) 31 | """ 32 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 33 | torch.Size([3, 16, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 34 | """ 35 | 36 | word_num, batch_size, feat_dim = qns_outputs.size() 37 | r = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device) 38 | 39 | for word in qns_outputs: 40 | r, beta_r = self.att_q2v(word, vid_outputs, r) 41 | 42 | vid_outputs = vid_outputs.permute(1, 0, 2) # change to fnum, batch_size, feat_dim 43 | qns_outputs = qns_outputs.permute(1, 0, 2) # change to batch_size, word_num, feat_dim 44 | w = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device) 45 | for frame in vid_outputs: 46 | w, beta_w = self.att_v2q(frame, qns_outputs, w) 47 | 48 | hidden = (torch.cat((r.unsqueeze(0), w.unsqueeze(0)), dim=0), 49 | torch.cat((vid_hidden[1][0].unsqueeze(0), qns_hidden[1][0].unsqueeze(0)), dim=0)) 50 | 51 | if mode == 'train': 52 | vocab_size = self.ans_decoder.vocab_size 53 | batch_size = ans.shape[0] 54 | ans_len = ans.shape[1] 55 | input = ans[:, 0] 56 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 57 | for t in range(0, ans_len): 58 | output, hidden = self.ans_decoder(hidden, input) 59 | outputs[:, t] = output 60 | teacher_force = rd.random() < teacher_force_ratio 61 | top1 = output.argmax(1) 62 | input = ans[:, t] if teacher_force else top1 63 | else: 64 | start = torch.LongTensor([1] * batch_size).to(self.device) 65 | outputs = self.ans_decoder.sample(hidden, start) 66 | return outputs 67 | -------------------------------------------------------------------------------- /networks/VQAModel_bak.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random as rd 4 | from .Attention import TempAttentionHis, TempAttention, SpatialAttention 5 | from .memory_rand import MemoryRamTwoStreamModule, MemoryRamModule, MMModule 6 | from .memory_module import EpisodicMemory 7 | from .q_v_transformer import CoAttention 8 | from .gcn import AdjLearner, GCN 9 | from block import fusions #pytorch >= 1.1.0 10 | 11 | 12 | class EVQA(nn.Module): 13 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device): 14 | """ 15 | 16 | :param vid_encoder: 17 | :param qns_encoder: 18 | :param ans_decoder: 19 | :param device: 20 | """ 21 | super(EVQA, self).__init__() 22 | self.vid_encoder = vid_encoder 23 | self.qns_encoder = qns_encoder 24 | self.ans_decoder = ans_decoder 25 | self.device = device 26 | 27 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 28 | 29 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats) 30 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths) 31 | 32 | # print(vid_outputs.size(), vid_hidden[0].size(), vid_hidden[1].size()) 33 | # print(qns_outputs.size(), qns_hidden[0].size(), qns_hidden[1].size()) 34 | """ 35 | torch.Size([64, 128, 512]) torch.Size([1, 64 512]) torch.Size([1, 64, 512]) 36 | torch.Size([16, 64, 512]) torch.Size([2, 64, 512]) torch.Size([2, 64, 512]) 37 | """ 38 | 39 | 40 | hidden = qns_hidden[0] +vid_hidden[0] 41 | batch_size = qns.shape[0] 42 | 43 | if mode == 'train': 44 | vocab_size = self.ans_decoder.vocab_size 45 | ans_len = ans.shape[1] 46 | input = ans[:, 0] 47 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 48 | for t in range(0, ans_len): 49 | output, hidden = self.ans_decoder(qns_outputs, hidden, input) 50 | outputs[:,t] = output 51 | teacher_force = rd.random() < teacher_force_ratio 52 | top1 = output.argmax(1) 53 | input = ans[:, t] if teacher_force else top1 54 | else: 55 | start = torch.LongTensor([1] * batch_size).to(self.device) 56 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start) 57 | 58 | return outputs 59 | 60 | 61 | class UATT(nn.Module): 62 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, device): 63 | """ 64 | Unifying the Video and Question Attentions for Open-Ended Video Question Answering (TIP17) 65 | """ 66 | super(UATT, self).__init__() 67 | self.vid_encoder = vid_encoder 68 | self.qns_encoder = qns_encoder 69 | self.ans_decoder = ans_decoder 70 | mem_dim = 512 71 | self.att_q2v = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim) 72 | self.att_v2q = TempAttentionHis(vid_encoder.dim_hidden*2, qns_encoder.dim_hidden*2, mem_dim, mem_dim) 73 | 74 | self.device = device 75 | 76 | 77 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 78 | 79 | vid_outputs, vid_hidden = self.vid_encoder(vid_feats) 80 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths) 81 | qns_outputs = qns_outputs.permute(1, 0, 2) 82 | # print(vid_outputs.size(), vid_hidden[0].size(), vid_hidden[1].size()) 83 | # print(qns_outputs.size(), qns_hidden[0].size(), qns_hidden[1].size()) 84 | """ 85 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 86 | torch.Size([3, 16, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 87 | """ 88 | 89 | word_num, batch_size, feat_dim = qns_outputs.size() 90 | r = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device) 91 | 92 | for word in qns_outputs: 93 | r, beta_r = self.att_q2v(word, vid_outputs, r) 94 | 95 | vid_outputs = vid_outputs.permute(1, 0, 2) # change to fnum, batch_size, feat_dim 96 | qns_outputs = qns_outputs.permute(1, 0, 2) # change to batch_size, word_num, feat_dim 97 | w = torch.zeros((batch_size, vid_hidden[0].shape[-1])).to(self.device) 98 | for frame in vid_outputs: 99 | w, beta_w = self.att_v2q(frame, qns_outputs, w) 100 | 101 | hidden = (torch.cat((r.unsqueeze(0), w.unsqueeze(0)), dim=0), 102 | torch.cat((vid_hidden[1][0].unsqueeze(0), qns_hidden[1][0].unsqueeze(0)), dim=0)) 103 | 104 | if mode == 'train': 105 | vocab_size = self.ans_decoder.vocab_size 106 | batch_size = ans.shape[0] 107 | ans_len = ans.shape[1] 108 | input = ans[:, 0] 109 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 110 | for t in range(0, ans_len): 111 | output, hidden = self.ans_decoder(hidden, input) 112 | outputs[:, t] = output 113 | teacher_force = rd.random() < teacher_force_ratio 114 | top1 = output.argmax(1) 115 | input = ans[:, t] if teacher_force else top1 116 | else: 117 | start = torch.LongTensor([1] * batch_size).to(self.device) 118 | outputs = self.ans_decoder.sample(hidden, start) 119 | return outputs 120 | 121 | 122 | class STVQA(nn.Module): 123 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, att_dim, device): 124 | """ 125 | TGIF-QA: Toward Spatio-Temporal Reasoning in Visual Question Answering (CVPR17) 126 | """ 127 | super(STVQA, self).__init__() 128 | self.vid_encoder = vid_encoder 129 | self.qns_encoder = qns_encoder 130 | self.ans_decoder = ans_decoder 131 | self.att_dim = att_dim 132 | 133 | self.spatial_att = SpatialAttention(qns_encoder.dim_hidden*2, vid_encoder.input_dim, hidden_dim=self.att_dim) 134 | self.temp_att = TempAttention(qns_encoder.dim_hidden*2, vid_encoder.dim_hidden*2, hidden_dim=self.att_dim) 135 | self.device = device 136 | self.FC = nn.Linear(att_dim*2, att_dim) 137 | 138 | 139 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 140 | """ 141 | """ 142 | qns_output_1, qns_hidden_1 = self.qns_encoder(qns, qns_lengths) 143 | n_layers, batch_size, qns_dim = qns_hidden_1[0].size() 144 | 145 | # Concatenate the dual-layer hidden as qns embedding 146 | qns_embed = qns_hidden_1[0].permute(1, 0, 2) # batch first 147 | qns_embed = qns_embed.reshape(batch_size, -1) #(batch_size, feat_dim*2) 148 | batch_size, fnum, vid_dim, w, h = vid_feats.size() 149 | 150 | # Apply spatial attention 151 | vid_att_feats = torch.zeros(batch_size, fnum, vid_dim).to(self.device) 152 | for bs in range(batch_size): 153 | vid_att_feats[bs], alpha = self.spatial_att(qns_embed[bs], vid_feats[bs]) 154 | 155 | vid_outputs, vid_hidden = self.vid_encoder(vid_att_feats) 156 | 157 | qns_outputs, qns_hidden = self.qns_encoder(qns, qns_lengths, vid_hidden) 158 | 159 | """ 160 | torch.Size([3, 128, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 161 | torch.Size([16, 3, 1024]) torch.Size([2, 3, 512]) torch.Size([2, 3, 512]) 162 | """ 163 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 164 | 165 | # Apply temporal attention 166 | temp_att_outputs, beta = self.temp_att(qns_embed, vid_outputs) 167 | encoder_outputs = self.FC(qns_embed + temp_att_outputs) 168 | # hidden = qns_hidden 169 | hidden = encoder_outputs.unsqueeze(0) 170 | # print(hidden.size()) 171 | 172 | if mode == 'train': 173 | vocab_size = self.ans_decoder.vocab_size 174 | ans_len = ans.shape[1] 175 | input = ans[:, 0] 176 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 177 | for t in range(0, ans_len): 178 | # output, hidden = self.ans_decoder(encoder_outputs, hidden, input) 179 | output, hidden = self.ans_decoder(qns_outputs, hidden, input) 180 | outputs[:, t] = output 181 | teacher_force = rd.random() < teacher_force_ratio 182 | top1 = output.argmax(1) 183 | input = ans[:, t] if teacher_force else top1 184 | else: 185 | start = torch.LongTensor([1] * batch_size).to(self.device) 186 | # outputs = self.ans_decoder.sample(encoder_outputs, hidden, start) 187 | outputs = self.ans_decoder.sample(qns_outputs, hidden, start) 188 | 189 | return outputs 190 | 191 | 192 | class CoMem(nn.Module): 193 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2): 194 | """ 195 | motion-appearance co-memory networks for video question answering (CVPR18) 196 | """ 197 | super(CoMem, self).__init__() 198 | self.vid_encoder = vid_encoder 199 | self.qns_encoder = qns_encoder 200 | self.ans_decoder = ans_decoder 201 | 202 | dim = qns_encoder.dim_hidden 203 | 204 | self.epm_app = EpisodicMemory(dim*2) 205 | self.epm_mot = EpisodicMemory(dim*2) 206 | 207 | self.linear_ma = nn.Linear(dim*2*3, dim*2) 208 | self.linear_mb = nn.Linear(dim*2*3, dim*2) 209 | 210 | self.vq2word = nn.Linear(dim*2*2, dim) 211 | self._init_weights() 212 | self.device = device 213 | 214 | def _init_weights(self): 215 | """ 216 | initialize the linear weights 217 | :return: 218 | """ 219 | nn.init.xavier_normal_(self.linear_ma.weight) 220 | nn.init.xavier_normal_(self.linear_mb.weight) 221 | nn.init.xavier_normal_(self.vq2word.weight) 222 | 223 | 224 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'): 225 | """ 226 | Co-memory network 227 | """ 228 | 229 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim) 230 | 231 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1) 232 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1) 233 | 234 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths) 235 | 236 | # qns_output = qns_output.permute(1, 0, 2) 237 | batch_size, seq_len, qns_feat_dim = qns_output.size() 238 | 239 | 240 | qns_embed = qns_hidden.permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 241 | 242 | m_app = outputs_app[:, -1, :] 243 | m_mot = outputs_motion[:, -1, :] 244 | ma, mb = m_app.detach(), m_mot.detach() 245 | m_app = m_app.unsqueeze(1) 246 | m_mot = m_mot.unsqueeze(1) 247 | for _ in range(iter_num): 248 | mm = ma + mb 249 | m_app = self.epm_app(outputs_app, mm, m_app) 250 | m_mot = self.epm_mot(outputs_motion, mm, m_mot) 251 | ma_q = torch.cat((ma, m_app.squeeze(1), qns_embed), dim=1) 252 | mb_q = torch.cat((mb, m_mot.squeeze(1), qns_embed), dim=1) 253 | # print(ma_q.shape) 254 | ma = torch.tanh(self.linear_ma(ma_q)) 255 | mb = torch.tanh(self.linear_mb(mb_q)) 256 | 257 | mem = torch.cat((ma, mb), dim=1) 258 | encoder_outputs = self.vq2word(mem) 259 | # hidden = qns_hidden 260 | hidden = encoder_outputs.unsqueeze(0) 261 | 262 | # decoder_inputs = encoder_outputs 263 | 264 | if mode == 'train': 265 | vocab_size = self.ans_decoder.vocab_size 266 | ans_len = ans.shape[1] 267 | input = ans[:, 0] 268 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 269 | 270 | for t in range(0, ans_len): 271 | output, hidden = self.ans_decoder(qns_output, hidden, input) 272 | outputs[:, t] = output 273 | teacher_force = rd.random() < teacher_force_ratio 274 | top1 = output.argmax(1) 275 | input = ans[:, t] if teacher_force else top1 276 | else: 277 | start = torch.LongTensor([1] * batch_size).to(self.device) 278 | outputs = self.ans_decoder.sample(qns_output, hidden, start) 279 | 280 | return outputs 281 | 282 | 283 | class HME(nn.Module): 284 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device, input_drop_p=0.2): 285 | """ 286 | Heterogeneous memory enhanced multimodal attention model for video question answering (CVPR19) 287 | 288 | """ 289 | super(HME, self).__init__() 290 | self.vid_encoder = vid_encoder 291 | self.qns_encoder = qns_encoder 292 | self.ans_decoder = ans_decoder 293 | 294 | dim = qns_encoder.dim_hidden 295 | 296 | self.temp_att_a = TempAttention(dim * 2, dim * 2, hidden_dim=256) 297 | self.temp_att_m = TempAttention(dim * 2, dim * 2, hidden_dim=256) 298 | self.mrm_vid = MemoryRamTwoStreamModule(dim, dim, max_len_v, device) 299 | self.mrm_txt = MemoryRamModule(dim, dim, max_len_q, device) 300 | 301 | self.mm_module_v1 = MMModule(dim, input_drop_p, device) 302 | 303 | self.linear_vid = nn.Linear(dim*2, dim) 304 | self.linear_qns = nn.Linear(dim*2, dim) 305 | self.linear_mem = nn.Linear(dim*2, dim) 306 | self.vq2word_hme = nn.Linear(dim*3, dim*2) 307 | self._init_weights() 308 | self.device = device 309 | 310 | def _init_weights(self): 311 | """ 312 | initialize the linear weights 313 | :return: 314 | """ 315 | nn.init.xavier_normal_(self.linear_vid.weight) 316 | nn.init.xavier_normal_(self.linear_qns.weight) 317 | nn.init.xavier_normal_(self.linear_mem.weight) 318 | nn.init.xavier_normal_(self.vq2word_hme.weight) 319 | 320 | 321 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, iter_num=3, mode='train'): 322 | """ 323 | """ 324 | 325 | outputs_app_l1, outputs_app_l2, outputs_motion_l1, outputs_motion_l2 = self.vid_encoder(vid_feats) #(batch_size, fnum, feat_dim) 326 | 327 | outputs_app = torch.cat((outputs_app_l1, outputs_app_l2), dim=-1) 328 | outputs_motion = torch.cat((outputs_motion_l1, outputs_motion_l2), dim=-1) 329 | 330 | batch_size, fnum, vid_feat_dim = outputs_app.size() 331 | 332 | qns_output, qns_hidden = self.qns_encoder(qns, qns_lengths) 333 | # print(qns_output.shape, qns_hidden[0].shape) #torch.Size([10, 23, 256]) torch.Size([2, 10, 256]) 334 | 335 | # qns_output = qns_output.permute(1, 0, 2) 336 | batch_size, seq_len, qns_feat_dim = qns_output.size() 337 | 338 | qns_embed = qns_hidden[0].permute(1, 0, 2).contiguous().view(batch_size, -1) #(batch_size, feat_dim) 339 | 340 | # Apply temporal attention 341 | att_app, beta_app = self.temp_att_a(qns_embed, outputs_app) 342 | att_motion, beta_motion = self.temp_att_m(qns_embed, outputs_motion) 343 | tmp_app_motion = torch.cat((outputs_app_l2[:, -1, :], outputs_motion_l2[:, -1, :]), dim=-1) 344 | 345 | mem_output = torch.zeros(batch_size, vid_feat_dim).to(self.device) 346 | 347 | for bs in range(batch_size): 348 | mem_ram_vid = self.mrm_vid(outputs_app_l2[bs], outputs_motion_l2[bs], fnum) 349 | cur_qns = qns_output[bs][:qns_lengths[bs]] 350 | mem_ram_txt = self.mrm_txt(cur_qns, qns_lengths[bs]) #should remove padded zeros 351 | mem_output[bs] = self.mm_module_v1(tmp_app_motion[bs].unsqueeze(0), mem_ram_vid, mem_ram_txt, iter_num) 352 | """ 353 | (64, 256) (22, 256) (1, 512) 354 | """ 355 | app_trans = torch.tanh(self.linear_vid(att_app)) 356 | motion_trans = torch.tanh(self.linear_vid(att_motion)) 357 | mem_trans = torch.tanh(self.linear_mem(mem_output)) 358 | 359 | encoder_outputs = torch.cat((app_trans, motion_trans, mem_trans), dim=1) 360 | decoder_inputs = self.vq2word_hme(encoder_outputs) 361 | hidden = qns_hidden 362 | if mode == 'train': 363 | vocab_size = self.ans_decoder.vocab_size 364 | ans_len = ans.shape[1] 365 | input = ans[:, 0] 366 | 367 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 368 | 369 | for t in range(0, ans_len): 370 | output, hidden = self.ans_decoder(decoder_inputs, hidden, input) 371 | outputs[:, t] = output 372 | teacher_force = rd.random() < teacher_force_ratio 373 | top1 = output.argmax(1) 374 | input = ans[:, t] if teacher_force else top1 375 | else: 376 | start = torch.LongTensor([1] * batch_size).to(self.device) 377 | outputs = self.ans_decoder.sample(decoder_inputs, hidden, start) 378 | 379 | return outputs 380 | 381 | 382 | class HGA(nn.Module): 383 | def __init__(self, vid_encoder, qns_encoder, ans_decoder, max_len_v, max_len_q, device): 384 | """ 385 | Reasoning with Heterogeneous Graph Alignment for Video Question Answering (AAAI20) 386 | """ 387 | super(HGA, self).__init__() 388 | self.vid_encoder = vid_encoder 389 | self.qns_encoder = qns_encoder 390 | self.ans_decoder = ans_decoder 391 | self.max_len_v = max_len_v 392 | self.max_len_q = max_len_q 393 | self.device = device 394 | hidden_size = vid_encoder.dim_hidden 395 | input_dropout_p = vid_encoder.input_dropout_p 396 | 397 | self.q_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False) 398 | self.v_input_ln = nn.LayerNorm(hidden_size, elementwise_affine=False) 399 | 400 | self.co_attn = CoAttention( 401 | hidden_size, n_layers=vid_encoder.n_layers, dropout_p=input_dropout_p) 402 | 403 | self.adj_learner = AdjLearner( 404 | hidden_size, hidden_size, dropout=input_dropout_p) 405 | 406 | self.gcn = GCN( 407 | hidden_size, 408 | hidden_size, 409 | hidden_size, 410 | num_layers=2, 411 | dropout=input_dropout_p) 412 | 413 | self.gcn_atten_pool = nn.Sequential( 414 | nn.Linear(hidden_size, hidden_size // 2), 415 | nn.Tanh(), 416 | nn.Linear(hidden_size // 2, 1), 417 | nn.Softmax(dim=-1)) 418 | 419 | self.global_fusion = fusions.Block( 420 | [hidden_size, hidden_size], hidden_size, dropout_input=input_dropout_p) 421 | 422 | self.fusion = fusions.Block([hidden_size, hidden_size], hidden_size) 423 | 424 | 425 | def forward(self, vid_feats, qns, qns_lengths, ans, ans_lengths, teacher_force_ratio=0.5, mode='train'): 426 | """ 427 | 428 | """ 429 | encoder_out, qns_hidden, qns_out, vid_out = self.vq_encoder(vid_feats, qns, qns_lengths) 430 | 431 | batch_size = encoder_out.shape[0] 432 | 433 | hidden = encoder_out.unsqueeze(0) 434 | if mode == 'train': 435 | vocab_size = self.ans_decoder.vocab_size 436 | ans_len = ans.shape[1] 437 | input = ans[:, 0] 438 | outputs = torch.zeros(batch_size, ans_len, vocab_size).to(self.device) 439 | for t in range(0, ans_len): 440 | 441 | output, hidden = self.ans_decoder(qns_out, hidden, input) #attqns, attvid 442 | outputs[:, t] = output 443 | teacher_force = rd.random() < teacher_force_ratio 444 | top1 = output.argmax(1) 445 | input = ans[:, t] if teacher_force else top1 446 | else: 447 | start = torch.LongTensor([1] * batch_size).to(self.device) 448 | 449 | outputs = self.ans_decoder.sample(qns_out, hidden, start) #vidatt, qns_att 450 | 451 | return outputs 452 | 453 | 454 | def vq_encoder(self, vid_feats, qns, qns_lengths): 455 | """ 456 | 457 | :param vid_feats: 458 | :param qns: 459 | :param qns_lengths: 460 | :return: 461 | """ 462 | q_output, s_hidden = self.qns_encoder(qns, qns_lengths) 463 | qns_last_hidden = torch.squeeze(s_hidden) 464 | 465 | 466 | v_output, v_hidden = self.vid_encoder(vid_feats) 467 | vid_last_hidden = torch.squeeze(v_hidden) 468 | 469 | q_output = self.q_input_ln(q_output) 470 | v_output = self.v_input_ln(v_output) 471 | 472 | q_output, v_output = self.co_attn(q_output, v_output) 473 | 474 | ### GCN 475 | adj = self.adj_learner(q_output, v_output) 476 | # q_v_inputs of shape (batch_size, q_v_len, hidden_size) 477 | q_v_inputs = torch.cat((q_output, v_output), dim=1) 478 | # q_v_output of shape (batch_size, q_v_len, hidden_size) 479 | q_v_output = self.gcn(q_v_inputs, adj) 480 | 481 | ## attention pool 482 | local_attn = self.gcn_atten_pool(q_v_output) 483 | local_out = torch.sum(q_v_output * local_attn, dim=1) 484 | 485 | # print(qns_last_hidden.shape, vid_last_hidden.shape) 486 | global_out = self.global_fusion((qns_last_hidden, vid_last_hidden)) 487 | 488 | 489 | out = self.fusion((global_out, local_out)).squeeze() #4 x 512 490 | 491 | return out, s_hidden, q_output, v_output, 492 | 493 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .EncoderRNN import EncoderVid, EncoderQns 2 | from .DecoderRNN import AnsUATT 3 | from .VQAModel import EVQA 4 | 5 | -------------------------------------------------------------------------------- /networks/gcn.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from torch.nn.parameter import Parameter 7 | from torch.nn.modules.module import Module 8 | 9 | 10 | def padding_mask_k(seq_q, seq_k): 11 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 12 | In batch 0: 13 | [[x x x 0] [[0 0 0 1] 14 | [x x x 0]-> [0 0 0 1] 15 | [x x x 0]] [0 0 0 1]] uint8 16 | """ 17 | fake_q = torch.ones_like(seq_q) 18 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2)) 19 | pad_mask = pad_mask.eq(0) 20 | # pad_mask = pad_mask.lt(1e-3) 21 | return pad_mask 22 | 23 | 24 | def padding_mask_q(seq_q, seq_k): 25 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 26 | In batch 0: 27 | [[x x x x] [[0 0 0 0] 28 | [x x x x] -> [0 0 0 0] 29 | [0 0 0 0]] [1 1 1 1]] uint8 30 | """ 31 | fake_k = torch.ones_like(seq_k) 32 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2)) 33 | pad_mask = pad_mask.eq(0) 34 | # pad_mask = pad_mask.lt(1e-3) 35 | return pad_mask 36 | 37 | 38 | class GraphConvolution(Module): 39 | """ 40 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 41 | """ 42 | 43 | def __init__(self, in_features, out_features): 44 | super(GraphConvolution, self).__init__() 45 | self.weight = nn.Linear(in_features, out_features, bias=False) 46 | self.layer_norm = nn.LayerNorm(out_features, elementwise_affine=False) 47 | 48 | def forward(self, input, adj): 49 | # self.weight of shape (hidden_size, hidden_size) 50 | support = self.weight(input) 51 | output = torch.bmm(adj, support) 52 | output = self.layer_norm(output) 53 | return output 54 | 55 | 56 | class GraphAttention(nn.Module): 57 | """ 58 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 59 | """ 60 | 61 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 62 | super(GraphAttention, self).__init__() 63 | self.dropout = dropout 64 | self.in_features = in_features 65 | self.out_features = out_features 66 | self.alpha = alpha 67 | self.concat = concat 68 | 69 | self.W = nn.Parameter( 70 | nn.init.xavier_normal_( 71 | torch.Tensor(in_features, out_features).type( 72 | torch.cuda.FloatTensor if torch.cuda.is_available( 73 | ) else torch.FloatTensor), 74 | gain=np.sqrt(2.0)), 75 | requires_grad=True) 76 | self.a1 = nn.Parameter( 77 | nn.init.xavier_normal_( 78 | torch.Tensor(out_features, 1).type( 79 | torch.cuda.FloatTensor if torch.cuda.is_available( 80 | ) else torch.FloatTensor), 81 | gain=np.sqrt(2.0)), 82 | requires_grad=True) 83 | self.a2 = nn.Parameter( 84 | nn.init.xavier_normal_( 85 | torch.Tensor(out_features, 1).type( 86 | torch.cuda.FloatTensor if torch.cuda.is_available( 87 | ) else torch.FloatTensor), 88 | gain=np.sqrt(2.0)), 89 | requires_grad=True) 90 | 91 | self.leakyrelu = nn.LeakyReLU(self.alpha) 92 | 93 | def forward(self, input, adj): 94 | h = torch.mm(input, self.W) 95 | N = h.size()[0] 96 | 97 | f_1 = torch.matmul(h, self.a1) 98 | f_2 = torch.matmul(h, self.a2) 99 | e = self.leakyrelu(f_1 + f_2.transpose(0, 1)) 100 | 101 | zero_vec = -9e15 * torch.ones_like(e) 102 | attention = torch.where(adj > 0, e, zero_vec) 103 | attention = F.softmax(attention, dim=1) 104 | attention = F.dropout(attention, self.dropout, training=self.training) 105 | h_prime = torch.matmul(attention, h) 106 | 107 | if self.concat: 108 | return F.elu(h_prime) 109 | else: 110 | return h_prime 111 | 112 | 113 | class GCN(nn.Module): 114 | 115 | def __init__( 116 | self, input_size, hidden_size, num_classes, num_layers=1, 117 | dropout=0.1): 118 | super(GCN, self).__init__() 119 | self.layers = nn.ModuleList() 120 | self.layers.append(GraphConvolution(input_size, hidden_size)) 121 | for i in range(num_layers - 1): 122 | self.layers.append(GraphConvolution(hidden_size, hidden_size)) 123 | self.layers.append(GraphConvolution(hidden_size, num_classes)) 124 | self.dropout = nn.Dropout(p=dropout) 125 | 126 | def forward(self, x, adj): 127 | for i, layer in enumerate(self.layers): 128 | x = self.dropout(F.relu(layer(x, adj))) 129 | 130 | # x of shape (bs, q_v_len, num_classes) 131 | return x 132 | 133 | 134 | class AdjLearner(Module): 135 | 136 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1): 137 | super().__init__() 138 | ''' 139 | ## Variables: 140 | - in_feature_dim: dimensionality of input features 141 | - hidden_size: dimensionality of the joint hidden embedding 142 | - K: number of graph nodes/objects on the image 143 | ''' 144 | 145 | # Embedding layers. Padded 0 => 0 146 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False) 147 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False) 148 | 149 | # Regularisation 150 | self.dropout = nn.Dropout(p=dropout) 151 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1) 152 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2) 153 | 154 | def forward(self, questions, videos): 155 | ''' 156 | ## Inputs: 157 | ## Returns: 158 | - adjacency matrix (batch_size, q_v_len, q_v_len) 159 | ''' 160 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features 161 | graph_nodes = torch.cat((questions, videos), dim=1) 162 | 163 | # layer 1 164 | h = self.edge_layer_1(graph_nodes) 165 | h = F.relu(h) 166 | 167 | # layer 2 168 | h = self.edge_layer_2(h) 169 | h = F.relu(h) 170 | # h * sigmoid(Wh) 171 | # h = F.tanh(h) 172 | 173 | # outer product 174 | adjacency_matrix = torch.bmm(h, h.transpose(1, 2)) 175 | 176 | return adjacency_matrix 177 | 178 | 179 | class EvoAdjLearner(Module): 180 | 181 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1): 182 | super().__init__() 183 | ''' 184 | ## Variables: 185 | - in_feature_dim: dimensionality of input features 186 | - hidden_size: dimensionality of the joint hidden embedding 187 | - K: number of graph nodes/objects on the image 188 | ''' 189 | 190 | # Embedding layers. Padded 0 => 0 191 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False) 192 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False) 193 | self.edge_layer_3 = nn.Linear(in_feature_dim, hidden_size, bias=False) 194 | self.edge_layer_4 = nn.Linear(hidden_size, hidden_size, bias=False) 195 | 196 | # Regularisation 197 | self.dropout = nn.Dropout(p=dropout) 198 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1) 199 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2) 200 | 201 | def forward(self, questions, videos): 202 | ''' 203 | ## Inputs: 204 | ## Returns: 205 | - adjacency matrix (batch_size, q_v_len, q_v_len) 206 | ''' 207 | # graph_nodes (batch_size, q_v_len, in_feat_dim): input features 208 | graph_nodes = torch.cat((questions, videos), dim=1) 209 | 210 | attn_mask = padding_mask_k(graph_nodes, graph_nodes) 211 | sf_mask = padding_mask_q(graph_nodes, graph_nodes) 212 | 213 | # layer 1 214 | h = self.edge_layer_1(graph_nodes) 215 | h = F.relu(h) 216 | # layer 2 217 | h = self.edge_layer_2(h) 218 | # h = F.relu(h) 219 | 220 | # layer 1 221 | h_ = self.edge_layer_3(graph_nodes) 222 | h_ = F.relu(h_) 223 | # layer 2 224 | h_ = self.edge_layer_4(h_) 225 | # h_ = F.relu(h_) 226 | 227 | # outer product 228 | adjacency_matrix = torch.bmm(h, h_.transpose(1, 2)) 229 | # adjacency_matrix = adjacency_matrix.masked_fill(attn_mask, -np.inf) 230 | 231 | # softmaxed_adj = F.softmax(adjacency_matrix, dim=-1) 232 | 233 | # softmaxed_adj = softmaxed_adj.masked_fill(sf_mask, 0.) 234 | 235 | return adjacency_matrix -------------------------------------------------------------------------------- /networks/memory_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | import torch.nn.functional as F 5 | import torch.nn.init as init 6 | 7 | 8 | class AttentionGRUCell(nn.Module): 9 | ''' 10 | Eq (1)~(4), then modify by Eq (11) 11 | When forwarding, we feed attention gate g into GRU 12 | ''' 13 | def __init__(self, input_size, hidden_size): 14 | super(AttentionGRUCell, self).__init__() 15 | self.hidden_size = hidden_size 16 | self.Wr = nn.Linear(input_size, hidden_size) 17 | self.Ur = nn.Linear(hidden_size, hidden_size) 18 | self.W = nn.Linear(input_size, hidden_size) 19 | self.U = nn.Linear(hidden_size, hidden_size) 20 | 21 | def init_weights(self): 22 | self.Wr.weight.data.normal_(0.0, 0.02) 23 | self.Wr.bias.data.fill_(0) 24 | self.Ur.weight.data.normal_(0.0, 0.02) 25 | self.Ur.bias.data.fill_(0) 26 | self.W.weight.data.normal_(0.0, 0.02) 27 | self.W.bias.data.fill_(0) 28 | self.U.weight.data.normal_(0.0, 0.02) 29 | self.U.bias.data.fill_(0) 30 | 31 | def forward(self, fact, C, g): 32 | ''' 33 | fact.size() -> (#batch, #hidden = #embedding) 34 | c.size() -> (#hidden, ) -> (#batch, #hidden = #embedding) 35 | r.size() -> (#batch, #hidden = #embedding) 36 | h_tilda.size() -> (#batch, #hidden = #embedding) 37 | g.size() -> (#batch, ) 38 | ''' 39 | 40 | r = torch.sigmoid(self.Wr(fact) + self.Ur(C)) 41 | h_tilda = torch.tanh(self.W(fact) + r * self.U(C)) 42 | g = g.unsqueeze(1).expand_as(h_tilda) 43 | h = g * h_tilda + (1 - g) * C 44 | return h 45 | 46 | class AttentionGRU(nn.Module): 47 | ''' 48 | Section 3.3 49 | continuously run AttnGRU to get contextual vector c at each time t 50 | ''' 51 | def __init__(self, input_size, hidden_size): 52 | super(AttentionGRU, self).__init__() 53 | self.hidden_size = hidden_size 54 | self.AGRUCell = AttentionGRUCell(input_size, hidden_size) 55 | 56 | def init_weights(self): 57 | self.AGRUCell.init_weights() 58 | 59 | def forward(self, facts, G): 60 | ''' 61 | facts.size() -> (#batch, #sentence, #hidden = #embedding) 62 | fact.size() -> (#batch, #hidden = #embedding) 63 | G.size() -> (#batch, #sentence) 64 | g.size() -> (#batch, ) 65 | C.size() -> (#batch, #hidden) 66 | ''' 67 | batch_num, sen_num, embedding_size = facts.size() 68 | C = Variable(torch.zeros(self.hidden_size)).cuda() 69 | for sid in range(sen_num): 70 | fact = facts[:, sid, :] 71 | g = G[:, sid] 72 | if sid == 0: 73 | C = C.unsqueeze(0).expand_as(fact) 74 | C = self.AGRUCell(fact, C, g) 75 | return C 76 | 77 | class EpisodicMemory(nn.Module): 78 | ''' 79 | Section 3.3 80 | ''' 81 | 82 | def __init__(self, hidden_size): 83 | super(EpisodicMemory, self).__init__() 84 | self.AGRU = AttentionGRU(hidden_size, hidden_size) 85 | self.z1 = nn.Linear(4 * hidden_size, hidden_size) 86 | self.z2 = nn.Linear(hidden_size, 1) 87 | self.next_mem = nn.Linear(3 * hidden_size, hidden_size) 88 | 89 | 90 | def init_weights(self): 91 | self.z1.weight.data.normal_(0.0, 0.02) 92 | self.z1.bias.data.fill_(0) 93 | self.z2.weight.data.normal_(0.0, 0.02) 94 | self.z2.bias.data.fill_(0) 95 | self.next_mem.weight.data.normal_(0.0, 0.02) 96 | self.next_mem.bias.data.fill_(0) 97 | self.AGRU.init_weights() 98 | 99 | 100 | def make_interaction(self, frames, questions, prevM): 101 | ''' 102 | frames.size() -> (#batch, T, #hidden = #embedding) 103 | questions.size() -> (#batch, 1, #hidden) 104 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding) 105 | z.size() -> (#batch, T, 4 x #embedding) 106 | G.size() -> (#batch, T) 107 | ''' 108 | batch_num, T, embedding_size = frames.size() 109 | questions = questions.view(questions.size(0),1,questions.size(1)) 110 | 111 | 112 | #questions = questions.expand_as(frames) 113 | #prevM = prevM.expand_as(frames) 114 | 115 | #print(questions.size(),prevM.size()) 116 | 117 | # Eq (8)~(10) 118 | z = torch.cat([ 119 | frames * questions, 120 | frames * prevM, 121 | torch.abs(frames - questions), 122 | torch.abs(frames - prevM) 123 | ], dim=2) 124 | 125 | z = z.view(-1, 4 * embedding_size) 126 | 127 | G = torch.tanh(self.z1(z)) 128 | G = self.z2(G) 129 | G = G.view(batch_num, -1) 130 | G = F.softmax(G,dim=1) 131 | #print('G size',G.size()) 132 | return G 133 | 134 | def forward(self, frames, questions, prevM): 135 | ''' 136 | frames.size() -> (#batch, #sentence, #hidden = #embedding) 137 | questions.size() -> (#batch, #sentence = 1, #hidden) 138 | prevM.size() -> (#batch, #sentence = 1, #hidden = #embedding) 139 | G.size() -> (#batch, #sentence) 140 | C.size() -> (#batch, #hidden) 141 | concat.size() -> (#batch, 3 x #embedding) 142 | ''' 143 | 144 | ''' 145 | section 3.3 - Attention based GRU 146 | input: F and q, as frames and questions 147 | then get gates g 148 | then (c,m,g) feed into memory update module Eq(13) 149 | output new memory state 150 | ''' 151 | # print(frames.shape, questions.shape, prevM.shape) 152 | 153 | G = self.make_interaction(frames, questions, prevM) 154 | C = self.AGRU(frames, G) 155 | concat = torch.cat([prevM.squeeze(1), C, questions.squeeze(1)], dim=1) 156 | next_mem = F.relu(self.next_mem(concat)) 157 | #print(next_mem.size()) 158 | next_mem = next_mem.unsqueeze(1) 159 | return next_mem 160 | -------------------------------------------------------------------------------- /networks/memory_rand.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from Attention import MultiModalAttentionModule 5 | 6 | class MemoryRamModule(nn.Module): 7 | 8 | def __init__(self, input_size=1024, hidden_size=512, memory_bank_size=100, device=None): 9 | """Set the hyper-parameters and build the layers.""" 10 | super(MemoryRamModule, self).__init__() 11 | 12 | self.input_size = input_size 13 | self.hidden_size = hidden_size 14 | self.memory_bank_size = memory_bank_size 15 | self.device = device 16 | 17 | self.hidden_to_content = nn.Linear(hidden_size+input_size, hidden_size) 18 | #self.read_to_hidden = nn.Linear(hidden_size+input_size, 1) 19 | self.write_gate = nn.Linear(hidden_size+input_size, 1) 20 | self.write_prob = nn.Linear(hidden_size+input_size, memory_bank_size) 21 | 22 | self.read_gate = nn.Linear(hidden_size+input_size, 1) 23 | self.read_prob = nn.Linear(hidden_size+input_size, memory_bank_size) 24 | 25 | 26 | self.Wxh = nn.Parameter(torch.FloatTensor(input_size, hidden_size),requires_grad=True) 27 | self.Wrh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True) 28 | self.Whh = nn.Parameter(torch.FloatTensor(hidden_size, hidden_size),requires_grad=True) 29 | self.bh = nn.Parameter(torch.FloatTensor(hidden_size),requires_grad=True) 30 | 31 | self.init_weights() 32 | 33 | 34 | def init_weights(self): 35 | self.Wxh.data.normal_(0.0, 0.1) 36 | self.Wrh.data.normal_(0.0, 0.1) 37 | self.Whh.data.normal_(0.0, 0.1) 38 | self.bh.data.fill_(0) 39 | 40 | 41 | def forward(self, hidden_frames, nImg): 42 | 43 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device) 44 | memory_ram.fill_(0) 45 | 46 | h_t = torch.zeros(1, self.hidden_size).to(self.device) 47 | 48 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device) 49 | 50 | for t in range(nImg): 51 | x_t = hidden_frames[t:t+1,:] 52 | 53 | x_h_t = torch.cat([x_t,h_t],dim=1) 54 | 55 | ############# read ############ 56 | ar = torch.softmax(self.read_prob( x_h_t ),dim=1) # read prob from memories 57 | go = torch.sigmoid(self.read_gate( x_h_t )) # read gate 58 | r = go * torch.matmul(ar,memory_ram) # read vector 59 | 60 | ######### h_t ######### 61 | # Eq (17) 62 | m1 = torch.matmul(x_t, self.Wxh) 63 | m2 = torch.matmul(r, self.Wrh) 64 | m3 = torch.matmul(h_t, self.Whh) 65 | h_t_p1 = F.relu(m1 + m2 + m3 + self.bh) # Eq(17) 66 | 67 | 68 | ############# write ############ 69 | c_t = F.relu( self.hidden_to_content(x_h_t) ) # Eq(15), content vector 70 | aw = torch.softmax(self.write_prob( x_h_t ),dim=1) # write prob to memories 71 | aw = aw.view(self.memory_bank_size,1) 72 | gw = torch.sigmoid(self.write_gate( x_h_t )) # write gate 73 | #print gw.size(),aw.size(),c_t.size(),memory_ram.size() 74 | memory_ram = gw * aw * c_t + (1.0-aw) * memory_ram # Eq(16) 75 | 76 | h_t = h_t_p1 77 | hiddens[t,:] = h_t 78 | 79 | #return memory_ram 80 | return hiddens 81 | 82 | 83 | class MemoryRamTwoStreamModule(nn.Module): 84 | 85 | def __init__(self, input_size, hidden_size=512, memory_bank_size=100, device=None): 86 | """Set the hyper-parameters and build the layers.""" 87 | super(MemoryRamTwoStreamModule, self).__init__() 88 | 89 | self.input_size = input_size 90 | self.hidden_size = hidden_size 91 | self.memory_bank_size = memory_bank_size 92 | self.device = device 93 | 94 | self.hidden_to_content_a = nn.Linear(hidden_size+input_size, hidden_size) 95 | self.hidden_to_content_m = nn.Linear(hidden_size+input_size, hidden_size) 96 | 97 | self.write_prob = nn.Linear(hidden_size*3, 3) 98 | self.write_prob_a = nn.Linear(hidden_size+input_size, memory_bank_size) 99 | self.write_prob_m = nn.Linear(hidden_size+input_size, memory_bank_size) 100 | 101 | self.read_prob = nn.Linear(hidden_size*3, memory_bank_size) 102 | 103 | self.read_to_hidden = nn.Linear(hidden_size*2, hidden_size) 104 | self.read_to_hidden_a = nn.Linear(hidden_size*2+input_size, hidden_size) 105 | self.read_to_hidden_m = nn.Linear(hidden_size*2+input_size, hidden_size) 106 | self.init_weights() 107 | 108 | def init_weights(self): 109 | pass 110 | 111 | 112 | def forward(self, hidden_out_a, hidden_out_m, nImg): 113 | 114 | 115 | memory_ram = torch.FloatTensor(self.memory_bank_size, self.hidden_size).to(self.device) 116 | memory_ram.fill_(0) 117 | 118 | h_t_a = torch.zeros(1, self.hidden_size).to(self.device) 119 | h_t_m = torch.zeros(1, self.hidden_size).to(self.device) 120 | h_t = torch.zeros(1, self.hidden_size).to(self.device) 121 | 122 | hiddens = torch.FloatTensor(nImg, self.hidden_size).to(self.device) 123 | 124 | for t in range(nImg): 125 | x_t_a = hidden_out_a[t:t+1,:] 126 | x_t_m = hidden_out_m[t:t+1,:] 127 | 128 | 129 | ############# read ############ 130 | x_h_t_am = torch.cat([h_t_a,h_t_m,h_t],dim=1) 131 | ar = torch.softmax(self.read_prob( x_h_t_am ),dim=1) # read prob from memories 132 | r = torch.matmul(ar,memory_ram) # read vector 133 | 134 | 135 | ######### h_t ######### 136 | # Eq (17) 137 | f_0 = torch.cat([r, h_t],dim=1) 138 | f_a = torch.cat([x_t_a, r, h_t_a],dim=1) 139 | f_m = torch.cat([x_t_m, r, h_t_m],dim=1) 140 | 141 | h_t_1 = F.relu(self.read_to_hidden(f_0)) 142 | h_t_a1 = F.relu(self.read_to_hidden_a(f_a)) 143 | h_t_m1 = F.relu(self.read_to_hidden_m(f_m)) 144 | 145 | 146 | ############# write ############ 147 | 148 | # write probability of [keep, write appearance, write motion] 149 | aw = torch.softmax(self.write_prob( x_h_t_am ),dim=1) # write prob to memories 150 | x_h_ta = torch.cat([h_t_a,x_t_a],dim=1) 151 | x_h_tm = torch.cat([h_t_m,x_t_m],dim=1) 152 | 153 | 154 | # write content 155 | c_t_a = F.relu( self.hidden_to_content_a(x_h_ta) ) # Eq(15), content vector 156 | c_t_m = F.relu( self.hidden_to_content_m(x_h_tm) ) # Eq(15), content vector 157 | 158 | aw_a = torch.softmax(self.write_prob_a( x_h_ta ),dim=1) # write prob to memories 159 | aw_m = torch.softmax(self.write_prob_m( x_h_tm ),dim=1) # write prob to memories 160 | 161 | 162 | aw_a = aw_a.view(self.memory_bank_size,1) 163 | aw_m = aw_m.view(self.memory_bank_size,1) 164 | 165 | memory_ram = aw[0,0] * memory_ram + aw[0,1] * aw_a * c_t_a + aw[0,2] * aw_m * c_t_m 166 | 167 | 168 | h_t = h_t_1 169 | h_t_a = h_t_a1 170 | h_t_m = h_t_m1 171 | 172 | hiddens[t,:] = h_t 173 | 174 | 175 | return hiddens 176 | 177 | class MMModule(nn.Module): 178 | def __init__(self, dim, input_drop_p, device): 179 | """Set the hyper-parameters and build the layers.""" 180 | super(MMModule, self).__init__() 181 | self.hidden_size = dim 182 | self.lstm_mm_1 = nn.LSTMCell(dim, dim) 183 | self.lstm_mm_2 = nn.LSTMCell(dim, dim) 184 | self.hidden_encoder_1 = nn.Linear(dim * 2, dim) 185 | self.hidden_encoder_2 = nn.Linear(dim * 2, dim) 186 | self.dropout = nn.Dropout(input_drop_p) 187 | self.mm_att = MultiModalAttentionModule(dim) 188 | self.device = device 189 | self.init_weights() 190 | 191 | 192 | def init_weights(self): 193 | nn.init.xavier_normal_(self.hidden_encoder_1.weight) 194 | nn.init.xavier_normal_(self.hidden_encoder_2.weight) 195 | self.init_hiddens() 196 | 197 | def init_hiddens(self): 198 | s_t = torch.zeros(1, self.hidden_size).to(self.device) 199 | s_t2 = torch.zeros(1, self.hidden_size).to(self.device) 200 | c_t = torch.zeros(1, self.hidden_size).to(self.device) 201 | c_t2 = torch.zeros(1, self.hidden_size).to(self.device) 202 | return s_t, s_t2, c_t, c_t2 203 | 204 | def forward(self, svt_tmp, memory_ram_vid, memory_ram_txt, loop=3): 205 | """ 206 | 207 | :param svt_tmp: 208 | :param memory_ram_vid: 209 | :param memory_ram_txt: 210 | :param loop: 211 | :return: 212 | """ 213 | 214 | sm_q1, sm_q2, cm_q1, cm_q2 = self.init_hiddens() 215 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_1(svt_tmp))) 216 | 217 | for _ in range(loop): 218 | sm_q1, cm_q1 = self.lstm_mm_1(mm_oo, (sm_q1, cm_q1)) 219 | sm_q2, cm_q2 = self.lstm_mm_2(sm_q1, (sm_q2, cm_q2)) 220 | 221 | mm_o1 = self.mm_att(sm_q2, memory_ram_vid, memory_ram_txt) 222 | mm_o2 = torch.cat((sm_q2, mm_o1), dim=1) 223 | mm_oo = self.dropout(torch.tanh(self.hidden_encoder_2(mm_o2))) 224 | 225 | smq = torch.cat((sm_q1, sm_q2), dim=1) 226 | 227 | return smq -------------------------------------------------------------------------------- /networks/q_v_transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchnlp_nn as nlpnn 6 | 7 | 8 | def padding_mask(seq_q, seq_k): 9 | # seq_k of shape (batch, k_len) and seq_q (batch, q_len), not embedded. q and k are padded with 0. 10 | seq_q = torch.unsqueeze(seq_q, 2) 11 | seq_k = torch.unsqueeze(seq_k, 2) 12 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2)) 13 | pad_mask = pad_mask.eq(0) 14 | return pad_mask 15 | 16 | 17 | def padding_mask_transformer(seq_q, seq_k): 18 | # original padding_mask in transformer, for masking out the padding part of key sequence. 19 | len_q = seq_q.size(1) 20 | # `PAD` is 0 21 | pad_mask = seq_k.eq(0) 22 | pad_mask = pad_mask.unsqueeze(1).expand( 23 | -1, len_q, -1) # shape [B, L_q, L_k] 24 | return pad_mask 25 | 26 | 27 | def padding_mask_embedded(seq_q, seq_k): 28 | # seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len) 29 | pad_mask = torch.bmm(seq_q, seq_k.transpose(1, 2)) 30 | pad_mask = pad_mask.eq(0) 31 | return pad_mask 32 | 33 | 34 | def padding_mask_k(seq_q, seq_k): 35 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 36 | In batch 0: 37 | [[x x x 0] [[0 0 0 1] 38 | [x x x 0]-> [0 0 0 1] 39 | [x x x 0]] [0 0 0 1]] uint8 40 | """ 41 | fake_q = torch.ones_like(seq_q) 42 | pad_mask = torch.bmm(fake_q, seq_k.transpose(1, 2)) 43 | pad_mask = pad_mask.eq(0) 44 | # pad_mask = pad_mask.lt(1e-3) 45 | return pad_mask 46 | 47 | 48 | def padding_mask_q(seq_q, seq_k): 49 | """ seq_k of shape (batch, k_len, k_feat) and seq_q (batch, q_len, q_feat). q and k are padded with 0. pad_mask is (batch, q_len, k_len). 50 | In batch 0: 51 | [[x x x x] [[0 0 0 0] 52 | [x x x x] -> [0 0 0 0] 53 | [0 0 0 0]] [1 1 1 1]] uint8 54 | """ 55 | fake_k = torch.ones_like(seq_k) 56 | pad_mask = torch.bmm(seq_q, fake_k.transpose(1, 2)) 57 | pad_mask = pad_mask.eq(0) 58 | # pad_mask = pad_mask.lt(1e-3) 59 | return pad_mask 60 | 61 | 62 | class PositionalEncoding(nn.Module): 63 | 64 | def __init__(self, d_model, max_seq_len): 65 | super(PositionalEncoding, self).__init__() 66 | self.max_seq_len = max_seq_len 67 | 68 | position_encoding = np.array( 69 | [ 70 | [ 71 | pos / np.power(10000, 2.0 * (j // 2) / d_model) 72 | for j in range(d_model) 73 | ] 74 | for pos in range(max_seq_len) 75 | ]) 76 | position_encoding[:, 0::2] = np.sin(position_encoding[:, 0::2]) 77 | position_encoding[:, 1::2] = np.cos(position_encoding[:, 1::2]) 78 | 79 | pad_row = torch.zeros([1, d_model]) 80 | position_encoding = torch.cat( 81 | (pad_row, torch.from_numpy(position_encoding).float())) 82 | 83 | self.position_encoding = nn.Embedding(max_seq_len + 1, d_model) 84 | self.position_encoding.weight = nn.Parameter( 85 | position_encoding, requires_grad=False) 86 | 87 | def forward(self, input_len): 88 | # max_len = torch.max(input_len) 89 | max_len = self.max_seq_len 90 | tensor = torch.cuda.LongTensor if input_len.is_cuda else torch.LongTensor 91 | input_pos = [ 92 | list(range(1, l + 1)) + [0] * (max_len - l.item()) 93 | for l in input_len 94 | ] 95 | input_pos = tensor(input_pos) 96 | return self.position_encoding(input_pos) 97 | 98 | 99 | class PositionalWiseFeedForward(nn.Module): 100 | 101 | def __init__(self, model_dim=512, ffn_dim=512, dropout=0.0): 102 | super(PositionalWiseFeedForward, self).__init__() 103 | self.w1 = nn.Conv1d(model_dim, ffn_dim, 1) 104 | self.w2 = nn.Conv1d(model_dim, ffn_dim, 1) 105 | self.dropout = nn.Dropout(dropout) 106 | self.layer_norm = nn.LayerNorm(model_dim) 107 | 108 | def forward(self, x): 109 | # x of shape (bs, seq_len, hs) 110 | output = x.transpose(1, 2) 111 | output = self.w2(F.relu(self.w1(output))) 112 | output = self.dropout(output.transpose(1, 2)) 113 | 114 | # add residual and norm layer 115 | output = self.layer_norm(x + output) 116 | return output 117 | 118 | 119 | class MaskedPositionalWiseFeedForward(nn.Module): 120 | 121 | def __init__(self, model_dim=512, ffn_dim=2048, dropout=0.0): 122 | super().__init__() 123 | self.w1 = nn.Linear(model_dim, ffn_dim, bias=False) 124 | self.w2 = nn.Linear(ffn_dim, model_dim, bias=False) 125 | self.dropout = nn.Dropout(dropout) 126 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False) 127 | 128 | def forward(self, x): 129 | # x of shape (bs, seq_len, hs) 130 | output = self.w2(F.relu(self.w1(x))) 131 | output = self.dropout(output) 132 | 133 | # add residual and norm layer 134 | output = self.layer_norm(x + output) 135 | return output 136 | 137 | 138 | class ScaledDotProductAttention(nn.Module): 139 | """Scaled dot-product attention mechanism.""" 140 | 141 | def __init__(self, attention_dropout=0.0): 142 | super(ScaledDotProductAttention, self).__init__() 143 | self.dropout = nn.Dropout(attention_dropout) 144 | self.softmax = nn.Softmax(dim=-1) 145 | 146 | def forward(self, q, k, v, scale=None, attn_mask=None): 147 | """ 148 | Args: 149 | q: [B, L_q, D_q] 150 | k: [B, L_k, D_k] 151 | v: [B, L_v, D_v] 152 | """ 153 | attention = torch.matmul(q, k.transpose(1, 2)) 154 | if scale is not None: 155 | attention = attention * scale 156 | if attn_mask is not None: 157 | attention = attention.masked_fill(attn_mask, -np.inf) 158 | attention = self.softmax(attention) 159 | attention = self.dropout(attention) 160 | output = torch.matmul(attention, v) 161 | return output, attention 162 | 163 | 164 | class MaskedScaledDotProductAttention(nn.Module): 165 | """Scaled dot-product attention mechanism.""" 166 | 167 | def __init__(self, attention_dropout=0.0): 168 | super().__init__() 169 | self.dropout = nn.Dropout(attention_dropout) 170 | self.softmax = nn.Softmax(dim=-1) 171 | 172 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None): 173 | """ 174 | Args: 175 | q: [B, L_q, D_q] 176 | k: [B, L_k, D_k] 177 | v: [B, L_v, D_v] 178 | """ 179 | attention = torch.matmul(q, k.transpose(-2, -1)) 180 | if scale is not None: 181 | attention = attention * scale 182 | if attn_mask is not None: 183 | attention = attention.masked_fill(attn_mask, -np.inf) 184 | attention = self.softmax(attention) 185 | attention = attention.masked_fill(softmax_mask, 0.) 186 | attention = self.dropout(attention) 187 | output = torch.matmul(attention, v) 188 | return output, attention 189 | 190 | 191 | class MultiHeadAttention(nn.Module): 192 | 193 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0): 194 | super(MultiHeadAttention, self).__init__() 195 | 196 | self.dim_per_head = model_dim // num_heads 197 | self.num_heads = num_heads 198 | self.linear_k = nn.Linear( 199 | model_dim, self.dim_per_head * num_heads, bias=False) 200 | self.linear_v = nn.Linear( 201 | model_dim, self.dim_per_head * num_heads, bias=False) 202 | self.linear_q = nn.Linear( 203 | model_dim, self.dim_per_head * num_heads, bias=False) 204 | 205 | self.dot_product_attention = ScaledDotProductAttention(dropout) 206 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False) 207 | self.dropout = nn.Dropout(dropout) 208 | self.layer_norm = nn.LayerNorm(model_dim) 209 | 210 | def forward(self, query, key, value, attn_mask=None): 211 | residual = query 212 | 213 | dim_per_head = self.dim_per_head 214 | num_heads = self.num_heads 215 | batch_size = key.size(0) 216 | 217 | # linear projection 218 | key = self.linear_k(key) 219 | value = self.linear_v(value) 220 | query = self.linear_q(query) 221 | 222 | # split by heads 223 | key = key.view(batch_size * num_heads, -1, dim_per_head) 224 | value = value.view(batch_size * num_heads, -1, dim_per_head) 225 | query = query.view(batch_size * num_heads, -1, dim_per_head) 226 | 227 | if attn_mask is not None: 228 | attn_mask = attn_mask.repeat(num_heads, 1, 1) 229 | 230 | # scaled dot product attention 231 | scale = (key.size(-1) // num_heads)**-0.5 232 | context, attention = self.dot_product_attention( 233 | query, key, value, scale, attn_mask) 234 | 235 | # concat heads 236 | context = context.view(batch_size, -1, dim_per_head * num_heads) 237 | 238 | # final linear projection 239 | output = self.linear_final(context) 240 | 241 | # dropout 242 | output = self.dropout(output) 243 | 244 | # add residual and norm layer 245 | output = self.layer_norm(residual + output) 246 | 247 | return output, attention 248 | 249 | 250 | class MaskedMultiHeadAttention(nn.Module): 251 | 252 | def __init__(self, model_dim=512, num_heads=8, dropout=0.0): 253 | super().__init__() 254 | 255 | self.dim_per_head = model_dim // num_heads 256 | self.num_heads = num_heads 257 | self.linear_k = nn.Linear( 258 | model_dim, self.dim_per_head * num_heads, bias=False) 259 | self.linear_v = nn.Linear( 260 | model_dim, self.dim_per_head * num_heads, bias=False) 261 | self.linear_q = nn.Linear( 262 | model_dim, self.dim_per_head * num_heads, bias=False) 263 | 264 | self.dot_product_attention = MaskedScaledDotProductAttention(dropout) 265 | self.linear_final = nn.Linear(model_dim, model_dim, bias=False) 266 | self.dropout = nn.Dropout(dropout) 267 | self.layer_norm = nn.LayerNorm(model_dim, elementwise_affine=False) 268 | 269 | def forward(self, query, key, value, attn_mask=None, softmax_mask=None): 270 | residual = query 271 | 272 | dim_per_head = self.dim_per_head 273 | num_heads = self.num_heads 274 | batch_size = key.size(0) 275 | 276 | # linear projection 277 | key = self.linear_k(key) 278 | value = self.linear_v(value) 279 | query = self.linear_q(query) 280 | 281 | # split by heads 282 | key = key.view(batch_size, -1, num_heads, dim_per_head).transpose(1, 2) 283 | value = value.view(batch_size, -1, num_heads, 284 | dim_per_head).transpose(1, 2) 285 | query = query.view(batch_size, -1, num_heads, 286 | dim_per_head).transpose(1, 2) 287 | 288 | if attn_mask is not None: 289 | attn_mask = attn_mask.unsqueeze(1).repeat(1, num_heads, 1, 1) 290 | if softmax_mask is not None: 291 | softmax_mask = softmax_mask.unsqueeze(1).repeat(1, num_heads, 1, 1) 292 | # scaled dot product attention 293 | # key.size(-1) is 64? 294 | scale = key.size(-1)**-0.5 295 | context, attention = self.dot_product_attention( 296 | query, key, value, scale, attn_mask, softmax_mask) 297 | 298 | # concat heads 299 | context = context.transpose(1, 2).contiguous().view( 300 | batch_size, -1, dim_per_head * num_heads) 301 | 302 | # final linear projection 303 | output = self.linear_final(context) 304 | 305 | # dropout 306 | output = self.dropout(output) 307 | 308 | # add residual and norm layer 309 | output = self.layer_norm(residual + output) 310 | 311 | return output, attention 312 | 313 | 314 | class SelfTransformerLayer(nn.Module): 315 | 316 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): 317 | super().__init__() 318 | 319 | self.transformer = MaskedMultiHeadAttention( 320 | model_dim, num_heads, dropout) 321 | self.feed_forward = MaskedPositionalWiseFeedForward( 322 | model_dim, ffn_dim, dropout) 323 | 324 | def forward(self, input, attn_mask=None, sf_mask=None): 325 | output, attention = self.transformer( 326 | input, input, input, attn_mask, sf_mask) 327 | # feed forward network 328 | output = self.feed_forward(output) 329 | 330 | return output, attention 331 | 332 | 333 | class SelfTransformer(nn.Module): 334 | 335 | def __init__( 336 | self, 337 | max_len=35, 338 | num_layers=2, 339 | model_dim=512, 340 | num_heads=8, 341 | ffn_dim=2048, 342 | dropout=0.0, 343 | position=False): 344 | super().__init__() 345 | 346 | self.position = position 347 | 348 | self.encoder_layers = nn.ModuleList( 349 | [ 350 | SelfTransformerLayer(model_dim, num_heads, ffn_dim, dropout) 351 | for _ in range(num_layers) 352 | ]) 353 | 354 | # max_seq_len is 35 or 80 355 | self.pos_embedding = PositionalEncoding(model_dim, max_len) 356 | 357 | def forward(self, input, input_length): 358 | # q_length of shape (batch, ), each item is the length of the seq 359 | if self.position: 360 | input += self.pos_embedding(input_length)[:, :input.size()[1], :] 361 | 362 | attention_mask = padding_mask_k(input, input) 363 | softmax_mask = padding_mask_q(input, input) 364 | 365 | attentions = [] 366 | for encoder in self.encoder_layers: 367 | input, attention = encoder(input, attention_mask, softmax_mask) 368 | attentions.append(attention) 369 | 370 | return input, attentions 371 | 372 | 373 | class SelfAttentionLayer(nn.Module): 374 | 375 | def __init__(self, hidden_size, dropout_p=0.0): 376 | super().__init__() 377 | self.dropout = nn.Dropout(dropout_p) 378 | self.softmax = nn.Softmax(dim=-1) 379 | 380 | self.linear_k = nlpnn.WeightDropLinear( 381 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 382 | self.linear_q = nlpnn.WeightDropLinear( 383 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 384 | self.linear_v = nlpnn.WeightDropLinear( 385 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 386 | 387 | self.linear_final = nlpnn.WeightDropLinear( 388 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 389 | 390 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) 391 | 392 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None): 393 | """ 394 | Args: 395 | q: [B, L_q, D_q] 396 | k: [B, L_k, D_k] 397 | v: [B, L_v, D_v] 398 | """ 399 | residual = q 400 | 401 | if attn_mask is None or softmax_mask is None: 402 | attn_mask = padding_mask_k(q, k) 403 | softmax_mask = padding_mask_q(q, k) 404 | 405 | # linear projection 406 | k = self.linear_k(k) 407 | v = self.linear_v(v) 408 | q = self.linear_q(q) 409 | 410 | scale = k.size(-1)**-0.5 411 | 412 | attention = torch.bmm(q, k.transpose(1, 2)) 413 | if scale is not None: 414 | attention = attention * scale 415 | if attn_mask is not None: 416 | attention = attention.masked_fill(attn_mask, -np.inf) 417 | attention = self.softmax(attention) 418 | attention = attention.masked_fill(softmax_mask, 0.) 419 | 420 | # attention = self.dropout(attention) 421 | output = torch.bmm(attention, v) 422 | output = self.linear_final(output) 423 | output = self.layer_norm(output + residual) 424 | return output, attention 425 | 426 | 427 | class SelfAttention(nn.Module): 428 | 429 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 430 | super().__init__() 431 | 432 | self.encoder_layers = nn.ModuleList( 433 | [ 434 | SelfAttentionLayer(hidden_size, dropout_p) 435 | for _ in range(n_layers) 436 | ]) 437 | 438 | def forward(self, input): 439 | 440 | # q_attention_mask of shape (bs, q_len, v_len) 441 | attn_mask = padding_mask_k(input, input) 442 | # v_attention_mask of shape (bs, v_len, q_len) 443 | softmax_mask = padding_mask_q(input, input) 444 | 445 | attentions = [] 446 | for encoder in self.encoder_layers: 447 | input, attention = encoder( 448 | input, 449 | input, 450 | input, 451 | attn_mask=attn_mask, 452 | softmax_mask=softmax_mask) 453 | attentions.append(attention) 454 | 455 | return input, attentions 456 | 457 | 458 | class CoAttentionLayer(nn.Module): 459 | 460 | def __init__(self, hidden_size, dropout_p=0.0): 461 | super().__init__() 462 | self.dropout = nn.Dropout(dropout_p) 463 | self.softmax = nn.Softmax(dim=-1) 464 | 465 | self.linear_question = nlpnn.WeightDropLinear( 466 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 467 | self.linear_video = nlpnn.WeightDropLinear( 468 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 469 | self.linear_v_question = nlpnn.WeightDropLinear( 470 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 471 | self.linear_v_video = nlpnn.WeightDropLinear( 472 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 473 | 474 | self.linear_final_qv = nlpnn.WeightDropLinear( 475 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 476 | self.linear_final_vq = nlpnn.WeightDropLinear( 477 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 478 | 479 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False) 480 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False) 481 | 482 | def forward( 483 | self, 484 | question, 485 | video, 486 | scale=None, 487 | attn_mask=None, 488 | softmax_mask=None, 489 | attn_mask_=None, 490 | softmax_mask_=None): 491 | """ 492 | Args: 493 | q: [B, L_q, D_q] 494 | k: [B, L_k, D_k] 495 | v: [B, L_v, D_v] 496 | """ 497 | q = question 498 | v = video 499 | 500 | if attn_mask is None or softmax_mask is None: 501 | attn_mask = padding_mask_k(question, video) 502 | softmax_mask = padding_mask_q(question, video) 503 | if attn_mask_ is None or softmax_mask_ is None: 504 | attn_mask_ = padding_mask_k(video, question) 505 | softmax_mask_ = padding_mask_q(video, question) 506 | 507 | # linear projection 508 | question_q = self.linear_question(question) 509 | video_k = self.linear_video(video) 510 | question = self.linear_v_question(question) 511 | video = self.linear_v_video(video) 512 | 513 | scale = video.size(-1)**-0.5 514 | 515 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2)) 516 | if scale is not None: 517 | attention_qv = attention_qv * scale 518 | if attn_mask is not None: 519 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf) 520 | attention_qv = self.softmax(attention_qv) 521 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.) 522 | 523 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2)) 524 | if scale is not None: 525 | attention_vq = attention_vq * scale 526 | if attn_mask_ is not None: 527 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf) 528 | attention_vq = self.softmax(attention_vq) 529 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.) 530 | 531 | # attention = self.dropout(attention) 532 | output_qv = torch.bmm(attention_qv, video) 533 | output_qv = self.linear_final_qv(output_qv) 534 | output_q = self.layer_norm_qv(output_qv + q) 535 | 536 | output_vq = torch.bmm(attention_vq, question) 537 | output_vq = self.linear_final_vq(output_vq) 538 | output_v = self.layer_norm_vq(output_vq + v) 539 | return output_q, output_v 540 | 541 | 542 | class CoAttention(nn.Module): 543 | 544 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 545 | super().__init__() 546 | 547 | self.encoder_layers = nn.ModuleList( 548 | [CoAttentionLayer(hidden_size, dropout_p) for _ in range(n_layers)]) 549 | 550 | def forward(self, question, video): 551 | attn_mask = padding_mask_k(question, video) 552 | softmax_mask = padding_mask_q(question, video) 553 | attn_mask_ = padding_mask_k(video, question) 554 | softmax_mask_ = padding_mask_q(video, question) 555 | 556 | for encoder in self.encoder_layers: 557 | question, video = encoder( 558 | question, 559 | video, 560 | attn_mask=attn_mask, 561 | softmax_mask=softmax_mask, 562 | attn_mask_=attn_mask_, 563 | softmax_mask_=softmax_mask_) 564 | 565 | return question, video 566 | 567 | 568 | class CoConcatAttentionLayer(nn.Module): 569 | 570 | def __init__(self, hidden_size, dropout_p=0.0): 571 | super().__init__() 572 | self.dropout = nn.Dropout(dropout_p) 573 | self.softmax = nn.Softmax(dim=-1) 574 | 575 | self.linear_question = nlpnn.WeightDropLinear( 576 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 577 | self.linear_video = nlpnn.WeightDropLinear( 578 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 579 | self.linear_v_question = nlpnn.WeightDropLinear( 580 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 581 | self.linear_v_video = nlpnn.WeightDropLinear( 582 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 583 | 584 | self.linear_final_qv = nn.Sequential( 585 | nlpnn.WeightDropLinear( 586 | 2 * hidden_size, 587 | hidden_size, 588 | weight_dropout=dropout_p, 589 | bias=False), nn.ReLU(), 590 | nlpnn.WeightDropLinear( 591 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)) 592 | self.linear_final_vq = nn.Sequential( 593 | nlpnn.WeightDropLinear( 594 | 2 * hidden_size, 595 | hidden_size, 596 | weight_dropout=dropout_p, 597 | bias=False), nn.ReLU(), 598 | nlpnn.WeightDropLinear( 599 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)) 600 | 601 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False) 602 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False) 603 | 604 | def forward( 605 | self, 606 | question, 607 | video, 608 | scale=None, 609 | attn_mask=None, 610 | softmax_mask=None, 611 | attn_mask_=None, 612 | softmax_mask_=None): 613 | """ 614 | Args: 615 | q: [B, L_q, D_q] 616 | k: [B, L_k, D_k] 617 | v: [B, L_v, D_v] 618 | """ 619 | q = question 620 | v = video 621 | 622 | if attn_mask is None or softmax_mask is None: 623 | attn_mask = padding_mask_k(question, video) 624 | softmax_mask = padding_mask_q(question, video) 625 | if attn_mask_ is None or softmax_mask_ is None: 626 | attn_mask_ = padding_mask_k(video, question) 627 | softmax_mask_ = padding_mask_q(video, question) 628 | 629 | # linear projection 630 | question_q = self.linear_question(question) 631 | video_k = self.linear_video(video) 632 | question = self.linear_v_question(question) 633 | video = self.linear_v_video(video) 634 | 635 | scale = video.size(-1)**-0.5 636 | 637 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2)) 638 | if scale is not None: 639 | attention_qv = attention_qv * scale 640 | if attn_mask is not None: 641 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf) 642 | attention_qv = self.softmax(attention_qv) 643 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.) 644 | 645 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2)) 646 | if scale is not None: 647 | attention_vq = attention_vq * scale 648 | if attn_mask_ is not None: 649 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf) 650 | attention_vq = self.softmax(attention_vq) 651 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.) 652 | 653 | # attention = self.dropout(attention) 654 | output_qv = torch.bmm(attention_qv, video) 655 | output_qv = self.linear_final_qv(torch.cat((output_qv, q), dim=-1)) 656 | # output_q = self.layer_norm_qv(output_qv + q) 657 | output_q = self.layer_norm_qv(output_qv) 658 | 659 | output_vq = torch.bmm(attention_vq, question) 660 | output_vq = self.linear_final_vq(torch.cat((output_vq, v), dim=-1)) 661 | # output_v = self.layer_norm_vq(output_vq + v) 662 | output_v = self.layer_norm_vq(output_vq) 663 | return output_q, output_v 664 | 665 | 666 | class CoConcatAttention(nn.Module): 667 | 668 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 669 | super().__init__() 670 | 671 | self.encoder_layers = nn.ModuleList( 672 | [ 673 | CoConcatAttentionLayer(hidden_size, dropout_p) 674 | for _ in range(n_layers) 675 | ]) 676 | 677 | def forward(self, question, video): 678 | attn_mask = padding_mask_k(question, video) 679 | softmax_mask = padding_mask_q(question, video) 680 | attn_mask_ = padding_mask_k(video, question) 681 | softmax_mask_ = padding_mask_q(video, question) 682 | 683 | for encoder in self.encoder_layers: 684 | question, video = encoder( 685 | question, 686 | video, 687 | attn_mask=attn_mask, 688 | softmax_mask=softmax_mask, 689 | attn_mask_=attn_mask_, 690 | softmax_mask_=softmax_mask_) 691 | 692 | return question, video 693 | 694 | 695 | class CoSiameseAttentionLayer(nn.Module): 696 | 697 | def __init__(self, hidden_size, dropout_p=0.0): 698 | super().__init__() 699 | self.dropout = nn.Dropout(dropout_p) 700 | self.softmax = nn.Softmax(dim=-1) 701 | 702 | self.linear_question = nlpnn.WeightDropLinear( 703 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 704 | self.linear_video = nlpnn.WeightDropLinear( 705 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 706 | self.linear_v_question = nlpnn.WeightDropLinear( 707 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 708 | self.linear_v_video = nlpnn.WeightDropLinear( 709 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 710 | 711 | self.linear_final = nn.Sequential( 712 | nlpnn.WeightDropLinear( 713 | 2 * hidden_size, 714 | hidden_size, 715 | weight_dropout=dropout_p, 716 | bias=False), nn.ReLU(), 717 | nlpnn.WeightDropLinear( 718 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False)) 719 | 720 | self.layer_norm_qv = nn.LayerNorm(hidden_size, elementwise_affine=False) 721 | self.layer_norm_vq = nn.LayerNorm(hidden_size, elementwise_affine=False) 722 | 723 | def forward( 724 | self, 725 | question, 726 | video, 727 | scale=None, 728 | attn_mask=None, 729 | softmax_mask=None, 730 | attn_mask_=None, 731 | softmax_mask_=None): 732 | """ 733 | Args: 734 | q: [B, L_q, D_q] 735 | k: [B, L_k, D_k] 736 | v: [B, L_v, D_v] 737 | """ 738 | q = question 739 | v = video 740 | 741 | if attn_mask is None or softmax_mask is None: 742 | attn_mask = padding_mask_k(question, video) 743 | softmax_mask = padding_mask_q(question, video) 744 | if attn_mask_ is None or softmax_mask_ is None: 745 | attn_mask_ = padding_mask_k(video, question) 746 | softmax_mask_ = padding_mask_q(video, question) 747 | 748 | # linear projection 749 | question_q = self.linear_question(question) 750 | video_k = self.linear_video(video) 751 | question = self.linear_v_question(question) 752 | video = self.linear_v_video(video) 753 | 754 | scale = video.size(-1)**-0.5 755 | 756 | attention_qv = torch.bmm(question_q, video_k.transpose(1, 2)) 757 | if scale is not None: 758 | attention_qv = attention_qv * scale 759 | if attn_mask is not None: 760 | attention_qv = attention_qv.masked_fill(attn_mask, -np.inf) 761 | attention_qv = self.softmax(attention_qv) 762 | attention_qv = attention_qv.masked_fill(softmax_mask, 0.) 763 | 764 | attention_vq = torch.bmm(video_k, question_q.transpose(1, 2)) 765 | if scale is not None: 766 | attention_vq = attention_vq * scale 767 | if attn_mask_ is not None: 768 | attention_vq = attention_vq.masked_fill(attn_mask_, -np.inf) 769 | attention_vq = self.softmax(attention_vq) 770 | attention_vq = attention_vq.masked_fill(softmax_mask_, 0.) 771 | 772 | # attention = self.dropout(attention) 773 | output_qv = torch.bmm(attention_qv, video) 774 | output_qv = self.linear_final(torch.cat((output_qv, q), dim=-1)) 775 | # output_q = self.layer_norm_qv(output_qv + q) 776 | output_q = self.layer_norm_qv(output_qv) 777 | 778 | output_vq = torch.bmm(attention_vq, question) 779 | output_vq = self.linear_final(torch.cat((output_vq, v), dim=-1)) 780 | # output_v = self.layer_norm_vq(output_vq + v) 781 | output_v = self.layer_norm_vq(output_vq) 782 | return output_q, output_v 783 | 784 | 785 | class CoSiameseAttention(nn.Module): 786 | 787 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 788 | super().__init__() 789 | 790 | self.encoder_layers = nn.ModuleList( 791 | [ 792 | CoSiameseAttentionLayer(hidden_size, dropout_p) 793 | for _ in range(n_layers) 794 | ]) 795 | 796 | def forward(self, question, video): 797 | attn_mask = padding_mask_k(question, video) 798 | softmax_mask = padding_mask_q(question, video) 799 | attn_mask_ = padding_mask_k(video, question) 800 | softmax_mask_ = padding_mask_q(video, question) 801 | 802 | for encoder in self.encoder_layers: 803 | question, video = encoder( 804 | question, 805 | video, 806 | attn_mask=attn_mask, 807 | softmax_mask=softmax_mask, 808 | attn_mask_=attn_mask_, 809 | softmax_mask_=softmax_mask_) 810 | 811 | return question, video 812 | 813 | 814 | class SingleAttentionLayer(nn.Module): 815 | 816 | def __init__(self, hidden_size, dropout_p=0.0): 817 | super().__init__() 818 | self.dropout = nn.Dropout(dropout_p) 819 | self.softmax = nn.Softmax(dim=-1) 820 | 821 | self.linear_q = nlpnn.WeightDropLinear( 822 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 823 | self.linear_v = nlpnn.WeightDropLinear( 824 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 825 | self.linear_k = nlpnn.WeightDropLinear( 826 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 827 | 828 | self.linear_final = nlpnn.WeightDropLinear( 829 | hidden_size, hidden_size, weight_dropout=dropout_p, bias=False) 830 | 831 | self.layer_norm = nn.LayerNorm(hidden_size, elementwise_affine=False) 832 | 833 | def forward(self, q, k, v, scale=None, attn_mask=None, softmax_mask=None): 834 | """ 835 | Args: 836 | q: [B, L_q, D_q] 837 | k: [B, L_k, D_k] 838 | v: [B, L_v, D_v] 839 | Return: Same shape to q, but in 'v' space, soft knn 840 | """ 841 | 842 | if attn_mask is None or softmax_mask is None: 843 | attn_mask = padding_mask_k(q, k) 844 | softmax_mask = padding_mask_q(q, k) 845 | 846 | # linear projection 847 | q = self.linear_q(q) 848 | k = self.linear_k(k) 849 | v = self.linear_v(v) 850 | 851 | scale = v.size(-1)**-0.5 852 | 853 | attention = torch.bmm(q, k.transpose(-2, -1)) 854 | if scale is not None: 855 | attention = attention * scale 856 | if attn_mask is not None: 857 | attention = attention.masked_fill(attn_mask, -np.inf) 858 | attention = self.softmax(attention) 859 | attention = attention.masked_fill(softmax_mask, 0.) 860 | 861 | # attention = self.dropout(attention) 862 | output = torch.bmm(attention, v) 863 | output = self.linear_final(output) 864 | output = self.layer_norm(output + q) 865 | 866 | return output 867 | 868 | 869 | class SingleAttention(nn.Module): 870 | 871 | def __init__(self, hidden_size, n_layers=1, dropout_p=0.0): 872 | super().__init__() 873 | 874 | self.encoder_layers = nn.ModuleList( 875 | [ 876 | SingleAttentionLayer(hidden_size, dropout_p) 877 | for _ in range(n_layers) 878 | ]) 879 | 880 | def forward(self, q, v): 881 | attn_mask = padding_mask_k(q, v) 882 | softmax_mask = padding_mask_q(q, v) 883 | 884 | for encoder in self.encoder_layers: 885 | q = encoder(q, v, v, attn_mask=attn_mask, softmax_mask=softmax_mask) 886 | 887 | return q 888 | 889 | 890 | class SoftKNN(nn.Module): 891 | 892 | def __init__(self, model_dim=512, num_heads=1, dropout=0.0): 893 | super().__init__() 894 | 895 | self.dim_per_head = model_dim // num_heads 896 | self.num_heads = num_heads 897 | self.linear_k = nn.Linear( 898 | model_dim, self.dim_per_head * num_heads, bias=False) 899 | self.linear_v = nn.Linear( 900 | model_dim, self.dim_per_head * num_heads, bias=False) 901 | self.linear_q = nn.Linear( 902 | model_dim, self.dim_per_head * num_heads, bias=False) 903 | 904 | self.dot_product_attention = ScaledDotProductAttention(dropout) 905 | 906 | def forward(self, query, key, value, attn_mask=None): 907 | 908 | dim_per_head = self.dim_per_head 909 | num_heads = self.num_heads 910 | batch_size = key.size(0) 911 | 912 | # linear projection 913 | key = self.linear_k(key) 914 | value = self.linear_v(value) 915 | query = self.linear_q(query) 916 | 917 | # split by heads 918 | key = key.view(batch_size * num_heads, -1, dim_per_head) 919 | value = value.view(batch_size * num_heads, -1, dim_per_head) 920 | query = query.view(batch_size * num_heads, -1, dim_per_head) 921 | 922 | if attn_mask is not None: 923 | attn_mask = attn_mask.repeat(num_heads, 1, 1) 924 | # scaled dot product attention 925 | scale = (key.size(-1) // num_heads)**-0.5 926 | context, attention = self.dot_product_attention( 927 | query, key, value, scale, attn_mask) 928 | 929 | # concat heads 930 | output = context.view(batch_size, -1, dim_per_head * num_heads) 931 | 932 | return output, attention 933 | 934 | 935 | class CrossoverTransformerLayer(nn.Module): 936 | 937 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): 938 | super().__init__() 939 | 940 | self.v_transformer = MultiHeadAttention(model_dim, num_heads, dropout) 941 | self.q_transformer = MultiHeadAttention(model_dim, num_heads, dropout) 942 | self.v_feed_forward = PositionalWiseFeedForward( 943 | model_dim, ffn_dim, dropout) 944 | self.q_feed_forward = PositionalWiseFeedForward( 945 | model_dim, ffn_dim, dropout) 946 | 947 | def forward(self, question, video, q_mask=None, v_mask=None): 948 | # self attention, v_attention of shape (bs, v_len, q_len) 949 | video_, v_attention = self.v_transformer( 950 | video, question, question, v_mask) 951 | # feed forward network 952 | video_ = self.v_feed_forward(video_) 953 | 954 | # self attention, q_attention of shape (bs, q_len, v_len) 955 | question_, q_attention = self.q_transformer( 956 | question, video, video, q_mask) 957 | # feed forward network 958 | question_ = self.q_feed_forward(question_) 959 | 960 | return video_, question_, v_attention, q_attention 961 | 962 | 963 | class CrossoverTransformer(nn.Module): 964 | 965 | def __init__( 966 | self, 967 | q_max_len=35, 968 | v_max_len=80, 969 | num_layers=2, 970 | model_dim=512, 971 | num_heads=8, 972 | ffn_dim=2048, 973 | dropout=0.0): 974 | super().__init__() 975 | 976 | self.encoder_layers = nn.ModuleList( 977 | [ 978 | CrossoverTransformerLayer( 979 | model_dim, num_heads, ffn_dim, dropout) 980 | for _ in range(num_layers) 981 | ]) 982 | 983 | # max_seq_len is 35 or 80 984 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len) 985 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len) 986 | 987 | def forward(self, question, video, q_length, v_length): 988 | # q_length of shape (batch, ), each item is the length of the seq 989 | question += self.q_pos_embedding(q_length)[:, :question.size()[1], :] 990 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :] 991 | 992 | # q_attention_mask of shape (bs, q_len, v_len) 993 | q_attention_mask = padding_mask_k(question, video) 994 | # v_attention_mask of shape (bs, v_len, q_len) 995 | v_attention_mask = padding_mask_k(video, question) 996 | 997 | q_attentions = [] 998 | v_attentions = [] 999 | for encoder in self.encoder_layers: 1000 | video, question, v_attention, q_attention = encoder( 1001 | question, video, q_attention_mask, v_attention_mask) 1002 | q_attentions.append(q_attention) 1003 | v_attentions.append(v_attention) 1004 | 1005 | return question, video, q_attentions, v_attentions 1006 | 1007 | 1008 | class MaskedCrossoverTransformerLayer(nn.Module): 1009 | 1010 | def __init__(self, model_dim=512, num_heads=8, ffn_dim=2048, dropout=0.0): 1011 | super().__init__() 1012 | 1013 | self.v_transformer = MaskedMultiHeadAttention( 1014 | model_dim, num_heads, dropout) 1015 | self.q_transformer = MaskedMultiHeadAttention( 1016 | model_dim, num_heads, dropout) 1017 | self.v_feed_forward = MaskedPositionalWiseFeedForward( 1018 | model_dim, ffn_dim, dropout) 1019 | self.q_feed_forward = MaskedPositionalWiseFeedForward( 1020 | model_dim, ffn_dim, dropout) 1021 | 1022 | def forward( 1023 | self, 1024 | question, 1025 | video, 1026 | q_mask=None, 1027 | v_mask=None, 1028 | q_sf_mask=None, 1029 | v_sf_mask=None): 1030 | # self attention, v_attention of shape (bs, v_len, q_len) 1031 | video_, v_attention = self.v_transformer( 1032 | video, question, question, v_mask, v_sf_mask) 1033 | # feed forward network 1034 | video_ = self.v_feed_forward(video_) 1035 | 1036 | # self attention, q_attention of shape (bs, q_len, v_len) 1037 | question_, q_attention = self.q_transformer( 1038 | question, video, video, q_mask, q_sf_mask) 1039 | # feed forward network 1040 | question_ = self.q_feed_forward(question_) 1041 | 1042 | return video_, question_, v_attention, q_attention 1043 | 1044 | 1045 | class MaskedCrossoverTransformer(nn.Module): 1046 | 1047 | def __init__( 1048 | self, 1049 | q_max_len=35, 1050 | v_max_len=80, 1051 | num_layers=2, 1052 | model_dim=512, 1053 | num_heads=8, 1054 | ffn_dim=2048, 1055 | dropout=0.0, 1056 | position=False): 1057 | super().__init__() 1058 | 1059 | self.position = position 1060 | 1061 | self.encoder_layers = nn.ModuleList( 1062 | [ 1063 | MaskedCrossoverTransformerLayer( 1064 | model_dim, num_heads, ffn_dim, dropout) 1065 | for _ in range(num_layers) 1066 | ]) 1067 | 1068 | # max_seq_len is 35 or 80 1069 | self.q_pos_embedding = PositionalEncoding(model_dim, q_max_len) 1070 | self.v_pos_embedding = PositionalEncoding(model_dim, v_max_len) 1071 | 1072 | def forward(self, question, video, q_length, v_length): 1073 | # q_length of shape (batch, ), each item is the length of the seq 1074 | if self.position: 1075 | question += self.q_pos_embedding( 1076 | q_length)[:, :question.size()[1], :] 1077 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :] 1078 | 1079 | q_attention_mask = padding_mask_k(question, video) 1080 | q_softmax_mask = padding_mask_q(question, video) 1081 | v_attention_mask = padding_mask_k(video, question) 1082 | v_softmax_mask = padding_mask_q(video, question) 1083 | 1084 | q_attentions = [] 1085 | v_attentions = [] 1086 | for encoder in self.encoder_layers: 1087 | video, question, v_attention, q_attention = encoder( 1088 | question, video, q_attention_mask, v_attention_mask, 1089 | q_softmax_mask, v_softmax_mask) 1090 | q_attentions.append(q_attention) 1091 | v_attentions.append(v_attention) 1092 | 1093 | return question, video, q_attentions, v_attentions 1094 | 1095 | 1096 | class SelfTransformerEncoder(nn.Module): 1097 | 1098 | def __init__( 1099 | self, 1100 | hidden_size, 1101 | n_layers, 1102 | dropout_p, 1103 | vocab_size, 1104 | q_max_len, 1105 | v_max_len, 1106 | embedding=None, 1107 | update_embedding=True, 1108 | position=True): 1109 | super().__init__() 1110 | self.dropout = nn.Dropout(p=dropout_p) 1111 | self.ln_q = nn.LayerNorm(hidden_size, elementwise_affine=False) 1112 | self.ln_v = nn.LayerNorm(hidden_size, elementwise_affine=False) 1113 | self.n_layers = n_layers 1114 | self.position = position 1115 | 1116 | embedding_dim = embedding.shape[ 1117 | 1] if embedding is not None else hidden_size 1118 | self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) 1119 | 1120 | # ! no embedding init 1121 | # if embedding is not None: 1122 | # # self.embedding.weight.data.copy_(torch.from_numpy(embedding)) 1123 | # self.embedding.weight = nn.Parameter( 1124 | # torch.from_numpy(embedding).float()) 1125 | self.upcompress_embedding = nlpnn.WeightDropLinear( 1126 | embedding_dim, hidden_size, weight_dropout=dropout_p, bias=False) 1127 | self.embedding.weight.requires_grad = update_embedding 1128 | 1129 | self.project_c3d = nlpnn.WeightDropLinear(4096, 2048, bias=False) 1130 | 1131 | self.project_resnet_and_c3d = nlpnn.WeightDropLinear( 1132 | 4096, hidden_size, weight_dropout=dropout_p, bias=False) 1133 | 1134 | # max_seq_len is 35 or 80 1135 | self.q_pos_embedding = PositionalEncoding(hidden_size, q_max_len) 1136 | self.v_pos_embedding = PositionalEncoding(hidden_size, v_max_len) 1137 | 1138 | def forward(self, question, resnet, c3d, q_length, v_length): 1139 | ### question 1140 | embedded = self.embedding(question) 1141 | embedded = self.dropout(embedded) 1142 | question = F.relu(self.upcompress_embedding(embedded)) 1143 | 1144 | ### video 1145 | # ! no relu 1146 | c3d = self.project_c3d(c3d) 1147 | video = F.relu( 1148 | self.project_resnet_and_c3d(torch.cat((resnet, c3d), dim=2))) 1149 | 1150 | ### position encoding 1151 | if self.position: 1152 | question += self.q_pos_embedding( 1153 | q_length)[:, :question.size()[1], :] 1154 | video += self.v_pos_embedding(v_length)[:, :video.size()[1], :] 1155 | 1156 | # question = self.ln_q(question) 1157 | # video = self.ln_v(video) 1158 | return question, video 1159 | -------------------------------------------------------------------------------- /networks/torchnlp_nn.py: -------------------------------------------------------------------------------- 1 | from torch.nn import Parameter 2 | 3 | import torch 4 | 5 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 6 | 7 | 8 | class weight_drop(): 9 | 10 | def __init__(self, module, weights, dropout): 11 | for name_w in weights: 12 | w = getattr(module, name_w) 13 | del module._parameters[name_w] 14 | module.register_parameter(name_w + '_raw', Parameter(w)) 15 | 16 | self.original_module_forward = module.forward 17 | 18 | self.weights = weights 19 | self.module = module 20 | self.dropout = dropout 21 | 22 | def __call__(self, *args, **kwargs): 23 | for name_w in self.weights: 24 | raw_w = getattr(self.module, name_w + '_raw') 25 | w = torch.nn.functional.dropout( 26 | raw_w, p=self.dropout, training=self.module.training) 27 | # module.register_parameter(name_w, Parameter(w)) 28 | setattr(self.module, name_w, Parameter(w)) 29 | 30 | return self.original_module_forward(*args, **kwargs) 31 | 32 | 33 | def _weight_drop(module, weights, dropout): 34 | setattr(module, 'forward', weight_drop(module, weights, dropout)) 35 | 36 | 37 | # def _weight_drop(module, weights, dropout): 38 | # """ 39 | # Helper for `WeightDrop`. 40 | # """ 41 | 42 | # for name_w in weights: 43 | # w = getattr(module, name_w) 44 | # del module._parameters[name_w] 45 | # module.register_parameter(name_w + '_raw', Parameter(w)) 46 | 47 | # original_module_forward = module.forward 48 | 49 | # def forward(*args, **kwargs): 50 | # for name_w in weights: 51 | # raw_w = getattr(module, name_w + '_raw') 52 | # w = torch.nn.functional.dropout( 53 | # raw_w, p=dropout, training=module.training) 54 | # # module.register_parameter(name_w, Parameter(w)) 55 | # setattr(module, name_w, Parameter(w)) 56 | 57 | # return original_module_forward(*args, **kwargs) 58 | 59 | # setattr(module, 'forward', forward) 60 | 61 | 62 | class WeightDrop(torch.nn.Module): 63 | """ 64 | The weight-dropped module applies recurrent regularization through a DropConnect mask on the 65 | hidden-to-hidden recurrent weights. 66 | **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. Here is 67 | their `License 68 | `__. 69 | Args: 70 | module (:class:`torch.nn.Module`): Containing module. 71 | weights (:class:`list` of :class:`str`): Names of the module weight parameters to apply a 72 | dropout too. 73 | dropout (float): The probability a weight will be dropped. 74 | Example: 75 | >>> from torchnlp.nn import WeightDrop 76 | >>> import torch 77 | >>> 78 | >>> torch.manual_seed(123) 79 | >> 81 | >>> gru = torch.nn.GRUCell(2, 2) 82 | >>> weights = ['weight_hh'] 83 | >>> weight_drop_gru = WeightDrop(gru, weights, dropout=0.9) 84 | >>> 85 | >>> input_ = torch.randn(3, 2) 86 | >>> hidden_state = torch.randn(3, 2) 87 | >>> weight_drop_gru(input_, hidden_state) 88 | tensor(... grad_fn=) 89 | """ 90 | 91 | def __init__(self, module, weights, dropout=0.0): 92 | super(WeightDrop, self).__init__() 93 | _weight_drop(module, weights, dropout) 94 | self.forward = module.forward 95 | 96 | 97 | class WeightDropLSTM(torch.nn.LSTM): 98 | """ 99 | Wrapper around :class:`torch.nn.LSTM` that adds ``weight_dropout`` named argument. 100 | Args: 101 | weight_dropout (float): The probability a weight will be dropped. 102 | """ 103 | 104 | def __init__(self, *args, weight_dropout=0.0, **kwargs): 105 | super().__init__(*args, **kwargs) 106 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)] 107 | _weight_drop(self, weights, weight_dropout) 108 | 109 | 110 | class WeightDropGRU(torch.nn.GRU): 111 | """ 112 | Wrapper around :class:`torch.nn.GRU` that adds ``weight_dropout`` named argument. 113 | Args: 114 | weight_dropout (float): The probability a weight will be dropped. 115 | """ 116 | 117 | def __init__(self, *args, weight_dropout=0.0, **kwargs): 118 | super().__init__(*args, **kwargs) 119 | weights = ['weight_hh_l' + str(i) for i in range(self.num_layers)] 120 | _weight_drop(self, weights, weight_dropout) 121 | 122 | 123 | class WeightDropLinear(torch.nn.Linear): 124 | """ 125 | Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. 126 | Args: 127 | weight_dropout (float): The probability a weight will be dropped. 128 | """ 129 | 130 | def __init__(self, *args, weight_dropout=0.0, **kwargs): 131 | super().__init__(*args, **kwargs) 132 | weights = ['weight'] 133 | _weight_drop(self, weights, weight_dropout) 134 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backcall==0.2.0 2 | bert-score==0.3.6 3 | block==0.0.5 4 | block.bootstrap.pytorch==0.1.6 5 | bootstrap.pytorch==0.0.13 6 | certifi==2020.12.5 7 | chardet==4.0.0 8 | click==8.0.1 9 | cycler==0.10.0 10 | decorator==5.0.9 11 | future==0.18.2 12 | h5py==2.7.1 13 | idna==2.10 14 | importlib-metadata==4.0.1 15 | importlib-resources==5.1.4 16 | ipdb==0.13.7 17 | ipython==7.16.1 18 | ipython-genutils==0.2.0 19 | jedi==0.18.0 20 | kiwisolver==1.3.1 21 | matplotlib==3.3.4 22 | munch==2.5.0 23 | nltk==3.3 24 | numpy==1.19.5 25 | opencv-python==4.5.2.52 26 | pandas==1.1.2 27 | parso==0.8.2 28 | pexpect==4.8.0 29 | pickleshare==0.7.5 30 | Pillow==8.2.0 31 | plotly==4.14.3 32 | pretrainedmodels==0.7.4 33 | prompt-toolkit==3.0.18 34 | protobuf==3.17.0 35 | ptyprocess==0.7.0 36 | Pygments==2.9.0 37 | pyparsing==2.4.7 38 | python-dateutil==2.8.1 39 | pytz==2021.1 40 | pywsd==1.2.4 41 | PyYAML==5.4.1 42 | requests==2.25.1 43 | retrying==1.3.3 44 | scipy==1.5.4 45 | seaborn==0.11.1 46 | six==1.16.0 47 | skipthoughts==0.0.1 48 | tabulate==0.8.9 49 | tensorboardX==2.2 50 | toml==0.10.2 51 | torch==1.6.0 52 | torchvision==0.7.0 53 | tqdm==4.60.0 54 | traitlets==4.3.3 55 | typing-extensions==3.10.0.0 56 | urllib3==1.26.4 57 | wcwidth==0.2.5 58 | wn==0.0.23 59 | zipp==3.4.1 60 | -------------------------------------------------------------------------------- /stopwords.txt: -------------------------------------------------------------------------------- 1 | i 2 | me 3 | my 4 | myself 5 | we 6 | our 7 | ours 8 | ourselves 9 | you 10 | you're 11 | you've 12 | you'll 13 | you'd 14 | your 15 | yours 16 | yourself 17 | yourselves 18 | he 19 | him 20 | his 21 | himself 22 | she 23 | she's 24 | her 25 | hers 26 | herself 27 | it 28 | it's 29 | its 30 | itself 31 | they 32 | them 33 | their 34 | theirs 35 | themselves 36 | what 37 | which 38 | who 39 | whom 40 | this 41 | that 42 | that'll 43 | these 44 | those 45 | am 46 | is 47 | are 48 | was 49 | were 50 | be 51 | been 52 | being 53 | have 54 | has 55 | had 56 | having 57 | do 58 | does 59 | did 60 | doing 61 | a 62 | an 63 | the 64 | and 65 | but 66 | if 67 | or 68 | because 69 | as 70 | until 71 | while 72 | to 73 | from 74 | of 75 | at 76 | for 77 | with 78 | about 79 | into 80 | through 81 | during 82 | again 83 | further 84 | then 85 | here 86 | there 87 | when 88 | where 89 | why 90 | how 91 | all 92 | any 93 | each 94 | most 95 | other 96 | some 97 | such 98 | only 99 | own 100 | so 101 | than 102 | too 103 | very 104 | s 105 | t 106 | can 107 | will 108 | just 109 | don 110 | don't 111 | should 112 | should've 113 | now 114 | d 115 | ll 116 | m 117 | o 118 | re 119 | ve 120 | y 121 | ain 122 | aren 123 | aren't 124 | couldn 125 | couldn't 126 | didn 127 | didn't 128 | doesn 129 | doesn't 130 | hadn 131 | hadn't 132 | hasn 133 | hasn't 134 | haven 135 | haven't 136 | isn 137 | isn't 138 | ma 139 | mightn 140 | mightn't 141 | mustn 142 | mustn't 143 | needn 144 | needn't 145 | shan 146 | shan't 147 | shouldn 148 | shouldn't 149 | wasn 150 | wasn't 151 | weren 152 | weren't 153 | won 154 | won't 155 | wouldn 156 | wouldn't 157 | -------------------------------------------------------------------------------- /tools/.gitignore: -------------------------------------------------------------------------------- 1 | # ignore all except .gitignore file 2 | * 3 | !.gitignore 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | import pickle as pkl 5 | import pandas as pd 6 | 7 | def set_gpu_devices(gpu_id): 8 | gpu = '' 9 | if gpu_id != -1: 10 | gpu = str(gpu_id) 11 | os.environ['CUDA_VOSIBLE_DEVICES'] = gpu 12 | 13 | 14 | def load_file(filename): 15 | """ 16 | load obj from filename 17 | :param filename: 18 | :return: 19 | """ 20 | cont = None 21 | if not osp.exists(filename): 22 | print('{} not exist'.format(filename)) 23 | return cont 24 | if osp.splitext(filename)[-1] == '.csv': 25 | return pd.read_csv(filename, delimiter=',') 26 | with open(filename, 'r') as fp: 27 | if osp.splitext(filename)[1] == '.txt': 28 | cont = fp.readlines() 29 | cont = [c.rstrip('\n') for c in cont] 30 | elif osp.splitext(filename)[1] == '.json': 31 | cont = json.load(fp) 32 | return cont 33 | 34 | def save_file(obj, filename): 35 | """ 36 | save obj to filename 37 | :param obj: 38 | :param filename: 39 | :return: 40 | """ 41 | filepath = osp.dirname(filename) 42 | if filepath != '' and not osp.exists(filepath): 43 | os.makedirs(filepath) 44 | else: 45 | with open(filename, 'w') as fp: 46 | json.dump(obj, fp, indent=4) 47 | 48 | def pkload(file): 49 | data = None 50 | if osp.exists(file) and osp.getsize(file) > 0: 51 | with open(file, 'rb') as fp: 52 | data = pkl.load(fp) 53 | # print('{} does not exist'.format(file)) 54 | return data 55 | 56 | 57 | def pkdump(data, file): 58 | dirname = osp.dirname(file) 59 | if not osp.exists(dirname): 60 | os.makedirs(dirname) 61 | with open(file, 'wb') as fp: 62 | pkl.dump(data, fp) 63 | 64 | -------------------------------------------------------------------------------- /videoqa.py: -------------------------------------------------------------------------------- 1 | from networks import EncoderRNN, DecoderRNN 2 | from networks.VQAModel import EVQA, UATT, STVQA, CoMem, HME, HGA 3 | from utils import * 4 | import torch 5 | from torch.optim.lr_scheduler import ReduceLROnPlateau 6 | import torch.nn as nn 7 | import time 8 | from metrics import get_wups 9 | from eval_oe import remove_stop 10 | 11 | class VideoQA(): 12 | def __init__(self, vocab_qns, vocab_ans, train_loader, val_loader, glove_embed_qns, glove_embed_ans, 13 | checkpoint_path, model_type, model_prefix, vis_step, 14 | lr_rate, batch_size, epoch_num): 15 | self.vocab_qns = vocab_qns 16 | self.vocab_ans = vocab_ans 17 | self.train_loader = train_loader 18 | self.val_loader = val_loader 19 | self.glove_embed_qns = glove_embed_qns 20 | self.glove_embed_ans = glove_embed_ans 21 | self.model_dir = checkpoint_path 22 | self.model_type = model_type 23 | self.model_prefix = model_prefix 24 | self.vis_step = vis_step 25 | self.lr_rate = lr_rate 26 | self.batch_size = batch_size 27 | self.epoch_num = epoch_num 28 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | self.model = None 30 | 31 | 32 | def build_model(self): 33 | 34 | vid_dim = 2048+2048 35 | hidden_dim = 512 36 | word_dim = 300 37 | qns_vocab_size = len(self.vocab_qns) 38 | ans_vocab_size = len(self.vocab_ans) 39 | max_ans_len = 7 40 | max_vid_len = 16 41 | max_qns_len = 23 42 | 43 | 44 | if self.model_type == 'EVQA' or self.model_type == 'BlindQA': 45 | #ICCV15, AAAI17 46 | vid_encoder = EncoderRNN.EncoderVid(vid_dim, hidden_dim, input_dropout_p=0.3, n_layers=1, rnn_dropout_p=0, 47 | bidirectional=False, rnn_cell='lstm') 48 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=1, 49 | input_dropout_p=0.3, rnn_dropout_p=0, bidirectional=False, rnn_cell='lstm') 50 | 51 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans, 52 | n_layers=1, input_dropout_p=0.3, rnn_dropout_p=0, rnn_cell='gru') 53 | self.model = EVQA.EVQA(vid_encoder, qns_encoder, ans_decoder, self.device) 54 | 55 | elif self.model_type == 'UATT': 56 | #TIP17 57 | # hidden_dim = 512 58 | vid_encoder = EncoderRNN.EncoderVid(vid_dim, hidden_dim, input_dropout_p=0.3, bidirectional=True, 59 | rnn_cell='lstm') 60 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, 61 | input_dropout_p=0.3, bidirectional=True, rnn_cell='lstm') 62 | 63 | ans_decoder = DecoderRNN.AnsUATT(ans_vocab_size, max_ans_len, hidden_dim, word_dim, 64 | self.glove_embed_ans, n_layers=2, input_dropout_p=0.3, 65 | rnn_dropout_p=0.5, rnn_cell='lstm') 66 | self.model = UATT.UATT(vid_encoder, qns_encoder, ans_decoder, self.device) 67 | 68 | elif self.model_type == 'STVQA': 69 | #CVPR17 70 | vid_dim = 2048 + 2048 # (64, 1024+2048, 7, 7) 71 | att_dim = 256 72 | hidden_dim = 256 73 | vid_encoder = EncoderRNN.EncoderVidSTVQA(vid_dim, hidden_dim, input_dropout_p=0.3, rnn_dropout_p=0, 74 | n_layers=1, rnn_cell='lstm') 75 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, 76 | input_dropout_p=0.3, rnn_dropout_p=0.5, n_layers=2, rnn_cell='lstm') 77 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans, 78 | input_dropout_p=0.3, rnn_dropout_p=0, n_layers=1, rnn_cell='gru') 79 | self.model = STVQA.STVQA(vid_encoder, qns_encoder, ans_decoder, att_dim, self.device) 80 | 81 | 82 | elif self.model_type == 'CoMem': 83 | #CVPR18 84 | app_dim = 2048 85 | motion_dim = 2048 86 | hidden_dim = 256 87 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.3, 88 | bidirectional=False, rnn_cell='gru') 89 | 90 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=2, 91 | rnn_dropout_p=0.5, input_dropout_p=0.3, bidirectional=False, rnn_cell='gru') 92 | 93 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans, 94 | n_layers=1, input_dropout_p=0.3, rnn_dropout_p=0, rnn_cell='gru') 95 | 96 | self.model = CoMem.CoMem(vid_encoder, qns_encoder, ans_decoder, max_vid_len, max_qns_len, self.device) 97 | 98 | 99 | elif self.model_type == 'HME': 100 | #CVPR19 101 | app_dim = 2048 102 | motion_dim = 2048 103 | vid_encoder = EncoderRNN.EncoderVidCoMem(app_dim, motion_dim, hidden_dim, input_dropout_p=0.3, 104 | bidirectional=False, rnn_cell='lstm') 105 | 106 | qns_encoder = EncoderRNN.EncoderQns(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=2, 107 | rnn_dropout_p=0.5, input_dropout_p=0.3, bidirectional=False, rnn_cell='lstm') 108 | 109 | ans_decoder = DecoderRNN.AnsHME(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans, 110 | n_layers=2, input_dropout_p=0.3, rnn_dropout_p=0.5, rnn_cell='lstm') 111 | 112 | self.model = HME.HME(vid_encoder, qns_encoder, ans_decoder, max_vid_len, max_qns_len, self.device) 113 | 114 | 115 | elif self.model_type == 'HGA': 116 | #AAAI20 117 | vid_encoder = EncoderRNN.EncoderVidHGA(vid_dim, hidden_dim, input_dropout_p=0.3, 118 | bidirectional=False, rnn_cell='gru') 119 | 120 | qns_encoder = EncoderRNN.EncoderQnsHGA(word_dim, hidden_dim, qns_vocab_size, self.glove_embed_qns, n_layers=1, 121 | rnn_dropout_p=0, input_dropout_p=0.3, bidirectional=False, 122 | rnn_cell='gru') 123 | 124 | ans_decoder = DecoderRNN.AnsAttSeq(ans_vocab_size, max_ans_len, hidden_dim, word_dim, self.glove_embed_ans, 125 | n_layers=1, input_dropout_p=0.3, rnn_dropout_p=0, rnn_cell='gru') 126 | 127 | self.model = HGA.HGA(vid_encoder, qns_encoder, ans_decoder, max_vid_len, max_qns_len, self.device) 128 | 129 | 130 | params = [{'params':self.model.parameters()}] 131 | # params = [{'params': vid_encoder.parameters()}, {'params': qns_encoder.parameters()}, 132 | # {'params': ans_decoder.parameters(), 'lr': self.lr_rate}] 133 | self.optimizer = torch.optim.Adam(params = params, lr=self.lr_rate) 134 | self.scheduler = ReduceLROnPlateau(self.optimizer, 'max', factor=0.5, patience=5, verbose=True) 135 | # if torch.cuda.device_count() > 1: 136 | # print("Let's use", torch.cuda.device_count(), "GPUs!") 137 | # self.model = nn.DataParallel(self.model) 138 | 139 | self.model.to(self.device) 140 | self.criterion = nn.CrossEntropyLoss().to(self.device) 141 | 142 | 143 | def save_model(self, epoch, loss): 144 | torch.save(self.model.state_dict(), osp.join(self.model_dir, '{}-{}-{}-{:.4f}.ckpt' 145 | .format(self.model_type, self.model_prefix, epoch, loss))) 146 | 147 | def resume(self, model_file): 148 | """ 149 | initialize model with pretrained weights 150 | :return: 151 | """ 152 | model_path = osp.join(self.model_dir, model_file) 153 | print(f'Warm-starting from model {model_path}') 154 | model_dict = torch.load(model_path) 155 | new_model_dict = {} 156 | for k, v in self.model.state_dict().items(): 157 | if k in model_dict: 158 | v = model_dict[k] 159 | 160 | new_model_dict[k] = v 161 | self.model.load_state_dict(new_model_dict) 162 | 163 | 164 | def run(self, model_file, pre_trained=False): 165 | self.build_model() 166 | best_eval_score = 0.0 167 | if pre_trained: 168 | self.resume(model_file) 169 | best_eval_score = self.eval(0) 170 | print('Initial Acc {:.4f}'.format(best_eval_score)) 171 | 172 | for epoch in range(1, self.epoch_num): 173 | train_loss = self.train(epoch) 174 | eval_score = self.eval(epoch) 175 | print("==>Epoch:[{}/{}][Train Loss: {:.4f} Val acc: {:.4f}]". 176 | format(epoch, self.epoch_num, train_loss, eval_score)) 177 | self.scheduler.step(eval_score) 178 | if eval_score > best_eval_score or pre_trained: 179 | best_eval_score = eval_score 180 | if epoch > 10 or pre_trained: 181 | self.save_model(epoch, best_eval_score) 182 | 183 | 184 | def train(self, epoch): 185 | print('==>Epoch:[{}/{}][lr_rate: {}]'.format(epoch, self.epoch_num, self.optimizer.param_groups[0]['lr'])) 186 | self.model.train() 187 | total_step = len(self.train_loader) 188 | epoch_loss = 0.0 189 | for iter, inputs in enumerate(self.train_loader): 190 | videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes = inputs 191 | video_inputs = videos.to(self.device) 192 | qns_inputs = targets_qns.to(self.device) 193 | ans_inputs = targets_ans.to(self.device) 194 | prediction = self.model(video_inputs, qns_inputs, qns_lengths, ans_inputs, ans_lengths, 0.5) 195 | 196 | out_dim = prediction.shape[-1] 197 | prediction = prediction.view(-1, out_dim) 198 | ans_targets = ans_inputs.view(-1) 199 | 200 | loss = self.criterion(prediction, ans_targets) 201 | self.model.zero_grad() 202 | loss.backward() 203 | self.optimizer.step() 204 | cur_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) 205 | if iter % self.vis_step == 0: 206 | print('\t[{}/{}]-{}-{:.4f}'.format(iter, total_step,cur_time, loss.item())) 207 | epoch_loss += loss.item() 208 | 209 | return epoch_loss / total_step 210 | 211 | 212 | def eval(self, epoch): 213 | print('==>Epoch:[{}/{}][validation stage]'.format(epoch, self.epoch_num)) 214 | self.model.eval() 215 | total_step = len(self.val_loader) 216 | acc_count = 0 217 | with torch.no_grad(): 218 | for iter, inputs in enumerate(self.val_loader): 219 | videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes = inputs 220 | video_inputs = videos.to(self.device) 221 | qns_inputs = targets_qns.to(self.device) 222 | ans_inputs = targets_ans.to(self.device) 223 | prediction = self.model(video_inputs, qns_inputs, qns_lengths, ans_inputs, ans_lengths, mode='val') 224 | acc_count += get_acc_count(prediction, targets_ans, self.vocab_ans, qtypes) 225 | 226 | return acc_count*1.0 / ((total_step-1)*self.batch_size) 227 | 228 | 229 | def predict(self, model_file, res_file): 230 | """ 231 | predict the answer with the trained model 232 | :param model_file: 233 | :return: 234 | """ 235 | model_path = osp.join(self.model_dir, model_file) 236 | self.build_model() 237 | if self.model_type == 'HGA': 238 | self.resume(model_file) 239 | else: 240 | old_state_dict = torch.load(model_path) 241 | self.model.load_state_dict(old_state_dict) 242 | #self.resume() 243 | self.model.eval() 244 | total = len(self.val_loader) 245 | acc = 0 246 | results = {} 247 | with torch.no_grad(): 248 | for iter, inputs in enumerate(self.val_loader): 249 | videos, targets_qns, qns_lengths, targets_ans, ans_lengths, video_names, qids, qtypes = inputs 250 | video_inputs = videos.to(self.device) 251 | qns_inputs = targets_qns.to(self.device) 252 | ans_inputs = targets_ans.to(self.device) 253 | # predict_ans_idxs = self.model.predict(video_inputs, qns_inputs, qns_lengths) 254 | predict_ans_idxs = self.model(video_inputs, qns_inputs, qns_lengths, ans_inputs, ans_lengths, mode='val') 255 | ans_idxs = predict_ans_idxs.cpu().numpy() 256 | targets_ans = targets_ans.numpy() 257 | targets_qns = targets_qns.numpy() 258 | for vname in video_names: 259 | if vname not in results: 260 | results[vname] = {} 261 | for bs, idx in enumerate(ans_idxs): 262 | ans_pred = [self.vocab_ans.idx2word[ans_id] for ans_id in idx[1:] if ans_id >3] #the first 4 ids are reserved for special token 263 | ans_pred = ' '.join(ans_pred) 264 | groundtruth = [self.vocab_ans.idx2word[ans_id] for ans_id in targets_ans[bs][1:] if ans_id > 3] 265 | groundtruth = ' '.join(groundtruth) 266 | qns_text = [self.vocab_qns.idx2word[qns_id] for qns_id in targets_qns[bs][1:] if qns_id > 3] 267 | # if qids[bs] not in results[video_names[bs]]: 268 | qns_text = ' '.join(qns_text) 269 | results[video_names[bs]][qids[bs]] = ans_pred 270 | if ans_pred==groundtruth and ans_pred != '': 271 | acc += 1 272 | # print(f'[{iter}/{total}]{qns_text}? P:{ans_pred} G:{groundtruth}') 273 | 274 | save_file(results, f'results/{res_file}') 275 | 276 | 277 | def get_acc_count(prediction, labels, vocab_ans, qtypes): 278 | """ 279 | 280 | :param prediction: 281 | :param labels: 282 | :return: 283 | """ 284 | preds = prediction.data.cpu().numpy() 285 | labels = np.asarray(labels) 286 | batch_size = labels.shape[0] 287 | score = 0 288 | for i in range(batch_size): 289 | pred = [j for j in preds[i] if j > 3] 290 | ans = [j for j in labels[i] if j > 3] 291 | pred_ans = ' '.join([vocab_ans.idx2word[id] for id in pred]) 292 | gt_ans = ' '.join([vocab_ans.idx2word[id] for id in ans]) 293 | pred_ans = remove_stop(pred_ans) 294 | gt_ans = remove_stop(gt_ans) 295 | cur_s = 0 296 | if qtypes[i] in ['CC', 'CB']: 297 | if gt_ans == pred_ans: 298 | cur_s = 1 299 | else: 300 | cur_s = get_wups(pred_ans, gt_ans, 0) 301 | score += cur_s 302 | 303 | 304 | return score 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | -------------------------------------------------------------------------------- /word2vec.py: -------------------------------------------------------------------------------- 1 | from build_vocab import Vocabulary 2 | from utils import * 3 | import numpy as np 4 | import random as rd 5 | rd.seed(0) 6 | 7 | def word2vec(vocab, glove_file, save_filename): 8 | glove = load_file(glove_file) 9 | word2vec = {} 10 | for i, line in enumerate(glove): 11 | if i == 0: continue # for FastText 12 | line = line.split(' ') 13 | word2vec[line[0]] = np.array(line[1:]).astype(np.float) 14 | 15 | temp = [] 16 | for word, vec in word2vec.items(): 17 | temp.append(vec) 18 | temp = np.asarray(temp) 19 | print(temp.shape) 20 | row, col = temp.shape 21 | 22 | pad = np.mean(temp, axis=0) 23 | start = np.mean(temp[:int(row//2), :], axis=0) 24 | end = np.mean(temp[int(row//2):, :], axis=0) 25 | special_tokens = [pad, start, end] 26 | count = 0 27 | bad_words = [] 28 | sort_idx_word = sorted(vocab.idx2word.items(), key=lambda k:k[0]) 29 | glove_embed = np.zeros((len(vocab), 300)) 30 | for row, item in enumerate(sort_idx_word): 31 | idx, word = item[0], item[1] 32 | if word in word2vec: 33 | glove_embed[row] = word2vec[word] 34 | else: 35 | if row < 3: 36 | glove_embed[row] = special_tokens[row] 37 | else: 38 | glove_embed[row] = np.random.randn(300)*0.4 39 | print(word) 40 | bad_words.append(word) 41 | count += 1 42 | print(glove_embed.shape) 43 | save_file(bad_words, 'bad_words_qns.json') 44 | np.save(save_filename, glove_embed) 45 | print(count) 46 | 47 | 48 | def main(): 49 | data_dir = 'dataset/nextqa/' 50 | vocab_file = osp.join(data_dir, 'vocab.pkl') 51 | vocab = pkload(vocab_file) 52 | glove_file = '../data/Vocabulary/glove.840B.300d.txt' 53 | save_filename = 'dataset/nextqa/glove_embed.npy' 54 | word2vec(vocab, glove_file, save_filename) 55 | 56 | if __name__ == "__main__": 57 | main() 58 | --------------------------------------------------------------------------------