├── .gitignore ├── LICENSE ├── README.md ├── configs ├── __init__.py └── settings.py ├── data ├── __init__.py └── loader │ ├── __init__.py │ └── data_loader.py ├── eval.py ├── figures ├── Entity.png ├── HMN.png ├── MSRVTT-performance.png ├── MSVD-performance.png └── motivation.png ├── hmn.yaml ├── main.py ├── models ├── __init__.py ├── caption_models │ ├── __init__.py │ ├── caption_module.py │ └── hierarchical_model.py ├── decoder.py ├── encoders │ ├── __init__.py │ ├── entity_level.py │ ├── predicate_level.py │ ├── sentence_level.py │ └── transformer.py └── hungary.py ├── scripts └── split_data.py ├── test.py ├── train.py └── utils ├── __init__.py ├── build_loaders.py ├── build_model.py └── loss.py /.gitignore: -------------------------------------------------------------------------------- 1 | # => data dir 2 | data/MSRVTT/ 3 | data/MSVD/ 4 | 5 | # => checkpoints 6 | checkpoints/ 7 | 8 | # => results 9 | results/ 10 | 11 | # => pycache 12 | */__pycache__ 13 | __pycache__/ 14 | 15 | # => eval metrics 16 | utils/coco_caption 17 | 18 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 MarcusNerva 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # HMN 2 | [[Paper]](https://arxiv.org/abs/2111.12476) 3 | 4 | Official code for **Hierarchical Modular Network for Video Captioning**.
5 | 6 | *Hanhua Ye, Guorong Li, Yuankai Qi, Shuhui Wang, Qingming Huang, Ming-Hsuan Yang* 7 | 8 | Accepted by CVPR2022
9 | 10 | motivation 11 | 12 |
Figure 1.Motivation
13 | 14 | Representation learning plays a crucial role in video captioning task. Hierarchical Modular Network learns a discriminative video representation by bridging video content and linguistic caption at three levels: 15 | 16 | 1. Entity level, which highlights objects that are most likely to be mentioned in captions and is supervised by *entities* in ground-truth captions. 17 | 2. Predicate level, which learns the actions conditioned on highlighted objects and is supervised by the *predicate* in the ground-truth caption. 18 | 3. Sentence level, which learns the global video representation supervised by the whole ground-truth *sentence*. 19 | 20 | As there are a large number of objects in a video, but only a few are mentioned in captions, we proposed a novel entity module to learn to highlight these principal objects adaptively. Experimental results demonstrate that highlighting principal video objects will improve the performance significantly. 21 | 22 | 23 | 24 | ## Methodology 25 | 26 | As shown in Figure 2, our model follows the conventional **Encoder-Decoder** paradigm, where the proposed Hierarchical Modular Network (HMN) serves as the encoder. HMN consists of the entity, predicate, and sentence modules. These modules are designed to bridge video representations and linguistic semantics from three levels. Our model operates as follows. First, taking all detected objects as input, the entity module outputs the features of principal objects. The predicate module encodes actions by combining features of principal objects and the video motion. Next, the sentence module encodes a global representation for the entire video content considering the global context and features of previously generated objects and actions. Finally, all features are concatenated together and fed into the decoder to generate captions. Each module has its own input and linguistic supervision extracted from captions. 27 | 28 | HMN 29 | 30 |
Figure 2. Hierarchical Modular Network
31 | 32 | Figure 3 illustrates the main architecture of our entity module, which consists of a transformer encoder and transformer decoder. This design is motivated by [DETR](https://arxiv.org/abs/2005.12872), which utilizes a transformer encoder-decoder architecture to learn a fixed set of object queries to directly predict object bounding boxes for the object detection task. Instead of simply detecting objects, we aim to determine the important ones in the video. 33 | 34 | Entity 35 | 36 |
Figure 3. Main architecture of the entity module
37 | 38 | 39 | 40 | ## Usage 41 | 42 | Our proposed HMN is implemented with PyTorch. 43 | 44 | #### Environment 45 | 46 | - Python = 3.7 47 | - PyTorch = 1.4 48 | 49 | 50 | 51 | #### 1.Installation 52 | 53 | - Clone this repo: 54 | 55 | ``` 56 | git clone https://github.com/MarcusNerva/HMN.git 57 | cd HMN 58 | ``` 59 | 60 | + Clone a python3-version coco_caption repo under the utilis/ 61 | 62 | 63 | 64 | #### 2.Download datasets 65 | 66 | **MSR-VTT Dataset:** 67 | 68 | - Context features (2D CNN features) : [MSRVTT-InceptionResNetV2](https://1drv.ms/u/s!ArYBhHmSAbFOc20zPEg-aSP7_cI?e=fhT1lN) 69 | - Motion features (3D CNN features) : [MSRVTT-C3D](https://1drv.ms/u/s!ArYBhHmSAbFOdKU9iZgHFGFHCAE?e=H5DyOE) 70 | - Object features (Extracted by Faster-RCNN) : [MSRVTT-Faster-RCNN](https://1drv.ms/u/s!ArYBhHmSAbFOdVQnfilWp6_epv4?e=Am9OXT) 71 | - Linguistic supervision: [MSRVTT-Language](https://1drv.ms/u/s!ArYBhHmSAbFOe0dX-SBDdxJ9RHM?e=ZlNbBQ) 72 | - Splits: [MSRVTT-Splits](https://1drv.ms/u/s!ArYBhHmSAbFOgURAYMc4f2-TeI0U?e=tde70k) 73 | 74 | **MSVD Dataset:** 75 | 76 | - Context features (2D CNN features) : [MSVD-InceptionResNetV2](https://1drv.ms/u/s!ArYBhHmSAbFOeMT-jksQPhkzYHA?e=mO2DTu) 77 | - Motion features (3D CNN features) : [MSVD-C3D](https://1drv.ms/u/s!ArYBhHmSAbFOd8H6ciT2CYwqFaE?e=VeWdS8) 78 | - Object features (Extracted by Faster-RCNN) : [MSVD-Faster-RCNN](https://1drv.ms/u/s!ArYBhHmSAbFOef5wZTxndFlz7bQ?e=fBPFHG) 79 | - Linguistic supervision: [MSVD-Language](https://1drv.ms/u/s!ArYBhHmSAbFOetaEHJnITH8q-eE?e=ePZlcn) 80 | - Splits: [MSVD-Splits](https://1drv.ms/u/s!ArYBhHmSAbFOgUj3RjfW982_KntY?e=hPMWYI) 81 | 82 | 83 | 84 | #### 3.Prepare training data 85 | 86 | - Organize visual and linguistic features under `data/` 87 | 88 | ```bash 89 | data 90 | ├── __init__.py 91 | ├── loader 92 | │   ├── data_loader.py 93 | │   └── __init__.py 94 | ├── MSRVTT 95 | │   ├── language 96 | │   │   ├── embedding_weights.pkl 97 | │   │   ├── idx2word.pkl 98 | │   │   ├── vid2groundtruth.pkl 99 | │   │   ├── vid2language.pkl 100 | │   │   ├── word2idx.pkl 101 | │   │   └── vid2fillmask_MSRVTT.pkl 102 | │   ├── MSRVTT_splits 103 | │   │   ├── MSRVTT_test_list.pkl 104 | │   │   ├── MSRVTT_train_list.pkl 105 | │   │   └── MSRVTT_valid_list.pkl 106 | │   └── visual 107 | │   ├── MSRVTT_C3D_test.hdf5 108 | │   ├── MSRVTT_C3D_train.hdf5 109 | │   ├── MSRVTT_C3D_valid.hdf5 110 | │   ├── MSRVTT_inceptionresnetv2_test.hdf5 111 | │   ├── MSRVTT_inceptionresnetv2_train.hdf5 112 | │   ├── MSRVTT_inceptionresnetv2_valid.hdf5 113 | │   ├── MSRVTT_vg_objects_test.hdf5 114 | │   ├── MSRVTT_vg_objects_train.hdf5 115 | │   └── MSRVTT_vg_objects_valid.hdf5 116 | └── MSVD 117 | ├── language 118 | │   ├── embedding_weights.pkl 119 | │   ├── idx2word.pkl 120 | │   ├── vid2groundtruth.pkl 121 | │   ├── vid2language.pkl 122 |    │   ├── word2idx.pkl 123 |    │   └── vid2fillmask_MSVD.pkl 124 | ├── MSVD_splits 125 | │   ├── MSVD_test_list.pkl 126 | │   ├── MSVD_train_list.pkl 127 | │   └── MSVD_valid_list.pkl 128 | └── visual 129 | ├── MSVD_C3D_test.hdf5 130 | ├── MSVD_C3D_train.hdf5 131 | ├── MSVD_C3D_valid.hdf5 132 | ├── MSVD_inceptionresnetv2_test.hdf5 133 | ├── MSVD_inceptionresnetv2_train.hdf5 134 | ├── MSVD_inceptionresnetv2_valid.hdf5 135 | ├── MSVD_vg_objects_test.hdf5 136 | ├── MSVD_vg_objects_train.hdf5 137 | └── MSVD_vg_objects_valid.hdf5 138 | ``` 139 | 140 | 141 | 142 | ## Pretrained Model 143 | 144 | [Pretrained model on MSR-VTT](https://1drv.ms/u/s!ArYBhHmSAbFOgTicz9UR4ljs_JD4?e=YC8gKW) 145 | 146 | [Pretrained model on MSVD](https://1drv.ms/u/s!ArYBhHmSAbFOgT1ahP77Tij-6yUy?e=KeoDZV) 147 | 148 | Download pretrained model on MSR-VTT and MSVD via above links, and place them undir checkpoints dir: 149 | 150 | ``` 151 | mkdir -p checkpoints/MSRVTT 152 | mkdir -p checkpoints/MSVD 153 | ``` 154 | 155 | Finally got: 156 | 157 | ``` 158 | checkpoints/ 159 | ├── MSRVTT 160 | │ └── HMN_MSRVTT_model.ckpt 161 | └── MSVD 162 | └── HMN_MSVD_model.ckpt 163 | ``` 164 | 165 | 166 | 167 | ## Training & Testing 168 | 169 | #### Training: MSR-VTT 170 | 171 | ```bash 172 | python -u main.py --dataset_name MSRVTT --entity_encoder_layer 3 --entity_decoder_layer 3 --max_objects 9 \ 173 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \ 174 | --backbone_3d_name C3D --backbone_3d_dim 2048 \ 175 | --object_name vg_objects --object_dim 2048 \ 176 | --max_epochs 16 --save_checkpoints_every 500 \ 177 | --data_dir ./data --model_name HMN 178 | --language_dir_name language \ 179 | --learning_rate 7e-5 --lambda_entity 0.1 --lambda_predicate 6.9 --lambda_sentence 6.9 --lambda_soft 3.5 180 | ``` 181 | 182 | 183 | 184 | #### Training: MSVD 185 | 186 | ```bash 187 | python -u main.py --dataset_name MSVD --entity_encoder_layer 2 --entity_decoder_layer 2 --max_objects 8 \ 188 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \ 189 | --backbone_3d_name C3D --backbone_3d_dim 2048 \ 190 | --object_name vg_objects --object_dim 2048 \ 191 | --max_epochs 20 --save_checkpoints_every 500 \ 192 | --data_dir ./data --model_name HMN \ 193 | --language_dir_name language --language_package_name vid2language_old \ 194 | --learning_rate 1e-4 --lambda_entity 0.6 --lambda_predicate 0.3 --lambda_sentence 1.0 --lambda_soft 0.5 195 | ``` 196 | 197 | 198 | 199 | #### Testing MSR-VTT & MSVD 200 | 201 | Comment out `train_fn` in `main.py` first. 202 | 203 | ```python 204 | model = train_fn(cfgs, cfgs.model_name, model, hungary_matcher, train_loader, valid_loader, device) 205 | ``` 206 | 207 | 208 | 209 | For MSR-VTT: 210 | 211 | ``` 212 | python3 main.py --dataset_name MSRVTT \ 213 | --entity_encoder_layer 3 --entity_decoder_layer 3 --max_objects 9 \ 214 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \ 215 | --backbone_3d_name C3D --backbone_3d_dim 2048 \ 216 | --object_name vg_objects --object_dim 2048 \ 217 | --max_epochs 16 --save_checkpoints_every 500 \ 218 | --data_dir ./data --model_name HMN --learning_rate 7e-5 \ 219 | --lambda_entity 0.1 --lambda_predicate 6.9 --lambda_sentence 6.9 \ 220 | --lambda_soft 3.5 \ 221 | --save_checkpoints_path checkpoints/MSRVTT/HMN_MSRVTT_model.ckpt 222 | ``` 223 | 224 | Get performance: 225 | 226 | MSRVTT-performance 227 | 228 | 229 | 230 | For MSVD: 231 | 232 | ``` 233 | python3 main.py --dataset_name MSVD \ 234 | --entity_encoder_layer 2 --entity_decoder_layer 2 --max_objects 8 \ 235 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \ 236 | --backbone_3d_name C3D --backbone_3d_dim 2048 \ 237 | --object_name vg_objects --object_dim 2048 \ 238 | --max_epochs 20 --save_checkpoints_every 500 \ 239 | --data_dir ./data --model_name HMN --learning_rate 1e-4 \ 240 | --lambda_entity 0.6 --lambda_predicate 0.3 --lambda_a_sentence 1.0 \ 241 | --lambda_soft 0.5 \ 242 | --save_checkpoints_path checkpoints/MSVD/HMN_MSVD_model.ckpt 243 | ``` 244 | 245 | Get performance: 246 | 247 | MSVD-performance 248 | 249 | 250 | 251 | ## Citation 252 | 253 | If our research and this repository are helpful to your work, please cite with: 254 | 255 | ``` 256 | @InProceedings{Ye_2022_CVPR, 257 | author = {Ye, Hanhua and Li, Guorong and Qi, Yuankai and Wang, Shuhui and Huang, Qingming and Yang, Ming-Hsuan}, 258 | title = {Hierarchical Modular Network for Video Captioning}, 259 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 260 | month = {June}, 261 | year = {2022}, 262 | pages = {17939-17948} 263 | } 264 | ``` 265 | 266 | 267 | 268 | ## Acknowledge 269 | 270 | Code of the decoding part is based on [POS-CG](https://github.com/vsislab/Controllable_XGating). 271 | 272 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/configs/__init__.py -------------------------------------------------------------------------------- /configs/settings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | __all__ = ['TotalConfigs', 'get_settings'] 4 | 5 | 6 | def _settings(): 7 | parser = argparse.ArgumentParser() 8 | 9 | """ 10 | =========================General Settings=========================== 11 | """ 12 | parser.add_argument('--seed', type=int, default=1) 13 | parser.add_argument('--drop_prob', type=float, default=0.5) 14 | parser.add_argument('--bsz', type=int, default=64, help='batch size') 15 | parser.add_argument('--sample_numb', type=int, default=15, help='how many frames would you like to sample from a given video') 16 | parser.add_argument('--model_name', type=str, default='HMN', help='which model you would like to train/test?') 17 | 18 | """ 19 | =========================Data Settings=========================== 20 | """ 21 | parser.add_argument('--data_dir', type=str, default='./data') 22 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints') 23 | parser.add_argument('--result_dir', type=str, default='-1') 24 | parser.add_argument('--dataset_name', type=str, default='-1') 25 | parser.add_argument('--backbone_2d_name', type=str, default='-1', help='2d backbone name (InceptionResNetV2)') 26 | parser.add_argument('--backbone_3d_name', type=str, default='-1', help='3d backbone name (C3D)') 27 | parser.add_argument('--object_name', type=str, default='-1', help='object features name (vg_objects)') 28 | parser.add_argument('--semantics_dim', type=int, default=768, help='semantics embedding dim') 29 | 30 | """ 31 | =========================Encoder Settings=========================== 32 | """ 33 | parser.add_argument('--backbone_2d_dim', type=int, default=2048, help='dimention for inceptionresnetv2') 34 | parser.add_argument('--backbone_3d_dim', type=int, default=2048, help='dimention for C3D') 35 | parser.add_argument('--object_dim', type=int, default=2048, help='dimention for vg_objects') 36 | parser.add_argument('--max_objects', type=int, default=8) 37 | 38 | parser.add_argument('--nheads', type=int, default=8) 39 | parser.add_argument('--entity_encoder_layer', type=int, default=2) 40 | parser.add_argument('--entity_decoder_layer', type=int, default=2) 41 | parser.add_argument('--dim_feedforward', type=int, default=2048) 42 | parser.add_argument('--transformer_activation', type=str, default='relu') 43 | parser.add_argument('--d_model', type=int, default=512) 44 | parser.add_argument('--transformer_dropout', type=float, default=0.1) 45 | 46 | """ 47 | =========================Decoder Settings=========================== 48 | """ 49 | parser.add_argument('--word_embedding_dim', type=int, default=300) 50 | parser.add_argument('--hidden_dim', type=int, default=512) 51 | parser.add_argument('--num_layers', type=int, default=1) 52 | 53 | 54 | """ 55 | =========================Word Dict Settings=========================== 56 | """ 57 | parser.add_argument('--eos_idx', type=int, default=0) 58 | parser.add_argument('--sos_idx', type=int, default=1) 59 | parser.add_argument('--unk_idx', type=int, default=2) 60 | parser.add_argument('--n_vocab', type=int, default=-1, help='how many different words are there in the dataset') 61 | 62 | """ 63 | =========================Training Settings=========================== 64 | """ 65 | parser.add_argument('--grad_clip', type=float, default=5.0) 66 | parser.add_argument('--learning_rate', type=float, default=1e-4) 67 | parser.add_argument('--lambda_entity', type=float, default=1.0) 68 | parser.add_argument('--lambda_predicate', type=float, default=1.0) 69 | parser.add_argument('--lambda_sentence', type=float, default=1.0) 70 | parser.add_argument('--lambda_soft', type=float, default=0.1) 71 | parser.add_argument('--max_epochs', type=int, default=20) 72 | parser.add_argument('--visualize_every', type=int, default=10) 73 | parser.add_argument('--save_checkpoints_every', type=int, default=200) 74 | parser.add_argument('--save_checkpoints_path', type=str, default='-1') 75 | 76 | """ 77 | =========================Testing Settings=========================== 78 | """ 79 | parser.add_argument('--beam_size', type=int, default=5) 80 | parser.add_argument('--max_caption_len', type=int, default=20 + 2) 81 | parser.add_argument('--temperature', type=float, default=1.0) 82 | parser.add_argument('--result_path', type=str, default='-1') 83 | 84 | args = parser.parse_args() 85 | return args 86 | 87 | 88 | class TotalConfigs: 89 | def __init__(self, args): 90 | self.data = DataConfigs(args) 91 | self.dict = DictConfigs(args) 92 | self.encoder = EncoderConfigs(args) 93 | self.decoder = DecoderConfigs(args) 94 | self.train = TrainingConfigs(args) 95 | self.test = TestConfigs(args) 96 | 97 | self.seed = args.seed 98 | self.bsz = args.bsz 99 | self.drop_prob = args.drop_prob 100 | self.model_name = args.model_name 101 | self.sample_numb = args.sample_numb 102 | 103 | 104 | class DataConfigs: 105 | def __init__(self, args): 106 | self.data_dir = args.data_dir 107 | self.checkpoints_dir = args.checkpoints_dir 108 | self.dataset_name = args.dataset_name 109 | self.backbone_2d_name = args.backbone_2d_name 110 | self.backbone_3d_name = args.backbone_3d_name 111 | self.object_name = args.object_name 112 | self.word_dim = args.word_embedding_dim 113 | 114 | assert self.dataset_name != '-1', 'Please set argument dataset_name' 115 | assert self.backbone_2d_name != '-1', 'Please set argument backbone_2d_name' 116 | assert self.backbone_3d_name != '-1', 'Please set argument backbone_3d_name' 117 | assert self.object_name != '-1', 'Please set argument object_name' 118 | 119 | # data dir 120 | self.data_dir = os.path.join(self.data_dir, self.dataset_name) 121 | 122 | # language part 123 | self.language_dir = os.path.join(self.data_dir, 'language') 124 | self.vid2language_path = os.path.join(self.language_dir, 'vid2language.pkl') 125 | self.vid2fillmask_path = os.path.join(self.data_dir, 'vid2fillmask_{}.pkl'.format(self.dataset_name)) 126 | self.word2idx_path = os.path.join(self.language_dir, 'word2idx.pkl') 127 | self.idx2word_path = os.path.join(self.language_dir, 'idx2word.pkl') 128 | self.embedding_weights_path = os.path.join(self.language_dir, 'embedding_weights.pkl') 129 | self.vid2groundtruth_path = os.path.join(self.language_dir, 'vid2groundtruth.pkl') 130 | 131 | # visual part 132 | self.visual_dir = os.path.join(self.data_dir, 'visual') 133 | self.backbone2d_path_tpl = os.path.join(self.visual_dir, '{}_{}_{}.hdf5'.format(args.dataset_name, args.backbone_2d_name, '{}')) 134 | self.backbone3d_path_tpl = os.path.join(self.visual_dir, '{}_{}_{}.hdf5'.format(args.dataset_name, args.backbone_3d_name, '{}')) 135 | self.objects_path_tpl = os.path.join(self.visual_dir, '{}_{}_{}.hdf5'.format(args.dataset_name, args.object_name, '{}')) 136 | 137 | # dataset split part 138 | self.split_dir = os.path.join(self.data_dir, '{dataset_name}_splits'.format(dataset_name=self.dataset_name)) 139 | self.videos_split_path_tpl = os.path.join(self.split_dir, '{}_{}_list.pkl'.format(self.dataset_name, '{}')) 140 | 141 | 142 | class DictConfigs: 143 | def __init__(self, args): 144 | self.eos_idx = args.eos_idx 145 | self.sos_idx = args.sos_idx 146 | self.unk_idx = args.unk_idx 147 | self.n_vocab = args.n_vocab 148 | 149 | 150 | class EncoderConfigs: 151 | def __init__(self, args): 152 | self.backbone_2d_dim = args.backbone_2d_dim 153 | self.backbone_3d_dim = args.backbone_3d_dim 154 | self.semantics_dim = args.semantics_dim 155 | self.object_dim = args.object_dim 156 | self.max_objects = args.max_objects 157 | 158 | self.nheads = args.nheads 159 | self.entity_encoder_layer = args.entity_encoder_layer 160 | self.entity_decoder_layer = args.entity_decoder_layer 161 | self.dim_feedforward = args.dim_feedforward 162 | self.transformer_activation = args.transformer_activation 163 | self.d_model = args.d_model 164 | self.trans_dropout = args.transformer_dropout 165 | 166 | 167 | class DecoderConfigs: 168 | def __init__(self, args): 169 | self.hidden_dim = args.hidden_dim 170 | self.num_layers = args.num_layers 171 | self.n_vocab = -1 172 | 173 | 174 | class TrainingConfigs: 175 | def __init__(self, args): 176 | self.grad_clip = args.grad_clip 177 | self.learning_rate = args.learning_rate 178 | self.lambda_entity = args.lambda_entity 179 | self.lambda_predicate = args.lambda_predicate 180 | self.lambda_sentence = args.lambda_sentence 181 | self.lambda_soft = args.lambda_soft 182 | self.max_epochs = args.max_epochs 183 | self.visualize_every = args.visualize_every 184 | self.checkpoints_dir = os.path.join(args.checkpoints_dir, args.dataset_name) 185 | self.save_checkpoints_every = args.save_checkpoints_every 186 | self.save_checkpoints_path = os.path.join(self.checkpoints_dir, 187 | '{model_name}_epochs_{max_epochs}_lr_{lr}_entity_{obj}_predicate_{act}_sentence_{v}_soft{s}_ne_{ne}_nd_{nd}_max_objects_{mo}.ckpt'.format( 188 | model_name=args.model_name, 189 | max_epochs=args.max_epochs, 190 | lr=self.learning_rate, 191 | obj=self.lambda_entity, 192 | act=self.lambda_predicate, 193 | v=self.lambda_sentence, 194 | s=self.lambda_soft, 195 | ne=args.entity_encoder_layer, 196 | nd=args.entity_decoder_layer, 197 | mo=args.max_objects)) 198 | if not os.path.exists(self.checkpoints_dir): 199 | os.makedirs(self.checkpoints_dir) 200 | if args.save_checkpoints_path != '-1': 201 | self.save_checkpoints_path = args.save_checkpoints_path 202 | 203 | 204 | class TestConfigs: 205 | def __init__(self, args): 206 | self.beam_size = args.beam_size 207 | self.max_caption_len = args.max_caption_len 208 | self.temperature = args.temperature 209 | self.result_dir = os.path.join('./results/{dataset_name}'.format(dataset_name=args.dataset_name)) 210 | if args.result_dir != '-1': 211 | self.result_dir = args.result_dir 212 | if not os.path.exists(self.result_dir): 213 | os.makedirs(self.result_dir) 214 | self.result_path = os.path.join( 215 | self.result_dir, 216 | '{model_name}_epochs_{max_epochs}_lr_{lr}_entity_{obj}_predicate_{act}_sentence_{v}_soft_{s}_ne_{ne}_nd_{nd}_max_objects_{mo}.pkl'.format( 217 | model_name=args.model_name, 218 | max_epochs=args.max_epochs, 219 | lr=args.learning_rate, 220 | obj=args.lambda_entity, 221 | act=args.lambda_predicate, 222 | v=args.lambda_sentence, 223 | s=args.lambda_soft, 224 | ne=args.entity_encoder_layer, 225 | nd=args.entity_decoder_layer, 226 | mo=args.max_objects) 227 | ) 228 | if not os.path.exists(self.result_dir): 229 | os.makedirs(self.result_dir) 230 | if args.result_path != '-1': 231 | self.result_path = args.result_path 232 | 233 | 234 | def get_settings(): 235 | args = _settings() 236 | configs = TotalConfigs(args=args) 237 | return configs 238 | 239 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /data/loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/data/loader/__init__.py -------------------------------------------------------------------------------- /data/loader/data_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os 5 | import h5py 6 | import pickle 7 | from collections import defaultdict 8 | import sys 9 | sys.path.append('../../') 10 | from configs.settings import TotalConfigs 11 | 12 | 13 | def get_ids_and_probs(fillmask_steps, max_caption_len): 14 | if fillmask_steps is None: 15 | return None, None, None 16 | 17 | ret_ids, ret_probs = [], [] 18 | ret_mask = torch.zeros(max_caption_len) 19 | 20 | for step in fillmask_steps: 21 | step_ids, step_probs = zip(*step) 22 | step_ids, step_probs = torch.Tensor(step_ids).long(), torch.Tensor(step_probs).float() 23 | ret_ids.append(step_ids) 24 | ret_probs.append(step_probs) 25 | gap = max_caption_len - len(fillmask_steps) 26 | for i in range(gap): 27 | zero_ids, zero_probs = torch.zeros(50).long(), torch.zeros(50).float() 28 | ret_ids.append(zero_ids) 29 | ret_probs.append(zero_probs) 30 | 31 | ret_ids = torch.cat([item[None, ...] for item in ret_ids], dim=0) 32 | ret_probs = torch.cat([item[None, ...] for item in ret_probs], dim=0) 33 | ret_mask[:len(fillmask_steps)] = 1 34 | 35 | return ret_ids, ret_probs, ret_mask 36 | 37 | 38 | class CaptionDataset(Dataset): 39 | def __init__(self, cfgs: TotalConfigs, mode, save_on_disk=False, is_total=False): 40 | """ 41 | Args: 42 | args: configurations. 43 | mode: train/valid/test. 44 | save_on_disk: whether save the prediction on disk or not. 45 | True->Each video only appears once. 46 | False->The number of times each video appears depends on 47 | the number of its corresponding captions. 48 | """ 49 | super(CaptionDataset, self).__init__() 50 | self.mode = mode 51 | self.save_on_disk = save_on_disk 52 | self.is_total = is_total 53 | sample_numb = cfgs.sample_numb # how many frames are sampled to perform video captioning? 54 | max_caption_len = cfgs.test.max_caption_len 55 | 56 | # language part 57 | vid2language_path = cfgs.data.vid2language_path 58 | vid2fillmask_path = cfgs.data.vid2fillmask_path 59 | 60 | # visual part 61 | backbone2d_path = cfgs.data.backbone2d_path_tpl.format(mode) 62 | backbone3d_path = cfgs.data.backbone3d_path_tpl.format(mode) 63 | objects_path = cfgs.data.objects_path_tpl.format(mode) 64 | 65 | # dataset split part 66 | videos_split_path = cfgs.data.videos_split_path_tpl.format(mode) 67 | 68 | with open(videos_split_path, 'rb') as f: 69 | video_ids = pickle.load(f) 70 | 71 | self.video_ids = video_ids 72 | self.corresponding_vid = [] 73 | 74 | self.backbone_2d_dict = {} 75 | self.backbone_3d_dict = {} 76 | self.objects_dict = {} 77 | self.total_entries = [] # (numberic words, original caption) 78 | self.vid2captions = defaultdict(list) 79 | 80 | # feature 2d dict 81 | with h5py.File(backbone2d_path, 'r') as f: 82 | for vid in video_ids: 83 | temp_feat = f[vid][()] 84 | sampled_idxs = np.linspace(0, len(temp_feat) - 1, sample_numb, dtype=int) 85 | self.backbone_2d_dict[vid] = temp_feat[sampled_idxs] 86 | 87 | # feature 3d dict 88 | with h5py.File(backbone3d_path, 'r') as f: 89 | for vid in video_ids: 90 | temp_feat = f[vid][()] 91 | sampled_idxs = np.linspace(0, len(temp_feat) - 1, sample_numb, dtype=int) 92 | self.backbone_3d_dict[vid] = temp_feat[sampled_idxs] 93 | 94 | # feature object dict 95 | with h5py.File(objects_path, 'r') as f: 96 | for vid in video_ids: 97 | temp_feat = f[vid]['feats'][()] 98 | self.objects_dict[vid] = temp_feat 99 | 100 | with open(vid2language_path, 'rb') as f: 101 | self.vid2language = pickle.load(f) 102 | 103 | if cfgs.train.lambda_soft > 0 and not save_on_disk: 104 | with open(vid2fillmask_path, 'rb') as f: 105 | self.vid2fillmask = pickle.load(f) 106 | 107 | for vid in video_ids: 108 | fillmask_dict = self.vid2fillmask[vid] if cfgs.train.lambda_soft > 0 and not save_on_disk and vid in self.vid2fillmask else None 109 | for item in self.vid2language[vid]: 110 | caption, numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec = item 111 | current_mask = fillmask_dict[caption] if fillmask_dict is not None else None 112 | vocab_ids, vocab_probs, fillmasks = get_ids_and_probs(current_mask, max_caption_len) 113 | self.total_entries.append((numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec, vocab_ids, vocab_probs, fillmasks)) 114 | self.corresponding_vid.append(vid) 115 | self.vid2captions[vid].append(caption) 116 | 117 | def __getitem__(self, idx): 118 | """ 119 | Returns: 120 | feature2d: (sample_numb, dim2d) 121 | feature3d: (sample_numb, dim3d) 122 | objects: (sample_numb * object_num, dim_obj) or (object_num_per_video, dim_obj) 123 | numberic: (max_caption_len, ) 124 | captions: List[str] 125 | vid: str 126 | """ 127 | vid = self.corresponding_vid[idx] if (self.mode == 'train' and not self.save_on_disk) or self.is_total else self.video_ids[idx] 128 | choose_idx = 0 129 | 130 | feature2d = self.backbone_2d_dict[vid] 131 | feature3d = self.backbone_3d_dict[vid] 132 | objects = self.objects_dict[vid] 133 | 134 | if (self.mode == 'train' and not self.save_on_disk) or self.is_total: 135 | numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec, vocab_ids, vocab_probs, fillmasks = self.total_entries[idx] 136 | else: 137 | numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec = self.vid2language[vid][choose_idx][1:] 138 | vocab_ids, vocab_probs, fillmasks = None, None, None 139 | 140 | captions = self.vid2captions[vid] 141 | nouns_dict = {'nouns': nouns, 'vec': torch.FloatTensor(nouns_vec)} 142 | 143 | return torch.FloatTensor(feature2d), torch.FloatTensor(feature3d), torch.FloatTensor(objects), \ 144 | torch.LongTensor(numberic_cap), \ 145 | torch.FloatTensor(vp_semantics), \ 146 | torch.FloatTensor(caption_semantics), captions, nouns_dict, vid, \ 147 | vocab_ids, vocab_probs, fillmasks 148 | 149 | def __len__(self): 150 | if (self.mode == 'train' and not self.save_on_disk) or self.is_total: 151 | return len(self.total_entries) 152 | else: 153 | return len(self.video_ids) 154 | 155 | 156 | def collate_fn_caption(batch): 157 | feature2ds, feature3ds, objects, numberic_caps, \ 158 | vp_semantics, caption_semantics, captions, nouns_dict_list, vids, \ 159 | vocab_ids, vocab_probs, fillmasks = zip(*batch) 160 | 161 | bsz, obj_dim = len(feature2ds), objects[0].shape[-1] 162 | longest_objects_num = max([item.shape[0] for item in objects]) 163 | ret_objects = torch.zeros([bsz, longest_objects_num, obj_dim]) 164 | ret_objects_mask = torch.ones([bsz, longest_objects_num]) 165 | for i in range(bsz): 166 | ret_objects[i, :objects[i].shape[0], :] = objects[i] 167 | ret_objects_mask[i, :objects[i].shape[0]] = 0.0 168 | 169 | feature2ds = torch.cat([item[None, ...] for item in feature2ds], dim=0) # (bsz, sample_numb, dim_2d) 170 | feature3ds = torch.cat([item[None, ...] for item in feature3ds], dim=0) # (bsz, sample_numb, dim_3d) 171 | 172 | vp_semantics = torch.cat([item[None, ...] for item in vp_semantics], dim=0) # (bsz, dim_sem) 173 | caption_semantics = torch.cat([item[None, ...] for item in caption_semantics], dim=0) # (bsz, dim_sem) 174 | 175 | numberic_caps = torch.cat([item[None, ...] for item in numberic_caps], dim=0) # (bsz, seq_len) 176 | masks = numberic_caps > 0 177 | 178 | captions = [item for item in captions] 179 | nouns = list(nouns_dict_list) 180 | vids = list(vids) 181 | vocab_ids = torch.cat([item[None, ...] for item in vocab_ids], dim=0).long() if vocab_ids[0] is not None else None # (bsz, seq_len, 50) 182 | vocab_probs = torch.cat([item[None, ...] for item in vocab_probs], dim=0).float() if vocab_probs[0] is not None else None # (bsz, seq_len, 50) 183 | fillmasks = torch.cat([item[None, ...] for item in fillmasks], dim=0).float() if fillmasks[0] is not None else None # (bsz, seq_len) 184 | 185 | return feature2ds.float(), feature3ds.float(), ret_objects.float(), ret_objects_mask.float(), \ 186 | vp_semantics.float(), caption_semantics.float(), \ 187 | numberic_caps.long(), masks.float(), captions, nouns, vids, \ 188 | vocab_ids, vocab_probs, fillmasks 189 | 190 | 191 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | from collections import OrderedDict, defaultdict 4 | 5 | from utils.coco_caption.pycocoevalcap.bleu.bleu import Bleu 6 | from utils.coco_caption.pycocoevalcap.cider.cider import Cider 7 | from utils.coco_caption.pycocoevalcap.meteor.meteor import Meteor 8 | from utils.coco_caption.pycocoevalcap.rouge.rouge import Rouge 9 | from configs.settings import TotalConfigs 10 | 11 | def language_eval(sample_seqs, groundtruth_seqs): 12 | assert len(sample_seqs) == len(groundtruth_seqs), 'length of sampled seqs is different from that of groundtruth seqs!' 13 | 14 | references, predictions = OrderedDict(), OrderedDict() 15 | for i in range(len(groundtruth_seqs)): 16 | references[i] = [groundtruth_seqs[i][j] for j in range(len(groundtruth_seqs[i]))] 17 | for i in range(len(sample_seqs)): 18 | predictions[i] = [sample_seqs[i]] 19 | 20 | predictions = {i: predictions[i] for i in range(len(sample_seqs))} 21 | references = {i: references[i] for i in range(len(groundtruth_seqs))} 22 | 23 | avg_bleu_score, bleu_score = Bleu(4).compute_score(references, predictions) 24 | print('avg_bleu_score == ', avg_bleu_score) 25 | avg_cider_score, cider_score = Cider().compute_score(references, predictions) 26 | print('avg_cider_score == ', avg_cider_score) 27 | avg_meteor_score, meteor_score = Meteor().compute_score(references, predictions) 28 | print('avg_meteor_score == ', avg_meteor_score) 29 | avg_rouge_score, rouge_score = Rouge().compute_score(references, predictions) 30 | print('avg_rouge_score == ', avg_rouge_score) 31 | 32 | return {'BLEU': avg_bleu_score, 'CIDEr': avg_cider_score, 'METEOR': avg_meteor_score, 'ROUGE': avg_rouge_score} 33 | 34 | 35 | def decode_idx(seq, itow, eos_idx): 36 | ret = '' 37 | length = seq.shape[0] 38 | for i in range(length): 39 | if seq[i] == eos_idx: break 40 | if i > 0: ret += ' ' 41 | ret += itow[seq[i]] 42 | return ret 43 | 44 | 45 | @torch.no_grad() 46 | def eval_fn(model, loader, device, idx2word, save_on_disk, cfgs: TotalConfigs, vid2groundtruth)->dict: 47 | model.eval() 48 | if save_on_disk: 49 | result_dict = {} 50 | predictions, gts = [], [] 51 | 52 | for i, (feature2ds, feature3ds, object_feats, object_masks, \ 53 | vp_semantics, caption_semantics, numberic_caps, masks, \ 54 | captions, nouns_dict_list, vids, vocab_ids, vocab_probs, fillmasks) \ 55 | in enumerate(loader): 56 | feature2ds = feature2ds.to(device) 57 | feature3ds = feature3ds.to(device) 58 | object_feats = object_feats.to(device) 59 | object_masks = object_masks.to(device) 60 | vp_semantics = vp_semantics.to(device) 61 | caption_semantics = caption_semantics.to(device) 62 | numberic_caps = numberic_caps.to(device) 63 | masks = masks.to(device) 64 | 65 | pred, seq_probabilities = model.sample(object_feats, object_masks, feature2ds, feature3ds) 66 | pred = pred.cpu().numpy() 67 | batch_pred = [decode_idx(single_seq, idx2word, cfgs.dict.eos_idx) for single_seq in pred] 68 | predictions += batch_pred 69 | batch_gts = [vid2groundtruth[id] for id in vids] if save_on_disk else [item for item in captions] 70 | gts += batch_gts 71 | 72 | if save_on_disk: 73 | assert len(batch_pred) == len(vids), \ 74 | 'expect len(batch_pred) == len(vids), ' \ 75 | 'but got len(batch_pred) == {} and len(vids) == {}'.format(len(batch_pred), len(vids)) 76 | for vid, pred in zip(vids, batch_pred): 77 | result_dict[vid] = pred 78 | 79 | model.train() 80 | score_states = language_eval(sample_seqs=predictions, groundtruth_seqs=gts) 81 | 82 | if save_on_disk: 83 | with open(cfgs.test.result_path, 'wb') as f: 84 | pickle.dump(result_dict, f) 85 | 86 | return score_states 87 | 88 | -------------------------------------------------------------------------------- /figures/Entity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/Entity.png -------------------------------------------------------------------------------- /figures/HMN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/HMN.png -------------------------------------------------------------------------------- /figures/MSRVTT-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/MSRVTT-performance.png -------------------------------------------------------------------------------- /figures/MSVD-performance.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/MSVD-performance.png -------------------------------------------------------------------------------- /figures/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/motivation.png -------------------------------------------------------------------------------- /hmn.yaml: -------------------------------------------------------------------------------- 1 | name: hmn_env 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2022.07.19=h06a4308_0 10 | - certifi=2022.6.15=py37h06a4308_0 11 | - cudatoolkit=10.1.243=h6bb024c_0 12 | - fftw=3.3.9=h27cfd23_1 13 | - freetype=2.11.0=h70c0345_0 14 | - giflib=5.2.1=h7b6447c_0 15 | - h5py=3.7.0=py37h737f45e_0 16 | - hdf5=1.10.6=h3ffc7dd_1 17 | - intel-openmp=2021.4.0=h06a4308_3561 18 | - jpeg=9e=h7f8727e_0 19 | - lcms2=2.12=h3be6417_0 20 | - ld_impl_linux-64=2.38=h1181459_1 21 | - lerc=3.0=h295c915_0 22 | - libdeflate=1.8=h7f8727e_5 23 | - libffi=3.3=he6710b0_2 24 | - libgcc-ng=11.2.0=h1234567_1 25 | - libgfortran-ng=11.2.0=h00389a5_1 26 | - libgfortran5=11.2.0=h1234567_1 27 | - libgomp=11.2.0=h1234567_1 28 | - libpng=1.6.37=hbc83047_0 29 | - libstdcxx-ng=11.2.0=h1234567_1 30 | - libtiff=4.4.0=hecacb30_0 31 | - libwebp=1.2.2=h55f646e_0 32 | - libwebp-base=1.2.2=h7f8727e_0 33 | - lz4-c=1.9.3=h295c915_1 34 | - mkl=2021.4.0=h06a4308_640 35 | - mkl-service=2.4.0=py37h7f8727e_0 36 | - mkl_fft=1.3.1=py37hd3c417c_0 37 | - mkl_random=1.2.2=py37h51133e4_0 38 | - ncurses=6.3=h5eee18b_3 39 | - ninja=1.10.2=h06a4308_5 40 | - ninja-base=1.10.2=hd09550d_5 41 | - numpy=1.21.5=py37h6c91a56_3 42 | - numpy-base=1.21.5=py37ha15fc14_3 43 | - openssl=1.1.1q=h7f8727e_0 44 | - pillow=9.2.0=py37hace64e9_1 45 | - pip=22.1.2=py37h06a4308_0 46 | - python=3.7.13=h12debd9_0 47 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0 48 | - readline=8.1.2=h7f8727e_1 49 | - scipy=1.7.3=py37h6c91a56_2 50 | - setuptools=63.4.1=py37h06a4308_0 51 | - six=1.16.0=pyhd3eb1b0_1 52 | - sqlite=3.39.2=h5082296_0 53 | - tk=8.6.12=h1ccaba5_0 54 | - torchvision=0.5.0=py37_cu101 55 | - wheel=0.37.1=pyhd3eb1b0_0 56 | - xz=5.2.5=h7f8727e_1 57 | - zlib=1.2.12=h5eee18b_3 58 | - zstd=1.5.2=ha4553b6_0 59 | prefix: /home/marcusnerva/anaconda3/envs/hmn_env 60 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | import numpy as np 5 | import os 6 | 7 | from train import train_fn 8 | from test import test_fn 9 | from utils.build_loaders import build_loaders 10 | from utils.build_model import build_model 11 | from configs.settings import get_settings 12 | from models.hungary import HungarianMatcher 13 | 14 | def set_random_seed(seed): 15 | random.seed(seed) 16 | os.environ['PYTHONHASHSEED'] = str(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed_all(seed) 20 | torch.backends.cudnn.deterministic = True 21 | torch.backends.cudnn.benchmark = False 22 | 23 | 24 | if __name__ == '__main__': 25 | cfgs = get_settings() 26 | set_random_seed(seed=cfgs.seed) 27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 28 | 29 | train_loader, valid_loader, test_loader = build_loaders(cfgs) 30 | 31 | hungary_matcher = HungarianMatcher() 32 | model = build_model(cfgs) 33 | model = model.float() 34 | model.to(device) 35 | 36 | model = train_fn(cfgs, cfgs.model_name, model, hungary_matcher, train_loader, valid_loader, device) 37 | model.load_state_dict(torch.load(cfgs.train.save_checkpoints_path)) 38 | model.eval() 39 | test_fn(cfgs, model, test_loader, device) 40 | 41 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/models/__init__.py -------------------------------------------------------------------------------- /models/caption_models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/caption_models/caption_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CaptionModule(nn.Module): 6 | """ 7 | CaptionModule and its child classes are complementary. 8 | """ 9 | def __init__(self, beam_size): 10 | super(CaptionModule, self).__init__() 11 | self.beam_size = beam_size 12 | 13 | def __beam_step(self, t, logprobs, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state): 14 | beam_size = self.beam_size 15 | 16 | probs, idx = torch.sort(logprobs, dim=1, descending=True) 17 | candidates = [] 18 | rows = beam_size if t >= 1 else 1 19 | cols = min(beam_size, probs.size(1)) 20 | 21 | for r in range(rows): 22 | for c in range(cols): 23 | tmp_logprob = probs[r, c] 24 | tmp_sum = beam_logprobs_sum[r] + tmp_logprob 25 | tmp_idx = idx[r, c] 26 | candidates.append({'sum': tmp_sum, 'logprob': tmp_logprob, 'ix': tmp_idx, 'beam': r}) 27 | 28 | candidates = sorted(candidates, key=lambda x: -x['sum']) 29 | prev_seq = beam_seq[:, :t].clone() 30 | prev_seq_probs = beam_seq_logprobs[:, :t].clone() 31 | prev_logprobs_sum = beam_logprobs_sum.clone() 32 | new_state = [_.clone() for _ in state] 33 | 34 | for i in range(beam_size): 35 | candidate_i = candidates[i] 36 | beam = candidate_i['beam'] 37 | ix = candidate_i['ix'] 38 | logprob = candidate_i['logprob'] 39 | 40 | beam_seq[i, :t] = prev_seq[beam, :] 41 | beam_seq_logprobs[i, :t] = prev_seq_probs[beam, :] 42 | beam_seq[i, t] = ix 43 | beam_seq_logprobs[i, t] = logprob 44 | beam_logprobs_sum[i] = prev_logprobs_sum[beam] + logprob 45 | for j in range(len(new_state)): 46 | new_state[j][:, i, :] = state[j][:, beam, :] 47 | 48 | return beam_seq, beam_seq_logprobs, beam_logprobs_sum, new_state 49 | 50 | def __beam_search(self, objects_feats, action_feats, caption_feats, objects_pending, action_pending, caption_pending, state): 51 | beam_size = self.beam_size 52 | device = caption_feats.device if caption_feats is not None else objects_feats.device 53 | 54 | beam_seq = torch.LongTensor(beam_size, self.max_caption_len).fill_(self.eos_idx) 55 | beam_seq_logprobs = torch.FloatTensor(beam_size, self.max_caption_len).zero_() 56 | beam_logprobs_sum = torch.zeros(beam_size) 57 | ret = [] 58 | 59 | it = torch.LongTensor(beam_size).fill_(self.sos_idx).to(device) 60 | it_embed = self.embedding(it) 61 | output_prob, state = self.forward_decoder(objects_feats, action_feats, caption_feats, objects_pending, action_pending, caption_pending, it_embed, state) 62 | logprob = output_prob 63 | 64 | for t in range(self.max_caption_len): 65 | # suppress UNK tokens in the decoding. So the probs of 'UNK' are extremely low 66 | logprob[:, self.unk_idx] = logprob[:, self.unk_idx] - 1000.0 67 | beam_seq, beam_seq_logprobs, beam_logprobs_sum, state = self.__beam_step(t=t, 68 | logprobs=logprob, 69 | beam_seq=beam_seq, 70 | beam_seq_logprobs=beam_seq_logprobs, 71 | beam_logprobs_sum=beam_logprobs_sum, 72 | state=state) 73 | 74 | for j in range(beam_size): 75 | if beam_seq[j, t] == self.eos_idx or t == self.max_caption_len - 1: 76 | final_beam = { 77 | 'seq': beam_seq[j, :].clone(), 78 | 'seq_logprob': beam_seq_logprobs[j, :].clone(), 79 | 'sum_logprob': beam_logprobs_sum[j].clone() 80 | } 81 | ret.append(final_beam) 82 | beam_logprobs_sum[j] = -1000.0 83 | 84 | it = beam_seq[:, t].to(device) 85 | it_embed = self.embedding(it).to(device) 86 | output_prob, state = self.forward_decoder(objects_feats, action_feats, caption_feats, objects_pending, action_pending, caption_pending, it_embed, state) 87 | logprob = output_prob 88 | 89 | ret = sorted(ret, key=lambda x: -x['sum_logprob'])[:beam_size] 90 | return ret 91 | 92 | def sample_beam(self, objects_feats, vps_feats, caption_feats, objects_semantics, action_semantics, caption_semantics): 93 | beam_size = self.beam_size 94 | batch_size = caption_feats.shape[0] if caption_feats is not None else objects_feats.shape[0] 95 | hidden_dim = caption_feats.shape[-1] if caption_feats is not None else objects_feats.shape[-1] 96 | device = caption_feats.device if caption_feats is not None else objects_feats.device 97 | 98 | seq = torch.LongTensor(batch_size, self.max_caption_len).fill_(self.eos_idx) 99 | seq_probabilities = torch.FloatTensor(batch_size, self.max_caption_len) 100 | done_beam = [[] for _ in range(batch_size)] 101 | 102 | for i in range(batch_size): 103 | single_objects_feats = objects_feats[i, ...][None, ...] if objects_feats is not None else None # (1, sample_numb, obj_per_frame, hidden_dim) 104 | single_vps_feats = vps_feats[i, ...][None, ...] if vps_feats is not None else None # (1, sample_numb, hidden_dim) 105 | single_caption_feats = caption_feats[i, ...][None, ...] if caption_feats is not None else None # (1, sample_numb, hidden_dim) 106 | single_objects_semantics = objects_semantics[i, ...][None, ...] if objects_semantics is not None else None # (1, max_objects, word_dim) 107 | single_action_semantics = action_semantics[i, ...][None, ...] if action_semantics is not None else None # (1, semantics_dim) 108 | single_caption_semantics = caption_semantics[i, ...][None, ...] if caption_semantics is not None else None # (1, semantics_dim) 109 | # print('====={}'.format(single_objects_semantics.shape)) 110 | 111 | single_objects_feats = single_objects_feats.repeat(beam_size, 1, 1) if single_objects_feats is not None else None # (beam_size, max_objects, hidden_dim) 112 | single_vps_feats = single_vps_feats.repeat(beam_size, 1, 1) if single_vps_feats is not None else None # (beam_size, sample_numb, hidden_dim) 113 | single_caption_feats = single_caption_feats.repeat(beam_size, 1, 1) if single_caption_feats is not None else None # (beam_size, sample_numb, hidden_dim) 114 | single_objects_semantics = single_objects_semantics.repeat(beam_size, 1, 1) if single_objects_semantics is not None else None # (beam_size, max_objects, word_dim) 115 | single_action_semantics = single_action_semantics.repeat(beam_size, 1) if single_action_semantics is not None else None # (beam_size, semantics_dim) 116 | single_caption_semantics = single_caption_semantics.repeat(beam_size, 1) if single_caption_semantics is not None else None # (beam_size, semantics_dim) 117 | 118 | state = self.get_rnn_init_hidden(beam_size, hidden_dim, device) 119 | 120 | done_beam[i] = self.__beam_search(single_objects_feats, single_vps_feats, 121 | single_caption_feats, single_objects_semantics, 122 | single_action_semantics, single_caption_semantics, 123 | state) 124 | seq[i, ...] = done_beam[i][0]['seq'] 125 | seq_probabilities[i, ...] = done_beam[i][0]['seq_logprob'] 126 | 127 | return seq, seq_probabilities 128 | 129 | def sample(self, objects, object_masks, feature2ds, feature3ds, is_sample_max=True): 130 | beam_size = self.beam_size 131 | temperature = self.temperature 132 | batch_size = feature2ds.shape[0] 133 | device = feature2ds.device 134 | 135 | objects_feats, action_feats, caption_feats, \ 136 | objects_semantics, action_semantics, caption_semantics = self.forward_encoder(objects, object_masks, feature2ds, feature3ds) 137 | 138 | if beam_size > 1: 139 | return self.sample_beam(objects_feats, action_feats, caption_feats, 140 | objects_semantics, action_semantics, caption_semantics) 141 | 142 | state = self.get_rnn_init_hidden(batch_size, device) 143 | seq, seq_probabilities = [], [] 144 | 145 | for t in range(self.max_caption_len): 146 | if t == 0: 147 | it = objects_feats.new(batch_size).fill_(self.sos_idx).long() 148 | elif is_sample_max: 149 | sampleLogprobs, it = torch.max(log_probabilities.detach(), 1) 150 | it = it.view(-1).long() 151 | else: 152 | prev_probabilities = torch.exp(torch.div(log_probabilities.detach(), temperature)) 153 | it = torch.multinomial(prev_probabilities, 1) 154 | sampleLogprobs = log_probabilities.gather(1, it) 155 | it = it.view(-1).long() 156 | 157 | it_embed = self.embedding(it) 158 | 159 | if t >= 1: 160 | if t == 1: 161 | unfinished = it > 0 162 | else: 163 | unfinished = unfinished * (it > 0) 164 | # if unfinished.sum() == 0: break 165 | it = it * unfinished.type_as(it) 166 | seq.append(it) 167 | seq_probabilities.append(sampleLogprobs.view(-1)) 168 | 169 | it_embed = it_embed.to(device) 170 | log_probabilities, state = self.forward_decoder(objects_feats, action_feats, caption_feats, 171 | objects_semantics, action_semantics, 172 | caption_semantics, it_embed, state) 173 | 174 | seq.append(it.new(batch_size).long().fill_(self.eos_idx)) 175 | seq_probabilities.append(sampleLogprobs.view(-1)) 176 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seq_probabilities], 1) 177 | -------------------------------------------------------------------------------- /models/caption_models/hierarchical_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.caption_models.caption_module import CaptionModule 4 | 5 | 6 | class HierarchicalModel(CaptionModule): 7 | def __init__(self, entity_level: nn.Module, predicate_level: nn.Module, sentence_level: nn.Module, 8 | decoder: nn.Module, word_embedding_weights, max_caption_len: int, beam_size: int, pad_idx=0, 9 | temperature=1, eos_idx=0, sos_idx=-1, unk_idx=-1): 10 | """ 11 | Args: 12 | entity_level: for encoding objects information. 13 | predicate_level: for encoding action information. 14 | sentence_level: for encoding the whole video information. 15 | decoder: for generating words. 16 | word_embedding_weights: pretrained word embedding weight. 17 | max_caption_len: generated sentences are no longer than max_caption_len. 18 | pad_idx: corresponding index of ''. 19 | """ 20 | super(HierarchicalModel, self).__init__(beam_size=beam_size) 21 | self.entity_level = entity_level 22 | self.predicate_level = predicate_level 23 | self.sentence_level = sentence_level 24 | self.decoder = decoder 25 | self.max_caption_len = max_caption_len 26 | self.temperature = temperature 27 | self.eos_idx = eos_idx 28 | self.sos_idx = sos_idx 29 | self.unk_idx = unk_idx 30 | self.num_layers = decoder.num_layers 31 | self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(word_embedding_weights), 32 | freeze=False, padding_idx=pad_idx) 33 | 34 | def get_rnn_init_hidden(self, bsz, hidden_size, device): 35 | # (hidden_state, cell_state) 36 | return (torch.zeros(self.num_layers, bsz, hidden_size).to(device), 37 | torch.zeros(self.num_layers, bsz, hidden_size).to(device)) 38 | 39 | def forward_encoder(self, objects, objects_mask, feature2ds, feature3ds): 40 | """ 41 | 42 | Args: 43 | objects: (bsz, max_objects_per_video, object_dim) 44 | objects_mask: (bsz, max_objects_per_video) 45 | feature2ds: (bsz, sample_numb, feature2d_dim) 46 | feature3ds: (bsz, sample_numb, feature3d_dim) 47 | 48 | Returns: 49 | objects_feats: (bsz, max_objects, hidden_dim) 50 | action_feats: (bsz, sample_numb, hidden_dim) 51 | video_feats: (bsz, sample_numb, hidden_dim) 52 | 53 | objects_semantics: (bsz, max_objects, word_dim) 54 | action_semantics: (bsz, semantics_dim) 55 | video_semantics: (bsz, semantics_dim) 56 | """ 57 | objects_feats, objects_semantics = self.entity_level(feature2ds, feature3ds, objects, objects_mask) 58 | action_feats, action_semantics = self.predicate_level(feature3ds, objects_feats, objects_mask) 59 | video_feats, video_semantics = self.sentence_level(feature2ds, action_feats, objects_feats, objects_mask) 60 | 61 | return objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics 62 | 63 | def forward_decoder(self, objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics, pre_embedding, pre_state): 64 | """ 65 | 66 | Args: 67 | objects_feats: (bsz, max_objects, hidden_dim) 68 | action_feats: (bsz, sample_numb, hidden_dim) 69 | video_feats: (bsz, sample_numb, hidden_dim) 70 | objects_semantics: (bsz, max_objects, word_dim) 71 | action_semantics: (bsz, semantics_dim) 72 | video_semantics: (bsz, semantics_dim) 73 | 74 | pre_embedding: (bsz, word_embed_dim) 75 | pre_state: (hidden_state, cell_state) 76 | 77 | Returns: 78 | output_prob: (bsz, n_vocab) 79 | current_state: (hidden_state, cell_state) 80 | """ 81 | output_prob, current_state = self.decoder(objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics, pre_embedding, pre_state) 82 | return output_prob, current_state 83 | 84 | def forward(self, objects_feats, objects_mask, feature2ds, feature3ds, numberic_captions): 85 | """ 86 | 87 | Args: 88 | numberic_captions: (bsz, max_caption_len) 89 | 90 | Returns: 91 | ret_seq: (bsz, max_caption_len, n_vocab) 92 | """ 93 | bsz, n_vocab = feature2ds.shape[0], self.decoder.n_vocab 94 | device = objects_feats.device 95 | objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics = self.forward_encoder(objects_feats, objects_mask, feature2ds, feature3ds) 96 | state = self.get_rnn_init_hidden(bsz=bsz, hidden_size=self.decoder.hidden_dim, device=device) 97 | outputs = [] 98 | 99 | for i in range(self.max_caption_len): 100 | if i > 0 and numberic_captions[:, i].sum() == 0: 101 | output_word = torch.zeros([bsz, n_vocab]).cuda() 102 | outputs.append(output_word) 103 | continue 104 | 105 | it = numberic_captions[:, i].clone() 106 | it_embeded = self.embedding(it) 107 | output_word, state = self.forward_decoder(objects_feats, action_feats, video_feats, 108 | objects_semantics, action_semantics, 109 | video_semantics, it_embeded, state) 110 | outputs.append(output_word) 111 | 112 | ret_seq = torch.stack(outputs, dim=1) 113 | return ret_seq, objects_semantics, action_semantics, video_semantics 114 | 115 | -------------------------------------------------------------------------------- /models/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Decoder(nn.Module): 5 | def __init__(self, semantics_dim, hidden_dim, num_layers, embed_dim, n_vocab, with_objects, with_action, with_video, with_objects_semantics, with_action_semantics, with_video_semantics): 6 | super(Decoder, self).__init__() 7 | self.n_vocab = n_vocab 8 | self.hidden_dim = hidden_dim 9 | self.num_layers = num_layers 10 | self.W = nn.Linear(hidden_dim, hidden_dim, bias=False) 11 | 12 | self.with_objects_semantics = with_objects_semantics 13 | self.with_action_semantics = with_action_semantics 14 | self.with_video_semantics = with_video_semantics 15 | 16 | total_visual_dim = 0 17 | total_semantics_dim = 0 18 | 19 | # objects visual features and corresponding semantics 20 | if with_objects: 21 | setattr(self, 'Uo', nn.Linear(hidden_dim, hidden_dim, bias=False)) 22 | setattr(self, 'bo', nn.Parameter(torch.ones(hidden_dim), requires_grad=True)) 23 | setattr(self, 'wo', nn.Linear(hidden_dim, 1, bias=False)) 24 | total_visual_dim += hidden_dim 25 | if with_objects_semantics: 26 | setattr(self, 'Uos', nn.Linear(semantics_dim, hidden_dim, bias=False)) 27 | setattr(self, 'bos', nn.Parameter(torch.ones(hidden_dim), requires_grad=True)) 28 | setattr(self, 'wos', nn.Linear(hidden_dim, 1, bias=False)) 29 | total_semantics_dim += semantics_dim 30 | 31 | # action visual features and corresponding semantics 32 | if with_action: 33 | setattr(self, 'Um', nn.Linear(hidden_dim, hidden_dim, bias=False)) 34 | setattr(self, 'bm', nn.Parameter(torch.ones(hidden_dim), requires_grad=True)) 35 | setattr(self, 'wm', nn.Linear(hidden_dim, 1, bias=False)) 36 | total_visual_dim += hidden_dim 37 | if with_action_semantics: 38 | total_semantics_dim += semantics_dim 39 | 40 | # video visual features and corresponding semantics 41 | if with_video: 42 | setattr(self, 'Uv', nn.Linear(hidden_dim, hidden_dim, bias=False)) 43 | setattr(self, 'bv', nn.Parameter(torch.ones(hidden_dim), requires_grad=True)) 44 | setattr(self, 'wv', nn.Linear(hidden_dim, 1, bias=False)) 45 | total_visual_dim += hidden_dim 46 | if with_video_semantics: 47 | total_semantics_dim += semantics_dim 48 | 49 | # fuse visual features together 50 | if total_visual_dim != hidden_dim: 51 | setattr(self, 'linear_visual_layer', nn.Linear(total_visual_dim, hidden_dim)) 52 | 53 | # fuse semantics features together 54 | if with_objects_semantics or with_action_semantics or with_video_semantics: 55 | setattr(self, 'linear_semantics_layer', nn.Linear(total_semantics_dim, hidden_dim)) 56 | 57 | with_semantics = with_objects_semantics or with_action_semantics or with_video_semantics 58 | self.lstm = nn.LSTM(input_size=hidden_dim * 2 + embed_dim if with_semantics else hidden_dim + embed_dim, 59 | hidden_size=hidden_dim, 60 | num_layers=num_layers) 61 | self.to_word = nn.Linear(hidden_dim, embed_dim) 62 | self.logit = nn.Linear(embed_dim, n_vocab) 63 | self.__init_weight() 64 | 65 | def __init_weight(self): 66 | init_range = 0.1 67 | self.logit.bias.data.fill_(0) 68 | self.logit.weight.data.uniform_(-init_range, init_range) 69 | 70 | def forward(self, objects, action, video, object_semantics, action_semantics, video_semantics, embed, last_states): 71 | last_hidden = last_states[0][0] # (bsz, hidden_dim) 72 | Wh = self.W(last_hidden) # (bsz, hidden_dim) 73 | U_obj = self.Uo(objects) if hasattr(self, 'Uo') else None # (bsz, max_objects, hidden_dim) 74 | U_objs = self.Uos(object_semantics) if hasattr(self, 'Uos') else None # (bsz, max_objects, emb_dim) 75 | U_action = self.Um(action) if hasattr(self, 'Um') else None # (bsz, sample_numb, hidden_dim) 76 | U_video = self.Uv(video) if hasattr(self, 'Uv') else None # (bsz, sample_numb, hidden_dim) 77 | 78 | # for visual features 79 | if U_obj is not None: 80 | attn_weights = self.wo(torch.tanh(Wh[:, None, :] + U_obj + self.bo)) 81 | attn_weights = attn_weights.softmax(dim=1) # (bsz, max_objects, 1) 82 | attn_objects = attn_weights * objects # (bsz, max_objects, hidden_dim) 83 | attn_objects = attn_objects.sum(dim=1) # (bsz, hidden_dim) 84 | else: 85 | attn_objects = None 86 | 87 | if U_action is not None: 88 | attn_weights = self.wm(torch.tanh(Wh[:, None, :] + U_action + self.bm)) 89 | attn_weights = attn_weights.softmax(dim=1) # (bsz, sample_numb, 1) 90 | attn_motion = attn_weights * action # (bsz, sample_numb, hidden_dim) 91 | attn_motion = attn_motion.sum(dim=1) # (bsz, hidden_dim) 92 | else: 93 | attn_motion = None 94 | 95 | if U_video is not None: 96 | attn_weights = self.wv(torch.tanh(Wh[:, None, :] + U_video + self.bv)) 97 | attn_weights = attn_weights.softmax(dim=1) # (bsz, sample_numb, 1) 98 | attn_video = attn_weights * video # (bsz, sample_numb, hidden_dim) 99 | attn_video = attn_video.sum(dim=1) # (bsz, hidden_dim) 100 | else: 101 | attn_video = None 102 | 103 | feats_list = [] 104 | if attn_video is not None: 105 | feats_list.append(attn_video) 106 | if attn_motion is not None: 107 | feats_list.append(attn_motion) 108 | if attn_objects is not None: 109 | feats_list.append(attn_objects) 110 | visual_feats = torch.cat(feats_list, dim=-1) 111 | visual_feats = self.linear_visual_layer(visual_feats) if hasattr(self, 'linear_visual_layer') else visual_feats 112 | 113 | # for semantic features 114 | semantics_list = [] 115 | if self.with_objects_semantics: 116 | attn_weights = self.wos(torch.tanh(Wh[:, None, :] + U_objs + self.bos)) 117 | attn_weights = attn_weights.softmax(dim=1) # (bsz, max_objects, 1) 118 | attn_objs = attn_weights * object_semantics # (bsz, max_objects, emb_dim) 119 | attn_objs = attn_objs.sum(dim=1) # (bsz, emb_dim) 120 | semantics_list.append(attn_objs) 121 | if self.with_action_semantics: semantics_list.append(action_semantics) 122 | if self.with_video_semantics: semantics_list.append(video_semantics) 123 | semantics_feats = torch.cat(semantics_list, dim=-1) if len(semantics_list) > 0 else None 124 | semantics_feats = self.linear_semantics_layer(semantics_feats) if semantics_feats is not None else None 125 | 126 | # in addition to the lastest generated word, fuse visual features and semantic features together 127 | if semantics_feats is not None: 128 | input_feats = torch.cat([visual_feats, semantics_feats, embed], dim=-1) # (bsz, hidden_dim * 2 + embed_dim) 129 | else: 130 | input_feats = torch.cat([visual_feats, embed], dim=-1) # (bsz, hidden_dim + embed_dim) 131 | output, states = self.lstm(input_feats[None, ...], last_states) 132 | output = output.squeeze(0) # (bsz, hidden_dim) 133 | output = self.to_word(output) # (bsz, embed_dim) 134 | output_prob = self.logit(output) # (bsz, n_vocab) 135 | output_prob = torch.log_softmax(output_prob, dim=1) # (bsz, n_vocab) 136 | 137 | return output_prob, states 138 | 139 | -------------------------------------------------------------------------------- /models/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | -------------------------------------------------------------------------------- /models/encoders/entity_level.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | 6 | class EntityLevelEncoder(nn.Module): 7 | def __init__(self, transformer, max_objects, object_dim, feature2d_dim, feature3d_dim, hidden_dim, word_dim): 8 | super(EntityLevelEncoder, self).__init__() 9 | self.max_objects = max_objects 10 | 11 | self.query_embed = nn.Embedding(max_objects, hidden_dim) 12 | self.input_proj = nn.Sequential( 13 | nn.Linear(object_dim, hidden_dim * 2), 14 | nn.BatchNorm1d(hidden_dim * 2), 15 | nn.ReLU(True), 16 | nn.Dropout(0.5), 17 | nn.Linear(hidden_dim * 2, hidden_dim) 18 | ) 19 | self.feature2d_proj = nn.Sequential( 20 | nn.Linear(feature2d_dim, 2 * hidden_dim), 21 | nn.BatchNorm1d(hidden_dim * 2), 22 | nn.ReLU(True), 23 | nn.Dropout(0.5), 24 | nn.Linear(hidden_dim * 2, hidden_dim) 25 | ) 26 | self.feature3d_proj = nn.Sequential( 27 | nn.Linear(feature3d_dim, 2 * hidden_dim), 28 | nn.BatchNorm1d(hidden_dim * 2), 29 | nn.ReLU(True), 30 | nn.Dropout(0.5), 31 | nn.Linear(hidden_dim * 2, hidden_dim) 32 | ) 33 | self.bilstm = nn.LSTM(input_size=hidden_dim * 2, hidden_size=hidden_dim//2, 34 | batch_first=True, bidirectional=True) 35 | self.transformer = transformer 36 | self.fc_layer = nn.Linear(hidden_dim, word_dim) 37 | 38 | def forward(self, features_2d: Tensor, features_3d: Tensor, objects: Tensor, objects_mask: Tensor): 39 | """ 40 | 41 | Args: 42 | features_2d: (bsz, sample_numb, feature2d_dim) 43 | features_3d: (bsz, sample_numb, feature3d_dim) 44 | objects: (bsz, max_objects_per_video, object_dim) 45 | objects_mask: (bsz, max_objects_per_video) 46 | 47 | Returns: 48 | salient_info: (bsz, max_objects, hidden_dim) 49 | object_pending: (bsz, max_objects, word_dim) 50 | """ 51 | device = objects.device 52 | bsz, sample_numb, max_objects_per_video = features_2d.shape[0], features_3d.shape[1], objects.shape[1] 53 | features_2d = self.feature2d_proj(features_2d.view(-1, features_2d.shape[-1])) 54 | features_2d = features_2d.view(bsz, sample_numb, -1).contiguous() # (bsz, sample_numb, hidden_dim) 55 | features_3d = self.feature3d_proj(features_3d.view(-1, features_3d.shape[-1])) 56 | features_3d = features_3d.view(bsz, sample_numb, -1).contiguous() # (bsz, sample_numb, hidden_dim) 57 | content_vectors = torch.cat([features_2d, features_3d], dim=-1) # (bsz, sample_numb, hidden_dim * 2) 58 | 59 | content_vectors, _ = self.bilstm(content_vectors) # (bsz, sample_numb, hidden_dim) 60 | content_vectors = torch.max(content_vectors, dim=1)[0] # (bsz, hidden_dim) 61 | 62 | tgt = content_vectors[None, ...].repeat(self.max_objects, 1, 1) # (max_objects, bsz, hidden_dim) 63 | objects = self.input_proj(objects.view(-1, objects.shape[-1])) 64 | objects = objects.view(bsz, max_objects_per_video, -1).contiguous() # (bsz, max_objects_per_video, hidden_dim) 65 | 66 | mask = objects_mask.to(device).bool() # (bsz, max_objects_per_video) 67 | salient_objects = self.transformer(objects, tgt, mask, self.query_embed.weight)[0][0] # (bsz, max_objects, hidden_dim) 68 | object_pending = self.fc_layer(salient_objects) 69 | return salient_objects, object_pending 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /models/encoders/predicate_level.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch import Tensor 4 | 5 | 6 | class PredicateLevelEncoder(nn.Module): 7 | def __init__(self, feature3d_dim, hidden_dim, semantics_dim, useless_objects): 8 | super(PredicateLevelEncoder, self).__init__() 9 | self.linear_layer = nn.Linear(feature3d_dim, hidden_dim) 10 | self.W = nn.Linear(hidden_dim, hidden_dim) 11 | self.U = nn.Linear(hidden_dim, hidden_dim) 12 | self.b = nn.Parameter(torch.ones(hidden_dim), requires_grad=True) 13 | self.w = nn.Linear(hidden_dim, 1) 14 | self.inf = 1e5 15 | self.useless_objects = useless_objects 16 | 17 | self.bilstm = nn.LSTM(input_size=hidden_dim + hidden_dim, 18 | hidden_size=hidden_dim // 2, 19 | num_layers=1, bidirectional=True, batch_first=True) 20 | self.fc_layer = nn.Linear(hidden_dim, semantics_dim) 21 | 22 | def forward(self, features3d: Tensor, objects: Tensor, objects_mask: Tensor): 23 | """ 24 | 25 | Args: 26 | features3d: (bsz, sample_numb, 3d_dim) 27 | objects: (bsz, max_objects, hidden_dim) 28 | objects_mask: (bsz, max_objects_per_video) 29 | 30 | Returns: 31 | action_features: (bsz, sample_numb, hidden_dim * 2) 32 | action_pending: (bsz, semantics_dim) 33 | """ 34 | sample_numb = features3d.shape[1] 35 | features3d = self.linear_layer(features3d) # (bsz, sample_numb, hidden_dim) 36 | Wf3d = self.W(features3d) # (bsz, sample_numb, hidden_dim) 37 | Uobjs = self.U(objects) # (bsz, max_objects, hidden_dim) 38 | 39 | attn_feat = Wf3d.unsqueeze(2) + Uobjs.unsqueeze(1) + self.b # (bsz, sample_numb, max_objects, hidden_dim) 40 | attn_weights = self.w(torch.tanh(attn_feat)) # (bsz, sample_numb, max_objects, 1) 41 | objects_mask = objects_mask[:, None, :, None].repeat(1, sample_numb, 1, 1) # (bsz, sample_numb, max_objects_per_video, 1) 42 | if self.useless_objects: 43 | attn_weights = attn_weights - objects_mask.float() * self.inf 44 | attn_weights = attn_weights.softmax(dim=-2) # (bsz, sample_numb, max_objects, 1) 45 | attn_objects = attn_weights * attn_feat 46 | attn_objects = attn_objects.sum(dim=-2) # (bsz, sample_numb, hidden_dim) 47 | 48 | features = torch.cat([features3d, attn_objects], dim=-1) # (bsz, sample_numb, hidden_dim * 2) 49 | output, states = self.bilstm(features) # (bsz, sample_numb, hidden_dim) 50 | action = torch.max(output, dim=1)[0] # (bsz, hidden_dim) 51 | action_pending = self.fc_layer(action) # (bsz, semantics_dim) 52 | action_features = output # (bsz, sample_numb, hidden_dim) 53 | 54 | return action_features, action_pending 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /models/encoders/sentence_level.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | 4 | 5 | class SentenceLevelEncoder(nn.Module): 6 | def __init__(self, feature2d_dim, hidden_dim, semantics_dim, useless_objects): 7 | super(SentenceLevelEncoder, self).__init__() 8 | self.inf = 1e5 9 | self.useless_objects = useless_objects 10 | self.linear_2d = nn.Linear(feature2d_dim, hidden_dim) 11 | 12 | self.W = nn.Linear(hidden_dim, hidden_dim) 13 | 14 | self.Uo = nn.Linear(hidden_dim, hidden_dim) 15 | self.Um = nn.Linear(hidden_dim, hidden_dim) 16 | 17 | self.bo = nn.Parameter(torch.ones(hidden_dim), requires_grad=True) 18 | self.bm = nn.Parameter(torch.ones(hidden_dim), requires_grad=True) 19 | 20 | self.wo = nn.Linear(hidden_dim, 1) 21 | self.wm = nn.Linear(hidden_dim, 1) 22 | 23 | self.lstm = nn.LSTM(input_size=hidden_dim + hidden_dim + hidden_dim, 24 | hidden_size=hidden_dim // 2, 25 | num_layers=1, bidirectional=True, batch_first=True) 26 | self.fc_layer = nn.Linear(hidden_dim, semantics_dim) 27 | 28 | def forward(self, feature2ds: Tensor, vp_features: Tensor, object_features: Tensor, objects_mask: Tensor): 29 | """ 30 | 31 | Args: 32 | feature2ds: (bsz, sample_numb, hidden_dim) 33 | vp_features: (bsz, sample_numb, hidden_dim) 34 | object_features: (bsz, max_objects, hidden_dim) 35 | objects_mask: (bsz, max_objects_per_video) 36 | 37 | Returns: 38 | video_features: (bsz, sample_numb, hidden_dim) 39 | video_pending: (bsz, semantics_dim) 40 | """ 41 | sample_numb = feature2ds.shape[1] 42 | feature2ds = self.linear_2d(feature2ds) 43 | W_f2d = self.W(feature2ds) 44 | U_objs = self.Uo(object_features) 45 | U_motion = self.Um(vp_features) 46 | 47 | attn_feat = W_f2d.unsqueeze(2) + U_objs.unsqueeze(1) + self.bo # (bsz, sample_numb, max_objects, hidden_dim) 48 | attn_weights = self.wo(torch.tanh(attn_feat)) # (bsz, sample_numb, max_objects, 1) 49 | objects_mask = objects_mask[:, None, :, None].repeat(1, sample_numb, 1, 1) # (bsz, sample, max_objects_per_video, 1) 50 | if self.useless_objects: 51 | attn_weights = attn_weights - objects_mask.float() * self.inf 52 | attn_weights = attn_weights.softmax(dim=-2) # (bsz, sample_numb, max_objects, 1) 53 | attn_objects = attn_weights * attn_feat 54 | attn_objects = attn_objects.sum(dim=-2) # (bsz, sample_numb, hidden_dim) 55 | 56 | attn_feat = W_f2d.unsqueeze(2) + U_motion.unsqueeze(1) + self.bm # (bsz, sample_numb, sample_numb, hidden_dim) 57 | attn_weights = self.wm(torch.tanh(attn_feat)) # (bsz, sample_numb, sample_numb, 1) 58 | attn_weights = attn_weights.softmax(dim=-2) # (bsz, sample_numb, sample_numb, 1) 59 | attn_motion = attn_weights * attn_feat 60 | attn_motion = attn_motion.sum(dim=-2) # (bsz, sample_numb, hidden_dim) 61 | 62 | features = torch.cat([feature2ds, attn_motion, attn_objects], dim=-1) # (bsz, sample_numb, hidden_dim * 3) 63 | output, states = self.lstm(features) # (bsz, sample_numb, hidden_dim) 64 | video = torch.max(output, dim=1)[0] # (bsz, hidden_dim) 65 | video_pending = self.fc_layer(video) # (bsz, semantics_dim) 66 | video_features = output # (bsz, sample_numb, hidden_dim) 67 | 68 | return video_features, video_pending 69 | 70 | 71 | -------------------------------------------------------------------------------- /models/encoders/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from typing import Optional, List 4 | from torch import nn, Tensor 5 | import copy 6 | 7 | 8 | class Transformer(nn.Module): 9 | 10 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 11 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 12 | activation="relu", normalize_before=False, 13 | return_intermediate_dec=False): 14 | super().__init__() 15 | 16 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 17 | dropout, activation, normalize_before) 18 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 19 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 20 | 21 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 22 | dropout, activation, normalize_before) 23 | decoder_norm = nn.LayerNorm(d_model) 24 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 25 | return_intermediate=return_intermediate_dec) 26 | 27 | self._reset_parameters() 28 | 29 | self.d_model = d_model 30 | self.nhead = nhead 31 | 32 | def _reset_parameters(self): 33 | for p in self.parameters(): 34 | if p.dim() > 1: 35 | nn.init.xavier_uniform_(p) 36 | 37 | def forward(self, src, tgt, mask, query_embed, pos_embed=None): 38 | bs, detected_num, c = src.shape 39 | src = src.permute(1, 0, 2) # (max_objects_per_video, bsz, object_feats_dim) 40 | # pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 41 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 42 | mask = mask.flatten(1) 43 | 44 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 45 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 46 | pos=pos_embed, query_pos=query_embed) 47 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, detected_num) 48 | 49 | 50 | class TransformerEncoder(nn.Module): 51 | 52 | def __init__(self, encoder_layer, num_layers, norm=None): 53 | super().__init__() 54 | self.layers = _get_clones(encoder_layer, num_layers) 55 | self.num_layers = num_layers 56 | self.norm = norm 57 | 58 | def forward(self, src, 59 | mask: Optional[Tensor] = None, 60 | src_key_padding_mask: Optional[Tensor] = None, 61 | pos: Optional[Tensor] = None): 62 | output = src 63 | 64 | for layer in self.layers: 65 | output = layer(output, src_mask=mask, 66 | src_key_padding_mask=src_key_padding_mask, pos=pos) 67 | 68 | if self.norm is not None: 69 | output = self.norm(output) 70 | 71 | return output 72 | 73 | 74 | class TransformerDecoder(nn.Module): 75 | 76 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 77 | super().__init__() 78 | self.layers = _get_clones(decoder_layer, num_layers) 79 | self.num_layers = num_layers 80 | self.norm = norm 81 | self.return_intermediate = return_intermediate 82 | 83 | def forward(self, tgt, memory, 84 | tgt_mask: Optional[Tensor] = None, 85 | memory_mask: Optional[Tensor] = None, 86 | tgt_key_padding_mask: Optional[Tensor] = None, 87 | memory_key_padding_mask: Optional[Tensor] = None, 88 | pos: Optional[Tensor] = None, 89 | query_pos: Optional[Tensor] = None): 90 | output = tgt 91 | 92 | intermediate = [] 93 | 94 | for layer in self.layers: 95 | output = layer(output, memory, tgt_mask=tgt_mask, 96 | memory_mask=memory_mask, 97 | tgt_key_padding_mask=tgt_key_padding_mask, 98 | memory_key_padding_mask=memory_key_padding_mask, 99 | pos=pos, query_pos=query_pos) 100 | if self.return_intermediate: 101 | intermediate.append(self.norm(output)) 102 | 103 | if self.norm is not None: 104 | output = self.norm(output) 105 | if self.return_intermediate: 106 | intermediate.pop() 107 | intermediate.append(output) 108 | 109 | if self.return_intermediate: 110 | return torch.stack(intermediate) 111 | 112 | return output.unsqueeze(0) 113 | 114 | 115 | class TransformerEncoderLayer(nn.Module): 116 | 117 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 118 | activation="relu", normalize_before=False): 119 | super().__init__() 120 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 121 | # Implementation of Feedforward model 122 | self.linear1 = nn.Linear(d_model, dim_feedforward) 123 | self.dropout = nn.Dropout(dropout) 124 | self.linear2 = nn.Linear(dim_feedforward, d_model) 125 | 126 | self.norm1 = nn.LayerNorm(d_model) 127 | self.norm2 = nn.LayerNorm(d_model) 128 | self.dropout1 = nn.Dropout(dropout) 129 | self.dropout2 = nn.Dropout(dropout) 130 | 131 | self.activation = _get_activation_fn(activation) 132 | self.normalize_before = normalize_before 133 | 134 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 135 | return tensor if pos is None else tensor + pos 136 | 137 | def forward_post(self, 138 | src, 139 | src_mask: Optional[Tensor] = None, 140 | src_key_padding_mask: Optional[Tensor] = None, 141 | pos: Optional[Tensor] = None): 142 | q = k = self.with_pos_embed(src, pos) 143 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 144 | key_padding_mask=src_key_padding_mask)[0] 145 | src = src + self.dropout1(src2) 146 | src = self.norm1(src) 147 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 148 | src = src + self.dropout2(src2) 149 | src = self.norm2(src) 150 | return src 151 | 152 | def forward_pre(self, src, 153 | src_mask: Optional[Tensor] = None, 154 | src_key_padding_mask: Optional[Tensor] = None, 155 | pos: Optional[Tensor] = None): 156 | src2 = self.norm1(src) 157 | q = k = self.with_pos_embed(src2, pos) 158 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 159 | key_padding_mask=src_key_padding_mask)[0] 160 | src = src + self.dropout1(src2) 161 | src2 = self.norm2(src) 162 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 163 | src = src + self.dropout2(src2) 164 | return src 165 | 166 | def forward(self, src, 167 | src_mask: Optional[Tensor] = None, 168 | src_key_padding_mask: Optional[Tensor] = None, 169 | pos: Optional[Tensor] = None): 170 | if self.normalize_before: 171 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 172 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 173 | 174 | 175 | class TransformerDecoderLayer(nn.Module): 176 | 177 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 178 | activation="relu", normalize_before=False): 179 | super().__init__() 180 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 181 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 182 | # Implementation of Feedforward model 183 | self.linear1 = nn.Linear(d_model, dim_feedforward) 184 | self.dropout = nn.Dropout(dropout) 185 | self.linear2 = nn.Linear(dim_feedforward, d_model) 186 | 187 | self.norm1 = nn.LayerNorm(d_model) 188 | self.norm2 = nn.LayerNorm(d_model) 189 | self.norm3 = nn.LayerNorm(d_model) 190 | self.dropout1 = nn.Dropout(dropout) 191 | self.dropout2 = nn.Dropout(dropout) 192 | self.dropout3 = nn.Dropout(dropout) 193 | 194 | self.activation = _get_activation_fn(activation) 195 | self.normalize_before = normalize_before 196 | 197 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 198 | return tensor if pos is None else tensor + pos 199 | 200 | def forward_post(self, tgt, memory, 201 | tgt_mask: Optional[Tensor] = None, 202 | memory_mask: Optional[Tensor] = None, 203 | tgt_key_padding_mask: Optional[Tensor] = None, 204 | memory_key_padding_mask: Optional[Tensor] = None, 205 | pos: Optional[Tensor] = None, 206 | query_pos: Optional[Tensor] = None): 207 | q = k = self.with_pos_embed(tgt, query_pos) 208 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 209 | key_padding_mask=tgt_key_padding_mask)[0] 210 | tgt = tgt + self.dropout1(tgt2) 211 | tgt = self.norm1(tgt) 212 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 213 | key=self.with_pos_embed(memory, pos), 214 | value=memory, attn_mask=memory_mask, 215 | key_padding_mask=memory_key_padding_mask)[0] 216 | tgt = tgt + self.dropout2(tgt2) 217 | tgt = self.norm2(tgt) 218 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 219 | tgt = tgt + self.dropout3(tgt2) 220 | tgt = self.norm3(tgt) 221 | return tgt 222 | 223 | def forward_pre(self, tgt, memory, 224 | tgt_mask: Optional[Tensor] = None, 225 | memory_mask: Optional[Tensor] = None, 226 | tgt_key_padding_mask: Optional[Tensor] = None, 227 | memory_key_padding_mask: Optional[Tensor] = None, 228 | pos: Optional[Tensor] = None, 229 | query_pos: Optional[Tensor] = None): 230 | tgt2 = self.norm1(tgt) 231 | q = k = self.with_pos_embed(tgt2, query_pos) 232 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 233 | key_padding_mask=tgt_key_padding_mask)[0] 234 | tgt = tgt + self.dropout1(tgt2) 235 | tgt2 = self.norm2(tgt) 236 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 237 | key=self.with_pos_embed(memory, pos), 238 | value=memory, attn_mask=memory_mask, 239 | key_padding_mask=memory_key_padding_mask)[0] 240 | tgt = tgt + self.dropout2(tgt2) 241 | tgt2 = self.norm3(tgt) 242 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 243 | tgt = tgt + self.dropout3(tgt2) 244 | return tgt 245 | 246 | def forward(self, tgt, memory, 247 | tgt_mask: Optional[Tensor] = None, 248 | memory_mask: Optional[Tensor] = None, 249 | tgt_key_padding_mask: Optional[Tensor] = None, 250 | memory_key_padding_mask: Optional[Tensor] = None, 251 | pos: Optional[Tensor] = None, 252 | query_pos: Optional[Tensor] = None): 253 | if self.normalize_before: 254 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 255 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 256 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 257 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 258 | 259 | 260 | def _get_clones(module, N): 261 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 262 | 263 | 264 | def build_transformer(args): 265 | return Transformer( 266 | d_model=args.hidden_dim, 267 | dropout=args.dropout, 268 | nhead=args.nheads, 269 | dim_feedforward=args.dim_feedforward, 270 | num_encoder_layers=args.enc_layers, 271 | num_decoder_layers=args.dec_layers, 272 | normalize_before=args.pre_norm, 273 | return_intermediate_dec=True, 274 | ) 275 | 276 | 277 | def _get_activation_fn(activation): 278 | """Return an activation function given a string""" 279 | if activation == "relu": 280 | return F.relu 281 | if activation == "gelu": 282 | return F.gelu 283 | if activation == "glu": 284 | return F.glu 285 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 286 | 287 | 288 | -------------------------------------------------------------------------------- /models/hungary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, Tensor 3 | import numpy as np 4 | from scipy.optimize import linear_sum_assignment 5 | 6 | class HungarianMatcher(nn.Module): 7 | """This class computes an assignment between the targets and the predictions of the network 8 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 9 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 10 | while the others are un-matched (and thus treated as non-objects). 11 | """ 12 | 13 | def __init__(self): 14 | super(HungarianMatcher, self).__init__() 15 | self.eps = 1e-6 16 | 17 | @torch.no_grad() 18 | def forward(self, salient_objects: Tensor, nouns_dict_list: list): 19 | """ Performs the matching 20 | Args: 21 | salient_objects: (bsz, max_objects, word_dim) 22 | nouns_dict_list: List[{'vec': nouns_vec, 'nouns': nouns}, ...] 23 | Returns: 24 | A list of size batch_size, containing tuples of (index_i, index_j) where: 25 | - index_i is the indices of the selected predictions (in order) 26 | - index_j is the indices of the corresponding selected targets (in order) 27 | For each batch element, it holds: 28 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 29 | """ 30 | bsz, max_objects = salient_objects.shape[:2] 31 | device = salient_objects.device 32 | sizes = [len(item['nouns']) for item in nouns_dict_list] 33 | nouns_semantics = torch.cat([item['vec'][:len(item['nouns'])] for item in nouns_dict_list]).to(device) # (\sigma nouns, word_dim) 34 | nouns_length = torch.norm(nouns_semantics, dim=-1, keepdim=True) # (\sigma nouns, 1) 35 | salient_objects = salient_objects.flatten(0, 1) # (bsz * max_objects, word_dim) 36 | salient_length = torch.norm(salient_objects, dim=-1, keepdim=True) # (bsz * max_objects, 1) 37 | matrix_length = salient_length * nouns_length.permute([1, 0]) + self.eps # (bsz * max_objects, \sigma nouns) 38 | 39 | 40 | cos_matrix = torch.mm(salient_objects, nouns_semantics.permute([1, 0])) # (bsz * max_objects, \sigma nouns) 41 | cos_matrix = -cos_matrix / matrix_length # (bsz * max_objects, \sigma nouns) 42 | cos_matrix = cos_matrix.view([bsz, max_objects, -1]) # (bsz, max_objects, \sigma nouns) 43 | indices = [linear_sum_assignment(c[i].detach().cpu().numpy()) for i, c in enumerate(cos_matrix.split(sizes, -1))] 44 | 45 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 46 | 47 | -------------------------------------------------------------------------------- /scripts/split_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import h5py 4 | import argparse 5 | 6 | if __name__ == '__main__': 7 | parser = argparse.ArgumentParser() 8 | parser.add_argument('--dataset_name', type=str, default='-1', help='dataset name (MSVD | MSRVTT)') 9 | parser.add_argument('--split_dir', type=str, default='-1', help='dir contains split_list') 10 | parser.add_argument('--data_path', type=str, default='-1', help='the path of unsplited data') 11 | parser.add_argument('--data_name', type=str, default='-1', help='name of unsplited data') 12 | parser.add_argument('--target_dir', type=str, default='-1', help='where do you want to place your data') 13 | parser.add_argument('--split_objects', action='store_true') 14 | 15 | args = parser.parse_args() 16 | dataset_name = args.dataset_name 17 | split_dir = args.split_dir 18 | data_path = args.data_path 19 | data_name = args.data_name 20 | target_dir = args.target_dir 21 | 22 | assert dataset_name != '-1', 'Please set dataset_name!' 23 | assert split_dir != '-1', 'Please set split_dir!' 24 | assert data_path != '-1', 'Please set data_path!' 25 | assert data_name != '-1', 'Please set data_name!' 26 | assert target_dir != '-1', 'Please set target_dir!' 27 | 28 | splits_list_path_tpl = '{dataset_name}_{part}_list.pkl'.format(dataset_name=dataset_name, part='{}') 29 | dataset_split_path_tpl = '{}_{}_{}.hdf5'.format(dataset_name, data_name, '{}') 30 | split_part_list = ['train', 'valid', 'test'] 31 | print('[split begin]', '=' * 20) 32 | with h5py.File(data_path, 'r') as f: 33 | for split in split_part_list: 34 | cur_vid_list_path = os.path.join(split_dir, splits_list_path_tpl.format(split)) 35 | dataset_split_save_path = os.path.join(target_dir, dataset_split_path_tpl.format(split)) 36 | with open(cur_vid_list_path, 'rb') as v: 37 | vid_list = pickle.load(v) 38 | with h5py.File(dataset_split_save_path, 'w') as t: 39 | for vid in vid_list: 40 | if args.split_objects: 41 | t[vid] = f[vid][()] 42 | else: 43 | temp_group = t.create_group(vid) 44 | temp_group['feats'] = f[vid]['feats'][()] 45 | # temp_group['bboxes'] = f[vid]['bboxes'][()] 46 | # temp_group['kinds'] = f[vid]['kinds'][()] 47 | print('[split end]', '=' * 20) 48 | 49 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | 4 | from configs.settings import TotalConfigs 5 | from eval import eval_fn 6 | 7 | 8 | def test_fn(cfgs: TotalConfigs, model, loader, device): 9 | print('##############n_vocab is {}##############'.format(cfgs.decoder.n_vocab)) 10 | with open(cfgs.data.idx2word_path, 'rb') as f: 11 | idx2word = pickle.load(f) 12 | with open(cfgs.data.vid2groundtruth_path, 'rb') as f: 13 | vid2groundtruth = pickle.load(f) 14 | scores = eval_fn(model=model, loader=loader, device=device, 15 | idx2word=idx2word, save_on_disk=True, cfgs=cfgs, 16 | vid2groundtruth=vid2groundtruth) 17 | print('===================Testing is finished====================') 18 | 19 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import pickle 6 | 7 | from utils.loss import LanguageModelCriterion, CosineCriterion, SoftCriterion 8 | from eval import eval_fn 9 | from configs.settings import TotalConfigs 10 | from models.hungary import HungarianMatcher 11 | 12 | 13 | def _get_src_permutation_idx(indices): 14 | # permute predictions following indices 15 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 16 | src_idx = torch.cat([src for (src, _) in indices]) 17 | return batch_idx, src_idx 18 | 19 | 20 | def _get_tgt_permutation_idx(indices): 21 | # permute targets following indices 22 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 23 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 24 | return batch_idx, tgt_idx 25 | 26 | 27 | def train_fn(cfgs: TotalConfigs, model_name: str, model: nn.Module, matcher: HungarianMatcher, train_loader, valid_loader, device): 28 | optimizer = optim.Adam(model.parameters(), lr=cfgs.train.learning_rate) 29 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfgs.train.max_epochs, eta_min=0, last_epoch=-1) 30 | language_loss = LanguageModelCriterion() 31 | soft_loss = SoftCriterion() 32 | cos_loss = CosineCriterion() 33 | language_loss.to(device) 34 | best_score, cnt = None, 0 35 | loss_store, loss_entity, loss_predicate, loss_sentence, loss_xe, loss_soft_target = [], [], [], [], [], [] 36 | 37 | with open(cfgs.data.idx2word_path, 'rb') as f: 38 | idx2word = pickle.load(f) 39 | with open(cfgs.data.vid2groundtruth_path, 'rb') as f: 40 | vid2groundtruth = pickle.load(f) 41 | 42 | print('===================Training begin====================') 43 | print(cfgs.train.save_checkpoints_path) 44 | for epoch in range(cfgs.train.max_epochs): 45 | print('\n{}[EPOCH {}]{}'.format('='*15, epoch, '='*15)) 46 | 47 | for i, (feature2ds, feature3ds, objects, object_masks, vp_semantics, caption_semantics, numberic_caps, masks, captions, nouns, vids, vocab_ids, vocab_probs, fillmasks) in enumerate(train_loader): 48 | cnt += 1 49 | feature2ds = feature2ds.to(device) 50 | feature3ds = feature3ds.to(device) 51 | objects = objects.to(device) 52 | object_masks = object_masks.to(device) 53 | vp_semantics = vp_semantics.to(device) 54 | caption_semantics = caption_semantics.to(device) 55 | numberic_caps = numberic_caps.to(device) 56 | masks = masks.to(device) 57 | vocab_ids = vocab_ids.to(device) if vocab_ids is not None else None 58 | vocab_probs = vocab_probs.to(device) if vocab_probs is not None else None 59 | fillmasks = fillmasks.to(device) if fillmasks is not None else None 60 | 61 | # bsz, sample_numb, obj_numb, obj_dim = objects.shape 62 | # objects = objects.reshape([bsz, sample_numb * obj_numb, obj_dim]) 63 | 64 | optimizer.zero_grad() 65 | 66 | preds, objects_pending, action_pending, video_pending = model(objects, object_masks, feature2ds, feature3ds, numberic_caps) 67 | xe_loss, s_loss, ent_loss, pred_loss, sent_loss = None, None, None, None, None 68 | 69 | # cross entropy loss 70 | loss_hard = language_loss(preds, numberic_caps, masks, cfgs.dict.eos_idx) 71 | loss = loss_hard 72 | xe_loss = loss_hard.detach().item() 73 | 74 | # soft loss 75 | if cfgs.train.lambda_soft > 0: 76 | loss_soft = soft_loss(preds, vocab_ids, vocab_probs, fillmasks) 77 | loss = loss + loss_soft * cfgs.train.lambda_soft 78 | s_loss = loss_soft.detach().item() 79 | 80 | # object module loss 81 | if cfgs.train.lambda_entity > 0: 82 | indices = matcher(objects_pending, nouns) 83 | src_idx = _get_src_permutation_idx(indices) 84 | objects = objects_pending[src_idx] 85 | targets = torch.cat([t['vec'][i] for t, (_, i) in zip(nouns, indices)], dim=0).to(device) 86 | if np.any(np.isnan(objects.detach().cpu().numpy())): 87 | raise RuntimeError 88 | object_loss = cos_loss(objects, targets) 89 | loss = loss + object_loss * cfgs.train.lambda_entity 90 | ent_loss = object_loss.detach().item() 91 | 92 | # action module loss 93 | if cfgs.train.lambda_predicate > 0: 94 | action_loss = cos_loss(action_pending, vp_semantics) 95 | loss = loss + action_loss * cfgs.train.lambda_predicate 96 | pred_loss = action_loss.detach().item() 97 | 98 | # video module loss 99 | if cfgs.train.lambda_sentence > 0: 100 | sent_loss = cos_loss(video_pending, caption_semantics) 101 | loss = loss + sent_loss * cfgs.train.lambda_sentence 102 | sent_loss = sent_loss.detach().item() 103 | 104 | loss.backward() 105 | loss_store.append(loss.detach().item()) 106 | loss_xe.append(xe_loss) 107 | loss_entity.append(ent_loss) 108 | loss_predicate.append(pred_loss) 109 | loss_sentence.append(sent_loss) 110 | loss_soft_target.append(s_loss) 111 | nn.utils.clip_grad_norm_(model.parameters(), cfgs.train.grad_clip) 112 | optimizer.step() 113 | 114 | if cnt % cfgs.train.visualize_every == 0: 115 | loss_store, loss_xe, loss_entity, loss_predicate, loss_sentence, loss_soft_target = \ 116 | loss_store[-10:], loss_xe[-10:], loss_entity[-10:], loss_predicate[-10:], loss_sentence[-10:], loss_soft_target[-10:] 117 | loss_value = np.array(loss_store).mean() 118 | xe_value = np.array(loss_xe).mean() if loss_xe[0] is not None else 0 119 | soft_value = np.array(loss_soft_target).mean() if loss_soft_target[0] is not None else 0 120 | entity_value = np.array(loss_entity).mean() if loss_entity[0] is not None else 0 121 | predicate_value = np.array(loss_predicate).mean() if loss_predicate[0] is not None else 0 122 | sentence_value = np.array(loss_sentence).mean() if loss_sentence[0] is not None else 0 123 | 124 | print('[EPOCH {};ITER {}]:loss[{:.3f}]=hard_loss[{:.3f}]*1+soft_loss[{:.3f}]*{:.2f}+entity[{:.3f}]*{:.2f}+predicate[{:.3f}]*{:.2f}+sentence[{:.3f}]*{:.2f}' 125 | .format(epoch, i, loss_value, xe_value, 126 | soft_value, cfgs.train.lambda_soft, 127 | entity_value, cfgs.train.lambda_entity, 128 | predicate_value, cfgs.train.lambda_predicate, 129 | sentence_value, cfgs.train.lambda_sentence)) 130 | 131 | if cnt % cfgs.train.save_checkpoints_every == 0: 132 | ckpt_path = cfgs.train.save_checkpoints_path 133 | scores = eval_fn(model=model, loader=valid_loader, device=device, 134 | idx2word=idx2word, save_on_disk=False, cfgs=cfgs, 135 | vid2groundtruth=vid2groundtruth) 136 | cider_score = scores['CIDEr'] 137 | if best_score is None or cider_score > best_score: 138 | best_score = cider_score 139 | torch.save(model.state_dict(), ckpt_path) 140 | print('=' * 10, 141 | '[EPOCH{epoch} iter{it}] :Best Cider is {bs}, Current Cider is {cs}'. 142 | format(epoch=epoch, it=i, bs=best_score, cs=cider_score), 143 | '=' * 10) 144 | 145 | lr_scheduler.step() 146 | print('===================Training is finished====================') 147 | return model 148 | 149 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/utils/__init__.py -------------------------------------------------------------------------------- /utils/build_loaders.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from data.loader.data_loader import CaptionDataset, collate_fn_caption 3 | from configs.settings import TotalConfigs 4 | 5 | 6 | def build_loaders(cfgs: TotalConfigs, is_total=False): 7 | train_dataset = CaptionDataset(cfgs=cfgs, mode='train', save_on_disk=False, is_total=is_total) 8 | valid_dataset = CaptionDataset(cfgs=cfgs, mode='valid', save_on_disk=False, is_total=is_total) 9 | test_dataset = CaptionDataset(cfgs=cfgs, mode='test', save_on_disk=True, is_total=is_total) 10 | 11 | train_loader = DataLoader(dataset=train_dataset, batch_size=cfgs.bsz, shuffle=True, 12 | collate_fn=collate_fn_caption, num_workers=0) 13 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=cfgs.bsz, shuffle=True, 14 | collate_fn=collate_fn_caption, num_workers=0) 15 | test_loader = DataLoader(dataset=test_dataset, batch_size=cfgs.bsz, shuffle=False, 16 | collate_fn=collate_fn_caption, num_workers=0) 17 | 18 | return train_loader, valid_loader, test_loader 19 | 20 | 21 | def get_test_loader(cfgs: TotalConfigs, is_total=False): 22 | test_dataset = CaptionDataset(cfgs=cfgs, mode='test', save_on_disk=True, is_total=is_total) 23 | test_loader = DataLoader(dataset=test_dataset, batch_size=cfgs.bsz, 24 | shuffle=False, collate_fn=collate_fn_caption, 25 | num_workers=0) 26 | return test_loader 27 | 28 | 29 | def get_train_loader(cfgs: TotalConfigs, save_on_disk=True): 30 | train_dataset = CaptionDataset(cfgs=cfgs, mode='train', save_on_disk=save_on_disk, is_total=False) 31 | train_loader = DataLoader(dataset=train_dataset, batch_size=cfgs.bsz, 32 | shuffle=False, collate_fn=collate_fn_caption, 33 | num_workers=0) 34 | return train_loader 35 | 36 | def get_valid_loader(cfgs: TotalConfigs, is_total=False): 37 | valid_dataset = CaptionDataset(cfgs=cfgs, mode='valid', save_on_disk=True, is_total=is_total) 38 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=cfgs.bsz, 39 | shuffle=False, collate_fn=collate_fn_caption, 40 | num_workers=0) 41 | return valid_loader 42 | 43 | -------------------------------------------------------------------------------- /utils/build_model.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | from models.caption_models.hierarchical_model import HierarchicalModel 4 | from models.decoder import Decoder 5 | 6 | from models.encoders.transformer import Transformer 7 | from models.encoders.entity_level import EntityLevelEncoder 8 | from models.encoders.predicate_level import PredicateLevelEncoder 9 | from models.encoders.sentence_level import SentenceLevelEncoder 10 | 11 | from configs.settings import TotalConfigs 12 | 13 | 14 | def build_model(cfgs: TotalConfigs): 15 | model_name = cfgs.model_name 16 | embedding_weights_path = cfgs.data.embedding_weights_path 17 | max_caption_len = cfgs.test.max_caption_len 18 | temperature = cfgs.test.temperature 19 | beam_size = cfgs.test.beam_size 20 | pad_idx = cfgs.dict.eos_idx 21 | eos_idx = cfgs.dict.eos_idx 22 | sos_idx = cfgs.dict.sos_idx 23 | unk_idx = cfgs.dict.unk_idx 24 | with open(embedding_weights_path, 'rb') as f: 25 | embedding_weights = pickle.load(f) 26 | n_vocab = embedding_weights.shape[0] 27 | cfgs.decoder.n_vocab = n_vocab 28 | 29 | feature2d_dim = cfgs.encoder.backbone_2d_dim 30 | feature3d_dim = cfgs.encoder.backbone_3d_dim 31 | object_dim = cfgs.encoder.object_dim 32 | semantics_dim = cfgs.encoder.semantics_dim 33 | hidden_dim = cfgs.decoder.hidden_dim 34 | decoder_num_layers = cfgs.decoder.num_layers 35 | embed_dim = cfgs.data.word_dim 36 | max_objects = cfgs.encoder.max_objects 37 | 38 | nheads = cfgs.encoder.nheads 39 | trans_num_encoder_layers = cfgs.encoder.entity_encoder_layer 40 | trans_num_decoder_layers = cfgs.encoder.entity_encoder_layer 41 | dim_feedforward = cfgs.encoder.dim_feedforward 42 | transformer_activation = cfgs.encoder.transformer_activation 43 | d_model = cfgs.encoder.d_model 44 | trans_dropout = cfgs.encoder.trans_dropout 45 | 46 | if model_name == 'HMN': 47 | # encoders 48 | transformer = Transformer(d_model=d_model, nhead=nheads, 49 | num_encoder_layers=trans_num_encoder_layers, 50 | num_decoder_layers=trans_num_decoder_layers, 51 | dim_feedforward=dim_feedforward, 52 | dropout=trans_dropout, 53 | activation=transformer_activation) 54 | entity_level_encoder = EntityLevelEncoder(transformer=transformer, 55 | max_objects=max_objects, 56 | object_dim=object_dim, 57 | feature2d_dim=feature2d_dim, 58 | feature3d_dim=feature3d_dim, 59 | hidden_dim=hidden_dim, 60 | word_dim=semantics_dim) 61 | predicate_level_encoder = PredicateLevelEncoder(feature3d_dim=feature3d_dim, 62 | hidden_dim=hidden_dim, 63 | semantics_dim=semantics_dim, 64 | useless_objects=False) 65 | sentence_level_encoder = SentenceLevelEncoder(feature2d_dim=feature2d_dim, 66 | hidden_dim=hidden_dim, 67 | semantics_dim=semantics_dim, 68 | useless_objects=False) 69 | 70 | # decoder 71 | decoder = Decoder(semantics_dim=semantics_dim, hidden_dim=hidden_dim, 72 | num_layers=decoder_num_layers, embed_dim=embed_dim, n_vocab=n_vocab, 73 | with_objects=True, with_action=True, with_video=True, 74 | with_objects_semantics=True, 75 | with_action_semantics=True, 76 | with_video_semantics=True) 77 | 78 | else: 79 | raise NotImplementedError 80 | 81 | # HMN 82 | model = HierarchicalModel(entity_level=entity_level_encoder, 83 | predicate_level=predicate_level_encoder, 84 | sentence_level=sentence_level_encoder, 85 | decoder=decoder, 86 | word_embedding_weights=embedding_weights, 87 | max_caption_len=max_caption_len, 88 | beam_size=beam_size, pad_idx=pad_idx, 89 | temperature=temperature, 90 | eos_idx=eos_idx, 91 | sos_idx=sos_idx, 92 | unk_idx=unk_idx) 93 | 94 | return model 95 | 96 | 97 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def to_contiguous(tensor): 6 | if tensor.is_contiguous(): 7 | return tensor 8 | else: 9 | return tensor.contiguous() 10 | 11 | 12 | class LanguageModelCriterion(nn.Module): 13 | def __init__(self): 14 | super(LanguageModelCriterion, self).__init__() 15 | 16 | def forward(self, pred, target, mask, eos_idx): 17 | 18 | pred = to_contiguous(pred).view(-1, pred.size(-1)) 19 | target = torch.cat([target[:, 1:], target[:, 0].unsqueeze(1).fill_(eos_idx)], dim=1) 20 | target = to_contiguous(target).view(-1, 1) 21 | mask = to_contiguous(mask).view(-1, 1) 22 | 23 | output = -1. * pred.gather(1, target) * mask 24 | output = torch.sum(output) / torch.sum(mask) 25 | return output.float() 26 | 27 | 28 | class SoftCriterion(nn.Module): 29 | def __init__(self): 30 | super(SoftCriterion, self).__init__() 31 | 32 | def forward(self, pred, idxs, soft_target, mask): 33 | topk = -1.0 * pred.gather(-1, idxs) * mask[..., None] 34 | output = soft_target * topk 35 | output = torch.sum(output) / torch.sum(mask) 36 | return output.float() 37 | 38 | 39 | class CosineCriterion(nn.Module): 40 | def __init__(self): 41 | super(CosineCriterion, self).__init__() 42 | self.eps = 1e-12 43 | 44 | def forward(self, pred, target): 45 | assert pred.shape == target.shape and pred.dim() == 2, \ 46 | 'expected pred.shape == target.shape, ' \ 47 | 'but got pred.shape == {} and target.shape == {}'.format(pred.shape, target.shape) 48 | pred_denom = torch.norm(pred, p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(pred) 49 | pred = pred / pred_denom 50 | target_denom = torch.norm(target, p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(target) 51 | target = target / target_denom 52 | 53 | ret = pred * target 54 | ret = 1.0 - ret.sum(dim=-1) 55 | ret = ret.sum() 56 | return ret 57 | 58 | 59 | 60 | --------------------------------------------------------------------------------