├── LICENSE ├── README.md ├── adaptive_detection_features_converter.py ├── data_utils ├── __init__.py ├── data_util.py ├── dataset.py ├── dataset_msrvtt.py └── dataset_msvd.py ├── embed_loss.py ├── main.py ├── main_msrvtt.py ├── main_msvd.py ├── model ├── masn.py └── modules │ ├── ban │ ├── ban.py │ └── fc.py │ ├── fusion │ ├── fusion.py │ └── net_utils.py │ ├── gcn.py │ ├── linear_weightdrop.py │ ├── position_embedding.py │ └── rnn_encoder.py ├── model_overview.jpeg ├── util.py └── warmup_scheduler.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 서아정 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | Motion-Appearance Synergistic Networks for VideoQA (MASN) 2 | ======================================================================== 3 | 4 | Pytorch Implementation for the paper: 5 | 6 | **[Attend What You Need: Motion-Appearance Synergistic Networks for Video Question Answering][1]**
7 | Ahjeong Seo, [Gi-Cheon Kang](https://gicheonkang.com), Joonhan Park, and [Byoung-Tak Zhang](https://bi.snu.ac.kr/~btzhang/)
8 | In ACL 2021 9 | 10 | 11 | 12 | 13 | Requirements 14 | -------- 15 | python 3.7, pytorch 1.2.0 16 | 17 | 18 | Dataset 19 | -------- 20 | - Download [TGIF-QA](https://github.com/YunseokJANG/tgif-qa) dataset and refer to the [paper](https://arxiv.org/abs/1704.04497) for details. 21 | - Download [MSVD-QA and MSRVTT-QA](https://github.com/xudejing/video-question-answering). 22 | 23 | Extract Features 24 | -------- 25 | 1. Appearance Features 26 | - For local features, we used the Faster-RCNN pre-trained with Visual Genome. Please cite this [Link](https://github.com/peteanderson80/bottom-up-attention). 27 | * After you extracted object features by Faster-RCNN, you can convert them to hdf5 file with simple run: `python adaptive_detection_features_converter.py` 28 | - For global features, we used ResNet152 provided by torchvision. Please cite this [Link](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py). 29 | 30 | 2. Motion Features 31 | - For local features, we use RoIAlign with bounding box features obtained from Faster-RCNN. Please cite this [Link](https://github.com/AceCoooool/RoIAlign-RoIPool-pytorch). 32 | - For global features, we use I3D pre-trained on Kinetics. Please cite this [Link](https://github.com/Tushar-N/pytorch-resnet3d). 33 | 34 | 35 | We uploaded our extracted features: 36 | 1) TGIF-QA 37 | * [`res152_avgpool.hdf5`][2]: appearance global features (3GB). 38 | * [`tgif_btup_f_obj10.hdf5`][3]: appearance local features (30GB). 39 | * [`tgif_i3d_hw7_perclip_avgpool.hdf5`][4]: motion global features (3GB). 40 | * [`tgif_i3d_roialign_hw7_perclip_avgpool.hdf5`][5]: motion local features (59GB). 41 | 42 | 2) MSRVTT-QA 43 | * [`msrvtt_res152_avgpool.hdf5`][10]: appearance global features (1.7GB). 44 | * [`msrvtt_btup_f_obj10.hdf5`][11]: appearance local features (17GB). 45 | * [`msrvtt_i3d_avgpool_perclip.hdf5`][12]: motion global features (1.7GB). 46 | * [`msrvtt_i3d_roialign_perclip_obj10.hdf5`][13]: motion local features (34GB). 47 | 48 | 3) MSVD-QA 49 | * [`msvd_res152_avgpool.hdf5`][14]: appearance global features (220MB). 50 | * [`msvd_btup_f_obj10.hdf5`][15]: appearance local features (2.2GB). 51 | * [`msvd_i3d_avgpool_perclip.hdf5`][16]: motion global features (220MB). 52 | * [`msvd_i3d_roialign_perclip_obj10.hdf5`][17]: motion local features (4.2GB). 53 | 54 | 55 | Training 56 | -------- 57 | Simple run 58 | ```sh 59 | CUDA_VISIBLE_DEVICES=0 python main.py --task Count --batch_size 32 60 | ``` 61 | 62 | For MSRVTT-QA, run 63 | ```sh 64 | CUDA_VISIBLE_DEVICES=0 python main_msrvtt.py --task MS-QA --batch_size 32 65 | ``` 66 | 67 | For MSVD-QA, run 68 | ```sh 69 | CUDA_VISIBLE_DEVICES=0 python main_msvd.py --task MS-QA --batch_size 32 70 | ``` 71 | 72 | ### Saving model checkpoints 73 | By default, our model save model checkpoints at every epoch. You can change the path for saving models by `--save_path` options. 74 | Each checkpoint's name is '[TASK]_[PERFORMANCE].pth' in default. 75 | 76 | 77 | Evaluation & Results 78 | -------- 79 | ```sh 80 | CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint [NAME] --task Count --batch_size 32 81 | ``` 82 | 83 | Performance on TGIF-QA dataset: 84 | 85 | Model | Count | Action | Trans. | FrameQA | 86 | ------- | ------ | ------ | ------ | ------ | 87 | MASN | 3.75 | 84.4 | 87.4 | 59.5| 88 | 89 | You can download our pre-trained model by this link : [`Count`][6], [`Action`][7], [`Trans.`][8], [`FrameQA`][9] 90 | 91 | Performance on MSRVTT-QA and MSVD-QA dataset: 92 | Model | MSRVTT-QA | MSVD-QA | 93 | ------- | ------ | ------ | 94 | MASN | 35.2 | 38.0 | 95 | 96 | 97 | Citation 98 | -------- 99 | If this repository is helpful for your research, we'd really appreciate it if you could cite the following paper: 100 | ```text 101 | @inproceedings{seo-etal-2021-attend, 102 | title = "Attend What You Need: Motion-Appearance Synergistic Networks for Video Question Answering", 103 | author = "Seo, Ahjeong and 104 | Kang, Gi-Cheon and 105 | Park, Joonhan and 106 | Zhang, Byoung-Tak", 107 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)", 108 | month = aug, 109 | year = "2021", 110 | address = "Online", 111 | publisher = "Association for Computational Linguistics", 112 | url = "https://aclanthology.org/2021.acl-long.481", 113 | doi = "10.18653/v1/2021.acl-long.481", 114 | pages = "6167--6177", 115 | abstract = "Video Question Answering is a task which requires an AI agent to answer questions grounded in video. This task entails three key challenges: (1) understand the intention of various questions, (2) capturing various elements of the input video (e.g., object, action, causality), and (3) cross-modal grounding between language and vision information. We propose Motion-Appearance Synergistic Networks (MASN), which embed two cross-modal features grounded on motion and appearance information and selectively utilize them depending on the question{'}s intentions. MASN consists of a motion module, an appearance module, and a motion-appearance fusion module. The motion module computes the action-oriented cross-modal joint representations, while the appearance module focuses on the appearance aspect of the input video. Finally, the motion-appearance fusion module takes each output of the motion module and the appearance module as input, and performs question-guided fusion. As a result, MASN achieves new state-of-the-art performance on the TGIF-QA and MSVD-QA datasets. We also conduct qualitative analysis by visualizing the inference results of MASN.", 116 | } 117 | 118 | ``` 119 | 120 | 121 | License 122 | -------- 123 | MIT License 124 | 125 | Acknowledgements 126 | -------- 127 | This work was partly supported by the Institute of Information & Communications Technology Planning & Evaluation (2015-0-00310-SW.StarLab/25%, 2017-0-01772-VTT/25%, 2018-0-00622-RMI/25%, 2019-0-01371-BabyMind/25%) grant funded by the Korean government. 128 | 129 | 130 | [1]: https://aclanthology.org/2021.acl-long.481/ 131 | [2]: https://drive.google.com/file/d/1tWY3gU4XohzhZjV5Wia5L8XqfaV10127/view?usp=sharing 132 | [3]: https://drive.google.com/file/d/1rxLL6eqi3d9FXKq7e4Wx7jiisu7_gzJa/view?usp=sharing 133 | [4]: https://drive.google.com/file/d/1ejP_V3CuJFB_jaUYf-OM9up5bsnnETP3/view?usp=sharing 134 | [5]: https://drive.google.com/file/d/1JbHWs0yTExL7Lc_abCvaXX49IsazUVvw/view?usp=sharing 135 | [6]: https://drive.google.com/file/d/1Z3r20wd2Mxco47WWggmNKazonfYnUDy1/view?usp=sharing 136 | [7]: https://drive.google.com/file/d/1USUA5D9bN5Ar9rClfdhOUHiTYdX1di1P/view?usp=sharing 137 | [8]: https://drive.google.com/file/d/1jZLDt14ZRmfHEqc8Yat7beQA6n-N6-h7/view?usp=sharing 138 | [9]: https://drive.google.com/file/d/1bXGlOKWrqUlEOer2cRNIJ2654_H_2UeR/view?usp=sharing 139 | [10]: https://drive.google.com/file/d/16UswbSjfhHBBUih-cGCZgvurNLq-gOKx/view?usp=sharing 140 | [11]: https://drive.google.com/file/d/1KdsLDW3oE-xNtrzsoYKZv9N_hauOR1Of/view?usp=sharing 141 | [12]: https://drive.google.com/file/d/1mX0oxSQXDS2h2Fxz091q6NdKuHihr0Fj/view?usp=sharing 142 | [13]: https://drive.google.com/file/d/1wQERtue5TY3zEZJwX0u2t19ARhU4mhtY/view?usp=sharing 143 | [14]: https://drive.google.com/file/d/1XtQNShBMbW3jNwuZPMYP9p-5QbpgtHF6/view?usp=sharing 144 | [15]: https://drive.google.com/file/d/1efxWKIGxvmEV5nR9iJosTMOvpMzHwjlG/view?usp=sharing 145 | [16]: https://drive.google.com/file/d/143miiDN3m9-QqptxtA6U6BJfSP8XOcoW/view?usp=sharing 146 | [17]: https://drive.google.com/file/d/14DUT3_yazEFYqZRjzWgrZ6K3lYu0XfHm/view?usp=sharing 147 | -------------------------------------------------------------------------------- /adaptive_detection_features_converter.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is modified from Jinhwa Kim's repository. 3 | https://github.com/jnhwkim/ban-vqa 4 | Reads in a tsv file with pre-trained bottom up attention features 5 | of the adaptive number of boxes and stores it in HDF5 format. 6 | Also store {image_id: feature_idx} as a pickle file. 7 | Hierarchy of HDF5 file: 8 | { 'image_features': num_boxes x 2048 9 | 'image_bb': num_boxes x 4 10 | 'spatial_features': num_boxes x 6 11 | 'pos_boxes': num_images x 2 } 12 | """ 13 | 14 | from __future__ import print_function 15 | 16 | import os 17 | import sys 18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 19 | 20 | import base64 21 | import csv 22 | import h5py 23 | import pickle as cPickle 24 | import numpy as np 25 | # from utils import utils 26 | 27 | 28 | def load_image_ids(base_dir): 29 | ''' Load a list of (path,image_id tuples). Modify this to suit your data locations. ''' 30 | ids = set() 31 | for gif_dir in os.listdir(base_dir): 32 | img_dir = os.path.join(base_dir, gif_dir) 33 | for img_name in os.listdir(img_dir): 34 | gif_name = gif_dir.split('.')[0] 35 | image_id = img_name.split('.')[0] 36 | ids.add(gif_name + '/' + image_id) 37 | return ids 38 | 39 | csv.field_size_limit(sys.maxsize) 40 | 41 | def extract(split, infiles): 42 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features'] 43 | # feature data name to write 44 | data_file = { 45 | 'all': './tgif_all_btmup_f.hdf5' 46 | } 47 | indices_file = { 48 | 'all': './tgif_all_imgid2idx.pkl' 49 | } 50 | ids_file = { 51 | 'all': './tgif_all_ids.pkl' 52 | } 53 | path_imgs = { 54 | 'all': '/root/data/TGIFQA/frames', 55 | } 56 | known_num_boxes = { 57 | 'all': None 58 | } 59 | feature_length = 2048 60 | min_fixed_boxes = 10 61 | max_fixed_boxes = 100 62 | 63 | if os.path.exists(ids_file[split]): 64 | imgids = cPickle.load(open(ids_file[split], 'rb')) 65 | else: 66 | imgids = load_image_ids(path_imgs[split]) 67 | cPickle.dump(imgids, open(ids_file[split], 'wb')) 68 | 69 | h = h5py.File(data_file[split], 'w') 70 | 71 | if known_num_boxes[split] is None: 72 | num_boxes = 0 73 | for infile in infiles: 74 | print("reading tsv...%s" % infile) 75 | with open(infile, "r+") as tsv_in_file: 76 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 77 | for item in reader: 78 | item['num_boxes'] = int(item['num_boxes']) 79 | image_id = str(item['image_id']) 80 | if image_id in imgids: 81 | num_boxes += item['num_boxes'] 82 | else: 83 | num_boxes = known_num_boxes[split] 84 | 85 | print('num_boxes=%d' % num_boxes) 86 | 87 | img_features = h.create_dataset( 88 | 'image_features', (num_boxes, feature_length), 'f') 89 | img_bb = h.create_dataset( 90 | 'image_bb', (num_boxes, 4), 'f') 91 | spatial_img_features = h.create_dataset( 92 | 'spatial_features', (num_boxes, 6), 'f') 93 | pos_boxes = h.create_dataset( 94 | 'pos_boxes', (len(imgids), 2), dtype='int32') 95 | 96 | counter = 0 97 | num_boxes = 0 98 | indices = {} 99 | 100 | for infile in infiles: 101 | unknown_ids = [] 102 | print("reading tsv...%s" % infile) 103 | with open(infile, "r+") as tsv_in_file: 104 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES) 105 | for item in reader: 106 | item['num_boxes'] = int(item['num_boxes']) 107 | item['boxes'] = bytes(item['boxes']).encode('utf-8') 108 | item['features'] = bytes(item['features']).encode('utf-8') 109 | image_id = str(item['image_id']) 110 | image_w = float(item['image_w']) 111 | image_h = float(item['image_h']) 112 | bboxes = np.frombuffer( 113 | base64.decodestring(item['boxes']), 114 | dtype=np.float32).reshape((item['num_boxes'], -1)) 115 | 116 | box_width = bboxes[:, 2] - bboxes[:, 0] 117 | box_height = bboxes[:, 3] - bboxes[:, 1] 118 | scaled_width = box_width / image_w 119 | scaled_height = box_height / image_h 120 | scaled_x = bboxes[:, 0] / image_w 121 | scaled_y = bboxes[:, 1] / image_h 122 | 123 | box_width = box_width[..., np.newaxis] 124 | box_height = box_height[..., np.newaxis] 125 | scaled_width = scaled_width[..., np.newaxis] 126 | scaled_height = scaled_height[..., np.newaxis] 127 | scaled_x = scaled_x[..., np.newaxis] 128 | scaled_y = scaled_y[..., np.newaxis] 129 | 130 | spatial_features = np.concatenate( 131 | (scaled_x, 132 | scaled_y, 133 | scaled_x + scaled_width, 134 | scaled_y + scaled_height, 135 | scaled_width, 136 | scaled_height), 137 | axis=1) 138 | 139 | if image_id in imgids: 140 | imgids.remove(image_id) 141 | indices[image_id] = counter 142 | pos_boxes[counter,:] = np.array([num_boxes, num_boxes + item['num_boxes']]) 143 | img_bb[num_boxes:num_boxes+item['num_boxes'], :] = bboxes 144 | img_features[num_boxes:num_boxes+item['num_boxes'], :] = np.frombuffer( 145 | base64.decodestring(item['features']), 146 | dtype=np.float32).reshape((item['num_boxes'], -1)) 147 | spatial_img_features[num_boxes:num_boxes+item['num_boxes'], :] = spatial_features 148 | counter += 1 149 | num_boxes += item['num_boxes'] 150 | else: 151 | unknown_ids.append(image_id) 152 | 153 | print('%d unknown_ids...' % len(unknown_ids)) 154 | print('%d image_ids left...' % len(imgids)) 155 | 156 | if len(imgids) != 0: 157 | print('Warning: %s_image_ids is not empty' % split) 158 | 159 | cPickle.dump(indices, open(indices_file[split], 'wb')) 160 | h.close() 161 | print("done!") 162 | 163 | if __name__ == '__main__': 164 | infile = ['./tgifqa_1.tsv.0', './tgifqa_1.tsv.1'] 165 | extract('all', infile) 166 | -------------------------------------------------------------------------------- /data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data_utils/data_util.py: -------------------------------------------------------------------------------- 1 | #-*- coding: utf-8 -*- 2 | 3 | # -------------------------------------------------------- 4 | # This code is modified from Jumpin2's repository. 5 | # https://github.com/Jumpin2/HGA 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | import re 10 | 11 | 12 | def clean_str(string, downcase=True): 13 | """ 14 | Tokenization/string cleaning for strings. 15 | Taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py 16 | """ 17 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`(_____)]", " ", string) 18 | string = re.sub(r"\'s", " \'s", string) 19 | string = re.sub(r"\'ve", " \'ve", string) 20 | string = re.sub(r"n\'t", " n\'t", string) 21 | string = re.sub(r"\'re", " \'re", string) 22 | string = re.sub(r"\'d", " \'d", string) 23 | string = re.sub(r"\'ll", " \'ll", string) 24 | string = re.sub(r",", " , ", string) 25 | string = re.sub(r"!", " ! ", string) 26 | string = re.sub(r"\(", " \( ", string) 27 | string = re.sub(r"\)", " \) ", string) 28 | string = re.sub(r"\?", " \? ", string) 29 | string = re.sub(r"\s{2,}", " ", string) 30 | return string.strip().lower() if downcase else string.strip() 31 | 32 | 33 | def recover_word(string): 34 | string = re.sub(r" \'s", "\'s", string) 35 | string = re.sub(r" ,", ",", string) 36 | return string 37 | 38 | 39 | def clean_blank(blank_sent): 40 | clean_sent = clean_str(blank_sent).split() 41 | return ['' if x == '_____' else x for x in clean_sent] 42 | 43 | 44 | def clean_root(string): 45 | """ 46 | Remove unexpected character in root. 47 | """ 48 | return string 49 | 50 | 51 | def pad_sequences( 52 | sequences, pad_token="[PAD]", pad_location="LEFT", max_length=None): 53 | """ 54 | Pads all sequences to the same length. The length is defined by the longest sequence. 55 | Returns padded sequences. 56 | """ 57 | if not max_length: 58 | max_length = max(len(x) for x in sequences) 59 | 60 | result = [] 61 | for i in range(len(sequences)): 62 | sentence = sequences[i] 63 | num_padding = max_length - len(sentence) 64 | if num_padding == 0: 65 | new_sentence = sentence 66 | elif num_padding < 0: 67 | new_sentence = sentence[:num_padding] 68 | elif pad_location == "RIGHT": 69 | new_sentence = sentence + [pad_token] * num_padding 70 | elif pad_location == "LEFT": 71 | new_sentence = [pad_token] * num_padding + sentence 72 | else: 73 | print("Invalid pad_location. Specify LEFT or RIGHT.") 74 | result.append(new_sentence) 75 | return result 76 | 77 | 78 | def convert_sent_to_index(sentence, word_to_index): 79 | """ 80 | Convert sentence consisting of string to indexed sentence. 81 | """ 82 | return [ 83 | word_to_index[word] if word in list(word_to_index.keys()) else 0 84 | for word in sentence 85 | ] 86 | 87 | 88 | def batch_iter(data, batch_size, seed=None, fill=True): 89 | """ 90 | Generates a batch iterator for a dataset. 91 | """ 92 | random = np.random.RandomState(seed) 93 | data_length = len(data) 94 | num_batches = int(data_length / batch_size) 95 | if data_length % batch_size != 0: 96 | num_batches += 1 97 | # Shuffle the data at each epoch 98 | shuffle_indices = random.permutation(np.arange(data_length)) 99 | for batch_num in range(num_batches): 100 | start_index = batch_num * batch_size 101 | end_index = min((batch_num + 1) * batch_size, data_length) 102 | selected_indices = shuffle_indices[start_index:end_index] 103 | # If we don't have enough data left for a whole batch, fill it 104 | # randomly 105 | if fill is True and end_index >= data_length: 106 | num_missing = batch_size - len(selected_indices) 107 | selected_indices = np.concatenate( 108 | [selected_indices, 109 | random.randint(0, data_length, num_missing)]) 110 | yield [data[i] for i in selected_indices] 111 | 112 | 113 | def fsr_iter(fsr_data, batch_size, random_seed=42, fill=True): 114 | """ 115 | fsr_data : one of LSMDCData.build_data(), [[video_features], [sentences], [roots]] 116 | return per iter : [[feature]*batch_size, [sentences]*batch_size, [roots]*batch] 117 | 118 | Usage: 119 | train_data, val_data, test_data = LSMDCData.build_data() 120 | for features, sentences, roots in fsr_iter(train_data, 20, 10): 121 | feed_dict = {model.video_feature : features, 122 | model.sentences : sentences, 123 | model.roots : roots} 124 | """ 125 | 126 | train_iter = batch_iter( 127 | list(zip(*fsr_data)), batch_size, fill=fill, seed=random_seed) 128 | return [list(zip(*batch)) for batch in train_iter] 129 | 130 | 131 | def preprocess_sents(descriptions, word_to_index, max_length): 132 | 133 | descriptions = [clean_str(sent).split() for sent in descriptions] 134 | descriptions = pad_sequences(descriptions, max_length=max_length) 135 | # sentence를 string list 에서 int-index list로 바꿈. 136 | descriptions = [ 137 | convert_sent_to_index(sent, word_to_index) for sent in descriptions 138 | ] 139 | 140 | return descriptions 141 | # remove punctuation mark and special chars from root. 142 | 143 | 144 | def preprocess_roots(roots, word_to_index): 145 | 146 | roots = [clean_root(root) for root in roots] 147 | # convert string to int index. 148 | roots = [ 149 | word_to_index[root] if root in list(word_to_index.keys()) else 0 150 | for root in roots 151 | ] 152 | 153 | return roots 154 | 155 | 156 | def pad_video(video_feature, dimension): 157 | ''' 158 | Fill pad to video to have same length. 159 | Pad in Left. 160 | video = [pad,..., pad, frm1, frm2, ..., frmN] 161 | ''' 162 | padded_feature = np.zeros(dimension) 163 | max_length = dimension[0] 164 | current_length = video_feature.shape[0] 165 | num_padding = max_length - current_length 166 | if num_padding == 0: 167 | padded_feature = video_feature 168 | elif num_padding < 0: 169 | steps = np.linspace( 170 | 0, current_length, num=max_length, endpoint=False, dtype=np.int32) 171 | padded_feature = video_feature[steps] 172 | else: 173 | padded_feature[num_padding:] = video_feature 174 | 175 | return padded_feature 176 | 177 | 178 | def video_3d_pad(sequence, obj_max_num, max_length): 179 | ''' Pad sequence with 0. 180 | ''' 181 | sequence = np.array(sequence) 182 | sequence_shape = np.array(sequence).shape 183 | current_length = sequence_shape[0] 184 | current_num = sequence_shape[1] 185 | pad = np.zeros((max_length, obj_max_num, sequence_shape[2]),dtype=np.float32) 186 | num_padding = max_length - current_length 187 | num_obj_padding = obj_max_num - current_num 188 | if num_padding <= 0: 189 | pad = sequence[:max_length] 190 | else: 191 | pad[:current_length, :current_num] = sequence 192 | return pad 193 | 194 | 195 | def fill_mask(max_length, current_length, zero_location='LEFT'): 196 | num_padding = max_length - current_length 197 | if num_padding <= 0: 198 | mask = np.ones(max_length) 199 | elif zero_location == 'LEFT': 200 | mask = np.ones(max_length) 201 | for i in range(num_padding): 202 | mask[i] = 0 203 | elif zero_location == 'RIGHT': 204 | mask = np.zeros(max_length) 205 | for i in range(current_length): 206 | mask[i] = 1 207 | 208 | return mask 209 | 210 | 211 | def question_pad(sequence, max_length): 212 | ''' Pad sequence with 0. 213 | ''' 214 | sequence = np.array(sequence) 215 | sequence_shape = np.array(sequence).shape 216 | current_length = sequence_shape[0] 217 | pad = np.zeros(max_length, dtype=np.int64) 218 | num_padding = max_length - current_length 219 | if num_padding <= 0: 220 | pad = sequence[:max_length] 221 | else: 222 | pad[:current_length] = sequence 223 | return pad 224 | 225 | 226 | def video_pad(sequence, max_length): 227 | ''' Pad sequence with 0. 228 | ''' 229 | sequence = np.array(sequence) 230 | sequence_shape = np.array(sequence).shape 231 | current_length = sequence_shape[0] 232 | pad = np.zeros((max_length, sequence_shape[1]),dtype=np.float32) 233 | num_padding = max_length - current_length 234 | if num_padding <= 0: 235 | pad = sequence[:max_length] 236 | else: 237 | pad[:current_length] = sequence 238 | return pad 239 | 240 | 241 | def dict_keys_bytes2string(dic): 242 | for k in list(dic.keys()): 243 | kk = k.decode('utf-8') 244 | v = dic.pop(k) 245 | dic[kk] = v 246 | return dic -------------------------------------------------------------------------------- /data_utils/dataset.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | """Generate data batch.""" 7 | import os 8 | 9 | import pandas as pd 10 | import numpy as np 11 | import h5py 12 | from torch.utils.data import Dataset 13 | from . import data_util 14 | 15 | from util import log 16 | 17 | import os.path 18 | import sys 19 | 20 | import pickle as pkl 21 | import nltk 22 | 23 | 24 | nltk.download('punkt') 25 | # nltk.download('averaged_perceptron_tagger') 26 | 27 | 28 | __PATH__ = os.path.abspath(os.path.dirname(__file__)) 29 | 30 | 31 | def assert_exists(path): 32 | assert os.path.exists(path), 'Does not exist : {}'.format(path) 33 | 34 | 35 | # PATHS 36 | TYPE_TO_CSV = { 37 | 'FrameQA': 'Train_frameqa_question.csv', 38 | 'Count': 'Train_count_question.csv', 39 | 'Trans': 'Train_transition_question.csv', 40 | 'Action': 'Train_action_question.csv' 41 | } 42 | 43 | eos_word = '' 44 | 45 | 46 | class TGIFQA(Dataset): 47 | 48 | def __init__( 49 | self, 50 | dataset_name='train', 51 | q_max_length=20, 52 | v_max_length=80, 53 | max_n_videos=None, 54 | data_type='FrameQA', 55 | csv_dir=None, 56 | vocab_dir=None, 57 | feat_dir=None): 58 | self.csv_dir = csv_dir 59 | self.vocabulary_dir = vocab_dir 60 | self.feat_dir = feat_dir 61 | self.dataset_name = dataset_name 62 | self.q_max_length = q_max_length 63 | self.v_max_length = v_max_length 64 | self.spa_q_max_length = 10 65 | self.tmp_q_max_length = 7 66 | self.obj_max_num = 10 67 | self.max_n_videos = max_n_videos 68 | self.data_type = data_type 69 | self.GLOVE_EMBEDDING_SIZE = 300 70 | 71 | self.data_df = self.read_df_from_csvfile() 72 | 73 | if max_n_videos is not None: 74 | self.data_df = self.data_df[:max_n_videos] 75 | 76 | self.res_avg_file = os.path.join(self.feat_dir, "feat_r6_vsync2/res152_avgpool.hdf5") 77 | self.i3d_avg_file = os.path.join(self.feat_dir, "feat_r6_vsync2/tgif_i3d_hw7_perclip_avgpool.hdf5") 78 | 79 | self.res_avg_feat = None 80 | self.i3d_avg_feat = None 81 | 82 | self.res_roi_file = os.path.join(self.feat_dir, "feat_r6_vsync2/tgif_btup_f_obj10.hdf5") 83 | self.i3d_roi_file = os.path.join(self.feat_dir, "feat_r6_vsync2/tgif_i3d_roialign_hw7_perclip_avgpool_obj10.hdf5") 84 | 85 | self.res_roi_feat = None 86 | self.i3d_roi_feat = None 87 | 88 | self.load_word_vocabulary() 89 | if self.data_type == 'FrameQA': 90 | self.get_FrameQA_result() 91 | elif self.data_type == 'Count': 92 | self.get_Count_result() 93 | elif self.data_type == 'Trans': 94 | self.get_Trans_result() 95 | elif self.data_type == 'Action': 96 | self.get_Trans_result() 97 | 98 | def __getitem__(self, index): 99 | if self.data_type == 'FrameQA': 100 | return self.getitem_frameqa(index) 101 | elif self.data_type == 'Count': 102 | return self.getitem_count(index) 103 | elif self.data_type == 'Trans': 104 | return self.getitem_trans(index) 105 | elif self.data_type == 'Action': 106 | return self.getitem_trans(index) 107 | 108 | def __len__(self): 109 | if self.max_n_videos is not None: 110 | if self.max_n_videos <= len(self.data_df): 111 | return self.max_n_videos 112 | return len(self.data_df) 113 | 114 | def read_df_from_csvfile(self): 115 | assert self.data_type in [ 116 | 'FrameQA', 'Count', 'Trans', 'Action' 117 | ], 'Should choose data type ' 118 | 119 | if self.data_type == 'FrameQA': 120 | train_data_path = os.path.join( 121 | self.csv_dir, 'Train_frameqa_question.csv') 122 | test_data_path = os.path.join( 123 | self.csv_dir, 'Test_frameqa_question.csv') 124 | #self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_frameqa_question.csv'), sep='\t') 125 | self.total_q = pd.read_csv( 126 | os.path.join(self.csv_dir, 'Total_frameqa_question.csv'), 127 | sep='\t') 128 | elif self.data_type == 'Count': 129 | train_data_path = os.path.join( 130 | self.csv_dir, 'Train_count_question.csv') 131 | test_data_path = os.path.join( 132 | self.csv_dir, 'Test_count_question.csv') 133 | #self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_count_question.csv'), sep='\t') 134 | self.total_q = pd.read_csv( 135 | os.path.join(self.csv_dir, 'Total_count_question.csv'), 136 | sep='\t') 137 | elif self.data_type == 'Trans': 138 | train_data_path = os.path.join( 139 | self.csv_dir, 'Train_transition_question.csv') 140 | test_data_path = os.path.join( 141 | self.csv_dir, 'Test_transition_question.csv') 142 | #self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_transition_question.csv'), sep='\t') 143 | self.total_q = pd.read_csv( 144 | os.path.join(self.csv_dir, 'Total_transition_question.csv'), 145 | sep='\t') 146 | elif self.data_type == 'Action': 147 | train_data_path = os.path.join( 148 | self.csv_dir, 'Train_action_question.csv') 149 | test_data_path = os.path.join( 150 | self.csv_dir, 'Test_action_question.csv') 151 | # self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_action_question.csv'), sep='\t') 152 | self.total_q = pd.read_csv( 153 | os.path.join(self.csv_dir, 'Total_action_question.csv'), 154 | sep='\t') 155 | 156 | assert_exists(train_data_path) 157 | assert_exists(test_data_path) 158 | 159 | if self.dataset_name == 'train': 160 | data_df = pd.read_csv(train_data_path, sep='\t') 161 | elif self.dataset_name == 'test': 162 | data_df = pd.read_csv(test_data_path, sep='\t') 163 | 164 | data_df = data_df.set_index('vid_id') 165 | data_df['row_index'] = list( 166 | range(1, 167 | len(data_df) + 1)) # assign csv row index 168 | return data_df 169 | 170 | @property 171 | def n_words(self): 172 | ''' The dictionary size. ''' 173 | if not hasattr(self, 'word2idx'): 174 | raise Exception('Dictionary not built yet!') 175 | return len(self.word2idx) 176 | 177 | def __repr__(self): 178 | if hasattr(self, 'word2idx'): 179 | return '' % ( 180 | self.dataset_name, len(self), len(self.word2idx)) 181 | else: 182 | return '' % ( 183 | self.dataset_name, len(self)) 184 | 185 | def split_sentence_into_words(self, sentence, eos=True): 186 | ''' 187 | Split the given sentence (str) and enumerate the words as strs. 188 | Each word is normalized, i.e. lower-cased, non-alphabet characters 189 | like period (.) or comma (,) are stripped. 190 | When tokenizing, I use ``data_util.clean_str`` 191 | ''' 192 | try: 193 | words = data_util.clean_str(sentence).split() 194 | except: 195 | print(sentence) 196 | sys.exit() 197 | if eos: 198 | words = words + [eos_word] 199 | for w in words: 200 | if not w: 201 | continue 202 | yield w 203 | 204 | def build_word_vocabulary( 205 | self, 206 | all_captions_source=None, 207 | word_count_threshold=0, 208 | ): 209 | ''' 210 | borrowed this implementation from @karpathy's neuraltalk. 211 | ''' 212 | log.infov('Building word vocabulary (%s) ...', self.dataset_name) 213 | 214 | if all_captions_source is None: 215 | all_captions_source = self.get_all_captions() 216 | 217 | # enumerate all sentences to build frequency table 218 | word_counts = {} 219 | nsents = 0 220 | nwords = 0 221 | for sentence in all_captions_source: 222 | nsents += 1 223 | for w in self.split_sentence_into_words(sentence): 224 | word_counts[w] = word_counts.get(w, 0) + 1 225 | nwords += 1 226 | 227 | import pickle as pkl 228 | vocab = [ 229 | w for w in word_counts if word_counts[w] >= word_count_threshold 230 | ] 231 | print("[%d] : [%d] word from captions" % (len(vocab), nwords)) 232 | log.info( 233 | "Filtered vocab words (threshold = %d), from %d to %d", 234 | word_count_threshold, len(word_counts), len(vocab)) 235 | 236 | # build index and vocabularies 237 | self.word2idx = {} 238 | self.idx2word = {} 239 | 240 | self.idx2word[0] = '.' 241 | self.idx2word[1] = 'UNK' 242 | self.word2idx['#START#'] = 0 243 | self.word2idx['UNK'] = 1 244 | for idx, w in enumerate(vocab, start=2): 245 | self.word2idx[w] = idx 246 | self.idx2word[idx] = w 247 | 248 | pkl.dump( 249 | self.word2idx, 250 | open( 251 | os.path.join( 252 | self.vocabulary_dir, 253 | 'word_to_index_%s.pkl' % self.data_type), 'wb')) 254 | pkl.dump( 255 | self.idx2word, 256 | open( 257 | os.path.join( 258 | self.vocabulary_dir, 259 | 'index_to_word_%s.pkl' % self.data_type), 'wb')) 260 | 261 | word_counts['.'] = nsents 262 | bias_init_vector = np.array( 263 | [ 264 | 1.0 * word_counts[w] if i > 1 else 0 265 | for i, w in self.idx2word.items() 266 | ]) 267 | bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies 268 | bias_init_vector = np.log(bias_init_vector + 1e-20) 269 | bias_init_vector -= np.max( 270 | bias_init_vector) # shift to nice numeric range 271 | self.bias_init_vector = bias_init_vector 272 | 273 | #self.total_q = pd.DataFrame().from_csv(os.path.join(csv_dir,'Total_desc_question.csv'), sep='\t') 274 | answers = list(set(self.total_q['answer'].values)) 275 | self.ans2idx = {} 276 | self.idx2ans = {} 277 | for idx, w in enumerate(answers): 278 | self.ans2idx[w] = idx 279 | self.idx2ans[idx] = w 280 | pkl.dump( 281 | self.ans2idx, 282 | open( 283 | os.path.join( 284 | self.vocabulary_dir, 285 | 'ans_to_index_%s.pkl' % self.data_type), 'wb')) 286 | pkl.dump( 287 | self.idx2ans, 288 | open( 289 | os.path.join( 290 | self.vocabulary_dir, 291 | 'index_to_ans_%s.pkl' % self.data_type), 'wb')) 292 | 293 | # Make glove embedding. 294 | #import spacy 295 | #nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors') 296 | 297 | with open(self.vocabulary_dir + '/glove.42B.300d.txt', 'rb') as f: 298 | lines = f.readlines() 299 | dict = {} 300 | 301 | for ix, line in enumerate(lines): 302 | if ix % 1000 == 0: 303 | print((ix, len(lines))) 304 | segs = line.split() 305 | wd = segs[0] 306 | dict[wd] = segs[1:] 307 | 308 | #max_length = len(vocab) 309 | max_length = len(self.idx2word) 310 | GLOVE_EMBEDDING_SIZE = 300 311 | 312 | glove_matrix = np.random.normal(size=[max_length, GLOVE_EMBEDDING_SIZE]) 313 | not_in_vocab = 0 314 | for i in range(len(vocab)): 315 | if i % 1000 == 0: 316 | print((i, len(vocab))) 317 | w = bytes(vocab[i], encoding='utf-8') 318 | #w_embed = nlp(u'%s' % w).vector 319 | if w in dict: 320 | w_embed = np.array(dict[w]) 321 | else: 322 | not_in_vocab += 1 323 | w_embed = np.random.normal(size=(GLOVE_EMBEDDING_SIZE)) 324 | #glove_matrix[i,:] = w_embed 325 | glove_matrix[i + 2, :] = w_embed # two placeholder words '.','UNK' 326 | 327 | print("[%d/%d] word is not saved" % (not_in_vocab, len(vocab))) 328 | 329 | vocab = pkl.dump( 330 | glove_matrix, 331 | open( 332 | os.path.join( 333 | self.vocabulary_dir, 334 | 'vocab_embedding_%s.pkl' % self.data_type), 'wb')) 335 | self.word_matrix = glove_matrix 336 | 337 | def load_word_vocabulary(self): 338 | 339 | word_matrix_path = os.path.join( 340 | self.vocabulary_dir, 'vocab_embedding_%s.pkl' % self.data_type) 341 | 342 | word2idx_path = os.path.join( 343 | self.vocabulary_dir, 'word_to_index_%s.pkl' % self.data_type) 344 | idx2word_path = os.path.join( 345 | self.vocabulary_dir, 'index_to_word_%s.pkl' % self.data_type) 346 | ans2idx_path = os.path.join( 347 | self.vocabulary_dir, 'ans_to_index_%s.pkl' % self.data_type) 348 | idx2ans_path = os.path.join( 349 | self.vocabulary_dir, 'index_to_ans_%s.pkl' % self.data_type) 350 | 351 | if not (os.path.exists(word_matrix_path) and 352 | os.path.exists(word2idx_path) and 353 | os.path.exists(idx2word_path) and 354 | os.path.exists(ans2idx_path) and os.path.exists(idx2ans_path)): 355 | self.build_word_vocabulary() 356 | 357 | # ndarray 358 | with open(word_matrix_path, 'rb') as f: 359 | self.word_matrix = pkl.load(f) 360 | log.info("Load word_matrix from pkl file : %s", word_matrix_path) 361 | 362 | with open(word2idx_path, 'rb') as f: 363 | self.word2idx = pkl.load(f) 364 | log.info("Load word2idx from pkl file : %s", word2idx_path) 365 | 366 | with open(idx2word_path, 'rb') as f: 367 | self.idx2word = pkl.load(f) 368 | log.info("Load idx2word from pkl file : %s", idx2word_path) 369 | 370 | with open(ans2idx_path, 'rb') as f: 371 | self.ans2idx = pkl.load(f) 372 | log.info("Load answer2idx from pkl file : %s", ans2idx_path) 373 | 374 | with open(idx2ans_path, 'rb') as f: 375 | self.idx2ans = pkl.load(f) 376 | log.info("Load idx2answers from pkl file : %s", idx2ans_path) 377 | 378 | def share_word_vocabulary_from(self, dataset): 379 | assert hasattr(dataset, 'idx2word') and hasattr( 380 | dataset, 'word2idx' 381 | ), 'The dataset instance should have idx2word and word2idx' 382 | assert ( 383 | isinstance(dataset.idx2word, dict) or 384 | isinstance(dataset.idx2word, list) 385 | ) and isinstance( 386 | dataset.word2idx, dict 387 | ), 'The dataset instance should have idx2word and word2idx (as dict)' 388 | 389 | if hasattr(self, 'word2idx'): 390 | log.warn( 391 | "Overriding %s' word vocabulary from %s ...", self, dataset) 392 | 393 | self.idx2word = dataset.idx2word 394 | self.word2idx = dataset.word2idx 395 | self.ans2idx = dataset.ans2idx 396 | self.idx2ans = dataset.idx2ans 397 | if hasattr(dataset, 'word_matrix'): 398 | self.word_matrix = dataset.word_matrix 399 | 400 | def get_all_captions(self): 401 | ''' 402 | Iterate caption strings associated in the vid/gifs. 403 | ''' 404 | #qa_data_df = pd.DataFrame().from_csv(os.path.join(self.csv_dir, TYPE_TO_CSV[self.data_type]), sep='\t') 405 | qa_data_df = pd.read_csv( 406 | os.path.join(self.csv_dir, TYPE_TO_CSV[self.data_type]), sep='\t') 407 | 408 | all_sents = [] 409 | for row in qa_data_df.iterrows(): 410 | all_sents.extend(self.get_captions(row)) 411 | self.data_type 412 | return all_sents 413 | 414 | def get_captions(self, row): 415 | if self.data_type == 'FrameQA': 416 | columns = ['description', 'question', 'answer'] 417 | elif self.data_type == 'Count': 418 | columns = ['question'] 419 | elif self.data_type == 'Trans': 420 | columns = ['question', 'a1', 'a2', 'a3', 'a4', 'a5'] 421 | elif self.data_type == 'Action': 422 | columns = ['question', 'a1', 'a2', 'a3', 'a4', 'a5'] 423 | 424 | sents = [row[1][col] for col in columns if not pd.isnull(row[1][col])] 425 | return sents 426 | 427 | def get_video_feature(self, key): # key : gif_name 428 | if self.res_avg_feat is None: 429 | self.res_avg_feat = h5py.File(self.res_avg_file, 'r') 430 | if self.i3d_avg_feat is None: 431 | self.i3d_avg_feat = h5py.File(self.i3d_avg_file, 'r') 432 | if self.res_roi_feat is None: 433 | self.res_roi_feat = h5py.File(self.res_roi_file, 'r') 434 | if self.i3d_roi_feat is None: 435 | self.i3d_roi_feat = h5py.File(self.i3d_roi_file, 'r') 436 | 437 | video_id = str(key) 438 | 439 | try: 440 | res_avg_feat = np.array(self.res_avg_feat[video_id]) # T, d 441 | i3d_avg_feat = np.array(self.i3d_avg_feat[video_id]) # T, d 442 | res_roi_feat = np.array(self.res_roi_feat['image_features'][video_id]) # T, 5, d 443 | roi_bbox_feat = np.array(self.res_roi_feat['spatial_features'][video_id]) # T, 5, 6 444 | i3d_roi_feat = np.array(self.i3d_roi_feat[video_id]) # T, 5, d 445 | except KeyError: # no img 446 | print('no image', key) 447 | res_avg_feat = np.zeros((1, 2048)) 448 | i3d_avg_feat = np.zeros((1, 2048)) 449 | res_roi_feat = np.zeros((1, self.obj_max_num, 2048)) 450 | roi_bbox_feat = np.zeros((1, self.obj_max_num, 6)) 451 | i3d_roi_feat = np.zeros((1, self.obj_max_num, 2048)) 452 | 453 | return res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat 454 | 455 | def convert_sentence_to_matrix(self, task, question, answer=None, eos=True): 456 | ''' 457 | Convert the given sentence into word indices and masks. 458 | WARNING: Unknown words (not in vocabulary) are revmoed. 459 | 460 | Args: 461 | sentence: A str for unnormalized sentence, containing T words 462 | 463 | Returns: 464 | sentence_word_indices : list of (at most) length T, 465 | each being a word index 466 | ''' 467 | if answer is None: 468 | sentence = question 469 | else: 470 | sentence = question + ' ' + answer 471 | 472 | token = self.split_sentence_into_words(sentence, eos=False) 473 | token = [t for t in token] 474 | 475 | sent2indices = [ 476 | self.word2idx[w] if w in self.word2idx else 1 477 | for w in token 478 | ] 479 | T = len(sent2indices) 480 | length = min(T, self.q_max_length) 481 | 482 | return sent2indices[:length] 483 | 484 | def get_question(self, key): 485 | ''' 486 | Return question index for given key. 487 | ''' 488 | question = self.data_df.loc[key, ['question', 'description']].values 489 | if len(list(question.shape)) > 1: 490 | question = question[0] 491 | # A question string 492 | question = question[0] 493 | 494 | return self.convert_sentence_to_matrix(self.data_type, question, eos=False) 495 | 496 | def get_answer(self, key): 497 | answer = self.data_df.loc[key, ['answer', 'type']].values 498 | 499 | if len(list(answer.shape)) > 1: 500 | answer = answer[0] 501 | 502 | anstype = answer[1] 503 | answer = answer[0] 504 | 505 | return answer, anstype 506 | 507 | def get_FrameQA_result(self): 508 | self.padded_all_ques = [] 509 | self.answers = [] 510 | self.answer_type = [] 511 | self.all_ques_lengths = [] 512 | 513 | self.keys = self.data_df['key'].values.astype(np.int64) 514 | 515 | for index, row in self.data_df.iterrows(): 516 | # ====== Question ====== 517 | all_ques = self.get_question(index) 518 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length) 519 | self.padded_all_ques.append(all_ques_pad) 520 | 521 | all_ques_length = min(self.q_max_length, len(all_ques)) 522 | self.all_ques_lengths.append(all_ques_length) 523 | 524 | answer, answer_type = self.get_answer(index) 525 | if str(answer) in self.ans2idx: 526 | answer = self.ans2idx[answer] 527 | else: 528 | # unknown token, check later 529 | answer = 1 530 | self.answers.append(np.array(answer, dtype=np.int64)) 531 | self.answer_type.append(np.array(answer_type, dtype=np.int64)) 532 | 533 | def getitem_frameqa(self, index): 534 | key = self.keys[index] 535 | 536 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat = self.get_video_feature(key) 537 | 538 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32) 539 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32) 540 | 541 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num, 542 | self.v_max_length).astype(np.float32) 543 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num, 544 | self.v_max_length).astype(np.float32) 545 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num, 546 | self.v_max_length).astype(np.float32) 547 | video_length = min(self.v_max_length, res_avg_feat.shape[0]) 548 | 549 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \ 550 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index], self.answer_type[index] 551 | 552 | def get_Count_question(self, key): 553 | ''' 554 | Return question string for given key. 555 | ''' 556 | question = self.data_df.loc[key, 'question'] 557 | return self.convert_sentence_to_matrix(self.data_type, question, eos=False) 558 | 559 | def get_Count_result(self): 560 | self.padded_all_ques = [] 561 | self.answers = [] 562 | self.all_ques_lengths = [] 563 | 564 | self.keys = self.data_df['key'].values.astype(np.int64) 565 | 566 | for index, row in self.data_df.iterrows(): 567 | # ====== Question ====== 568 | all_ques = self.get_Count_question(index) 569 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length) 570 | 571 | self.padded_all_ques.append(all_ques_pad) 572 | 573 | all_ques_length = min(self.q_max_length, len(all_ques)) 574 | self.all_ques_lengths.append(all_ques_length) 575 | 576 | # force answer not equal to 0 577 | answer = max(row['answer'], 1) 578 | self.answers.append(np.array(answer, dtype=np.float32)) 579 | 580 | def getitem_count(self, index): 581 | key = self.keys[index] 582 | 583 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat = self.get_video_feature(key) 584 | 585 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32) 586 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32) 587 | 588 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num, 589 | self.v_max_length).astype(np.float32) 590 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num, 591 | self.v_max_length).astype(np.float32) 592 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num, 593 | self.v_max_length).astype(np.float32) 594 | video_length = min(self.v_max_length, res_avg_feat.shape[0]) 595 | 596 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \ 597 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index] 598 | 599 | def get_Trans_matrix(self, candidates, q_max_length, is_left=True): 600 | candidates_matrix = np.zeros([5, q_max_length], dtype=np.int64) 601 | for k in range(5): 602 | sentence = candidates[k] 603 | if is_left: 604 | candidates_matrix[k, :len(sentence)] = sentence 605 | else: 606 | candidates_matrix[k, -len(sentence):] = sentence 607 | return candidates_matrix 608 | 609 | def get_Trans_result(self): 610 | self.padded_candidates = [] 611 | self.answers = self.data_df['answer'].values.astype(np.int64) 612 | self.row_index = self.data_df['row_index'].values.astype(np.int64) 613 | self.candidate_lengths = [] 614 | 615 | self.keys = self.data_df['key'].values.astype(np.int64) 616 | 617 | for index, row in self.data_df.iterrows(): 618 | a1 = row['a1'].strip() 619 | a2 = row['a2'].strip() 620 | a3 = row['a3'].strip() 621 | a4 = row['a4'].strip() 622 | a5 = row['a5'].strip() 623 | candidates = [a1, a2, a3, a4, a5] 624 | raw_question = row['question'].strip() 625 | indexed_candidates = [] 626 | 627 | for x in candidates: 628 | all_cand = self.convert_sentence_to_matrix(self.data_type, raw_question, x, eos=False) 629 | indexed_candidates.append(all_cand) 630 | 631 | cand_lens = [] 632 | for i in range(len(indexed_candidates)): 633 | vl = min(self.q_max_length, len(indexed_candidates[i])) 634 | cand_lens.append(vl) 635 | self.candidate_lengths.append(cand_lens) 636 | 637 | # (5, self.q_max_length) 638 | candidates_pad = self.get_Trans_matrix(indexed_candidates, self.q_max_length) 639 | self.padded_candidates.append(candidates_pad) 640 | 641 | def getitem_trans(self, index): 642 | key = self.keys[index] 643 | 644 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat = self.get_video_feature(key) 645 | 646 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32) 647 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32) 648 | 649 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num, 650 | self.v_max_length).astype(np.float32) 651 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num, 652 | self.v_max_length).astype(np.float32) 653 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num, 654 | self.v_max_length).astype(np.float32) 655 | video_length = min(self.v_max_length, res_avg_feat.shape[0]) 656 | 657 | self.padded_candidates = np.asarray(self.padded_candidates).astype( 658 | np.int64) 659 | self.candidate_lengths = np.asarray(self.candidate_lengths).astype( 660 | np.int64) 661 | 662 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \ 663 | self.padded_candidates[index], self.candidate_lengths[index], self.answers[index], self.row_index[index] 664 | -------------------------------------------------------------------------------- /data_utils/dataset_msrvtt.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | """Generate data batch.""" 7 | import os 8 | import os.path 9 | import sys 10 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 11 | 12 | import pandas as pd 13 | import numpy as np 14 | import h5py 15 | from torch.utils.data import Dataset 16 | from data_utils import data_util 17 | 18 | from util import log 19 | 20 | import pickle as pkl 21 | import nltk 22 | 23 | 24 | nltk.download('punkt') 25 | # nltk.download('averaged_perceptron_tagger') 26 | 27 | 28 | __PATH__ = os.path.abspath(os.path.dirname(__file__)) 29 | 30 | 31 | def assert_exists(path): 32 | assert os.path.exists(path), 'Does not exist : {}'.format(path) 33 | 34 | eos_word = '' 35 | 36 | 37 | class MSRVTTQA(Dataset): 38 | 39 | def __init__( 40 | self, 41 | dataset_name='train', 42 | q_max_length=20, 43 | v_max_length=80, 44 | max_n_videos=None, 45 | csv_dir=None, 46 | vocab_dir=None, 47 | feat_dir=None): 48 | self.csv_dir = csv_dir 49 | self.vocabulary_dir = vocab_dir 50 | self.feat_dir = feat_dir 51 | self.dataset_name = dataset_name 52 | self.q_max_length = q_max_length 53 | self.v_max_length = v_max_length 54 | self.obj_max_num = 10 55 | self.max_n_videos = max_n_videos 56 | self.GLOVE_EMBEDDING_SIZE = 300 57 | 58 | self.data_df = self.read_df_from_json() 59 | 60 | if max_n_videos is not None: 61 | self.data_df = self.data_df[:max_n_videos] 62 | 63 | self.res_avg_file = os.path.join(self.feat_dir, "vfeat/msrvtt_res152_avgpool.hdf5") 64 | self.i3d_avg_file = os.path.join(self.feat_dir, "vfeat/msrvtt_i3d_avgpool_perclip.hdf5") 65 | 66 | self.res_avg_feat = None 67 | self.i3d_avg_feat = None 68 | 69 | self.res_roi_file = os.path.join(self.feat_dir, "vfeat/msrvtt_btup_f_obj10.hdf5") 70 | self.i3d_roi_file = os.path.join(self.feat_dir, "vfeat/msrvtt_i3d_roialign_perclip_obj10.hdf5") 71 | 72 | self.res_roi_feat = None 73 | self.i3d_roi_feat = None 74 | 75 | self.load_word_vocabulary() 76 | self.get_result() 77 | 78 | def __getitem__(self, index): 79 | return self.get_item(index) 80 | 81 | def __len__(self): 82 | if self.max_n_videos is not None: 83 | if self.max_n_videos <= len(self.data_df): 84 | return self.max_n_videos 85 | return len(self.data_df) 86 | 87 | @property 88 | def n_words(self): 89 | ''' The dictionary size. ''' 90 | if not hasattr(self, 'word2idx'): 91 | raise Exception('Dictionary not built yet!') 92 | return len(self.word2idx) 93 | 94 | def __repr__(self): 95 | if hasattr(self, 'word2idx'): 96 | return '' % ( 97 | self.dataset_name, len(self), len(self.word2idx)) 98 | else: 99 | return '' % ( 100 | self.dataset_name, len(self)) 101 | 102 | def read_df_from_json(self): 103 | if self.dataset_name == 'train': 104 | data_path = '%s/train_qa.json'%self.csv_dir 105 | elif self.dataset_name == 'val': 106 | data_path = '%s/val_qa.json'%self.csv_dir 107 | elif self.dataset_name == 'test': 108 | data_path = '%s/test_qa.json'%self.csv_dir 109 | 110 | with open(data_path, 'r') as f: 111 | data_df = pd.read_json(f) 112 | 113 | return data_df 114 | 115 | def split_sentence_into_words(self, sentence, eos=True): 116 | ''' 117 | Split the given sentence (str) and enumerate the words as strs. 118 | Each word is normalized, i.e. lower-cased, non-alphabet characters 119 | like period (.) or comma (,) are stripped. 120 | When tokenizing, I use ``data_util.clean_str`` 121 | ''' 122 | try: 123 | words = data_util.clean_str(sentence).split() 124 | except: 125 | print(sentence) 126 | sys.exit() 127 | if eos: 128 | words = words + [eos_word] 129 | for w in words: 130 | if not w: 131 | continue 132 | yield w 133 | 134 | def create_answerset(self, ans_df): 135 | """Generate 1000 answer set from train_qa.json. 136 | Args: 137 | trainqa_path: path to train_qa.json. 138 | answerset_path: generate answer set of mc_qa 139 | """ 140 | ans_num = 4000 141 | answer_freq = ans_df['answer'].value_counts() 142 | answer_freq = list(answer_freq.iloc[0:ans_num].keys()) 143 | return answer_freq 144 | 145 | def build_word_vocabulary( 146 | self, 147 | all_sen=None, 148 | ans_df=None, 149 | word_count_threshold=0, 150 | ): 151 | ''' 152 | borrowed this implementation from @karpathy's neuraltalk. 153 | ''' 154 | log.infov('Building word vocabulary (%s) ...', self.dataset_name) 155 | 156 | if all_sen is None or ans_df is None: 157 | all_sen, ans_df = self.get_all_captions() 158 | all_captions_source = all_sen 159 | 160 | # enumerate all sentences to build frequency table 161 | word_counts = {} 162 | nsents = 0 163 | nwords = 0 164 | for sentence in all_captions_source: 165 | nsents += 1 166 | for w in self.split_sentence_into_words(sentence): 167 | word_counts[w] = word_counts.get(w, 0) + 1 168 | nwords += 1 169 | 170 | vocab = [ 171 | w for w in word_counts if word_counts[w] >= word_count_threshold 172 | ] 173 | print("Filtered vocab words (threshold = %d), from %d to %d" % (word_count_threshold, len(word_counts), len(vocab))) 174 | log.info( 175 | "Filtered vocab words (threshold = %d), from %d to %d", 176 | word_count_threshold, len(word_counts), len(vocab)) 177 | 178 | # build index and vocabularies 179 | self.word2idx = {} 180 | self.idx2word = {} 181 | 182 | self.idx2word[0] = '.' 183 | self.idx2word[1] = 'UNK' 184 | self.word2idx['#START#'] = 0 185 | self.word2idx['UNK'] = 1 186 | for idx, w in enumerate(vocab, start=2): 187 | self.word2idx[w] = idx 188 | self.idx2word[idx] = w 189 | 190 | pkl.dump( 191 | self.word2idx, 192 | open(os.path.join(self.vocabulary_dir, 'word_to_index.pkl'), 'wb') 193 | ) 194 | pkl.dump( 195 | self.idx2word, 196 | open(os.path.join(self.vocabulary_dir, 'index_to_word.pkl'), 'wb') 197 | ) 198 | 199 | word_counts['.'] = nsents 200 | bias_init_vector = np.array( 201 | [ 202 | 1.0 * word_counts[w] if i > 1 else 0 203 | for i, w in self.idx2word.items() 204 | ]) 205 | bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies 206 | bias_init_vector = np.log(bias_init_vector + 1e-20) 207 | bias_init_vector -= np.max( 208 | bias_init_vector) # shift to nice numeric range 209 | self.bias_init_vector = bias_init_vector 210 | 211 | #self.total_q = pd.DataFrame().from_csv(os.path.join(csv_dir,'Total_desc_question.csv'), sep='\t') 212 | answers = self.create_answerset(ans_df) 213 | print("all answer num : %d"%len(answers)) 214 | self.ans2idx = {} 215 | self.idx2ans = {} 216 | for idx, w in enumerate(answers): 217 | self.ans2idx[w] = idx 218 | self.idx2ans[idx] = w 219 | pkl.dump( 220 | self.ans2idx, 221 | open(os.path.join(self.vocabulary_dir, 'ans_to_index.pkl'), 'wb') 222 | ) 223 | pkl.dump( 224 | self.idx2ans, 225 | open(os.path.join(self.vocabulary_dir, 'index_to_ans.pkl'), 'wb') 226 | ) 227 | 228 | # Make glove embedding. 229 | #import spacy 230 | #nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors') 231 | 232 | with open(self.vocabulary_dir + '/glove.42B.300d.txt', 'rb') as f: 233 | lines = f.readlines() 234 | dict = {} 235 | 236 | for ix, line in enumerate(lines): 237 | if ix % 1000 == 0: 238 | print((ix, len(lines))) 239 | segs = line.split() 240 | wd = segs[0] 241 | dict[wd] = segs[1:] 242 | 243 | #max_length = len(vocab) 244 | max_length = len(self.idx2word) 245 | GLOVE_EMBEDDING_SIZE = 300 246 | 247 | glove_matrix = np.random.normal(size=[max_length, GLOVE_EMBEDDING_SIZE]) 248 | not_in_vocab = 0 249 | for i in range(len(vocab)): 250 | if i % 1000 == 0: 251 | print((i, len(vocab))) 252 | w = bytes(vocab[i], encoding='utf-8') 253 | #w_embed = nlp(u'%s' % w).vector 254 | if w in dict: 255 | w_embed = np.array(dict[w]) 256 | else: 257 | not_in_vocab += 1 258 | w_embed = np.random.normal(size=(GLOVE_EMBEDDING_SIZE)) 259 | #glove_matrix[i,:] = w_embed 260 | glove_matrix[i + 2, :] = w_embed # two placeholder words '.','UNK' 261 | 262 | print("[%d/%d] word is not saved" % (not_in_vocab, len(vocab))) 263 | 264 | vocab = pkl.dump( 265 | glove_matrix, 266 | open(os.path.join(self.vocabulary_dir, 'vocab_embedding.pkl'), 'wb') 267 | ) 268 | self.word_matrix = glove_matrix 269 | 270 | def load_word_vocabulary(self): 271 | 272 | word_matrix_path = os.path.join( 273 | self.vocabulary_dir, 'vocab_embedding.pkl') 274 | 275 | word2idx_path = os.path.join( 276 | self.vocabulary_dir, 'word_to_index.pkl') 277 | idx2word_path = os.path.join( 278 | self.vocabulary_dir, 'index_to_word.pkl') 279 | ans2idx_path = os.path.join( 280 | self.vocabulary_dir, 'ans_to_index.pkl') 281 | idx2ans_path = os.path.join( 282 | self.vocabulary_dir, 'index_to_ans.pkl') 283 | 284 | if not (os.path.exists(word_matrix_path) and 285 | os.path.exists(word2idx_path) and 286 | os.path.exists(idx2word_path) and 287 | os.path.exists(ans2idx_path) and os.path.exists(idx2ans_path)): 288 | self.build_word_vocabulary() 289 | 290 | # ndarray 291 | with open(word_matrix_path, 'rb') as f: 292 | self.word_matrix = pkl.load(f) 293 | log.info("Load word_matrix from pkl file : %s", word_matrix_path) 294 | 295 | with open(word2idx_path, 'rb') as f: 296 | self.word2idx = pkl.load(f) 297 | log.info("Load word2idx from pkl file : %s", word2idx_path) 298 | 299 | with open(idx2word_path, 'rb') as f: 300 | self.idx2word = pkl.load(f) 301 | log.info("Load idx2word from pkl file : %s", idx2word_path) 302 | 303 | with open(ans2idx_path, 'rb') as f: 304 | self.ans2idx = pkl.load(f) 305 | log.info("Load answer2idx from pkl file : %s", ans2idx_path) 306 | 307 | with open(idx2ans_path, 'rb') as f: 308 | self.idx2ans = pkl.load(f) 309 | log.info("Load idx2answers from pkl file : %s", idx2ans_path) 310 | 311 | def get_all_captions(self): 312 | ''' 313 | Iterate caption strings associated in the vid/gifs. 314 | ''' 315 | data_path = '%s/train_qa.json'%self.csv_dir 316 | with open(data_path, 'r') as f: 317 | data_df = pd.read_json(f) 318 | 319 | all_sents = list(data_df['question']) 320 | return all_sents, data_df 321 | 322 | def get_video_feature(self, key): # key : gif_name 323 | if self.res_avg_feat is None: 324 | self.res_avg_feat = h5py.File(self.res_avg_file, 'r') 325 | if self.i3d_avg_feat is None: 326 | self.i3d_avg_feat = h5py.File(self.i3d_avg_file, 'r') 327 | if self.res_roi_feat is None: 328 | self.res_roi_feat = h5py.File(self.res_roi_file, 'r') 329 | if self.i3d_roi_feat is None: 330 | self.i3d_roi_feat = h5py.File(self.i3d_roi_file, 'r') 331 | 332 | video_id = 'video' + str(key) 333 | 334 | try: 335 | res_avg_feat = np.array(self.res_avg_feat[video_id]) # T, d 336 | i3d_avg_feat = np.array(self.i3d_avg_feat[video_id]) # T, d 337 | res_roi_feat = np.array(self.res_roi_feat['image_features'][video_id]) # T, 5, d 338 | roi_bbox_feat = np.array(self.res_roi_feat['spatial_features'][video_id]) # T, 5, 6 339 | i3d_roi_feat = np.array(self.i3d_roi_feat[video_id]) # T, 5, d 340 | except KeyError: # no img 341 | print('no image', key) 342 | res_avg_feat = np.zeros((1, 2048)) 343 | i3d_avg_feat = np.zeros((1, 2048)) 344 | res_roi_feat = np.zeros((1, self.obj_max_num, 2048)) 345 | roi_bbox_feat = np.zeros((1, self.obj_max_num, 6)) 346 | i3d_roi_feat = np.zeros((1, self.obj_max_num, 2048)) 347 | 348 | return res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat 349 | 350 | def convert_sentence_to_matrix(self, sentence, eos=True): 351 | ''' 352 | Convert the given sentence into word indices and masks. 353 | WARNING: Unknown words (not in vocabulary) are revmoed. 354 | 355 | Args: 356 | sentence: A str for unnormalized sentence, containing T words 357 | 358 | Returns: 359 | sentence_word_indices : list of (at most) length T, 360 | each being a word index 361 | ''' 362 | sent2indices = [ 363 | self.word2idx[w] if w in self.word2idx else 1 364 | for w in sentence 365 | ] # 1 is UNK, unknown 366 | T = len(sent2indices) 367 | length = min(T, self.q_max_length) 368 | return sent2indices[:length] 369 | 370 | def get_question(self, key): 371 | ''' 372 | Return question index for given key. 373 | ''' 374 | question = self.data_df.loc[key, ['question']].values 375 | question = question[0] 376 | 377 | question = self.split_sentence_into_words(question, eos=False) 378 | q_refine = [] 379 | for w in question: 380 | q_refine.append(w) 381 | q_refine = self.convert_sentence_to_matrix(q_refine) 382 | return q_refine 383 | 384 | def get_answer(self, key): 385 | answer = self.data_df.loc[key, ['answer']].values 386 | answer = answer[0] 387 | 388 | return answer 389 | 390 | def get_result(self): 391 | self.padded_all_ques = [] 392 | self.answers = [] 393 | 394 | self.all_ques_lengths = [] 395 | self.keys = self.data_df['video_id'].values.astype(np.int64) 396 | 397 | for index, row in self.data_df.iterrows(): 398 | # ====== Question ====== 399 | all_ques = self.get_question(index) 400 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length) 401 | self.padded_all_ques.append(all_ques_pad) 402 | all_ques_length = min(self.q_max_length, len(all_ques)) 403 | self.all_ques_lengths.append(all_ques_length) 404 | 405 | answer = self.get_answer(index) 406 | if str(answer) in self.ans2idx: 407 | answer = self.ans2idx[answer] 408 | else: 409 | # unknown token, check later 410 | answer = 1 411 | self.answers.append(np.array(answer, dtype=np.int64)) 412 | 413 | def get_item(self, index): 414 | key = self.keys[index] 415 | 416 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat \ 417 | = self.get_video_feature(key) 418 | 419 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32) 420 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32) 421 | 422 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num, 423 | self.v_max_length).astype(np.float32) 424 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num, 425 | self.v_max_length).astype(np.float32) 426 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num, 427 | self.v_max_length).astype(np.float32) 428 | video_length = min(self.v_max_length, res_avg_feat.shape[0]) 429 | 430 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \ 431 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index] 432 | -------------------------------------------------------------------------------- /data_utils/dataset_msvd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | """Generate data batch.""" 7 | import os 8 | import os.path 9 | import sys 10 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__)))) 11 | 12 | import pandas as pd 13 | import numpy as np 14 | import h5py 15 | from torch.utils.data import Dataset 16 | from data_utils import data_util 17 | 18 | from util import log 19 | 20 | import pickle as pkl 21 | import nltk 22 | 23 | 24 | nltk.download('punkt') 25 | # nltk.download('averaged_perceptron_tagger') 26 | 27 | 28 | __PATH__ = os.path.abspath(os.path.dirname(__file__)) 29 | 30 | 31 | def assert_exists(path): 32 | assert os.path.exists(path), 'Does not exist : {}'.format(path) 33 | 34 | eos_word = '' 35 | 36 | 37 | class MSVDQA(Dataset): 38 | 39 | def __init__( 40 | self, 41 | dataset_name='train', 42 | q_max_length=20, 43 | v_max_length=80, 44 | max_n_videos=None, 45 | csv_dir=None, 46 | vocab_dir=None, 47 | feat_dir=None): 48 | self.csv_dir = csv_dir 49 | self.vocabulary_dir = vocab_dir 50 | self.feat_dir = feat_dir 51 | self.dataset_name = dataset_name 52 | self.q_max_length = q_max_length 53 | self.v_max_length = v_max_length 54 | self.obj_max_num = 10 55 | self.max_n_videos = max_n_videos 56 | self.GLOVE_EMBEDDING_SIZE = 300 57 | 58 | self.data_df = self.read_df_from_json() 59 | 60 | if max_n_videos is not None: 61 | self.data_df = self.data_df[:max_n_videos] 62 | 63 | self.res_avg_file = os.path.join(self.feat_dir, "vfeat/msvd_res152_avgpool.hdf5") 64 | self.i3d_avg_file = os.path.join(self.feat_dir, "vfeat/msvd_i3d_avgpool_perclip.hdf5") 65 | 66 | self.res_avg_feat = None 67 | self.i3d_avg_feat = None 68 | 69 | self.res_roi_file = os.path.join(self.feat_dir, "vfeat/msvd_btup_f_obj10.hdf5") 70 | self.i3d_roi_file = os.path.join(self.feat_dir, "vfeat/msvd_i3d_roialign_perclip_obj10.hdf5") 71 | 72 | self.res_roi_feat = None 73 | self.i3d_roi_feat = None 74 | 75 | self.load_word_vocabulary() 76 | self.get_result() 77 | 78 | def __getitem__(self, index): 79 | return self.get_item(index) 80 | 81 | def __len__(self): 82 | if self.max_n_videos is not None: 83 | if self.max_n_videos <= len(self.data_df): 84 | return self.max_n_videos 85 | return len(self.data_df) 86 | 87 | @property 88 | def n_words(self): 89 | ''' The dictionary size. ''' 90 | if not hasattr(self, 'word2idx'): 91 | raise Exception('Dictionary not built yet!') 92 | return len(self.word2idx) 93 | 94 | def __repr__(self): 95 | if hasattr(self, 'word2idx'): 96 | return '' % ( 97 | self.dataset_name, len(self), len(self.word2idx)) 98 | else: 99 | return '' % ( 100 | self.dataset_name, len(self)) 101 | 102 | def read_df_from_json(self): 103 | if self.dataset_name == 'train': 104 | data_path = '%s/train_qa.json'%self.csv_dir 105 | elif self.dataset_name == 'val': 106 | data_path = '%s/val_qa.json'%self.csv_dir 107 | elif self.dataset_name == 'test': 108 | data_path = '%s/test_qa.json'%self.csv_dir 109 | 110 | with open(data_path, 'r') as f: 111 | data_df = pd.read_json(f) 112 | 113 | return data_df 114 | 115 | def split_sentence_into_words(self, sentence, eos=True): 116 | ''' 117 | Split the given sentence (str) and enumerate the words as strs. 118 | Each word is normalized, i.e. lower-cased, non-alphabet characters 119 | like period (.) or comma (,) are stripped. 120 | When tokenizing, I use ``data_util.clean_str`` 121 | ''' 122 | try: 123 | words = data_util.clean_str(sentence).split() 124 | except: 125 | print(sentence) 126 | sys.exit() 127 | if eos: 128 | words = words + [eos_word] 129 | for w in words: 130 | if not w: 131 | continue 132 | yield w 133 | 134 | def create_answerset(self, ans_df): 135 | """Generate 1000 answer set from train_qa.json. 136 | Args: 137 | trainqa_path: path to train_qa.json. 138 | answerset_path: generate answer set of mc_qa 139 | """ 140 | answer_freq = ans_df['answer'].value_counts() 141 | answer_freq = list(answer_freq.iloc[0:1000].keys()) 142 | return answer_freq 143 | 144 | def build_word_vocabulary( 145 | self, 146 | all_sen=None, 147 | ans_df=None, 148 | word_count_threshold=0, 149 | ): 150 | ''' 151 | borrowed this implementation from @karpathy's neuraltalk. 152 | ''' 153 | log.infov('Building word vocabulary (%s) ...', self.dataset_name) 154 | 155 | if all_sen is None or ans_df is None: 156 | all_sen, ans_df = self.get_all_captions() 157 | all_captions_source = all_sen 158 | 159 | # enumerate all sentences to build frequency table 160 | word_counts = {} 161 | nsents = 0 162 | nwords = 0 163 | for sentence in all_captions_source: 164 | nsents += 1 165 | for w in self.split_sentence_into_words(sentence): 166 | word_counts[w] = word_counts.get(w, 0) + 1 167 | nwords += 1 168 | 169 | vocab = [ 170 | w for w in word_counts if word_counts[w] >= word_count_threshold 171 | ] 172 | print("Filtered vocab words (threshold = %d), from %d to %d" % (word_count_threshold, len(word_counts), len(vocab))) 173 | log.info( 174 | "Filtered vocab words (threshold = %d), from %d to %d", 175 | word_count_threshold, len(word_counts), len(vocab)) 176 | 177 | # build index and vocabularies 178 | self.word2idx = {} 179 | self.idx2word = {} 180 | 181 | self.idx2word[0] = '.' 182 | self.idx2word[1] = 'UNK' 183 | self.word2idx['#START#'] = 0 184 | self.word2idx['UNK'] = 1 185 | for idx, w in enumerate(vocab, start=2): 186 | self.word2idx[w] = idx 187 | self.idx2word[idx] = w 188 | 189 | pkl.dump( 190 | self.word2idx, 191 | open(os.path.join(self.vocabulary_dir, 'word_to_index.pkl'), 'wb') 192 | ) 193 | pkl.dump( 194 | self.idx2word, 195 | open(os.path.join(self.vocabulary_dir, 'index_to_word.pkl'), 'wb') 196 | ) 197 | 198 | word_counts['.'] = nsents 199 | bias_init_vector = np.array( 200 | [ 201 | 1.0 * word_counts[w] if i > 1 else 0 202 | for i, w in self.idx2word.items() 203 | ]) 204 | bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies 205 | bias_init_vector = np.log(bias_init_vector + 1e-20) 206 | bias_init_vector -= np.max( 207 | bias_init_vector) # shift to nice numeric range 208 | self.bias_init_vector = bias_init_vector 209 | 210 | #self.total_q = pd.DataFrame().from_csv(os.path.join(csv_dir,'Total_desc_question.csv'), sep='\t') 211 | answers = self.create_answerset(ans_df) 212 | print("all answer num : %d"%len(answers)) 213 | self.ans2idx = {} 214 | self.idx2ans = {} 215 | for idx, w in enumerate(answers): 216 | self.ans2idx[w] = idx 217 | self.idx2ans[idx] = w 218 | pkl.dump( 219 | self.ans2idx, 220 | open(os.path.join(self.vocabulary_dir, 'ans_to_index.pkl'), 'wb') 221 | ) 222 | pkl.dump( 223 | self.idx2ans, 224 | open(os.path.join(self.vocabulary_dir, 'index_to_ans.pkl'), 'wb') 225 | ) 226 | 227 | # Make glove embedding. 228 | #import spacy 229 | #nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors') 230 | 231 | with open(self.vocabulary_dir + '/glove.42B.300d.txt', 'rb') as f: 232 | lines = f.readlines() 233 | dict = {} 234 | 235 | for ix, line in enumerate(lines): 236 | if ix % 1000 == 0: 237 | print((ix, len(lines))) 238 | segs = line.split() 239 | wd = segs[0] 240 | dict[wd] = segs[1:] 241 | 242 | #max_length = len(vocab) 243 | max_length = len(self.idx2word) 244 | GLOVE_EMBEDDING_SIZE = 300 245 | 246 | glove_matrix = np.random.normal(size=[max_length, GLOVE_EMBEDDING_SIZE]) 247 | not_in_vocab = 0 248 | for i in range(len(vocab)): 249 | if i % 1000 == 0: 250 | print((i, len(vocab))) 251 | w = bytes(vocab[i], encoding='utf-8') 252 | #w_embed = nlp(u'%s' % w).vector 253 | if w in dict: 254 | w_embed = np.array(dict[w]) 255 | else: 256 | not_in_vocab += 1 257 | w_embed = np.random.normal(size=(GLOVE_EMBEDDING_SIZE)) 258 | #glove_matrix[i,:] = w_embed 259 | glove_matrix[i + 2, :] = w_embed # two placeholder words '.','UNK' 260 | 261 | print("[%d/%d] word is not saved" % (not_in_vocab, len(vocab))) 262 | 263 | vocab = pkl.dump( 264 | glove_matrix, 265 | open(os.path.join(self.vocabulary_dir, 'vocab_embedding.pkl'), 'wb') 266 | ) 267 | self.word_matrix = glove_matrix 268 | 269 | def load_word_vocabulary(self): 270 | 271 | word_matrix_path = os.path.join( 272 | self.vocabulary_dir, 'vocab_embedding.pkl') 273 | 274 | word2idx_path = os.path.join( 275 | self.vocabulary_dir, 'word_to_index.pkl') 276 | idx2word_path = os.path.join( 277 | self.vocabulary_dir, 'index_to_word.pkl') 278 | ans2idx_path = os.path.join( 279 | self.vocabulary_dir, 'ans_to_index.pkl') 280 | idx2ans_path = os.path.join( 281 | self.vocabulary_dir, 'index_to_ans.pkl') 282 | 283 | if not (os.path.exists(word_matrix_path) and 284 | os.path.exists(word2idx_path) and 285 | os.path.exists(idx2word_path) and 286 | os.path.exists(ans2idx_path) and os.path.exists(idx2ans_path)): 287 | self.build_word_vocabulary() 288 | 289 | # ndarray 290 | with open(word_matrix_path, 'rb') as f: 291 | self.word_matrix = pkl.load(f) 292 | log.info("Load word_matrix from pkl file : %s", word_matrix_path) 293 | 294 | with open(word2idx_path, 'rb') as f: 295 | self.word2idx = pkl.load(f) 296 | log.info("Load word2idx from pkl file : %s", word2idx_path) 297 | 298 | with open(idx2word_path, 'rb') as f: 299 | self.idx2word = pkl.load(f) 300 | log.info("Load idx2word from pkl file : %s", idx2word_path) 301 | 302 | with open(ans2idx_path, 'rb') as f: 303 | self.ans2idx = pkl.load(f) 304 | log.info("Load answer2idx from pkl file : %s", ans2idx_path) 305 | 306 | with open(idx2ans_path, 'rb') as f: 307 | self.idx2ans = pkl.load(f) 308 | log.info("Load idx2answers from pkl file : %s", idx2ans_path) 309 | 310 | def get_all_captions(self): 311 | ''' 312 | Iterate caption strings associated in the vid/gifs. 313 | ''' 314 | data_path = '%s/train_qa.json'%self.csv_dir 315 | with open(data_path, 'r') as f: 316 | data_df = pd.read_json(f) 317 | 318 | all_sents = list(data_df['question']) 319 | return all_sents, data_df 320 | 321 | def get_video_feature(self, key): # key : gif_name 322 | if self.res_avg_feat is None: 323 | self.res_avg_feat = h5py.File(self.res_avg_file, 'r') 324 | if self.i3d_avg_feat is None: 325 | self.i3d_avg_feat = h5py.File(self.i3d_avg_file, 'r') 326 | if self.res_roi_feat is None: 327 | self.res_roi_feat = h5py.File(self.res_roi_file, 'r') 328 | if self.i3d_roi_feat is None: 329 | self.i3d_roi_feat = h5py.File(self.i3d_roi_file, 'r') 330 | 331 | video_id = 'vid' + str(key) 332 | 333 | try: 334 | res_avg_feat = np.array(self.res_avg_feat[video_id]) # T, d 335 | i3d_avg_feat = np.array(self.i3d_avg_feat[video_id]) # T, d 336 | res_roi_feat = np.array(self.res_roi_feat['image_features'][video_id]) # T, 5, d 337 | roi_bbox_feat = np.array(self.res_roi_feat['spatial_features'][video_id]) # T, 5, 6 338 | i3d_roi_feat = np.array(self.i3d_roi_feat[video_id]) # T, 5, d 339 | except KeyError: # no img 340 | print('no image', key) 341 | res_avg_feat = np.zeros((1, 2048)) 342 | i3d_avg_feat = np.zeros((1, 2048)) 343 | res_roi_feat = np.zeros((1, self.obj_max_num, 2048)) 344 | roi_bbox_feat = np.zeros((1, self.obj_max_num, 6)) 345 | i3d_roi_feat = np.zeros((1, self.obj_max_num, 2048)) 346 | 347 | return res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat 348 | 349 | def convert_sentence_to_matrix(self, sentence, eos=True): 350 | ''' 351 | Convert the given sentence into word indices and masks. 352 | WARNING: Unknown words (not in vocabulary) are revmoed. 353 | 354 | Args: 355 | sentence: A str for unnormalized sentence, containing T words 356 | 357 | Returns: 358 | sentence_word_indices : list of (at most) length T, 359 | each being a word index 360 | ''' 361 | sent2indices = [ 362 | self.word2idx[w] if w in self.word2idx else 1 363 | for w in sentence 364 | ] # 1 is UNK, unknown 365 | T = len(sent2indices) 366 | length = min(T, self.q_max_length) 367 | return sent2indices[:length] 368 | 369 | def get_question(self, key): 370 | ''' 371 | Return question index for given key. 372 | ''' 373 | question = self.data_df.loc[key, ['question']].values 374 | question = question[0] 375 | 376 | question = self.split_sentence_into_words(question, eos=False) 377 | q_refine = [] 378 | for w in question: 379 | q_refine.append(w) 380 | q_refine = self.convert_sentence_to_matrix(q_refine) 381 | return q_refine 382 | 383 | def get_answer(self, key): 384 | answer = self.data_df.loc[key, ['answer']].values 385 | answer = answer[0] 386 | 387 | return answer 388 | 389 | def get_result(self): 390 | self.padded_all_ques = [] 391 | self.answers = [] 392 | 393 | self.all_ques_lengths = [] 394 | self.keys = self.data_df['video_id'].values.astype(np.int64) 395 | 396 | for index, row in self.data_df.iterrows(): 397 | # ====== Question ====== 398 | all_ques = self.get_question(index) 399 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length) 400 | self.padded_all_ques.append(all_ques_pad) 401 | all_ques_length = min(self.q_max_length, len(all_ques)) 402 | self.all_ques_lengths.append(all_ques_length) 403 | 404 | answer = self.get_answer(index) 405 | if str(answer) in self.ans2idx: 406 | answer = self.ans2idx[answer] 407 | else: 408 | # unknown token, check later 409 | answer = 1 410 | self.answers.append(np.array(answer, dtype=np.int64)) 411 | 412 | def get_item(self, index): 413 | key = self.keys[index] 414 | 415 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat \ 416 | = self.get_video_feature(key) 417 | 418 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32) 419 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32) 420 | 421 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num, 422 | self.v_max_length).astype(np.float32) 423 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num, 424 | self.v_max_length).astype(np.float32) 425 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num, 426 | self.v_max_length).astype(np.float32) 427 | video_length = min(self.v_max_length, res_avg_feat.shape[0]) 428 | 429 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \ 430 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index] 431 | -------------------------------------------------------------------------------- /embed_loss.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | # __all__ = ['MultipleChoiceLoss', 'CountLoss'] 11 | 12 | 13 | class MultipleChoiceLoss(nn.Module): 14 | 15 | def __init__(self, num_option=5, margin=1, size_average=True): 16 | super(MultipleChoiceLoss, self).__init__() 17 | self.margin = margin 18 | self.num_option = num_option 19 | self.size_average = size_average 20 | 21 | # score is N x C 22 | 23 | def forward(self, score, target): 24 | N = score.size(0) 25 | C = score.size(1) 26 | assert self.num_option == C 27 | 28 | loss = torch.tensor(0.0).cuda() 29 | zero = torch.tensor(0.0).cuda() 30 | 31 | cnt = 0 32 | #print(N,C) 33 | for b in range(N): 34 | # loop over incorrect answer, check if correct answer's score larger than a margin 35 | c0 = target[b] 36 | for c in range(C): 37 | if c == c0: 38 | continue 39 | 40 | # right class and wrong class should have score difference larger than a margin 41 | # see formula under paper Eq(4) 42 | loss += torch.max(zero, 1.0 + score[b, c] - score[b, c0]) 43 | cnt += 1 44 | 45 | if cnt == 0: 46 | return loss 47 | 48 | return loss / cnt if self.size_average else loss 49 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | import os 7 | import argparse 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import random 13 | import h5py 14 | 15 | seed = 999 16 | 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.enabled = False 22 | # torch.backends.cudnn.benchmark = True 23 | # torch.backends.cudnn.deterministic = True 24 | 25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 26 | 27 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) 28 | 29 | 30 | def _init_fn(worker_id): 31 | np.random.seed(seed) 32 | 33 | 34 | from data_utils.dataset import TGIFQA 35 | from torch.utils.data import DataLoader 36 | from warmup_scheduler import GradualWarmupScheduler 37 | 38 | from model.masn import MASN 39 | 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | 42 | 43 | def main(): 44 | """Main script.""" 45 | 46 | args.pin_memory = False 47 | args.dataset = 'tgif_qa' 48 | args.log = './logs/%s' % args.model_name 49 | if not os.path.exists(args.log): 50 | os.mkdir(args.log) 51 | 52 | args.val_epoch_step = 1 53 | 54 | args.save_model_path = os.path.join(args.save_path, args.model_name) 55 | if not os.path.exists(args.save_model_path): 56 | os.makedirs(args.save_model_path) 57 | 58 | full_dataset = TGIFQA( 59 | dataset_name='train', 60 | q_max_length=args.q_max_length, 61 | v_max_length=args.v_max_length, 62 | max_n_videos=args.max_n_videos, 63 | data_type=args.task, 64 | csv_dir=args.df_dir, 65 | vocab_dir=args.vc_dir, 66 | feat_dir=args.feat_dir) 67 | test_dataset = TGIFQA( 68 | dataset_name='test', 69 | q_max_length=args.q_max_length, 70 | v_max_length=args.v_max_length, 71 | max_n_videos=args.max_n_videos, 72 | data_type=args.task, 73 | csv_dir=args.df_dir, 74 | vocab_dir=args.vc_dir, 75 | feat_dir=args.feat_dir) 76 | 77 | val_size = int(args.val_ratio * len(full_dataset)) 78 | train_size = len(full_dataset) - val_size 79 | train_dataset, val_dataset = torch.utils.data.random_split( 80 | full_dataset, [train_size, val_size]) 81 | print( 82 | 'Dataset lengths train/val/test %d/%d/%d' % 83 | (len(train_dataset), len(val_dataset), len(test_dataset))) 84 | 85 | train_dataloader = DataLoader( 86 | train_dataset, 87 | args.batch_size, 88 | shuffle=True, 89 | num_workers=args.num_workers, 90 | pin_memory=args.pin_memory, 91 | worker_init_fn=_init_fn) 92 | val_dataloader = DataLoader( 93 | val_dataset, 94 | args.batch_size, 95 | shuffle=False, 96 | num_workers=args.num_workers, 97 | pin_memory=args.pin_memory, 98 | worker_init_fn=_init_fn) 99 | test_dataloader = DataLoader( 100 | test_dataset, 101 | args.batch_size, 102 | shuffle=False, 103 | num_workers=args.num_workers, 104 | pin_memory=args.pin_memory, 105 | worker_init_fn=_init_fn) 106 | 107 | print('Load data successful.') 108 | 109 | args.resnet_input_size = 2048 110 | args.i3d_input_size = 2048 111 | 112 | args.text_embed_size = train_dataset.dataset.GLOVE_EMBEDDING_SIZE 113 | args.answer_vocab_size = None 114 | 115 | args.word_matrix = train_dataset.dataset.word_matrix 116 | args.voc_len = args.word_matrix.shape[0] 117 | assert args.text_embed_size == args.word_matrix.shape[1] 118 | 119 | VOCABULARY_SIZE = train_dataset.dataset.n_words 120 | assert VOCABULARY_SIZE == args.voc_len 121 | 122 | ### criterions 123 | if args.task == 'Count': 124 | # add L2 loss 125 | criterion = nn.MSELoss().to(device) 126 | elif args.task in ['Action', 'Trans']: 127 | from embed_loss import MultipleChoiceLoss 128 | criterion = MultipleChoiceLoss( 129 | num_option=5, margin=1, size_average=True).to(device) 130 | elif args.task == 'FrameQA': 131 | # add classification loss 132 | args.answer_vocab_size = len(train_dataset.dataset.ans2idx) 133 | print(('Vocabulary size', args.answer_vocab_size, VOCABULARY_SIZE)) 134 | criterion = nn.CrossEntropyLoss().to(device) 135 | 136 | if not args.test: 137 | train( 138 | args, train_dataloader, val_dataloader, test_dataloader, criterion) 139 | else: 140 | print(args.checkpoint[:5], args.task[:5]) 141 | model = torch.load(os.path.join(args.save_model_path, args.checkpoint)) 142 | test(args, model, test_dataloader, 0, criterion) 143 | 144 | 145 | def train(args, train_dataloader, val_dataloader, test_dataloader, criterion): 146 | model = MASN( 147 | args.voc_len, 148 | args.rnn_layers, 149 | args.word_matrix, 150 | args.resnet_input_size, 151 | args.i3d_input_size, 152 | args.hidden_size, 153 | dropout_p=args.dropout, 154 | gcn_layers=args.gcn_layers, 155 | answer_vocab_size=args.answer_vocab_size, 156 | q_max_len=args.q_max_length, 157 | v_max_len=args.v_max_length, 158 | ablation=args.ablation) 159 | 160 | if torch.cuda.device_count() > 1: 161 | print("Let's use", torch.cuda.device_count(), "GPUs!") 162 | model = nn.DataParallel(model) 163 | 164 | model.to(device) 165 | 166 | if args.change_lr == 'none': 167 | optimizer = torch.optim.Adam( 168 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 169 | elif args.change_lr == 'acc': 170 | optimizer = torch.optim.Adam( 171 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 172 | # val plateau scheduler 173 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 174 | optimizer, mode='max', factor=0.1, patience=3, verbose=True) 175 | # target lr = args.lr * multiplier 176 | scheduler_warmup = GradualWarmupScheduler( 177 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 178 | elif args.change_lr == 'loss': 179 | optimizer = torch.optim.Adam( 180 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 181 | # val plateau scheduler 182 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 183 | optimizer, mode='min', factor=0.1, patience=3, verbose=True) 184 | # target lr = args.lr * multiplier 185 | scheduler_warmup = GradualWarmupScheduler( 186 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 187 | elif args.change_lr == 'cos': 188 | optimizer = torch.optim.Adam( 189 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 190 | # consine annealing 191 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 192 | optimizer, args.max_epoch) 193 | # target lr = args.lr * multiplier 194 | scheduler_warmup = GradualWarmupScheduler( 195 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 196 | elif args.change_lr == 'step': 197 | optimizer = torch.optim.Adam( 198 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 199 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 200 | optimizer, milestones=args.lr_list, gamma=0.1) 201 | # scheduler_warmup = GradualWarmupScheduler( 202 | # optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 203 | 204 | best_val_acc = 0. if args.task != 'Count' else -100. 205 | 206 | for epoch in range(args.max_epoch): 207 | print('Start Training Epoch: {}'.format(epoch)) 208 | 209 | model.train() 210 | 211 | loss_list = [] 212 | prediction_list = [] 213 | correct_answer_list = [] 214 | 215 | if args.change_lr == 'cos': 216 | # consine annealing 217 | scheduler_warmup.step(epoch=epoch) 218 | 219 | for ii, data in enumerate(train_dataloader): 220 | if epoch == 0 and ii == 0: 221 | print([d.dtype for d in data], [d.size() for d in data]) 222 | data = [d.to(device) for d in data] 223 | 224 | optimizer.zero_grad() 225 | out, predictions, answers = model(args.task, *data) 226 | loss = criterion(out, answers) 227 | loss.backward() 228 | optimizer.step() 229 | 230 | correct_answer_list.append(answers) 231 | loss_list.append(loss.item()) 232 | prediction_list.append(predictions.detach()) 233 | if ii % 100 == 0: 234 | print("Batch: ", ii) 235 | 236 | train_loss = np.mean(loss_list) 237 | 238 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 239 | predict_answer = torch.cat(prediction_list, dim=0).long() 240 | assert correct_answer.shape == predict_answer.shape 241 | 242 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 243 | acc = current_num / len(correct_answer) * 100. 244 | 245 | # print('Learning Rate: {}'.format(optimizer.param_groups[0]['lr'])) 246 | if args.change_lr == 'acc': 247 | scheduler_warmup.step(epoch, val_acc) 248 | elif args.change_lr == 'loss': 249 | scheduler_warmup.step(epoch, val_loss) 250 | elif args.change_lr == 'step': 251 | scheduler.step() 252 | 253 | print( 254 | "Train|Epoch: {}, Acc : {:.3f}={}/{}, Train Loss: {:.3f}".format( 255 | epoch, acc, current_num, len(correct_answer), train_loss)) 256 | if args.task == 'Count': 257 | count_loss = F.mse_loss( 258 | predict_answer.float(), correct_answer.float()) 259 | print('Train|Count Real Loss:\t {:.3f}'.format(count_loss)) 260 | 261 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 262 | logfile.write( 263 | "Train|Epoch: %d, Acc : %.3f=%d/%d, Train Loss: %.3f\n" 264 | % (epoch, acc, current_num, len(correct_answer), train_loss) 265 | ) 266 | if args.task == 'Count': 267 | logfile.write( 268 | "Train|Count Real Loss:\t %.3f\n"%count_loss 269 | ) 270 | logfile.close() 271 | 272 | val_acc, val_loss = val(args, model, val_dataloader, epoch, criterion) 273 | 274 | if val_acc > best_val_acc: 275 | print('Best Val Acc ======') 276 | best_val_acc = val_acc 277 | if epoch % args.val_epoch_step == 0 or val_acc >= best_val_acc: 278 | test(args, model, test_dataloader, epoch, criterion) 279 | 280 | 281 | @torch.no_grad() 282 | def val(args, model, val_dataloader, epoch, criterion): 283 | model.eval() 284 | 285 | loss_list = [] 286 | prediction_list = [] 287 | correct_answer_list = [] 288 | 289 | for ii, data in enumerate(val_dataloader): 290 | data = [d.to(device) for d in data] 291 | 292 | out, predictions, answers = model(args.task, *data) 293 | loss = criterion(out, answers) 294 | 295 | correct_answer_list.append(answers) 296 | loss_list.append(loss.item()) 297 | prediction_list.append(predictions.detach()) 298 | 299 | val_loss = np.mean(loss_list) 300 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 301 | predict_answer = torch.cat(prediction_list, dim=0).long() 302 | assert correct_answer.shape == predict_answer.shape 303 | 304 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 305 | 306 | acc = current_num / len(correct_answer) * 100. 307 | 308 | print( 309 | "VAL|Epoch: {}, Acc: {:3f}={}/{}, Val Loss: {:3f}".format( 310 | epoch, acc, current_num, len(correct_answer), val_loss)) 311 | 312 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 313 | logfile.write( 314 | "VAL|Epoch: %d, Acc: %.3f=%d/%d, Val Loss: %.3f\n" 315 | % (epoch, acc, current_num, len(correct_answer), val_loss) 316 | ) 317 | logfile.close() 318 | 319 | if args.task == 'Count': 320 | print( 321 | 'VAL|Count Real Loss:\t {:.3f}'.format( 322 | F.mse_loss(predict_answer.float(), correct_answer.float()))) 323 | acc = -F.mse_loss(predict_answer.float(), correct_answer.float()) 324 | 325 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 326 | logfile.write( 327 | "VAL|Count Real Loss:\t %.3f}\n" 328 | % (F.mse_loss(predict_answer.float(), correct_answer.float())) 329 | ) 330 | logfile.close() 331 | return acc, val_loss 332 | 333 | 334 | @torch.no_grad() 335 | def test(args, model, test_dataloader, epoch, criterion): 336 | 337 | model.eval() 338 | 339 | loss_list = [] 340 | prediction_list = [] 341 | correct_answer_list = [] 342 | 343 | for ii, data in enumerate(test_dataloader): 344 | data = [d.to(device) for d in data] 345 | 346 | out, predictions, answers = model(args.task, *data) 347 | loss = criterion(out, answers) 348 | 349 | correct_answer_list.append(answers) 350 | loss_list.append(loss.item()) 351 | prediction_list.append(predictions.detach()) 352 | 353 | test_loss = np.mean(loss_list) 354 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 355 | predict_answer = torch.cat(prediction_list, dim=0).long() 356 | assert correct_answer.shape == predict_answer.shape 357 | 358 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 359 | 360 | acc = current_num / len(correct_answer) * 100. 361 | 362 | print( 363 | "Test|Epoch: {}, Acc: {:3f}={}/{}, Test Loss: {:3f}".format( 364 | epoch, acc, current_num, len(correct_answer), test_loss)) 365 | 366 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 367 | logfile.write( 368 | "Test|Epoch: %d, Acc: %.3f=%d/%d, Test Loss: %.3f\n" 369 | % (epoch, acc, current_num, len(correct_answer), test_loss) 370 | ) 371 | logfile.close() 372 | 373 | if args.save: 374 | if (args.task == 'Action' and 375 | acc >= 80) or (args.task == 'Trans' and 376 | acc >= 80) or (args.task == 'FrameQA' and 377 | acc >= 55): 378 | torch.save( 379 | model, os.path.join(args.save_model_path, 380 | args.task + '_' + str(acc.item())[:5] + '.pth')) 381 | print('Save model at ', args.save_model_path) 382 | 383 | if args.task == 'Count': 384 | count_loss = F.mse_loss(predict_answer.float(), correct_answer.float()) 385 | print('Test|Count Real Loss:\t {:.3f}'.format(count_loss)) 386 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 387 | logfile.write( 388 | 'Test|Count Real Loss:\t %.3f\n' % (count_loss) 389 | ) 390 | logfile.close() 391 | if args.save and count_loss <= 4.0: 392 | torch.save( 393 | model, os.path.join(args.save_model_path, 394 | args.task + '_' + str(count_loss.item())[:5] + '.pth')) 395 | print('Save model at ', args.save_model_path) 396 | 397 | 398 | if __name__ == '__main__': 399 | parser = argparse.ArgumentParser() 400 | 401 | parser.add_argument('--model_name', type=str, default='[MASN]') 402 | 403 | ################ path config ################ 404 | parser.add_argument('--feat_dir', default='/data/TGIFQA', help='path for resnet and i3d features') 405 | parser.add_argument('--vc_dir', default='/data/TGIFQA/vocab', help='path for vocabulary') 406 | parser.add_argument('--df_dir', default='/data/TGIFQA/question', help='path for tgif question csv files') 407 | 408 | ################ inference config ################ 409 | parser.add_argument( 410 | '--checkpoint', 411 | type=str, 412 | default='FrameQA_59.73.pth', 413 | help='path to checkpoint') 414 | parser.add_argument( 415 | '--save_path', 416 | type=str, 417 | default='./saved_models/', 418 | help='path for saving trained models') 419 | 420 | parser.add_argument( 421 | '--save', action='store_true', default=True, help='save models or not') 422 | parser.add_argument( 423 | '--hidden_size', 424 | type=int, 425 | default=512, 426 | help='dimension of model') 427 | parser.add_argument( 428 | '--test', action='store_true', default=False, help='Train or Test') 429 | parser.add_argument('--max_epoch', type=int, default=100) 430 | parser.add_argument('--val_ratio', type=float, default=0.1) 431 | parser.add_argument('--q_max_length', type=int, default=20) 432 | parser.add_argument('--v_max_length', type=int, default=20) 433 | 434 | parser.add_argument( 435 | '--task', 436 | type=str, 437 | default='Count', 438 | help='[Count, Action, FrameQA, Trans]') 439 | parser.add_argument( 440 | '--rnn_layers', type=int, default=1, help='number of layers in lstm') 441 | parser.add_argument( 442 | '--gcn_layers', 443 | type=int, 444 | default=2, 445 | help='number of layers in gcn (+1)') 446 | parser.add_argument('--batch_size', type=int, default=32) 447 | parser.add_argument('--max_n_videos', type=int, default=100000) 448 | parser.add_argument('--num_workers', type=int, default=1) 449 | parser.add_argument('--lr', type=float, default=0.0001) 450 | parser.add_argument('--lr_list', type=list, default=[10, 20, 30, 40]) 451 | parser.add_argument('--dropout', type=float, default=0.3) 452 | parser.add_argument( 453 | '--change_lr', type=str, default='none', help='0 False, 1 True') 454 | parser.add_argument( 455 | '--weight_decay', type=float, default=0, help='weight_decay') 456 | parser.add_argument('--ablation', type=str, default='none') 457 | 458 | args = parser.parse_args() 459 | print(args) 460 | 461 | main() -------------------------------------------------------------------------------- /main_msrvtt.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | import os 7 | import argparse 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import random 13 | import h5py 14 | 15 | seed = 999 16 | 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.enabled = False 22 | # torch.backends.cudnn.benchmark = True 23 | # torch.backends.cudnn.deterministic = True 24 | 25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 26 | 27 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) 28 | 29 | 30 | def _init_fn(worker_id): 31 | np.random.seed(seed) 32 | 33 | 34 | from data_utils.dataset_msrvtt import MSRVTTQA 35 | from torch.utils.data import DataLoader 36 | from warmup_scheduler import GradualWarmupScheduler 37 | 38 | from model.masn import MASN 39 | 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | 42 | 43 | def main(): 44 | """Main script.""" 45 | 46 | args.pin_memory = False 47 | args.dataset = 'msrvtt_qa' 48 | args.log = './logs/%s' % args.model_name 49 | if not os.path.exists(args.log): 50 | os.mkdir(args.log) 51 | 52 | args.val_epoch_step = 1 53 | 54 | args.save_model_path = os.path.join(args.save_path, args.model_name) 55 | if not os.path.exists(args.save_model_path): 56 | os.makedirs(args.save_model_path) 57 | 58 | train_dataset = MSRVTTQA( 59 | dataset_name='train', 60 | q_max_length=args.q_max_length, 61 | v_max_length=args.v_max_length, 62 | max_n_videos=args.max_n_videos, 63 | csv_dir=args.df_dir, 64 | vocab_dir=args.vc_dir, 65 | feat_dir=args.feat_dir) 66 | val_dataset = MSRVTTQA( 67 | dataset_name='val', 68 | q_max_length=args.q_max_length, 69 | v_max_length=args.v_max_length, 70 | max_n_videos=args.max_n_videos, 71 | csv_dir=args.df_dir, 72 | vocab_dir=args.vc_dir, 73 | feat_dir=args.feat_dir) 74 | test_dataset = MSRVTTQA( 75 | dataset_name='test', 76 | q_max_length=args.q_max_length, 77 | v_max_length=args.v_max_length, 78 | max_n_videos=args.max_n_videos, 79 | csv_dir=args.df_dir, 80 | vocab_dir=args.vc_dir, 81 | feat_dir=args.feat_dir) 82 | 83 | print( 84 | 'Dataset lengths train/val/test %d/%d/%d' % 85 | (len(train_dataset), len(val_dataset), len(test_dataset))) 86 | 87 | train_dataloader = DataLoader( 88 | train_dataset, 89 | args.batch_size, 90 | shuffle=True, 91 | num_workers=args.num_workers, 92 | pin_memory=args.pin_memory, 93 | worker_init_fn=_init_fn) 94 | val_dataloader = DataLoader( 95 | val_dataset, 96 | args.batch_size, 97 | shuffle=False, 98 | num_workers=args.num_workers, 99 | pin_memory=args.pin_memory, 100 | worker_init_fn=_init_fn) 101 | test_dataloader = DataLoader( 102 | test_dataset, 103 | args.batch_size, 104 | shuffle=False, 105 | num_workers=args.num_workers, 106 | pin_memory=args.pin_memory, 107 | worker_init_fn=_init_fn) 108 | 109 | print('Load data successful.') 110 | 111 | args.resnet_input_size = 2048 112 | args.c3d_input_size = 2048 113 | 114 | args.text_embed_size = train_dataset.GLOVE_EMBEDDING_SIZE 115 | args.answer_vocab_size = None 116 | 117 | args.word_matrix = train_dataset.word_matrix 118 | args.voc_len = args.word_matrix.shape[0] 119 | assert args.text_embed_size == args.word_matrix.shape[1] 120 | 121 | VOCABULARY_SIZE = train_dataset.n_words 122 | assert VOCABULARY_SIZE == args.voc_len 123 | 124 | # add classification loss 125 | args.answer_vocab_size = len(train_dataset.ans2idx) 126 | print(('Vocabulary size', args.answer_vocab_size, VOCABULARY_SIZE)) 127 | criterion = nn.CrossEntropyLoss().to(device) 128 | 129 | if not args.test: 130 | train( 131 | args, train_dataloader, val_dataloader, test_dataloader, criterion) 132 | else: 133 | model = torch.load(os.path.join(args.save_model_path, args.checkpoint)) 134 | test(args, model, test_dataloader, 0, criterion) 135 | 136 | 137 | def train(args, train_dataloader, val_dataloader, test_dataloader, criterion): 138 | model = MASN( 139 | args.voc_len, 140 | args.rnn_layers, 141 | args.word_matrix, 142 | args.resnet_input_size, 143 | args.i3d_input_size, 144 | args.hidden_size, 145 | dropout_p=args.dropout, 146 | gcn_layers=args.gcn_layers, 147 | answer_vocab_size=args.answer_vocab_size, 148 | q_max_len=args.q_max_length, 149 | v_max_len=args.v_max_length, 150 | ablation=args.ablation) 151 | 152 | if torch.cuda.device_count() > 1: 153 | print("Let's use", torch.cuda.device_count(), "GPUs!") 154 | model = nn.DataParallel(model) 155 | 156 | model.to(device) 157 | 158 | if args.change_lr == 'none': 159 | optimizer = torch.optim.Adam( 160 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 161 | elif args.change_lr == 'acc': 162 | optimizer = torch.optim.Adam( 163 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 164 | # val plateau scheduler 165 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 166 | optimizer, mode='max', factor=0.1, patience=3, verbose=True) 167 | # target lr = args.lr * multiplier 168 | scheduler_warmup = GradualWarmupScheduler( 169 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 170 | elif args.change_lr == 'loss': 171 | optimizer = torch.optim.Adam( 172 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 173 | # val plateau scheduler 174 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 175 | optimizer, mode='min', factor=0.1, patience=3, verbose=True) 176 | # target lr = args.lr * multiplier 177 | scheduler_warmup = GradualWarmupScheduler( 178 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 179 | elif args.change_lr == 'cos': 180 | optimizer = torch.optim.Adam( 181 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 182 | # consine annealing 183 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 184 | optimizer, args.max_epoch) 185 | # target lr = args.lr * multiplier 186 | scheduler_warmup = GradualWarmupScheduler( 187 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 188 | elif args.change_lr == 'step': 189 | optimizer = torch.optim.Adam( 190 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 191 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 192 | optimizer, milestones=args.lr_list, gamma=0.1) 193 | # scheduler_warmup = GradualWarmupScheduler( 194 | # optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 195 | 196 | best_val_acc = 0. 197 | 198 | for epoch in range(args.max_epoch): 199 | print('Start Training Epoch: {}'.format(epoch)) 200 | 201 | model.train() 202 | 203 | loss_list = [] 204 | prediction_list = [] 205 | correct_answer_list = [] 206 | 207 | if args.change_lr == 'cos': 208 | # consine annealing 209 | scheduler_warmup.step(epoch=epoch) 210 | 211 | for ii, data in enumerate(train_dataloader): 212 | if epoch == 0 and ii == 0: 213 | print([d.dtype for d in data], [d.size() for d in data]) 214 | data = [d.to(device) for d in data] 215 | 216 | optimizer.zero_grad() 217 | out, predictions, answers = model(args.task, *data) 218 | loss = criterion(out, answers) 219 | loss.backward() 220 | optimizer.step() 221 | 222 | correct_answer_list.append(answers) 223 | loss_list.append(loss.item()) 224 | prediction_list.append(predictions.detach()) 225 | if ii % 100 == 0: 226 | print("Batch: ", ii) 227 | 228 | train_loss = np.mean(loss_list) 229 | 230 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 231 | predict_answer = torch.cat(prediction_list, dim=0).long() 232 | assert correct_answer.shape == predict_answer.shape 233 | 234 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 235 | acc = current_num / len(correct_answer) * 100. 236 | 237 | # print('Learning Rate: {}'.format(optimizer.param_groups[0]['lr'])) 238 | if args.change_lr == 'acc': 239 | scheduler_warmup.step(epoch, val_acc) 240 | elif args.change_lr == 'loss': 241 | scheduler_warmup.step(epoch, val_loss) 242 | elif args.change_lr == 'step': 243 | scheduler.step() 244 | 245 | print( 246 | "Train|Epoch: {}, Acc : {:.3f}={}/{}, Train Loss: {:.3f}".format( 247 | epoch, acc, current_num, len(correct_answer), train_loss)) 248 | 249 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 250 | logfile.write( 251 | "Train|Epoch: %d, Acc : %.3f=%d/%d, Train Loss: %.3f\n" 252 | % (epoch, acc, current_num, len(correct_answer), train_loss) 253 | ) 254 | logfile.close() 255 | 256 | val_acc, val_loss = val(args, model, val_dataloader, epoch, criterion) 257 | 258 | if val_acc > best_val_acc: 259 | print('Best Val Acc ======') 260 | best_val_acc = val_acc 261 | if epoch % args.val_epoch_step == 0 or val_acc >= best_val_acc: 262 | test(args, model, test_dataloader, epoch, criterion) 263 | 264 | 265 | @torch.no_grad() 266 | def val(args, model, val_dataloader, epoch, criterion): 267 | model.eval() 268 | 269 | loss_list = [] 270 | prediction_list = [] 271 | correct_answer_list = [] 272 | 273 | for ii, data in enumerate(val_dataloader): 274 | data = [d.to(device) for d in data] 275 | 276 | out, predictions, answers = model(args.task, *data) 277 | loss = criterion(out, answers) 278 | 279 | correct_answer_list.append(answers) 280 | loss_list.append(loss.item()) 281 | prediction_list.append(predictions.detach()) 282 | 283 | val_loss = np.mean(loss_list) 284 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 285 | predict_answer = torch.cat(prediction_list, dim=0).long() 286 | assert correct_answer.shape == predict_answer.shape 287 | 288 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 289 | 290 | acc = current_num / len(correct_answer) * 100. 291 | 292 | print( 293 | "VAL|Epoch: {}, Acc: {:3f}={}/{}, Val Loss: {:3f}".format( 294 | epoch, acc, current_num, len(correct_answer), val_loss)) 295 | 296 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 297 | logfile.write( 298 | "VAL|Epoch: %d, Acc: %.3f=%d/%d, Val Loss: %.3f\n" 299 | % (epoch, acc, current_num, len(correct_answer), val_loss) 300 | ) 301 | logfile.close() 302 | 303 | return acc, val_loss 304 | 305 | 306 | @torch.no_grad() 307 | def test(args, model, test_dataloader, epoch, criterion): 308 | 309 | model.eval() 310 | 311 | loss_list = [] 312 | prediction_list = [] 313 | correct_answer_list = [] 314 | 315 | for ii, data in enumerate(test_dataloader): 316 | data = [d.to(device) for d in data] 317 | 318 | out, predictions, answers = model(args.task, *data) 319 | loss = criterion(out, answers) 320 | 321 | correct_answer_list.append(answers) 322 | loss_list.append(loss.item()) 323 | prediction_list.append(predictions.detach()) 324 | 325 | test_loss = np.mean(loss_list) 326 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 327 | predict_answer = torch.cat(prediction_list, dim=0).long() 328 | assert correct_answer.shape == predict_answer.shape 329 | 330 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 331 | 332 | acc = current_num / len(correct_answer) * 100. 333 | 334 | print( 335 | "Test|Epoch: {}, Acc: {:3f}={}/{}, Test Loss: {:3f}".format( 336 | epoch, acc, current_num, len(correct_answer), test_loss)) 337 | 338 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 339 | logfile.write( 340 | "Test|Epoch: %d, Acc: %.3f=%d/%d, Test Loss: %.3f\n" 341 | % (epoch, acc, current_num, len(correct_answer), test_loss) 342 | ) 343 | logfile.close() 344 | 345 | if args.save: 346 | if acc >= 30: 347 | torch.save( 348 | model, os.path.join(args.save_model_path, 349 | args.task + '_' + str(acc.item())[:5] + '.pth')) 350 | print('Save model at ', args.save_model_path) 351 | 352 | if __name__ == '__main__': 353 | parser = argparse.ArgumentParser() 354 | 355 | parser.add_argument('--model_name', type=str, default='[MSVD_msrvtt]') 356 | 357 | ################ path config ################ 358 | parser.add_argument('--feat_dir', default='/data/MSRVTT-QA/video', help='path for imagenet and c3d features') 359 | parser.add_argument('--vc_dir', default='/data/MSRVTT-QA/vocab4000', help='path for vocabulary') 360 | parser.add_argument('--df_dir', default='/data/MSRVTT-QA/question', help='path for tgif question csv files') 361 | 362 | ################ inference config ################ 363 | parser.add_argument( 364 | '--checkpoint', 365 | type=str, 366 | default='', 367 | help='path to checkpoint') 368 | parser.add_argument( 369 | '--save_path', 370 | type=str, 371 | default='./saved_models/', 372 | help='path for saving trained models') 373 | 374 | parser.add_argument( 375 | '--save', action='store_true', default=True, help='save models or not') 376 | parser.add_argument( 377 | '--hidden_size', 378 | type=int, 379 | default=512, 380 | help='dimension of lstm hidden states') 381 | parser.add_argument( 382 | '--test', action='store_true', default=False, help='Train or Test') 383 | parser.add_argument('--max_epoch', type=int, default=100) 384 | parser.add_argument('--val_ratio', type=float, default=0.1) 385 | parser.add_argument('--q_max_length', type=int, default=20) 386 | parser.add_argument('--v_max_length', type=int, default=30) 387 | 388 | parser.add_argument( 389 | '--task', 390 | type=str, 391 | default='MS-QA' 392 | ) 393 | parser.add_argument( 394 | '--rnn_layers', type=int, default=1, help='number of layers in lstm') 395 | parser.add_argument( 396 | '--gcn_layers', 397 | type=int, 398 | default=2, 399 | help='number of layers in gcn (+1)') 400 | parser.add_argument('--batch_size', type=int, default=32) 401 | parser.add_argument('--max_n_videos', type=int, default=100000) 402 | parser.add_argument('--num_workers', type=int, default=1) 403 | parser.add_argument('--lr', type=float, default=0.0001) 404 | parser.add_argument('--lr_list', type=list, default=[10, 20, 30, 40]) 405 | parser.add_argument('--dropout', type=float, default=0.3) 406 | parser.add_argument( 407 | '--change_lr', type=str, default='none', help='0 False, 1 True') 408 | parser.add_argument( 409 | '--weight_decay', type=float, default=0, help='weight_decay') 410 | parser.add_argument('--ablation', type=str, default='none') 411 | 412 | args = parser.parse_args() 413 | print(args) 414 | 415 | main() -------------------------------------------------------------------------------- /main_msvd.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | import os 7 | import argparse 8 | import torch 9 | import numpy as np 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | import random 13 | import h5py 14 | 15 | seed = 999 16 | 17 | random.seed(seed) 18 | np.random.seed(seed) 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.enabled = False 22 | # torch.backends.cudnn.benchmark = True 23 | # torch.backends.cudnn.deterministic = True 24 | 25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 26 | 27 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) 28 | 29 | 30 | def _init_fn(worker_id): 31 | np.random.seed(seed) 32 | 33 | 34 | from data_utils.dataset_msvd import MSVDQA 35 | from torch.utils.data import DataLoader 36 | from warmup_scheduler import GradualWarmupScheduler 37 | 38 | from model.masn import MASN 39 | 40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 41 | 42 | 43 | def main(): 44 | """Main script.""" 45 | 46 | args.pin_memory = False 47 | args.dataset = 'msvd_qa' 48 | args.log = './logs/%s' % args.model_name 49 | if not os.path.exists(args.log): 50 | os.mkdir(args.log) 51 | 52 | args.val_epoch_step = 1 53 | 54 | args.save_model_path = os.path.join(args.save_path, args.model_name) 55 | if not os.path.exists(args.save_model_path): 56 | os.makedirs(args.save_model_path) 57 | 58 | train_dataset = MSVDQA( 59 | dataset_name='train', 60 | q_max_length=args.q_max_length, 61 | v_max_length=args.v_max_length, 62 | max_n_videos=args.max_n_videos, 63 | csv_dir=args.df_dir, 64 | vocab_dir=args.vc_dir, 65 | feat_dir=args.feat_dir) 66 | val_dataset = MSVDQA( 67 | dataset_name='val', 68 | q_max_length=args.q_max_length, 69 | v_max_length=args.v_max_length, 70 | max_n_videos=args.max_n_videos, 71 | csv_dir=args.df_dir, 72 | vocab_dir=args.vc_dir, 73 | feat_dir=args.feat_dir) 74 | test_dataset = MSVDQA( 75 | dataset_name='test', 76 | q_max_length=args.q_max_length, 77 | v_max_length=args.v_max_length, 78 | max_n_videos=args.max_n_videos, 79 | csv_dir=args.df_dir, 80 | vocab_dir=args.vc_dir, 81 | feat_dir=args.feat_dir) 82 | 83 | print( 84 | 'Dataset lengths train/val/test %d/%d/%d' % 85 | (len(train_dataset), len(val_dataset), len(test_dataset))) 86 | 87 | train_dataloader = DataLoader( 88 | train_dataset, 89 | args.batch_size, 90 | shuffle=True, 91 | num_workers=args.num_workers, 92 | pin_memory=args.pin_memory, 93 | worker_init_fn=_init_fn) 94 | val_dataloader = DataLoader( 95 | val_dataset, 96 | args.batch_size, 97 | shuffle=False, 98 | num_workers=args.num_workers, 99 | pin_memory=args.pin_memory, 100 | worker_init_fn=_init_fn) 101 | test_dataloader = DataLoader( 102 | test_dataset, 103 | args.batch_size, 104 | shuffle=False, 105 | num_workers=args.num_workers, 106 | pin_memory=args.pin_memory, 107 | worker_init_fn=_init_fn) 108 | 109 | print('Load data successful.') 110 | 111 | args.resnet_input_size = 2048 112 | args.c3d_input_size = 2048 113 | 114 | args.text_embed_size = train_dataset.GLOVE_EMBEDDING_SIZE 115 | args.answer_vocab_size = None 116 | 117 | args.word_matrix = train_dataset.word_matrix 118 | args.voc_len = args.word_matrix.shape[0] 119 | assert args.text_embed_size == args.word_matrix.shape[1] 120 | 121 | VOCABULARY_SIZE = train_dataset.n_words 122 | assert VOCABULARY_SIZE == args.voc_len 123 | 124 | # add classification loss 125 | args.answer_vocab_size = len(train_dataset.ans2idx) 126 | print(('Vocabulary size', args.answer_vocab_size, VOCABULARY_SIZE)) 127 | criterion = nn.CrossEntropyLoss().to(device) 128 | 129 | if not args.test: 130 | train( 131 | args, train_dataloader, val_dataloader, test_dataloader, criterion) 132 | else: 133 | model = torch.load(os.path.join(args.save_model_path, args.checkpoint)) 134 | test(args, model, test_dataloader, 0, criterion) 135 | 136 | 137 | def train(args, train_dataloader, val_dataloader, test_dataloader, criterion): 138 | model = MASN( 139 | args.voc_len, 140 | args.rnn_layers, 141 | args.word_matrix, 142 | args.resnet_input_size, 143 | args.i3d_input_size, 144 | args.hidden_size, 145 | dropout_p=args.dropout, 146 | gcn_layers=args.gcn_layers, 147 | answer_vocab_size=args.answer_vocab_size, 148 | q_max_len=args.q_max_length, 149 | v_max_len=args.v_max_length, 150 | ablation=args.ablation) 151 | 152 | if torch.cuda.device_count() > 1: 153 | print("Let's use", torch.cuda.device_count(), "GPUs!") 154 | model = nn.DataParallel(model) 155 | 156 | model.to(device) 157 | 158 | if args.change_lr == 'none': 159 | optimizer = torch.optim.Adam( 160 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 161 | elif args.change_lr == 'acc': 162 | optimizer = torch.optim.Adam( 163 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 164 | # val plateau scheduler 165 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 166 | optimizer, mode='max', factor=0.1, patience=3, verbose=True) 167 | # target lr = args.lr * multiplier 168 | scheduler_warmup = GradualWarmupScheduler( 169 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 170 | elif args.change_lr == 'loss': 171 | optimizer = torch.optim.Adam( 172 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 173 | # val plateau scheduler 174 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( 175 | optimizer, mode='min', factor=0.1, patience=3, verbose=True) 176 | # target lr = args.lr * multiplier 177 | scheduler_warmup = GradualWarmupScheduler( 178 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 179 | elif args.change_lr == 'cos': 180 | optimizer = torch.optim.Adam( 181 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay) 182 | # consine annealing 183 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 184 | optimizer, args.max_epoch) 185 | # target lr = args.lr * multiplier 186 | scheduler_warmup = GradualWarmupScheduler( 187 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 188 | elif args.change_lr == 'step': 189 | optimizer = torch.optim.Adam( 190 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay) 191 | scheduler = torch.optim.lr_scheduler.MultiStepLR( 192 | optimizer, milestones=args.lr_list, gamma=0.1) 193 | # scheduler_warmup = GradualWarmupScheduler( 194 | # optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler) 195 | 196 | best_val_acc = 0. 197 | 198 | for epoch in range(args.max_epoch): 199 | print('Start Training Epoch: {}'.format(epoch)) 200 | 201 | model.train() 202 | 203 | loss_list = [] 204 | prediction_list = [] 205 | correct_answer_list = [] 206 | 207 | if args.change_lr == 'cos': 208 | # consine annealing 209 | scheduler_warmup.step(epoch=epoch) 210 | 211 | for ii, data in enumerate(train_dataloader): 212 | if epoch == 0 and ii == 0: 213 | print([d.dtype for d in data], [d.size() for d in data]) 214 | data = [d.to(device) for d in data] 215 | 216 | optimizer.zero_grad() 217 | out, predictions, answers = model(args.task, *data) 218 | loss = criterion(out, answers) 219 | loss.backward() 220 | optimizer.step() 221 | 222 | correct_answer_list.append(answers) 223 | loss_list.append(loss.item()) 224 | prediction_list.append(predictions.detach()) 225 | if ii % 100 == 0: 226 | print("Batch: ", ii) 227 | 228 | train_loss = np.mean(loss_list) 229 | 230 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 231 | predict_answer = torch.cat(prediction_list, dim=0).long() 232 | assert correct_answer.shape == predict_answer.shape 233 | 234 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 235 | acc = current_num / len(correct_answer) * 100. 236 | 237 | # print('Learning Rate: {}'.format(optimizer.param_groups[0]['lr'])) 238 | if args.change_lr == 'acc': 239 | scheduler_warmup.step(epoch, val_acc) 240 | elif args.change_lr == 'loss': 241 | scheduler_warmup.step(epoch, val_loss) 242 | elif args.change_lr == 'step': 243 | scheduler.step() 244 | 245 | print( 246 | "Train|Epoch: {}, Acc : {:.3f}={}/{}, Train Loss: {:.3f}".format( 247 | epoch, acc, current_num, len(correct_answer), train_loss)) 248 | 249 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 250 | logfile.write( 251 | "Train|Epoch: %d, Acc : %.3f=%d/%d, Train Loss: %.3f\n" 252 | % (epoch, acc, current_num, len(correct_answer), train_loss) 253 | ) 254 | logfile.close() 255 | 256 | val_acc, val_loss = val(args, model, val_dataloader, epoch, criterion) 257 | 258 | if val_acc > best_val_acc: 259 | print('Best Val Acc ======') 260 | best_val_acc = val_acc 261 | if epoch % args.val_epoch_step == 0 or val_acc >= best_val_acc: 262 | test(args, model, test_dataloader, epoch, criterion) 263 | 264 | 265 | @torch.no_grad() 266 | def val(args, model, val_dataloader, epoch, criterion): 267 | model.eval() 268 | 269 | loss_list = [] 270 | prediction_list = [] 271 | correct_answer_list = [] 272 | 273 | for ii, data in enumerate(val_dataloader): 274 | data = [d.to(device) for d in data] 275 | 276 | out, predictions, answers = model(args.task, *data) 277 | loss = criterion(out, answers) 278 | 279 | correct_answer_list.append(answers) 280 | loss_list.append(loss.item()) 281 | prediction_list.append(predictions.detach()) 282 | 283 | val_loss = np.mean(loss_list) 284 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 285 | predict_answer = torch.cat(prediction_list, dim=0).long() 286 | assert correct_answer.shape == predict_answer.shape 287 | 288 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 289 | 290 | acc = current_num / len(correct_answer) * 100. 291 | 292 | print( 293 | "VAL|Epoch: {}, Acc: {:3f}={}/{}, Val Loss: {:3f}".format( 294 | epoch, acc, current_num, len(correct_answer), val_loss)) 295 | 296 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 297 | logfile.write( 298 | "VAL|Epoch: %d, Acc: %.3f=%d/%d, Val Loss: %.3f\n" 299 | % (epoch, acc, current_num, len(correct_answer), val_loss) 300 | ) 301 | logfile.close() 302 | 303 | return acc, val_loss 304 | 305 | 306 | @torch.no_grad() 307 | def test(args, model, test_dataloader, epoch, criterion): 308 | 309 | model.eval() 310 | 311 | loss_list = [] 312 | prediction_list = [] 313 | correct_answer_list = [] 314 | 315 | for ii, data in enumerate(test_dataloader): 316 | data = [d.to(device) for d in data] 317 | 318 | out, predictions, answers = model(args.task, *data) 319 | loss = criterion(out, answers) 320 | 321 | correct_answer_list.append(answers) 322 | loss_list.append(loss.item()) 323 | prediction_list.append(predictions.detach()) 324 | 325 | test_loss = np.mean(loss_list) 326 | correct_answer = torch.cat(correct_answer_list, dim=0).long() 327 | predict_answer = torch.cat(prediction_list, dim=0).long() 328 | assert correct_answer.shape == predict_answer.shape 329 | 330 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy() 331 | 332 | acc = current_num / len(correct_answer) * 100. 333 | 334 | print( 335 | "Test|Epoch: {}, Acc: {:3f}={}/{}, Test Loss: {:3f}".format( 336 | epoch, acc, current_num, len(correct_answer), test_loss)) 337 | 338 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+') 339 | logfile.write( 340 | "Test|Epoch: %d, Acc: %.3f=%d/%d, Test Loss: %.3f\n" 341 | % (epoch, acc, current_num, len(correct_answer), test_loss) 342 | ) 343 | logfile.close() 344 | 345 | if args.save: 346 | if acc >= 30: 347 | torch.save( 348 | model, os.path.join(args.save_model_path, 349 | args.task + '_' + str(acc.item())[:5] + '.pth')) 350 | print('Save model at ', args.save_model_path) 351 | 352 | if __name__ == '__main__': 353 | parser = argparse.ArgumentParser() 354 | 355 | parser.add_argument('--exp_name', type=str, default='[MSVD_msvd]') 356 | 357 | ################ path config ################ 358 | parser.add_argument('--feat_dir', default='/data/MSVD-QA', help='path for imagenet and c3d features') 359 | parser.add_argument('--vc_dir', default='/data/MSVD-QA/vocab', help='path for vocabulary') 360 | parser.add_argument('--df_dir', default='/data/MSVD-QA/question', help='path for tgif question csv files') 361 | 362 | ################ inference config ################ 363 | parser.add_argument( 364 | '--checkpoint', 365 | type=str, 366 | default='', 367 | help='path to checkpoint') 368 | parser.add_argument( 369 | '--save_path', 370 | type=str, 371 | default='./saved_models/', 372 | help='path for saving trained models') 373 | 374 | parser.add_argument( 375 | '--save', action='store_true', default=True, help='save models or not') 376 | parser.add_argument( 377 | '--hidden_size', 378 | type=int, 379 | default=512, 380 | help='dimension of lstm hidden states') 381 | parser.add_argument( 382 | '--test', action='store_true', default=False, help='Train or Test') 383 | parser.add_argument('--max_epoch', type=int, default=100) 384 | parser.add_argument('--val_ratio', type=float, default=0.1) 385 | parser.add_argument('--q_max_length', type=int, default=20) 386 | parser.add_argument('--v_max_length', type=int, default=30) 387 | 388 | parser.add_argument( 389 | '--task', 390 | type=str, 391 | default='MS-QA' 392 | ) 393 | parser.add_argument( 394 | '--rnn_layers', type=int, default=1, help='number of layers in lstm') 395 | parser.add_argument( 396 | '--gcn_layers', 397 | type=int, 398 | default=2, 399 | help='number of layers in gcn (+1)') 400 | parser.add_argument('--batch_size', type=int, default=32) 401 | parser.add_argument('--max_n_videos', type=int, default=100000) 402 | parser.add_argument('--num_workers', type=int, default=1) 403 | parser.add_argument('--lr', type=float, default=0.0001) 404 | parser.add_argument('--lr_list', type=list, default=[10, 20, 30, 40]) 405 | parser.add_argument('--dropout', type=float, default=0.3) 406 | parser.add_argument( 407 | '--change_lr', type=str, default='none', help='0 False, 1 True') 408 | parser.add_argument( 409 | '--weight_decay', type=float, default=0, help='weight_decay') 410 | parser.add_argument('--ablation', type=str, default='none') 411 | 412 | args = parser.parse_args() 413 | print(args) 414 | 415 | main() -------------------------------------------------------------------------------- /model/masn.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import functional as F 9 | import numpy as np 10 | from model.modules import linear_weightdrop as dropnn 11 | 12 | from torch.autograd import Variable 13 | 14 | from model.modules.rnn_encoder import SentenceEncoderRNN 15 | from model.modules.gcn import VideoAdjLearner, GCN 16 | from model.modules.position_embedding import PositionEncoding 17 | from model.modules.ban.ban import BAN 18 | from model.modules.fusion.fusion import MotionApprFusion, AttFlat 19 | 20 | # torch.set_printoptions(threshold=np.inf) 21 | 22 | 23 | class MASN(nn.Module): 24 | 25 | def __init__( 26 | self, 27 | vocab_size, 28 | s_layers, 29 | s_embedding, 30 | resnet_input_size, 31 | i3d_input_size, 32 | hidden_size, 33 | dropout_p=0.0, 34 | gcn_layers=2, 35 | answer_vocab_size=None, 36 | q_max_len=35, 37 | v_max_len=80, 38 | ablation='none'): 39 | super().__init__() 40 | 41 | self.ablation = ablation 42 | self.q_max_len = q_max_len 43 | self.v_max_len = v_max_len 44 | self.hidden_size = hidden_size 45 | 46 | self.compress_appr_local = dropnn.WeightDropLinear( 47 | resnet_input_size, 48 | hidden_size, 49 | weight_dropout=dropout_p, 50 | bias=False) 51 | self.compress_motion_local = dropnn.WeightDropLinear( 52 | i3d_input_size, 53 | hidden_size, 54 | weight_dropout=dropout_p, 55 | bias=False) 56 | self.compress_appr_global = dropnn.WeightDropLinear( 57 | resnet_input_size, 58 | hidden_size, 59 | weight_dropout=dropout_p, 60 | bias=False) 61 | self.compress_motion_global = dropnn.WeightDropLinear( 62 | i3d_input_size, 63 | hidden_size, 64 | weight_dropout=dropout_p, 65 | bias=False) 66 | 67 | embedding_dim = s_embedding.shape[1] if s_embedding is not None else hidden_size 68 | self.glove = nn.Embedding(vocab_size, embedding_dim, padding_idx=0) 69 | if s_embedding is not None: 70 | print("glove embedding weight is loaded!") 71 | self.glove.weight = nn.Parameter(torch.from_numpy(s_embedding).float()) 72 | self.glove.weight.requires_grad = False 73 | self.embedding_proj = nn.Sequential( 74 | nn.Dropout(p=dropout_p), 75 | nn.Linear(embedding_dim, hidden_size, bias=False) 76 | ) 77 | 78 | self.sentence_encoder = SentenceEncoderRNN( 79 | vocab_size, 80 | hidden_size, 81 | input_dropout_p=dropout_p, 82 | dropout_p=dropout_p, 83 | n_layers=s_layers, 84 | bidirectional=True, 85 | rnn_cell='lstm' 86 | ) 87 | 88 | self.bbox_location_encoding = nn.Linear(6, 64) 89 | self.pos_location_encoding = PositionEncoding(n_filters=64, max_len=self.v_max_len) 90 | 91 | self.appr_local_proj = nn.Linear(hidden_size+128, hidden_size) 92 | self.motion_local_proj = nn.Linear(hidden_size+128, hidden_size) 93 | 94 | self.pos_enc = PositionEncoding(n_filters=512, max_len=self.v_max_len) 95 | self.appr_v = nn.Linear(hidden_size*2, hidden_size) 96 | self.motion_v = nn.Linear(hidden_size*2, hidden_size) 97 | 98 | self.appr_adj = VideoAdjLearner(hidden_size, hidden_size) 99 | self.appr_gcn = GCN(hidden_size, hidden_size, hidden_size, num_layers=gcn_layers) 100 | self.motion_adj = VideoAdjLearner(hidden_size, hidden_size) 101 | self.motion_gcn = GCN(hidden_size, hidden_size, hidden_size, num_layers=gcn_layers) 102 | 103 | self.res_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False) 104 | self.i3d_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False) 105 | 106 | self.appr_vq_interact = BAN(hidden_size, glimpse=4) 107 | self.motion_vq_interact = BAN(hidden_size, glimpse=4) 108 | 109 | self.motion_appr_fusion = MotionApprFusion(hidden_size, hidden_size, n_layer=1) 110 | self.attflat = AttFlat(hidden_size, hidden_size, 1, hidden_size) 111 | 112 | if answer_vocab_size is not None: 113 | self.fc = nn.Linear(hidden_size, answer_vocab_size) 114 | else: 115 | self.fc = nn.Linear(hidden_size, 1) 116 | 117 | def forward(self, task, *args): 118 | # expected sentence_inputs is of shape (batch_size, sentence_len, 1) 119 | # expected video_inputs is of shape (batch_size, frame_num, video_feature) 120 | self.task = task 121 | if task == 'Count': 122 | return self.forward_count(*args) 123 | elif task == 'FrameQA': 124 | return self.forward_frameqa(*args) 125 | elif task == 'Action' or task == 'Trans': 126 | return self.forward_trans_or_action(*args) 127 | elif task == 'MS-QA': 128 | return self.forward_msqa(*args) 129 | 130 | def model_block(self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, 131 | video_length, all_sen_inputs, all_ques_length): 132 | 133 | q_mask = self.make_mask(all_sen_inputs, all_ques_length) 134 | v_mask = self.make_mask(res_avg_inp[:,:,0], video_length) 135 | 136 | q_emb = F.relu(self.embedding_proj(self.glove(all_sen_inputs))) # b, q_len, d 137 | q_output, q_hidden = self.sentence_encoder(q_emb, input_lengths=all_ques_length) 138 | q_hidden = q_hidden.squeeze() 139 | 140 | bsz, v_len, obj_num = res_obj_inp.size(0), res_obj_inp.size(1), res_obj_inp.size(2) 141 | q_len = q_output.size(1) 142 | q_mask = q_mask[:,:q_len] 143 | 144 | # make local and global feature 145 | res_obj_inp = self.compress_appr_local(res_obj_inp) # b, v_len, N, d 146 | i3d_obj_inp = self.compress_motion_local(i3d_obj_inp) # b, v_len, N, d 147 | res_avg_inp = self.compress_appr_global(res_avg_inp) # b, v_len, d 148 | i3d_avg_inp = self.compress_motion_global(i3d_avg_inp) # b, v_len, d 149 | 150 | bbox_inp = self.bbox_location_encoding(bbox_inp) # b, v_len, N, d/8 151 | pos_inp = self.pos_location_encoding(res_obj_inp.contiguous().view(bsz*obj_num, v_len, -1)) # 1, v_len, 64 152 | pos_inp = pos_inp.unsqueeze(2).expand(bsz, v_len, obj_num, 64) * v_mask.unsqueeze(2).unsqueeze(3) # b, v_len, N, d/8 153 | 154 | appr_local = self.appr_local_proj(torch.cat([res_obj_inp, bbox_inp, pos_inp], dim=3)) # b, v_len, N, d 155 | motion_local = self.motion_local_proj(torch.cat([i3d_obj_inp, bbox_inp, pos_inp], dim=3)) # b, v_len, N, d 156 | 157 | v_len = appr_local.size(1) 158 | appr_local = appr_local.contiguous().view(bsz*v_len, obj_num, self.hidden_size) 159 | motion_local = motion_local.contiguous().view(bsz*v_len, obj_num, self.hidden_size) 160 | 161 | res_avg_inp = self.pos_enc(res_avg_inp) + res_avg_inp 162 | res_avg_inp = res_avg_inp.contiguous().view(bsz*v_len, self.hidden_size) 163 | res_avg_inp = res_avg_inp.unsqueeze(1).expand_as(appr_local) 164 | appr_v = self.appr_v(torch.cat([appr_local, res_avg_inp], dim=-1)) 165 | 166 | i3d_avg_inp = self.pos_enc(i3d_avg_inp) + i3d_avg_inp 167 | i3d_avg_inp = i3d_avg_inp.contiguous().view(bsz*v_len, self.hidden_size) 168 | i3d_avg_inp = i3d_avg_inp.unsqueeze(1).expand_as(motion_local) 169 | motion_v = self.motion_v(torch.cat([motion_local, i3d_avg_inp], dim=-1)) 170 | 171 | appr_v = appr_v.contiguous().view(bsz, v_len*obj_num, self.hidden_size) 172 | motion_v = motion_v.contiguous().view(bsz, v_len*obj_num, self.hidden_size) 173 | v_mask_expand = v_mask[:,:v_len].unsqueeze(2).expand(bsz, v_len, obj_num).contiguous().view(bsz, v_len*obj_num) 174 | 175 | # object graph convolution 176 | appr_adj = self.appr_adj(appr_v, v_mask_expand) 177 | appr_gcn = self.appr_gcn(appr_v, appr_adj) # b, v_len*obj_num, d 178 | motion_adj = self.motion_adj(motion_v, v_mask_expand) 179 | motion_gcn = self.motion_gcn(motion_v, motion_adj) # b, v_len*obj_num, d 180 | 181 | # vq interaction 182 | appr_vq, _ = self.appr_vq_interact(appr_gcn, q_output, v_mask_expand, q_mask) 183 | motion_vq, _ = self.motion_vq_interact(motion_gcn, q_output, v_mask_expand, q_mask) 184 | 185 | # motion-appr fusion 186 | U = torch.cat([appr_vq, motion_vq], dim=1) # b, 2*q_len, d 187 | q_mask_ = torch.cat([q_mask, q_mask], dim=1) 188 | U_mask = torch.matmul(q_mask_.unsqueeze(2), q_mask_.unsqueeze(2).transpose(1, 2)) 189 | 190 | fusion_out = self.motion_appr_fusion(U, q_hidden, U_mask) 191 | fusion_out = self.attflat(fusion_out, q_mask_) 192 | 193 | out = self.fc(fusion_out).squeeze() 194 | return out 195 | 196 | def make_mask(self, seq, seq_length): 197 | mask = seq 198 | mask = mask.data.new(*mask.size()).fill_(1) 199 | for i, l in enumerate(seq_length): 200 | mask[i][min(mask.size(1)-1, l):] = 0 201 | mask = Variable(mask) # b, seq_len 202 | mask = mask.to(torch.float) 203 | return mask 204 | 205 | def forward_count( 206 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 207 | all_sen_inputs, all_ques_length, answers): 208 | # out of shape (batch_size, ) 209 | out = self.model_block( 210 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 211 | all_sen_inputs, all_ques_length) 212 | predictions = torch.clamp(torch.round(out), min=1, max=10).long() 213 | # answers of shape (batch_size, ) 214 | return out, predictions, answers 215 | 216 | def forward_frameqa( 217 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 218 | all_sen_inputs, all_ques_length, answers, answer_type): 219 | # out of shape (batch_size, num_class) 220 | out = self.model_block( 221 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 222 | all_sen_inputs, all_ques_length) 223 | 224 | _, max_idx = torch.max(out, 1) 225 | # (batch_size, ), dtype is long 226 | predictions = max_idx 227 | # answers of shape (batch_size, ) 228 | return out, predictions, answers 229 | 230 | def forward_trans_or_action( 231 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 232 | all_cand_inputs, all_cand_length, answers, row_index): 233 | all_cand_inputs = all_cand_inputs.permute(1, 0, 2) 234 | all_cand_length = all_cand_length.permute(1, 0) 235 | 236 | all_out = [] 237 | for idx in range(5): 238 | # out of shape (batch_size, ) 239 | out = self.model_block( 240 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 241 | all_cand_inputs[idx], all_cand_length[idx]) 242 | all_out.append(out) 243 | # all_out of shape (batch_size, 5) 244 | all_out = torch.stack(all_out, 0).transpose(1, 0) 245 | _, max_idx = torch.max(all_out, 1) 246 | # (batch_size, ) 247 | predictions = max_idx 248 | 249 | # answers of shape (batch_size, ) 250 | return all_out, predictions, answers 251 | 252 | def forward_msqa( 253 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 254 | all_sen_inputs, all_ques_length, answers): 255 | # out of shape (batch_size, num_class) 256 | out = self.model_block( 257 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length, 258 | all_sen_inputs, all_ques_length) 259 | 260 | _, max_idx = torch.max(out, 1) 261 | # (batch_size, ), dtype is long 262 | predictions = max_idx 263 | # answers of shape (batch_size, ) 264 | return out, predictions, answers 265 | -------------------------------------------------------------------------------- /model/modules/ban/ban.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from jnhwkim's repository. 3 | # https://github.com/jnhwkim/ban-vqa 4 | # -------------------------------------------------------- 5 | 6 | from __future__ import print_function 7 | import math 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn.utils.weight_norm import weight_norm 11 | from model.modules.ban.fc import FCNet 12 | 13 | 14 | class BCNet(nn.Module): 15 | """Simple class for non-linear bilinear connect network 16 | """ 17 | def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=[.2,.5], k=3): 18 | super(BCNet, self).__init__() 19 | 20 | self.c = 32 21 | self.k = k 22 | self.v_dim = v_dim; self.q_dim = q_dim 23 | self.h_dim = h_dim; self.h_out = h_out 24 | 25 | self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0]) 26 | self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0]) 27 | self.dropout = nn.Dropout(dropout[1]) # attention 28 | if 1 < k: 29 | self.p_net = nn.AvgPool1d(self.k, stride=self.k) 30 | 31 | if None == h_out: 32 | pass 33 | elif h_out <= self.c: 34 | self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_()) 35 | self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_()) 36 | else: 37 | self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None) 38 | 39 | def forward(self, v, q): 40 | if None == self.h_out: 41 | v_ = self.v_net(v) 42 | q_ = self.q_net(q) 43 | logits = torch.einsum('bvk,bqk->bvqk', (v_, q_)) 44 | return logits 45 | 46 | # low-rank bilinear pooling using einsum 47 | elif self.h_out <= self.c: 48 | v_ = self.dropout(self.v_net(v)) 49 | q_ = self.q_net(q) 50 | logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias 51 | return logits # b x h_out x v x q 52 | 53 | # batch outer product, linear projection 54 | # memory efficient but slow computation 55 | else: 56 | v_ = self.dropout(self.v_net(v)).transpose(1,2).unsqueeze(3) 57 | q_ = self.q_net(q).transpose(1,2).unsqueeze(2) 58 | d_ = torch.matmul(v_, q_) # b x h_dim x v x q 59 | logits = self.h_net(d_.transpose(1,2).transpose(2,3)) # b x v x q x h_out 60 | return logits.transpose(2,3).transpose(1,2) # b x h_out x v x q 61 | 62 | def forward_with_weights(self, v, q, w): 63 | v_ = self.v_net(v) # b x v x d 64 | q_ = self.q_net(q) # b x q x d 65 | logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_)) 66 | if 1 < self.k: 67 | logits = logits.unsqueeze(1) # b x 1 x d 68 | logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling 69 | return logits 70 | 71 | 72 | class BiAttention(nn.Module): 73 | def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2,.5]): 74 | super(BiAttention, self).__init__() 75 | 76 | self.glimpse = glimpse 77 | self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \ 78 | name='h_mat', dim=None) 79 | 80 | def forward(self, v, q, v_mask, q_mask): 81 | """ 82 | v: [batch, k, vdim] 83 | q: [batch, qdim] 84 | """ 85 | p, logits = self.forward_all(v, q, v_mask, q_mask) 86 | return p, logits 87 | 88 | def forward_all(self, v, q, v_mask, q_mask, logit=False, mask_with=-float('inf')): 89 | v_num = v.size(1) 90 | q_num = q.size(1) 91 | logits = self.logits(v,q) # b x g x v x q 92 | 93 | if v_mask is not None: 94 | v_mask = v_mask.unsqueeze(1).unsqueeze(3).expand_as(logits) 95 | logits = logits - 1e10*(1 - v_mask) 96 | if q_mask is not None: 97 | q_mask = q_mask.unsqueeze(1).unsqueeze(2).expand_as(logits) 98 | logits = logits - 1e10*(1 - q_mask) 99 | 100 | if not logit: 101 | p = nn.functional.softmax(logits.view(-1, self.glimpse, v_num * q_num), 2) 102 | p = p.view(-1, self.glimpse, v_num, q_num) 103 | if v_mask is not None: 104 | p = p * v_mask 105 | if q_mask is not None: 106 | p = p * q_mask 107 | p = p.masked_fill(p != p, 0.) 108 | 109 | return p, logits 110 | 111 | return logits 112 | 113 | 114 | class BAN(nn.Module): 115 | def __init__(self, num_hid, glimpse): 116 | super(BAN, self).__init__() 117 | self.glimpse = glimpse 118 | self.v_att = BiAttention(num_hid, num_hid, num_hid, glimpse) 119 | 120 | b_net = [] 121 | q_prj = [] 122 | for i in range(glimpse): 123 | b_net.append(BCNet(num_hid, num_hid, num_hid, None, k=1)) 124 | q_prj.append(FCNet([num_hid, num_hid], '', .2)) 125 | 126 | self.b_net = nn.ModuleList(b_net) 127 | self.q_prj = nn.ModuleList(q_prj) 128 | self.drop = nn.Dropout(.5) 129 | self.tanh = nn.Tanh() 130 | 131 | def forward(self, v, q_emb, v_mask, q_mask): 132 | """Forward 133 | v: [batch, num_objs, dim] 134 | q: [batch_size, q_len, dim] 135 | v_mask: b, v_len 136 | q_mask: b, q_len 137 | return: logits, not probs 138 | """ 139 | # v_emb = v 140 | b_emb = [0] * self.glimpse 141 | att, logits = self.v_att.forward_all(v, q_emb, v_mask, q_mask) # b x g x v x q 142 | 143 | for g in range(self.glimpse): 144 | b_emb[g] = self.b_net[g].forward_with_weights(v, q_emb, att[:,g,:,:]) # b x l x h 145 | q_emb = self.q_prj[g](b_emb[g].unsqueeze(1)) + q_emb 146 | 147 | return q_emb, att 148 | -------------------------------------------------------------------------------- /model/modules/ban/fc.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from jnhwkim's repository. 3 | # https://github.com/jnhwkim/ban-vqa 4 | # -------------------------------------------------------- 5 | 6 | 7 | from __future__ import print_function 8 | import torch.nn as nn 9 | from torch.nn.utils.weight_norm import weight_norm 10 | 11 | 12 | class FCNet(nn.Module): 13 | """Simple class for non-linear fully connect network 14 | """ 15 | def __init__(self, dims, act='ReLU', dropout=0): 16 | super(FCNet, self).__init__() 17 | 18 | layers = [] 19 | for i in range(len(dims)-2): 20 | in_dim = dims[i] 21 | out_dim = dims[i+1] 22 | if 0 < dropout: 23 | layers.append(nn.Dropout(dropout)) 24 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None)) 25 | if ''!=act: 26 | layers.append(getattr(nn, act)()) 27 | if 0 < dropout: 28 | layers.append(nn.Dropout(dropout)) 29 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None)) 30 | if ''!=act: 31 | layers.append(getattr(nn, act)()) 32 | 33 | self.main = nn.Sequential(*layers) 34 | 35 | def forward(self, x): 36 | return self.main(x) 37 | -------------------------------------------------------------------------------- /model/modules/fusion/fusion.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from cuiyuhao1996's repository. 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # https://github.com/cuiyuhao1996/mcan-vqa 5 | # -------------------------------------------------------- 6 | 7 | from model.modules.fusion.net_utils import FC, MLP, LayerNorm 8 | 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch, math 12 | 13 | 14 | class AttFlat(nn.Module): 15 | def __init__(self, hidden_size, flat_mlp_size, flat_glimpses, flat_out_size, dropout_r=0.1): 16 | super(AttFlat, self).__init__() 17 | 18 | self.flat_glimpses = flat_glimpses 19 | 20 | self.mlp = MLP( 21 | in_size=hidden_size, 22 | mid_size=flat_mlp_size, 23 | out_size=flat_glimpses, 24 | dropout_r=dropout_r, 25 | use_relu=True 26 | ) 27 | 28 | self.linear_merge = nn.Linear( 29 | hidden_size * flat_glimpses, 30 | flat_out_size 31 | ) 32 | 33 | def forward(self, x, x_mask): 34 | att = self.mlp(x) # b, L, glimpse 35 | if x_mask is not None: 36 | x_mask = x_mask.unsqueeze(2) 37 | att = att - 1e10*(1 - x_mask) 38 | att = F.softmax(att, dim=1) 39 | if x_mask is not None: 40 | att = att * x_mask 41 | att = att.masked_fill(att != att, 0.) # b, L, glimpse 42 | 43 | att_list = [] 44 | for i in range(self.flat_glimpses): 45 | att_list.append( 46 | torch.sum(att[:, :, i: i + 1] * x, dim=1) 47 | ) 48 | 49 | x_atted = torch.cat(att_list, dim=1) # b, d 50 | x_atted = self.linear_merge(x_atted) 51 | 52 | return x_atted 53 | 54 | 55 | class FFN(nn.Module): 56 | def __init__(self, hidden_size, ff_size, dropout_r=0.1): 57 | super(FFN, self).__init__() 58 | 59 | self.mlp = MLP( 60 | in_size=hidden_size, 61 | mid_size=ff_size, 62 | out_size=hidden_size, 63 | dropout_r=dropout_r, 64 | use_relu=True 65 | ) 66 | 67 | def forward(self, x): 68 | return self.mlp(x) 69 | 70 | 71 | class AttAdj(nn.Module): 72 | def __init__(self, hidden_size, dropout_r=0.1): 73 | super(AttAdj, self).__init__() 74 | self.hidden_size = hidden_size 75 | self.linear_k = nn.Linear(hidden_size, hidden_size) 76 | self.linear_q = nn.Linear(hidden_size, hidden_size) 77 | 78 | self.dropout = nn.Dropout(dropout_r) 79 | 80 | def forward(self, k, q, mask=None): 81 | ''' 82 | :param k: b, kv_l, d 83 | :param q: b, q_l, d 84 | :param mask: b, q_l, kv_l 85 | ''' 86 | 87 | k = self.linear_k(k) 88 | q = self.linear_q(q) 89 | 90 | adj_scores = self.att(k, q, mask) 91 | 92 | return adj_scores 93 | 94 | def att(self, key, query, mask): 95 | d_k = query.size(-1) 96 | 97 | scores = torch.matmul( 98 | query, key.transpose(-2, -1) 99 | ) / math.sqrt(d_k) 100 | 101 | if mask is not None: 102 | scores = scores - 1e10*(1 - mask) 103 | 104 | att_map = F.softmax(scores, dim=-1) 105 | if mask is not None: 106 | att_map = att_map * mask 107 | att_map = att_map.masked_fill(att_map != att_map, 0.) 108 | att_map = self.dropout(att_map) 109 | 110 | return att_map 111 | 112 | 113 | class MotionApprFusion(nn.Module): 114 | def __init__(self, hidden_size, ff_size, n_layer=1, dropout_r=0.1): 115 | super(MotionApprFusion, self).__init__() 116 | self.n_layer = n_layer 117 | self.hidden_size = hidden_size 118 | 119 | self.appr_att_score = AttAdj(hidden_size, dropout_r) 120 | self.motion_att_score = AttAdj(hidden_size, dropout_r) 121 | self.all_att_score = AttAdj(hidden_size, dropout_r) 122 | 123 | self.appr_linear = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layer)]) 124 | self.motion_linear = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layer)]) 125 | self.all_linear = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layer)]) 126 | 127 | self.appr_norm = LayerNorm(hidden_size) 128 | self.motion_norm = LayerNorm(hidden_size) 129 | self.all_norm = LayerNorm(hidden_size) 130 | 131 | self.ques_guide_att = AttAdj(hidden_size, dropout_r) 132 | 133 | self.ffn = FFN(hidden_size, ff_size, dropout_r) 134 | self.dropout = nn.Dropout(dropout_r) 135 | self.norm = LayerNorm(hidden_size) 136 | 137 | def forward(self, U, q_hid, U_mask=None): 138 | ''' 139 | :param U: b, 2L, d 140 | :param q_hid: b, d 141 | :param U_mask: b, 2L 142 | :return: 143 | ''' 144 | residual = U 145 | bsz, seq_len = U.size(0), U.size(1) 146 | qword_len = int(seq_len/2) 147 | device = U.device 148 | 149 | appr_mask = torch.zeros(bsz, seq_len, seq_len).to(torch.float).to(device) 150 | appr_mask[:,:,:qword_len] = 1 151 | motion_mask = torch.zeros(bsz, seq_len, seq_len).to(torch.float).to(device) 152 | motion_mask[:,:,qword_len:] = 1 153 | 154 | if U_mask is not None: 155 | appr_mask = appr_mask * U_mask 156 | motion_mask = motion_mask * U_mask 157 | 158 | appr_att_adj = self.appr_att_score(U, U, appr_mask) # b, 2L, 2L 159 | motion_att_adj = self.motion_att_score(U, U, motion_mask) 160 | if U_mask is not None: 161 | all_att_adj = self.all_att_score(U, U, U_mask) 162 | else: 163 | all_att_adj = self.all_att_score(U, U, None) 164 | 165 | appr_inp, motion_inp, all_inp = [U, ], [U, ], [U, ] 166 | for i in range(self.n_layer): 167 | appr_x = self.appr_linear[i](appr_inp[i]) # b, 2L, d 168 | appr_x = torch.matmul(appr_att_adj, appr_x) 169 | appr_inp.append(appr_x) 170 | 171 | motion_x = self.motion_linear[i](motion_inp[i]) 172 | motion_x = torch.matmul(motion_att_adj, motion_x) 173 | motion_inp.append(motion_x) 174 | 175 | all_x = self.all_linear[i](all_inp[i]) 176 | all_x = torch.matmul(all_att_adj, all_x) 177 | all_inp.append(all_x) 178 | 179 | appr_x = self.appr_norm(residual + appr_x) 180 | motion_x = self.motion_norm(residual + motion_x) 181 | all_x = self.all_norm(residual + all_x) 182 | 183 | graph_out = torch.cat([appr_x.unsqueeze(1), motion_x.unsqueeze(1), all_x.unsqueeze(1)], dim=1) # b, 3, 2L, d 184 | fusion_k = torch.sum(graph_out, dim=2) # b, 3, d 185 | 186 | fusion_q = q_hid.unsqueeze(1) # b, 1, d 187 | fusion_att_score = self.ques_guide_att(fusion_k, fusion_q).squeeze() # b, 3 188 | 189 | fusion_att = graph_out * fusion_att_score.unsqueeze(2).unsqueeze(3) # b, 3, 2L, d 190 | fusion_att = torch.sum(fusion_att, dim=1) # b, 2L, d 191 | 192 | fusion_out = self.norm(fusion_att + self.ffn(self.dropout(fusion_att))) 193 | 194 | return fusion_out 195 | 196 | -------------------------------------------------------------------------------- /model/modules/fusion/net_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from cuiyuhao1996's repository. 3 | # Licensed under The MIT License [see LICENSE for details] 4 | # https://github.com/cuiyuhao1996/mcan-vqa 5 | # -------------------------------------------------------- 6 | 7 | import torch.nn as nn 8 | import torch 9 | 10 | 11 | class FC(nn.Module): 12 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True): 13 | super(FC, self).__init__() 14 | self.dropout_r = dropout_r 15 | self.use_relu = use_relu 16 | 17 | self.linear = nn.Linear(in_size, out_size) 18 | 19 | if use_relu: 20 | self.relu = nn.ReLU(inplace=True) 21 | 22 | if dropout_r > 0: 23 | self.dropout = nn.Dropout(dropout_r) 24 | 25 | def forward(self, x): 26 | x = self.linear(x) 27 | 28 | if self.use_relu: 29 | x = self.relu(x) 30 | 31 | if self.dropout_r > 0: 32 | x = self.dropout(x) 33 | 34 | return x 35 | 36 | 37 | class MLP(nn.Module): 38 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True): 39 | super(MLP, self).__init__() 40 | 41 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu) 42 | self.linear = nn.Linear(mid_size, out_size) 43 | 44 | def forward(self, x): 45 | return self.linear(self.fc(x)) 46 | 47 | 48 | class LayerNorm(nn.Module): 49 | def __init__(self, size, eps=1e-6): 50 | super(LayerNorm, self).__init__() 51 | self.eps = eps 52 | 53 | self.a_2 = nn.Parameter(torch.ones(size)) 54 | self.b_2 = nn.Parameter(torch.zeros(size)) 55 | 56 | def forward(self, x): 57 | mean = x.mean(-1, keepdim=True) 58 | std = x.std(-1, keepdim=True) 59 | 60 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2 -------------------------------------------------------------------------------- /model/modules/gcn.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | import math 7 | import torch 8 | import numpy as np 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch.nn.parameter import Parameter 12 | from torch.nn.modules.module import Module 13 | from torch.autograd import Variable 14 | 15 | 16 | class GraphConvolution(Module): 17 | """ 18 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 19 | """ 20 | 21 | def __init__(self, in_features, out_features): 22 | super(GraphConvolution, self).__init__() 23 | self.weight = nn.Linear(in_features, out_features, bias=False) 24 | self.layer_norm = nn.LayerNorm(out_features, elementwise_affine=False) 25 | 26 | def forward(self, input, adj): 27 | # self.weight of shape (hidden_size, hidden_size) 28 | support = self.weight(input) 29 | output = torch.bmm(adj, support) 30 | output = self.layer_norm(output) 31 | return output 32 | 33 | 34 | class GCN(nn.Module): 35 | 36 | def __init__( 37 | self, input_size, hidden_size, num_classes, num_layers=1, 38 | dropout=0.1): 39 | super(GCN, self).__init__() 40 | self.layers = nn.ModuleList() 41 | self.layers.append(GraphConvolution(input_size, hidden_size)) 42 | for i in range(num_layers - 1): 43 | self.layers.append(GraphConvolution(hidden_size, hidden_size)) 44 | self.layers.append(GraphConvolution(hidden_size, num_classes)) 45 | 46 | self.layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False) 47 | self.dropout = nn.Dropout(p=dropout) 48 | 49 | def forward(self, x, adj): 50 | x_gcn = [x] 51 | for i, layer in enumerate(self.layers): 52 | x_gcn.append(self.dropout(F.relu(layer(x_gcn[i], adj)))) 53 | 54 | x = self.layernorm(x + x_gcn[-1]) 55 | return x 56 | 57 | 58 | class VideoAdjLearner(Module): 59 | 60 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1, scale=100): 61 | super().__init__() 62 | self.scale = scale 63 | 64 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False) 65 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False) 66 | 67 | # Regularization 68 | self.dropout = nn.Dropout(p=dropout) 69 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1) 70 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2) 71 | 72 | def forward(self, v, v_mask=None): 73 | ''' 74 | :param v: (b, v_len, d) 75 | :param v_mask: (b, v_len) 76 | :return: adj: (b, v_len, v_len) 77 | ''' 78 | # layer 1 79 | h = self.edge_layer_1(v) # b, v_l, d 80 | h = F.relu(h) 81 | 82 | # layer 2 83 | h = self.edge_layer_2(h) # b, v_l, d 84 | h = F.relu(h) 85 | 86 | # outer product 87 | adj = torch.bmm(h, h.transpose(1, 2)) # b, v_l, v_l 88 | 89 | if v_mask is not None: 90 | adj_mask = adj.data.new(*adj.size()).fill_(1) 91 | v_mask_ = torch.matmul(v_mask.unsqueeze(2), v_mask.unsqueeze(2).transpose(1, 2)) 92 | adj_mask = adj_mask * v_mask_ 93 | 94 | adj_mask = Variable(adj_mask) 95 | adj = adj - 1e10*(1 - adj_mask) 96 | 97 | adj = F.softmax(adj * self.scale, dim=-1) 98 | adj = adj * adj_mask 99 | adj = adj.masked_fill(adj != adj, 0.) 100 | else: 101 | adj = F.softmax(adj * self.scale, dim=-1) 102 | 103 | return adj 104 | 105 | -------------------------------------------------------------------------------- /model/modules/linear_weightdrop.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | 7 | 8 | from torch.nn import Parameter 9 | import torch 10 | 11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 12 | 13 | 14 | class weight_drop(): 15 | 16 | def __init__(self, module, weights, dropout): 17 | for name_w in weights: 18 | w = getattr(module, name_w) 19 | del module._parameters[name_w] 20 | module.register_parameter(name_w + '_raw', Parameter(w)) 21 | 22 | self.original_module_forward = module.forward 23 | 24 | self.weights = weights 25 | self.module = module 26 | self.dropout = dropout 27 | 28 | def __call__(self, *args, **kwargs): 29 | for name_w in self.weights: 30 | raw_w = getattr(self.module, name_w + '_raw') 31 | w = torch.nn.functional.dropout( 32 | raw_w, p=self.dropout, training=self.module.training) 33 | # module.register_parameter(name_w, Parameter(w)) 34 | setattr(self.module, name_w, Parameter(w)) 35 | 36 | return self.original_module_forward(*args, **kwargs) 37 | 38 | 39 | def _weight_drop(module, weights, dropout): 40 | setattr(module, 'forward', weight_drop(module, weights, dropout)) 41 | 42 | 43 | class WeightDrop(torch.nn.Module): 44 | """ 45 | The weight-dropped module applies recurrent regularization through a DropConnect mask on the 46 | hidden-to-hidden recurrent weights. 47 | **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. Here is 48 | their `License 49 | `__. 50 | Args: 51 | module (:class:`torch.nn.Module`): Containing module. 52 | weights (:class:`list` of :class:`str`): Names of the module weight parameters to apply a 53 | dropout too. 54 | dropout (float): The probability a weight will be dropped. 55 | Example: 56 | >>> from torchnlp.nn import WeightDrop 57 | >>> import torch 58 | >>> 59 | >>> torch.manual_seed(123) 60 | >> 62 | >>> gru = torch.nn.GRUCell(2, 2) 63 | >>> weights = ['weight_hh'] 64 | >>> weight_drop_gru = WeightDrop(gru, weights, dropout=0.9) 65 | >>> 66 | >>> input_ = torch.randn(3, 2) 67 | >>> hidden_state = torch.randn(3, 2) 68 | >>> weight_drop_gru(input_, hidden_state) 69 | tensor(... grad_fn=) 70 | """ 71 | 72 | def __init__(self, module, weights, dropout=0.0): 73 | super(WeightDrop, self).__init__() 74 | _weight_drop(module, weights, dropout) 75 | self.forward = module.forward 76 | 77 | 78 | class WeightDropLinear(torch.nn.Linear): 79 | """ 80 | Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument. 81 | Args: 82 | weight_dropout (float): The probability a weight will be dropped. 83 | """ 84 | 85 | def __init__(self, *args, weight_dropout=0.0, **kwargs): 86 | super().__init__(*args, **kwargs) 87 | weights = ['weight'] 88 | _weight_drop(self, weights, weight_dropout) 89 | -------------------------------------------------------------------------------- /model/modules/position_embedding.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from jayleicn's repository. 3 | # https://github.com/jayleicn/TVQAplus 4 | # -------------------------------------------------------- 5 | 6 | import torch 7 | import math 8 | import torch.nn as nn 9 | 10 | 11 | class PositionEncoding(nn.Module): 12 | def __init__(self, n_filters=128, max_len=500): 13 | super(PositionEncoding, self).__init__() 14 | # Compute the positional encodings once in log space. 15 | pe = torch.zeros(max_len, n_filters) # (L, D) 16 | position = torch.arange(0, max_len).float().unsqueeze(1) 17 | div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters)) 18 | pe[:, 0::2] = torch.sin(position * div_term) 19 | pe[:, 1::2] = torch.cos(position * div_term) 20 | self.register_buffer('pe', pe) # buffer is a tensor, not a variable, (L, D) 21 | 22 | def forward(self, x): 23 | """ 24 | :Input: (*, L, D) 25 | :Output: (*, L, D) the same size as input 26 | """ 27 | pe = self.pe.data[:x.size(-2), :] # (#x.size(-2), n_filters) 28 | extra_dim = len(x.size()) - 2 29 | for _ in range(extra_dim): 30 | pe = pe.unsqueeze(0) 31 | 32 | return pe 33 | 34 | 35 | def test_pos_enc(): 36 | mdl = PositionEncoding() 37 | 38 | batch_size = 8 39 | n_channels = 128 40 | n_items = 60 41 | 42 | input = torch.ones(batch_size, n_items, n_channels) 43 | 44 | out = mdl(input) 45 | print(out) 46 | 47 | 48 | if __name__ == '__main__': 49 | test_pos_enc() 50 | -------------------------------------------------------------------------------- /model/modules/rnn_encoder.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch 9 | import numpy as np 10 | 11 | 12 | class BaseRNN(nn.Module): 13 | 14 | # SYM_MASK = "MASK" 15 | # SYM_EOS = "EOS" 16 | 17 | def __init__( 18 | self, input_size, hidden_size, input_dropout_p, dropout_p, n_layers, 19 | rnn_cell): 20 | super(BaseRNN, self).__init__() 21 | self.vocab_size = input_size 22 | self.hidden_size = hidden_size 23 | self.n_layers = n_layers 24 | self.input_dropout_p = input_dropout_p 25 | self.dropout = nn.Dropout(p=input_dropout_p) 26 | 27 | if rnn_cell.lower() == 'lstm': 28 | self.rnn_cell = nn.LSTM 29 | elif rnn_cell.lower() == 'gru': 30 | self.rnn_cell = nn.GRU 31 | else: 32 | raise ValueError("Unsupported RNN Cell: {}".format(rnn_cell)) 33 | 34 | self.dropout_p = dropout_p 35 | 36 | self.compress_hn_layers = nn.Linear( 37 | n_layers * hidden_size, hidden_size, bias=False) 38 | self.compress_hn_layers_bi = nn.Linear( 39 | n_layers * 2 * hidden_size, hidden_size, bias=False) 40 | self.compress_hn_bi = nn.Linear( 41 | 2 * hidden_size, hidden_size, bias=False) 42 | 43 | self.compress_output = nn.Linear( 44 | 2 * hidden_size, hidden_size, bias=False) 45 | self.compress_output_dropout = nn.Dropout(p=input_dropout_p) 46 | 47 | def forward(self, *args, **kwargs): 48 | raise NotImplementedError() 49 | 50 | 51 | class SentenceEncoderRNN(BaseRNN): 52 | r""" 53 | Applies a multi-layer RNN to an input sequence. 54 | 55 | variable_lengths: if use variable length RNN (default: False) 56 | 57 | Args: 58 | input_var (batch, seq_len, dim): glove embedding feature 59 | input_lengths (list of int, optional): A list that contains the lengths of sequences 60 | in the mini-batch 61 | 62 | 63 | Returns: output, hidden 64 | - **output** (batch, seq_len, hidden_size): variable containing the encoded features of the input sequence 65 | - **hidden** (num_layers * num_directions, batch, hidden_size): variable containing the features in the hidden state h 66 | """ 67 | 68 | def __init__( 69 | self, 70 | vocab_size, 71 | hidden_size, 72 | input_dropout_p=0, 73 | dropout_p=0, 74 | n_layers=1, 75 | bidirectional=False, 76 | rnn_cell='gru', 77 | variable_lengths=True): 78 | super().__init__( 79 | vocab_size, hidden_size, input_dropout_p, dropout_p, n_layers, 80 | rnn_cell) 81 | 82 | self.variable_lengths = variable_lengths 83 | self.n_layers = n_layers 84 | self.bidirectional = bidirectional 85 | self.rnn_name = rnn_cell 86 | 87 | self.rnn = self.rnn_cell( 88 | hidden_size, 89 | hidden_size, 90 | n_layers, 91 | batch_first=True, 92 | bidirectional=bidirectional, 93 | dropout=dropout_p) 94 | 95 | def forward(self, input_var, h_0=None, input_lengths=None): 96 | batch_size = input_var.size()[0] 97 | embedded = input_var 98 | 99 | if self.variable_lengths: 100 | embedded = nn.utils.rnn.pack_padded_sequence( 101 | embedded, input_lengths, batch_first=True, enforce_sorted=False) 102 | 103 | # output of shape (batch, seq_len, num_directions * hidden_size) 104 | # h_n of shape (num_layers * num_directions, batch, hidden_size) 105 | if self.rnn_name == 'gru': 106 | output, hidden = self.rnn(embedded, h_0) 107 | else: 108 | output, (hidden, _) = self.rnn(embedded, h_0) 109 | 110 | if self.variable_lengths: 111 | total_length = input_var.size()[1] 112 | output, _ = nn.utils.rnn.pad_packed_sequence( 113 | output, batch_first=True, total_length=None) 114 | 115 | if self.n_layers > 1 and self.bidirectional: 116 | output = self.dropout(F.relu(self.compress_output(output))) 117 | hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) 118 | hidden = self.dropout(F.relu(self.compress_hn_layers_bi(hidden))) 119 | elif self.n_layers > 1: 120 | hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) 121 | hidden = self.dropout(F.relu(self.compress_hn_layers(hidden))) 122 | elif self.bidirectional: 123 | output = self.dropout(F.relu(self.compress_output(output))) 124 | hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1) 125 | hidden = self.dropout(F.relu(self.compress_hn_bi(hidden))) 126 | # output of shape (batch, seq_len, hidden_size) hidden of shape (batch, hidden_size) 127 | 128 | return output, hidden 129 | -------------------------------------------------------------------------------- /model_overview.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ajseo17/MASN-pytorch/5ca3fc80cf37f7b6124070b1aae5bc599db8fa29/model_overview.jpeg -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # This code is modified from Jumpin2's repository. 3 | # https://github.com/Jumpin2/HGA 4 | # -------------------------------------------------------- 5 | """ Common utilities. """ 6 | 7 | # Logging 8 | # ======= 9 | 10 | import logging 11 | import os, os.path 12 | from colorlog import ColoredFormatter 13 | import torch 14 | 15 | ch = logging.StreamHandler() 16 | ch.setLevel(logging.DEBUG) 17 | 18 | formatter = ColoredFormatter( 19 | "%(log_color)s[%(asctime)s] %(message)s", 20 | # datefmt='%H:%M:%S.%f', 21 | datefmt=None, 22 | reset=True, 23 | log_colors={ 24 | 'DEBUG': 'cyan', 25 | 'INFO': 'white,bold', 26 | 'INFOV': 'cyan,bold', 27 | 'WARNING': 'yellow', 28 | 'ERROR': 'red,bold', 29 | 'CRITICAL': 'red,bg_white', 30 | }, 31 | secondary_log_colors={}, 32 | style='%') 33 | ch.setFormatter(formatter) 34 | 35 | log = logging.getLogger('videocap') 36 | log.setLevel(logging.DEBUG) 37 | log.handlers = [] # No duplicated handlers 38 | log.propagate = False # workaround for duplicated logs in ipython 39 | log.addHandler(ch) 40 | 41 | logging.addLevelName(logging.INFO + 1, 'INFOV') 42 | 43 | 44 | def _infov(self, msg, *args, **kwargs): 45 | self.log(logging.INFO + 1, msg, *args, **kwargs) 46 | 47 | 48 | logging.Logger.infov = _infov 49 | 50 | 51 | class AverageMeter(object): 52 | """Computes and stores the average and current value""" 53 | 54 | def __init__(self): 55 | self.reset() 56 | 57 | def reset(self): 58 | self.val = 0 59 | self.avg = 0 60 | self.sum = 0 61 | self.count = 0 62 | 63 | def update(self, val, n=1): 64 | self.val = val 65 | self.sum += val * n 66 | self.count += n 67 | self.avg = self.sum / self.count 68 | 69 | 70 | class StrToBytes: 71 | 72 | def __init__(self, fileobj): 73 | self.fileobj = fileobj 74 | 75 | def read(self, size): 76 | return self.fileobj.read(size).encode() 77 | 78 | def readline(self, size=-1): 79 | return self.fileobj.readline(size).encode() 80 | 81 | 82 | def get_accuracy(logits, targets): 83 | correct = torch.sum(logits.eq(targets)).float() 84 | return correct * 100.0 / targets.size(0) 85 | 86 | 87 | class nvidia_prefetcher(): 88 | def __init__(self, loader): 89 | self.loader = iter(loader) 90 | self.stream = torch.cuda.Stream() 91 | self.preload() 92 | 93 | def preload(self): 94 | try: 95 | self.next_data = next(self.loader) 96 | except StopIteration: 97 | self.next_data = None 98 | return 99 | with torch.cuda.stream(self.stream): 100 | self.next_data = [d.cuda(non_blocking=True) for d in self.next_data] 101 | # With Amp, it isn't necessary to manually convert data to half. 102 | # if args.fp16: 103 | # self.next_input = self.next_input.half() 104 | # else: 105 | 106 | def next(self): 107 | torch.cuda.current_stream().wait_stream(self.stream) 108 | if self.next_data is None: 109 | raise StopIteration 110 | next_data = self.next_data 111 | self.preload() 112 | return next_data 113 | 114 | def __next__(self): 115 | return self.next() 116 | 117 | def __iter__(self): 118 | return self -------------------------------------------------------------------------------- /warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | Args: 9 | optimizer (Optimizer): Wrapped optimizer. 10 | multiplier: target learning rate = base lr * multiplier 11 | total_epoch: target learning rate is reached at total_epoch, gradually 12 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 13 | """ 14 | 15 | def __init__( 16 | self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier <= 1.: 19 | raise ValueError('multiplier should be greater than 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super().__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [ 30 | base_lr * self.multiplier for base_lr in self.base_lrs 31 | ] 32 | self.finished = True 33 | return self.after_scheduler.get_lr() 34 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 35 | 36 | return [ 37 | base_lr * 38 | ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) 39 | for base_lr in self.base_lrs 40 | ] 41 | 42 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 43 | if epoch is None: 44 | epoch = self.last_epoch + 1 45 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 46 | if self.last_epoch <= self.total_epoch: 47 | warmup_lr = [ 48 | base_lr * ( 49 | (self.multiplier - 1.) * self.last_epoch / self.total_epoch 50 | + 1.) for base_lr in self.base_lrs 51 | ] 52 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 53 | param_group['lr'] = lr 54 | else: 55 | if epoch is None: 56 | self.after_scheduler.step(metrics, None) 57 | else: 58 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 59 | 60 | def step(self, epoch=None, metrics=None): 61 | if type(self.after_scheduler) != ReduceLROnPlateau: 62 | if self.finished and self.after_scheduler: 63 | if epoch is None: 64 | self.after_scheduler.step(None) 65 | else: 66 | self.after_scheduler.step(epoch - self.total_epoch) 67 | else: 68 | return super(GradualWarmupScheduler, self).step(epoch) 69 | else: 70 | self.step_ReduceLROnPlateau(metrics, epoch) --------------------------------------------------------------------------------