├── .gitignore ├── CRNUnit.png ├── DataLoader.py ├── LICENSE ├── README.md ├── config.py ├── configs ├── msrvtt_qa.yml ├── msvd_qa.yml ├── tgif_qa_action.yml ├── tgif_qa_count.yml ├── tgif_qa_frameqa.yml └── tgif_qa_transition.yml ├── data ├── glove │ └── txt2pickle.py └── tgif-qa │ └── action │ ├── tgif-qa_action_test_questions.pt │ ├── tgif-qa_action_train_questions.pt │ ├── tgif-qa_action_val_questions.pt │ └── tgif-qa_action_vocab.json ├── model ├── CRN.py ├── HCRN.py └── utils.py ├── overview.png ├── preprocess ├── datautils │ ├── msrvtt_qa.py │ ├── msvd_qa.py │ ├── tgif_qa.py │ └── utils.py ├── models │ ├── densenet.py │ ├── pre_act_resnet.py │ ├── resnet.py │ ├── resnext.py │ └── wide_resnet.py ├── preprocess_features.py └── preprocess_questions.py ├── requirements.txt ├── train.py ├── utils.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pth 2 | *.h5 3 | *.tsv 4 | *.jpg 5 | *.zip 6 | *.pkl 7 | *.xml 8 | *.p 9 | *.svg 10 | *.model 11 | 12 | .idea/* 13 | results/* 14 | __pycache__/ 15 | -------------------------------------------------------------------------------- /CRNUnit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thaolmk54/hcrn-videoqa/eb92c9b21aaa00f912fe5c2e5188abc47fda5211/CRNUnit.png -------------------------------------------------------------------------------- /DataLoader.py: -------------------------------------------------------------------------------- 1 | # DISTRIBUTION STATEMENT A. Approved for public release: distribution unlimited. 2 | # 3 | # This material is based upon work supported by the Assistant Secretary of Defense for Research and 4 | # Engineering under Air Force Contract No. FA8721-05-C-0002 and/or FA8702-15-D-0001. Any opinions, 5 | # findings, conclusions or recommendations expressed in this material are those of the author(s) and 6 | # do not necessarily reflect the views of the Assistant Secretary of Defense for Research and 7 | # Engineering. 8 | # 9 | # © 2017 Massachusetts Institute of Technology. 10 | # 11 | # MIT Proprietary, Subject to FAR52.227-11 Patent Rights - Ownership by the contractor (May 2014) 12 | # 13 | # The software/firmware is provided to you on an As-Is basis 14 | # 15 | # Delivered to the U.S. Government with Unlimited Rights, as defined in DFARS Part 252.227-7013 or 16 | # 7014 (Feb 2014). Notwithstanding any copyright notice, U.S. Government rights in this work are 17 | # defined by DFARS 252.227-7013 or DFARS 252.227-7014 as detailed above. Use of this work other than 18 | # as specifically authorized by the U.S. Government may violate any copyrights that exist in this 19 | # work. 20 | 21 | import numpy as np 22 | import json 23 | import pickle 24 | import torch 25 | import math 26 | import h5py 27 | from torch.utils.data import Dataset, DataLoader 28 | 29 | 30 | def invert_dict(d): 31 | return {v: k for k, v in d.items()} 32 | 33 | 34 | def load_vocab(path): 35 | with open(path, 'r') as f: 36 | vocab = json.load(f) 37 | vocab['question_idx_to_token'] = invert_dict(vocab['question_token_to_idx']) 38 | vocab['answer_idx_to_token'] = invert_dict(vocab['answer_token_to_idx']) 39 | vocab['question_answer_idx_to_token'] = invert_dict(vocab['question_answer_token_to_idx']) 40 | return vocab 41 | 42 | 43 | class VideoQADataset(Dataset): 44 | 45 | def __init__(self, answers, ans_candidates, ans_candidates_len, questions, questions_len, video_ids, q_ids, 46 | app_feature_h5, app_feat_id_to_index, motion_feature_h5, motion_feat_id_to_index): 47 | # convert data to tensor 48 | self.all_answers = answers 49 | self.all_questions = torch.LongTensor(np.asarray(questions)) 50 | self.all_questions_len = torch.LongTensor(np.asarray(questions_len)) 51 | self.all_video_ids = torch.LongTensor(np.asarray(video_ids)) 52 | self.all_q_ids = q_ids 53 | self.app_feature_h5 = app_feature_h5 54 | self.motion_feature_h5 = motion_feature_h5 55 | self.app_feat_id_to_index = app_feat_id_to_index 56 | self.motion_feat_id_to_index = motion_feat_id_to_index 57 | 58 | if not np.any(ans_candidates): 59 | self.question_type = 'openended' 60 | else: 61 | self.question_type = 'mulchoices' 62 | self.all_ans_candidates = torch.LongTensor(np.asarray(ans_candidates)) 63 | self.all_ans_candidates_len = torch.LongTensor(np.asarray(ans_candidates_len)) 64 | 65 | def __getitem__(self, index): 66 | answer = self.all_answers[index] if self.all_answers is not None else None 67 | ans_candidates = torch.zeros(5) 68 | ans_candidates_len = torch.zeros(5) 69 | if self.question_type == 'mulchoices': 70 | ans_candidates = self.all_ans_candidates[index] 71 | ans_candidates_len = self.all_ans_candidates_len[index] 72 | question = self.all_questions[index] 73 | question_len = self.all_questions_len[index] 74 | video_idx = self.all_video_ids[index].item() 75 | question_idx = self.all_q_ids[index] 76 | app_index = self.app_feat_id_to_index[str(video_idx)] 77 | motion_index = self.motion_feat_id_to_index[str(video_idx)] 78 | with h5py.File(self.app_feature_h5, 'r') as f_app: 79 | appearance_feat = f_app['resnet_features'][app_index] # (8, 16, 2048) 80 | with h5py.File(self.motion_feature_h5, 'r') as f_motion: 81 | motion_feat = f_motion['resnext_features'][motion_index] # (8, 2048) 82 | appearance_feat = torch.from_numpy(appearance_feat) 83 | motion_feat = torch.from_numpy(motion_feat) 84 | return ( 85 | video_idx, question_idx, answer, ans_candidates, ans_candidates_len, appearance_feat, motion_feat, question, 86 | question_len) 87 | 88 | def __len__(self): 89 | return len(self.all_questions) 90 | 91 | 92 | class VideoQADataLoader(DataLoader): 93 | 94 | def __init__(self, **kwargs): 95 | vocab_json_path = str(kwargs.pop('vocab_json')) 96 | print('loading vocab from %s' % (vocab_json_path)) 97 | vocab = load_vocab(vocab_json_path) 98 | 99 | question_pt_path = str(kwargs.pop('question_pt')) 100 | print('loading questions from %s' % (question_pt_path)) 101 | question_type = kwargs.pop('question_type') 102 | with open(question_pt_path, 'rb') as f: 103 | obj = pickle.load(f) 104 | questions = obj['questions'] 105 | questions_len = obj['questions_len'] 106 | video_ids = obj['video_ids'] 107 | q_ids = obj['question_id'] 108 | answers = obj['answers'] 109 | glove_matrix = obj['glove'] 110 | ans_candidates = np.zeros(5) 111 | ans_candidates_len = np.zeros(5) 112 | if question_type in ['action', 'transition']: 113 | ans_candidates = obj['ans_candidates'] 114 | ans_candidates_len = obj['ans_candidates_len'] 115 | 116 | if 'train_num' in kwargs: 117 | trained_num = kwargs.pop('train_num') 118 | if trained_num > 0: 119 | questions = questions[:trained_num] 120 | questions_len = questions_len[:trained_num] 121 | video_ids = video_ids[:trained_num] 122 | q_ids = q_ids[:trained_num] 123 | answers = answers[:trained_num] 124 | if question_type in ['action', 'transition']: 125 | ans_candidates = ans_candidates[:trained_num] 126 | ans_candidates_len = ans_candidates_len[:trained_num] 127 | if 'val_num' in kwargs: 128 | val_num = kwargs.pop('val_num') 129 | if val_num > 0: 130 | questions = questions[:val_num] 131 | questions_len = questions_len[:val_num] 132 | video_ids = video_ids[:val_num] 133 | q_ids = q_ids[:val_num] 134 | answers = answers[:val_num] 135 | if question_type in ['action', 'transition']: 136 | ans_candidates = ans_candidates[:val_num] 137 | ans_candidates_len = ans_candidates_len[:val_num] 138 | if 'test_num' in kwargs: 139 | test_num = kwargs.pop('test_num') 140 | if test_num > 0: 141 | questions = questions[:test_num] 142 | questions_len = questions_len[:test_num] 143 | video_ids = video_ids[:test_num] 144 | q_ids = q_ids[:test_num] 145 | answers = answers[:test_num] 146 | if question_type in ['action', 'transition']: 147 | ans_candidates = ans_candidates[:test_num] 148 | ans_candidates_len = ans_candidates_len[:test_num] 149 | 150 | print('loading appearance feature from %s' % (kwargs['appearance_feat'])) 151 | with h5py.File(kwargs['appearance_feat'], 'r') as app_features_file: 152 | app_video_ids = app_features_file['ids'][()] 153 | app_feat_id_to_index = {str(id): i for i, id in enumerate(app_video_ids)} 154 | print('loading motion feature from %s' % (kwargs['motion_feat'])) 155 | with h5py.File(kwargs['motion_feat'], 'r') as motion_features_file: 156 | motion_video_ids = motion_features_file['ids'][()] 157 | motion_feat_id_to_index = {str(id): i for i, id in enumerate(motion_video_ids)} 158 | self.app_feature_h5 = kwargs.pop('appearance_feat') 159 | self.motion_feature_h5 = kwargs.pop('motion_feat') 160 | self.dataset = VideoQADataset(answers, ans_candidates, ans_candidates_len, questions, questions_len, 161 | video_ids, q_ids, 162 | self.app_feature_h5, app_feat_id_to_index, self.motion_feature_h5, 163 | motion_feat_id_to_index) 164 | 165 | self.vocab = vocab 166 | self.batch_size = kwargs['batch_size'] 167 | self.glove_matrix = glove_matrix 168 | 169 | super().__init__(self.dataset, **kwargs) 170 | 171 | def __len__(self): 172 | return math.ceil(len(self.dataset) / self.batch_size) 173 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Hierarchical Conditional Relation Networks for Video Question Answering (HCRN-VideoQA) 2 | 3 | We introduce a general-purpose reusable neural unit called Conditional Relation Network (CRN) that encapsulates and transforms an array of tensorial objects into a new array of the same kind, conditioned on a contextual feature. The flexibility of CRN units is then examined in solving Video Question Answering, a challenging problem requiring joint comprehension of video content and natural language processing. 4 | 5 | Illustrations of CRN unit and the result of model building HCNR for VideoQA: 6 | 7 | CRN Unit | HCRN Architecture 8 | :-------------------------:|:-------------------------: 9 | ![](CRNUnit.png) | ![](overview.png) 10 | 11 | Check out our [paper](https://arxiv.org/abs/2002.10698) for details. 12 | 13 | ## Setups 14 | 1. Clone the repository: 15 | ``` 16 | git clone https://github.com/thaolmk54/hcrn-videoqa.git 17 | ``` 18 | 19 | 2. Download [TGIF-QA](https://github.com/YunseokJANG/tgif-qa), [MSRVTT-QA, MSVD-QA](https://github.com/xudejing/video-question-answering) dataset and edit absolute paths in `preprocess/preprocess_features.py` and `preprocess/preprocess_questions.py` upon where you locate your data. Default paths are with `/ceph-g/lethao/datasets/{dataset_name}/`. 20 | 21 | 3. Install dependencies: 22 | ```bash 23 | conda create -n hcrn_videoqa python=3.6 24 | conda activate hcrn_videoqa 25 | conda install -c conda-forge ffmpeg 26 | conda install -c conda-forge scikit-video 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ## Experiments with TGIF-QA 31 | Depending on the task to chose `question_type` out of 4 options: `action, transition, count, frameqa`. 32 | #### Preprocessing visual features 33 | 1. To extract appearance feature: 34 | 35 | ``` 36 | python preprocess/preprocess_features.py --gpu_id 2 --dataset tgif-qa --model resnet101 --question_type {question_type} 37 | ``` 38 | 39 | 2. To extract motion feature: 40 | 41 | Download ResNeXt-101 [pretrained model](https://drive.google.com/drive/folders/1zvl89AgFAApbH0At-gMuZSeQB_LpNP-M) (resnext-101-kinetics.pth) and place it to `data/preprocess/pretrained/`. 42 | 43 | ``` 44 | python preprocess/preprocess_features.py --dataset tgif-qa --model resnext101 --image_height 112 --image_width 112 --question_type {question_type} 45 | ``` 46 | 47 | **Note**: Extracting visual feature takes a long time. You can download our pre-extracted features from [here](https://bit.ly/2TX9rlZ) and save them in `data/tgif-qa/{question_type}/`. Please use the following command to join split files: 48 | 49 | ``` 50 | cat tgif-qa_{question_type}_appearance_feat.h5.part* > tgif-qa_{question_type}_appearance_feat.h5 51 | ``` 52 | 53 | #### Proprocess linguistic features 54 | 1. Download [glove pretrained 300d word vectors](http://nlp.stanford.edu/data/glove.840B.300d.zip) to `data/glove/` and process it into a pickle file: 55 | 56 | ``` 57 | python txt2pickle.py 58 | ``` 59 | 2. Preprocess train/val/test questions: 60 | ``` 61 | python preprocess/preprocess_questions.py --dataset tgif-qa --question_type {question_type} --glove_pt data/glove/glove.840.300d.pkl --mode train 62 | 63 | python preprocess/preprocess_questions.py --dataset tgif-qa --question_type {question_type} --mode test 64 | ``` 65 | #### Training 66 | Choose a suitable config file in `configs/{task}.yml` for one of 4 tasks: `action, transition, count, frameqa` to train the model. For example, to train with action task, run the following command: 67 | ```bash 68 | python train.py --cfg configs/tgif_qa_action.yml 69 | ``` 70 | 71 | #### Evaluation 72 | To evaluate the trained model, run the following: 73 | ```bash 74 | python validate.py --cfg configs/tgif_qa_action.yml 75 | ``` 76 | **Note**: Pretrained model for action task is available [here](https://drive.google.com/open?id=1xzD4JbuoFYAgJG41eAwBo77i3oVrbKyg). Save the file in `results/expTGIF-QAAction/ckpt/` for evaluation. 77 | ## Experiments with MSRVTT-QA and MSVD-QA 78 | The following is to run experiments with MSRVTT-QA dataset, replace `msrvtt-qa` with `msvd-qa` to run with MSVD-QA dataset. 79 | #### Preprocessing visual features 80 | 1. To extract appearance feature: 81 | ``` 82 | python preprocess/preprocess_features.py --gpu_id 2 --dataset msrvtt-qa --model resnet101 83 | ``` 84 | 2. To extract motion feature: 85 | ``` 86 | python preprocess/preprocess_features.py --dataset msrvtt-qa --model resnext101 --image_height 112 --image_width 112 87 | ``` 88 | 89 | #### Proprocess linguistic features 90 | Preprocess train/val/test questions: 91 | ``` 92 | python preprocess/preprocess_questions.py --dataset msrvtt-qa --glove_pt data/glove/glove.840.300d.pkl --mode train 93 | 94 | python preprocess/preprocess_questions.py --dataset msrvtt-qa --question_type {question_type} --mode val 95 | 96 | python preprocess/preprocess_questions.py --dataset msrvtt-qa --question_type {question_type} --mode test 97 | ``` 98 | 99 | #### Training 100 | ```bash 101 | python train.py --cfg configs/msrvtt_qa.yml 102 | ``` 103 | 104 | #### Evaluation 105 | To evaluate the trained model, run the following: 106 | ```bash 107 | python validate.py --cfg configs/msrvtt_qa.yml 108 | ``` 109 | ## Citations 110 | If you make use of this repository for your research, please cite the following paper: 111 | ``` 112 | @article{le2020hierarchical, 113 | title={Hierarchical Conditional Relation Networks for Video Question Answering}, 114 | author={Le, Thao Minh and Le, Vuong and Venkatesh, Svetha and Tran, Truyen}, 115 | journal={arXiv preprint arXiv:2002.10698}, 116 | year={2020} 117 | } 118 | ``` 119 | ## Acknowledgement 120 | - As for motion feature extraction, we adapt ResNeXt-101 model from this [repo](https://github.com/kenshohara/video-classification-3d-cnn-pytorch) to our code. Thank @kenshohara for releasing the code and the pretrained models. 121 | - We refer to this [repo](https://github.com/facebookresearch/clevr-iep) for preprocessing. 122 | - Our implementation of dataloader is based on this [repo](https://github.com/shijx12/XNM-Net). 123 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | 4 | import numpy as np 5 | from easydict import EasyDict as edict 6 | 7 | __C = edict() 8 | cfg = __C 9 | 10 | __C.gpu_id = 0 11 | __C.num_workers = 4 12 | __C.multi_gpus = False 13 | __C.seed = 666 14 | # training options 15 | __C.train = edict() 16 | __C.train.restore = False 17 | __C.train.lr = 0.0001 18 | __C.train.batch_size = 32 19 | __C.train.max_epochs = 25 20 | __C.train.vision_dim = 2048 21 | __C.train.word_dim = 300 22 | __C.train.module_dim = 512 23 | __C.train.train_num = 0 # Default 0 for full train set 24 | __C.train.restore = False 25 | __C.train.glove = True 26 | __C.train.k_max_frame_level = 16 27 | __C.train.k_max_clip_level = 8 28 | __C.train.spl_resolution = 1 29 | __C.train = dict(__C.train) 30 | 31 | # validation 32 | __C.val = edict() 33 | __C.val.flag = True 34 | __C.val.val_num = 0 # Default 0 for full val set 35 | __C.val = dict(__C.val) 36 | # test 37 | __C.test = edict() 38 | __C.test.test_num = 0 # Default 0 for full test set 39 | __C.test.write_preds = False 40 | __C.test = dict(__C.test) 41 | # dataset options 42 | __C.dataset = edict() 43 | __C.dataset.name = 'tgif-qa' # ['tgif-qa', 'msrvtt-qa', 'msvd-qa'] 44 | __C.dataset.question_type = 'none' #['frameqa', 'count', 'transition', 'action', 'none'] 45 | __C.dataset.data_dir = '' 46 | __C.dataset.appearance_feat = '{}_{}_appearance_feat.h5' 47 | __C.dataset.motion_feat = '{}_{}_motion_feat.h5' 48 | __C.dataset.vocab_json = '{}_{}_vocab.json' 49 | __C.dataset.train_question_pt = '{}_{}_train_questions.pt' 50 | __C.dataset.val_question_pt = '{}_{}_val_questions.pt' 51 | __C.dataset.test_question_pt = '{}_{}_test_questions.pt' 52 | __C.dataset.save_dir = '' 53 | __C.dataset = dict(__C.dataset) 54 | 55 | # experiment name 56 | __C.exp_name = 'defaultExp' 57 | 58 | # credit https://github.com/tohinz/pytorch-mac-network/blob/master/code/config.py 59 | def merge_cfg(yaml_cfg, cfg): 60 | if type(yaml_cfg) is not edict: 61 | return 62 | 63 | for k, v in yaml_cfg.items(): 64 | if not k in cfg: 65 | raise KeyError('{} is not a valid config key'.format(k)) 66 | 67 | old_type = type(cfg[k]) 68 | if old_type is not type(v): 69 | if isinstance(cfg[k], np.ndarray): 70 | v = np.array(v, dtype=cfg[k].dtype) 71 | elif isinstance(cfg[k], list): 72 | v = v.split(",") 73 | v = [int(_v) for _v in v] 74 | elif cfg[k] is None: 75 | if v == "None": 76 | continue 77 | else: 78 | v = v 79 | else: 80 | raise ValueError(('Type mismatch ({} vs. {}) ' 81 | 'for config key: {}').format(type(cfg[k]), 82 | type(v), k)) 83 | # recursively merge dicts 84 | if type(v) is edict: 85 | try: 86 | merge_cfg(yaml_cfg[k], cfg[k]) 87 | except: 88 | print('Error under config key: {}'.format(k)) 89 | raise 90 | else: 91 | cfg[k] = v 92 | 93 | 94 | 95 | def cfg_from_file(file_name): 96 | import yaml 97 | with open(file_name, 'r') as f: 98 | yaml_cfg = edict(yaml.load(f)) 99 | 100 | merge_cfg(yaml_cfg, __C) -------------------------------------------------------------------------------- /configs/msrvtt_qa.yml: -------------------------------------------------------------------------------- 1 | gpu_id: 1 2 | multi_gpus: False 3 | num_workers: 4 4 | seed: 666 5 | exp_name: 'expMSRVTT-QA' 6 | 7 | train: 8 | lr: 0.0001 9 | batch_size: 32 10 | restore: False 11 | max_epochs: 25 12 | word_dim: 300 13 | module_dim: 512 14 | glove: True 15 | k_max_frame_level: 16 16 | k_max_clip_level: 8 17 | spl_resolution: 1 18 | 19 | 20 | val: 21 | flag: True 22 | 23 | test: 24 | test_num: 0 25 | write_preds: False 26 | 27 | dataset: 28 | name: 'msrvtt-qa' 29 | question_type: 'none' 30 | data_dir: 'data/msrvtt-qa' 31 | save_dir: 'results/' -------------------------------------------------------------------------------- /configs/msvd_qa.yml: -------------------------------------------------------------------------------- 1 | gpu_id: 3 2 | multi_gpus: False 3 | num_workers: 4 4 | seed: 666 5 | exp_name: 'expMSVD-QA' 6 | 7 | train: 8 | lr: 0.0001 9 | batch_size: 32 10 | restore: False 11 | max_epochs: 25 12 | word_dim: 300 13 | module_dim: 512 14 | glove: True 15 | k_max_frame_level: 16 16 | k_max_clip_level: 8 17 | spl_resolution: 1 18 | 19 | val: 20 | flag: True 21 | 22 | test: 23 | test_num: 0 24 | write_preds: False 25 | 26 | dataset: 27 | name: 'msvd-qa' 28 | question_type: 'none' 29 | data_dir: 'data/msvd-qa' 30 | save_dir: 'results/' -------------------------------------------------------------------------------- /configs/tgif_qa_action.yml: -------------------------------------------------------------------------------- 1 | gpu_id: 2 2 | multi_gpus: False 3 | num_workers: 4 4 | seed: 666 5 | exp_name: 'expTGIF-QAAction' 6 | 7 | train: 8 | lr: 0.0001 9 | batch_size: 32 10 | restore: False 11 | max_epochs: 25 12 | word_dim: 300 13 | module_dim: 512 14 | glove: True 15 | k_max_frame_level: 16 16 | k_max_clip_level: 8 17 | spl_resolution: 1 18 | 19 | val: 20 | flag: True 21 | 22 | test: 23 | test_num: 0 24 | write_preds: False 25 | 26 | dataset: 27 | name: 'tgif-qa' 28 | question_type: 'action' 29 | data_dir: 'data/tgif-qa/action' 30 | save_dir: 'results/' -------------------------------------------------------------------------------- /configs/tgif_qa_count.yml: -------------------------------------------------------------------------------- 1 | gpu_id: 1 2 | multi_gpus: False 3 | num_workers: 4 4 | seed: 666 5 | exp_name: 'expTGIF-QACount' 6 | 7 | train: 8 | lr: 0.0001 9 | batch_size: 32 10 | restore: False 11 | max_epochs: 25 12 | word_dim: 300 13 | module_dim: 512 14 | glove: True 15 | k_max_frame_level: 16 16 | k_max_clip_level: 8 17 | spl_resolution: 1 18 | 19 | val: 20 | flag: True 21 | 22 | test: 23 | test_num: 0 24 | write_preds: False 25 | 26 | dataset: 27 | name: 'tgif-qa' 28 | question_type: 'count' 29 | data_dir: 'data/tgif-qa/count' 30 | save_dir: 'results/' -------------------------------------------------------------------------------- /configs/tgif_qa_frameqa.yml: -------------------------------------------------------------------------------- 1 | gpu_id: 1 2 | multi_gpus: False 3 | num_workers: 4 4 | seed: 666 5 | exp_name: 'expTGIF-QAFrameQA' 6 | 7 | train: 8 | lr: 0.0001 9 | batch_size: 32 10 | restore: False 11 | max_epochs: 25 12 | word_dim: 300 13 | module_dim: 512 14 | glove: True 15 | k_max_frame_level: 16 16 | k_max_clip_level: 8 17 | spl_resolution: 1 18 | 19 | 20 | val: 21 | flag: True 22 | 23 | test: 24 | test_num: 0 25 | write_preds: False 26 | 27 | dataset: 28 | name: 'tgif-qa' 29 | question_type: 'frameqa' 30 | data_dir: 'data/tgif-qa/frameqa' 31 | save_dir: 'results/' -------------------------------------------------------------------------------- /configs/tgif_qa_transition.yml: -------------------------------------------------------------------------------- 1 | gpu_id: 1 2 | multi_gpus: False 3 | num_workers: 4 4 | seed: 666 5 | exp_name: 'expTGIF-QATransition' 6 | 7 | train: 8 | lr: 0.0001 9 | batch_size: 32 10 | restore: False 11 | max_epochs: 25 12 | word_dim: 300 13 | module_dim: 512 14 | glove: True 15 | k_max_frame_level: 16 16 | k_max_clip_level: 8 17 | spl_resolution: 1 18 | 19 | val: 20 | flag: True 21 | 22 | test: 23 | test_num: 0 24 | write_preds: False 25 | 26 | dataset: 27 | name: 'tgif-qa' 28 | question_type: 'transition' 29 | data_dir: 'data/tgif-qa/transition' 30 | save_dir: 'results/' -------------------------------------------------------------------------------- /data/glove/txt2pickle.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import pickle 3 | import csv 4 | 5 | df = pd.read_csv('glove.840B.300d.txt', sep =" ", quoting=3, header=None, index_col=0) 6 | print("glove file loaded!") 7 | glove = {key: val.values for key, val in df.T.items()} 8 | 9 | with open('glove.840.300d.pkl', 'wb') as fp: 10 | pickle.dump(glove, fp) -------------------------------------------------------------------------------- /data/tgif-qa/action/tgif-qa_action_test_questions.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thaolmk54/hcrn-videoqa/eb92c9b21aaa00f912fe5c2e5188abc47fda5211/data/tgif-qa/action/tgif-qa_action_test_questions.pt -------------------------------------------------------------------------------- /data/tgif-qa/action/tgif-qa_action_train_questions.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thaolmk54/hcrn-videoqa/eb92c9b21aaa00f912fe5c2e5188abc47fda5211/data/tgif-qa/action/tgif-qa_action_train_questions.pt -------------------------------------------------------------------------------- /data/tgif-qa/action/tgif-qa_action_val_questions.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thaolmk54/hcrn-videoqa/eb92c9b21aaa00f912fe5c2e5188abc47fda5211/data/tgif-qa/action/tgif-qa_action_val_questions.pt -------------------------------------------------------------------------------- /model/CRN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import itertools 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.modules.module import Module 7 | 8 | 9 | class CRN(Module): 10 | def __init__(self, module_dim, num_objects, max_subset_size, gating=False, spl_resolution=1): 11 | super(CRN, self).__init__() 12 | self.module_dim = module_dim 13 | self.gating = gating 14 | 15 | self.k_objects_fusion = nn.ModuleList() 16 | if self.gating: 17 | self.gate_k_objects_fusion = nn.ModuleList() 18 | for i in range(min(num_objects, max_subset_size + 1), 1, -1): 19 | self.k_objects_fusion.append(nn.Linear(2 * module_dim, module_dim)) 20 | if self.gating: 21 | self.gate_k_objects_fusion.append(nn.Linear(2 * module_dim, module_dim)) 22 | self.spl_resolution = spl_resolution 23 | self.activation = nn.ELU() 24 | self.max_subset_size = max_subset_size 25 | 26 | def forward(self, object_list, cond_feat): 27 | """ 28 | :param object_list: list of tensors or vectors 29 | :param cond_feat: conditioning feature 30 | :return: list of output objects 31 | """ 32 | scales = [i for i in range(len(object_list), 1, -1)] 33 | 34 | relations_scales = [] 35 | subsample_scales = [] 36 | for scale in scales: 37 | relations_scale = self.relationset(len(object_list), scale) 38 | relations_scales.append(relations_scale) 39 | subsample_scales.append(min(self.spl_resolution, len(relations_scale))) 40 | 41 | crn_feats = [] 42 | if len(scales) > 1 and self.max_subset_size == len(object_list): 43 | start_scale = 1 44 | else: 45 | start_scale = 0 46 | for scaleID in range(start_scale, min(len(scales), self.max_subset_size)): 47 | idx_relations_randomsample = np.random.choice(len(relations_scales[scaleID]), 48 | subsample_scales[scaleID], replace=False) 49 | mono_scale_features = 0 50 | for id_choice, idx in enumerate(idx_relations_randomsample): 51 | clipFeatList = [object_list[obj].unsqueeze(1) for obj in relations_scales[scaleID][idx]] 52 | # g_theta 53 | g_feat = torch.cat(clipFeatList, dim=1) 54 | g_feat = g_feat.mean(1) 55 | if len(g_feat.size()) == 2: 56 | h_feat = torch.cat((g_feat, cond_feat), dim=-1) 57 | elif len(g_feat.size()) == 3: 58 | cond_feat_repeat = cond_feat.repeat(1, g_feat.size(1), 1) 59 | h_feat = torch.cat((g_feat, cond_feat_repeat), dim=-1) 60 | if self.gating: 61 | h_feat = self.activation(self.k_objects_fusion[scaleID](h_feat)) * torch.sigmoid( 62 | self.gate_k_objects_fusion[scaleID](h_feat)) 63 | else: 64 | h_feat = self.activation(self.k_objects_fusion[scaleID](h_feat)) 65 | mono_scale_features += h_feat 66 | crn_feats.append(mono_scale_features / len(idx_relations_randomsample)) 67 | return crn_feats 68 | 69 | def relationset(self, num_objects, num_object_relation): 70 | return list(itertools.combinations([i for i in range(num_objects)], num_object_relation)) 71 | -------------------------------------------------------------------------------- /model/HCRN.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.nn import functional as F 3 | 4 | from .utils import * 5 | from .CRN import CRN 6 | 7 | 8 | class FeatureAggregation(nn.Module): 9 | def __init__(self, module_dim=512): 10 | super(FeatureAggregation, self).__init__() 11 | self.module_dim = module_dim 12 | 13 | self.q_proj = nn.Linear(module_dim, module_dim, bias=False) 14 | self.v_proj = nn.Linear(module_dim, module_dim, bias=False) 15 | 16 | self.cat = nn.Linear(2 * module_dim, module_dim) 17 | self.attn = nn.Linear(module_dim, 1) 18 | 19 | self.activation = nn.ELU() 20 | self.dropout = nn.Dropout(0.15) 21 | 22 | def forward(self, question_rep, visual_feat): 23 | visual_feat = self.dropout(visual_feat) 24 | q_proj = self.q_proj(question_rep) 25 | v_proj = self.v_proj(visual_feat) 26 | 27 | v_q_cat = torch.cat((v_proj, q_proj.unsqueeze(1) * v_proj), dim=-1) 28 | v_q_cat = self.cat(v_q_cat) 29 | v_q_cat = self.activation(v_q_cat) 30 | 31 | attn = self.attn(v_q_cat) # (bz, k, 1) 32 | attn = F.softmax(attn, dim=1) # (bz, k, 1) 33 | 34 | v_distill = (attn * visual_feat).sum(1) 35 | 36 | return v_distill 37 | 38 | 39 | class InputUnitLinguistic(nn.Module): 40 | def __init__(self, vocab_size, wordvec_dim=300, rnn_dim=512, module_dim=512, bidirectional=True): 41 | super(InputUnitLinguistic, self).__init__() 42 | 43 | self.dim = module_dim 44 | 45 | self.bidirectional = bidirectional 46 | if bidirectional: 47 | rnn_dim = rnn_dim // 2 48 | 49 | self.encoder_embed = nn.Embedding(vocab_size, wordvec_dim) 50 | self.tanh = nn.Tanh() 51 | self.encoder = nn.LSTM(wordvec_dim, rnn_dim, batch_first=True, bidirectional=bidirectional) 52 | self.embedding_dropout = nn.Dropout(p=0.15) 53 | self.question_dropout = nn.Dropout(p=0.18) 54 | 55 | self.module_dim = module_dim 56 | 57 | def forward(self, questions, question_len): 58 | """ 59 | Args: 60 | question: [Tensor] (batch_size, max_question_length) 61 | question_len: [Tensor] (batch_size) 62 | return: 63 | question representation [Tensor] (batch_size, module_dim) 64 | """ 65 | questions_embedding = self.encoder_embed(questions) # (batch_size, seq_len, dim_word) 66 | embed = self.tanh(self.embedding_dropout(questions_embedding)) 67 | embed = nn.utils.rnn.pack_padded_sequence(embed, question_len, batch_first=True, 68 | enforce_sorted=False) 69 | 70 | self.encoder.flatten_parameters() 71 | _, (question_embedding, _) = self.encoder(embed) 72 | if self.bidirectional: 73 | question_embedding = torch.cat([question_embedding[0], question_embedding[1]], -1) 74 | question_embedding = self.question_dropout(question_embedding) 75 | 76 | return question_embedding 77 | 78 | 79 | class InputUnitVisual(nn.Module): 80 | def __init__(self, k_max_frame_level, k_max_clip_level, spl_resolution, vision_dim, module_dim=512): 81 | super(InputUnitVisual, self).__init__() 82 | 83 | self.clip_level_motion_cond = CRN(module_dim, k_max_frame_level, k_max_frame_level, gating=False, spl_resolution=spl_resolution) 84 | self.clip_level_question_cond = CRN(module_dim, k_max_frame_level-2, k_max_frame_level-2, gating=True, spl_resolution=spl_resolution) 85 | self.video_level_motion_cond = CRN(module_dim, k_max_clip_level, k_max_clip_level, gating=False, spl_resolution=spl_resolution) 86 | self.video_level_question_cond = CRN(module_dim, k_max_clip_level-2, k_max_clip_level-2, gating=True, spl_resolution=spl_resolution) 87 | 88 | self.sequence_encoder = nn.LSTM(vision_dim, module_dim, batch_first=True, bidirectional=False) 89 | self.clip_level_motion_proj = nn.Linear(vision_dim, module_dim) 90 | self.video_level_motion_proj = nn.Linear(module_dim, module_dim) 91 | self.appearance_feat_proj = nn.Linear(vision_dim, module_dim) 92 | 93 | self.question_embedding_proj = nn.Linear(module_dim, module_dim) 94 | 95 | self.module_dim = module_dim 96 | self.activation = nn.ELU() 97 | 98 | def forward(self, appearance_video_feat, motion_video_feat, question_embedding): 99 | """ 100 | Args: 101 | appearance_video_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 102 | motion_video_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 103 | question_embedding: [Tensor] (batch_size, module_dim) 104 | return: 105 | encoded video feature: [Tensor] (batch_size, N, module_dim) 106 | """ 107 | batch_size = appearance_video_feat.size(0) 108 | clip_level_crn_outputs = [] 109 | question_embedding_proj = self.question_embedding_proj(question_embedding) 110 | for i in range(appearance_video_feat.size(1)): 111 | clip_level_motion = motion_video_feat[:, i, :] # (bz, 2048) 112 | clip_level_motion_proj = self.clip_level_motion_proj(clip_level_motion) 113 | 114 | clip_level_appearance = appearance_video_feat[:, i, :, :] # (bz, 16, 2048) 115 | clip_level_appearance_proj = self.appearance_feat_proj(clip_level_appearance) # (bz, 16, 512) 116 | # clip level CRNs 117 | clip_level_crn_motion = self.clip_level_motion_cond(torch.unbind(clip_level_appearance_proj, dim=1), 118 | clip_level_motion_proj) 119 | clip_level_crn_question = self.clip_level_question_cond(clip_level_crn_motion, question_embedding_proj) 120 | 121 | clip_level_crn_output = torch.cat( 122 | [frame_relation.unsqueeze(1) for frame_relation in clip_level_crn_question], 123 | dim=1) 124 | clip_level_crn_output = clip_level_crn_output.view(batch_size, -1, self.module_dim) 125 | clip_level_crn_outputs.append(clip_level_crn_output) 126 | 127 | # Encode video level motion 128 | _, (video_level_motion, _) = self.sequence_encoder(motion_video_feat) 129 | video_level_motion = video_level_motion.transpose(0, 1) 130 | video_level_motion_feat_proj = self.video_level_motion_proj(video_level_motion) 131 | # video level CRNs 132 | video_level_crn_motion = self.video_level_motion_cond(clip_level_crn_outputs, video_level_motion_feat_proj) 133 | video_level_crn_question = self.video_level_question_cond(video_level_crn_motion, 134 | question_embedding_proj.unsqueeze(1)) 135 | 136 | video_level_crn_output = torch.cat([clip_relation.unsqueeze(1) for clip_relation in video_level_crn_question], 137 | dim=1) 138 | video_level_crn_output = video_level_crn_output.view(batch_size, -1, self.module_dim) 139 | 140 | return video_level_crn_output 141 | 142 | 143 | class OutputUnitOpenEnded(nn.Module): 144 | def __init__(self, module_dim=512, num_answers=1000): 145 | super(OutputUnitOpenEnded, self).__init__() 146 | 147 | self.question_proj = nn.Linear(module_dim, module_dim) 148 | 149 | self.classifier = nn.Sequential(nn.Dropout(0.15), 150 | nn.Linear(module_dim * 2, module_dim), 151 | nn.ELU(), 152 | nn.BatchNorm1d(module_dim), 153 | nn.Dropout(0.15), 154 | nn.Linear(module_dim, num_answers)) 155 | 156 | def forward(self, question_embedding, visual_embedding): 157 | question_embedding = self.question_proj(question_embedding) 158 | out = torch.cat([visual_embedding, question_embedding], 1) 159 | out = self.classifier(out) 160 | 161 | return out 162 | 163 | 164 | class OutputUnitMultiChoices(nn.Module): 165 | def __init__(self, module_dim=512): 166 | super(OutputUnitMultiChoices, self).__init__() 167 | 168 | self.question_proj = nn.Linear(module_dim, module_dim) 169 | 170 | self.ans_candidates_proj = nn.Linear(module_dim, module_dim) 171 | 172 | self.classifier = nn.Sequential(nn.Dropout(0.15), 173 | nn.Linear(module_dim * 4, module_dim), 174 | nn.ELU(), 175 | nn.BatchNorm1d(module_dim), 176 | nn.Dropout(0.15), 177 | nn.Linear(module_dim, 1)) 178 | 179 | def forward(self, question_embedding, q_visual_embedding, ans_candidates_embedding, 180 | a_visual_embedding): 181 | question_embedding = self.question_proj(question_embedding) 182 | ans_candidates_embedding = self.ans_candidates_proj(ans_candidates_embedding) 183 | out = torch.cat([q_visual_embedding, question_embedding, a_visual_embedding, 184 | ans_candidates_embedding], 1) 185 | out = self.classifier(out) 186 | 187 | return out 188 | 189 | 190 | class OutputUnitCount(nn.Module): 191 | def __init__(self, module_dim=512): 192 | super(OutputUnitCount, self).__init__() 193 | 194 | self.question_proj = nn.Linear(module_dim, module_dim) 195 | 196 | self.regression = nn.Sequential(nn.Dropout(0.15), 197 | nn.Linear(module_dim * 2, module_dim), 198 | nn.ELU(), 199 | nn.BatchNorm1d(module_dim), 200 | nn.Dropout(0.15), 201 | nn.Linear(module_dim, 1)) 202 | 203 | def forward(self, question_embedding, visual_embedding): 204 | question_embedding = self.question_proj(question_embedding) 205 | out = torch.cat([visual_embedding, question_embedding], 1) 206 | out = self.regression(out) 207 | 208 | return out 209 | 210 | 211 | class HCRNNetwork(nn.Module): 212 | def __init__(self, vision_dim, module_dim, word_dim, k_max_frame_level, k_max_clip_level, spl_resolution, vocab, question_type): 213 | super(HCRNNetwork, self).__init__() 214 | 215 | self.question_type = question_type 216 | self.feature_aggregation = FeatureAggregation(module_dim) 217 | 218 | if self.question_type in ['action', 'transition']: 219 | encoder_vocab_size = len(vocab['question_answer_token_to_idx']) 220 | self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, 221 | module_dim=module_dim, rnn_dim=module_dim) 222 | self.visual_input_unit = InputUnitVisual(k_max_frame_level=k_max_frame_level, k_max_clip_level=k_max_clip_level, spl_resolution=spl_resolution, vision_dim=vision_dim, module_dim=module_dim) 223 | self.output_unit = OutputUnitMultiChoices(module_dim=module_dim) 224 | 225 | elif self.question_type == 'count': 226 | encoder_vocab_size = len(vocab['question_token_to_idx']) 227 | self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, 228 | module_dim=module_dim, rnn_dim=module_dim) 229 | self.visual_input_unit = InputUnitVisual(k_max_frame_level=k_max_frame_level, k_max_clip_level=k_max_clip_level, spl_resolution=spl_resolution, vision_dim=vision_dim, module_dim=module_dim) 230 | self.output_unit = OutputUnitCount(module_dim=module_dim) 231 | else: 232 | encoder_vocab_size = len(vocab['question_token_to_idx']) 233 | self.num_classes = len(vocab['answer_token_to_idx']) 234 | self.linguistic_input_unit = InputUnitLinguistic(vocab_size=encoder_vocab_size, wordvec_dim=word_dim, 235 | module_dim=module_dim, rnn_dim=module_dim) 236 | self.visual_input_unit = InputUnitVisual(k_max_frame_level=k_max_frame_level, k_max_clip_level=k_max_clip_level, spl_resolution=spl_resolution, vision_dim=vision_dim, module_dim=module_dim) 237 | self.output_unit = OutputUnitOpenEnded(num_answers=self.num_classes) 238 | 239 | init_modules(self.modules(), w_init="xavier_uniform") 240 | nn.init.uniform_(self.linguistic_input_unit.encoder_embed.weight, -1.0, 1.0) 241 | 242 | def forward(self, ans_candidates, ans_candidates_len, video_appearance_feat, video_motion_feat, question, 243 | question_len): 244 | """ 245 | Args: 246 | ans_candidates: [Tensor] (batch_size, 5, max_ans_candidates_length) 247 | ans_candidates_len: [Tensor] (batch_size, 5) 248 | video_appearance_feat: [Tensor] (batch_size, num_clips, num_frames, visual_inp_dim) 249 | video_motion_feat: [Tensor] (batch_size, num_clips, visual_inp_dim) 250 | question: [Tensor] (batch_size, max_question_length) 251 | question_len: [Tensor] (batch_size) 252 | return: 253 | logits. 254 | """ 255 | batch_size = question.size(0) 256 | if self.question_type in ['frameqa', 'count', 'none']: 257 | # get image, word, and sentence embeddings 258 | question_embedding = self.linguistic_input_unit(question, question_len) 259 | visual_embedding = self.visual_input_unit(video_appearance_feat, video_motion_feat, question_embedding) 260 | 261 | visual_embedding = self.feature_aggregation(question_embedding, visual_embedding) 262 | 263 | out = self.output_unit(question_embedding, visual_embedding) 264 | else: 265 | question_embedding = self.linguistic_input_unit(question, question_len) 266 | visual_embedding = self.visual_input_unit(video_appearance_feat, video_motion_feat, question_embedding) 267 | 268 | q_visual_embedding = self.feature_aggregation(question_embedding, visual_embedding) 269 | 270 | # ans_candidates: (batch_size, num_choices, max_len) 271 | ans_candidates_agg = ans_candidates.view(-1, ans_candidates.size(2)) 272 | ans_candidates_len_agg = ans_candidates_len.view(-1) 273 | 274 | batch_agg = np.reshape( 275 | np.tile(np.expand_dims(np.arange(batch_size), axis=1), [1, 5]), [-1]) 276 | 277 | ans_candidates_embedding = self.linguistic_input_unit(ans_candidates_agg, ans_candidates_len_agg) 278 | 279 | a_visual_embedding = self.feature_aggregation(ans_candidates_embedding, visual_embedding[batch_agg]) 280 | out = self.output_unit(question_embedding[batch_agg], q_visual_embedding[batch_agg], 281 | ans_candidates_embedding, 282 | a_visual_embedding) 283 | return out 284 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | from torch.nn import init 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | def init_modules(modules, w_init='kaiming_uniform'): 7 | if w_init == "normal": 8 | _init = init.normal_ 9 | elif w_init == "xavier_normal": 10 | _init = init.xavier_normal_ 11 | elif w_init == "xavier_uniform": 12 | _init = init.xavier_uniform_ 13 | elif w_init == "kaiming_normal": 14 | _init = init.kaiming_normal_ 15 | elif w_init == "kaiming_uniform": 16 | _init = init.kaiming_uniform_ 17 | elif w_init == "orthogonal": 18 | _init = init.orthogonal_ 19 | else: 20 | raise NotImplementedError 21 | for m in modules: 22 | if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): 23 | _init(m.weight) 24 | if m.bias is not None: 25 | torch.nn.init.zeros_(m.bias) 26 | if isinstance(m, (nn.LSTM, nn.GRU)): 27 | for name, param in m.named_parameters(): 28 | if 'bias' in name: 29 | nn.init.zeros_(param) 30 | elif 'weight' in name: 31 | _init(param) 32 | -------------------------------------------------------------------------------- /overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thaolmk54/hcrn-videoqa/eb92c9b21aaa00f912fe5c2e5188abc47fda5211/overview.png -------------------------------------------------------------------------------- /preprocess/datautils/msrvtt_qa.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datautils import utils 3 | import nltk 4 | from collections import Counter 5 | 6 | import pickle 7 | import numpy as np 8 | 9 | 10 | def load_video_paths(args): 11 | ''' Load a list of (path,image_id tuples).''' 12 | video_paths = [] 13 | modes = ['train', 'val', 'test'] 14 | for mode in modes: 15 | with open(args.annotation_file.format(mode), 'r') as anno_file: 16 | instances = json.load(anno_file) 17 | video_ids = [instance['video_id'] for instance in instances] 18 | video_ids = set(video_ids) 19 | if mode in ['train', 'val']: 20 | for video_id in video_ids: 21 | video_paths.append((args.video_dir + 'TrainValVideo/video{}.mp4'.format(video_id), video_id)) 22 | else: 23 | for video_id in video_ids: 24 | video_paths.append((args.video_dir + 'TestVideo/video{}.mp4'.format(video_id), video_id)) 25 | 26 | return video_paths 27 | 28 | 29 | def process_questions(args): 30 | ''' Encode question tokens''' 31 | print('Loading data') 32 | with open(args.annotation_file, 'r') as dataset_file: 33 | instances = json.load(dataset_file) 34 | 35 | # Either create the vocab or load it from disk 36 | if args.mode in ['train']: 37 | print('Building vocab') 38 | answer_cnt = {} 39 | for instance in instances: 40 | answer = instance['answer'] 41 | answer_cnt[answer] = answer_cnt.get(answer, 0) + 1 42 | 43 | answer_token_to_idx = {'': 0, '': 1} 44 | answer_counter = Counter(answer_cnt) 45 | frequent_answers = answer_counter.most_common(args.answer_top) 46 | total_ans = sum(item[1] for item in answer_counter.items()) 47 | total_freq_ans = sum(item[1] for item in frequent_answers) 48 | print("Number of unique answers:", len(answer_counter)) 49 | print("Total number of answers:", total_ans) 50 | print("Top %i answers account for %f%%" % (len(frequent_answers), total_freq_ans * 100.0 / total_ans)) 51 | 52 | for token, cnt in Counter(answer_cnt).most_common(args.answer_top): 53 | answer_token_to_idx[token] = len(answer_token_to_idx) 54 | print('Get answer_token_to_idx, num: %d' % len(answer_token_to_idx)) 55 | 56 | question_token_to_idx = {'': 0, '': 1} 57 | for i, instance in enumerate(instances): 58 | question = instance['question'].lower()[:-1] 59 | for token in nltk.word_tokenize(question): 60 | if token not in question_token_to_idx: 61 | question_token_to_idx[token] = len(question_token_to_idx) 62 | print('Get question_token_to_idx') 63 | print(len(question_token_to_idx)) 64 | 65 | vocab = { 66 | 'question_token_to_idx': question_token_to_idx, 67 | 'answer_token_to_idx': answer_token_to_idx, 68 | 'question_answer_token_to_idx': {'': 0, '': 1} 69 | } 70 | 71 | print('Write into %s' % args.vocab_json.format(args.dataset, args.dataset)) 72 | with open(args.vocab_json.format(args.dataset, args.dataset), 'w') as f: 73 | json.dump(vocab, f, indent=4) 74 | else: 75 | print('Loading vocab') 76 | with open(args.vocab_json.format(args.dataset, args.dataset), 'r') as f: 77 | vocab = json.load(f) 78 | 79 | # Encode all questions 80 | print('Encoding data') 81 | questions_encoded = [] 82 | questions_len = [] 83 | question_ids = [] 84 | video_ids_tbw = [] 85 | video_names_tbw = [] 86 | all_answers = [] 87 | for idx, instance in enumerate(instances): 88 | question = instance['question'].lower()[:-1] 89 | question_tokens = nltk.word_tokenize(question) 90 | question_encoded = utils.encode(question_tokens, vocab['question_token_to_idx'], allow_unk=True) 91 | questions_encoded.append(question_encoded) 92 | questions_len.append(len(question_encoded)) 93 | question_ids.append(idx) 94 | im_name = instance['video_id'] 95 | video_ids_tbw.append(im_name) 96 | video_names_tbw.append(im_name) 97 | 98 | if instance['answer'] in vocab['answer_token_to_idx']: 99 | answer = vocab['answer_token_to_idx'][instance['answer']] 100 | elif args.mode in ['train']: 101 | answer = 0 102 | elif args.mode in ['val', 'test']: 103 | answer = 1 104 | 105 | all_answers.append(answer) 106 | max_question_length = max(len(x) for x in questions_encoded) 107 | for qe in questions_encoded: 108 | while len(qe) < max_question_length: 109 | qe.append(vocab['question_token_to_idx']['']) 110 | 111 | questions_encoded = np.asarray(questions_encoded, dtype=np.int32) 112 | questions_len = np.asarray(questions_len, dtype=np.int32) 113 | print(questions_encoded.shape) 114 | 115 | glove_matrix = None 116 | if args.mode == 'train': 117 | token_itow = {i: w for w, i in vocab['question_token_to_idx'].items()} 118 | print("Load glove from %s" % args.glove_pt) 119 | glove = pickle.load(open(args.glove_pt, 'rb')) 120 | dim_word = glove['the'].shape[0] 121 | glove_matrix = [] 122 | for i in range(len(token_itow)): 123 | vector = glove.get(token_itow[i], np.zeros((dim_word,))) 124 | glove_matrix.append(vector) 125 | glove_matrix = np.asarray(glove_matrix, dtype=np.float32) 126 | print(glove_matrix.shape) 127 | 128 | print('Writing', args.output_pt.format(args.dataset, args.dataset, args.mode)) 129 | obj = { 130 | 'questions': questions_encoded, 131 | 'questions_len': questions_len, 132 | 'question_id': question_ids, 133 | 'video_ids': np.asarray(video_ids_tbw), 134 | 'video_names': np.array(video_names_tbw), 135 | 'answers': all_answers, 136 | 'glove': glove_matrix, 137 | } 138 | with open(args.output_pt.format(args.dataset, args.dataset, args.mode), 'wb') as f: 139 | pickle.dump(obj, f) 140 | -------------------------------------------------------------------------------- /preprocess/datautils/msvd_qa.py: -------------------------------------------------------------------------------- 1 | import json 2 | from datautils import utils 3 | import nltk 4 | from collections import Counter 5 | 6 | import pickle 7 | import numpy as np 8 | 9 | 10 | def load_video_paths(args): 11 | ''' Load a list of (path,image_id tuples).''' 12 | video_paths = [] 13 | video_ids = [] 14 | modes = ['train', 'val', 'test'] 15 | for mode in modes: 16 | with open(args.annotation_file.format(mode), 'r') as anno_file: 17 | instances = json.load(anno_file) 18 | [video_ids.append(instance['video_id']) for instance in instances] 19 | video_ids = set(video_ids) 20 | with open(args.video_name_mapping, 'r') as mapping: 21 | mapping_pairs = mapping.read().split('\n') 22 | mapping_dict = {} 23 | for idx in range(len(mapping_pairs)): 24 | cur_pair = mapping_pairs[idx].split(' ') 25 | mapping_dict[cur_pair[1]] = cur_pair[0] 26 | for video_id in video_ids: 27 | video_paths.append((args.video_dir + 'YouTubeClips/{}.avi'.format(mapping_dict['vid' + str(video_id)]), video_id)) 28 | return video_paths 29 | 30 | 31 | def process_questions(args): 32 | ''' Encode question tokens''' 33 | print('Loading data') 34 | with open(args.annotation_file, 'r') as dataset_file: 35 | instances = json.load(dataset_file) 36 | 37 | # Either create the vocab or load it from disk 38 | if args.mode in ['train']: 39 | print('Building vocab') 40 | answer_cnt = {} 41 | for instance in instances: 42 | answer = instance['answer'] 43 | answer_cnt[answer] = answer_cnt.get(answer, 0) + 1 44 | 45 | answer_token_to_idx = {'': 0, '': 1} 46 | answer_counter = Counter(answer_cnt) 47 | frequent_answers = answer_counter.most_common(args.answer_top) 48 | total_ans = sum(item[1] for item in answer_counter.items()) 49 | total_freq_ans = sum(item[1] for item in frequent_answers) 50 | print("Number of unique answers:", len(answer_counter)) 51 | print("Total number of answers:", total_ans) 52 | print("Top %i answers account for %f%%" % (len(frequent_answers), total_freq_ans * 100.0 / total_ans)) 53 | 54 | for token, cnt in Counter(answer_cnt).most_common(args.answer_top): 55 | answer_token_to_idx[token] = len(answer_token_to_idx) 56 | print('Get answer_token_to_idx, num: %d' % len(answer_token_to_idx)) 57 | 58 | question_token_to_idx = {'': 0, '': 1} 59 | for i, instance in enumerate(instances): 60 | question = instance['question'].lower()[:-1] 61 | for token in nltk.word_tokenize(question): 62 | if token not in question_token_to_idx: 63 | question_token_to_idx[token] = len(question_token_to_idx) 64 | print('Get question_token_to_idx') 65 | print(len(question_token_to_idx)) 66 | 67 | vocab = { 68 | 'question_token_to_idx': question_token_to_idx, 69 | 'answer_token_to_idx': answer_token_to_idx, 70 | 'question_answer_token_to_idx': {'': 0, '': 1} 71 | } 72 | 73 | print('Write into %s' % args.vocab_json.format(args.dataset, args.dataset)) 74 | with open(args.vocab_json.format(args.dataset, args.dataset), 'w') as f: 75 | json.dump(vocab, f, indent=4) 76 | else: 77 | print('Loading vocab') 78 | with open(args.vocab_json.format(args.dataset, args.dataset), 'r') as f: 79 | vocab = json.load(f) 80 | 81 | # Encode all questions 82 | print('Encoding data') 83 | questions_encoded = [] 84 | questions_len = [] 85 | question_ids = [] 86 | video_ids_tbw = [] 87 | video_names_tbw = [] 88 | all_answers = [] 89 | for idx, instance in enumerate(instances): 90 | question = instance['question'].lower()[:-1] 91 | question_tokens = nltk.word_tokenize(question) 92 | question_encoded = utils.encode(question_tokens, vocab['question_token_to_idx'], allow_unk=True) 93 | questions_encoded.append(question_encoded) 94 | questions_len.append(len(question_encoded)) 95 | question_ids.append(idx) 96 | im_name = instance['video_id'] 97 | video_ids_tbw.append(im_name) 98 | video_names_tbw.append(im_name) 99 | 100 | if instance['answer'] in vocab['answer_token_to_idx']: 101 | answer = vocab['answer_token_to_idx'][instance['answer']] 102 | elif args.mode in ['train']: 103 | answer = 0 104 | elif args.mode in ['val', 'test']: 105 | answer = 1 106 | 107 | all_answers.append(answer) 108 | max_question_length = max(len(x) for x in questions_encoded) 109 | for qe in questions_encoded: 110 | while len(qe) < max_question_length: 111 | qe.append(vocab['question_token_to_idx']['']) 112 | 113 | questions_encoded = np.asarray(questions_encoded, dtype=np.int32) 114 | questions_len = np.asarray(questions_len, dtype=np.int32) 115 | print(questions_encoded.shape) 116 | 117 | glove_matrix = None 118 | if args.mode == 'train': 119 | token_itow = {i: w for w, i in vocab['question_token_to_idx'].items()} 120 | print("Load glove from %s" % args.glove_pt) 121 | glove = pickle.load(open(args.glove_pt, 'rb')) 122 | dim_word = glove['the'].shape[0] 123 | glove_matrix = [] 124 | for i in range(len(token_itow)): 125 | vector = glove.get(token_itow[i], np.zeros((dim_word,))) 126 | glove_matrix.append(vector) 127 | glove_matrix = np.asarray(glove_matrix, dtype=np.float32) 128 | print(glove_matrix.shape) 129 | 130 | print('Writing', args.output_pt.format(args.dataset, args.dataset, args.mode)) 131 | obj = { 132 | 'questions': questions_encoded, 133 | 'questions_len': questions_len, 134 | 'question_id': question_ids, 135 | 'video_ids': np.asarray(video_ids_tbw), 136 | 'video_names': np.array(video_names_tbw), 137 | 'answers': all_answers, 138 | 'glove': glove_matrix, 139 | } 140 | with open(args.output_pt.format(args.dataset, args.dataset, args.mode), 'wb') as f: 141 | pickle.dump(obj, f) 142 | -------------------------------------------------------------------------------- /preprocess/datautils/tgif_qa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pandas as pd 3 | import json 4 | from datautils import utils 5 | import nltk 6 | 7 | import pickle 8 | import numpy as np 9 | 10 | 11 | def load_video_paths(args): 12 | ''' Load a list of (path,image_id tuples).''' 13 | input_paths = [] 14 | annotation = pd.read_csv(args.annotation_file.format(args.question_type), delimiter='\t') 15 | gif_names = list(annotation['gif_name']) 16 | keys = list(annotation['key']) 17 | print("Number of questions: {}".format(len(gif_names))) 18 | for idx, gif in enumerate(gif_names): 19 | gif_abs_path = os.path.join(args.video_dir, ''.join([gif, '.gif'])) 20 | input_paths.append((gif_abs_path, keys[idx])) 21 | input_paths = list(set(input_paths)) 22 | print("Number of unique videos: {}".format(len(input_paths))) 23 | 24 | return input_paths 25 | 26 | 27 | def openeded_encoding_data(args, vocab, questions, video_names, video_ids, answers, mode='train'): 28 | ''' Encode question tokens''' 29 | print('Encoding data') 30 | questions_encoded = [] 31 | questions_len = [] 32 | video_ids_tbw = [] 33 | video_names_tbw = [] 34 | all_answers = [] 35 | question_ids = [] 36 | for idx, question in enumerate(questions): 37 | question = question.lower()[:-1] 38 | question_tokens = nltk.word_tokenize(question) 39 | question_encoded = utils.encode(question_tokens, vocab['question_token_to_idx'], allow_unk=True) 40 | questions_encoded.append(question_encoded) 41 | questions_len.append(len(question_encoded)) 42 | question_ids.append(idx) 43 | video_names_tbw.append(video_names[idx]) 44 | video_ids_tbw.append(video_ids[idx]) 45 | 46 | if args.question_type == "frameqa": 47 | answer = answers[idx] 48 | if answer in vocab['answer_token_to_idx']: 49 | answer = vocab['answer_token_to_idx'][answer] 50 | elif mode in ['train']: 51 | answer = 0 52 | elif mode in ['val', 'test']: 53 | answer = 1 54 | else: 55 | answer = max(int(answers[idx]), 1) 56 | all_answers.append(answer) 57 | 58 | # Pad encoded questions 59 | max_question_length = max(len(x) for x in questions_encoded) 60 | for qe in questions_encoded: 61 | while len(qe) < max_question_length: 62 | qe.append(vocab['question_token_to_idx']['']) 63 | 64 | questions_encoded = np.asarray(questions_encoded, dtype=np.int32) 65 | questions_len = np.asarray(questions_len, dtype=np.int32) 66 | print(questions_encoded.shape) 67 | 68 | glove_matrix = None 69 | if mode == 'train': 70 | token_itow = {i: w for w, i in vocab['question_token_to_idx'].items()} 71 | print("Load glove from %s" % args.glove_pt) 72 | glove = pickle.load(open(args.glove_pt, 'rb')) 73 | dim_word = glove['the'].shape[0] 74 | glove_matrix = [] 75 | for i in range(len(token_itow)): 76 | vector = glove.get(token_itow[i], np.zeros((dim_word,))) 77 | glove_matrix.append(vector) 78 | glove_matrix = np.asarray(glove_matrix, dtype=np.float32) 79 | print(glove_matrix.shape) 80 | 81 | print('Writing ', args.output_pt.format(args.question_type, args.question_type, mode)) 82 | obj = { 83 | 'questions': questions_encoded, 84 | 'questions_len': questions_len, 85 | 'question_id': question_ids, 86 | 'video_ids': np.asarray(video_ids_tbw), 87 | 'video_names': np.array(video_names_tbw), 88 | 'answers': all_answers, 89 | 'glove': glove_matrix, 90 | } 91 | with open(args.output_pt.format(args.question_type, args.question_type, mode), 'wb') as f: 92 | pickle.dump(obj, f) 93 | 94 | def multichoice_encoding_data(args, vocab, questions, video_names, video_ids, answers, ans_candidates, mode='train'): 95 | # Encode all questions 96 | print('Encoding data') 97 | questions_encoded = [] 98 | questions_len = [] 99 | question_ids = [] 100 | all_answer_cands_encoded = [] 101 | all_answer_cands_len = [] 102 | video_ids_tbw = [] 103 | video_names_tbw = [] 104 | correct_answers = [] 105 | for idx, question in enumerate(questions): 106 | question = question.lower()[:-1] 107 | question_tokens = nltk.word_tokenize(question) 108 | question_encoded = utils.encode(question_tokens, vocab['question_answer_token_to_idx'], allow_unk=True) 109 | questions_encoded.append(question_encoded) 110 | questions_len.append(len(question_encoded)) 111 | question_ids.append(idx) 112 | video_names_tbw.append(video_names[idx]) 113 | video_ids_tbw.append(video_ids[idx]) 114 | # grounthtruth 115 | answer = int(answers[idx]) 116 | correct_answers.append(answer) 117 | # answer candidates 118 | candidates = ans_candidates[idx] 119 | candidates_encoded = [] 120 | candidates_len = [] 121 | for ans in candidates: 122 | ans = ans.lower() 123 | ans_tokens = nltk.word_tokenize(ans) 124 | cand_encoded = utils.encode(ans_tokens, vocab['question_answer_token_to_idx'], allow_unk=True) 125 | candidates_encoded.append(cand_encoded) 126 | candidates_len.append(len(cand_encoded)) 127 | all_answer_cands_encoded.append(candidates_encoded) 128 | all_answer_cands_len.append(candidates_len) 129 | 130 | # Pad encoded questions 131 | max_question_length = max(len(x) for x in questions_encoded) 132 | for qe in questions_encoded: 133 | while len(qe) < max_question_length: 134 | qe.append(vocab['question_answer_token_to_idx']['']) 135 | 136 | questions_encoded = np.asarray(questions_encoded, dtype=np.int32) 137 | questions_len = np.asarray(questions_len, dtype=np.int32) 138 | print(questions_encoded.shape) 139 | 140 | # Pad encoded answer candidates 141 | max_answer_cand_length = max(max(len(x) for x in candidate) for candidate in all_answer_cands_encoded) 142 | for ans_cands in all_answer_cands_encoded: 143 | for ans in ans_cands: 144 | while len(ans) < max_answer_cand_length: 145 | ans.append(vocab['question_answer_token_to_idx']['']) 146 | all_answer_cands_encoded = np.asarray(all_answer_cands_encoded, dtype=np.int32) 147 | all_answer_cands_len = np.asarray(all_answer_cands_len, dtype=np.int32) 148 | print(all_answer_cands_encoded.shape) 149 | 150 | glove_matrix = None 151 | if mode in ['train']: 152 | token_itow = {i: w for w, i in vocab['question_answer_token_to_idx'].items()} 153 | print("Load glove from %s" % args.glove_pt) 154 | glove = pickle.load(open(args.glove_pt, 'rb')) 155 | dim_word = glove['the'].shape[0] 156 | glove_matrix = [] 157 | for i in range(len(token_itow)): 158 | vector = glove.get(token_itow[i], np.zeros((dim_word,))) 159 | glove_matrix.append(vector) 160 | glove_matrix = np.asarray(glove_matrix, dtype=np.float32) 161 | print(glove_matrix.shape) 162 | 163 | print('Writing ', args.output_pt.format(args.question_type, args.question_type, mode)) 164 | obj = { 165 | 'questions': questions_encoded, 166 | 'questions_len': questions_len, 167 | 'question_id': question_ids, 168 | 'video_ids': np.asarray(video_ids_tbw), 169 | 'video_names': np.array(video_names_tbw), 170 | 'ans_candidates': all_answer_cands_encoded, 171 | 'ans_candidates_len': all_answer_cands_len, 172 | 'answers': correct_answers, 173 | 'glove': glove_matrix, 174 | } 175 | with open(args.output_pt.format(args.question_type, args.question_type, mode), 'wb') as f: 176 | pickle.dump(obj, f) 177 | 178 | def process_questions_openended(args): 179 | print('Loading data') 180 | if args.mode in ["train"]: 181 | csv_data = pd.read_csv(args.annotation_file.format("Train", args.question_type), delimiter='\t') 182 | else: 183 | csv_data = pd.read_csv(args.annotation_file.format("Test", args.question_type), delimiter='\t') 184 | csv_data = csv_data.iloc[np.random.permutation(len(csv_data))] 185 | questions = list(csv_data['question']) 186 | answers = list(csv_data['answer']) 187 | video_names = list(csv_data['gif_name']) 188 | video_ids = list(csv_data['key']) 189 | 190 | print('number of questions: %s' % len(questions)) 191 | # Either create the vocab or load it from disk 192 | if args.mode in ['train']: 193 | print('Building vocab') 194 | answer_cnt = {} 195 | 196 | if args.question_type == "frameqa": 197 | for i, answer in enumerate(answers): 198 | answer_cnt[answer] = answer_cnt.get(answer, 0) + 1 199 | 200 | answer_token_to_idx = {'': 0} 201 | for token in answer_cnt: 202 | answer_token_to_idx[token] = len(answer_token_to_idx) 203 | print('Get answer_token_to_idx, num: %d' % len(answer_token_to_idx)) 204 | elif args.question_type == 'count': 205 | answer_token_to_idx = {'': 0} 206 | 207 | question_token_to_idx = {'': 0, '': 1} 208 | for i, q in enumerate(questions): 209 | question = q.lower()[:-1] 210 | for token in nltk.word_tokenize(question): 211 | if token not in question_token_to_idx: 212 | question_token_to_idx[token] = len(question_token_to_idx) 213 | print('Get question_token_to_idx') 214 | print(len(question_token_to_idx)) 215 | 216 | vocab = { 217 | 'question_token_to_idx': question_token_to_idx, 218 | 'answer_token_to_idx': answer_token_to_idx, 219 | 'question_answer_token_to_idx': {'': 0, '': 1} 220 | } 221 | 222 | print('Write into %s' % args.vocab_json.format(args.question_type, args.question_type)) 223 | with open(args.vocab_json.format(args.question_type, args.question_type), 'w') as f: 224 | json.dump(vocab, f, indent=4) 225 | 226 | # split 10% of questions for evaluation 227 | split = int(0.9 * len(questions)) 228 | train_questions = questions[:split] 229 | train_answers = answers[:split] 230 | train_video_names = video_names[:split] 231 | train_video_ids = video_ids[:split] 232 | 233 | val_questions = questions[split:] 234 | val_answers = answers[split:] 235 | val_video_names = video_names[split:] 236 | val_video_ids = video_ids[split:] 237 | 238 | openeded_encoding_data(args, vocab, train_questions, train_video_names, train_video_ids, train_answers, mode='train') 239 | openeded_encoding_data(args, vocab, val_questions, val_video_names, val_video_ids, val_answers, mode='val') 240 | else: 241 | print('Loading vocab') 242 | with open(args.vocab_json.format(args.question_type, args.question_type), 'r') as f: 243 | vocab = json.load(f) 244 | openeded_encoding_data(args, vocab, questions, video_names, video_ids, answers, mode='test') 245 | 246 | 247 | 248 | 249 | def process_questions_mulchoices(args): 250 | print('Loading data') 251 | if args.mode in ["train", "val"]: 252 | csv_data = pd.read_csv(args.annotation_file.format("Train", args.question_type), delimiter='\t') 253 | else: 254 | csv_data = pd.read_csv(args.annotation_file.format("Test", args.question_type), delimiter='\t') 255 | csv_data = csv_data.iloc[np.random.permutation(len(csv_data))] 256 | questions = list(csv_data['question']) 257 | answers = list(csv_data['answer']) 258 | video_names = list(csv_data['gif_name']) 259 | video_ids = list(csv_data['key']) 260 | ans_candidates = np.asarray( 261 | [csv_data['a1'], csv_data['a2'], csv_data['a3'], csv_data['a4'], csv_data['a5']]) 262 | ans_candidates = ans_candidates.transpose() 263 | print(ans_candidates.shape) 264 | # ans_candidates: (num_ques, 5) 265 | print('number of questions: %s' % len(questions)) 266 | # Either create the vocab or load it from disk 267 | if args.mode in ['train']: 268 | print('Building vocab') 269 | 270 | answer_token_to_idx = {'': 0, '': 1} 271 | question_answer_token_to_idx = {'': 0, '': 1} 272 | for candidates in ans_candidates: 273 | for ans in candidates: 274 | ans = ans.lower() 275 | for token in nltk.word_tokenize(ans): 276 | if token not in answer_token_to_idx: 277 | answer_token_to_idx[token] = len(answer_token_to_idx) 278 | if token not in question_answer_token_to_idx: 279 | question_answer_token_to_idx[token] = len(question_answer_token_to_idx) 280 | print('Get answer_token_to_idx, num: %d' % len(answer_token_to_idx)) 281 | 282 | question_token_to_idx = {'': 0, '': 1} 283 | for i, q in enumerate(questions): 284 | question = q.lower()[:-1] 285 | for token in nltk.word_tokenize(question): 286 | if token not in question_token_to_idx: 287 | question_token_to_idx[token] = len(question_token_to_idx) 288 | if token not in question_answer_token_to_idx: 289 | question_answer_token_to_idx[token] = len(question_answer_token_to_idx) 290 | 291 | print('Get question_token_to_idx') 292 | print(len(question_token_to_idx)) 293 | print('Get question_answer_token_to_idx') 294 | print(len(question_answer_token_to_idx)) 295 | 296 | vocab = { 297 | 'question_token_to_idx': question_token_to_idx, 298 | 'answer_token_to_idx': answer_token_to_idx, 299 | 'question_answer_token_to_idx': question_answer_token_to_idx, 300 | } 301 | 302 | print('Write into %s' % args.vocab_json.format(args.question_type, args.question_type)) 303 | with open(args.vocab_json.format(args.question_type, args.question_type), 'w') as f: 304 | json.dump(vocab, f, indent=4) 305 | 306 | # split 10% of questions for evaluation 307 | split = int(0.9 * len(questions)) 308 | train_questions = questions[:split] 309 | train_answers = answers[:split] 310 | train_video_names = video_names[:split] 311 | train_video_ids = video_ids[:split] 312 | train_ans_candidates = ans_candidates[:split, :] 313 | 314 | val_questions = questions[split:] 315 | val_answers = answers[split:] 316 | val_video_names = video_names[split:] 317 | val_video_ids = video_ids[split:] 318 | val_ans_candidates = ans_candidates[split:, :] 319 | 320 | multichoice_encoding_data(args, vocab, train_questions, train_video_names, train_video_ids, train_answers, train_ans_candidates, mode='train') 321 | multichoice_encoding_data(args, vocab, val_questions, val_video_names, val_video_ids, val_answers, 322 | val_ans_candidates, mode='val') 323 | else: 324 | print('Loading vocab') 325 | with open(args.vocab_json.format(args.question_type, args.question_type), 'r') as f: 326 | vocab = json.load(f) 327 | multichoice_encoding_data(args, vocab, questions, video_names, video_ids, answers, 328 | ans_candidates, mode='test') 329 | -------------------------------------------------------------------------------- /preprocess/datautils/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | def encode(seq_tokens, token_to_idx, allow_unk=False): 4 | seq_idx = [] 5 | for token in seq_tokens: 6 | if token not in token_to_idx: 7 | if allow_unk: 8 | token = '' 9 | else: 10 | raise KeyError('Token "%s" not in vocab' % token) 11 | seq_idx.append(token_to_idx[token]) 12 | return seq_idx 13 | 14 | 15 | def decode(seq_idx, idx_to_token, delim=None, stop_at_end=True): 16 | tokens = [] 17 | for idx in seq_idx: 18 | tokens.append(idx_to_token[idx]) 19 | if stop_at_end and tokens[-1] == '': 20 | break 21 | if delim is None: 22 | return tokens 23 | else: 24 | return delim.join(tokens) 25 | 26 | # -------------------------------------------------------- 27 | # Fast R-CNN 28 | # Copyright (c) 2015 Microsoft 29 | # Licensed under The MIT License [see LICENSE for details] 30 | # Written by Ross Girshick 31 | # -------------------------------------------------------- 32 | 33 | class Timer(object): 34 | """A simple timer.""" 35 | def __init__(self): 36 | self.total_time = 0. 37 | self.calls = 0 38 | self.start_time = 0. 39 | self.diff = 0. 40 | self.average_time = 0. 41 | 42 | def tic(self): 43 | # using time.time instead of time.clock because time time.clock 44 | # does not normalize for multithreading 45 | self.start_time = time.time() 46 | 47 | def toc(self, average=True): 48 | self.diff = time.time() - self.start_time 49 | self.total_time += self.diff 50 | self.calls += 1 51 | self.average_time = self.total_time / self.calls 52 | if average: 53 | return self.average_time 54 | else: 55 | return self.diff -------------------------------------------------------------------------------- /preprocess/models/densenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from collections import OrderedDict 5 | import math 6 | 7 | __all__ = ['DenseNet', 'densenet121', 'densenet169', 'densenet201', 'densenet264'] 8 | 9 | 10 | def densenet121(**kwargs): 11 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 12 | **kwargs) 13 | return model 14 | 15 | 16 | def densenet169(**kwargs): 17 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 18 | **kwargs) 19 | return model 20 | 21 | 22 | def densenet201(**kwargs): 23 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 24 | **kwargs) 25 | return model 26 | 27 | 28 | def densenet264(**kwargs): 29 | model = DenseNet(num_init_features=64, growth_rate=32, block_config=(6, 12, 64, 48), 30 | **kwargs) 31 | return model 32 | 33 | 34 | def get_fine_tuning_parameters(model, ft_begin_index): 35 | if ft_begin_index == 0: 36 | return model.parameters() 37 | 38 | ft_module_names = [] 39 | for i in range(ft_begin_index, 5): 40 | ft_module_names.append('denseblock{}'.format(ft_begin_index)) 41 | ft_module_names.append('transition{}'.format(ft_begin_index)) 42 | ft_module_names.append('norm5') 43 | ft_module_names.append('classifier') 44 | 45 | parameters = [] 46 | for k, v in model.named_parameters(): 47 | for ft_module in ft_module_names: 48 | if ft_module in k: 49 | parameters.append({'params': v}) 50 | break 51 | else: 52 | parameters.append({'params': v, 'lr': 0.0}) 53 | 54 | return parameters 55 | 56 | 57 | class _DenseLayer(nn.Sequential): 58 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 59 | super(_DenseLayer, self).__init__() 60 | self.add_module('norm.1', nn.BatchNorm3d(num_input_features)) 61 | self.add_module('relu.1', nn.ReLU(inplace=True)) 62 | self.add_module('conv.1', nn.Conv3d(num_input_features, bn_size * growth_rate, 63 | kernel_size=1, stride=1, bias=False)) 64 | self.add_module('norm.2', nn.BatchNorm3d(bn_size * growth_rate)) 65 | self.add_module('relu.2', nn.ReLU(inplace=True)) 66 | self.add_module('conv.2', nn.Conv3d(bn_size * growth_rate, growth_rate, 67 | kernel_size=3, stride=1, padding=1, bias=False)) 68 | self.drop_rate = drop_rate 69 | 70 | def forward(self, x): 71 | new_features = super(_DenseLayer, self).forward(x) 72 | if self.drop_rate > 0: 73 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 74 | return torch.cat([x, new_features], 1) 75 | 76 | 77 | class _DenseBlock(nn.Sequential): 78 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 79 | super(_DenseBlock, self).__init__() 80 | for i in range(num_layers): 81 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 82 | self.add_module('denselayer%d' % (i + 1), layer) 83 | 84 | 85 | class _Transition(nn.Sequential): 86 | def __init__(self, num_input_features, num_output_features): 87 | super(_Transition, self).__init__() 88 | self.add_module('norm', nn.BatchNorm3d(num_input_features)) 89 | self.add_module('relu', nn.ReLU(inplace=True)) 90 | self.add_module('conv', nn.Conv3d(num_input_features, num_output_features, 91 | kernel_size=1, stride=1, bias=False)) 92 | self.add_module('pool', nn.AvgPool3d(kernel_size=2, stride=2)) 93 | 94 | 95 | class DenseNet(nn.Module): 96 | """Densenet-BC model class 97 | Args: 98 | growth_rate (int) - how many filters to add each layer (k in paper) 99 | block_config (list of 4 ints) - how many layers in each pooling block 100 | num_init_features (int) - the number of filters to learn in the first convolution layer 101 | bn_size (int) - multiplicative factor for number of bottle neck layers 102 | (i.e. bn_size * k features in the bottleneck layer) 103 | drop_rate (float) - dropout rate after each dense layer 104 | num_classes (int) - number of classification classes 105 | """ 106 | def __init__(self, sample_size, sample_duration, growth_rate=32, block_config=(6, 12, 24, 16), 107 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000, last_fc=True): 108 | 109 | super(DenseNet, self).__init__() 110 | 111 | self.last_fc = last_fc 112 | 113 | self.sample_size = sample_size 114 | self.sample_duration = sample_duration 115 | 116 | # First convolution 117 | self.features = nn.Sequential(OrderedDict([ 118 | ('conv0', nn.Conv3d(3, num_init_features, kernel_size=7, 119 | stride=(1, 2, 2), padding=(3, 3, 3), bias=False)), 120 | ('norm0', nn.BatchNorm3d(num_init_features)), 121 | ('relu0', nn.ReLU(inplace=True)), 122 | ('pool0', nn.MaxPool3d(kernel_size=3, stride=2, padding=1)), 123 | ])) 124 | 125 | # Each denseblock 126 | num_features = num_init_features 127 | for i, num_layers in enumerate(block_config): 128 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 129 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 130 | self.features.add_module('denseblock%d' % (i + 1), block) 131 | num_features = num_features + num_layers * growth_rate 132 | if i != len(block_config) - 1: 133 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 134 | self.features.add_module('transition%d' % (i + 1), trans) 135 | num_features = num_features // 2 136 | 137 | # Final batch norm 138 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 139 | 140 | # Linear layer 141 | self.classifier = nn.Linear(num_features, num_classes) 142 | 143 | def forward(self, x): 144 | features = self.features(x) 145 | out = F.relu(features, inplace=True) 146 | last_duration = math.ceil(self.sample_duration / 16) 147 | last_size = math.floor(self.sample_size / 32) 148 | out = F.avg_pool3d(out, kernel_size=(last_duration, last_size, last_size)).view(features.size(0), -1) 149 | if self.last_fc: 150 | out = self.classifier(out) 151 | return out 152 | -------------------------------------------------------------------------------- /preprocess/models/pre_act_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['PreActivationResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet200'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | 16 | 17 | def downsample_basic_block(x, planes, stride): 18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), 20 | out.size(2), out.size(3), 21 | out.size(4)).zero_() 22 | if isinstance(out.data, torch.cuda.FloatTensor): 23 | zero_pads = zero_pads.cuda() 24 | 25 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 26 | 27 | return out 28 | 29 | 30 | class PreActivationBasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(PreActivationBasicBlock, self).__init__() 35 | self.bn1 = nn.BatchNorm3d(inplanes) 36 | self.conv1 = conv3x3x3(inplanes, planes, stride) 37 | self.bn2 = nn.BatchNorm3d(planes) 38 | self.conv2 = conv3x3x3(planes, planes) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.bn1(x) 47 | out = self.relu(out) 48 | out = self.conv1(out) 49 | 50 | out = self.bn2(out) 51 | out = self.relu(out) 52 | out = self.conv2(out) 53 | 54 | if self.downsample is not None: 55 | residual = self.downsample(x) 56 | 57 | out += residual 58 | 59 | return out 60 | 61 | 62 | class PreActivationBottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(PreActivationBottleneck, self).__init__() 67 | self.bn1 = nn.BatchNorm3d(inplanes) 68 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 69 | self.bn2 = nn.BatchNorm3d(planes) 70 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, 71 | padding=1, bias=False) 72 | self.bn3 = nn.BatchNorm3d(planes) 73 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.bn1(x) 82 | out = self.relu(out) 83 | out = self.conv1(out) 84 | 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | out = self.conv2(out) 88 | 89 | out = self.bn3(out) 90 | out = self.relu(out) 91 | out = self.conv3(out) 92 | 93 | if self.downsample is not None: 94 | residual = self.downsample(x) 95 | 96 | out += residual 97 | 98 | return out 99 | 100 | 101 | class PreActivationResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, sample_size, sample_duration, shortcut_type='B', num_classes=400, last_fc=True): 104 | self.last_fc = last_fc 105 | 106 | self.inplanes = 64 107 | super(PreActivationResNet, self).__init__() 108 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), 109 | padding=(3, 3, 3), bias=False) 110 | self.bn1 = nn.BatchNorm3d(64) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 114 | self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=2) 117 | last_duration = math.ceil(sample_duration / 16) 118 | last_size = math.ceil(sample_size / 32) 119 | self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) 120 | self.fc = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv3d): 124 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 125 | m.weight.data.normal_(0, math.sqrt(2. / n)) 126 | elif isinstance(m, nn.BatchNorm3d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | 130 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 131 | downsample = None 132 | if stride != 1 or self.inplanes != planes * block.expansion: 133 | if shortcut_type == 'A': 134 | downsample = partial(downsample_basic_block, 135 | planes=planes * block.expansion, 136 | stride=stride) 137 | else: 138 | downsample = nn.Sequential( 139 | nn.Conv3d(self.inplanes, planes * block.expansion, 140 | kernel_size=1, stride=stride, bias=False), 141 | nn.BatchNorm3d(planes * block.expansion) 142 | ) 143 | 144 | layers = [] 145 | layers.append(block(self.inplanes, planes, stride, downsample)) 146 | self.inplanes = planes * block.expansion 147 | for i in range(1, blocks): 148 | layers.append(block(self.inplanes, planes)) 149 | 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.relu(x) 156 | x = self.maxpool(x) 157 | 158 | x = self.layer1(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | x = self.layer4(x) 162 | 163 | x = self.avgpool(x) 164 | 165 | x = x.view(x.size(0), -1) 166 | if self.last_fc: 167 | x = self.fc(x) 168 | 169 | return x 170 | 171 | def get_fine_tuning_parameters(model, ft_begin_index): 172 | if ft_begin_index == 0: 173 | return model.parameters() 174 | 175 | ft_module_names = [] 176 | for i in range(ft_begin_index, 5): 177 | ft_module_names.append('layer{}'.format(ft_begin_index)) 178 | ft_module_names.append('fc') 179 | 180 | parameters = [] 181 | for k, v in model.named_parameters(): 182 | for ft_module in ft_module_names: 183 | if ft_module in k: 184 | parameters.append({'params': v}) 185 | break 186 | else: 187 | parameters.append({'params': v, 'lr': 0.0}) 188 | 189 | return parameters 190 | 191 | def resnet18(**kwargs): 192 | """Constructs a ResNet-18 model. 193 | """ 194 | model = PreActivationResNet(PreActivationBasicBlock, [2, 2, 2, 2], **kwargs) 195 | return model 196 | 197 | def resnet34(**kwargs): 198 | """Constructs a ResNet-34 model. 199 | """ 200 | model = PreActivationResNet(PreActivationBasicBlock, [3, 4, 6, 3], **kwargs) 201 | return model 202 | 203 | 204 | def resnet50(**kwargs): 205 | """Constructs a ResNet-50 model. 206 | """ 207 | model = PreActivationResNet(PreActivationBottleneck, [3, 4, 6, 3], **kwargs) 208 | return model 209 | 210 | def resnet101(**kwargs): 211 | """Constructs a ResNet-101 model. 212 | """ 213 | model = PreActivationResNet(PreActivationBottleneck, [3, 4, 23, 3], **kwargs) 214 | return model 215 | 216 | def resnet152(**kwargs): 217 | """Constructs a ResNet-101 model. 218 | """ 219 | model = PreActivationResNet(PreActivationBottleneck, [3, 8, 36, 3], **kwargs) 220 | return model 221 | 222 | def resnet200(**kwargs): 223 | """Constructs a ResNet-101 model. 224 | """ 225 | model = PreActivationResNet(PreActivationBottleneck, [3, 24, 36, 3], **kwargs) 226 | return model 227 | -------------------------------------------------------------------------------- /preprocess/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['ResNet', 'resnet10', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'resnet200'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | 16 | 17 | def downsample_basic_block(x, planes, stride): 18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), 20 | out.size(2), out.size(3), 21 | out.size(4)).zero_() 22 | if isinstance(out.data, torch.cuda.FloatTensor): 23 | zero_pads = zero_pads.cuda() 24 | 25 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 26 | 27 | return out 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(BasicBlock, self).__init__() 35 | self.conv1 = conv3x3x3(inplanes, planes, stride) 36 | self.bn1 = nn.BatchNorm3d(planes) 37 | self.relu = nn.ReLU(inplace=True) 38 | self.conv2 = conv3x3x3(planes, planes) 39 | self.bn2 = nn.BatchNorm3d(planes) 40 | self.downsample = downsample 41 | self.stride = stride 42 | 43 | def forward(self, x): 44 | residual = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | residual = self.downsample(x) 55 | 56 | out += residual 57 | out = self.relu(out) 58 | 59 | return out 60 | 61 | 62 | class Bottleneck(nn.Module): 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None): 66 | super(Bottleneck, self).__init__() 67 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 68 | self.bn1 = nn.BatchNorm3d(planes) 69 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, 70 | padding=1, bias=False) 71 | self.bn2 = nn.BatchNorm3d(planes) 72 | self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False) 73 | self.bn3 = nn.BatchNorm3d(planes * 4) 74 | self.relu = nn.ReLU(inplace=True) 75 | self.downsample = downsample 76 | self.stride = stride 77 | 78 | def forward(self, x): 79 | residual = x 80 | 81 | out = self.conv1(x) 82 | out = self.bn1(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv2(out) 86 | out = self.bn2(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv3(out) 90 | out = self.bn3(out) 91 | 92 | if self.downsample is not None: 93 | residual = self.downsample(x) 94 | 95 | out += residual 96 | out = self.relu(out) 97 | 98 | return out 99 | 100 | 101 | class ResNet(nn.Module): 102 | 103 | def __init__(self, block, layers, sample_size, sample_duration, shortcut_type='B', num_classes=400, last_fc=True): 104 | self.last_fc = last_fc 105 | 106 | self.inplanes = 64 107 | super(ResNet, self).__init__() 108 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), 109 | padding=(3, 3, 3), bias=False) 110 | self.bn1 = nn.BatchNorm3d(64) 111 | self.relu = nn.ReLU(inplace=True) 112 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 113 | self.layer1 = self._make_layer(block, 64, layers[0], shortcut_type) 114 | self.layer2 = self._make_layer(block, 128, layers[1], shortcut_type, stride=2) 115 | self.layer3 = self._make_layer(block, 256, layers[2], shortcut_type, stride=2) 116 | self.layer4 = self._make_layer(block, 512, layers[3], shortcut_type, stride=2) 117 | last_duration = math.ceil(sample_duration / 16) 118 | last_size = math.ceil(sample_size / 32) 119 | self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) 120 | self.fc = nn.Linear(512 * block.expansion, num_classes) 121 | 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv3d): 124 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 125 | m.weight.data.normal_(0, math.sqrt(2. / n)) 126 | elif isinstance(m, nn.BatchNorm3d): 127 | m.weight.data.fill_(1) 128 | m.bias.data.zero_() 129 | 130 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 131 | downsample = None 132 | if stride != 1 or self.inplanes != planes * block.expansion: 133 | if shortcut_type == 'A': 134 | downsample = partial(downsample_basic_block, 135 | planes=planes * block.expansion, 136 | stride=stride) 137 | else: 138 | downsample = nn.Sequential( 139 | nn.Conv3d(self.inplanes, planes * block.expansion, 140 | kernel_size=1, stride=stride, bias=False), 141 | nn.BatchNorm3d(planes * block.expansion) 142 | ) 143 | 144 | layers = [] 145 | layers.append(block(self.inplanes, planes, stride, downsample)) 146 | self.inplanes = planes * block.expansion 147 | for i in range(1, blocks): 148 | layers.append(block(self.inplanes, planes)) 149 | 150 | return nn.Sequential(*layers) 151 | 152 | def forward(self, x): 153 | x = self.conv1(x) 154 | x = self.bn1(x) 155 | x = self.relu(x) 156 | x = self.maxpool(x) 157 | 158 | x = self.layer1(x) 159 | x = self.layer2(x) 160 | x = self.layer3(x) 161 | x = self.layer4(x) 162 | 163 | x = self.avgpool(x) 164 | 165 | x = x.view(x.size(0), -1) 166 | if self.last_fc: 167 | x = self.fc(x) 168 | 169 | return x 170 | 171 | 172 | def get_fine_tuning_parameters(model, ft_begin_index): 173 | if ft_begin_index == 0: 174 | return model.parameters() 175 | 176 | ft_module_names = [] 177 | for i in range(ft_begin_index, 5): 178 | ft_module_names.append('layer{}'.format(ft_begin_index)) 179 | ft_module_names.append('fc') 180 | 181 | parameters = [] 182 | for k, v in model.named_parameters(): 183 | for ft_module in ft_module_names: 184 | if ft_module in k: 185 | parameters.append({'params': v}) 186 | break 187 | else: 188 | parameters.append({'params': v, 'lr': 0.0}) 189 | 190 | return parameters 191 | 192 | 193 | def resnet10(**kwargs): 194 | """Constructs a ResNet-18 model. 195 | """ 196 | model = ResNet(BasicBlock, [1, 1, 1, 1], **kwargs) 197 | return model 198 | 199 | def resnet18(**kwargs): 200 | """Constructs a ResNet-18 model. 201 | """ 202 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 203 | return model 204 | 205 | def resnet34(**kwargs): 206 | """Constructs a ResNet-34 model. 207 | """ 208 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 209 | return model 210 | 211 | def resnet50(**kwargs): 212 | """Constructs a ResNet-50 model. 213 | """ 214 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 215 | return model 216 | 217 | def resnet101(**kwargs): 218 | """Constructs a ResNet-101 model. 219 | """ 220 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 221 | return model 222 | 223 | def resnet152(**kwargs): 224 | """Constructs a ResNet-101 model. 225 | """ 226 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 227 | return model 228 | 229 | def resnet200(**kwargs): 230 | """Constructs a ResNet-101 model. 231 | """ 232 | model = ResNet(Bottleneck, [3, 24, 36, 3], **kwargs) 233 | return model 234 | -------------------------------------------------------------------------------- /preprocess/models/resnext.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['ResNeXt', 'resnet50', 'resnet101'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | 16 | 17 | def downsample_basic_block(x, planes, stride): 18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), 20 | out.size(2), out.size(3), 21 | out.size(4)).zero_() 22 | if isinstance(out.data, torch.cuda.FloatTensor): 23 | zero_pads = zero_pads.cuda() 24 | 25 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 26 | 27 | return out 28 | 29 | 30 | class ResNeXtBottleneck(nn.Module): 31 | expansion = 2 32 | 33 | def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None): 34 | super(ResNeXtBottleneck, self).__init__() 35 | mid_planes = cardinality * int(planes / 32) 36 | self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm3d(mid_planes) 38 | self.conv2 = nn.Conv3d(mid_planes, mid_planes, kernel_size=3, stride=stride, 39 | padding=1, groups=cardinality, bias=False) 40 | self.bn2 = nn.BatchNorm3d(mid_planes) 41 | self.conv3 = nn.Conv3d(mid_planes, planes * self.expansion, kernel_size=1, bias=False) 42 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 43 | self.relu = nn.ReLU(inplace=True) 44 | self.downsample = downsample 45 | self.stride = stride 46 | 47 | def forward(self, x): 48 | residual = x 49 | 50 | out = self.conv1(x) 51 | out = self.bn1(out) 52 | out = self.relu(out) 53 | 54 | out = self.conv2(out) 55 | out = self.bn2(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv3(out) 59 | out = self.bn3(out) 60 | 61 | if self.downsample is not None: 62 | residual = self.downsample(x) 63 | 64 | out += residual 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class ResNeXt(nn.Module): 71 | 72 | def __init__(self, block, layers, sample_size, sample_duration, shortcut_type='B', cardinality=32, num_classes=400, last_fc=True): 73 | self.last_fc = last_fc 74 | 75 | self.inplanes = 64 76 | super(ResNeXt, self).__init__() 77 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), 78 | padding=(3, 3, 3), bias=False) 79 | self.bn1 = nn.BatchNorm3d(64) 80 | self.relu = nn.ReLU(inplace=True) 81 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 82 | self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type, cardinality) 83 | self.layer2 = self._make_layer(block, 256, layers[1], shortcut_type, cardinality, stride=2) 84 | self.layer3 = self._make_layer(block, 512, layers[2], shortcut_type, cardinality, stride=2) 85 | self.layer4 = self._make_layer(block, 1024, layers[3], shortcut_type, cardinality, stride=2) 86 | last_duration = math.ceil(sample_duration / 16) 87 | last_size = math.ceil(sample_size / 32) 88 | self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) 89 | self.fc = nn.Linear(cardinality * 32 * block.expansion, num_classes) 90 | 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv3d): 93 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 94 | m.weight.data.normal_(0, math.sqrt(2. / n)) 95 | elif isinstance(m, nn.BatchNorm3d): 96 | m.weight.data.fill_(1) 97 | m.bias.data.zero_() 98 | 99 | def _make_layer(self, block, planes, blocks, shortcut_type, cardinality, stride=1): 100 | downsample = None 101 | if stride != 1 or self.inplanes != planes * block.expansion: 102 | if shortcut_type == 'A': 103 | downsample = partial(downsample_basic_block, 104 | planes=planes * block.expansion, 105 | stride=stride) 106 | else: 107 | downsample = nn.Sequential( 108 | nn.Conv3d(self.inplanes, planes * block.expansion, 109 | kernel_size=1, stride=stride, bias=False), 110 | nn.BatchNorm3d(planes * block.expansion) 111 | ) 112 | 113 | layers = [] 114 | layers.append(block(self.inplanes, planes, cardinality, stride, downsample)) 115 | self.inplanes = planes * block.expansion 116 | for i in range(1, blocks): 117 | layers.append(block(self.inplanes, planes, cardinality)) 118 | 119 | return nn.Sequential(*layers) 120 | 121 | def forward(self, x): 122 | x = self.conv1(x) 123 | x = self.bn1(x) 124 | x = self.relu(x) 125 | x = self.maxpool(x) 126 | 127 | x = self.layer1(x) 128 | x = self.layer2(x) 129 | x = self.layer3(x) 130 | x = self.layer4(x) 131 | 132 | x = self.avgpool(x) 133 | # 134 | x = x.view(x.size(0), -1) 135 | if self.last_fc: 136 | x = self.fc(x) 137 | 138 | return x 139 | 140 | def get_fine_tuning_parameters(model, ft_begin_index): 141 | if ft_begin_index == 0: 142 | return model.parameters() 143 | 144 | ft_module_names = [] 145 | for i in range(ft_begin_index, 5): 146 | ft_module_names.append('layer{}'.format(ft_begin_index)) 147 | ft_module_names.append('fc') 148 | 149 | parameters = [] 150 | for k, v in model.named_parameters(): 151 | for ft_module in ft_module_names: 152 | if ft_module in k: 153 | parameters.append({'params': v}) 154 | break 155 | else: 156 | parameters.append({'params': v, 'lr': 0.0}) 157 | 158 | return parameters 159 | 160 | def resnet50(**kwargs): 161 | """Constructs a ResNet-50 model. 162 | """ 163 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs) 164 | return model 165 | 166 | def resnet101(**kwargs): 167 | """Constructs a ResNet-101 model. 168 | """ 169 | model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs) 170 | return model 171 | 172 | def resnet152(**kwargs): 173 | """Constructs a ResNet-101 model. 174 | """ 175 | model = ResNeXt(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs) 176 | return model 177 | -------------------------------------------------------------------------------- /preprocess/models/wide_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | import math 6 | from functools import partial 7 | 8 | __all__ = ['WideResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101'] 9 | 10 | 11 | def conv3x3x3(in_planes, out_planes, stride=1): 12 | # 3x3x3 convolution with padding 13 | return nn.Conv3d(in_planes, out_planes, kernel_size=3, 14 | stride=stride, padding=1, bias=False) 15 | 16 | 17 | def downsample_basic_block(x, planes, stride): 18 | out = F.avg_pool3d(x, kernel_size=1, stride=stride) 19 | zero_pads = torch.Tensor(out.size(0), planes - out.size(1), 20 | out.size(2), out.size(3), 21 | out.size(4)).zero_() 22 | if isinstance(out.data, torch.cuda.FloatTensor): 23 | zero_pads = zero_pads.cuda() 24 | 25 | out = Variable(torch.cat([out.data, zero_pads], dim=1)) 26 | 27 | return out 28 | 29 | 30 | class WideBottleneck(nn.Module): 31 | expansion = 2 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None): 34 | super(WideBottleneck, self).__init__() 35 | self.conv1 = nn.Conv3d(inplanes, planes, kernel_size=1, bias=False) 36 | self.bn1 = nn.BatchNorm3d(planes) 37 | self.conv2 = nn.Conv3d(planes, planes, kernel_size=3, stride=stride, 38 | padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm3d(planes) 40 | self.conv3 = nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False) 41 | self.bn3 = nn.BatchNorm3d(planes * self.expansion) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.downsample = downsample 44 | self.stride = stride 45 | 46 | def forward(self, x): 47 | residual = x 48 | 49 | out = self.conv1(x) 50 | out = self.bn1(out) 51 | out = self.relu(out) 52 | 53 | out = self.conv2(out) 54 | out = self.bn2(out) 55 | out = self.relu(out) 56 | 57 | out = self.conv3(out) 58 | out = self.bn3(out) 59 | 60 | if self.downsample is not None: 61 | residual = self.downsample(x) 62 | 63 | out += residual 64 | out = self.relu(out) 65 | 66 | return out 67 | 68 | 69 | class WideResNet(nn.Module): 70 | 71 | def __init__(self, block, layers, sample_size, sample_duration, k=1, shortcut_type='B', num_classes=400, last_fc=True): 72 | self.last_fc = last_fc 73 | 74 | self.inplanes = 64 75 | super(WideResNet, self).__init__() 76 | self.conv1 = nn.Conv3d(3, 64, kernel_size=7, stride=(1, 2, 2), 77 | padding=(3, 3, 3), bias=False) 78 | self.bn1 = nn.BatchNorm3d(64) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1) 81 | self.layer1 = self._make_layer(block, 64 * k, layers[0], shortcut_type) 82 | self.layer2 = self._make_layer(block, 128 * k, layers[1], shortcut_type, stride=2) 83 | self.layer3 = self._make_layer(block, 256 * k, layers[2], shortcut_type, stride=2) 84 | self.layer4 = self._make_layer(block, 512 * k, layers[3], shortcut_type, stride=2) 85 | last_duration = math.ceil(sample_duration / 16) 86 | last_size = math.ceil(sample_size / 32) 87 | self.avgpool = nn.AvgPool3d((last_duration, last_size, last_size), stride=1) 88 | self.fc = nn.Linear(512 * k * block.expansion, num_classes) 89 | 90 | for m in self.modules(): 91 | if isinstance(m, nn.Conv3d): 92 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 93 | m.weight.data.normal_(0, math.sqrt(2. / n)) 94 | elif isinstance(m, nn.BatchNorm3d): 95 | m.weight.data.fill_(1) 96 | m.bias.data.zero_() 97 | 98 | def _make_layer(self, block, planes, blocks, shortcut_type, stride=1): 99 | downsample = None 100 | if stride != 1 or self.inplanes != planes * block.expansion: 101 | if shortcut_type == 'A': 102 | downsample = partial(downsample_basic_block, 103 | planes=planes * block.expansion, 104 | stride=stride) 105 | else: 106 | downsample = nn.Sequential( 107 | nn.Conv3d(self.inplanes, planes * block.expansion, 108 | kernel_size=1, stride=stride, bias=False), 109 | nn.BatchNorm3d(planes * block.expansion) 110 | ) 111 | 112 | layers = [] 113 | layers.append(block(self.inplanes, planes, stride, downsample)) 114 | self.inplanes = planes * block.expansion 115 | for i in range(1, blocks): 116 | layers.append(block(self.inplanes, planes)) 117 | 118 | return nn.Sequential(*layers) 119 | 120 | def forward(self, x): 121 | x = self.conv1(x) 122 | x = self.bn1(x) 123 | x = self.relu(x) 124 | x = self.maxpool(x) 125 | 126 | x = self.layer1(x) 127 | x = self.layer2(x) 128 | x = self.layer3(x) 129 | x = self.layer4(x) 130 | 131 | x = self.avgpool(x) 132 | 133 | x = x.view(x.size(0), -1) 134 | if self.last_fc: 135 | x = self.fc(x) 136 | 137 | return x 138 | 139 | def get_fine_tuning_parameters(model, ft_begin_index): 140 | if ft_begin_index == 0: 141 | return model.parameters() 142 | 143 | ft_module_names = [] 144 | for i in range(ft_begin_index, 5): 145 | ft_module_names.append('layer{}'.format(ft_begin_index)) 146 | ft_module_names.append('fc') 147 | 148 | parameters = [] 149 | for k, v in model.named_parameters(): 150 | for ft_module in ft_module_names: 151 | if ft_module in k: 152 | parameters.append({'params': v}) 153 | break 154 | else: 155 | parameters.append({'params': v, 'lr': 0.0}) 156 | 157 | return parameters 158 | 159 | def resnet50(**kwargs): 160 | """Constructs a ResNet-50 model. 161 | """ 162 | model = WideResNet(WideBottleneck, [3, 4, 6, 3], **kwargs) 163 | return model 164 | -------------------------------------------------------------------------------- /preprocess/preprocess_features.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import h5py 3 | from scipy.misc import imresize 4 | import skvideo.io 5 | from PIL import Image 6 | 7 | import torch 8 | from torch import nn 9 | import torchvision 10 | import random 11 | import numpy as np 12 | 13 | from models import resnext 14 | from datautils import utils 15 | from datautils import tgif_qa 16 | from datautils import msrvtt_qa 17 | from datautils import msvd_qa 18 | 19 | 20 | def build_resnet(): 21 | if not hasattr(torchvision.models, args.model): 22 | raise ValueError('Invalid model "%s"' % args.model) 23 | if not 'resnet' in args.model: 24 | raise ValueError('Feature extraction only supports ResNets') 25 | cnn = getattr(torchvision.models, args.model)(pretrained=True) 26 | model = torch.nn.Sequential(*list(cnn.children())[:-1]) 27 | model.cuda() 28 | model.eval() 29 | return model 30 | 31 | 32 | def build_resnext(): 33 | model = resnext.resnet101(num_classes=400, shortcut_type='B', cardinality=32, 34 | sample_size=112, sample_duration=16, 35 | last_fc=False) 36 | model = model.cuda() 37 | model = nn.DataParallel(model, device_ids=None) 38 | assert os.path.exists('preprocess/pretrained/resnext-101-kinetics.pth') 39 | model_data = torch.load('preprocess/pretrained/resnext-101-kinetics.pth', map_location='cpu') 40 | model.load_state_dict(model_data['state_dict']) 41 | model.eval() 42 | return model 43 | 44 | 45 | def run_batch(cur_batch, model): 46 | """ 47 | Args: 48 | cur_batch: treat a video as a batch of images 49 | model: ResNet model for feature extraction 50 | Returns: 51 | ResNet extracted feature. 52 | """ 53 | mean = np.array([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1) 54 | std = np.array([0.229, 0.224, 0.224]).reshape(1, 3, 1, 1) 55 | 56 | image_batch = np.concatenate(cur_batch, 0).astype(np.float32) 57 | image_batch = (image_batch / 255.0 - mean) / std 58 | image_batch = torch.FloatTensor(image_batch).cuda() 59 | with torch.no_grad(): 60 | image_batch = torch.autograd.Variable(image_batch) 61 | 62 | feats = model(image_batch) 63 | feats = feats.data.cpu().clone().numpy() 64 | 65 | return feats 66 | 67 | 68 | def extract_clips_with_consecutive_frames(path, num_clips, num_frames_per_clip): 69 | """ 70 | Args: 71 | path: path of a video 72 | num_clips: expected numbers of splitted clips 73 | num_frames_per_clip: number of frames in a single clip, pretrained model only supports 16 frames 74 | Returns: 75 | A list of raw features of clips. 76 | """ 77 | valid = True 78 | clips = list() 79 | try: 80 | video_data = skvideo.io.vread(path) 81 | except: 82 | print('file {} error'.format(path)) 83 | valid = False 84 | if args.model == 'resnext101': 85 | return list(np.zeros(shape=(num_clips, 3, num_frames_per_clip, 112, 112))), valid 86 | else: 87 | return list(np.zeros(shape=(num_clips, num_frames_per_clip, 3, 224, 224))), valid 88 | total_frames = video_data.shape[0] 89 | img_size = (args.image_height, args.image_width) 90 | for i in np.linspace(0, total_frames, num_clips + 2, dtype=np.int32)[1:num_clips + 1]: 91 | clip_start = int(i) - int(num_frames_per_clip / 2) 92 | clip_end = int(i) + int(num_frames_per_clip / 2) 93 | if clip_start < 0: 94 | clip_start = 0 95 | if clip_end > total_frames: 96 | clip_end = total_frames - 1 97 | clip = video_data[clip_start:clip_end] 98 | if clip_start == 0: 99 | shortage = num_frames_per_clip - (clip_end - clip_start) 100 | added_frames = [] 101 | for _ in range(shortage): 102 | added_frames.append(np.expand_dims(video_data[clip_start], axis=0)) 103 | if len(added_frames) > 0: 104 | added_frames = np.concatenate(added_frames, axis=0) 105 | clip = np.concatenate((added_frames, clip), axis=0) 106 | if clip_end == (total_frames - 1): 107 | shortage = num_frames_per_clip - (clip_end - clip_start) 108 | added_frames = [] 109 | for _ in range(shortage): 110 | added_frames.append(np.expand_dims(video_data[clip_end], axis=0)) 111 | if len(added_frames) > 0: 112 | added_frames = np.concatenate(added_frames, axis=0) 113 | clip = np.concatenate((clip, added_frames), axis=0) 114 | new_clip = [] 115 | for j in range(num_frames_per_clip): 116 | frame_data = clip[j] 117 | img = Image.fromarray(frame_data) 118 | img = imresize(img, img_size, interp='bicubic') 119 | img = img.transpose(2, 0, 1)[None] 120 | frame_data = np.array(img) 121 | new_clip.append(frame_data) 122 | new_clip = np.asarray(new_clip) # (num_frames, width, height, channels) 123 | if args.model in ['resnext101']: 124 | new_clip = np.squeeze(new_clip) 125 | new_clip = np.transpose(new_clip, axes=(1, 0, 2, 3)) 126 | clips.append(new_clip) 127 | return clips, valid 128 | 129 | 130 | def generate_h5(model, video_ids, num_clips, outfile): 131 | """ 132 | Args: 133 | model: loaded pretrained model for feature extraction 134 | video_ids: list of video ids 135 | num_clips: expected numbers of splitted clips 136 | outfile: path of output file to be written 137 | Returns: 138 | h5 file containing visual features of splitted clips. 139 | """ 140 | if args.dataset == "tgif-qa": 141 | if not os.path.exists('data/tgif-qa/{}'.format(args.question_type)): 142 | os.makedirs('data/tgif-qa/{}'.format(args.question_type)) 143 | else: 144 | if not os.path.exists('data/{}'.format(args.dataset)): 145 | os.makedirs('data/{}'.format(args.dataset)) 146 | 147 | dataset_size = len(video_ids) 148 | 149 | with h5py.File(outfile, 'w') as fd: 150 | feat_dset = None 151 | video_ids_dset = None 152 | i0 = 0 153 | _t = {'misc': utils.Timer()} 154 | for i, (video_path, video_id) in enumerate(video_ids): 155 | _t['misc'].tic() 156 | clips, valid = extract_clips_with_consecutive_frames(video_path, num_clips=num_clips, num_frames_per_clip=16) 157 | if args.feature_type == 'appearance': 158 | clip_feat = [] 159 | if valid: 160 | for clip_id, clip in enumerate(clips): 161 | feats = run_batch(clip, model) # (16, 2048) 162 | feats = feats.squeeze() 163 | clip_feat.append(feats) 164 | else: 165 | clip_feat = np.zeros(shape=(num_clips, 16, 2048)) 166 | clip_feat = np.asarray(clip_feat) # (8, 16, 2048) 167 | if feat_dset is None: 168 | C, F, D = clip_feat.shape 169 | feat_dset = fd.create_dataset('resnet_features', (dataset_size, C, F, D), 170 | dtype=np.float32) 171 | video_ids_dset = fd.create_dataset('ids', shape=(dataset_size,), dtype=np.int) 172 | elif args.feature_type == 'motion': 173 | clip_torch = torch.FloatTensor(np.asarray(clips)).cuda() 174 | if valid: 175 | clip_feat = model(clip_torch) # (8, 2048) 176 | clip_feat = clip_feat.squeeze() 177 | clip_feat = clip_feat.detach().cpu().numpy() 178 | else: 179 | clip_feat = np.zeros(shape=(num_clips, 2048)) 180 | if feat_dset is None: 181 | C, D = clip_feat.shape 182 | feat_dset = fd.create_dataset('resnext_features', (dataset_size, C, D), 183 | dtype=np.float32) 184 | video_ids_dset = fd.create_dataset('ids', shape=(dataset_size,), dtype=np.int) 185 | 186 | i1 = i0 + 1 187 | feat_dset[i0:i1] = clip_feat 188 | video_ids_dset[i0:i1] = video_id 189 | i0 = i1 190 | _t['misc'].toc() 191 | if (i % 1000 == 0): 192 | print('{:d}/{:d} {:.3f}s (projected finish: {:.2f} hours)' \ 193 | .format(i1, dataset_size, _t['misc'].average_time, 194 | _t['misc'].average_time * (dataset_size - i1) / 3600)) 195 | 196 | 197 | if __name__ == '__main__': 198 | parser = argparse.ArgumentParser() 199 | parser.add_argument('--gpu_id', type=int, default=2, help='specify which gpu will be used') 200 | # dataset info 201 | parser.add_argument('--dataset', default='tgif-qa', choices=['tgif-qa', 'msvd-qa', 'msrvtt-qa'], type=str) 202 | parser.add_argument('--question_type', default='none', choices=['frameqa', 'count', 'transition', 'action', 'none'], type=str) 203 | # output 204 | parser.add_argument('--out', dest='outfile', 205 | help='output filepath', 206 | default="data/{}/{}_{}_feat.h5", type=str) 207 | # image sizes 208 | parser.add_argument('--num_clips', default=8, type=int) 209 | parser.add_argument('--image_height', default=224, type=int) 210 | parser.add_argument('--image_width', default=224, type=int) 211 | 212 | # network params 213 | parser.add_argument('--model', default='resnet101', choices=['resnet101', 'resnext101'], type=str) 214 | parser.add_argument('--seed', default='666', type=int, help='random seed') 215 | args = parser.parse_args() 216 | if args.model == 'resnet101': 217 | args.feature_type = 'appearance' 218 | elif args.model == 'resnext101': 219 | args.feature_type = 'motion' 220 | else: 221 | raise Exception('Feature type not supported!') 222 | # set gpu 223 | if args.model != 'resnext101': 224 | torch.cuda.set_device(args.gpu_id) 225 | torch.manual_seed(args.seed) 226 | np.random.seed(args.seed) 227 | 228 | # annotation files 229 | if args.dataset == 'tgif-qa': 230 | args.annotation_file = '/ceph-g/lethao/datasets/tgif-qa/csv/Total_{}_question.csv' 231 | args.video_dir = '/ceph-g/lethao/datasets/tgif-qa/gifs' 232 | args.outfile = 'data/{}/{}/{}_{}_{}_feat.h5' 233 | video_paths = tgif_qa.load_video_paths(args) 234 | random.shuffle(video_paths) 235 | # load model 236 | if args.model == 'resnet101': 237 | model = build_resnet() 238 | elif args.model == 'resnext101': 239 | model = build_resnext() 240 | generate_h5(model, video_paths, args.num_clips, 241 | args.outfile.format(args.dataset, args.question_type, args.dataset, args.question_type, args.feature_type)) 242 | elif args.dataset == 'msrvtt-qa': 243 | args.annotation_file = '/ceph-g/lethao/datasets/msrvtt/annotations/{}_qa.json' 244 | args.video_dir = '/ceph-g/lethao/datasets/msrvtt/videos/' 245 | video_paths = msrvtt_qa.load_video_paths(args) 246 | random.shuffle(video_paths) 247 | # load model 248 | if args.model == 'resnet101': 249 | model = build_resnet() 250 | elif args.model == 'resnext101': 251 | model = build_resnext() 252 | generate_h5(model, video_paths, args.num_clips, 253 | args.outfile.format(args.dataset, args.dataset, args.feature_type)) 254 | 255 | elif args.dataset == 'msvd-qa': 256 | args.annotation_file = '/ceph-g/lethao/datasets/msvd/MSVD-QA/{}_qa.json' 257 | args.video_dir = '/ceph-g/lethao/datasets/msvd/MSVD-QA/video/' 258 | args.video_name_mapping = '/ceph-g/lethao/datasets/msvd/youtube_mapping.txt' 259 | video_paths = msvd_qa.load_video_paths(args) 260 | random.shuffle(video_paths) 261 | # load model 262 | if args.model == 'resnet101': 263 | model = build_resnet() 264 | elif args.model == 'resnext101': 265 | model = build_resnext() 266 | generate_h5(model, video_paths, args.num_clips, 267 | args.outfile.format(args.dataset, args.dataset, args.feature_type)) 268 | -------------------------------------------------------------------------------- /preprocess/preprocess_questions.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import os 4 | 5 | from datautils import tgif_qa 6 | from datautils import msrvtt_qa 7 | from datautils import msvd_qa 8 | 9 | if __name__ == '__main__': 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument('--dataset', default='tgif-qa', choices=['tgif-qa', 'msrvtt-qa', 'msvd-qa'], type=str) 12 | parser.add_argument('--answer_top', default=4000, type=int) 13 | parser.add_argument('--glove_pt', 14 | help='glove pickle file, should be a map whose key are words and value are word vectors represented by numpy arrays. Only needed in train mode') 15 | parser.add_argument('--output_pt', type=str, default='data/{}/{}_{}_questions.pt') 16 | parser.add_argument('--vocab_json', type=str, default='data/{}/{}_vocab.json') 17 | parser.add_argument('--mode', choices=['train', 'val', 'test']) 18 | parser.add_argument('--question_type', choices=['frameqa', 'action', 'transition', 'count', 'none'], default='none') 19 | parser.add_argument('--seed', type=int, default=666) 20 | 21 | args = parser.parse_args() 22 | np.random.seed(args.seed) 23 | 24 | if args.dataset == 'tgif-qa': 25 | args.annotation_file = '/ceph-g/lethao/datasets/tgif-qa/csv/{}_{}_question.csv' 26 | args.output_pt = 'data/tgif-qa/{}/tgif-qa_{}_{}_questions.pt' 27 | args.vocab_json = 'data/tgif-qa/{}/tgif-qa_{}_vocab.json' 28 | # check if data folder exists 29 | if not os.path.exists('data/tgif-qa/{}'.format(args.question_type)): 30 | os.makedirs('data/tgif-qa/{}'.format(args.question_type)) 31 | 32 | if args.question_type in ['frameqa', 'count']: 33 | tgif_qa.process_questions_openended(args) 34 | else: 35 | tgif_qa.process_questions_mulchoices(args) 36 | elif args.dataset == 'msrvtt-qa': 37 | args.annotation_file = '/ceph-g/lethao/datasets/msrvtt/annotations/{}_qa.json'.format(args.mode) 38 | # check if data folder exists 39 | if not os.path.exists('data/{}'.format(args.dataset)): 40 | os.makedirs('data/{}'.format(args.dataset)) 41 | msrvtt_qa.process_questions(args) 42 | elif args.dataset == 'msvd-qa': 43 | args.annotation_file = '/ceph-g/lethao/datasets/msvd/MSVD-QA/{}_qa.json'.format(args.mode) 44 | # check if data folder exists 45 | if not os.path.exists('data/{}'.format(args.dataset)): 46 | os.makedirs('data/{}'.format(args.dataset)) 47 | msvd_qa.process_questions(args) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | h5py 3 | nltk 4 | numpy 5 | Pillow 6 | torch==1.2.0 7 | torchvision==0.2.1 8 | tqdm 9 | termcolor 10 | pyyaml -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn as nn 6 | import numpy as np 7 | import argparse 8 | import time 9 | import logging 10 | from termcolor import colored 11 | 12 | logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s') 13 | logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s') 14 | rootLogger = logging.getLogger() 15 | 16 | from DataLoader import VideoQADataLoader 17 | from utils import todevice 18 | from validate import validate 19 | 20 | import model.HCRN as HCRN 21 | 22 | from utils import todevice 23 | 24 | from config import cfg, cfg_from_file 25 | 26 | 27 | def train(cfg): 28 | logging.info("Create train_loader and val_loader.........") 29 | train_loader_kwargs = { 30 | 'question_type': cfg.dataset.question_type, 31 | 'question_pt': cfg.dataset.train_question_pt, 32 | 'vocab_json': cfg.dataset.vocab_json, 33 | 'appearance_feat': cfg.dataset.appearance_feat, 34 | 'motion_feat': cfg.dataset.motion_feat, 35 | 'train_num': cfg.train.train_num, 36 | 'batch_size': cfg.train.batch_size, 37 | 'num_workers': cfg.num_workers, 38 | 'shuffle': True 39 | } 40 | train_loader = VideoQADataLoader(**train_loader_kwargs) 41 | logging.info("number of train instances: {}".format(len(train_loader.dataset))) 42 | if cfg.val.flag: 43 | val_loader_kwargs = { 44 | 'question_type': cfg.dataset.question_type, 45 | 'question_pt': cfg.dataset.val_question_pt, 46 | 'vocab_json': cfg.dataset.vocab_json, 47 | 'appearance_feat': cfg.dataset.appearance_feat, 48 | 'motion_feat': cfg.dataset.motion_feat, 49 | 'val_num': cfg.val.val_num, 50 | 'batch_size': cfg.train.batch_size, 51 | 'num_workers': cfg.num_workers, 52 | 'shuffle': False 53 | } 54 | val_loader = VideoQADataLoader(**val_loader_kwargs) 55 | logging.info("number of val instances: {}".format(len(val_loader.dataset))) 56 | 57 | logging.info("Create model.........") 58 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 59 | model_kwargs = { 60 | 'vision_dim': cfg.train.vision_dim, 61 | 'module_dim': cfg.train.module_dim, 62 | 'word_dim': cfg.train.word_dim, 63 | 'k_max_frame_level': cfg.train.k_max_frame_level, 64 | 'k_max_clip_level': cfg.train.k_max_clip_level, 65 | 'spl_resolution': cfg.train.spl_resolution, 66 | 'vocab': train_loader.vocab, 67 | 'question_type': cfg.dataset.question_type 68 | } 69 | model_kwargs_tosave = {k: v for k, v in model_kwargs.items() if k != 'vocab'} 70 | model = HCRN.HCRNNetwork(**model_kwargs).to(device) 71 | pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 72 | logging.info('num of params: {}'.format(pytorch_total_params)) 73 | logging.info(model) 74 | 75 | if cfg.train.glove: 76 | logging.info('load glove vectors') 77 | train_loader.glove_matrix = torch.FloatTensor(train_loader.glove_matrix).to(device) 78 | with torch.no_grad(): 79 | model.linguistic_input_unit.encoder_embed.weight.set_(train_loader.glove_matrix) 80 | if torch.cuda.device_count() > 1 and cfg.multi_gpus: 81 | model = model.cuda() 82 | logging.info("Using {} GPUs".format(torch.cuda.device_count())) 83 | model = nn.DataParallel(model, device_ids=None) 84 | 85 | optimizer = optim.Adam(model.parameters(), cfg.train.lr) 86 | 87 | start_epoch = 0 88 | if cfg.dataset.question_type == 'count': 89 | best_val = 100.0 90 | else: 91 | best_val = 0 92 | if cfg.train.restore: 93 | print("Restore checkpoint and optimizer...") 94 | ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt') 95 | ckpt = torch.load(ckpt, map_location=lambda storage, loc: storage) 96 | start_epoch = ckpt['epoch'] + 1 97 | model.load_state_dict(ckpt['state_dict']) 98 | optimizer.load_state_dict(ckpt['optimizer']) 99 | if cfg.dataset.question_type in ['frameqa', 'none']: 100 | criterion = nn.CrossEntropyLoss().to(device) 101 | elif cfg.dataset.question_type == 'count': 102 | criterion = nn.MSELoss().to(device) 103 | logging.info("Start training........") 104 | for epoch in range(start_epoch, cfg.train.max_epochs): 105 | logging.info('>>>>>> epoch {epoch} <<<<<<'.format(epoch=colored("{}".format(epoch), "green", attrs=["bold"]))) 106 | model.train() 107 | total_acc, count = 0, 0 108 | batch_mse_sum = 0.0 109 | total_loss, avg_loss = 0.0, 0.0 110 | avg_loss = 0 111 | train_accuracy = 0 112 | for i, batch in enumerate(iter(train_loader)): 113 | progress = epoch + i / len(train_loader) 114 | _, _, answers, *batch_input = [todevice(x, device) for x in batch] 115 | answers = answers.cuda().squeeze() 116 | batch_size = answers.size(0) 117 | optimizer.zero_grad() 118 | logits = model(*batch_input) 119 | if cfg.dataset.question_type in ['action', 'transition']: 120 | batch_agg = np.concatenate(np.tile(np.arange(batch_size).reshape([batch_size, 1]), 121 | [1, 5])) * 5 # [0, 0, 0, 0, 0, 5, 5, 5, 5, 1, ...] 122 | answers_agg = tile(answers, 0, 5) 123 | loss = torch.max(torch.tensor(0.0).cuda(), 124 | 1.0 + logits - logits[answers_agg + torch.from_numpy(batch_agg).cuda()]) 125 | loss = loss.sum() 126 | loss.backward() 127 | total_loss += loss.detach() 128 | avg_loss = total_loss / (i + 1) 129 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) 130 | optimizer.step() 131 | preds = torch.argmax(logits.view(batch_size, 5), dim=1) 132 | aggreeings = (preds == answers) 133 | elif cfg.dataset.question_type == 'count': 134 | answers = answers.unsqueeze(-1) 135 | loss = criterion(logits, answers.float()) 136 | loss.backward() 137 | total_loss += loss.detach() 138 | avg_loss = total_loss / (i + 1) 139 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) 140 | optimizer.step() 141 | preds = (logits + 0.5).long().clamp(min=1, max=10) 142 | batch_mse = (preds - answers) ** 2 143 | else: 144 | loss = criterion(logits, answers) 145 | loss.backward() 146 | total_loss += loss.detach() 147 | avg_loss = total_loss / (i + 1) 148 | nn.utils.clip_grad_norm_(model.parameters(), max_norm=12) 149 | optimizer.step() 150 | aggreeings = batch_accuracy(logits, answers) 151 | 152 | if cfg.dataset.question_type == 'count': 153 | batch_avg_mse = batch_mse.sum().item() / answers.size(0) 154 | batch_mse_sum += batch_mse.sum().item() 155 | count += answers.size(0) 156 | avg_mse = batch_mse_sum / count 157 | sys.stdout.write( 158 | "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_mse = {train_mse} avg_mse = {avg_mse} exp: {exp_name}".format( 159 | progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), 160 | ce_loss=colored("{:.4f}".format(loss.item()), "blue", attrs=['bold']), 161 | avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), 162 | train_mse=colored("{:.4f}".format(batch_avg_mse), "blue", 163 | attrs=['bold']), 164 | avg_mse=colored("{:.4f}".format(avg_mse), "red", attrs=['bold']), 165 | exp_name=cfg.exp_name)) 166 | sys.stdout.flush() 167 | else: 168 | total_acc += aggreeings.sum().item() 169 | count += answers.size(0) 170 | train_accuracy = total_acc / count 171 | sys.stdout.write( 172 | "\rProgress = {progress} ce_loss = {ce_loss} avg_loss = {avg_loss} train_acc = {train_acc} avg_acc = {avg_acc} exp: {exp_name}".format( 173 | progress=colored("{:.3f}".format(progress), "green", attrs=['bold']), 174 | ce_loss=colored("{:.4f}".format(loss.item()), "blue", attrs=['bold']), 175 | avg_loss=colored("{:.4f}".format(avg_loss), "red", attrs=['bold']), 176 | train_acc=colored("{:.4f}".format(aggreeings.float().mean().cpu().numpy()), "blue", 177 | attrs=['bold']), 178 | avg_acc=colored("{:.4f}".format(train_accuracy), "red", attrs=['bold']), 179 | exp_name=cfg.exp_name)) 180 | sys.stdout.flush() 181 | sys.stdout.write("\n") 182 | if cfg.dataset.question_type == 'count': 183 | if (epoch + 1) % 5 == 0: 184 | optimizer = step_decay(cfg, optimizer) 185 | else: 186 | if (epoch + 1) % 10 == 0: 187 | optimizer = step_decay(cfg, optimizer) 188 | sys.stdout.flush() 189 | logging.info("Epoch = %s avg_loss = %.3f avg_acc = %.3f" % (epoch, avg_loss, train_accuracy)) 190 | 191 | if cfg.val.flag: 192 | output_dir = os.path.join(cfg.dataset.save_dir, 'preds') 193 | if not os.path.exists(output_dir): 194 | os.makedirs(output_dir) 195 | else: 196 | assert os.path.isdir(output_dir) 197 | valid_acc = validate(cfg, model, val_loader, device, write_preds=False) 198 | if (valid_acc > best_val and cfg.dataset.question_type != 'count') or (valid_acc < best_val and cfg.dataset.question_type == 'count'): 199 | best_val = valid_acc 200 | # Save best model 201 | ckpt_dir = os.path.join(cfg.dataset.save_dir, 'ckpt') 202 | if not os.path.exists(ckpt_dir): 203 | os.makedirs(ckpt_dir) 204 | else: 205 | assert os.path.isdir(ckpt_dir) 206 | save_checkpoint(epoch, model, optimizer, model_kwargs_tosave, os.path.join(ckpt_dir, 'model.pt')) 207 | sys.stdout.write('\n >>>>>> save to %s <<<<<< \n' % (ckpt_dir)) 208 | sys.stdout.flush() 209 | 210 | logging.info('~~~~~~ Valid Accuracy: %.4f ~~~~~~~' % valid_acc) 211 | sys.stdout.write('~~~~~~ Valid Accuracy: {valid_acc} ~~~~~~~\n'.format( 212 | valid_acc=colored("{:.4f}".format(valid_acc), "red", attrs=['bold']))) 213 | sys.stdout.flush() 214 | 215 | # Credit https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/4 216 | def tile(a, dim, n_tile): 217 | init_dim = a.size(dim) 218 | repeat_idx = [1] * a.dim() 219 | repeat_idx[dim] = n_tile 220 | a = a.repeat(*(repeat_idx)) 221 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).cuda() 222 | return torch.index_select(a, dim, order_index) 223 | 224 | 225 | def step_decay(cfg, optimizer): 226 | # compute the new learning rate based on decay rate 227 | cfg.train.lr *= 0.5 228 | logging.info("Reduced learning rate to {}".format(cfg.train.lr)) 229 | sys.stdout.flush() 230 | for param_group in optimizer.param_groups: 231 | param_group['lr'] = cfg.train.lr 232 | 233 | return optimizer 234 | 235 | 236 | def batch_accuracy(predicted, true): 237 | """ Compute the accuracies for a batch of predictions and answers """ 238 | predicted = predicted.detach().argmax(1) 239 | agreeing = (predicted == true) 240 | return agreeing 241 | 242 | 243 | def save_checkpoint(epoch, model, optimizer, model_kwargs, filename): 244 | state = { 245 | 'epoch': epoch, 246 | 'state_dict': model.state_dict(), 247 | 'optimizer': optimizer.state_dict(), 248 | 'model_kwargs': model_kwargs, 249 | } 250 | time.sleep(10) 251 | torch.save(state, filename) 252 | 253 | 254 | def main(): 255 | parser = argparse.ArgumentParser() 256 | parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='tgif_qa_action.yml', type=str) 257 | args = parser.parse_args() 258 | if args.cfg_file is not None: 259 | cfg_from_file(args.cfg_file) 260 | 261 | assert cfg.dataset.name in ['tgif-qa', 'msrvtt-qa', 'msvd-qa'] 262 | assert cfg.dataset.question_type in ['frameqa', 'count', 'transition', 'action', 'none'] 263 | # check if the data folder exists 264 | assert os.path.exists(cfg.dataset.data_dir) 265 | # check if k_max is set correctly 266 | assert cfg.train.k_max_frame_level <= 16 267 | assert cfg.train.k_max_clip_level <= 8 268 | 269 | 270 | if not cfg.multi_gpus: 271 | torch.cuda.set_device(cfg.gpu_id) 272 | # make logging.info display into both shell and file 273 | cfg.dataset.save_dir = os.path.join(cfg.dataset.save_dir, cfg.exp_name) 274 | if not os.path.exists(cfg.dataset.save_dir): 275 | os.makedirs(cfg.dataset.save_dir) 276 | else: 277 | assert os.path.isdir(cfg.dataset.save_dir) 278 | log_file = os.path.join(cfg.dataset.save_dir, "log") 279 | if not cfg.train.restore and not os.path.exists(log_file): 280 | os.mkdir(log_file) 281 | else: 282 | assert os.path.isdir(log_file) 283 | 284 | fileHandler = logging.FileHandler(os.path.join(log_file, 'stdout.log'), 'w+') 285 | fileHandler.setFormatter(logFormatter) 286 | rootLogger.addHandler(fileHandler) 287 | # args display 288 | for k, v in vars(cfg).items(): 289 | logging.info(k + ':' + str(v)) 290 | # concat absolute path of input files 291 | 292 | if cfg.dataset.name == 'tgif-qa': 293 | cfg.dataset.train_question_pt = os.path.join(cfg.dataset.data_dir, 294 | cfg.dataset.train_question_pt.format(cfg.dataset.name, cfg.dataset.question_type)) 295 | cfg.dataset.val_question_pt = os.path.join(cfg.dataset.data_dir, 296 | cfg.dataset.val_question_pt.format(cfg.dataset.name, cfg.dataset.question_type)) 297 | cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name, cfg.dataset.question_type)) 298 | 299 | cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name, cfg.dataset.question_type)) 300 | cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name, cfg.dataset.question_type)) 301 | else: 302 | cfg.dataset.question_type = 'none' 303 | cfg.dataset.appearance_feat = '{}_appearance_feat.h5' 304 | cfg.dataset.motion_feat = '{}_motion_feat.h5' 305 | cfg.dataset.vocab_json = '{}_vocab.json' 306 | cfg.dataset.train_question_pt = '{}_train_questions.pt' 307 | cfg.dataset.val_question_pt = '{}_val_questions.pt' 308 | cfg.dataset.train_question_pt = os.path.join(cfg.dataset.data_dir, 309 | cfg.dataset.train_question_pt.format(cfg.dataset.name)) 310 | cfg.dataset.val_question_pt = os.path.join(cfg.dataset.data_dir, 311 | cfg.dataset.val_question_pt.format(cfg.dataset.name)) 312 | cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name)) 313 | 314 | cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name)) 315 | cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name)) 316 | 317 | # set random seed 318 | torch.manual_seed(cfg.seed) 319 | np.random.seed(cfg.seed) 320 | if torch.cuda.is_available(): 321 | torch.cuda.manual_seed_all(cfg.seed) 322 | 323 | train(cfg) 324 | 325 | 326 | if __name__ == '__main__': 327 | main() 328 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def todevice(tensor, device): 4 | if isinstance(tensor, list) or isinstance(tensor, tuple): 5 | assert isinstance(tensor[0], torch.Tensor) 6 | return [todevice(t, device) for t in tensor] 7 | elif isinstance(tensor, torch.Tensor): 8 | return tensor.to(device) -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from tqdm import tqdm 4 | import argparse 5 | import os, sys 6 | import json 7 | import pickle 8 | from termcolor import colored 9 | 10 | from DataLoader import VideoQADataLoader 11 | from utils import todevice 12 | 13 | import model.HCRN as HCRN 14 | 15 | from config import cfg, cfg_from_file 16 | 17 | 18 | def validate(cfg, model, data, device, write_preds=False): 19 | model.eval() 20 | print('validating...') 21 | total_acc, count = 0.0, 0 22 | all_preds = [] 23 | gts = [] 24 | v_ids = [] 25 | q_ids = [] 26 | with torch.no_grad(): 27 | for batch in tqdm(data, total=len(data)): 28 | video_ids, question_ids, answers, *batch_input = [todevice(x, device) for x in batch] 29 | if cfg.train.batch_size == 1: 30 | answers = answers.to(device) 31 | else: 32 | answers = answers.to(device).squeeze() 33 | batch_size = answers.size(0) 34 | logits = model(*batch_input).to(device) 35 | if cfg.dataset.question_type in ['action', 'transition']: 36 | preds = torch.argmax(logits.view(batch_size, 5), dim=1) 37 | agreeings = (preds == answers) 38 | elif cfg.dataset.question_type == 'count': 39 | answers = answers.unsqueeze(-1) 40 | preds = (logits + 0.5).long().clamp(min=1, max=10) 41 | batch_mse = (preds - answers) ** 2 42 | else: 43 | preds = logits.detach().argmax(1) 44 | agreeings = (preds == answers) 45 | if write_preds: 46 | if cfg.dataset.question_type not in ['action', 'transition', 'count']: 47 | preds = logits.argmax(1) 48 | if cfg.dataset.question_type in ['action', 'transition']: 49 | answer_vocab = data.vocab['question_answer_idx_to_token'] 50 | else: 51 | answer_vocab = data.vocab['answer_idx_to_token'] 52 | for predict in preds: 53 | if cfg.dataset.question_type in ['count', 'transition', 'action']: 54 | all_preds.append(predict.item()) 55 | else: 56 | all_preds.append(answer_vocab[predict.item()]) 57 | for gt in answers: 58 | if cfg.dataset.question_type in ['count', 'transition', 'action']: 59 | gts.append(gt.item()) 60 | else: 61 | gts.append(answer_vocab[gt.item()]) 62 | for id in video_ids: 63 | v_ids.append(id.cpu().numpy()) 64 | for ques_id in question_ids: 65 | q_ids.append(ques_id) 66 | 67 | if cfg.dataset.question_type == 'count': 68 | total_acc += batch_mse.float().sum().item() 69 | count += answers.size(0) 70 | else: 71 | total_acc += agreeings.float().sum().item() 72 | count += answers.size(0) 73 | acc = total_acc / count 74 | if not write_preds: 75 | return acc 76 | else: 77 | return acc, all_preds, gts, v_ids, q_ids 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser() 82 | parser.add_argument('--cfg', dest='cfg_file', help='optional config file', default='tgif_qa_action.yml', type=str) 83 | args = parser.parse_args() 84 | if args.cfg_file is not None: 85 | cfg_from_file(args.cfg_file) 86 | 87 | assert cfg.dataset.name in ['tgif-qa', 'msrvtt-qa', 'msvd-qa'] 88 | assert cfg.dataset.question_type in ['frameqa', 'count', 'transition', 'action', 'none'] 89 | # check if the data folder exists 90 | assert os.path.exists(cfg.dataset.data_dir) 91 | 92 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 93 | cfg.dataset.save_dir = os.path.join(cfg.dataset.save_dir, cfg.exp_name) 94 | ckpt = os.path.join(cfg.dataset.save_dir, 'ckpt', 'model.pt') 95 | assert os.path.exists(ckpt) 96 | # load pretrained model 97 | loaded = torch.load(ckpt, map_location='cpu') 98 | model_kwargs = loaded['model_kwargs'] 99 | 100 | if cfg.dataset.name == 'tgif-qa': 101 | cfg.dataset.test_question_pt = os.path.join(cfg.dataset.data_dir, 102 | cfg.dataset.test_question_pt.format(cfg.dataset.name, cfg.dataset.question_type)) 103 | cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name, cfg.dataset.question_type)) 104 | 105 | cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name, cfg.dataset.question_type)) 106 | cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name, cfg.dataset.question_type)) 107 | else: 108 | cfg.dataset.question_type = 'none' 109 | cfg.dataset.appearance_feat = '{}_appearance_feat.h5' 110 | cfg.dataset.motion_feat = '{}_motion_feat.h5' 111 | cfg.dataset.vocab_json = '{}_vocab.json' 112 | cfg.dataset.test_question_pt = '{}_test_questions.pt' 113 | 114 | cfg.dataset.test_question_pt = os.path.join(cfg.dataset.data_dir, 115 | cfg.dataset.test_question_pt.format(cfg.dataset.name)) 116 | cfg.dataset.vocab_json = os.path.join(cfg.dataset.data_dir, cfg.dataset.vocab_json.format(cfg.dataset.name)) 117 | 118 | cfg.dataset.appearance_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.appearance_feat.format(cfg.dataset.name)) 119 | cfg.dataset.motion_feat = os.path.join(cfg.dataset.data_dir, cfg.dataset.motion_feat.format(cfg.dataset.name)) 120 | 121 | test_loader_kwargs = { 122 | 'question_type': cfg.dataset.question_type, 123 | 'question_pt': cfg.dataset.test_question_pt, 124 | 'vocab_json': cfg.dataset.vocab_json, 125 | 'appearance_feat': cfg.dataset.appearance_feat, 126 | 'motion_feat': cfg.dataset.motion_feat, 127 | 'test_num': cfg.test.test_num, 128 | 'batch_size': cfg.train.batch_size, 129 | 'num_workers': cfg.num_workers, 130 | 'shuffle': False 131 | } 132 | test_loader = VideoQADataLoader(**test_loader_kwargs) 133 | model_kwargs.update({'vocab': test_loader.vocab}) 134 | model = HCRN.HCRNNetwork(**model_kwargs).to(device) 135 | model.load_state_dict(loaded['state_dict']) 136 | 137 | if cfg.test.write_preds: 138 | acc, preds, gts, v_ids, q_ids = validate(cfg, model, test_loader, device, cfg.test.write_preds) 139 | 140 | sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format( 141 | test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold']))) 142 | sys.stdout.flush() 143 | 144 | # write predictions for visualization purposes 145 | output_dir = os.path.join(cfg.dataset.save_dir, 'preds') 146 | if not os.path.exists(output_dir): 147 | os.makedirs(output_dir) 148 | else: 149 | assert os.path.isdir(output_dir) 150 | preds_file = os.path.join(output_dir, "test_preds.json") 151 | 152 | if cfg.dataset.question_type in ['action', 'transition']: \ 153 | # Find groundtruth questions and corresponding answer candidates 154 | vocab = test_loader.vocab['question_answer_idx_to_token'] 155 | dict = {} 156 | with open(cfg.dataset.test_question_pt, 'rb') as f: 157 | obj = pickle.load(f) 158 | questions = obj['questions'] 159 | org_v_ids = obj['video_ids'] 160 | org_v_names = obj['video_names'] 161 | org_q_ids = obj['question_id'] 162 | ans_candidates = obj['ans_candidates'] 163 | 164 | for idx in range(len(org_q_ids)): 165 | dict[str(org_q_ids[idx])] = [org_v_names[idx], questions[idx], ans_candidates[idx]] 166 | instances = [ 167 | {'video_id': video_id, 'question_id': q_id, 'video_name': dict[str(q_id)][0], 'question': [vocab[word.item()] for word in dict[str(q_id)][1] if word != 0], 168 | 'answer': answer, 169 | 'prediction': pred} for video_id, q_id, answer, pred in 170 | zip(np.hstack(v_ids).tolist(), np.hstack(q_ids).tolist(), gts, preds)] 171 | # write preditions to json file 172 | with open(preds_file, 'w') as f: 173 | json.dump(instances, f) 174 | sys.stdout.write('Display 10 samples...\n') 175 | # Display 10 samples 176 | for idx in range(10): 177 | print('Video name: {}'.format(dict[str(q_ids[idx].item())][0])) 178 | cur_question = [vocab[word.item()] for word in dict[str(q_ids[idx].item())][1] if word != 0] 179 | print('Question: ' + ' '.join(cur_question) + '?') 180 | all_answer_cands = dict[str(q_ids[idx].item())][2] 181 | for cand_id in range(len(all_answer_cands)): 182 | cur_answer_cands = [vocab[word.item()] for word in all_answer_cands[cand_id] if word 183 | != 0] 184 | print('({}): '.format(cand_id) + ' '.join(cur_answer_cands)) 185 | print('Prediction: {}'.format(preds[idx])) 186 | print('Groundtruth: {}'.format(gts[idx])) 187 | else: 188 | vocab = test_loader.vocab['question_idx_to_token'] 189 | dict = {} 190 | with open(cfg.dataset.test_question_pt, 'rb') as f: 191 | obj = pickle.load(f) 192 | questions = obj['questions'] 193 | org_v_ids = obj['video_ids'] 194 | org_v_names = obj['video_names'] 195 | org_q_ids = obj['question_id'] 196 | 197 | for idx in range(len(org_q_ids)): 198 | dict[str(org_q_ids[idx])] = [org_v_names[idx], questions[idx]] 199 | instances = [ 200 | {'video_id': video_id, 'question_id': q_id, 'video_name': str(dict[str(q_id)][0]), 'question': [vocab[word.item()] for word in dict[str(q_id)][1] if word != 0], 201 | 'answer': answer, 202 | 'prediction': pred} for video_id, q_id, answer, pred in 203 | zip(np.hstack(v_ids).tolist(), np.hstack(q_ids).tolist(), gts, preds)] 204 | # write preditions to json file 205 | with open(preds_file, 'w') as f: 206 | json.dump(instances, f) 207 | sys.stdout.write('Display 10 samples...\n') 208 | # Display 10 examples 209 | for idx in range(10): 210 | print('Video name: {}'.format(dict[str(q_ids[idx].item())][0])) 211 | cur_question = [vocab[word.item()] for word in dict[str(q_ids[idx].item())][1] if word != 0] 212 | print('Question: ' + ' '.join(cur_question) + '?') 213 | print('Prediction: {}'.format(preds[idx])) 214 | print('Groundtruth: {}'.format(gts[idx])) 215 | else: 216 | acc = validate(cfg, model, test_loader, device, cfg.test.write_preds) 217 | sys.stdout.write('~~~~~~ Test Accuracy: {test_acc} ~~~~~~~\n'.format( 218 | test_acc=colored("{:.4f}".format(acc), "red", attrs=['bold']))) 219 | sys.stdout.flush() 220 | --------------------------------------------------------------------------------