├── LICENSE
├── README.md
├── adaptive_detection_features_converter.py
├── data_utils
├── __init__.py
├── data_util.py
├── dataset.py
├── dataset_msrvtt.py
└── dataset_msvd.py
├── embed_loss.py
├── main.py
├── main_msrvtt.py
├── main_msvd.py
├── model
├── masn.py
└── modules
│ ├── ban
│ ├── ban.py
│ └── fc.py
│ ├── fusion
│ ├── fusion.py
│ └── net_utils.py
│ ├── gcn.py
│ ├── linear_weightdrop.py
│ ├── position_embedding.py
│ └── rnn_encoder.py
├── model_overview.jpeg
├── util.py
└── warmup_scheduler.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 서아정
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | Motion-Appearance Synergistic Networks for VideoQA (MASN)
2 | ========================================================================
3 |
4 | Pytorch Implementation for the paper:
5 |
6 | **[Attend What You Need: Motion-Appearance Synergistic Networks for Video Question Answering][1]**
7 | Ahjeong Seo, [Gi-Cheon Kang](https://gicheonkang.com), Joonhan Park, and [Byoung-Tak Zhang](https://bi.snu.ac.kr/~btzhang/)
8 | In ACL 2021
9 |
10 |
11 |
12 |
13 | Requirements
14 | --------
15 | python 3.7, pytorch 1.2.0
16 |
17 |
18 | Dataset
19 | --------
20 | - Download [TGIF-QA](https://github.com/YunseokJANG/tgif-qa) dataset and refer to the [paper](https://arxiv.org/abs/1704.04497) for details.
21 | - Download [MSVD-QA and MSRVTT-QA](https://github.com/xudejing/video-question-answering).
22 |
23 | Extract Features
24 | --------
25 | 1. Appearance Features
26 | - For local features, we used the Faster-RCNN pre-trained with Visual Genome. Please cite this [Link](https://github.com/peteanderson80/bottom-up-attention).
27 | * After you extracted object features by Faster-RCNN, you can convert them to hdf5 file with simple run: `python adaptive_detection_features_converter.py`
28 | - For global features, we used ResNet152 provided by torchvision. Please cite this [Link](https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py).
29 |
30 | 2. Motion Features
31 | - For local features, we use RoIAlign with bounding box features obtained from Faster-RCNN. Please cite this [Link](https://github.com/AceCoooool/RoIAlign-RoIPool-pytorch).
32 | - For global features, we use I3D pre-trained on Kinetics. Please cite this [Link](https://github.com/Tushar-N/pytorch-resnet3d).
33 |
34 |
35 | We uploaded our extracted features:
36 | 1) TGIF-QA
37 | * [`res152_avgpool.hdf5`][2]: appearance global features (3GB).
38 | * [`tgif_btup_f_obj10.hdf5`][3]: appearance local features (30GB).
39 | * [`tgif_i3d_hw7_perclip_avgpool.hdf5`][4]: motion global features (3GB).
40 | * [`tgif_i3d_roialign_hw7_perclip_avgpool.hdf5`][5]: motion local features (59GB).
41 |
42 | 2) MSRVTT-QA
43 | * [`msrvtt_res152_avgpool.hdf5`][10]: appearance global features (1.7GB).
44 | * [`msrvtt_btup_f_obj10.hdf5`][11]: appearance local features (17GB).
45 | * [`msrvtt_i3d_avgpool_perclip.hdf5`][12]: motion global features (1.7GB).
46 | * [`msrvtt_i3d_roialign_perclip_obj10.hdf5`][13]: motion local features (34GB).
47 |
48 | 3) MSVD-QA
49 | * [`msvd_res152_avgpool.hdf5`][14]: appearance global features (220MB).
50 | * [`msvd_btup_f_obj10.hdf5`][15]: appearance local features (2.2GB).
51 | * [`msvd_i3d_avgpool_perclip.hdf5`][16]: motion global features (220MB).
52 | * [`msvd_i3d_roialign_perclip_obj10.hdf5`][17]: motion local features (4.2GB).
53 |
54 |
55 | Training
56 | --------
57 | Simple run
58 | ```sh
59 | CUDA_VISIBLE_DEVICES=0 python main.py --task Count --batch_size 32
60 | ```
61 |
62 | For MSRVTT-QA, run
63 | ```sh
64 | CUDA_VISIBLE_DEVICES=0 python main_msrvtt.py --task MS-QA --batch_size 32
65 | ```
66 |
67 | For MSVD-QA, run
68 | ```sh
69 | CUDA_VISIBLE_DEVICES=0 python main_msvd.py --task MS-QA --batch_size 32
70 | ```
71 |
72 | ### Saving model checkpoints
73 | By default, our model save model checkpoints at every epoch. You can change the path for saving models by `--save_path` options.
74 | Each checkpoint's name is '[TASK]_[PERFORMANCE].pth' in default.
75 |
76 |
77 | Evaluation & Results
78 | --------
79 | ```sh
80 | CUDA_VISIBLE_DEVICES=0 python main.py --test --checkpoint [NAME] --task Count --batch_size 32
81 | ```
82 |
83 | Performance on TGIF-QA dataset:
84 |
85 | Model | Count | Action | Trans. | FrameQA |
86 | ------- | ------ | ------ | ------ | ------ |
87 | MASN | 3.75 | 84.4 | 87.4 | 59.5|
88 |
89 | You can download our pre-trained model by this link : [`Count`][6], [`Action`][7], [`Trans.`][8], [`FrameQA`][9]
90 |
91 | Performance on MSRVTT-QA and MSVD-QA dataset:
92 | Model | MSRVTT-QA | MSVD-QA |
93 | ------- | ------ | ------ |
94 | MASN | 35.2 | 38.0 |
95 |
96 |
97 | Citation
98 | --------
99 | If this repository is helpful for your research, we'd really appreciate it if you could cite the following paper:
100 | ```text
101 | @inproceedings{seo-etal-2021-attend,
102 | title = "Attend What You Need: Motion-Appearance Synergistic Networks for Video Question Answering",
103 | author = "Seo, Ahjeong and
104 | Kang, Gi-Cheon and
105 | Park, Joonhan and
106 | Zhang, Byoung-Tak",
107 | booktitle = "Proceedings of the 59th Annual Meeting of the Association for Computational Linguistics and the 11th International Joint Conference on Natural Language Processing (Volume 1: Long Papers)",
108 | month = aug,
109 | year = "2021",
110 | address = "Online",
111 | publisher = "Association for Computational Linguistics",
112 | url = "https://aclanthology.org/2021.acl-long.481",
113 | doi = "10.18653/v1/2021.acl-long.481",
114 | pages = "6167--6177",
115 | abstract = "Video Question Answering is a task which requires an AI agent to answer questions grounded in video. This task entails three key challenges: (1) understand the intention of various questions, (2) capturing various elements of the input video (e.g., object, action, causality), and (3) cross-modal grounding between language and vision information. We propose Motion-Appearance Synergistic Networks (MASN), which embed two cross-modal features grounded on motion and appearance information and selectively utilize them depending on the question{'}s intentions. MASN consists of a motion module, an appearance module, and a motion-appearance fusion module. The motion module computes the action-oriented cross-modal joint representations, while the appearance module focuses on the appearance aspect of the input video. Finally, the motion-appearance fusion module takes each output of the motion module and the appearance module as input, and performs question-guided fusion. As a result, MASN achieves new state-of-the-art performance on the TGIF-QA and MSVD-QA datasets. We also conduct qualitative analysis by visualizing the inference results of MASN.",
116 | }
117 |
118 | ```
119 |
120 |
121 | License
122 | --------
123 | MIT License
124 |
125 | Acknowledgements
126 | --------
127 | This work was partly supported by the Institute of Information & Communications Technology Planning & Evaluation (2015-0-00310-SW.StarLab/25%, 2017-0-01772-VTT/25%, 2018-0-00622-RMI/25%, 2019-0-01371-BabyMind/25%) grant funded by the Korean government.
128 |
129 |
130 | [1]: https://aclanthology.org/2021.acl-long.481/
131 | [2]: https://drive.google.com/file/d/1tWY3gU4XohzhZjV5Wia5L8XqfaV10127/view?usp=sharing
132 | [3]: https://drive.google.com/file/d/1rxLL6eqi3d9FXKq7e4Wx7jiisu7_gzJa/view?usp=sharing
133 | [4]: https://drive.google.com/file/d/1ejP_V3CuJFB_jaUYf-OM9up5bsnnETP3/view?usp=sharing
134 | [5]: https://drive.google.com/file/d/1JbHWs0yTExL7Lc_abCvaXX49IsazUVvw/view?usp=sharing
135 | [6]: https://drive.google.com/file/d/1Z3r20wd2Mxco47WWggmNKazonfYnUDy1/view?usp=sharing
136 | [7]: https://drive.google.com/file/d/1USUA5D9bN5Ar9rClfdhOUHiTYdX1di1P/view?usp=sharing
137 | [8]: https://drive.google.com/file/d/1jZLDt14ZRmfHEqc8Yat7beQA6n-N6-h7/view?usp=sharing
138 | [9]: https://drive.google.com/file/d/1bXGlOKWrqUlEOer2cRNIJ2654_H_2UeR/view?usp=sharing
139 | [10]: https://drive.google.com/file/d/16UswbSjfhHBBUih-cGCZgvurNLq-gOKx/view?usp=sharing
140 | [11]: https://drive.google.com/file/d/1KdsLDW3oE-xNtrzsoYKZv9N_hauOR1Of/view?usp=sharing
141 | [12]: https://drive.google.com/file/d/1mX0oxSQXDS2h2Fxz091q6NdKuHihr0Fj/view?usp=sharing
142 | [13]: https://drive.google.com/file/d/1wQERtue5TY3zEZJwX0u2t19ARhU4mhtY/view?usp=sharing
143 | [14]: https://drive.google.com/file/d/1XtQNShBMbW3jNwuZPMYP9p-5QbpgtHF6/view?usp=sharing
144 | [15]: https://drive.google.com/file/d/1efxWKIGxvmEV5nR9iJosTMOvpMzHwjlG/view?usp=sharing
145 | [16]: https://drive.google.com/file/d/143miiDN3m9-QqptxtA6U6BJfSP8XOcoW/view?usp=sharing
146 | [17]: https://drive.google.com/file/d/14DUT3_yazEFYqZRjzWgrZ6K3lYu0XfHm/view?usp=sharing
147 |
--------------------------------------------------------------------------------
/adaptive_detection_features_converter.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is modified from Jinhwa Kim's repository.
3 | https://github.com/jnhwkim/ban-vqa
4 | Reads in a tsv file with pre-trained bottom up attention features
5 | of the adaptive number of boxes and stores it in HDF5 format.
6 | Also store {image_id: feature_idx} as a pickle file.
7 | Hierarchy of HDF5 file:
8 | { 'image_features': num_boxes x 2048
9 | 'image_bb': num_boxes x 4
10 | 'spatial_features': num_boxes x 6
11 | 'pos_boxes': num_images x 2 }
12 | """
13 |
14 | from __future__ import print_function
15 |
16 | import os
17 | import sys
18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19 |
20 | import base64
21 | import csv
22 | import h5py
23 | import pickle as cPickle
24 | import numpy as np
25 | # from utils import utils
26 |
27 |
28 | def load_image_ids(base_dir):
29 | ''' Load a list of (path,image_id tuples). Modify this to suit your data locations. '''
30 | ids = set()
31 | for gif_dir in os.listdir(base_dir):
32 | img_dir = os.path.join(base_dir, gif_dir)
33 | for img_name in os.listdir(img_dir):
34 | gif_name = gif_dir.split('.')[0]
35 | image_id = img_name.split('.')[0]
36 | ids.add(gif_name + '/' + image_id)
37 | return ids
38 |
39 | csv.field_size_limit(sys.maxsize)
40 |
41 | def extract(split, infiles):
42 | FIELDNAMES = ['image_id', 'image_w', 'image_h', 'num_boxes', 'boxes', 'features']
43 | # feature data name to write
44 | data_file = {
45 | 'all': './tgif_all_btmup_f.hdf5'
46 | }
47 | indices_file = {
48 | 'all': './tgif_all_imgid2idx.pkl'
49 | }
50 | ids_file = {
51 | 'all': './tgif_all_ids.pkl'
52 | }
53 | path_imgs = {
54 | 'all': '/root/data/TGIFQA/frames',
55 | }
56 | known_num_boxes = {
57 | 'all': None
58 | }
59 | feature_length = 2048
60 | min_fixed_boxes = 10
61 | max_fixed_boxes = 100
62 |
63 | if os.path.exists(ids_file[split]):
64 | imgids = cPickle.load(open(ids_file[split], 'rb'))
65 | else:
66 | imgids = load_image_ids(path_imgs[split])
67 | cPickle.dump(imgids, open(ids_file[split], 'wb'))
68 |
69 | h = h5py.File(data_file[split], 'w')
70 |
71 | if known_num_boxes[split] is None:
72 | num_boxes = 0
73 | for infile in infiles:
74 | print("reading tsv...%s" % infile)
75 | with open(infile, "r+") as tsv_in_file:
76 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES)
77 | for item in reader:
78 | item['num_boxes'] = int(item['num_boxes'])
79 | image_id = str(item['image_id'])
80 | if image_id in imgids:
81 | num_boxes += item['num_boxes']
82 | else:
83 | num_boxes = known_num_boxes[split]
84 |
85 | print('num_boxes=%d' % num_boxes)
86 |
87 | img_features = h.create_dataset(
88 | 'image_features', (num_boxes, feature_length), 'f')
89 | img_bb = h.create_dataset(
90 | 'image_bb', (num_boxes, 4), 'f')
91 | spatial_img_features = h.create_dataset(
92 | 'spatial_features', (num_boxes, 6), 'f')
93 | pos_boxes = h.create_dataset(
94 | 'pos_boxes', (len(imgids), 2), dtype='int32')
95 |
96 | counter = 0
97 | num_boxes = 0
98 | indices = {}
99 |
100 | for infile in infiles:
101 | unknown_ids = []
102 | print("reading tsv...%s" % infile)
103 | with open(infile, "r+") as tsv_in_file:
104 | reader = csv.DictReader(tsv_in_file, delimiter='\t', fieldnames=FIELDNAMES)
105 | for item in reader:
106 | item['num_boxes'] = int(item['num_boxes'])
107 | item['boxes'] = bytes(item['boxes']).encode('utf-8')
108 | item['features'] = bytes(item['features']).encode('utf-8')
109 | image_id = str(item['image_id'])
110 | image_w = float(item['image_w'])
111 | image_h = float(item['image_h'])
112 | bboxes = np.frombuffer(
113 | base64.decodestring(item['boxes']),
114 | dtype=np.float32).reshape((item['num_boxes'], -1))
115 |
116 | box_width = bboxes[:, 2] - bboxes[:, 0]
117 | box_height = bboxes[:, 3] - bboxes[:, 1]
118 | scaled_width = box_width / image_w
119 | scaled_height = box_height / image_h
120 | scaled_x = bboxes[:, 0] / image_w
121 | scaled_y = bboxes[:, 1] / image_h
122 |
123 | box_width = box_width[..., np.newaxis]
124 | box_height = box_height[..., np.newaxis]
125 | scaled_width = scaled_width[..., np.newaxis]
126 | scaled_height = scaled_height[..., np.newaxis]
127 | scaled_x = scaled_x[..., np.newaxis]
128 | scaled_y = scaled_y[..., np.newaxis]
129 |
130 | spatial_features = np.concatenate(
131 | (scaled_x,
132 | scaled_y,
133 | scaled_x + scaled_width,
134 | scaled_y + scaled_height,
135 | scaled_width,
136 | scaled_height),
137 | axis=1)
138 |
139 | if image_id in imgids:
140 | imgids.remove(image_id)
141 | indices[image_id] = counter
142 | pos_boxes[counter,:] = np.array([num_boxes, num_boxes + item['num_boxes']])
143 | img_bb[num_boxes:num_boxes+item['num_boxes'], :] = bboxes
144 | img_features[num_boxes:num_boxes+item['num_boxes'], :] = np.frombuffer(
145 | base64.decodestring(item['features']),
146 | dtype=np.float32).reshape((item['num_boxes'], -1))
147 | spatial_img_features[num_boxes:num_boxes+item['num_boxes'], :] = spatial_features
148 | counter += 1
149 | num_boxes += item['num_boxes']
150 | else:
151 | unknown_ids.append(image_id)
152 |
153 | print('%d unknown_ids...' % len(unknown_ids))
154 | print('%d image_ids left...' % len(imgids))
155 |
156 | if len(imgids) != 0:
157 | print('Warning: %s_image_ids is not empty' % split)
158 |
159 | cPickle.dump(indices, open(indices_file[split], 'wb'))
160 | h.close()
161 | print("done!")
162 |
163 | if __name__ == '__main__':
164 | infile = ['./tgifqa_1.tsv.0', './tgifqa_1.tsv.1']
165 | extract('all', infile)
166 |
--------------------------------------------------------------------------------
/data_utils/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data_utils/data_util.py:
--------------------------------------------------------------------------------
1 | #-*- coding: utf-8 -*-
2 |
3 | # --------------------------------------------------------
4 | # This code is modified from Jumpin2's repository.
5 | # https://github.com/Jumpin2/HGA
6 | # --------------------------------------------------------
7 |
8 | import numpy as np
9 | import re
10 |
11 |
12 | def clean_str(string, downcase=True):
13 | """
14 | Tokenization/string cleaning for strings.
15 | Taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
16 | """
17 | string = re.sub(r"[^A-Za-z0-9(),!?\'\`(_____)]", " ", string)
18 | string = re.sub(r"\'s", " \'s", string)
19 | string = re.sub(r"\'ve", " \'ve", string)
20 | string = re.sub(r"n\'t", " n\'t", string)
21 | string = re.sub(r"\'re", " \'re", string)
22 | string = re.sub(r"\'d", " \'d", string)
23 | string = re.sub(r"\'ll", " \'ll", string)
24 | string = re.sub(r",", " , ", string)
25 | string = re.sub(r"!", " ! ", string)
26 | string = re.sub(r"\(", " \( ", string)
27 | string = re.sub(r"\)", " \) ", string)
28 | string = re.sub(r"\?", " \? ", string)
29 | string = re.sub(r"\s{2,}", " ", string)
30 | return string.strip().lower() if downcase else string.strip()
31 |
32 |
33 | def recover_word(string):
34 | string = re.sub(r" \'s", "\'s", string)
35 | string = re.sub(r" ,", ",", string)
36 | return string
37 |
38 |
39 | def clean_blank(blank_sent):
40 | clean_sent = clean_str(blank_sent).split()
41 | return ['' if x == '_____' else x for x in clean_sent]
42 |
43 |
44 | def clean_root(string):
45 | """
46 | Remove unexpected character in root.
47 | """
48 | return string
49 |
50 |
51 | def pad_sequences(
52 | sequences, pad_token="[PAD]", pad_location="LEFT", max_length=None):
53 | """
54 | Pads all sequences to the same length. The length is defined by the longest sequence.
55 | Returns padded sequences.
56 | """
57 | if not max_length:
58 | max_length = max(len(x) for x in sequences)
59 |
60 | result = []
61 | for i in range(len(sequences)):
62 | sentence = sequences[i]
63 | num_padding = max_length - len(sentence)
64 | if num_padding == 0:
65 | new_sentence = sentence
66 | elif num_padding < 0:
67 | new_sentence = sentence[:num_padding]
68 | elif pad_location == "RIGHT":
69 | new_sentence = sentence + [pad_token] * num_padding
70 | elif pad_location == "LEFT":
71 | new_sentence = [pad_token] * num_padding + sentence
72 | else:
73 | print("Invalid pad_location. Specify LEFT or RIGHT.")
74 | result.append(new_sentence)
75 | return result
76 |
77 |
78 | def convert_sent_to_index(sentence, word_to_index):
79 | """
80 | Convert sentence consisting of string to indexed sentence.
81 | """
82 | return [
83 | word_to_index[word] if word in list(word_to_index.keys()) else 0
84 | for word in sentence
85 | ]
86 |
87 |
88 | def batch_iter(data, batch_size, seed=None, fill=True):
89 | """
90 | Generates a batch iterator for a dataset.
91 | """
92 | random = np.random.RandomState(seed)
93 | data_length = len(data)
94 | num_batches = int(data_length / batch_size)
95 | if data_length % batch_size != 0:
96 | num_batches += 1
97 | # Shuffle the data at each epoch
98 | shuffle_indices = random.permutation(np.arange(data_length))
99 | for batch_num in range(num_batches):
100 | start_index = batch_num * batch_size
101 | end_index = min((batch_num + 1) * batch_size, data_length)
102 | selected_indices = shuffle_indices[start_index:end_index]
103 | # If we don't have enough data left for a whole batch, fill it
104 | # randomly
105 | if fill is True and end_index >= data_length:
106 | num_missing = batch_size - len(selected_indices)
107 | selected_indices = np.concatenate(
108 | [selected_indices,
109 | random.randint(0, data_length, num_missing)])
110 | yield [data[i] for i in selected_indices]
111 |
112 |
113 | def fsr_iter(fsr_data, batch_size, random_seed=42, fill=True):
114 | """
115 | fsr_data : one of LSMDCData.build_data(), [[video_features], [sentences], [roots]]
116 | return per iter : [[feature]*batch_size, [sentences]*batch_size, [roots]*batch]
117 |
118 | Usage:
119 | train_data, val_data, test_data = LSMDCData.build_data()
120 | for features, sentences, roots in fsr_iter(train_data, 20, 10):
121 | feed_dict = {model.video_feature : features,
122 | model.sentences : sentences,
123 | model.roots : roots}
124 | """
125 |
126 | train_iter = batch_iter(
127 | list(zip(*fsr_data)), batch_size, fill=fill, seed=random_seed)
128 | return [list(zip(*batch)) for batch in train_iter]
129 |
130 |
131 | def preprocess_sents(descriptions, word_to_index, max_length):
132 |
133 | descriptions = [clean_str(sent).split() for sent in descriptions]
134 | descriptions = pad_sequences(descriptions, max_length=max_length)
135 | # sentence를 string list 에서 int-index list로 바꿈.
136 | descriptions = [
137 | convert_sent_to_index(sent, word_to_index) for sent in descriptions
138 | ]
139 |
140 | return descriptions
141 | # remove punctuation mark and special chars from root.
142 |
143 |
144 | def preprocess_roots(roots, word_to_index):
145 |
146 | roots = [clean_root(root) for root in roots]
147 | # convert string to int index.
148 | roots = [
149 | word_to_index[root] if root in list(word_to_index.keys()) else 0
150 | for root in roots
151 | ]
152 |
153 | return roots
154 |
155 |
156 | def pad_video(video_feature, dimension):
157 | '''
158 | Fill pad to video to have same length.
159 | Pad in Left.
160 | video = [pad,..., pad, frm1, frm2, ..., frmN]
161 | '''
162 | padded_feature = np.zeros(dimension)
163 | max_length = dimension[0]
164 | current_length = video_feature.shape[0]
165 | num_padding = max_length - current_length
166 | if num_padding == 0:
167 | padded_feature = video_feature
168 | elif num_padding < 0:
169 | steps = np.linspace(
170 | 0, current_length, num=max_length, endpoint=False, dtype=np.int32)
171 | padded_feature = video_feature[steps]
172 | else:
173 | padded_feature[num_padding:] = video_feature
174 |
175 | return padded_feature
176 |
177 |
178 | def video_3d_pad(sequence, obj_max_num, max_length):
179 | ''' Pad sequence with 0.
180 | '''
181 | sequence = np.array(sequence)
182 | sequence_shape = np.array(sequence).shape
183 | current_length = sequence_shape[0]
184 | current_num = sequence_shape[1]
185 | pad = np.zeros((max_length, obj_max_num, sequence_shape[2]),dtype=np.float32)
186 | num_padding = max_length - current_length
187 | num_obj_padding = obj_max_num - current_num
188 | if num_padding <= 0:
189 | pad = sequence[:max_length]
190 | else:
191 | pad[:current_length, :current_num] = sequence
192 | return pad
193 |
194 |
195 | def fill_mask(max_length, current_length, zero_location='LEFT'):
196 | num_padding = max_length - current_length
197 | if num_padding <= 0:
198 | mask = np.ones(max_length)
199 | elif zero_location == 'LEFT':
200 | mask = np.ones(max_length)
201 | for i in range(num_padding):
202 | mask[i] = 0
203 | elif zero_location == 'RIGHT':
204 | mask = np.zeros(max_length)
205 | for i in range(current_length):
206 | mask[i] = 1
207 |
208 | return mask
209 |
210 |
211 | def question_pad(sequence, max_length):
212 | ''' Pad sequence with 0.
213 | '''
214 | sequence = np.array(sequence)
215 | sequence_shape = np.array(sequence).shape
216 | current_length = sequence_shape[0]
217 | pad = np.zeros(max_length, dtype=np.int64)
218 | num_padding = max_length - current_length
219 | if num_padding <= 0:
220 | pad = sequence[:max_length]
221 | else:
222 | pad[:current_length] = sequence
223 | return pad
224 |
225 |
226 | def video_pad(sequence, max_length):
227 | ''' Pad sequence with 0.
228 | '''
229 | sequence = np.array(sequence)
230 | sequence_shape = np.array(sequence).shape
231 | current_length = sequence_shape[0]
232 | pad = np.zeros((max_length, sequence_shape[1]),dtype=np.float32)
233 | num_padding = max_length - current_length
234 | if num_padding <= 0:
235 | pad = sequence[:max_length]
236 | else:
237 | pad[:current_length] = sequence
238 | return pad
239 |
240 |
241 | def dict_keys_bytes2string(dic):
242 | for k in list(dic.keys()):
243 | kk = k.decode('utf-8')
244 | v = dic.pop(k)
245 | dic[kk] = v
246 | return dic
--------------------------------------------------------------------------------
/data_utils/dataset.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | """Generate data batch."""
7 | import os
8 |
9 | import pandas as pd
10 | import numpy as np
11 | import h5py
12 | from torch.utils.data import Dataset
13 | from . import data_util
14 |
15 | from util import log
16 |
17 | import os.path
18 | import sys
19 |
20 | import pickle as pkl
21 | import nltk
22 |
23 |
24 | nltk.download('punkt')
25 | # nltk.download('averaged_perceptron_tagger')
26 |
27 |
28 | __PATH__ = os.path.abspath(os.path.dirname(__file__))
29 |
30 |
31 | def assert_exists(path):
32 | assert os.path.exists(path), 'Does not exist : {}'.format(path)
33 |
34 |
35 | # PATHS
36 | TYPE_TO_CSV = {
37 | 'FrameQA': 'Train_frameqa_question.csv',
38 | 'Count': 'Train_count_question.csv',
39 | 'Trans': 'Train_transition_question.csv',
40 | 'Action': 'Train_action_question.csv'
41 | }
42 |
43 | eos_word = ''
44 |
45 |
46 | class TGIFQA(Dataset):
47 |
48 | def __init__(
49 | self,
50 | dataset_name='train',
51 | q_max_length=20,
52 | v_max_length=80,
53 | max_n_videos=None,
54 | data_type='FrameQA',
55 | csv_dir=None,
56 | vocab_dir=None,
57 | feat_dir=None):
58 | self.csv_dir = csv_dir
59 | self.vocabulary_dir = vocab_dir
60 | self.feat_dir = feat_dir
61 | self.dataset_name = dataset_name
62 | self.q_max_length = q_max_length
63 | self.v_max_length = v_max_length
64 | self.spa_q_max_length = 10
65 | self.tmp_q_max_length = 7
66 | self.obj_max_num = 10
67 | self.max_n_videos = max_n_videos
68 | self.data_type = data_type
69 | self.GLOVE_EMBEDDING_SIZE = 300
70 |
71 | self.data_df = self.read_df_from_csvfile()
72 |
73 | if max_n_videos is not None:
74 | self.data_df = self.data_df[:max_n_videos]
75 |
76 | self.res_avg_file = os.path.join(self.feat_dir, "feat_r6_vsync2/res152_avgpool.hdf5")
77 | self.i3d_avg_file = os.path.join(self.feat_dir, "feat_r6_vsync2/tgif_i3d_hw7_perclip_avgpool.hdf5")
78 |
79 | self.res_avg_feat = None
80 | self.i3d_avg_feat = None
81 |
82 | self.res_roi_file = os.path.join(self.feat_dir, "feat_r6_vsync2/tgif_btup_f_obj10.hdf5")
83 | self.i3d_roi_file = os.path.join(self.feat_dir, "feat_r6_vsync2/tgif_i3d_roialign_hw7_perclip_avgpool_obj10.hdf5")
84 |
85 | self.res_roi_feat = None
86 | self.i3d_roi_feat = None
87 |
88 | self.load_word_vocabulary()
89 | if self.data_type == 'FrameQA':
90 | self.get_FrameQA_result()
91 | elif self.data_type == 'Count':
92 | self.get_Count_result()
93 | elif self.data_type == 'Trans':
94 | self.get_Trans_result()
95 | elif self.data_type == 'Action':
96 | self.get_Trans_result()
97 |
98 | def __getitem__(self, index):
99 | if self.data_type == 'FrameQA':
100 | return self.getitem_frameqa(index)
101 | elif self.data_type == 'Count':
102 | return self.getitem_count(index)
103 | elif self.data_type == 'Trans':
104 | return self.getitem_trans(index)
105 | elif self.data_type == 'Action':
106 | return self.getitem_trans(index)
107 |
108 | def __len__(self):
109 | if self.max_n_videos is not None:
110 | if self.max_n_videos <= len(self.data_df):
111 | return self.max_n_videos
112 | return len(self.data_df)
113 |
114 | def read_df_from_csvfile(self):
115 | assert self.data_type in [
116 | 'FrameQA', 'Count', 'Trans', 'Action'
117 | ], 'Should choose data type '
118 |
119 | if self.data_type == 'FrameQA':
120 | train_data_path = os.path.join(
121 | self.csv_dir, 'Train_frameqa_question.csv')
122 | test_data_path = os.path.join(
123 | self.csv_dir, 'Test_frameqa_question.csv')
124 | #self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_frameqa_question.csv'), sep='\t')
125 | self.total_q = pd.read_csv(
126 | os.path.join(self.csv_dir, 'Total_frameqa_question.csv'),
127 | sep='\t')
128 | elif self.data_type == 'Count':
129 | train_data_path = os.path.join(
130 | self.csv_dir, 'Train_count_question.csv')
131 | test_data_path = os.path.join(
132 | self.csv_dir, 'Test_count_question.csv')
133 | #self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_count_question.csv'), sep='\t')
134 | self.total_q = pd.read_csv(
135 | os.path.join(self.csv_dir, 'Total_count_question.csv'),
136 | sep='\t')
137 | elif self.data_type == 'Trans':
138 | train_data_path = os.path.join(
139 | self.csv_dir, 'Train_transition_question.csv')
140 | test_data_path = os.path.join(
141 | self.csv_dir, 'Test_transition_question.csv')
142 | #self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_transition_question.csv'), sep='\t')
143 | self.total_q = pd.read_csv(
144 | os.path.join(self.csv_dir, 'Total_transition_question.csv'),
145 | sep='\t')
146 | elif self.data_type == 'Action':
147 | train_data_path = os.path.join(
148 | self.csv_dir, 'Train_action_question.csv')
149 | test_data_path = os.path.join(
150 | self.csv_dir, 'Test_action_question.csv')
151 | # self.total_q = pd.DataFrame().from_csv(os.path.join(self.csv_dir,'Total_action_question.csv'), sep='\t')
152 | self.total_q = pd.read_csv(
153 | os.path.join(self.csv_dir, 'Total_action_question.csv'),
154 | sep='\t')
155 |
156 | assert_exists(train_data_path)
157 | assert_exists(test_data_path)
158 |
159 | if self.dataset_name == 'train':
160 | data_df = pd.read_csv(train_data_path, sep='\t')
161 | elif self.dataset_name == 'test':
162 | data_df = pd.read_csv(test_data_path, sep='\t')
163 |
164 | data_df = data_df.set_index('vid_id')
165 | data_df['row_index'] = list(
166 | range(1,
167 | len(data_df) + 1)) # assign csv row index
168 | return data_df
169 |
170 | @property
171 | def n_words(self):
172 | ''' The dictionary size. '''
173 | if not hasattr(self, 'word2idx'):
174 | raise Exception('Dictionary not built yet!')
175 | return len(self.word2idx)
176 |
177 | def __repr__(self):
178 | if hasattr(self, 'word2idx'):
179 | return '' % (
180 | self.dataset_name, len(self), len(self.word2idx))
181 | else:
182 | return '' % (
183 | self.dataset_name, len(self))
184 |
185 | def split_sentence_into_words(self, sentence, eos=True):
186 | '''
187 | Split the given sentence (str) and enumerate the words as strs.
188 | Each word is normalized, i.e. lower-cased, non-alphabet characters
189 | like period (.) or comma (,) are stripped.
190 | When tokenizing, I use ``data_util.clean_str``
191 | '''
192 | try:
193 | words = data_util.clean_str(sentence).split()
194 | except:
195 | print(sentence)
196 | sys.exit()
197 | if eos:
198 | words = words + [eos_word]
199 | for w in words:
200 | if not w:
201 | continue
202 | yield w
203 |
204 | def build_word_vocabulary(
205 | self,
206 | all_captions_source=None,
207 | word_count_threshold=0,
208 | ):
209 | '''
210 | borrowed this implementation from @karpathy's neuraltalk.
211 | '''
212 | log.infov('Building word vocabulary (%s) ...', self.dataset_name)
213 |
214 | if all_captions_source is None:
215 | all_captions_source = self.get_all_captions()
216 |
217 | # enumerate all sentences to build frequency table
218 | word_counts = {}
219 | nsents = 0
220 | nwords = 0
221 | for sentence in all_captions_source:
222 | nsents += 1
223 | for w in self.split_sentence_into_words(sentence):
224 | word_counts[w] = word_counts.get(w, 0) + 1
225 | nwords += 1
226 |
227 | import pickle as pkl
228 | vocab = [
229 | w for w in word_counts if word_counts[w] >= word_count_threshold
230 | ]
231 | print("[%d] : [%d] word from captions" % (len(vocab), nwords))
232 | log.info(
233 | "Filtered vocab words (threshold = %d), from %d to %d",
234 | word_count_threshold, len(word_counts), len(vocab))
235 |
236 | # build index and vocabularies
237 | self.word2idx = {}
238 | self.idx2word = {}
239 |
240 | self.idx2word[0] = '.'
241 | self.idx2word[1] = 'UNK'
242 | self.word2idx['#START#'] = 0
243 | self.word2idx['UNK'] = 1
244 | for idx, w in enumerate(vocab, start=2):
245 | self.word2idx[w] = idx
246 | self.idx2word[idx] = w
247 |
248 | pkl.dump(
249 | self.word2idx,
250 | open(
251 | os.path.join(
252 | self.vocabulary_dir,
253 | 'word_to_index_%s.pkl' % self.data_type), 'wb'))
254 | pkl.dump(
255 | self.idx2word,
256 | open(
257 | os.path.join(
258 | self.vocabulary_dir,
259 | 'index_to_word_%s.pkl' % self.data_type), 'wb'))
260 |
261 | word_counts['.'] = nsents
262 | bias_init_vector = np.array(
263 | [
264 | 1.0 * word_counts[w] if i > 1 else 0
265 | for i, w in self.idx2word.items()
266 | ])
267 | bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies
268 | bias_init_vector = np.log(bias_init_vector + 1e-20)
269 | bias_init_vector -= np.max(
270 | bias_init_vector) # shift to nice numeric range
271 | self.bias_init_vector = bias_init_vector
272 |
273 | #self.total_q = pd.DataFrame().from_csv(os.path.join(csv_dir,'Total_desc_question.csv'), sep='\t')
274 | answers = list(set(self.total_q['answer'].values))
275 | self.ans2idx = {}
276 | self.idx2ans = {}
277 | for idx, w in enumerate(answers):
278 | self.ans2idx[w] = idx
279 | self.idx2ans[idx] = w
280 | pkl.dump(
281 | self.ans2idx,
282 | open(
283 | os.path.join(
284 | self.vocabulary_dir,
285 | 'ans_to_index_%s.pkl' % self.data_type), 'wb'))
286 | pkl.dump(
287 | self.idx2ans,
288 | open(
289 | os.path.join(
290 | self.vocabulary_dir,
291 | 'index_to_ans_%s.pkl' % self.data_type), 'wb'))
292 |
293 | # Make glove embedding.
294 | #import spacy
295 | #nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors')
296 |
297 | with open(self.vocabulary_dir + '/glove.42B.300d.txt', 'rb') as f:
298 | lines = f.readlines()
299 | dict = {}
300 |
301 | for ix, line in enumerate(lines):
302 | if ix % 1000 == 0:
303 | print((ix, len(lines)))
304 | segs = line.split()
305 | wd = segs[0]
306 | dict[wd] = segs[1:]
307 |
308 | #max_length = len(vocab)
309 | max_length = len(self.idx2word)
310 | GLOVE_EMBEDDING_SIZE = 300
311 |
312 | glove_matrix = np.random.normal(size=[max_length, GLOVE_EMBEDDING_SIZE])
313 | not_in_vocab = 0
314 | for i in range(len(vocab)):
315 | if i % 1000 == 0:
316 | print((i, len(vocab)))
317 | w = bytes(vocab[i], encoding='utf-8')
318 | #w_embed = nlp(u'%s' % w).vector
319 | if w in dict:
320 | w_embed = np.array(dict[w])
321 | else:
322 | not_in_vocab += 1
323 | w_embed = np.random.normal(size=(GLOVE_EMBEDDING_SIZE))
324 | #glove_matrix[i,:] = w_embed
325 | glove_matrix[i + 2, :] = w_embed # two placeholder words '.','UNK'
326 |
327 | print("[%d/%d] word is not saved" % (not_in_vocab, len(vocab)))
328 |
329 | vocab = pkl.dump(
330 | glove_matrix,
331 | open(
332 | os.path.join(
333 | self.vocabulary_dir,
334 | 'vocab_embedding_%s.pkl' % self.data_type), 'wb'))
335 | self.word_matrix = glove_matrix
336 |
337 | def load_word_vocabulary(self):
338 |
339 | word_matrix_path = os.path.join(
340 | self.vocabulary_dir, 'vocab_embedding_%s.pkl' % self.data_type)
341 |
342 | word2idx_path = os.path.join(
343 | self.vocabulary_dir, 'word_to_index_%s.pkl' % self.data_type)
344 | idx2word_path = os.path.join(
345 | self.vocabulary_dir, 'index_to_word_%s.pkl' % self.data_type)
346 | ans2idx_path = os.path.join(
347 | self.vocabulary_dir, 'ans_to_index_%s.pkl' % self.data_type)
348 | idx2ans_path = os.path.join(
349 | self.vocabulary_dir, 'index_to_ans_%s.pkl' % self.data_type)
350 |
351 | if not (os.path.exists(word_matrix_path) and
352 | os.path.exists(word2idx_path) and
353 | os.path.exists(idx2word_path) and
354 | os.path.exists(ans2idx_path) and os.path.exists(idx2ans_path)):
355 | self.build_word_vocabulary()
356 |
357 | # ndarray
358 | with open(word_matrix_path, 'rb') as f:
359 | self.word_matrix = pkl.load(f)
360 | log.info("Load word_matrix from pkl file : %s", word_matrix_path)
361 |
362 | with open(word2idx_path, 'rb') as f:
363 | self.word2idx = pkl.load(f)
364 | log.info("Load word2idx from pkl file : %s", word2idx_path)
365 |
366 | with open(idx2word_path, 'rb') as f:
367 | self.idx2word = pkl.load(f)
368 | log.info("Load idx2word from pkl file : %s", idx2word_path)
369 |
370 | with open(ans2idx_path, 'rb') as f:
371 | self.ans2idx = pkl.load(f)
372 | log.info("Load answer2idx from pkl file : %s", ans2idx_path)
373 |
374 | with open(idx2ans_path, 'rb') as f:
375 | self.idx2ans = pkl.load(f)
376 | log.info("Load idx2answers from pkl file : %s", idx2ans_path)
377 |
378 | def share_word_vocabulary_from(self, dataset):
379 | assert hasattr(dataset, 'idx2word') and hasattr(
380 | dataset, 'word2idx'
381 | ), 'The dataset instance should have idx2word and word2idx'
382 | assert (
383 | isinstance(dataset.idx2word, dict) or
384 | isinstance(dataset.idx2word, list)
385 | ) and isinstance(
386 | dataset.word2idx, dict
387 | ), 'The dataset instance should have idx2word and word2idx (as dict)'
388 |
389 | if hasattr(self, 'word2idx'):
390 | log.warn(
391 | "Overriding %s' word vocabulary from %s ...", self, dataset)
392 |
393 | self.idx2word = dataset.idx2word
394 | self.word2idx = dataset.word2idx
395 | self.ans2idx = dataset.ans2idx
396 | self.idx2ans = dataset.idx2ans
397 | if hasattr(dataset, 'word_matrix'):
398 | self.word_matrix = dataset.word_matrix
399 |
400 | def get_all_captions(self):
401 | '''
402 | Iterate caption strings associated in the vid/gifs.
403 | '''
404 | #qa_data_df = pd.DataFrame().from_csv(os.path.join(self.csv_dir, TYPE_TO_CSV[self.data_type]), sep='\t')
405 | qa_data_df = pd.read_csv(
406 | os.path.join(self.csv_dir, TYPE_TO_CSV[self.data_type]), sep='\t')
407 |
408 | all_sents = []
409 | for row in qa_data_df.iterrows():
410 | all_sents.extend(self.get_captions(row))
411 | self.data_type
412 | return all_sents
413 |
414 | def get_captions(self, row):
415 | if self.data_type == 'FrameQA':
416 | columns = ['description', 'question', 'answer']
417 | elif self.data_type == 'Count':
418 | columns = ['question']
419 | elif self.data_type == 'Trans':
420 | columns = ['question', 'a1', 'a2', 'a3', 'a4', 'a5']
421 | elif self.data_type == 'Action':
422 | columns = ['question', 'a1', 'a2', 'a3', 'a4', 'a5']
423 |
424 | sents = [row[1][col] for col in columns if not pd.isnull(row[1][col])]
425 | return sents
426 |
427 | def get_video_feature(self, key): # key : gif_name
428 | if self.res_avg_feat is None:
429 | self.res_avg_feat = h5py.File(self.res_avg_file, 'r')
430 | if self.i3d_avg_feat is None:
431 | self.i3d_avg_feat = h5py.File(self.i3d_avg_file, 'r')
432 | if self.res_roi_feat is None:
433 | self.res_roi_feat = h5py.File(self.res_roi_file, 'r')
434 | if self.i3d_roi_feat is None:
435 | self.i3d_roi_feat = h5py.File(self.i3d_roi_file, 'r')
436 |
437 | video_id = str(key)
438 |
439 | try:
440 | res_avg_feat = np.array(self.res_avg_feat[video_id]) # T, d
441 | i3d_avg_feat = np.array(self.i3d_avg_feat[video_id]) # T, d
442 | res_roi_feat = np.array(self.res_roi_feat['image_features'][video_id]) # T, 5, d
443 | roi_bbox_feat = np.array(self.res_roi_feat['spatial_features'][video_id]) # T, 5, 6
444 | i3d_roi_feat = np.array(self.i3d_roi_feat[video_id]) # T, 5, d
445 | except KeyError: # no img
446 | print('no image', key)
447 | res_avg_feat = np.zeros((1, 2048))
448 | i3d_avg_feat = np.zeros((1, 2048))
449 | res_roi_feat = np.zeros((1, self.obj_max_num, 2048))
450 | roi_bbox_feat = np.zeros((1, self.obj_max_num, 6))
451 | i3d_roi_feat = np.zeros((1, self.obj_max_num, 2048))
452 |
453 | return res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat
454 |
455 | def convert_sentence_to_matrix(self, task, question, answer=None, eos=True):
456 | '''
457 | Convert the given sentence into word indices and masks.
458 | WARNING: Unknown words (not in vocabulary) are revmoed.
459 |
460 | Args:
461 | sentence: A str for unnormalized sentence, containing T words
462 |
463 | Returns:
464 | sentence_word_indices : list of (at most) length T,
465 | each being a word index
466 | '''
467 | if answer is None:
468 | sentence = question
469 | else:
470 | sentence = question + ' ' + answer
471 |
472 | token = self.split_sentence_into_words(sentence, eos=False)
473 | token = [t for t in token]
474 |
475 | sent2indices = [
476 | self.word2idx[w] if w in self.word2idx else 1
477 | for w in token
478 | ]
479 | T = len(sent2indices)
480 | length = min(T, self.q_max_length)
481 |
482 | return sent2indices[:length]
483 |
484 | def get_question(self, key):
485 | '''
486 | Return question index for given key.
487 | '''
488 | question = self.data_df.loc[key, ['question', 'description']].values
489 | if len(list(question.shape)) > 1:
490 | question = question[0]
491 | # A question string
492 | question = question[0]
493 |
494 | return self.convert_sentence_to_matrix(self.data_type, question, eos=False)
495 |
496 | def get_answer(self, key):
497 | answer = self.data_df.loc[key, ['answer', 'type']].values
498 |
499 | if len(list(answer.shape)) > 1:
500 | answer = answer[0]
501 |
502 | anstype = answer[1]
503 | answer = answer[0]
504 |
505 | return answer, anstype
506 |
507 | def get_FrameQA_result(self):
508 | self.padded_all_ques = []
509 | self.answers = []
510 | self.answer_type = []
511 | self.all_ques_lengths = []
512 |
513 | self.keys = self.data_df['key'].values.astype(np.int64)
514 |
515 | for index, row in self.data_df.iterrows():
516 | # ====== Question ======
517 | all_ques = self.get_question(index)
518 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length)
519 | self.padded_all_ques.append(all_ques_pad)
520 |
521 | all_ques_length = min(self.q_max_length, len(all_ques))
522 | self.all_ques_lengths.append(all_ques_length)
523 |
524 | answer, answer_type = self.get_answer(index)
525 | if str(answer) in self.ans2idx:
526 | answer = self.ans2idx[answer]
527 | else:
528 | # unknown token, check later
529 | answer = 1
530 | self.answers.append(np.array(answer, dtype=np.int64))
531 | self.answer_type.append(np.array(answer_type, dtype=np.int64))
532 |
533 | def getitem_frameqa(self, index):
534 | key = self.keys[index]
535 |
536 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat = self.get_video_feature(key)
537 |
538 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32)
539 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32)
540 |
541 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num,
542 | self.v_max_length).astype(np.float32)
543 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num,
544 | self.v_max_length).astype(np.float32)
545 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num,
546 | self.v_max_length).astype(np.float32)
547 | video_length = min(self.v_max_length, res_avg_feat.shape[0])
548 |
549 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \
550 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index], self.answer_type[index]
551 |
552 | def get_Count_question(self, key):
553 | '''
554 | Return question string for given key.
555 | '''
556 | question = self.data_df.loc[key, 'question']
557 | return self.convert_sentence_to_matrix(self.data_type, question, eos=False)
558 |
559 | def get_Count_result(self):
560 | self.padded_all_ques = []
561 | self.answers = []
562 | self.all_ques_lengths = []
563 |
564 | self.keys = self.data_df['key'].values.astype(np.int64)
565 |
566 | for index, row in self.data_df.iterrows():
567 | # ====== Question ======
568 | all_ques = self.get_Count_question(index)
569 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length)
570 |
571 | self.padded_all_ques.append(all_ques_pad)
572 |
573 | all_ques_length = min(self.q_max_length, len(all_ques))
574 | self.all_ques_lengths.append(all_ques_length)
575 |
576 | # force answer not equal to 0
577 | answer = max(row['answer'], 1)
578 | self.answers.append(np.array(answer, dtype=np.float32))
579 |
580 | def getitem_count(self, index):
581 | key = self.keys[index]
582 |
583 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat = self.get_video_feature(key)
584 |
585 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32)
586 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32)
587 |
588 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num,
589 | self.v_max_length).astype(np.float32)
590 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num,
591 | self.v_max_length).astype(np.float32)
592 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num,
593 | self.v_max_length).astype(np.float32)
594 | video_length = min(self.v_max_length, res_avg_feat.shape[0])
595 |
596 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \
597 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index]
598 |
599 | def get_Trans_matrix(self, candidates, q_max_length, is_left=True):
600 | candidates_matrix = np.zeros([5, q_max_length], dtype=np.int64)
601 | for k in range(5):
602 | sentence = candidates[k]
603 | if is_left:
604 | candidates_matrix[k, :len(sentence)] = sentence
605 | else:
606 | candidates_matrix[k, -len(sentence):] = sentence
607 | return candidates_matrix
608 |
609 | def get_Trans_result(self):
610 | self.padded_candidates = []
611 | self.answers = self.data_df['answer'].values.astype(np.int64)
612 | self.row_index = self.data_df['row_index'].values.astype(np.int64)
613 | self.candidate_lengths = []
614 |
615 | self.keys = self.data_df['key'].values.astype(np.int64)
616 |
617 | for index, row in self.data_df.iterrows():
618 | a1 = row['a1'].strip()
619 | a2 = row['a2'].strip()
620 | a3 = row['a3'].strip()
621 | a4 = row['a4'].strip()
622 | a5 = row['a5'].strip()
623 | candidates = [a1, a2, a3, a4, a5]
624 | raw_question = row['question'].strip()
625 | indexed_candidates = []
626 |
627 | for x in candidates:
628 | all_cand = self.convert_sentence_to_matrix(self.data_type, raw_question, x, eos=False)
629 | indexed_candidates.append(all_cand)
630 |
631 | cand_lens = []
632 | for i in range(len(indexed_candidates)):
633 | vl = min(self.q_max_length, len(indexed_candidates[i]))
634 | cand_lens.append(vl)
635 | self.candidate_lengths.append(cand_lens)
636 |
637 | # (5, self.q_max_length)
638 | candidates_pad = self.get_Trans_matrix(indexed_candidates, self.q_max_length)
639 | self.padded_candidates.append(candidates_pad)
640 |
641 | def getitem_trans(self, index):
642 | key = self.keys[index]
643 |
644 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat = self.get_video_feature(key)
645 |
646 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32)
647 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32)
648 |
649 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num,
650 | self.v_max_length).astype(np.float32)
651 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num,
652 | self.v_max_length).astype(np.float32)
653 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num,
654 | self.v_max_length).astype(np.float32)
655 | video_length = min(self.v_max_length, res_avg_feat.shape[0])
656 |
657 | self.padded_candidates = np.asarray(self.padded_candidates).astype(
658 | np.int64)
659 | self.candidate_lengths = np.asarray(self.candidate_lengths).astype(
660 | np.int64)
661 |
662 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \
663 | self.padded_candidates[index], self.candidate_lengths[index], self.answers[index], self.row_index[index]
664 |
--------------------------------------------------------------------------------
/data_utils/dataset_msrvtt.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | """Generate data batch."""
7 | import os
8 | import os.path
9 | import sys
10 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
11 |
12 | import pandas as pd
13 | import numpy as np
14 | import h5py
15 | from torch.utils.data import Dataset
16 | from data_utils import data_util
17 |
18 | from util import log
19 |
20 | import pickle as pkl
21 | import nltk
22 |
23 |
24 | nltk.download('punkt')
25 | # nltk.download('averaged_perceptron_tagger')
26 |
27 |
28 | __PATH__ = os.path.abspath(os.path.dirname(__file__))
29 |
30 |
31 | def assert_exists(path):
32 | assert os.path.exists(path), 'Does not exist : {}'.format(path)
33 |
34 | eos_word = ''
35 |
36 |
37 | class MSRVTTQA(Dataset):
38 |
39 | def __init__(
40 | self,
41 | dataset_name='train',
42 | q_max_length=20,
43 | v_max_length=80,
44 | max_n_videos=None,
45 | csv_dir=None,
46 | vocab_dir=None,
47 | feat_dir=None):
48 | self.csv_dir = csv_dir
49 | self.vocabulary_dir = vocab_dir
50 | self.feat_dir = feat_dir
51 | self.dataset_name = dataset_name
52 | self.q_max_length = q_max_length
53 | self.v_max_length = v_max_length
54 | self.obj_max_num = 10
55 | self.max_n_videos = max_n_videos
56 | self.GLOVE_EMBEDDING_SIZE = 300
57 |
58 | self.data_df = self.read_df_from_json()
59 |
60 | if max_n_videos is not None:
61 | self.data_df = self.data_df[:max_n_videos]
62 |
63 | self.res_avg_file = os.path.join(self.feat_dir, "vfeat/msrvtt_res152_avgpool.hdf5")
64 | self.i3d_avg_file = os.path.join(self.feat_dir, "vfeat/msrvtt_i3d_avgpool_perclip.hdf5")
65 |
66 | self.res_avg_feat = None
67 | self.i3d_avg_feat = None
68 |
69 | self.res_roi_file = os.path.join(self.feat_dir, "vfeat/msrvtt_btup_f_obj10.hdf5")
70 | self.i3d_roi_file = os.path.join(self.feat_dir, "vfeat/msrvtt_i3d_roialign_perclip_obj10.hdf5")
71 |
72 | self.res_roi_feat = None
73 | self.i3d_roi_feat = None
74 |
75 | self.load_word_vocabulary()
76 | self.get_result()
77 |
78 | def __getitem__(self, index):
79 | return self.get_item(index)
80 |
81 | def __len__(self):
82 | if self.max_n_videos is not None:
83 | if self.max_n_videos <= len(self.data_df):
84 | return self.max_n_videos
85 | return len(self.data_df)
86 |
87 | @property
88 | def n_words(self):
89 | ''' The dictionary size. '''
90 | if not hasattr(self, 'word2idx'):
91 | raise Exception('Dictionary not built yet!')
92 | return len(self.word2idx)
93 |
94 | def __repr__(self):
95 | if hasattr(self, 'word2idx'):
96 | return '' % (
97 | self.dataset_name, len(self), len(self.word2idx))
98 | else:
99 | return '' % (
100 | self.dataset_name, len(self))
101 |
102 | def read_df_from_json(self):
103 | if self.dataset_name == 'train':
104 | data_path = '%s/train_qa.json'%self.csv_dir
105 | elif self.dataset_name == 'val':
106 | data_path = '%s/val_qa.json'%self.csv_dir
107 | elif self.dataset_name == 'test':
108 | data_path = '%s/test_qa.json'%self.csv_dir
109 |
110 | with open(data_path, 'r') as f:
111 | data_df = pd.read_json(f)
112 |
113 | return data_df
114 |
115 | def split_sentence_into_words(self, sentence, eos=True):
116 | '''
117 | Split the given sentence (str) and enumerate the words as strs.
118 | Each word is normalized, i.e. lower-cased, non-alphabet characters
119 | like period (.) or comma (,) are stripped.
120 | When tokenizing, I use ``data_util.clean_str``
121 | '''
122 | try:
123 | words = data_util.clean_str(sentence).split()
124 | except:
125 | print(sentence)
126 | sys.exit()
127 | if eos:
128 | words = words + [eos_word]
129 | for w in words:
130 | if not w:
131 | continue
132 | yield w
133 |
134 | def create_answerset(self, ans_df):
135 | """Generate 1000 answer set from train_qa.json.
136 | Args:
137 | trainqa_path: path to train_qa.json.
138 | answerset_path: generate answer set of mc_qa
139 | """
140 | ans_num = 4000
141 | answer_freq = ans_df['answer'].value_counts()
142 | answer_freq = list(answer_freq.iloc[0:ans_num].keys())
143 | return answer_freq
144 |
145 | def build_word_vocabulary(
146 | self,
147 | all_sen=None,
148 | ans_df=None,
149 | word_count_threshold=0,
150 | ):
151 | '''
152 | borrowed this implementation from @karpathy's neuraltalk.
153 | '''
154 | log.infov('Building word vocabulary (%s) ...', self.dataset_name)
155 |
156 | if all_sen is None or ans_df is None:
157 | all_sen, ans_df = self.get_all_captions()
158 | all_captions_source = all_sen
159 |
160 | # enumerate all sentences to build frequency table
161 | word_counts = {}
162 | nsents = 0
163 | nwords = 0
164 | for sentence in all_captions_source:
165 | nsents += 1
166 | for w in self.split_sentence_into_words(sentence):
167 | word_counts[w] = word_counts.get(w, 0) + 1
168 | nwords += 1
169 |
170 | vocab = [
171 | w for w in word_counts if word_counts[w] >= word_count_threshold
172 | ]
173 | print("Filtered vocab words (threshold = %d), from %d to %d" % (word_count_threshold, len(word_counts), len(vocab)))
174 | log.info(
175 | "Filtered vocab words (threshold = %d), from %d to %d",
176 | word_count_threshold, len(word_counts), len(vocab))
177 |
178 | # build index and vocabularies
179 | self.word2idx = {}
180 | self.idx2word = {}
181 |
182 | self.idx2word[0] = '.'
183 | self.idx2word[1] = 'UNK'
184 | self.word2idx['#START#'] = 0
185 | self.word2idx['UNK'] = 1
186 | for idx, w in enumerate(vocab, start=2):
187 | self.word2idx[w] = idx
188 | self.idx2word[idx] = w
189 |
190 | pkl.dump(
191 | self.word2idx,
192 | open(os.path.join(self.vocabulary_dir, 'word_to_index.pkl'), 'wb')
193 | )
194 | pkl.dump(
195 | self.idx2word,
196 | open(os.path.join(self.vocabulary_dir, 'index_to_word.pkl'), 'wb')
197 | )
198 |
199 | word_counts['.'] = nsents
200 | bias_init_vector = np.array(
201 | [
202 | 1.0 * word_counts[w] if i > 1 else 0
203 | for i, w in self.idx2word.items()
204 | ])
205 | bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies
206 | bias_init_vector = np.log(bias_init_vector + 1e-20)
207 | bias_init_vector -= np.max(
208 | bias_init_vector) # shift to nice numeric range
209 | self.bias_init_vector = bias_init_vector
210 |
211 | #self.total_q = pd.DataFrame().from_csv(os.path.join(csv_dir,'Total_desc_question.csv'), sep='\t')
212 | answers = self.create_answerset(ans_df)
213 | print("all answer num : %d"%len(answers))
214 | self.ans2idx = {}
215 | self.idx2ans = {}
216 | for idx, w in enumerate(answers):
217 | self.ans2idx[w] = idx
218 | self.idx2ans[idx] = w
219 | pkl.dump(
220 | self.ans2idx,
221 | open(os.path.join(self.vocabulary_dir, 'ans_to_index.pkl'), 'wb')
222 | )
223 | pkl.dump(
224 | self.idx2ans,
225 | open(os.path.join(self.vocabulary_dir, 'index_to_ans.pkl'), 'wb')
226 | )
227 |
228 | # Make glove embedding.
229 | #import spacy
230 | #nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors')
231 |
232 | with open(self.vocabulary_dir + '/glove.42B.300d.txt', 'rb') as f:
233 | lines = f.readlines()
234 | dict = {}
235 |
236 | for ix, line in enumerate(lines):
237 | if ix % 1000 == 0:
238 | print((ix, len(lines)))
239 | segs = line.split()
240 | wd = segs[0]
241 | dict[wd] = segs[1:]
242 |
243 | #max_length = len(vocab)
244 | max_length = len(self.idx2word)
245 | GLOVE_EMBEDDING_SIZE = 300
246 |
247 | glove_matrix = np.random.normal(size=[max_length, GLOVE_EMBEDDING_SIZE])
248 | not_in_vocab = 0
249 | for i in range(len(vocab)):
250 | if i % 1000 == 0:
251 | print((i, len(vocab)))
252 | w = bytes(vocab[i], encoding='utf-8')
253 | #w_embed = nlp(u'%s' % w).vector
254 | if w in dict:
255 | w_embed = np.array(dict[w])
256 | else:
257 | not_in_vocab += 1
258 | w_embed = np.random.normal(size=(GLOVE_EMBEDDING_SIZE))
259 | #glove_matrix[i,:] = w_embed
260 | glove_matrix[i + 2, :] = w_embed # two placeholder words '.','UNK'
261 |
262 | print("[%d/%d] word is not saved" % (not_in_vocab, len(vocab)))
263 |
264 | vocab = pkl.dump(
265 | glove_matrix,
266 | open(os.path.join(self.vocabulary_dir, 'vocab_embedding.pkl'), 'wb')
267 | )
268 | self.word_matrix = glove_matrix
269 |
270 | def load_word_vocabulary(self):
271 |
272 | word_matrix_path = os.path.join(
273 | self.vocabulary_dir, 'vocab_embedding.pkl')
274 |
275 | word2idx_path = os.path.join(
276 | self.vocabulary_dir, 'word_to_index.pkl')
277 | idx2word_path = os.path.join(
278 | self.vocabulary_dir, 'index_to_word.pkl')
279 | ans2idx_path = os.path.join(
280 | self.vocabulary_dir, 'ans_to_index.pkl')
281 | idx2ans_path = os.path.join(
282 | self.vocabulary_dir, 'index_to_ans.pkl')
283 |
284 | if not (os.path.exists(word_matrix_path) and
285 | os.path.exists(word2idx_path) and
286 | os.path.exists(idx2word_path) and
287 | os.path.exists(ans2idx_path) and os.path.exists(idx2ans_path)):
288 | self.build_word_vocabulary()
289 |
290 | # ndarray
291 | with open(word_matrix_path, 'rb') as f:
292 | self.word_matrix = pkl.load(f)
293 | log.info("Load word_matrix from pkl file : %s", word_matrix_path)
294 |
295 | with open(word2idx_path, 'rb') as f:
296 | self.word2idx = pkl.load(f)
297 | log.info("Load word2idx from pkl file : %s", word2idx_path)
298 |
299 | with open(idx2word_path, 'rb') as f:
300 | self.idx2word = pkl.load(f)
301 | log.info("Load idx2word from pkl file : %s", idx2word_path)
302 |
303 | with open(ans2idx_path, 'rb') as f:
304 | self.ans2idx = pkl.load(f)
305 | log.info("Load answer2idx from pkl file : %s", ans2idx_path)
306 |
307 | with open(idx2ans_path, 'rb') as f:
308 | self.idx2ans = pkl.load(f)
309 | log.info("Load idx2answers from pkl file : %s", idx2ans_path)
310 |
311 | def get_all_captions(self):
312 | '''
313 | Iterate caption strings associated in the vid/gifs.
314 | '''
315 | data_path = '%s/train_qa.json'%self.csv_dir
316 | with open(data_path, 'r') as f:
317 | data_df = pd.read_json(f)
318 |
319 | all_sents = list(data_df['question'])
320 | return all_sents, data_df
321 |
322 | def get_video_feature(self, key): # key : gif_name
323 | if self.res_avg_feat is None:
324 | self.res_avg_feat = h5py.File(self.res_avg_file, 'r')
325 | if self.i3d_avg_feat is None:
326 | self.i3d_avg_feat = h5py.File(self.i3d_avg_file, 'r')
327 | if self.res_roi_feat is None:
328 | self.res_roi_feat = h5py.File(self.res_roi_file, 'r')
329 | if self.i3d_roi_feat is None:
330 | self.i3d_roi_feat = h5py.File(self.i3d_roi_file, 'r')
331 |
332 | video_id = 'video' + str(key)
333 |
334 | try:
335 | res_avg_feat = np.array(self.res_avg_feat[video_id]) # T, d
336 | i3d_avg_feat = np.array(self.i3d_avg_feat[video_id]) # T, d
337 | res_roi_feat = np.array(self.res_roi_feat['image_features'][video_id]) # T, 5, d
338 | roi_bbox_feat = np.array(self.res_roi_feat['spatial_features'][video_id]) # T, 5, 6
339 | i3d_roi_feat = np.array(self.i3d_roi_feat[video_id]) # T, 5, d
340 | except KeyError: # no img
341 | print('no image', key)
342 | res_avg_feat = np.zeros((1, 2048))
343 | i3d_avg_feat = np.zeros((1, 2048))
344 | res_roi_feat = np.zeros((1, self.obj_max_num, 2048))
345 | roi_bbox_feat = np.zeros((1, self.obj_max_num, 6))
346 | i3d_roi_feat = np.zeros((1, self.obj_max_num, 2048))
347 |
348 | return res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat
349 |
350 | def convert_sentence_to_matrix(self, sentence, eos=True):
351 | '''
352 | Convert the given sentence into word indices and masks.
353 | WARNING: Unknown words (not in vocabulary) are revmoed.
354 |
355 | Args:
356 | sentence: A str for unnormalized sentence, containing T words
357 |
358 | Returns:
359 | sentence_word_indices : list of (at most) length T,
360 | each being a word index
361 | '''
362 | sent2indices = [
363 | self.word2idx[w] if w in self.word2idx else 1
364 | for w in sentence
365 | ] # 1 is UNK, unknown
366 | T = len(sent2indices)
367 | length = min(T, self.q_max_length)
368 | return sent2indices[:length]
369 |
370 | def get_question(self, key):
371 | '''
372 | Return question index for given key.
373 | '''
374 | question = self.data_df.loc[key, ['question']].values
375 | question = question[0]
376 |
377 | question = self.split_sentence_into_words(question, eos=False)
378 | q_refine = []
379 | for w in question:
380 | q_refine.append(w)
381 | q_refine = self.convert_sentence_to_matrix(q_refine)
382 | return q_refine
383 |
384 | def get_answer(self, key):
385 | answer = self.data_df.loc[key, ['answer']].values
386 | answer = answer[0]
387 |
388 | return answer
389 |
390 | def get_result(self):
391 | self.padded_all_ques = []
392 | self.answers = []
393 |
394 | self.all_ques_lengths = []
395 | self.keys = self.data_df['video_id'].values.astype(np.int64)
396 |
397 | for index, row in self.data_df.iterrows():
398 | # ====== Question ======
399 | all_ques = self.get_question(index)
400 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length)
401 | self.padded_all_ques.append(all_ques_pad)
402 | all_ques_length = min(self.q_max_length, len(all_ques))
403 | self.all_ques_lengths.append(all_ques_length)
404 |
405 | answer = self.get_answer(index)
406 | if str(answer) in self.ans2idx:
407 | answer = self.ans2idx[answer]
408 | else:
409 | # unknown token, check later
410 | answer = 1
411 | self.answers.append(np.array(answer, dtype=np.int64))
412 |
413 | def get_item(self, index):
414 | key = self.keys[index]
415 |
416 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat \
417 | = self.get_video_feature(key)
418 |
419 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32)
420 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32)
421 |
422 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num,
423 | self.v_max_length).astype(np.float32)
424 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num,
425 | self.v_max_length).astype(np.float32)
426 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num,
427 | self.v_max_length).astype(np.float32)
428 | video_length = min(self.v_max_length, res_avg_feat.shape[0])
429 |
430 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \
431 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index]
432 |
--------------------------------------------------------------------------------
/data_utils/dataset_msvd.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | """Generate data batch."""
7 | import os
8 | import os.path
9 | import sys
10 | sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
11 |
12 | import pandas as pd
13 | import numpy as np
14 | import h5py
15 | from torch.utils.data import Dataset
16 | from data_utils import data_util
17 |
18 | from util import log
19 |
20 | import pickle as pkl
21 | import nltk
22 |
23 |
24 | nltk.download('punkt')
25 | # nltk.download('averaged_perceptron_tagger')
26 |
27 |
28 | __PATH__ = os.path.abspath(os.path.dirname(__file__))
29 |
30 |
31 | def assert_exists(path):
32 | assert os.path.exists(path), 'Does not exist : {}'.format(path)
33 |
34 | eos_word = ''
35 |
36 |
37 | class MSVDQA(Dataset):
38 |
39 | def __init__(
40 | self,
41 | dataset_name='train',
42 | q_max_length=20,
43 | v_max_length=80,
44 | max_n_videos=None,
45 | csv_dir=None,
46 | vocab_dir=None,
47 | feat_dir=None):
48 | self.csv_dir = csv_dir
49 | self.vocabulary_dir = vocab_dir
50 | self.feat_dir = feat_dir
51 | self.dataset_name = dataset_name
52 | self.q_max_length = q_max_length
53 | self.v_max_length = v_max_length
54 | self.obj_max_num = 10
55 | self.max_n_videos = max_n_videos
56 | self.GLOVE_EMBEDDING_SIZE = 300
57 |
58 | self.data_df = self.read_df_from_json()
59 |
60 | if max_n_videos is not None:
61 | self.data_df = self.data_df[:max_n_videos]
62 |
63 | self.res_avg_file = os.path.join(self.feat_dir, "vfeat/msvd_res152_avgpool.hdf5")
64 | self.i3d_avg_file = os.path.join(self.feat_dir, "vfeat/msvd_i3d_avgpool_perclip.hdf5")
65 |
66 | self.res_avg_feat = None
67 | self.i3d_avg_feat = None
68 |
69 | self.res_roi_file = os.path.join(self.feat_dir, "vfeat/msvd_btup_f_obj10.hdf5")
70 | self.i3d_roi_file = os.path.join(self.feat_dir, "vfeat/msvd_i3d_roialign_perclip_obj10.hdf5")
71 |
72 | self.res_roi_feat = None
73 | self.i3d_roi_feat = None
74 |
75 | self.load_word_vocabulary()
76 | self.get_result()
77 |
78 | def __getitem__(self, index):
79 | return self.get_item(index)
80 |
81 | def __len__(self):
82 | if self.max_n_videos is not None:
83 | if self.max_n_videos <= len(self.data_df):
84 | return self.max_n_videos
85 | return len(self.data_df)
86 |
87 | @property
88 | def n_words(self):
89 | ''' The dictionary size. '''
90 | if not hasattr(self, 'word2idx'):
91 | raise Exception('Dictionary not built yet!')
92 | return len(self.word2idx)
93 |
94 | def __repr__(self):
95 | if hasattr(self, 'word2idx'):
96 | return '' % (
97 | self.dataset_name, len(self), len(self.word2idx))
98 | else:
99 | return '' % (
100 | self.dataset_name, len(self))
101 |
102 | def read_df_from_json(self):
103 | if self.dataset_name == 'train':
104 | data_path = '%s/train_qa.json'%self.csv_dir
105 | elif self.dataset_name == 'val':
106 | data_path = '%s/val_qa.json'%self.csv_dir
107 | elif self.dataset_name == 'test':
108 | data_path = '%s/test_qa.json'%self.csv_dir
109 |
110 | with open(data_path, 'r') as f:
111 | data_df = pd.read_json(f)
112 |
113 | return data_df
114 |
115 | def split_sentence_into_words(self, sentence, eos=True):
116 | '''
117 | Split the given sentence (str) and enumerate the words as strs.
118 | Each word is normalized, i.e. lower-cased, non-alphabet characters
119 | like period (.) or comma (,) are stripped.
120 | When tokenizing, I use ``data_util.clean_str``
121 | '''
122 | try:
123 | words = data_util.clean_str(sentence).split()
124 | except:
125 | print(sentence)
126 | sys.exit()
127 | if eos:
128 | words = words + [eos_word]
129 | for w in words:
130 | if not w:
131 | continue
132 | yield w
133 |
134 | def create_answerset(self, ans_df):
135 | """Generate 1000 answer set from train_qa.json.
136 | Args:
137 | trainqa_path: path to train_qa.json.
138 | answerset_path: generate answer set of mc_qa
139 | """
140 | answer_freq = ans_df['answer'].value_counts()
141 | answer_freq = list(answer_freq.iloc[0:1000].keys())
142 | return answer_freq
143 |
144 | def build_word_vocabulary(
145 | self,
146 | all_sen=None,
147 | ans_df=None,
148 | word_count_threshold=0,
149 | ):
150 | '''
151 | borrowed this implementation from @karpathy's neuraltalk.
152 | '''
153 | log.infov('Building word vocabulary (%s) ...', self.dataset_name)
154 |
155 | if all_sen is None or ans_df is None:
156 | all_sen, ans_df = self.get_all_captions()
157 | all_captions_source = all_sen
158 |
159 | # enumerate all sentences to build frequency table
160 | word_counts = {}
161 | nsents = 0
162 | nwords = 0
163 | for sentence in all_captions_source:
164 | nsents += 1
165 | for w in self.split_sentence_into_words(sentence):
166 | word_counts[w] = word_counts.get(w, 0) + 1
167 | nwords += 1
168 |
169 | vocab = [
170 | w for w in word_counts if word_counts[w] >= word_count_threshold
171 | ]
172 | print("Filtered vocab words (threshold = %d), from %d to %d" % (word_count_threshold, len(word_counts), len(vocab)))
173 | log.info(
174 | "Filtered vocab words (threshold = %d), from %d to %d",
175 | word_count_threshold, len(word_counts), len(vocab))
176 |
177 | # build index and vocabularies
178 | self.word2idx = {}
179 | self.idx2word = {}
180 |
181 | self.idx2word[0] = '.'
182 | self.idx2word[1] = 'UNK'
183 | self.word2idx['#START#'] = 0
184 | self.word2idx['UNK'] = 1
185 | for idx, w in enumerate(vocab, start=2):
186 | self.word2idx[w] = idx
187 | self.idx2word[idx] = w
188 |
189 | pkl.dump(
190 | self.word2idx,
191 | open(os.path.join(self.vocabulary_dir, 'word_to_index.pkl'), 'wb')
192 | )
193 | pkl.dump(
194 | self.idx2word,
195 | open(os.path.join(self.vocabulary_dir, 'index_to_word.pkl'), 'wb')
196 | )
197 |
198 | word_counts['.'] = nsents
199 | bias_init_vector = np.array(
200 | [
201 | 1.0 * word_counts[w] if i > 1 else 0
202 | for i, w in self.idx2word.items()
203 | ])
204 | bias_init_vector /= np.sum(bias_init_vector) # normalize to frequencies
205 | bias_init_vector = np.log(bias_init_vector + 1e-20)
206 | bias_init_vector -= np.max(
207 | bias_init_vector) # shift to nice numeric range
208 | self.bias_init_vector = bias_init_vector
209 |
210 | #self.total_q = pd.DataFrame().from_csv(os.path.join(csv_dir,'Total_desc_question.csv'), sep='\t')
211 | answers = self.create_answerset(ans_df)
212 | print("all answer num : %d"%len(answers))
213 | self.ans2idx = {}
214 | self.idx2ans = {}
215 | for idx, w in enumerate(answers):
216 | self.ans2idx[w] = idx
217 | self.idx2ans[idx] = w
218 | pkl.dump(
219 | self.ans2idx,
220 | open(os.path.join(self.vocabulary_dir, 'ans_to_index.pkl'), 'wb')
221 | )
222 | pkl.dump(
223 | self.idx2ans,
224 | open(os.path.join(self.vocabulary_dir, 'index_to_ans.pkl'), 'wb')
225 | )
226 |
227 | # Make glove embedding.
228 | #import spacy
229 | #nlp = spacy.load('en', vectors='en_glove_cc_300_1m_vectors')
230 |
231 | with open(self.vocabulary_dir + '/glove.42B.300d.txt', 'rb') as f:
232 | lines = f.readlines()
233 | dict = {}
234 |
235 | for ix, line in enumerate(lines):
236 | if ix % 1000 == 0:
237 | print((ix, len(lines)))
238 | segs = line.split()
239 | wd = segs[0]
240 | dict[wd] = segs[1:]
241 |
242 | #max_length = len(vocab)
243 | max_length = len(self.idx2word)
244 | GLOVE_EMBEDDING_SIZE = 300
245 |
246 | glove_matrix = np.random.normal(size=[max_length, GLOVE_EMBEDDING_SIZE])
247 | not_in_vocab = 0
248 | for i in range(len(vocab)):
249 | if i % 1000 == 0:
250 | print((i, len(vocab)))
251 | w = bytes(vocab[i], encoding='utf-8')
252 | #w_embed = nlp(u'%s' % w).vector
253 | if w in dict:
254 | w_embed = np.array(dict[w])
255 | else:
256 | not_in_vocab += 1
257 | w_embed = np.random.normal(size=(GLOVE_EMBEDDING_SIZE))
258 | #glove_matrix[i,:] = w_embed
259 | glove_matrix[i + 2, :] = w_embed # two placeholder words '.','UNK'
260 |
261 | print("[%d/%d] word is not saved" % (not_in_vocab, len(vocab)))
262 |
263 | vocab = pkl.dump(
264 | glove_matrix,
265 | open(os.path.join(self.vocabulary_dir, 'vocab_embedding.pkl'), 'wb')
266 | )
267 | self.word_matrix = glove_matrix
268 |
269 | def load_word_vocabulary(self):
270 |
271 | word_matrix_path = os.path.join(
272 | self.vocabulary_dir, 'vocab_embedding.pkl')
273 |
274 | word2idx_path = os.path.join(
275 | self.vocabulary_dir, 'word_to_index.pkl')
276 | idx2word_path = os.path.join(
277 | self.vocabulary_dir, 'index_to_word.pkl')
278 | ans2idx_path = os.path.join(
279 | self.vocabulary_dir, 'ans_to_index.pkl')
280 | idx2ans_path = os.path.join(
281 | self.vocabulary_dir, 'index_to_ans.pkl')
282 |
283 | if not (os.path.exists(word_matrix_path) and
284 | os.path.exists(word2idx_path) and
285 | os.path.exists(idx2word_path) and
286 | os.path.exists(ans2idx_path) and os.path.exists(idx2ans_path)):
287 | self.build_word_vocabulary()
288 |
289 | # ndarray
290 | with open(word_matrix_path, 'rb') as f:
291 | self.word_matrix = pkl.load(f)
292 | log.info("Load word_matrix from pkl file : %s", word_matrix_path)
293 |
294 | with open(word2idx_path, 'rb') as f:
295 | self.word2idx = pkl.load(f)
296 | log.info("Load word2idx from pkl file : %s", word2idx_path)
297 |
298 | with open(idx2word_path, 'rb') as f:
299 | self.idx2word = pkl.load(f)
300 | log.info("Load idx2word from pkl file : %s", idx2word_path)
301 |
302 | with open(ans2idx_path, 'rb') as f:
303 | self.ans2idx = pkl.load(f)
304 | log.info("Load answer2idx from pkl file : %s", ans2idx_path)
305 |
306 | with open(idx2ans_path, 'rb') as f:
307 | self.idx2ans = pkl.load(f)
308 | log.info("Load idx2answers from pkl file : %s", idx2ans_path)
309 |
310 | def get_all_captions(self):
311 | '''
312 | Iterate caption strings associated in the vid/gifs.
313 | '''
314 | data_path = '%s/train_qa.json'%self.csv_dir
315 | with open(data_path, 'r') as f:
316 | data_df = pd.read_json(f)
317 |
318 | all_sents = list(data_df['question'])
319 | return all_sents, data_df
320 |
321 | def get_video_feature(self, key): # key : gif_name
322 | if self.res_avg_feat is None:
323 | self.res_avg_feat = h5py.File(self.res_avg_file, 'r')
324 | if self.i3d_avg_feat is None:
325 | self.i3d_avg_feat = h5py.File(self.i3d_avg_file, 'r')
326 | if self.res_roi_feat is None:
327 | self.res_roi_feat = h5py.File(self.res_roi_file, 'r')
328 | if self.i3d_roi_feat is None:
329 | self.i3d_roi_feat = h5py.File(self.i3d_roi_file, 'r')
330 |
331 | video_id = 'vid' + str(key)
332 |
333 | try:
334 | res_avg_feat = np.array(self.res_avg_feat[video_id]) # T, d
335 | i3d_avg_feat = np.array(self.i3d_avg_feat[video_id]) # T, d
336 | res_roi_feat = np.array(self.res_roi_feat['image_features'][video_id]) # T, 5, d
337 | roi_bbox_feat = np.array(self.res_roi_feat['spatial_features'][video_id]) # T, 5, 6
338 | i3d_roi_feat = np.array(self.i3d_roi_feat[video_id]) # T, 5, d
339 | except KeyError: # no img
340 | print('no image', key)
341 | res_avg_feat = np.zeros((1, 2048))
342 | i3d_avg_feat = np.zeros((1, 2048))
343 | res_roi_feat = np.zeros((1, self.obj_max_num, 2048))
344 | roi_bbox_feat = np.zeros((1, self.obj_max_num, 6))
345 | i3d_roi_feat = np.zeros((1, self.obj_max_num, 2048))
346 |
347 | return res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat
348 |
349 | def convert_sentence_to_matrix(self, sentence, eos=True):
350 | '''
351 | Convert the given sentence into word indices and masks.
352 | WARNING: Unknown words (not in vocabulary) are revmoed.
353 |
354 | Args:
355 | sentence: A str for unnormalized sentence, containing T words
356 |
357 | Returns:
358 | sentence_word_indices : list of (at most) length T,
359 | each being a word index
360 | '''
361 | sent2indices = [
362 | self.word2idx[w] if w in self.word2idx else 1
363 | for w in sentence
364 | ] # 1 is UNK, unknown
365 | T = len(sent2indices)
366 | length = min(T, self.q_max_length)
367 | return sent2indices[:length]
368 |
369 | def get_question(self, key):
370 | '''
371 | Return question index for given key.
372 | '''
373 | question = self.data_df.loc[key, ['question']].values
374 | question = question[0]
375 |
376 | question = self.split_sentence_into_words(question, eos=False)
377 | q_refine = []
378 | for w in question:
379 | q_refine.append(w)
380 | q_refine = self.convert_sentence_to_matrix(q_refine)
381 | return q_refine
382 |
383 | def get_answer(self, key):
384 | answer = self.data_df.loc[key, ['answer']].values
385 | answer = answer[0]
386 |
387 | return answer
388 |
389 | def get_result(self):
390 | self.padded_all_ques = []
391 | self.answers = []
392 |
393 | self.all_ques_lengths = []
394 | self.keys = self.data_df['video_id'].values.astype(np.int64)
395 |
396 | for index, row in self.data_df.iterrows():
397 | # ====== Question ======
398 | all_ques = self.get_question(index)
399 | all_ques_pad = data_util.question_pad(all_ques, self.q_max_length)
400 | self.padded_all_ques.append(all_ques_pad)
401 | all_ques_length = min(self.q_max_length, len(all_ques))
402 | self.all_ques_lengths.append(all_ques_length)
403 |
404 | answer = self.get_answer(index)
405 | if str(answer) in self.ans2idx:
406 | answer = self.ans2idx[answer]
407 | else:
408 | # unknown token, check later
409 | answer = 1
410 | self.answers.append(np.array(answer, dtype=np.int64))
411 |
412 | def get_item(self, index):
413 | key = self.keys[index]
414 |
415 | res_avg_feat, i3d_avg_feat, res_roi_feat, roi_bbox_feat, i3d_roi_feat \
416 | = self.get_video_feature(key)
417 |
418 | res_avg_pad = data_util.video_pad(res_avg_feat, self.v_max_length).astype(np.float32)
419 | i3d_avg_pad = data_util.video_pad(i3d_avg_feat, self.v_max_length).astype(np.float32)
420 |
421 | res_roi_pad = data_util.video_3d_pad(res_roi_feat, self.obj_max_num,
422 | self.v_max_length).astype(np.float32)
423 | bbox_pad = data_util.video_3d_pad(roi_bbox_feat, self.obj_max_num,
424 | self.v_max_length).astype(np.float32)
425 | i3d_roi_pad = data_util.video_3d_pad(i3d_roi_feat, self.obj_max_num,
426 | self.v_max_length).astype(np.float32)
427 | video_length = min(self.v_max_length, res_avg_feat.shape[0])
428 |
429 | return res_avg_pad, i3d_avg_pad, res_roi_pad, bbox_pad, i3d_roi_pad, video_length, \
430 | self.padded_all_ques[index], self.all_ques_lengths[index], self.answers[index]
431 |
--------------------------------------------------------------------------------
/embed_loss.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | # __all__ = ['MultipleChoiceLoss', 'CountLoss']
11 |
12 |
13 | class MultipleChoiceLoss(nn.Module):
14 |
15 | def __init__(self, num_option=5, margin=1, size_average=True):
16 | super(MultipleChoiceLoss, self).__init__()
17 | self.margin = margin
18 | self.num_option = num_option
19 | self.size_average = size_average
20 |
21 | # score is N x C
22 |
23 | def forward(self, score, target):
24 | N = score.size(0)
25 | C = score.size(1)
26 | assert self.num_option == C
27 |
28 | loss = torch.tensor(0.0).cuda()
29 | zero = torch.tensor(0.0).cuda()
30 |
31 | cnt = 0
32 | #print(N,C)
33 | for b in range(N):
34 | # loop over incorrect answer, check if correct answer's score larger than a margin
35 | c0 = target[b]
36 | for c in range(C):
37 | if c == c0:
38 | continue
39 |
40 | # right class and wrong class should have score difference larger than a margin
41 | # see formula under paper Eq(4)
42 | loss += torch.max(zero, 1.0 + score[b, c] - score[b, c0])
43 | cnt += 1
44 |
45 | if cnt == 0:
46 | return loss
47 |
48 | return loss / cnt if self.size_average else loss
49 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | import os
7 | import argparse
8 | import torch
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import random
13 | import h5py
14 |
15 | seed = 999
16 |
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | torch.backends.cudnn.enabled = False
22 | # torch.backends.cudnn.benchmark = True
23 | # torch.backends.cudnn.deterministic = True
24 |
25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
26 |
27 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
28 |
29 |
30 | def _init_fn(worker_id):
31 | np.random.seed(seed)
32 |
33 |
34 | from data_utils.dataset import TGIFQA
35 | from torch.utils.data import DataLoader
36 | from warmup_scheduler import GradualWarmupScheduler
37 |
38 | from model.masn import MASN
39 |
40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41 |
42 |
43 | def main():
44 | """Main script."""
45 |
46 | args.pin_memory = False
47 | args.dataset = 'tgif_qa'
48 | args.log = './logs/%s' % args.model_name
49 | if not os.path.exists(args.log):
50 | os.mkdir(args.log)
51 |
52 | args.val_epoch_step = 1
53 |
54 | args.save_model_path = os.path.join(args.save_path, args.model_name)
55 | if not os.path.exists(args.save_model_path):
56 | os.makedirs(args.save_model_path)
57 |
58 | full_dataset = TGIFQA(
59 | dataset_name='train',
60 | q_max_length=args.q_max_length,
61 | v_max_length=args.v_max_length,
62 | max_n_videos=args.max_n_videos,
63 | data_type=args.task,
64 | csv_dir=args.df_dir,
65 | vocab_dir=args.vc_dir,
66 | feat_dir=args.feat_dir)
67 | test_dataset = TGIFQA(
68 | dataset_name='test',
69 | q_max_length=args.q_max_length,
70 | v_max_length=args.v_max_length,
71 | max_n_videos=args.max_n_videos,
72 | data_type=args.task,
73 | csv_dir=args.df_dir,
74 | vocab_dir=args.vc_dir,
75 | feat_dir=args.feat_dir)
76 |
77 | val_size = int(args.val_ratio * len(full_dataset))
78 | train_size = len(full_dataset) - val_size
79 | train_dataset, val_dataset = torch.utils.data.random_split(
80 | full_dataset, [train_size, val_size])
81 | print(
82 | 'Dataset lengths train/val/test %d/%d/%d' %
83 | (len(train_dataset), len(val_dataset), len(test_dataset)))
84 |
85 | train_dataloader = DataLoader(
86 | train_dataset,
87 | args.batch_size,
88 | shuffle=True,
89 | num_workers=args.num_workers,
90 | pin_memory=args.pin_memory,
91 | worker_init_fn=_init_fn)
92 | val_dataloader = DataLoader(
93 | val_dataset,
94 | args.batch_size,
95 | shuffle=False,
96 | num_workers=args.num_workers,
97 | pin_memory=args.pin_memory,
98 | worker_init_fn=_init_fn)
99 | test_dataloader = DataLoader(
100 | test_dataset,
101 | args.batch_size,
102 | shuffle=False,
103 | num_workers=args.num_workers,
104 | pin_memory=args.pin_memory,
105 | worker_init_fn=_init_fn)
106 |
107 | print('Load data successful.')
108 |
109 | args.resnet_input_size = 2048
110 | args.i3d_input_size = 2048
111 |
112 | args.text_embed_size = train_dataset.dataset.GLOVE_EMBEDDING_SIZE
113 | args.answer_vocab_size = None
114 |
115 | args.word_matrix = train_dataset.dataset.word_matrix
116 | args.voc_len = args.word_matrix.shape[0]
117 | assert args.text_embed_size == args.word_matrix.shape[1]
118 |
119 | VOCABULARY_SIZE = train_dataset.dataset.n_words
120 | assert VOCABULARY_SIZE == args.voc_len
121 |
122 | ### criterions
123 | if args.task == 'Count':
124 | # add L2 loss
125 | criterion = nn.MSELoss().to(device)
126 | elif args.task in ['Action', 'Trans']:
127 | from embed_loss import MultipleChoiceLoss
128 | criterion = MultipleChoiceLoss(
129 | num_option=5, margin=1, size_average=True).to(device)
130 | elif args.task == 'FrameQA':
131 | # add classification loss
132 | args.answer_vocab_size = len(train_dataset.dataset.ans2idx)
133 | print(('Vocabulary size', args.answer_vocab_size, VOCABULARY_SIZE))
134 | criterion = nn.CrossEntropyLoss().to(device)
135 |
136 | if not args.test:
137 | train(
138 | args, train_dataloader, val_dataloader, test_dataloader, criterion)
139 | else:
140 | print(args.checkpoint[:5], args.task[:5])
141 | model = torch.load(os.path.join(args.save_model_path, args.checkpoint))
142 | test(args, model, test_dataloader, 0, criterion)
143 |
144 |
145 | def train(args, train_dataloader, val_dataloader, test_dataloader, criterion):
146 | model = MASN(
147 | args.voc_len,
148 | args.rnn_layers,
149 | args.word_matrix,
150 | args.resnet_input_size,
151 | args.i3d_input_size,
152 | args.hidden_size,
153 | dropout_p=args.dropout,
154 | gcn_layers=args.gcn_layers,
155 | answer_vocab_size=args.answer_vocab_size,
156 | q_max_len=args.q_max_length,
157 | v_max_len=args.v_max_length,
158 | ablation=args.ablation)
159 |
160 | if torch.cuda.device_count() > 1:
161 | print("Let's use", torch.cuda.device_count(), "GPUs!")
162 | model = nn.DataParallel(model)
163 |
164 | model.to(device)
165 |
166 | if args.change_lr == 'none':
167 | optimizer = torch.optim.Adam(
168 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
169 | elif args.change_lr == 'acc':
170 | optimizer = torch.optim.Adam(
171 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
172 | # val plateau scheduler
173 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
174 | optimizer, mode='max', factor=0.1, patience=3, verbose=True)
175 | # target lr = args.lr * multiplier
176 | scheduler_warmup = GradualWarmupScheduler(
177 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
178 | elif args.change_lr == 'loss':
179 | optimizer = torch.optim.Adam(
180 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
181 | # val plateau scheduler
182 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
183 | optimizer, mode='min', factor=0.1, patience=3, verbose=True)
184 | # target lr = args.lr * multiplier
185 | scheduler_warmup = GradualWarmupScheduler(
186 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
187 | elif args.change_lr == 'cos':
188 | optimizer = torch.optim.Adam(
189 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
190 | # consine annealing
191 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
192 | optimizer, args.max_epoch)
193 | # target lr = args.lr * multiplier
194 | scheduler_warmup = GradualWarmupScheduler(
195 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
196 | elif args.change_lr == 'step':
197 | optimizer = torch.optim.Adam(
198 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
199 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
200 | optimizer, milestones=args.lr_list, gamma=0.1)
201 | # scheduler_warmup = GradualWarmupScheduler(
202 | # optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
203 |
204 | best_val_acc = 0. if args.task != 'Count' else -100.
205 |
206 | for epoch in range(args.max_epoch):
207 | print('Start Training Epoch: {}'.format(epoch))
208 |
209 | model.train()
210 |
211 | loss_list = []
212 | prediction_list = []
213 | correct_answer_list = []
214 |
215 | if args.change_lr == 'cos':
216 | # consine annealing
217 | scheduler_warmup.step(epoch=epoch)
218 |
219 | for ii, data in enumerate(train_dataloader):
220 | if epoch == 0 and ii == 0:
221 | print([d.dtype for d in data], [d.size() for d in data])
222 | data = [d.to(device) for d in data]
223 |
224 | optimizer.zero_grad()
225 | out, predictions, answers = model(args.task, *data)
226 | loss = criterion(out, answers)
227 | loss.backward()
228 | optimizer.step()
229 |
230 | correct_answer_list.append(answers)
231 | loss_list.append(loss.item())
232 | prediction_list.append(predictions.detach())
233 | if ii % 100 == 0:
234 | print("Batch: ", ii)
235 |
236 | train_loss = np.mean(loss_list)
237 |
238 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
239 | predict_answer = torch.cat(prediction_list, dim=0).long()
240 | assert correct_answer.shape == predict_answer.shape
241 |
242 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
243 | acc = current_num / len(correct_answer) * 100.
244 |
245 | # print('Learning Rate: {}'.format(optimizer.param_groups[0]['lr']))
246 | if args.change_lr == 'acc':
247 | scheduler_warmup.step(epoch, val_acc)
248 | elif args.change_lr == 'loss':
249 | scheduler_warmup.step(epoch, val_loss)
250 | elif args.change_lr == 'step':
251 | scheduler.step()
252 |
253 | print(
254 | "Train|Epoch: {}, Acc : {:.3f}={}/{}, Train Loss: {:.3f}".format(
255 | epoch, acc, current_num, len(correct_answer), train_loss))
256 | if args.task == 'Count':
257 | count_loss = F.mse_loss(
258 | predict_answer.float(), correct_answer.float())
259 | print('Train|Count Real Loss:\t {:.3f}'.format(count_loss))
260 |
261 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
262 | logfile.write(
263 | "Train|Epoch: %d, Acc : %.3f=%d/%d, Train Loss: %.3f\n"
264 | % (epoch, acc, current_num, len(correct_answer), train_loss)
265 | )
266 | if args.task == 'Count':
267 | logfile.write(
268 | "Train|Count Real Loss:\t %.3f\n"%count_loss
269 | )
270 | logfile.close()
271 |
272 | val_acc, val_loss = val(args, model, val_dataloader, epoch, criterion)
273 |
274 | if val_acc > best_val_acc:
275 | print('Best Val Acc ======')
276 | best_val_acc = val_acc
277 | if epoch % args.val_epoch_step == 0 or val_acc >= best_val_acc:
278 | test(args, model, test_dataloader, epoch, criterion)
279 |
280 |
281 | @torch.no_grad()
282 | def val(args, model, val_dataloader, epoch, criterion):
283 | model.eval()
284 |
285 | loss_list = []
286 | prediction_list = []
287 | correct_answer_list = []
288 |
289 | for ii, data in enumerate(val_dataloader):
290 | data = [d.to(device) for d in data]
291 |
292 | out, predictions, answers = model(args.task, *data)
293 | loss = criterion(out, answers)
294 |
295 | correct_answer_list.append(answers)
296 | loss_list.append(loss.item())
297 | prediction_list.append(predictions.detach())
298 |
299 | val_loss = np.mean(loss_list)
300 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
301 | predict_answer = torch.cat(prediction_list, dim=0).long()
302 | assert correct_answer.shape == predict_answer.shape
303 |
304 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
305 |
306 | acc = current_num / len(correct_answer) * 100.
307 |
308 | print(
309 | "VAL|Epoch: {}, Acc: {:3f}={}/{}, Val Loss: {:3f}".format(
310 | epoch, acc, current_num, len(correct_answer), val_loss))
311 |
312 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
313 | logfile.write(
314 | "VAL|Epoch: %d, Acc: %.3f=%d/%d, Val Loss: %.3f\n"
315 | % (epoch, acc, current_num, len(correct_answer), val_loss)
316 | )
317 | logfile.close()
318 |
319 | if args.task == 'Count':
320 | print(
321 | 'VAL|Count Real Loss:\t {:.3f}'.format(
322 | F.mse_loss(predict_answer.float(), correct_answer.float())))
323 | acc = -F.mse_loss(predict_answer.float(), correct_answer.float())
324 |
325 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
326 | logfile.write(
327 | "VAL|Count Real Loss:\t %.3f}\n"
328 | % (F.mse_loss(predict_answer.float(), correct_answer.float()))
329 | )
330 | logfile.close()
331 | return acc, val_loss
332 |
333 |
334 | @torch.no_grad()
335 | def test(args, model, test_dataloader, epoch, criterion):
336 |
337 | model.eval()
338 |
339 | loss_list = []
340 | prediction_list = []
341 | correct_answer_list = []
342 |
343 | for ii, data in enumerate(test_dataloader):
344 | data = [d.to(device) for d in data]
345 |
346 | out, predictions, answers = model(args.task, *data)
347 | loss = criterion(out, answers)
348 |
349 | correct_answer_list.append(answers)
350 | loss_list.append(loss.item())
351 | prediction_list.append(predictions.detach())
352 |
353 | test_loss = np.mean(loss_list)
354 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
355 | predict_answer = torch.cat(prediction_list, dim=0).long()
356 | assert correct_answer.shape == predict_answer.shape
357 |
358 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
359 |
360 | acc = current_num / len(correct_answer) * 100.
361 |
362 | print(
363 | "Test|Epoch: {}, Acc: {:3f}={}/{}, Test Loss: {:3f}".format(
364 | epoch, acc, current_num, len(correct_answer), test_loss))
365 |
366 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
367 | logfile.write(
368 | "Test|Epoch: %d, Acc: %.3f=%d/%d, Test Loss: %.3f\n"
369 | % (epoch, acc, current_num, len(correct_answer), test_loss)
370 | )
371 | logfile.close()
372 |
373 | if args.save:
374 | if (args.task == 'Action' and
375 | acc >= 80) or (args.task == 'Trans' and
376 | acc >= 80) or (args.task == 'FrameQA' and
377 | acc >= 55):
378 | torch.save(
379 | model, os.path.join(args.save_model_path,
380 | args.task + '_' + str(acc.item())[:5] + '.pth'))
381 | print('Save model at ', args.save_model_path)
382 |
383 | if args.task == 'Count':
384 | count_loss = F.mse_loss(predict_answer.float(), correct_answer.float())
385 | print('Test|Count Real Loss:\t {:.3f}'.format(count_loss))
386 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
387 | logfile.write(
388 | 'Test|Count Real Loss:\t %.3f\n' % (count_loss)
389 | )
390 | logfile.close()
391 | if args.save and count_loss <= 4.0:
392 | torch.save(
393 | model, os.path.join(args.save_model_path,
394 | args.task + '_' + str(count_loss.item())[:5] + '.pth'))
395 | print('Save model at ', args.save_model_path)
396 |
397 |
398 | if __name__ == '__main__':
399 | parser = argparse.ArgumentParser()
400 |
401 | parser.add_argument('--model_name', type=str, default='[MASN]')
402 |
403 | ################ path config ################
404 | parser.add_argument('--feat_dir', default='/data/TGIFQA', help='path for resnet and i3d features')
405 | parser.add_argument('--vc_dir', default='/data/TGIFQA/vocab', help='path for vocabulary')
406 | parser.add_argument('--df_dir', default='/data/TGIFQA/question', help='path for tgif question csv files')
407 |
408 | ################ inference config ################
409 | parser.add_argument(
410 | '--checkpoint',
411 | type=str,
412 | default='FrameQA_59.73.pth',
413 | help='path to checkpoint')
414 | parser.add_argument(
415 | '--save_path',
416 | type=str,
417 | default='./saved_models/',
418 | help='path for saving trained models')
419 |
420 | parser.add_argument(
421 | '--save', action='store_true', default=True, help='save models or not')
422 | parser.add_argument(
423 | '--hidden_size',
424 | type=int,
425 | default=512,
426 | help='dimension of model')
427 | parser.add_argument(
428 | '--test', action='store_true', default=False, help='Train or Test')
429 | parser.add_argument('--max_epoch', type=int, default=100)
430 | parser.add_argument('--val_ratio', type=float, default=0.1)
431 | parser.add_argument('--q_max_length', type=int, default=20)
432 | parser.add_argument('--v_max_length', type=int, default=20)
433 |
434 | parser.add_argument(
435 | '--task',
436 | type=str,
437 | default='Count',
438 | help='[Count, Action, FrameQA, Trans]')
439 | parser.add_argument(
440 | '--rnn_layers', type=int, default=1, help='number of layers in lstm')
441 | parser.add_argument(
442 | '--gcn_layers',
443 | type=int,
444 | default=2,
445 | help='number of layers in gcn (+1)')
446 | parser.add_argument('--batch_size', type=int, default=32)
447 | parser.add_argument('--max_n_videos', type=int, default=100000)
448 | parser.add_argument('--num_workers', type=int, default=1)
449 | parser.add_argument('--lr', type=float, default=0.0001)
450 | parser.add_argument('--lr_list', type=list, default=[10, 20, 30, 40])
451 | parser.add_argument('--dropout', type=float, default=0.3)
452 | parser.add_argument(
453 | '--change_lr', type=str, default='none', help='0 False, 1 True')
454 | parser.add_argument(
455 | '--weight_decay', type=float, default=0, help='weight_decay')
456 | parser.add_argument('--ablation', type=str, default='none')
457 |
458 | args = parser.parse_args()
459 | print(args)
460 |
461 | main()
--------------------------------------------------------------------------------
/main_msrvtt.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | import os
7 | import argparse
8 | import torch
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import random
13 | import h5py
14 |
15 | seed = 999
16 |
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | torch.backends.cudnn.enabled = False
22 | # torch.backends.cudnn.benchmark = True
23 | # torch.backends.cudnn.deterministic = True
24 |
25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
26 |
27 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
28 |
29 |
30 | def _init_fn(worker_id):
31 | np.random.seed(seed)
32 |
33 |
34 | from data_utils.dataset_msrvtt import MSRVTTQA
35 | from torch.utils.data import DataLoader
36 | from warmup_scheduler import GradualWarmupScheduler
37 |
38 | from model.masn import MASN
39 |
40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41 |
42 |
43 | def main():
44 | """Main script."""
45 |
46 | args.pin_memory = False
47 | args.dataset = 'msrvtt_qa'
48 | args.log = './logs/%s' % args.model_name
49 | if not os.path.exists(args.log):
50 | os.mkdir(args.log)
51 |
52 | args.val_epoch_step = 1
53 |
54 | args.save_model_path = os.path.join(args.save_path, args.model_name)
55 | if not os.path.exists(args.save_model_path):
56 | os.makedirs(args.save_model_path)
57 |
58 | train_dataset = MSRVTTQA(
59 | dataset_name='train',
60 | q_max_length=args.q_max_length,
61 | v_max_length=args.v_max_length,
62 | max_n_videos=args.max_n_videos,
63 | csv_dir=args.df_dir,
64 | vocab_dir=args.vc_dir,
65 | feat_dir=args.feat_dir)
66 | val_dataset = MSRVTTQA(
67 | dataset_name='val',
68 | q_max_length=args.q_max_length,
69 | v_max_length=args.v_max_length,
70 | max_n_videos=args.max_n_videos,
71 | csv_dir=args.df_dir,
72 | vocab_dir=args.vc_dir,
73 | feat_dir=args.feat_dir)
74 | test_dataset = MSRVTTQA(
75 | dataset_name='test',
76 | q_max_length=args.q_max_length,
77 | v_max_length=args.v_max_length,
78 | max_n_videos=args.max_n_videos,
79 | csv_dir=args.df_dir,
80 | vocab_dir=args.vc_dir,
81 | feat_dir=args.feat_dir)
82 |
83 | print(
84 | 'Dataset lengths train/val/test %d/%d/%d' %
85 | (len(train_dataset), len(val_dataset), len(test_dataset)))
86 |
87 | train_dataloader = DataLoader(
88 | train_dataset,
89 | args.batch_size,
90 | shuffle=True,
91 | num_workers=args.num_workers,
92 | pin_memory=args.pin_memory,
93 | worker_init_fn=_init_fn)
94 | val_dataloader = DataLoader(
95 | val_dataset,
96 | args.batch_size,
97 | shuffle=False,
98 | num_workers=args.num_workers,
99 | pin_memory=args.pin_memory,
100 | worker_init_fn=_init_fn)
101 | test_dataloader = DataLoader(
102 | test_dataset,
103 | args.batch_size,
104 | shuffle=False,
105 | num_workers=args.num_workers,
106 | pin_memory=args.pin_memory,
107 | worker_init_fn=_init_fn)
108 |
109 | print('Load data successful.')
110 |
111 | args.resnet_input_size = 2048
112 | args.c3d_input_size = 2048
113 |
114 | args.text_embed_size = train_dataset.GLOVE_EMBEDDING_SIZE
115 | args.answer_vocab_size = None
116 |
117 | args.word_matrix = train_dataset.word_matrix
118 | args.voc_len = args.word_matrix.shape[0]
119 | assert args.text_embed_size == args.word_matrix.shape[1]
120 |
121 | VOCABULARY_SIZE = train_dataset.n_words
122 | assert VOCABULARY_SIZE == args.voc_len
123 |
124 | # add classification loss
125 | args.answer_vocab_size = len(train_dataset.ans2idx)
126 | print(('Vocabulary size', args.answer_vocab_size, VOCABULARY_SIZE))
127 | criterion = nn.CrossEntropyLoss().to(device)
128 |
129 | if not args.test:
130 | train(
131 | args, train_dataloader, val_dataloader, test_dataloader, criterion)
132 | else:
133 | model = torch.load(os.path.join(args.save_model_path, args.checkpoint))
134 | test(args, model, test_dataloader, 0, criterion)
135 |
136 |
137 | def train(args, train_dataloader, val_dataloader, test_dataloader, criterion):
138 | model = MASN(
139 | args.voc_len,
140 | args.rnn_layers,
141 | args.word_matrix,
142 | args.resnet_input_size,
143 | args.i3d_input_size,
144 | args.hidden_size,
145 | dropout_p=args.dropout,
146 | gcn_layers=args.gcn_layers,
147 | answer_vocab_size=args.answer_vocab_size,
148 | q_max_len=args.q_max_length,
149 | v_max_len=args.v_max_length,
150 | ablation=args.ablation)
151 |
152 | if torch.cuda.device_count() > 1:
153 | print("Let's use", torch.cuda.device_count(), "GPUs!")
154 | model = nn.DataParallel(model)
155 |
156 | model.to(device)
157 |
158 | if args.change_lr == 'none':
159 | optimizer = torch.optim.Adam(
160 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
161 | elif args.change_lr == 'acc':
162 | optimizer = torch.optim.Adam(
163 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
164 | # val plateau scheduler
165 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
166 | optimizer, mode='max', factor=0.1, patience=3, verbose=True)
167 | # target lr = args.lr * multiplier
168 | scheduler_warmup = GradualWarmupScheduler(
169 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
170 | elif args.change_lr == 'loss':
171 | optimizer = torch.optim.Adam(
172 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
173 | # val plateau scheduler
174 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
175 | optimizer, mode='min', factor=0.1, patience=3, verbose=True)
176 | # target lr = args.lr * multiplier
177 | scheduler_warmup = GradualWarmupScheduler(
178 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
179 | elif args.change_lr == 'cos':
180 | optimizer = torch.optim.Adam(
181 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
182 | # consine annealing
183 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
184 | optimizer, args.max_epoch)
185 | # target lr = args.lr * multiplier
186 | scheduler_warmup = GradualWarmupScheduler(
187 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
188 | elif args.change_lr == 'step':
189 | optimizer = torch.optim.Adam(
190 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
191 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
192 | optimizer, milestones=args.lr_list, gamma=0.1)
193 | # scheduler_warmup = GradualWarmupScheduler(
194 | # optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
195 |
196 | best_val_acc = 0.
197 |
198 | for epoch in range(args.max_epoch):
199 | print('Start Training Epoch: {}'.format(epoch))
200 |
201 | model.train()
202 |
203 | loss_list = []
204 | prediction_list = []
205 | correct_answer_list = []
206 |
207 | if args.change_lr == 'cos':
208 | # consine annealing
209 | scheduler_warmup.step(epoch=epoch)
210 |
211 | for ii, data in enumerate(train_dataloader):
212 | if epoch == 0 and ii == 0:
213 | print([d.dtype for d in data], [d.size() for d in data])
214 | data = [d.to(device) for d in data]
215 |
216 | optimizer.zero_grad()
217 | out, predictions, answers = model(args.task, *data)
218 | loss = criterion(out, answers)
219 | loss.backward()
220 | optimizer.step()
221 |
222 | correct_answer_list.append(answers)
223 | loss_list.append(loss.item())
224 | prediction_list.append(predictions.detach())
225 | if ii % 100 == 0:
226 | print("Batch: ", ii)
227 |
228 | train_loss = np.mean(loss_list)
229 |
230 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
231 | predict_answer = torch.cat(prediction_list, dim=0).long()
232 | assert correct_answer.shape == predict_answer.shape
233 |
234 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
235 | acc = current_num / len(correct_answer) * 100.
236 |
237 | # print('Learning Rate: {}'.format(optimizer.param_groups[0]['lr']))
238 | if args.change_lr == 'acc':
239 | scheduler_warmup.step(epoch, val_acc)
240 | elif args.change_lr == 'loss':
241 | scheduler_warmup.step(epoch, val_loss)
242 | elif args.change_lr == 'step':
243 | scheduler.step()
244 |
245 | print(
246 | "Train|Epoch: {}, Acc : {:.3f}={}/{}, Train Loss: {:.3f}".format(
247 | epoch, acc, current_num, len(correct_answer), train_loss))
248 |
249 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
250 | logfile.write(
251 | "Train|Epoch: %d, Acc : %.3f=%d/%d, Train Loss: %.3f\n"
252 | % (epoch, acc, current_num, len(correct_answer), train_loss)
253 | )
254 | logfile.close()
255 |
256 | val_acc, val_loss = val(args, model, val_dataloader, epoch, criterion)
257 |
258 | if val_acc > best_val_acc:
259 | print('Best Val Acc ======')
260 | best_val_acc = val_acc
261 | if epoch % args.val_epoch_step == 0 or val_acc >= best_val_acc:
262 | test(args, model, test_dataloader, epoch, criterion)
263 |
264 |
265 | @torch.no_grad()
266 | def val(args, model, val_dataloader, epoch, criterion):
267 | model.eval()
268 |
269 | loss_list = []
270 | prediction_list = []
271 | correct_answer_list = []
272 |
273 | for ii, data in enumerate(val_dataloader):
274 | data = [d.to(device) for d in data]
275 |
276 | out, predictions, answers = model(args.task, *data)
277 | loss = criterion(out, answers)
278 |
279 | correct_answer_list.append(answers)
280 | loss_list.append(loss.item())
281 | prediction_list.append(predictions.detach())
282 |
283 | val_loss = np.mean(loss_list)
284 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
285 | predict_answer = torch.cat(prediction_list, dim=0).long()
286 | assert correct_answer.shape == predict_answer.shape
287 |
288 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
289 |
290 | acc = current_num / len(correct_answer) * 100.
291 |
292 | print(
293 | "VAL|Epoch: {}, Acc: {:3f}={}/{}, Val Loss: {:3f}".format(
294 | epoch, acc, current_num, len(correct_answer), val_loss))
295 |
296 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
297 | logfile.write(
298 | "VAL|Epoch: %d, Acc: %.3f=%d/%d, Val Loss: %.3f\n"
299 | % (epoch, acc, current_num, len(correct_answer), val_loss)
300 | )
301 | logfile.close()
302 |
303 | return acc, val_loss
304 |
305 |
306 | @torch.no_grad()
307 | def test(args, model, test_dataloader, epoch, criterion):
308 |
309 | model.eval()
310 |
311 | loss_list = []
312 | prediction_list = []
313 | correct_answer_list = []
314 |
315 | for ii, data in enumerate(test_dataloader):
316 | data = [d.to(device) for d in data]
317 |
318 | out, predictions, answers = model(args.task, *data)
319 | loss = criterion(out, answers)
320 |
321 | correct_answer_list.append(answers)
322 | loss_list.append(loss.item())
323 | prediction_list.append(predictions.detach())
324 |
325 | test_loss = np.mean(loss_list)
326 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
327 | predict_answer = torch.cat(prediction_list, dim=0).long()
328 | assert correct_answer.shape == predict_answer.shape
329 |
330 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
331 |
332 | acc = current_num / len(correct_answer) * 100.
333 |
334 | print(
335 | "Test|Epoch: {}, Acc: {:3f}={}/{}, Test Loss: {:3f}".format(
336 | epoch, acc, current_num, len(correct_answer), test_loss))
337 |
338 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
339 | logfile.write(
340 | "Test|Epoch: %d, Acc: %.3f=%d/%d, Test Loss: %.3f\n"
341 | % (epoch, acc, current_num, len(correct_answer), test_loss)
342 | )
343 | logfile.close()
344 |
345 | if args.save:
346 | if acc >= 30:
347 | torch.save(
348 | model, os.path.join(args.save_model_path,
349 | args.task + '_' + str(acc.item())[:5] + '.pth'))
350 | print('Save model at ', args.save_model_path)
351 |
352 | if __name__ == '__main__':
353 | parser = argparse.ArgumentParser()
354 |
355 | parser.add_argument('--model_name', type=str, default='[MSVD_msrvtt]')
356 |
357 | ################ path config ################
358 | parser.add_argument('--feat_dir', default='/data/MSRVTT-QA/video', help='path for imagenet and c3d features')
359 | parser.add_argument('--vc_dir', default='/data/MSRVTT-QA/vocab4000', help='path for vocabulary')
360 | parser.add_argument('--df_dir', default='/data/MSRVTT-QA/question', help='path for tgif question csv files')
361 |
362 | ################ inference config ################
363 | parser.add_argument(
364 | '--checkpoint',
365 | type=str,
366 | default='',
367 | help='path to checkpoint')
368 | parser.add_argument(
369 | '--save_path',
370 | type=str,
371 | default='./saved_models/',
372 | help='path for saving trained models')
373 |
374 | parser.add_argument(
375 | '--save', action='store_true', default=True, help='save models or not')
376 | parser.add_argument(
377 | '--hidden_size',
378 | type=int,
379 | default=512,
380 | help='dimension of lstm hidden states')
381 | parser.add_argument(
382 | '--test', action='store_true', default=False, help='Train or Test')
383 | parser.add_argument('--max_epoch', type=int, default=100)
384 | parser.add_argument('--val_ratio', type=float, default=0.1)
385 | parser.add_argument('--q_max_length', type=int, default=20)
386 | parser.add_argument('--v_max_length', type=int, default=30)
387 |
388 | parser.add_argument(
389 | '--task',
390 | type=str,
391 | default='MS-QA'
392 | )
393 | parser.add_argument(
394 | '--rnn_layers', type=int, default=1, help='number of layers in lstm')
395 | parser.add_argument(
396 | '--gcn_layers',
397 | type=int,
398 | default=2,
399 | help='number of layers in gcn (+1)')
400 | parser.add_argument('--batch_size', type=int, default=32)
401 | parser.add_argument('--max_n_videos', type=int, default=100000)
402 | parser.add_argument('--num_workers', type=int, default=1)
403 | parser.add_argument('--lr', type=float, default=0.0001)
404 | parser.add_argument('--lr_list', type=list, default=[10, 20, 30, 40])
405 | parser.add_argument('--dropout', type=float, default=0.3)
406 | parser.add_argument(
407 | '--change_lr', type=str, default='none', help='0 False, 1 True')
408 | parser.add_argument(
409 | '--weight_decay', type=float, default=0, help='weight_decay')
410 | parser.add_argument('--ablation', type=str, default='none')
411 |
412 | args = parser.parse_args()
413 | print(args)
414 |
415 | main()
--------------------------------------------------------------------------------
/main_msvd.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | import os
7 | import argparse
8 | import torch
9 | import numpy as np
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import random
13 | import h5py
14 |
15 | seed = 999
16 |
17 | random.seed(seed)
18 | np.random.seed(seed)
19 | torch.manual_seed(seed)
20 | torch.cuda.manual_seed_all(seed)
21 | torch.backends.cudnn.enabled = False
22 | # torch.backends.cudnn.benchmark = True
23 | # torch.backends.cudnn.deterministic = True
24 |
25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
26 |
27 | cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
28 |
29 |
30 | def _init_fn(worker_id):
31 | np.random.seed(seed)
32 |
33 |
34 | from data_utils.dataset_msvd import MSVDQA
35 | from torch.utils.data import DataLoader
36 | from warmup_scheduler import GradualWarmupScheduler
37 |
38 | from model.masn import MASN
39 |
40 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
41 |
42 |
43 | def main():
44 | """Main script."""
45 |
46 | args.pin_memory = False
47 | args.dataset = 'msvd_qa'
48 | args.log = './logs/%s' % args.model_name
49 | if not os.path.exists(args.log):
50 | os.mkdir(args.log)
51 |
52 | args.val_epoch_step = 1
53 |
54 | args.save_model_path = os.path.join(args.save_path, args.model_name)
55 | if not os.path.exists(args.save_model_path):
56 | os.makedirs(args.save_model_path)
57 |
58 | train_dataset = MSVDQA(
59 | dataset_name='train',
60 | q_max_length=args.q_max_length,
61 | v_max_length=args.v_max_length,
62 | max_n_videos=args.max_n_videos,
63 | csv_dir=args.df_dir,
64 | vocab_dir=args.vc_dir,
65 | feat_dir=args.feat_dir)
66 | val_dataset = MSVDQA(
67 | dataset_name='val',
68 | q_max_length=args.q_max_length,
69 | v_max_length=args.v_max_length,
70 | max_n_videos=args.max_n_videos,
71 | csv_dir=args.df_dir,
72 | vocab_dir=args.vc_dir,
73 | feat_dir=args.feat_dir)
74 | test_dataset = MSVDQA(
75 | dataset_name='test',
76 | q_max_length=args.q_max_length,
77 | v_max_length=args.v_max_length,
78 | max_n_videos=args.max_n_videos,
79 | csv_dir=args.df_dir,
80 | vocab_dir=args.vc_dir,
81 | feat_dir=args.feat_dir)
82 |
83 | print(
84 | 'Dataset lengths train/val/test %d/%d/%d' %
85 | (len(train_dataset), len(val_dataset), len(test_dataset)))
86 |
87 | train_dataloader = DataLoader(
88 | train_dataset,
89 | args.batch_size,
90 | shuffle=True,
91 | num_workers=args.num_workers,
92 | pin_memory=args.pin_memory,
93 | worker_init_fn=_init_fn)
94 | val_dataloader = DataLoader(
95 | val_dataset,
96 | args.batch_size,
97 | shuffle=False,
98 | num_workers=args.num_workers,
99 | pin_memory=args.pin_memory,
100 | worker_init_fn=_init_fn)
101 | test_dataloader = DataLoader(
102 | test_dataset,
103 | args.batch_size,
104 | shuffle=False,
105 | num_workers=args.num_workers,
106 | pin_memory=args.pin_memory,
107 | worker_init_fn=_init_fn)
108 |
109 | print('Load data successful.')
110 |
111 | args.resnet_input_size = 2048
112 | args.c3d_input_size = 2048
113 |
114 | args.text_embed_size = train_dataset.GLOVE_EMBEDDING_SIZE
115 | args.answer_vocab_size = None
116 |
117 | args.word_matrix = train_dataset.word_matrix
118 | args.voc_len = args.word_matrix.shape[0]
119 | assert args.text_embed_size == args.word_matrix.shape[1]
120 |
121 | VOCABULARY_SIZE = train_dataset.n_words
122 | assert VOCABULARY_SIZE == args.voc_len
123 |
124 | # add classification loss
125 | args.answer_vocab_size = len(train_dataset.ans2idx)
126 | print(('Vocabulary size', args.answer_vocab_size, VOCABULARY_SIZE))
127 | criterion = nn.CrossEntropyLoss().to(device)
128 |
129 | if not args.test:
130 | train(
131 | args, train_dataloader, val_dataloader, test_dataloader, criterion)
132 | else:
133 | model = torch.load(os.path.join(args.save_model_path, args.checkpoint))
134 | test(args, model, test_dataloader, 0, criterion)
135 |
136 |
137 | def train(args, train_dataloader, val_dataloader, test_dataloader, criterion):
138 | model = MASN(
139 | args.voc_len,
140 | args.rnn_layers,
141 | args.word_matrix,
142 | args.resnet_input_size,
143 | args.i3d_input_size,
144 | args.hidden_size,
145 | dropout_p=args.dropout,
146 | gcn_layers=args.gcn_layers,
147 | answer_vocab_size=args.answer_vocab_size,
148 | q_max_len=args.q_max_length,
149 | v_max_len=args.v_max_length,
150 | ablation=args.ablation)
151 |
152 | if torch.cuda.device_count() > 1:
153 | print("Let's use", torch.cuda.device_count(), "GPUs!")
154 | model = nn.DataParallel(model)
155 |
156 | model.to(device)
157 |
158 | if args.change_lr == 'none':
159 | optimizer = torch.optim.Adam(
160 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
161 | elif args.change_lr == 'acc':
162 | optimizer = torch.optim.Adam(
163 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
164 | # val plateau scheduler
165 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
166 | optimizer, mode='max', factor=0.1, patience=3, verbose=True)
167 | # target lr = args.lr * multiplier
168 | scheduler_warmup = GradualWarmupScheduler(
169 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
170 | elif args.change_lr == 'loss':
171 | optimizer = torch.optim.Adam(
172 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
173 | # val plateau scheduler
174 | scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
175 | optimizer, mode='min', factor=0.1, patience=3, verbose=True)
176 | # target lr = args.lr * multiplier
177 | scheduler_warmup = GradualWarmupScheduler(
178 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
179 | elif args.change_lr == 'cos':
180 | optimizer = torch.optim.Adam(
181 | model.parameters(), lr=args.lr / 5., weight_decay=args.weight_decay)
182 | # consine annealing
183 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
184 | optimizer, args.max_epoch)
185 | # target lr = args.lr * multiplier
186 | scheduler_warmup = GradualWarmupScheduler(
187 | optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
188 | elif args.change_lr == 'step':
189 | optimizer = torch.optim.Adam(
190 | model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
191 | scheduler = torch.optim.lr_scheduler.MultiStepLR(
192 | optimizer, milestones=args.lr_list, gamma=0.1)
193 | # scheduler_warmup = GradualWarmupScheduler(
194 | # optimizer, multiplier=5, total_epoch=5, after_scheduler=scheduler)
195 |
196 | best_val_acc = 0.
197 |
198 | for epoch in range(args.max_epoch):
199 | print('Start Training Epoch: {}'.format(epoch))
200 |
201 | model.train()
202 |
203 | loss_list = []
204 | prediction_list = []
205 | correct_answer_list = []
206 |
207 | if args.change_lr == 'cos':
208 | # consine annealing
209 | scheduler_warmup.step(epoch=epoch)
210 |
211 | for ii, data in enumerate(train_dataloader):
212 | if epoch == 0 and ii == 0:
213 | print([d.dtype for d in data], [d.size() for d in data])
214 | data = [d.to(device) for d in data]
215 |
216 | optimizer.zero_grad()
217 | out, predictions, answers = model(args.task, *data)
218 | loss = criterion(out, answers)
219 | loss.backward()
220 | optimizer.step()
221 |
222 | correct_answer_list.append(answers)
223 | loss_list.append(loss.item())
224 | prediction_list.append(predictions.detach())
225 | if ii % 100 == 0:
226 | print("Batch: ", ii)
227 |
228 | train_loss = np.mean(loss_list)
229 |
230 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
231 | predict_answer = torch.cat(prediction_list, dim=0).long()
232 | assert correct_answer.shape == predict_answer.shape
233 |
234 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
235 | acc = current_num / len(correct_answer) * 100.
236 |
237 | # print('Learning Rate: {}'.format(optimizer.param_groups[0]['lr']))
238 | if args.change_lr == 'acc':
239 | scheduler_warmup.step(epoch, val_acc)
240 | elif args.change_lr == 'loss':
241 | scheduler_warmup.step(epoch, val_loss)
242 | elif args.change_lr == 'step':
243 | scheduler.step()
244 |
245 | print(
246 | "Train|Epoch: {}, Acc : {:.3f}={}/{}, Train Loss: {:.3f}".format(
247 | epoch, acc, current_num, len(correct_answer), train_loss))
248 |
249 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
250 | logfile.write(
251 | "Train|Epoch: %d, Acc : %.3f=%d/%d, Train Loss: %.3f\n"
252 | % (epoch, acc, current_num, len(correct_answer), train_loss)
253 | )
254 | logfile.close()
255 |
256 | val_acc, val_loss = val(args, model, val_dataloader, epoch, criterion)
257 |
258 | if val_acc > best_val_acc:
259 | print('Best Val Acc ======')
260 | best_val_acc = val_acc
261 | if epoch % args.val_epoch_step == 0 or val_acc >= best_val_acc:
262 | test(args, model, test_dataloader, epoch, criterion)
263 |
264 |
265 | @torch.no_grad()
266 | def val(args, model, val_dataloader, epoch, criterion):
267 | model.eval()
268 |
269 | loss_list = []
270 | prediction_list = []
271 | correct_answer_list = []
272 |
273 | for ii, data in enumerate(val_dataloader):
274 | data = [d.to(device) for d in data]
275 |
276 | out, predictions, answers = model(args.task, *data)
277 | loss = criterion(out, answers)
278 |
279 | correct_answer_list.append(answers)
280 | loss_list.append(loss.item())
281 | prediction_list.append(predictions.detach())
282 |
283 | val_loss = np.mean(loss_list)
284 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
285 | predict_answer = torch.cat(prediction_list, dim=0).long()
286 | assert correct_answer.shape == predict_answer.shape
287 |
288 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
289 |
290 | acc = current_num / len(correct_answer) * 100.
291 |
292 | print(
293 | "VAL|Epoch: {}, Acc: {:3f}={}/{}, Val Loss: {:3f}".format(
294 | epoch, acc, current_num, len(correct_answer), val_loss))
295 |
296 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
297 | logfile.write(
298 | "VAL|Epoch: %d, Acc: %.3f=%d/%d, Val Loss: %.3f\n"
299 | % (epoch, acc, current_num, len(correct_answer), val_loss)
300 | )
301 | logfile.close()
302 |
303 | return acc, val_loss
304 |
305 |
306 | @torch.no_grad()
307 | def test(args, model, test_dataloader, epoch, criterion):
308 |
309 | model.eval()
310 |
311 | loss_list = []
312 | prediction_list = []
313 | correct_answer_list = []
314 |
315 | for ii, data in enumerate(test_dataloader):
316 | data = [d.to(device) for d in data]
317 |
318 | out, predictions, answers = model(args.task, *data)
319 | loss = criterion(out, answers)
320 |
321 | correct_answer_list.append(answers)
322 | loss_list.append(loss.item())
323 | prediction_list.append(predictions.detach())
324 |
325 | test_loss = np.mean(loss_list)
326 | correct_answer = torch.cat(correct_answer_list, dim=0).long()
327 | predict_answer = torch.cat(prediction_list, dim=0).long()
328 | assert correct_answer.shape == predict_answer.shape
329 |
330 | current_num = torch.sum(predict_answer == correct_answer).cpu().numpy()
331 |
332 | acc = current_num / len(correct_answer) * 100.
333 |
334 | print(
335 | "Test|Epoch: {}, Acc: {:3f}={}/{}, Test Loss: {:3f}".format(
336 | epoch, acc, current_num, len(correct_answer), test_loss))
337 |
338 | logfile = open(os.path.join(args.log, args.task + '.txt'), 'a+')
339 | logfile.write(
340 | "Test|Epoch: %d, Acc: %.3f=%d/%d, Test Loss: %.3f\n"
341 | % (epoch, acc, current_num, len(correct_answer), test_loss)
342 | )
343 | logfile.close()
344 |
345 | if args.save:
346 | if acc >= 30:
347 | torch.save(
348 | model, os.path.join(args.save_model_path,
349 | args.task + '_' + str(acc.item())[:5] + '.pth'))
350 | print('Save model at ', args.save_model_path)
351 |
352 | if __name__ == '__main__':
353 | parser = argparse.ArgumentParser()
354 |
355 | parser.add_argument('--exp_name', type=str, default='[MSVD_msvd]')
356 |
357 | ################ path config ################
358 | parser.add_argument('--feat_dir', default='/data/MSVD-QA', help='path for imagenet and c3d features')
359 | parser.add_argument('--vc_dir', default='/data/MSVD-QA/vocab', help='path for vocabulary')
360 | parser.add_argument('--df_dir', default='/data/MSVD-QA/question', help='path for tgif question csv files')
361 |
362 | ################ inference config ################
363 | parser.add_argument(
364 | '--checkpoint',
365 | type=str,
366 | default='',
367 | help='path to checkpoint')
368 | parser.add_argument(
369 | '--save_path',
370 | type=str,
371 | default='./saved_models/',
372 | help='path for saving trained models')
373 |
374 | parser.add_argument(
375 | '--save', action='store_true', default=True, help='save models or not')
376 | parser.add_argument(
377 | '--hidden_size',
378 | type=int,
379 | default=512,
380 | help='dimension of lstm hidden states')
381 | parser.add_argument(
382 | '--test', action='store_true', default=False, help='Train or Test')
383 | parser.add_argument('--max_epoch', type=int, default=100)
384 | parser.add_argument('--val_ratio', type=float, default=0.1)
385 | parser.add_argument('--q_max_length', type=int, default=20)
386 | parser.add_argument('--v_max_length', type=int, default=30)
387 |
388 | parser.add_argument(
389 | '--task',
390 | type=str,
391 | default='MS-QA'
392 | )
393 | parser.add_argument(
394 | '--rnn_layers', type=int, default=1, help='number of layers in lstm')
395 | parser.add_argument(
396 | '--gcn_layers',
397 | type=int,
398 | default=2,
399 | help='number of layers in gcn (+1)')
400 | parser.add_argument('--batch_size', type=int, default=32)
401 | parser.add_argument('--max_n_videos', type=int, default=100000)
402 | parser.add_argument('--num_workers', type=int, default=1)
403 | parser.add_argument('--lr', type=float, default=0.0001)
404 | parser.add_argument('--lr_list', type=list, default=[10, 20, 30, 40])
405 | parser.add_argument('--dropout', type=float, default=0.3)
406 | parser.add_argument(
407 | '--change_lr', type=str, default='none', help='0 False, 1 True')
408 | parser.add_argument(
409 | '--weight_decay', type=float, default=0, help='weight_decay')
410 | parser.add_argument('--ablation', type=str, default='none')
411 |
412 | args = parser.parse_args()
413 | print(args)
414 |
415 | main()
--------------------------------------------------------------------------------
/model/masn.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import functional as F
9 | import numpy as np
10 | from model.modules import linear_weightdrop as dropnn
11 |
12 | from torch.autograd import Variable
13 |
14 | from model.modules.rnn_encoder import SentenceEncoderRNN
15 | from model.modules.gcn import VideoAdjLearner, GCN
16 | from model.modules.position_embedding import PositionEncoding
17 | from model.modules.ban.ban import BAN
18 | from model.modules.fusion.fusion import MotionApprFusion, AttFlat
19 |
20 | # torch.set_printoptions(threshold=np.inf)
21 |
22 |
23 | class MASN(nn.Module):
24 |
25 | def __init__(
26 | self,
27 | vocab_size,
28 | s_layers,
29 | s_embedding,
30 | resnet_input_size,
31 | i3d_input_size,
32 | hidden_size,
33 | dropout_p=0.0,
34 | gcn_layers=2,
35 | answer_vocab_size=None,
36 | q_max_len=35,
37 | v_max_len=80,
38 | ablation='none'):
39 | super().__init__()
40 |
41 | self.ablation = ablation
42 | self.q_max_len = q_max_len
43 | self.v_max_len = v_max_len
44 | self.hidden_size = hidden_size
45 |
46 | self.compress_appr_local = dropnn.WeightDropLinear(
47 | resnet_input_size,
48 | hidden_size,
49 | weight_dropout=dropout_p,
50 | bias=False)
51 | self.compress_motion_local = dropnn.WeightDropLinear(
52 | i3d_input_size,
53 | hidden_size,
54 | weight_dropout=dropout_p,
55 | bias=False)
56 | self.compress_appr_global = dropnn.WeightDropLinear(
57 | resnet_input_size,
58 | hidden_size,
59 | weight_dropout=dropout_p,
60 | bias=False)
61 | self.compress_motion_global = dropnn.WeightDropLinear(
62 | i3d_input_size,
63 | hidden_size,
64 | weight_dropout=dropout_p,
65 | bias=False)
66 |
67 | embedding_dim = s_embedding.shape[1] if s_embedding is not None else hidden_size
68 | self.glove = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
69 | if s_embedding is not None:
70 | print("glove embedding weight is loaded!")
71 | self.glove.weight = nn.Parameter(torch.from_numpy(s_embedding).float())
72 | self.glove.weight.requires_grad = False
73 | self.embedding_proj = nn.Sequential(
74 | nn.Dropout(p=dropout_p),
75 | nn.Linear(embedding_dim, hidden_size, bias=False)
76 | )
77 |
78 | self.sentence_encoder = SentenceEncoderRNN(
79 | vocab_size,
80 | hidden_size,
81 | input_dropout_p=dropout_p,
82 | dropout_p=dropout_p,
83 | n_layers=s_layers,
84 | bidirectional=True,
85 | rnn_cell='lstm'
86 | )
87 |
88 | self.bbox_location_encoding = nn.Linear(6, 64)
89 | self.pos_location_encoding = PositionEncoding(n_filters=64, max_len=self.v_max_len)
90 |
91 | self.appr_local_proj = nn.Linear(hidden_size+128, hidden_size)
92 | self.motion_local_proj = nn.Linear(hidden_size+128, hidden_size)
93 |
94 | self.pos_enc = PositionEncoding(n_filters=512, max_len=self.v_max_len)
95 | self.appr_v = nn.Linear(hidden_size*2, hidden_size)
96 | self.motion_v = nn.Linear(hidden_size*2, hidden_size)
97 |
98 | self.appr_adj = VideoAdjLearner(hidden_size, hidden_size)
99 | self.appr_gcn = GCN(hidden_size, hidden_size, hidden_size, num_layers=gcn_layers)
100 | self.motion_adj = VideoAdjLearner(hidden_size, hidden_size)
101 | self.motion_gcn = GCN(hidden_size, hidden_size, hidden_size, num_layers=gcn_layers)
102 |
103 | self.res_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False)
104 | self.i3d_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False)
105 |
106 | self.appr_vq_interact = BAN(hidden_size, glimpse=4)
107 | self.motion_vq_interact = BAN(hidden_size, glimpse=4)
108 |
109 | self.motion_appr_fusion = MotionApprFusion(hidden_size, hidden_size, n_layer=1)
110 | self.attflat = AttFlat(hidden_size, hidden_size, 1, hidden_size)
111 |
112 | if answer_vocab_size is not None:
113 | self.fc = nn.Linear(hidden_size, answer_vocab_size)
114 | else:
115 | self.fc = nn.Linear(hidden_size, 1)
116 |
117 | def forward(self, task, *args):
118 | # expected sentence_inputs is of shape (batch_size, sentence_len, 1)
119 | # expected video_inputs is of shape (batch_size, frame_num, video_feature)
120 | self.task = task
121 | if task == 'Count':
122 | return self.forward_count(*args)
123 | elif task == 'FrameQA':
124 | return self.forward_frameqa(*args)
125 | elif task == 'Action' or task == 'Trans':
126 | return self.forward_trans_or_action(*args)
127 | elif task == 'MS-QA':
128 | return self.forward_msqa(*args)
129 |
130 | def model_block(self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp,
131 | video_length, all_sen_inputs, all_ques_length):
132 |
133 | q_mask = self.make_mask(all_sen_inputs, all_ques_length)
134 | v_mask = self.make_mask(res_avg_inp[:,:,0], video_length)
135 |
136 | q_emb = F.relu(self.embedding_proj(self.glove(all_sen_inputs))) # b, q_len, d
137 | q_output, q_hidden = self.sentence_encoder(q_emb, input_lengths=all_ques_length)
138 | q_hidden = q_hidden.squeeze()
139 |
140 | bsz, v_len, obj_num = res_obj_inp.size(0), res_obj_inp.size(1), res_obj_inp.size(2)
141 | q_len = q_output.size(1)
142 | q_mask = q_mask[:,:q_len]
143 |
144 | # make local and global feature
145 | res_obj_inp = self.compress_appr_local(res_obj_inp) # b, v_len, N, d
146 | i3d_obj_inp = self.compress_motion_local(i3d_obj_inp) # b, v_len, N, d
147 | res_avg_inp = self.compress_appr_global(res_avg_inp) # b, v_len, d
148 | i3d_avg_inp = self.compress_motion_global(i3d_avg_inp) # b, v_len, d
149 |
150 | bbox_inp = self.bbox_location_encoding(bbox_inp) # b, v_len, N, d/8
151 | pos_inp = self.pos_location_encoding(res_obj_inp.contiguous().view(bsz*obj_num, v_len, -1)) # 1, v_len, 64
152 | pos_inp = pos_inp.unsqueeze(2).expand(bsz, v_len, obj_num, 64) * v_mask.unsqueeze(2).unsqueeze(3) # b, v_len, N, d/8
153 |
154 | appr_local = self.appr_local_proj(torch.cat([res_obj_inp, bbox_inp, pos_inp], dim=3)) # b, v_len, N, d
155 | motion_local = self.motion_local_proj(torch.cat([i3d_obj_inp, bbox_inp, pos_inp], dim=3)) # b, v_len, N, d
156 |
157 | v_len = appr_local.size(1)
158 | appr_local = appr_local.contiguous().view(bsz*v_len, obj_num, self.hidden_size)
159 | motion_local = motion_local.contiguous().view(bsz*v_len, obj_num, self.hidden_size)
160 |
161 | res_avg_inp = self.pos_enc(res_avg_inp) + res_avg_inp
162 | res_avg_inp = res_avg_inp.contiguous().view(bsz*v_len, self.hidden_size)
163 | res_avg_inp = res_avg_inp.unsqueeze(1).expand_as(appr_local)
164 | appr_v = self.appr_v(torch.cat([appr_local, res_avg_inp], dim=-1))
165 |
166 | i3d_avg_inp = self.pos_enc(i3d_avg_inp) + i3d_avg_inp
167 | i3d_avg_inp = i3d_avg_inp.contiguous().view(bsz*v_len, self.hidden_size)
168 | i3d_avg_inp = i3d_avg_inp.unsqueeze(1).expand_as(motion_local)
169 | motion_v = self.motion_v(torch.cat([motion_local, i3d_avg_inp], dim=-1))
170 |
171 | appr_v = appr_v.contiguous().view(bsz, v_len*obj_num, self.hidden_size)
172 | motion_v = motion_v.contiguous().view(bsz, v_len*obj_num, self.hidden_size)
173 | v_mask_expand = v_mask[:,:v_len].unsqueeze(2).expand(bsz, v_len, obj_num).contiguous().view(bsz, v_len*obj_num)
174 |
175 | # object graph convolution
176 | appr_adj = self.appr_adj(appr_v, v_mask_expand)
177 | appr_gcn = self.appr_gcn(appr_v, appr_adj) # b, v_len*obj_num, d
178 | motion_adj = self.motion_adj(motion_v, v_mask_expand)
179 | motion_gcn = self.motion_gcn(motion_v, motion_adj) # b, v_len*obj_num, d
180 |
181 | # vq interaction
182 | appr_vq, _ = self.appr_vq_interact(appr_gcn, q_output, v_mask_expand, q_mask)
183 | motion_vq, _ = self.motion_vq_interact(motion_gcn, q_output, v_mask_expand, q_mask)
184 |
185 | # motion-appr fusion
186 | U = torch.cat([appr_vq, motion_vq], dim=1) # b, 2*q_len, d
187 | q_mask_ = torch.cat([q_mask, q_mask], dim=1)
188 | U_mask = torch.matmul(q_mask_.unsqueeze(2), q_mask_.unsqueeze(2).transpose(1, 2))
189 |
190 | fusion_out = self.motion_appr_fusion(U, q_hidden, U_mask)
191 | fusion_out = self.attflat(fusion_out, q_mask_)
192 |
193 | out = self.fc(fusion_out).squeeze()
194 | return out
195 |
196 | def make_mask(self, seq, seq_length):
197 | mask = seq
198 | mask = mask.data.new(*mask.size()).fill_(1)
199 | for i, l in enumerate(seq_length):
200 | mask[i][min(mask.size(1)-1, l):] = 0
201 | mask = Variable(mask) # b, seq_len
202 | mask = mask.to(torch.float)
203 | return mask
204 |
205 | def forward_count(
206 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
207 | all_sen_inputs, all_ques_length, answers):
208 | # out of shape (batch_size, )
209 | out = self.model_block(
210 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
211 | all_sen_inputs, all_ques_length)
212 | predictions = torch.clamp(torch.round(out), min=1, max=10).long()
213 | # answers of shape (batch_size, )
214 | return out, predictions, answers
215 |
216 | def forward_frameqa(
217 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
218 | all_sen_inputs, all_ques_length, answers, answer_type):
219 | # out of shape (batch_size, num_class)
220 | out = self.model_block(
221 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
222 | all_sen_inputs, all_ques_length)
223 |
224 | _, max_idx = torch.max(out, 1)
225 | # (batch_size, ), dtype is long
226 | predictions = max_idx
227 | # answers of shape (batch_size, )
228 | return out, predictions, answers
229 |
230 | def forward_trans_or_action(
231 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
232 | all_cand_inputs, all_cand_length, answers, row_index):
233 | all_cand_inputs = all_cand_inputs.permute(1, 0, 2)
234 | all_cand_length = all_cand_length.permute(1, 0)
235 |
236 | all_out = []
237 | for idx in range(5):
238 | # out of shape (batch_size, )
239 | out = self.model_block(
240 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
241 | all_cand_inputs[idx], all_cand_length[idx])
242 | all_out.append(out)
243 | # all_out of shape (batch_size, 5)
244 | all_out = torch.stack(all_out, 0).transpose(1, 0)
245 | _, max_idx = torch.max(all_out, 1)
246 | # (batch_size, )
247 | predictions = max_idx
248 |
249 | # answers of shape (batch_size, )
250 | return all_out, predictions, answers
251 |
252 | def forward_msqa(
253 | self, res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
254 | all_sen_inputs, all_ques_length, answers):
255 | # out of shape (batch_size, num_class)
256 | out = self.model_block(
257 | res_avg_inp, i3d_avg_inp, res_obj_inp, bbox_inp, i3d_obj_inp, video_length,
258 | all_sen_inputs, all_ques_length)
259 |
260 | _, max_idx = torch.max(out, 1)
261 | # (batch_size, ), dtype is long
262 | predictions = max_idx
263 | # answers of shape (batch_size, )
264 | return out, predictions, answers
265 |
--------------------------------------------------------------------------------
/model/modules/ban/ban.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from jnhwkim's repository.
3 | # https://github.com/jnhwkim/ban-vqa
4 | # --------------------------------------------------------
5 |
6 | from __future__ import print_function
7 | import math
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn.utils.weight_norm import weight_norm
11 | from model.modules.ban.fc import FCNet
12 |
13 |
14 | class BCNet(nn.Module):
15 | """Simple class for non-linear bilinear connect network
16 | """
17 | def __init__(self, v_dim, q_dim, h_dim, h_out, act='ReLU', dropout=[.2,.5], k=3):
18 | super(BCNet, self).__init__()
19 |
20 | self.c = 32
21 | self.k = k
22 | self.v_dim = v_dim; self.q_dim = q_dim
23 | self.h_dim = h_dim; self.h_out = h_out
24 |
25 | self.v_net = FCNet([v_dim, h_dim * self.k], act=act, dropout=dropout[0])
26 | self.q_net = FCNet([q_dim, h_dim * self.k], act=act, dropout=dropout[0])
27 | self.dropout = nn.Dropout(dropout[1]) # attention
28 | if 1 < k:
29 | self.p_net = nn.AvgPool1d(self.k, stride=self.k)
30 |
31 | if None == h_out:
32 | pass
33 | elif h_out <= self.c:
34 | self.h_mat = nn.Parameter(torch.Tensor(1, h_out, 1, h_dim * self.k).normal_())
35 | self.h_bias = nn.Parameter(torch.Tensor(1, h_out, 1, 1).normal_())
36 | else:
37 | self.h_net = weight_norm(nn.Linear(h_dim * self.k, h_out), dim=None)
38 |
39 | def forward(self, v, q):
40 | if None == self.h_out:
41 | v_ = self.v_net(v)
42 | q_ = self.q_net(q)
43 | logits = torch.einsum('bvk,bqk->bvqk', (v_, q_))
44 | return logits
45 |
46 | # low-rank bilinear pooling using einsum
47 | elif self.h_out <= self.c:
48 | v_ = self.dropout(self.v_net(v))
49 | q_ = self.q_net(q)
50 | logits = torch.einsum('xhyk,bvk,bqk->bhvq', (self.h_mat, v_, q_)) + self.h_bias
51 | return logits # b x h_out x v x q
52 |
53 | # batch outer product, linear projection
54 | # memory efficient but slow computation
55 | else:
56 | v_ = self.dropout(self.v_net(v)).transpose(1,2).unsqueeze(3)
57 | q_ = self.q_net(q).transpose(1,2).unsqueeze(2)
58 | d_ = torch.matmul(v_, q_) # b x h_dim x v x q
59 | logits = self.h_net(d_.transpose(1,2).transpose(2,3)) # b x v x q x h_out
60 | return logits.transpose(2,3).transpose(1,2) # b x h_out x v x q
61 |
62 | def forward_with_weights(self, v, q, w):
63 | v_ = self.v_net(v) # b x v x d
64 | q_ = self.q_net(q) # b x q x d
65 | logits = torch.einsum('bvk,bvq,bqk->bk', (v_, w, q_))
66 | if 1 < self.k:
67 | logits = logits.unsqueeze(1) # b x 1 x d
68 | logits = self.p_net(logits).squeeze(1) * self.k # sum-pooling
69 | return logits
70 |
71 |
72 | class BiAttention(nn.Module):
73 | def __init__(self, x_dim, y_dim, z_dim, glimpse, dropout=[.2,.5]):
74 | super(BiAttention, self).__init__()
75 |
76 | self.glimpse = glimpse
77 | self.logits = weight_norm(BCNet(x_dim, y_dim, z_dim, glimpse, dropout=dropout, k=3), \
78 | name='h_mat', dim=None)
79 |
80 | def forward(self, v, q, v_mask, q_mask):
81 | """
82 | v: [batch, k, vdim]
83 | q: [batch, qdim]
84 | """
85 | p, logits = self.forward_all(v, q, v_mask, q_mask)
86 | return p, logits
87 |
88 | def forward_all(self, v, q, v_mask, q_mask, logit=False, mask_with=-float('inf')):
89 | v_num = v.size(1)
90 | q_num = q.size(1)
91 | logits = self.logits(v,q) # b x g x v x q
92 |
93 | if v_mask is not None:
94 | v_mask = v_mask.unsqueeze(1).unsqueeze(3).expand_as(logits)
95 | logits = logits - 1e10*(1 - v_mask)
96 | if q_mask is not None:
97 | q_mask = q_mask.unsqueeze(1).unsqueeze(2).expand_as(logits)
98 | logits = logits - 1e10*(1 - q_mask)
99 |
100 | if not logit:
101 | p = nn.functional.softmax(logits.view(-1, self.glimpse, v_num * q_num), 2)
102 | p = p.view(-1, self.glimpse, v_num, q_num)
103 | if v_mask is not None:
104 | p = p * v_mask
105 | if q_mask is not None:
106 | p = p * q_mask
107 | p = p.masked_fill(p != p, 0.)
108 |
109 | return p, logits
110 |
111 | return logits
112 |
113 |
114 | class BAN(nn.Module):
115 | def __init__(self, num_hid, glimpse):
116 | super(BAN, self).__init__()
117 | self.glimpse = glimpse
118 | self.v_att = BiAttention(num_hid, num_hid, num_hid, glimpse)
119 |
120 | b_net = []
121 | q_prj = []
122 | for i in range(glimpse):
123 | b_net.append(BCNet(num_hid, num_hid, num_hid, None, k=1))
124 | q_prj.append(FCNet([num_hid, num_hid], '', .2))
125 |
126 | self.b_net = nn.ModuleList(b_net)
127 | self.q_prj = nn.ModuleList(q_prj)
128 | self.drop = nn.Dropout(.5)
129 | self.tanh = nn.Tanh()
130 |
131 | def forward(self, v, q_emb, v_mask, q_mask):
132 | """Forward
133 | v: [batch, num_objs, dim]
134 | q: [batch_size, q_len, dim]
135 | v_mask: b, v_len
136 | q_mask: b, q_len
137 | return: logits, not probs
138 | """
139 | # v_emb = v
140 | b_emb = [0] * self.glimpse
141 | att, logits = self.v_att.forward_all(v, q_emb, v_mask, q_mask) # b x g x v x q
142 |
143 | for g in range(self.glimpse):
144 | b_emb[g] = self.b_net[g].forward_with_weights(v, q_emb, att[:,g,:,:]) # b x l x h
145 | q_emb = self.q_prj[g](b_emb[g].unsqueeze(1)) + q_emb
146 |
147 | return q_emb, att
148 |
--------------------------------------------------------------------------------
/model/modules/ban/fc.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from jnhwkim's repository.
3 | # https://github.com/jnhwkim/ban-vqa
4 | # --------------------------------------------------------
5 |
6 |
7 | from __future__ import print_function
8 | import torch.nn as nn
9 | from torch.nn.utils.weight_norm import weight_norm
10 |
11 |
12 | class FCNet(nn.Module):
13 | """Simple class for non-linear fully connect network
14 | """
15 | def __init__(self, dims, act='ReLU', dropout=0):
16 | super(FCNet, self).__init__()
17 |
18 | layers = []
19 | for i in range(len(dims)-2):
20 | in_dim = dims[i]
21 | out_dim = dims[i+1]
22 | if 0 < dropout:
23 | layers.append(nn.Dropout(dropout))
24 | layers.append(weight_norm(nn.Linear(in_dim, out_dim), dim=None))
25 | if ''!=act:
26 | layers.append(getattr(nn, act)())
27 | if 0 < dropout:
28 | layers.append(nn.Dropout(dropout))
29 | layers.append(weight_norm(nn.Linear(dims[-2], dims[-1]), dim=None))
30 | if ''!=act:
31 | layers.append(getattr(nn, act)())
32 |
33 | self.main = nn.Sequential(*layers)
34 |
35 | def forward(self, x):
36 | return self.main(x)
37 |
--------------------------------------------------------------------------------
/model/modules/fusion/fusion.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from cuiyuhao1996's repository.
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # https://github.com/cuiyuhao1996/mcan-vqa
5 | # --------------------------------------------------------
6 |
7 | from model.modules.fusion.net_utils import FC, MLP, LayerNorm
8 |
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch, math
12 |
13 |
14 | class AttFlat(nn.Module):
15 | def __init__(self, hidden_size, flat_mlp_size, flat_glimpses, flat_out_size, dropout_r=0.1):
16 | super(AttFlat, self).__init__()
17 |
18 | self.flat_glimpses = flat_glimpses
19 |
20 | self.mlp = MLP(
21 | in_size=hidden_size,
22 | mid_size=flat_mlp_size,
23 | out_size=flat_glimpses,
24 | dropout_r=dropout_r,
25 | use_relu=True
26 | )
27 |
28 | self.linear_merge = nn.Linear(
29 | hidden_size * flat_glimpses,
30 | flat_out_size
31 | )
32 |
33 | def forward(self, x, x_mask):
34 | att = self.mlp(x) # b, L, glimpse
35 | if x_mask is not None:
36 | x_mask = x_mask.unsqueeze(2)
37 | att = att - 1e10*(1 - x_mask)
38 | att = F.softmax(att, dim=1)
39 | if x_mask is not None:
40 | att = att * x_mask
41 | att = att.masked_fill(att != att, 0.) # b, L, glimpse
42 |
43 | att_list = []
44 | for i in range(self.flat_glimpses):
45 | att_list.append(
46 | torch.sum(att[:, :, i: i + 1] * x, dim=1)
47 | )
48 |
49 | x_atted = torch.cat(att_list, dim=1) # b, d
50 | x_atted = self.linear_merge(x_atted)
51 |
52 | return x_atted
53 |
54 |
55 | class FFN(nn.Module):
56 | def __init__(self, hidden_size, ff_size, dropout_r=0.1):
57 | super(FFN, self).__init__()
58 |
59 | self.mlp = MLP(
60 | in_size=hidden_size,
61 | mid_size=ff_size,
62 | out_size=hidden_size,
63 | dropout_r=dropout_r,
64 | use_relu=True
65 | )
66 |
67 | def forward(self, x):
68 | return self.mlp(x)
69 |
70 |
71 | class AttAdj(nn.Module):
72 | def __init__(self, hidden_size, dropout_r=0.1):
73 | super(AttAdj, self).__init__()
74 | self.hidden_size = hidden_size
75 | self.linear_k = nn.Linear(hidden_size, hidden_size)
76 | self.linear_q = nn.Linear(hidden_size, hidden_size)
77 |
78 | self.dropout = nn.Dropout(dropout_r)
79 |
80 | def forward(self, k, q, mask=None):
81 | '''
82 | :param k: b, kv_l, d
83 | :param q: b, q_l, d
84 | :param mask: b, q_l, kv_l
85 | '''
86 |
87 | k = self.linear_k(k)
88 | q = self.linear_q(q)
89 |
90 | adj_scores = self.att(k, q, mask)
91 |
92 | return adj_scores
93 |
94 | def att(self, key, query, mask):
95 | d_k = query.size(-1)
96 |
97 | scores = torch.matmul(
98 | query, key.transpose(-2, -1)
99 | ) / math.sqrt(d_k)
100 |
101 | if mask is not None:
102 | scores = scores - 1e10*(1 - mask)
103 |
104 | att_map = F.softmax(scores, dim=-1)
105 | if mask is not None:
106 | att_map = att_map * mask
107 | att_map = att_map.masked_fill(att_map != att_map, 0.)
108 | att_map = self.dropout(att_map)
109 |
110 | return att_map
111 |
112 |
113 | class MotionApprFusion(nn.Module):
114 | def __init__(self, hidden_size, ff_size, n_layer=1, dropout_r=0.1):
115 | super(MotionApprFusion, self).__init__()
116 | self.n_layer = n_layer
117 | self.hidden_size = hidden_size
118 |
119 | self.appr_att_score = AttAdj(hidden_size, dropout_r)
120 | self.motion_att_score = AttAdj(hidden_size, dropout_r)
121 | self.all_att_score = AttAdj(hidden_size, dropout_r)
122 |
123 | self.appr_linear = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layer)])
124 | self.motion_linear = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layer)])
125 | self.all_linear = nn.ModuleList([nn.Linear(hidden_size, hidden_size) for _ in range(n_layer)])
126 |
127 | self.appr_norm = LayerNorm(hidden_size)
128 | self.motion_norm = LayerNorm(hidden_size)
129 | self.all_norm = LayerNorm(hidden_size)
130 |
131 | self.ques_guide_att = AttAdj(hidden_size, dropout_r)
132 |
133 | self.ffn = FFN(hidden_size, ff_size, dropout_r)
134 | self.dropout = nn.Dropout(dropout_r)
135 | self.norm = LayerNorm(hidden_size)
136 |
137 | def forward(self, U, q_hid, U_mask=None):
138 | '''
139 | :param U: b, 2L, d
140 | :param q_hid: b, d
141 | :param U_mask: b, 2L
142 | :return:
143 | '''
144 | residual = U
145 | bsz, seq_len = U.size(0), U.size(1)
146 | qword_len = int(seq_len/2)
147 | device = U.device
148 |
149 | appr_mask = torch.zeros(bsz, seq_len, seq_len).to(torch.float).to(device)
150 | appr_mask[:,:,:qword_len] = 1
151 | motion_mask = torch.zeros(bsz, seq_len, seq_len).to(torch.float).to(device)
152 | motion_mask[:,:,qword_len:] = 1
153 |
154 | if U_mask is not None:
155 | appr_mask = appr_mask * U_mask
156 | motion_mask = motion_mask * U_mask
157 |
158 | appr_att_adj = self.appr_att_score(U, U, appr_mask) # b, 2L, 2L
159 | motion_att_adj = self.motion_att_score(U, U, motion_mask)
160 | if U_mask is not None:
161 | all_att_adj = self.all_att_score(U, U, U_mask)
162 | else:
163 | all_att_adj = self.all_att_score(U, U, None)
164 |
165 | appr_inp, motion_inp, all_inp = [U, ], [U, ], [U, ]
166 | for i in range(self.n_layer):
167 | appr_x = self.appr_linear[i](appr_inp[i]) # b, 2L, d
168 | appr_x = torch.matmul(appr_att_adj, appr_x)
169 | appr_inp.append(appr_x)
170 |
171 | motion_x = self.motion_linear[i](motion_inp[i])
172 | motion_x = torch.matmul(motion_att_adj, motion_x)
173 | motion_inp.append(motion_x)
174 |
175 | all_x = self.all_linear[i](all_inp[i])
176 | all_x = torch.matmul(all_att_adj, all_x)
177 | all_inp.append(all_x)
178 |
179 | appr_x = self.appr_norm(residual + appr_x)
180 | motion_x = self.motion_norm(residual + motion_x)
181 | all_x = self.all_norm(residual + all_x)
182 |
183 | graph_out = torch.cat([appr_x.unsqueeze(1), motion_x.unsqueeze(1), all_x.unsqueeze(1)], dim=1) # b, 3, 2L, d
184 | fusion_k = torch.sum(graph_out, dim=2) # b, 3, d
185 |
186 | fusion_q = q_hid.unsqueeze(1) # b, 1, d
187 | fusion_att_score = self.ques_guide_att(fusion_k, fusion_q).squeeze() # b, 3
188 |
189 | fusion_att = graph_out * fusion_att_score.unsqueeze(2).unsqueeze(3) # b, 3, 2L, d
190 | fusion_att = torch.sum(fusion_att, dim=1) # b, 2L, d
191 |
192 | fusion_out = self.norm(fusion_att + self.ffn(self.dropout(fusion_att)))
193 |
194 | return fusion_out
195 |
196 |
--------------------------------------------------------------------------------
/model/modules/fusion/net_utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from cuiyuhao1996's repository.
3 | # Licensed under The MIT License [see LICENSE for details]
4 | # https://github.com/cuiyuhao1996/mcan-vqa
5 | # --------------------------------------------------------
6 |
7 | import torch.nn as nn
8 | import torch
9 |
10 |
11 | class FC(nn.Module):
12 | def __init__(self, in_size, out_size, dropout_r=0., use_relu=True):
13 | super(FC, self).__init__()
14 | self.dropout_r = dropout_r
15 | self.use_relu = use_relu
16 |
17 | self.linear = nn.Linear(in_size, out_size)
18 |
19 | if use_relu:
20 | self.relu = nn.ReLU(inplace=True)
21 |
22 | if dropout_r > 0:
23 | self.dropout = nn.Dropout(dropout_r)
24 |
25 | def forward(self, x):
26 | x = self.linear(x)
27 |
28 | if self.use_relu:
29 | x = self.relu(x)
30 |
31 | if self.dropout_r > 0:
32 | x = self.dropout(x)
33 |
34 | return x
35 |
36 |
37 | class MLP(nn.Module):
38 | def __init__(self, in_size, mid_size, out_size, dropout_r=0., use_relu=True):
39 | super(MLP, self).__init__()
40 |
41 | self.fc = FC(in_size, mid_size, dropout_r=dropout_r, use_relu=use_relu)
42 | self.linear = nn.Linear(mid_size, out_size)
43 |
44 | def forward(self, x):
45 | return self.linear(self.fc(x))
46 |
47 |
48 | class LayerNorm(nn.Module):
49 | def __init__(self, size, eps=1e-6):
50 | super(LayerNorm, self).__init__()
51 | self.eps = eps
52 |
53 | self.a_2 = nn.Parameter(torch.ones(size))
54 | self.b_2 = nn.Parameter(torch.zeros(size))
55 |
56 | def forward(self, x):
57 | mean = x.mean(-1, keepdim=True)
58 | std = x.std(-1, keepdim=True)
59 |
60 | return self.a_2 * (x - mean) / (std + self.eps) + self.b_2
--------------------------------------------------------------------------------
/model/modules/gcn.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | import math
7 | import torch
8 | import numpy as np
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from torch.nn.parameter import Parameter
12 | from torch.nn.modules.module import Module
13 | from torch.autograd import Variable
14 |
15 |
16 | class GraphConvolution(Module):
17 | """
18 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
19 | """
20 |
21 | def __init__(self, in_features, out_features):
22 | super(GraphConvolution, self).__init__()
23 | self.weight = nn.Linear(in_features, out_features, bias=False)
24 | self.layer_norm = nn.LayerNorm(out_features, elementwise_affine=False)
25 |
26 | def forward(self, input, adj):
27 | # self.weight of shape (hidden_size, hidden_size)
28 | support = self.weight(input)
29 | output = torch.bmm(adj, support)
30 | output = self.layer_norm(output)
31 | return output
32 |
33 |
34 | class GCN(nn.Module):
35 |
36 | def __init__(
37 | self, input_size, hidden_size, num_classes, num_layers=1,
38 | dropout=0.1):
39 | super(GCN, self).__init__()
40 | self.layers = nn.ModuleList()
41 | self.layers.append(GraphConvolution(input_size, hidden_size))
42 | for i in range(num_layers - 1):
43 | self.layers.append(GraphConvolution(hidden_size, hidden_size))
44 | self.layers.append(GraphConvolution(hidden_size, num_classes))
45 |
46 | self.layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False)
47 | self.dropout = nn.Dropout(p=dropout)
48 |
49 | def forward(self, x, adj):
50 | x_gcn = [x]
51 | for i, layer in enumerate(self.layers):
52 | x_gcn.append(self.dropout(F.relu(layer(x_gcn[i], adj))))
53 |
54 | x = self.layernorm(x + x_gcn[-1])
55 | return x
56 |
57 |
58 | class VideoAdjLearner(Module):
59 |
60 | def __init__(self, in_feature_dim, hidden_size, dropout=0.1, scale=100):
61 | super().__init__()
62 | self.scale = scale
63 |
64 | self.edge_layer_1 = nn.Linear(in_feature_dim, hidden_size, bias=False)
65 | self.edge_layer_2 = nn.Linear(hidden_size, hidden_size, bias=False)
66 |
67 | # Regularization
68 | self.dropout = nn.Dropout(p=dropout)
69 | self.edge_layer_1 = nn.utils.weight_norm(self.edge_layer_1)
70 | self.edge_layer_2 = nn.utils.weight_norm(self.edge_layer_2)
71 |
72 | def forward(self, v, v_mask=None):
73 | '''
74 | :param v: (b, v_len, d)
75 | :param v_mask: (b, v_len)
76 | :return: adj: (b, v_len, v_len)
77 | '''
78 | # layer 1
79 | h = self.edge_layer_1(v) # b, v_l, d
80 | h = F.relu(h)
81 |
82 | # layer 2
83 | h = self.edge_layer_2(h) # b, v_l, d
84 | h = F.relu(h)
85 |
86 | # outer product
87 | adj = torch.bmm(h, h.transpose(1, 2)) # b, v_l, v_l
88 |
89 | if v_mask is not None:
90 | adj_mask = adj.data.new(*adj.size()).fill_(1)
91 | v_mask_ = torch.matmul(v_mask.unsqueeze(2), v_mask.unsqueeze(2).transpose(1, 2))
92 | adj_mask = adj_mask * v_mask_
93 |
94 | adj_mask = Variable(adj_mask)
95 | adj = adj - 1e10*(1 - adj_mask)
96 |
97 | adj = F.softmax(adj * self.scale, dim=-1)
98 | adj = adj * adj_mask
99 | adj = adj.masked_fill(adj != adj, 0.)
100 | else:
101 | adj = F.softmax(adj * self.scale, dim=-1)
102 |
103 | return adj
104 |
105 |
--------------------------------------------------------------------------------
/model/modules/linear_weightdrop.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 |
7 |
8 | from torch.nn import Parameter
9 | import torch
10 |
11 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12 |
13 |
14 | class weight_drop():
15 |
16 | def __init__(self, module, weights, dropout):
17 | for name_w in weights:
18 | w = getattr(module, name_w)
19 | del module._parameters[name_w]
20 | module.register_parameter(name_w + '_raw', Parameter(w))
21 |
22 | self.original_module_forward = module.forward
23 |
24 | self.weights = weights
25 | self.module = module
26 | self.dropout = dropout
27 |
28 | def __call__(self, *args, **kwargs):
29 | for name_w in self.weights:
30 | raw_w = getattr(self.module, name_w + '_raw')
31 | w = torch.nn.functional.dropout(
32 | raw_w, p=self.dropout, training=self.module.training)
33 | # module.register_parameter(name_w, Parameter(w))
34 | setattr(self.module, name_w, Parameter(w))
35 |
36 | return self.original_module_forward(*args, **kwargs)
37 |
38 |
39 | def _weight_drop(module, weights, dropout):
40 | setattr(module, 'forward', weight_drop(module, weights, dropout))
41 |
42 |
43 | class WeightDrop(torch.nn.Module):
44 | """
45 | The weight-dropped module applies recurrent regularization through a DropConnect mask on the
46 | hidden-to-hidden recurrent weights.
47 | **Thank you** to Sales Force for their initial implementation of :class:`WeightDrop`. Here is
48 | their `License
49 | `__.
50 | Args:
51 | module (:class:`torch.nn.Module`): Containing module.
52 | weights (:class:`list` of :class:`str`): Names of the module weight parameters to apply a
53 | dropout too.
54 | dropout (float): The probability a weight will be dropped.
55 | Example:
56 | >>> from torchnlp.nn import WeightDrop
57 | >>> import torch
58 | >>>
59 | >>> torch.manual_seed(123)
60 | >>
62 | >>> gru = torch.nn.GRUCell(2, 2)
63 | >>> weights = ['weight_hh']
64 | >>> weight_drop_gru = WeightDrop(gru, weights, dropout=0.9)
65 | >>>
66 | >>> input_ = torch.randn(3, 2)
67 | >>> hidden_state = torch.randn(3, 2)
68 | >>> weight_drop_gru(input_, hidden_state)
69 | tensor(... grad_fn=)
70 | """
71 |
72 | def __init__(self, module, weights, dropout=0.0):
73 | super(WeightDrop, self).__init__()
74 | _weight_drop(module, weights, dropout)
75 | self.forward = module.forward
76 |
77 |
78 | class WeightDropLinear(torch.nn.Linear):
79 | """
80 | Wrapper around :class:`torch.nn.Linear` that adds ``weight_dropout`` named argument.
81 | Args:
82 | weight_dropout (float): The probability a weight will be dropped.
83 | """
84 |
85 | def __init__(self, *args, weight_dropout=0.0, **kwargs):
86 | super().__init__(*args, **kwargs)
87 | weights = ['weight']
88 | _weight_drop(self, weights, weight_dropout)
89 |
--------------------------------------------------------------------------------
/model/modules/position_embedding.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from jayleicn's repository.
3 | # https://github.com/jayleicn/TVQAplus
4 | # --------------------------------------------------------
5 |
6 | import torch
7 | import math
8 | import torch.nn as nn
9 |
10 |
11 | class PositionEncoding(nn.Module):
12 | def __init__(self, n_filters=128, max_len=500):
13 | super(PositionEncoding, self).__init__()
14 | # Compute the positional encodings once in log space.
15 | pe = torch.zeros(max_len, n_filters) # (L, D)
16 | position = torch.arange(0, max_len).float().unsqueeze(1)
17 | div_term = torch.exp(torch.arange(0, n_filters, 2).float() * - (math.log(10000.0) / n_filters))
18 | pe[:, 0::2] = torch.sin(position * div_term)
19 | pe[:, 1::2] = torch.cos(position * div_term)
20 | self.register_buffer('pe', pe) # buffer is a tensor, not a variable, (L, D)
21 |
22 | def forward(self, x):
23 | """
24 | :Input: (*, L, D)
25 | :Output: (*, L, D) the same size as input
26 | """
27 | pe = self.pe.data[:x.size(-2), :] # (#x.size(-2), n_filters)
28 | extra_dim = len(x.size()) - 2
29 | for _ in range(extra_dim):
30 | pe = pe.unsqueeze(0)
31 |
32 | return pe
33 |
34 |
35 | def test_pos_enc():
36 | mdl = PositionEncoding()
37 |
38 | batch_size = 8
39 | n_channels = 128
40 | n_items = 60
41 |
42 | input = torch.ones(batch_size, n_items, n_channels)
43 |
44 | out = mdl(input)
45 | print(out)
46 |
47 |
48 | if __name__ == '__main__':
49 | test_pos_enc()
50 |
--------------------------------------------------------------------------------
/model/modules/rnn_encoder.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 |
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import torch
9 | import numpy as np
10 |
11 |
12 | class BaseRNN(nn.Module):
13 |
14 | # SYM_MASK = "MASK"
15 | # SYM_EOS = "EOS"
16 |
17 | def __init__(
18 | self, input_size, hidden_size, input_dropout_p, dropout_p, n_layers,
19 | rnn_cell):
20 | super(BaseRNN, self).__init__()
21 | self.vocab_size = input_size
22 | self.hidden_size = hidden_size
23 | self.n_layers = n_layers
24 | self.input_dropout_p = input_dropout_p
25 | self.dropout = nn.Dropout(p=input_dropout_p)
26 |
27 | if rnn_cell.lower() == 'lstm':
28 | self.rnn_cell = nn.LSTM
29 | elif rnn_cell.lower() == 'gru':
30 | self.rnn_cell = nn.GRU
31 | else:
32 | raise ValueError("Unsupported RNN Cell: {}".format(rnn_cell))
33 |
34 | self.dropout_p = dropout_p
35 |
36 | self.compress_hn_layers = nn.Linear(
37 | n_layers * hidden_size, hidden_size, bias=False)
38 | self.compress_hn_layers_bi = nn.Linear(
39 | n_layers * 2 * hidden_size, hidden_size, bias=False)
40 | self.compress_hn_bi = nn.Linear(
41 | 2 * hidden_size, hidden_size, bias=False)
42 |
43 | self.compress_output = nn.Linear(
44 | 2 * hidden_size, hidden_size, bias=False)
45 | self.compress_output_dropout = nn.Dropout(p=input_dropout_p)
46 |
47 | def forward(self, *args, **kwargs):
48 | raise NotImplementedError()
49 |
50 |
51 | class SentenceEncoderRNN(BaseRNN):
52 | r"""
53 | Applies a multi-layer RNN to an input sequence.
54 |
55 | variable_lengths: if use variable length RNN (default: False)
56 |
57 | Args:
58 | input_var (batch, seq_len, dim): glove embedding feature
59 | input_lengths (list of int, optional): A list that contains the lengths of sequences
60 | in the mini-batch
61 |
62 |
63 | Returns: output, hidden
64 | - **output** (batch, seq_len, hidden_size): variable containing the encoded features of the input sequence
65 | - **hidden** (num_layers * num_directions, batch, hidden_size): variable containing the features in the hidden state h
66 | """
67 |
68 | def __init__(
69 | self,
70 | vocab_size,
71 | hidden_size,
72 | input_dropout_p=0,
73 | dropout_p=0,
74 | n_layers=1,
75 | bidirectional=False,
76 | rnn_cell='gru',
77 | variable_lengths=True):
78 | super().__init__(
79 | vocab_size, hidden_size, input_dropout_p, dropout_p, n_layers,
80 | rnn_cell)
81 |
82 | self.variable_lengths = variable_lengths
83 | self.n_layers = n_layers
84 | self.bidirectional = bidirectional
85 | self.rnn_name = rnn_cell
86 |
87 | self.rnn = self.rnn_cell(
88 | hidden_size,
89 | hidden_size,
90 | n_layers,
91 | batch_first=True,
92 | bidirectional=bidirectional,
93 | dropout=dropout_p)
94 |
95 | def forward(self, input_var, h_0=None, input_lengths=None):
96 | batch_size = input_var.size()[0]
97 | embedded = input_var
98 |
99 | if self.variable_lengths:
100 | embedded = nn.utils.rnn.pack_padded_sequence(
101 | embedded, input_lengths, batch_first=True, enforce_sorted=False)
102 |
103 | # output of shape (batch, seq_len, num_directions * hidden_size)
104 | # h_n of shape (num_layers * num_directions, batch, hidden_size)
105 | if self.rnn_name == 'gru':
106 | output, hidden = self.rnn(embedded, h_0)
107 | else:
108 | output, (hidden, _) = self.rnn(embedded, h_0)
109 |
110 | if self.variable_lengths:
111 | total_length = input_var.size()[1]
112 | output, _ = nn.utils.rnn.pad_packed_sequence(
113 | output, batch_first=True, total_length=None)
114 |
115 | if self.n_layers > 1 and self.bidirectional:
116 | output = self.dropout(F.relu(self.compress_output(output)))
117 | hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1)
118 | hidden = self.dropout(F.relu(self.compress_hn_layers_bi(hidden)))
119 | elif self.n_layers > 1:
120 | hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1)
121 | hidden = self.dropout(F.relu(self.compress_hn_layers(hidden)))
122 | elif self.bidirectional:
123 | output = self.dropout(F.relu(self.compress_output(output)))
124 | hidden = hidden.transpose(0, 1).contiguous().view(batch_size, -1)
125 | hidden = self.dropout(F.relu(self.compress_hn_bi(hidden)))
126 | # output of shape (batch, seq_len, hidden_size) hidden of shape (batch, hidden_size)
127 |
128 | return output, hidden
129 |
--------------------------------------------------------------------------------
/model_overview.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ajseo17/MASN-pytorch/5ca3fc80cf37f7b6124070b1aae5bc599db8fa29/model_overview.jpeg
--------------------------------------------------------------------------------
/util.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # This code is modified from Jumpin2's repository.
3 | # https://github.com/Jumpin2/HGA
4 | # --------------------------------------------------------
5 | """ Common utilities. """
6 |
7 | # Logging
8 | # =======
9 |
10 | import logging
11 | import os, os.path
12 | from colorlog import ColoredFormatter
13 | import torch
14 |
15 | ch = logging.StreamHandler()
16 | ch.setLevel(logging.DEBUG)
17 |
18 | formatter = ColoredFormatter(
19 | "%(log_color)s[%(asctime)s] %(message)s",
20 | # datefmt='%H:%M:%S.%f',
21 | datefmt=None,
22 | reset=True,
23 | log_colors={
24 | 'DEBUG': 'cyan',
25 | 'INFO': 'white,bold',
26 | 'INFOV': 'cyan,bold',
27 | 'WARNING': 'yellow',
28 | 'ERROR': 'red,bold',
29 | 'CRITICAL': 'red,bg_white',
30 | },
31 | secondary_log_colors={},
32 | style='%')
33 | ch.setFormatter(formatter)
34 |
35 | log = logging.getLogger('videocap')
36 | log.setLevel(logging.DEBUG)
37 | log.handlers = [] # No duplicated handlers
38 | log.propagate = False # workaround for duplicated logs in ipython
39 | log.addHandler(ch)
40 |
41 | logging.addLevelName(logging.INFO + 1, 'INFOV')
42 |
43 |
44 | def _infov(self, msg, *args, **kwargs):
45 | self.log(logging.INFO + 1, msg, *args, **kwargs)
46 |
47 |
48 | logging.Logger.infov = _infov
49 |
50 |
51 | class AverageMeter(object):
52 | """Computes and stores the average and current value"""
53 |
54 | def __init__(self):
55 | self.reset()
56 |
57 | def reset(self):
58 | self.val = 0
59 | self.avg = 0
60 | self.sum = 0
61 | self.count = 0
62 |
63 | def update(self, val, n=1):
64 | self.val = val
65 | self.sum += val * n
66 | self.count += n
67 | self.avg = self.sum / self.count
68 |
69 |
70 | class StrToBytes:
71 |
72 | def __init__(self, fileobj):
73 | self.fileobj = fileobj
74 |
75 | def read(self, size):
76 | return self.fileobj.read(size).encode()
77 |
78 | def readline(self, size=-1):
79 | return self.fileobj.readline(size).encode()
80 |
81 |
82 | def get_accuracy(logits, targets):
83 | correct = torch.sum(logits.eq(targets)).float()
84 | return correct * 100.0 / targets.size(0)
85 |
86 |
87 | class nvidia_prefetcher():
88 | def __init__(self, loader):
89 | self.loader = iter(loader)
90 | self.stream = torch.cuda.Stream()
91 | self.preload()
92 |
93 | def preload(self):
94 | try:
95 | self.next_data = next(self.loader)
96 | except StopIteration:
97 | self.next_data = None
98 | return
99 | with torch.cuda.stream(self.stream):
100 | self.next_data = [d.cuda(non_blocking=True) for d in self.next_data]
101 | # With Amp, it isn't necessary to manually convert data to half.
102 | # if args.fp16:
103 | # self.next_input = self.next_input.half()
104 | # else:
105 |
106 | def next(self):
107 | torch.cuda.current_stream().wait_stream(self.stream)
108 | if self.next_data is None:
109 | raise StopIteration
110 | next_data = self.next_data
111 | self.preload()
112 | return next_data
113 |
114 | def __next__(self):
115 | return self.next()
116 |
117 | def __iter__(self):
118 | return self
--------------------------------------------------------------------------------
/warmup_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 | Args:
9 | optimizer (Optimizer): Wrapped optimizer.
10 | multiplier: target learning rate = base lr * multiplier
11 | total_epoch: target learning rate is reached at total_epoch, gradually
12 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
13 | """
14 |
15 | def __init__(
16 | self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier <= 1.:
19 | raise ValueError('multiplier should be greater than 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super().__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [
30 | base_lr * self.multiplier for base_lr in self.base_lrs
31 | ]
32 | self.finished = True
33 | return self.after_scheduler.get_lr()
34 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
35 |
36 | return [
37 | base_lr *
38 | ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.)
39 | for base_lr in self.base_lrs
40 | ]
41 |
42 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
43 | if epoch is None:
44 | epoch = self.last_epoch + 1
45 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
46 | if self.last_epoch <= self.total_epoch:
47 | warmup_lr = [
48 | base_lr * (
49 | (self.multiplier - 1.) * self.last_epoch / self.total_epoch
50 | + 1.) for base_lr in self.base_lrs
51 | ]
52 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
53 | param_group['lr'] = lr
54 | else:
55 | if epoch is None:
56 | self.after_scheduler.step(metrics, None)
57 | else:
58 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
59 |
60 | def step(self, epoch=None, metrics=None):
61 | if type(self.after_scheduler) != ReduceLROnPlateau:
62 | if self.finished and self.after_scheduler:
63 | if epoch is None:
64 | self.after_scheduler.step(None)
65 | else:
66 | self.after_scheduler.step(epoch - self.total_epoch)
67 | else:
68 | return super(GradualWarmupScheduler, self).step(epoch)
69 | else:
70 | self.step_ReduceLROnPlateau(metrics, epoch)
--------------------------------------------------------------------------------