├── .gitignore
├── LICENSE
├── README.md
├── configs
├── __init__.py
└── settings.py
├── data
├── __init__.py
└── loader
│ ├── __init__.py
│ └── data_loader.py
├── eval.py
├── figures
├── Entity.png
├── HMN.png
├── MSRVTT-performance.png
├── MSVD-performance.png
└── motivation.png
├── hmn.yaml
├── main.py
├── models
├── __init__.py
├── caption_models
│ ├── __init__.py
│ ├── caption_module.py
│ └── hierarchical_model.py
├── decoder.py
├── encoders
│ ├── __init__.py
│ ├── entity_level.py
│ ├── predicate_level.py
│ ├── sentence_level.py
│ └── transformer.py
└── hungary.py
├── scripts
└── split_data.py
├── test.py
├── train.py
└── utils
├── __init__.py
├── build_loaders.py
├── build_model.py
└── loss.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # => data dir
2 | data/MSRVTT/
3 | data/MSVD/
4 |
5 | # => checkpoints
6 | checkpoints/
7 |
8 | # => results
9 | results/
10 |
11 | # => pycache
12 | */__pycache__
13 | __pycache__/
14 |
15 | # => eval metrics
16 | utils/coco_caption
17 |
18 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 MarcusNerva
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # HMN
2 | [[Paper]](https://arxiv.org/abs/2111.12476)
3 |
4 | Official code for **Hierarchical Modular Network for Video Captioning**.
5 |
6 | *Hanhua Ye, Guorong Li, Yuankai Qi, Shuhui Wang, Qingming Huang, Ming-Hsuan Yang*
7 |
8 | Accepted by CVPR2022
9 |
10 |
11 |
12 |
Figure 1.Motivation
13 |
14 | Representation learning plays a crucial role in video captioning task. Hierarchical Modular Network learns a discriminative video representation by bridging video content and linguistic caption at three levels:
15 |
16 | 1. Entity level, which highlights objects that are most likely to be mentioned in captions and is supervised by *entities* in ground-truth captions.
17 | 2. Predicate level, which learns the actions conditioned on highlighted objects and is supervised by the *predicate* in the ground-truth caption.
18 | 3. Sentence level, which learns the global video representation supervised by the whole ground-truth *sentence*.
19 |
20 | As there are a large number of objects in a video, but only a few are mentioned in captions, we proposed a novel entity module to learn to highlight these principal objects adaptively. Experimental results demonstrate that highlighting principal video objects will improve the performance significantly.
21 |
22 |
23 |
24 | ## Methodology
25 |
26 | As shown in Figure 2, our model follows the conventional **Encoder-Decoder** paradigm, where the proposed Hierarchical Modular Network (HMN) serves as the encoder. HMN consists of the entity, predicate, and sentence modules. These modules are designed to bridge video representations and linguistic semantics from three levels. Our model operates as follows. First, taking all detected objects as input, the entity module outputs the features of principal objects. The predicate module encodes actions by combining features of principal objects and the video motion. Next, the sentence module encodes a global representation for the entire video content considering the global context and features of previously generated objects and actions. Finally, all features are concatenated together and fed into the decoder to generate captions. Each module has its own input and linguistic supervision extracted from captions.
27 |
28 |
29 |
30 | Figure 2. Hierarchical Modular Network
31 |
32 | Figure 3 illustrates the main architecture of our entity module, which consists of a transformer encoder and transformer decoder. This design is motivated by [DETR](https://arxiv.org/abs/2005.12872), which utilizes a transformer encoder-decoder architecture to learn a fixed set of object queries to directly predict object bounding boxes for the object detection task. Instead of simply detecting objects, we aim to determine the important ones in the video.
33 |
34 |
35 |
36 | Figure 3. Main architecture of the entity module
37 |
38 |
39 |
40 | ## Usage
41 |
42 | Our proposed HMN is implemented with PyTorch.
43 |
44 | #### Environment
45 |
46 | - Python = 3.7
47 | - PyTorch = 1.4
48 |
49 |
50 |
51 | #### 1.Installation
52 |
53 | - Clone this repo:
54 |
55 | ```
56 | git clone https://github.com/MarcusNerva/HMN.git
57 | cd HMN
58 | ```
59 |
60 | + Clone a python3-version coco_caption repo under the utilis/
61 |
62 |
63 |
64 | #### 2.Download datasets
65 |
66 | **MSR-VTT Dataset:**
67 |
68 | - Context features (2D CNN features) : [MSRVTT-InceptionResNetV2](https://1drv.ms/u/s!ArYBhHmSAbFOc20zPEg-aSP7_cI?e=fhT1lN)
69 | - Motion features (3D CNN features) : [MSRVTT-C3D](https://1drv.ms/u/s!ArYBhHmSAbFOdKU9iZgHFGFHCAE?e=H5DyOE)
70 | - Object features (Extracted by Faster-RCNN) : [MSRVTT-Faster-RCNN](https://1drv.ms/u/s!ArYBhHmSAbFOdVQnfilWp6_epv4?e=Am9OXT)
71 | - Linguistic supervision: [MSRVTT-Language](https://1drv.ms/u/s!ArYBhHmSAbFOe0dX-SBDdxJ9RHM?e=ZlNbBQ)
72 | - Splits: [MSRVTT-Splits](https://1drv.ms/u/s!ArYBhHmSAbFOgURAYMc4f2-TeI0U?e=tde70k)
73 |
74 | **MSVD Dataset:**
75 |
76 | - Context features (2D CNN features) : [MSVD-InceptionResNetV2](https://1drv.ms/u/s!ArYBhHmSAbFOeMT-jksQPhkzYHA?e=mO2DTu)
77 | - Motion features (3D CNN features) : [MSVD-C3D](https://1drv.ms/u/s!ArYBhHmSAbFOd8H6ciT2CYwqFaE?e=VeWdS8)
78 | - Object features (Extracted by Faster-RCNN) : [MSVD-Faster-RCNN](https://1drv.ms/u/s!ArYBhHmSAbFOef5wZTxndFlz7bQ?e=fBPFHG)
79 | - Linguistic supervision: [MSVD-Language](https://1drv.ms/u/s!ArYBhHmSAbFOetaEHJnITH8q-eE?e=ePZlcn)
80 | - Splits: [MSVD-Splits](https://1drv.ms/u/s!ArYBhHmSAbFOgUj3RjfW982_KntY?e=hPMWYI)
81 |
82 |
83 |
84 | #### 3.Prepare training data
85 |
86 | - Organize visual and linguistic features under `data/`
87 |
88 | ```bash
89 | data
90 | ├── __init__.py
91 | ├── loader
92 | │ ├── data_loader.py
93 | │ └── __init__.py
94 | ├── MSRVTT
95 | │ ├── language
96 | │ │ ├── embedding_weights.pkl
97 | │ │ ├── idx2word.pkl
98 | │ │ ├── vid2groundtruth.pkl
99 | │ │ ├── vid2language.pkl
100 | │ │ ├── word2idx.pkl
101 | │ │ └── vid2fillmask_MSRVTT.pkl
102 | │ ├── MSRVTT_splits
103 | │ │ ├── MSRVTT_test_list.pkl
104 | │ │ ├── MSRVTT_train_list.pkl
105 | │ │ └── MSRVTT_valid_list.pkl
106 | │ └── visual
107 | │ ├── MSRVTT_C3D_test.hdf5
108 | │ ├── MSRVTT_C3D_train.hdf5
109 | │ ├── MSRVTT_C3D_valid.hdf5
110 | │ ├── MSRVTT_inceptionresnetv2_test.hdf5
111 | │ ├── MSRVTT_inceptionresnetv2_train.hdf5
112 | │ ├── MSRVTT_inceptionresnetv2_valid.hdf5
113 | │ ├── MSRVTT_vg_objects_test.hdf5
114 | │ ├── MSRVTT_vg_objects_train.hdf5
115 | │ └── MSRVTT_vg_objects_valid.hdf5
116 | └── MSVD
117 | ├── language
118 | │ ├── embedding_weights.pkl
119 | │ ├── idx2word.pkl
120 | │ ├── vid2groundtruth.pkl
121 | │ ├── vid2language.pkl
122 | │ ├── word2idx.pkl
123 | │ └── vid2fillmask_MSVD.pkl
124 | ├── MSVD_splits
125 | │ ├── MSVD_test_list.pkl
126 | │ ├── MSVD_train_list.pkl
127 | │ └── MSVD_valid_list.pkl
128 | └── visual
129 | ├── MSVD_C3D_test.hdf5
130 | ├── MSVD_C3D_train.hdf5
131 | ├── MSVD_C3D_valid.hdf5
132 | ├── MSVD_inceptionresnetv2_test.hdf5
133 | ├── MSVD_inceptionresnetv2_train.hdf5
134 | ├── MSVD_inceptionresnetv2_valid.hdf5
135 | ├── MSVD_vg_objects_test.hdf5
136 | ├── MSVD_vg_objects_train.hdf5
137 | └── MSVD_vg_objects_valid.hdf5
138 | ```
139 |
140 |
141 |
142 | ## Pretrained Model
143 |
144 | [Pretrained model on MSR-VTT](https://1drv.ms/u/s!ArYBhHmSAbFOgTicz9UR4ljs_JD4?e=YC8gKW)
145 |
146 | [Pretrained model on MSVD](https://1drv.ms/u/s!ArYBhHmSAbFOgT1ahP77Tij-6yUy?e=KeoDZV)
147 |
148 | Download pretrained model on MSR-VTT and MSVD via above links, and place them undir checkpoints dir:
149 |
150 | ```
151 | mkdir -p checkpoints/MSRVTT
152 | mkdir -p checkpoints/MSVD
153 | ```
154 |
155 | Finally got:
156 |
157 | ```
158 | checkpoints/
159 | ├── MSRVTT
160 | │ └── HMN_MSRVTT_model.ckpt
161 | └── MSVD
162 | └── HMN_MSVD_model.ckpt
163 | ```
164 |
165 |
166 |
167 | ## Training & Testing
168 |
169 | #### Training: MSR-VTT
170 |
171 | ```bash
172 | python -u main.py --dataset_name MSRVTT --entity_encoder_layer 3 --entity_decoder_layer 3 --max_objects 9 \
173 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \
174 | --backbone_3d_name C3D --backbone_3d_dim 2048 \
175 | --object_name vg_objects --object_dim 2048 \
176 | --max_epochs 16 --save_checkpoints_every 500 \
177 | --data_dir ./data --model_name HMN
178 | --language_dir_name language \
179 | --learning_rate 7e-5 --lambda_entity 0.1 --lambda_predicate 6.9 --lambda_sentence 6.9 --lambda_soft 3.5
180 | ```
181 |
182 |
183 |
184 | #### Training: MSVD
185 |
186 | ```bash
187 | python -u main.py --dataset_name MSVD --entity_encoder_layer 2 --entity_decoder_layer 2 --max_objects 8 \
188 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \
189 | --backbone_3d_name C3D --backbone_3d_dim 2048 \
190 | --object_name vg_objects --object_dim 2048 \
191 | --max_epochs 20 --save_checkpoints_every 500 \
192 | --data_dir ./data --model_name HMN \
193 | --language_dir_name language --language_package_name vid2language_old \
194 | --learning_rate 1e-4 --lambda_entity 0.6 --lambda_predicate 0.3 --lambda_sentence 1.0 --lambda_soft 0.5
195 | ```
196 |
197 |
198 |
199 | #### Testing MSR-VTT & MSVD
200 |
201 | Comment out `train_fn` in `main.py` first.
202 |
203 | ```python
204 | model = train_fn(cfgs, cfgs.model_name, model, hungary_matcher, train_loader, valid_loader, device)
205 | ```
206 |
207 |
208 |
209 | For MSR-VTT:
210 |
211 | ```
212 | python3 main.py --dataset_name MSRVTT \
213 | --entity_encoder_layer 3 --entity_decoder_layer 3 --max_objects 9 \
214 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \
215 | --backbone_3d_name C3D --backbone_3d_dim 2048 \
216 | --object_name vg_objects --object_dim 2048 \
217 | --max_epochs 16 --save_checkpoints_every 500 \
218 | --data_dir ./data --model_name HMN --learning_rate 7e-5 \
219 | --lambda_entity 0.1 --lambda_predicate 6.9 --lambda_sentence 6.9 \
220 | --lambda_soft 3.5 \
221 | --save_checkpoints_path checkpoints/MSRVTT/HMN_MSRVTT_model.ckpt
222 | ```
223 |
224 | Get performance:
225 |
226 |
227 |
228 |
229 |
230 | For MSVD:
231 |
232 | ```
233 | python3 main.py --dataset_name MSVD \
234 | --entity_encoder_layer 2 --entity_decoder_layer 2 --max_objects 8 \
235 | --backbone_2d_name inceptionresnetv2 --backbone_2d_dim 1536 \
236 | --backbone_3d_name C3D --backbone_3d_dim 2048 \
237 | --object_name vg_objects --object_dim 2048 \
238 | --max_epochs 20 --save_checkpoints_every 500 \
239 | --data_dir ./data --model_name HMN --learning_rate 1e-4 \
240 | --lambda_entity 0.6 --lambda_predicate 0.3 --lambda_a_sentence 1.0 \
241 | --lambda_soft 0.5 \
242 | --save_checkpoints_path checkpoints/MSVD/HMN_MSVD_model.ckpt
243 | ```
244 |
245 | Get performance:
246 |
247 |
248 |
249 |
250 |
251 | ## Citation
252 |
253 | If our research and this repository are helpful to your work, please cite with:
254 |
255 | ```
256 | @InProceedings{Ye_2022_CVPR,
257 | author = {Ye, Hanhua and Li, Guorong and Qi, Yuankai and Wang, Shuhui and Huang, Qingming and Yang, Ming-Hsuan},
258 | title = {Hierarchical Modular Network for Video Captioning},
259 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
260 | month = {June},
261 | year = {2022},
262 | pages = {17939-17948}
263 | }
264 | ```
265 |
266 |
267 |
268 | ## Acknowledge
269 |
270 | Code of the decoding part is based on [POS-CG](https://github.com/vsislab/Controllable_XGating).
271 |
272 |
--------------------------------------------------------------------------------
/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/configs/__init__.py
--------------------------------------------------------------------------------
/configs/settings.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | __all__ = ['TotalConfigs', 'get_settings']
4 |
5 |
6 | def _settings():
7 | parser = argparse.ArgumentParser()
8 |
9 | """
10 | =========================General Settings===========================
11 | """
12 | parser.add_argument('--seed', type=int, default=1)
13 | parser.add_argument('--drop_prob', type=float, default=0.5)
14 | parser.add_argument('--bsz', type=int, default=64, help='batch size')
15 | parser.add_argument('--sample_numb', type=int, default=15, help='how many frames would you like to sample from a given video')
16 | parser.add_argument('--model_name', type=str, default='HMN', help='which model you would like to train/test?')
17 |
18 | """
19 | =========================Data Settings===========================
20 | """
21 | parser.add_argument('--data_dir', type=str, default='./data')
22 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints')
23 | parser.add_argument('--result_dir', type=str, default='-1')
24 | parser.add_argument('--dataset_name', type=str, default='-1')
25 | parser.add_argument('--backbone_2d_name', type=str, default='-1', help='2d backbone name (InceptionResNetV2)')
26 | parser.add_argument('--backbone_3d_name', type=str, default='-1', help='3d backbone name (C3D)')
27 | parser.add_argument('--object_name', type=str, default='-1', help='object features name (vg_objects)')
28 | parser.add_argument('--semantics_dim', type=int, default=768, help='semantics embedding dim')
29 |
30 | """
31 | =========================Encoder Settings===========================
32 | """
33 | parser.add_argument('--backbone_2d_dim', type=int, default=2048, help='dimention for inceptionresnetv2')
34 | parser.add_argument('--backbone_3d_dim', type=int, default=2048, help='dimention for C3D')
35 | parser.add_argument('--object_dim', type=int, default=2048, help='dimention for vg_objects')
36 | parser.add_argument('--max_objects', type=int, default=8)
37 |
38 | parser.add_argument('--nheads', type=int, default=8)
39 | parser.add_argument('--entity_encoder_layer', type=int, default=2)
40 | parser.add_argument('--entity_decoder_layer', type=int, default=2)
41 | parser.add_argument('--dim_feedforward', type=int, default=2048)
42 | parser.add_argument('--transformer_activation', type=str, default='relu')
43 | parser.add_argument('--d_model', type=int, default=512)
44 | parser.add_argument('--transformer_dropout', type=float, default=0.1)
45 |
46 | """
47 | =========================Decoder Settings===========================
48 | """
49 | parser.add_argument('--word_embedding_dim', type=int, default=300)
50 | parser.add_argument('--hidden_dim', type=int, default=512)
51 | parser.add_argument('--num_layers', type=int, default=1)
52 |
53 |
54 | """
55 | =========================Word Dict Settings===========================
56 | """
57 | parser.add_argument('--eos_idx', type=int, default=0)
58 | parser.add_argument('--sos_idx', type=int, default=1)
59 | parser.add_argument('--unk_idx', type=int, default=2)
60 | parser.add_argument('--n_vocab', type=int, default=-1, help='how many different words are there in the dataset')
61 |
62 | """
63 | =========================Training Settings===========================
64 | """
65 | parser.add_argument('--grad_clip', type=float, default=5.0)
66 | parser.add_argument('--learning_rate', type=float, default=1e-4)
67 | parser.add_argument('--lambda_entity', type=float, default=1.0)
68 | parser.add_argument('--lambda_predicate', type=float, default=1.0)
69 | parser.add_argument('--lambda_sentence', type=float, default=1.0)
70 | parser.add_argument('--lambda_soft', type=float, default=0.1)
71 | parser.add_argument('--max_epochs', type=int, default=20)
72 | parser.add_argument('--visualize_every', type=int, default=10)
73 | parser.add_argument('--save_checkpoints_every', type=int, default=200)
74 | parser.add_argument('--save_checkpoints_path', type=str, default='-1')
75 |
76 | """
77 | =========================Testing Settings===========================
78 | """
79 | parser.add_argument('--beam_size', type=int, default=5)
80 | parser.add_argument('--max_caption_len', type=int, default=20 + 2)
81 | parser.add_argument('--temperature', type=float, default=1.0)
82 | parser.add_argument('--result_path', type=str, default='-1')
83 |
84 | args = parser.parse_args()
85 | return args
86 |
87 |
88 | class TotalConfigs:
89 | def __init__(self, args):
90 | self.data = DataConfigs(args)
91 | self.dict = DictConfigs(args)
92 | self.encoder = EncoderConfigs(args)
93 | self.decoder = DecoderConfigs(args)
94 | self.train = TrainingConfigs(args)
95 | self.test = TestConfigs(args)
96 |
97 | self.seed = args.seed
98 | self.bsz = args.bsz
99 | self.drop_prob = args.drop_prob
100 | self.model_name = args.model_name
101 | self.sample_numb = args.sample_numb
102 |
103 |
104 | class DataConfigs:
105 | def __init__(self, args):
106 | self.data_dir = args.data_dir
107 | self.checkpoints_dir = args.checkpoints_dir
108 | self.dataset_name = args.dataset_name
109 | self.backbone_2d_name = args.backbone_2d_name
110 | self.backbone_3d_name = args.backbone_3d_name
111 | self.object_name = args.object_name
112 | self.word_dim = args.word_embedding_dim
113 |
114 | assert self.dataset_name != '-1', 'Please set argument dataset_name'
115 | assert self.backbone_2d_name != '-1', 'Please set argument backbone_2d_name'
116 | assert self.backbone_3d_name != '-1', 'Please set argument backbone_3d_name'
117 | assert self.object_name != '-1', 'Please set argument object_name'
118 |
119 | # data dir
120 | self.data_dir = os.path.join(self.data_dir, self.dataset_name)
121 |
122 | # language part
123 | self.language_dir = os.path.join(self.data_dir, 'language')
124 | self.vid2language_path = os.path.join(self.language_dir, 'vid2language.pkl')
125 | self.vid2fillmask_path = os.path.join(self.data_dir, 'vid2fillmask_{}.pkl'.format(self.dataset_name))
126 | self.word2idx_path = os.path.join(self.language_dir, 'word2idx.pkl')
127 | self.idx2word_path = os.path.join(self.language_dir, 'idx2word.pkl')
128 | self.embedding_weights_path = os.path.join(self.language_dir, 'embedding_weights.pkl')
129 | self.vid2groundtruth_path = os.path.join(self.language_dir, 'vid2groundtruth.pkl')
130 |
131 | # visual part
132 | self.visual_dir = os.path.join(self.data_dir, 'visual')
133 | self.backbone2d_path_tpl = os.path.join(self.visual_dir, '{}_{}_{}.hdf5'.format(args.dataset_name, args.backbone_2d_name, '{}'))
134 | self.backbone3d_path_tpl = os.path.join(self.visual_dir, '{}_{}_{}.hdf5'.format(args.dataset_name, args.backbone_3d_name, '{}'))
135 | self.objects_path_tpl = os.path.join(self.visual_dir, '{}_{}_{}.hdf5'.format(args.dataset_name, args.object_name, '{}'))
136 |
137 | # dataset split part
138 | self.split_dir = os.path.join(self.data_dir, '{dataset_name}_splits'.format(dataset_name=self.dataset_name))
139 | self.videos_split_path_tpl = os.path.join(self.split_dir, '{}_{}_list.pkl'.format(self.dataset_name, '{}'))
140 |
141 |
142 | class DictConfigs:
143 | def __init__(self, args):
144 | self.eos_idx = args.eos_idx
145 | self.sos_idx = args.sos_idx
146 | self.unk_idx = args.unk_idx
147 | self.n_vocab = args.n_vocab
148 |
149 |
150 | class EncoderConfigs:
151 | def __init__(self, args):
152 | self.backbone_2d_dim = args.backbone_2d_dim
153 | self.backbone_3d_dim = args.backbone_3d_dim
154 | self.semantics_dim = args.semantics_dim
155 | self.object_dim = args.object_dim
156 | self.max_objects = args.max_objects
157 |
158 | self.nheads = args.nheads
159 | self.entity_encoder_layer = args.entity_encoder_layer
160 | self.entity_decoder_layer = args.entity_decoder_layer
161 | self.dim_feedforward = args.dim_feedforward
162 | self.transformer_activation = args.transformer_activation
163 | self.d_model = args.d_model
164 | self.trans_dropout = args.transformer_dropout
165 |
166 |
167 | class DecoderConfigs:
168 | def __init__(self, args):
169 | self.hidden_dim = args.hidden_dim
170 | self.num_layers = args.num_layers
171 | self.n_vocab = -1
172 |
173 |
174 | class TrainingConfigs:
175 | def __init__(self, args):
176 | self.grad_clip = args.grad_clip
177 | self.learning_rate = args.learning_rate
178 | self.lambda_entity = args.lambda_entity
179 | self.lambda_predicate = args.lambda_predicate
180 | self.lambda_sentence = args.lambda_sentence
181 | self.lambda_soft = args.lambda_soft
182 | self.max_epochs = args.max_epochs
183 | self.visualize_every = args.visualize_every
184 | self.checkpoints_dir = os.path.join(args.checkpoints_dir, args.dataset_name)
185 | self.save_checkpoints_every = args.save_checkpoints_every
186 | self.save_checkpoints_path = os.path.join(self.checkpoints_dir,
187 | '{model_name}_epochs_{max_epochs}_lr_{lr}_entity_{obj}_predicate_{act}_sentence_{v}_soft{s}_ne_{ne}_nd_{nd}_max_objects_{mo}.ckpt'.format(
188 | model_name=args.model_name,
189 | max_epochs=args.max_epochs,
190 | lr=self.learning_rate,
191 | obj=self.lambda_entity,
192 | act=self.lambda_predicate,
193 | v=self.lambda_sentence,
194 | s=self.lambda_soft,
195 | ne=args.entity_encoder_layer,
196 | nd=args.entity_decoder_layer,
197 | mo=args.max_objects))
198 | if not os.path.exists(self.checkpoints_dir):
199 | os.makedirs(self.checkpoints_dir)
200 | if args.save_checkpoints_path != '-1':
201 | self.save_checkpoints_path = args.save_checkpoints_path
202 |
203 |
204 | class TestConfigs:
205 | def __init__(self, args):
206 | self.beam_size = args.beam_size
207 | self.max_caption_len = args.max_caption_len
208 | self.temperature = args.temperature
209 | self.result_dir = os.path.join('./results/{dataset_name}'.format(dataset_name=args.dataset_name))
210 | if args.result_dir != '-1':
211 | self.result_dir = args.result_dir
212 | if not os.path.exists(self.result_dir):
213 | os.makedirs(self.result_dir)
214 | self.result_path = os.path.join(
215 | self.result_dir,
216 | '{model_name}_epochs_{max_epochs}_lr_{lr}_entity_{obj}_predicate_{act}_sentence_{v}_soft_{s}_ne_{ne}_nd_{nd}_max_objects_{mo}.pkl'.format(
217 | model_name=args.model_name,
218 | max_epochs=args.max_epochs,
219 | lr=args.learning_rate,
220 | obj=args.lambda_entity,
221 | act=args.lambda_predicate,
222 | v=args.lambda_sentence,
223 | s=args.lambda_soft,
224 | ne=args.entity_encoder_layer,
225 | nd=args.entity_decoder_layer,
226 | mo=args.max_objects)
227 | )
228 | if not os.path.exists(self.result_dir):
229 | os.makedirs(self.result_dir)
230 | if args.result_path != '-1':
231 | self.result_path = args.result_path
232 |
233 |
234 | def get_settings():
235 | args = _settings()
236 | configs = TotalConfigs(args=args)
237 | return configs
238 |
239 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/data/loader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/data/loader/__init__.py
--------------------------------------------------------------------------------
/data/loader/data_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 | import numpy as np
4 | import os
5 | import h5py
6 | import pickle
7 | from collections import defaultdict
8 | import sys
9 | sys.path.append('../../')
10 | from configs.settings import TotalConfigs
11 |
12 |
13 | def get_ids_and_probs(fillmask_steps, max_caption_len):
14 | if fillmask_steps is None:
15 | return None, None, None
16 |
17 | ret_ids, ret_probs = [], []
18 | ret_mask = torch.zeros(max_caption_len)
19 |
20 | for step in fillmask_steps:
21 | step_ids, step_probs = zip(*step)
22 | step_ids, step_probs = torch.Tensor(step_ids).long(), torch.Tensor(step_probs).float()
23 | ret_ids.append(step_ids)
24 | ret_probs.append(step_probs)
25 | gap = max_caption_len - len(fillmask_steps)
26 | for i in range(gap):
27 | zero_ids, zero_probs = torch.zeros(50).long(), torch.zeros(50).float()
28 | ret_ids.append(zero_ids)
29 | ret_probs.append(zero_probs)
30 |
31 | ret_ids = torch.cat([item[None, ...] for item in ret_ids], dim=0)
32 | ret_probs = torch.cat([item[None, ...] for item in ret_probs], dim=0)
33 | ret_mask[:len(fillmask_steps)] = 1
34 |
35 | return ret_ids, ret_probs, ret_mask
36 |
37 |
38 | class CaptionDataset(Dataset):
39 | def __init__(self, cfgs: TotalConfigs, mode, save_on_disk=False, is_total=False):
40 | """
41 | Args:
42 | args: configurations.
43 | mode: train/valid/test.
44 | save_on_disk: whether save the prediction on disk or not.
45 | True->Each video only appears once.
46 | False->The number of times each video appears depends on
47 | the number of its corresponding captions.
48 | """
49 | super(CaptionDataset, self).__init__()
50 | self.mode = mode
51 | self.save_on_disk = save_on_disk
52 | self.is_total = is_total
53 | sample_numb = cfgs.sample_numb # how many frames are sampled to perform video captioning?
54 | max_caption_len = cfgs.test.max_caption_len
55 |
56 | # language part
57 | vid2language_path = cfgs.data.vid2language_path
58 | vid2fillmask_path = cfgs.data.vid2fillmask_path
59 |
60 | # visual part
61 | backbone2d_path = cfgs.data.backbone2d_path_tpl.format(mode)
62 | backbone3d_path = cfgs.data.backbone3d_path_tpl.format(mode)
63 | objects_path = cfgs.data.objects_path_tpl.format(mode)
64 |
65 | # dataset split part
66 | videos_split_path = cfgs.data.videos_split_path_tpl.format(mode)
67 |
68 | with open(videos_split_path, 'rb') as f:
69 | video_ids = pickle.load(f)
70 |
71 | self.video_ids = video_ids
72 | self.corresponding_vid = []
73 |
74 | self.backbone_2d_dict = {}
75 | self.backbone_3d_dict = {}
76 | self.objects_dict = {}
77 | self.total_entries = [] # (numberic words, original caption)
78 | self.vid2captions = defaultdict(list)
79 |
80 | # feature 2d dict
81 | with h5py.File(backbone2d_path, 'r') as f:
82 | for vid in video_ids:
83 | temp_feat = f[vid][()]
84 | sampled_idxs = np.linspace(0, len(temp_feat) - 1, sample_numb, dtype=int)
85 | self.backbone_2d_dict[vid] = temp_feat[sampled_idxs]
86 |
87 | # feature 3d dict
88 | with h5py.File(backbone3d_path, 'r') as f:
89 | for vid in video_ids:
90 | temp_feat = f[vid][()]
91 | sampled_idxs = np.linspace(0, len(temp_feat) - 1, sample_numb, dtype=int)
92 | self.backbone_3d_dict[vid] = temp_feat[sampled_idxs]
93 |
94 | # feature object dict
95 | with h5py.File(objects_path, 'r') as f:
96 | for vid in video_ids:
97 | temp_feat = f[vid]['feats'][()]
98 | self.objects_dict[vid] = temp_feat
99 |
100 | with open(vid2language_path, 'rb') as f:
101 | self.vid2language = pickle.load(f)
102 |
103 | if cfgs.train.lambda_soft > 0 and not save_on_disk:
104 | with open(vid2fillmask_path, 'rb') as f:
105 | self.vid2fillmask = pickle.load(f)
106 |
107 | for vid in video_ids:
108 | fillmask_dict = self.vid2fillmask[vid] if cfgs.train.lambda_soft > 0 and not save_on_disk and vid in self.vid2fillmask else None
109 | for item in self.vid2language[vid]:
110 | caption, numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec = item
111 | current_mask = fillmask_dict[caption] if fillmask_dict is not None else None
112 | vocab_ids, vocab_probs, fillmasks = get_ids_and_probs(current_mask, max_caption_len)
113 | self.total_entries.append((numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec, vocab_ids, vocab_probs, fillmasks))
114 | self.corresponding_vid.append(vid)
115 | self.vid2captions[vid].append(caption)
116 |
117 | def __getitem__(self, idx):
118 | """
119 | Returns:
120 | feature2d: (sample_numb, dim2d)
121 | feature3d: (sample_numb, dim3d)
122 | objects: (sample_numb * object_num, dim_obj) or (object_num_per_video, dim_obj)
123 | numberic: (max_caption_len, )
124 | captions: List[str]
125 | vid: str
126 | """
127 | vid = self.corresponding_vid[idx] if (self.mode == 'train' and not self.save_on_disk) or self.is_total else self.video_ids[idx]
128 | choose_idx = 0
129 |
130 | feature2d = self.backbone_2d_dict[vid]
131 | feature3d = self.backbone_3d_dict[vid]
132 | objects = self.objects_dict[vid]
133 |
134 | if (self.mode == 'train' and not self.save_on_disk) or self.is_total:
135 | numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec, vocab_ids, vocab_probs, fillmasks = self.total_entries[idx]
136 | else:
137 | numberic_cap, vp_semantics, caption_semantics, nouns, nouns_vec = self.vid2language[vid][choose_idx][1:]
138 | vocab_ids, vocab_probs, fillmasks = None, None, None
139 |
140 | captions = self.vid2captions[vid]
141 | nouns_dict = {'nouns': nouns, 'vec': torch.FloatTensor(nouns_vec)}
142 |
143 | return torch.FloatTensor(feature2d), torch.FloatTensor(feature3d), torch.FloatTensor(objects), \
144 | torch.LongTensor(numberic_cap), \
145 | torch.FloatTensor(vp_semantics), \
146 | torch.FloatTensor(caption_semantics), captions, nouns_dict, vid, \
147 | vocab_ids, vocab_probs, fillmasks
148 |
149 | def __len__(self):
150 | if (self.mode == 'train' and not self.save_on_disk) or self.is_total:
151 | return len(self.total_entries)
152 | else:
153 | return len(self.video_ids)
154 |
155 |
156 | def collate_fn_caption(batch):
157 | feature2ds, feature3ds, objects, numberic_caps, \
158 | vp_semantics, caption_semantics, captions, nouns_dict_list, vids, \
159 | vocab_ids, vocab_probs, fillmasks = zip(*batch)
160 |
161 | bsz, obj_dim = len(feature2ds), objects[0].shape[-1]
162 | longest_objects_num = max([item.shape[0] for item in objects])
163 | ret_objects = torch.zeros([bsz, longest_objects_num, obj_dim])
164 | ret_objects_mask = torch.ones([bsz, longest_objects_num])
165 | for i in range(bsz):
166 | ret_objects[i, :objects[i].shape[0], :] = objects[i]
167 | ret_objects_mask[i, :objects[i].shape[0]] = 0.0
168 |
169 | feature2ds = torch.cat([item[None, ...] for item in feature2ds], dim=0) # (bsz, sample_numb, dim_2d)
170 | feature3ds = torch.cat([item[None, ...] for item in feature3ds], dim=0) # (bsz, sample_numb, dim_3d)
171 |
172 | vp_semantics = torch.cat([item[None, ...] for item in vp_semantics], dim=0) # (bsz, dim_sem)
173 | caption_semantics = torch.cat([item[None, ...] for item in caption_semantics], dim=0) # (bsz, dim_sem)
174 |
175 | numberic_caps = torch.cat([item[None, ...] for item in numberic_caps], dim=0) # (bsz, seq_len)
176 | masks = numberic_caps > 0
177 |
178 | captions = [item for item in captions]
179 | nouns = list(nouns_dict_list)
180 | vids = list(vids)
181 | vocab_ids = torch.cat([item[None, ...] for item in vocab_ids], dim=0).long() if vocab_ids[0] is not None else None # (bsz, seq_len, 50)
182 | vocab_probs = torch.cat([item[None, ...] for item in vocab_probs], dim=0).float() if vocab_probs[0] is not None else None # (bsz, seq_len, 50)
183 | fillmasks = torch.cat([item[None, ...] for item in fillmasks], dim=0).float() if fillmasks[0] is not None else None # (bsz, seq_len)
184 |
185 | return feature2ds.float(), feature3ds.float(), ret_objects.float(), ret_objects_mask.float(), \
186 | vp_semantics.float(), caption_semantics.float(), \
187 | numberic_caps.long(), masks.float(), captions, nouns, vids, \
188 | vocab_ids, vocab_probs, fillmasks
189 |
190 |
191 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 | from collections import OrderedDict, defaultdict
4 |
5 | from utils.coco_caption.pycocoevalcap.bleu.bleu import Bleu
6 | from utils.coco_caption.pycocoevalcap.cider.cider import Cider
7 | from utils.coco_caption.pycocoevalcap.meteor.meteor import Meteor
8 | from utils.coco_caption.pycocoevalcap.rouge.rouge import Rouge
9 | from configs.settings import TotalConfigs
10 |
11 | def language_eval(sample_seqs, groundtruth_seqs):
12 | assert len(sample_seqs) == len(groundtruth_seqs), 'length of sampled seqs is different from that of groundtruth seqs!'
13 |
14 | references, predictions = OrderedDict(), OrderedDict()
15 | for i in range(len(groundtruth_seqs)):
16 | references[i] = [groundtruth_seqs[i][j] for j in range(len(groundtruth_seqs[i]))]
17 | for i in range(len(sample_seqs)):
18 | predictions[i] = [sample_seqs[i]]
19 |
20 | predictions = {i: predictions[i] for i in range(len(sample_seqs))}
21 | references = {i: references[i] for i in range(len(groundtruth_seqs))}
22 |
23 | avg_bleu_score, bleu_score = Bleu(4).compute_score(references, predictions)
24 | print('avg_bleu_score == ', avg_bleu_score)
25 | avg_cider_score, cider_score = Cider().compute_score(references, predictions)
26 | print('avg_cider_score == ', avg_cider_score)
27 | avg_meteor_score, meteor_score = Meteor().compute_score(references, predictions)
28 | print('avg_meteor_score == ', avg_meteor_score)
29 | avg_rouge_score, rouge_score = Rouge().compute_score(references, predictions)
30 | print('avg_rouge_score == ', avg_rouge_score)
31 |
32 | return {'BLEU': avg_bleu_score, 'CIDEr': avg_cider_score, 'METEOR': avg_meteor_score, 'ROUGE': avg_rouge_score}
33 |
34 |
35 | def decode_idx(seq, itow, eos_idx):
36 | ret = ''
37 | length = seq.shape[0]
38 | for i in range(length):
39 | if seq[i] == eos_idx: break
40 | if i > 0: ret += ' '
41 | ret += itow[seq[i]]
42 | return ret
43 |
44 |
45 | @torch.no_grad()
46 | def eval_fn(model, loader, device, idx2word, save_on_disk, cfgs: TotalConfigs, vid2groundtruth)->dict:
47 | model.eval()
48 | if save_on_disk:
49 | result_dict = {}
50 | predictions, gts = [], []
51 |
52 | for i, (feature2ds, feature3ds, object_feats, object_masks, \
53 | vp_semantics, caption_semantics, numberic_caps, masks, \
54 | captions, nouns_dict_list, vids, vocab_ids, vocab_probs, fillmasks) \
55 | in enumerate(loader):
56 | feature2ds = feature2ds.to(device)
57 | feature3ds = feature3ds.to(device)
58 | object_feats = object_feats.to(device)
59 | object_masks = object_masks.to(device)
60 | vp_semantics = vp_semantics.to(device)
61 | caption_semantics = caption_semantics.to(device)
62 | numberic_caps = numberic_caps.to(device)
63 | masks = masks.to(device)
64 |
65 | pred, seq_probabilities = model.sample(object_feats, object_masks, feature2ds, feature3ds)
66 | pred = pred.cpu().numpy()
67 | batch_pred = [decode_idx(single_seq, idx2word, cfgs.dict.eos_idx) for single_seq in pred]
68 | predictions += batch_pred
69 | batch_gts = [vid2groundtruth[id] for id in vids] if save_on_disk else [item for item in captions]
70 | gts += batch_gts
71 |
72 | if save_on_disk:
73 | assert len(batch_pred) == len(vids), \
74 | 'expect len(batch_pred) == len(vids), ' \
75 | 'but got len(batch_pred) == {} and len(vids) == {}'.format(len(batch_pred), len(vids))
76 | for vid, pred in zip(vids, batch_pred):
77 | result_dict[vid] = pred
78 |
79 | model.train()
80 | score_states = language_eval(sample_seqs=predictions, groundtruth_seqs=gts)
81 |
82 | if save_on_disk:
83 | with open(cfgs.test.result_path, 'wb') as f:
84 | pickle.dump(result_dict, f)
85 |
86 | return score_states
87 |
88 |
--------------------------------------------------------------------------------
/figures/Entity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/Entity.png
--------------------------------------------------------------------------------
/figures/HMN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/HMN.png
--------------------------------------------------------------------------------
/figures/MSRVTT-performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/MSRVTT-performance.png
--------------------------------------------------------------------------------
/figures/MSVD-performance.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/MSVD-performance.png
--------------------------------------------------------------------------------
/figures/motivation.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/figures/motivation.png
--------------------------------------------------------------------------------
/hmn.yaml:
--------------------------------------------------------------------------------
1 | name: hmn_env
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=5.1=1_gnu
8 | - blas=1.0=mkl
9 | - ca-certificates=2022.07.19=h06a4308_0
10 | - certifi=2022.6.15=py37h06a4308_0
11 | - cudatoolkit=10.1.243=h6bb024c_0
12 | - fftw=3.3.9=h27cfd23_1
13 | - freetype=2.11.0=h70c0345_0
14 | - giflib=5.2.1=h7b6447c_0
15 | - h5py=3.7.0=py37h737f45e_0
16 | - hdf5=1.10.6=h3ffc7dd_1
17 | - intel-openmp=2021.4.0=h06a4308_3561
18 | - jpeg=9e=h7f8727e_0
19 | - lcms2=2.12=h3be6417_0
20 | - ld_impl_linux-64=2.38=h1181459_1
21 | - lerc=3.0=h295c915_0
22 | - libdeflate=1.8=h7f8727e_5
23 | - libffi=3.3=he6710b0_2
24 | - libgcc-ng=11.2.0=h1234567_1
25 | - libgfortran-ng=11.2.0=h00389a5_1
26 | - libgfortran5=11.2.0=h1234567_1
27 | - libgomp=11.2.0=h1234567_1
28 | - libpng=1.6.37=hbc83047_0
29 | - libstdcxx-ng=11.2.0=h1234567_1
30 | - libtiff=4.4.0=hecacb30_0
31 | - libwebp=1.2.2=h55f646e_0
32 | - libwebp-base=1.2.2=h7f8727e_0
33 | - lz4-c=1.9.3=h295c915_1
34 | - mkl=2021.4.0=h06a4308_640
35 | - mkl-service=2.4.0=py37h7f8727e_0
36 | - mkl_fft=1.3.1=py37hd3c417c_0
37 | - mkl_random=1.2.2=py37h51133e4_0
38 | - ncurses=6.3=h5eee18b_3
39 | - ninja=1.10.2=h06a4308_5
40 | - ninja-base=1.10.2=hd09550d_5
41 | - numpy=1.21.5=py37h6c91a56_3
42 | - numpy-base=1.21.5=py37ha15fc14_3
43 | - openssl=1.1.1q=h7f8727e_0
44 | - pillow=9.2.0=py37hace64e9_1
45 | - pip=22.1.2=py37h06a4308_0
46 | - python=3.7.13=h12debd9_0
47 | - pytorch=1.4.0=py3.7_cuda10.1.243_cudnn7.6.3_0
48 | - readline=8.1.2=h7f8727e_1
49 | - scipy=1.7.3=py37h6c91a56_2
50 | - setuptools=63.4.1=py37h06a4308_0
51 | - six=1.16.0=pyhd3eb1b0_1
52 | - sqlite=3.39.2=h5082296_0
53 | - tk=8.6.12=h1ccaba5_0
54 | - torchvision=0.5.0=py37_cu101
55 | - wheel=0.37.1=pyhd3eb1b0_0
56 | - xz=5.2.5=h7f8727e_1
57 | - zlib=1.2.12=h5eee18b_3
58 | - zstd=1.5.2=ha4553b6_0
59 | prefix: /home/marcusnerva/anaconda3/envs/hmn_env
60 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import random
4 | import numpy as np
5 | import os
6 |
7 | from train import train_fn
8 | from test import test_fn
9 | from utils.build_loaders import build_loaders
10 | from utils.build_model import build_model
11 | from configs.settings import get_settings
12 | from models.hungary import HungarianMatcher
13 |
14 | def set_random_seed(seed):
15 | random.seed(seed)
16 | os.environ['PYTHONHASHSEED'] = str(seed)
17 | np.random.seed(seed)
18 | torch.manual_seed(seed)
19 | torch.cuda.manual_seed_all(seed)
20 | torch.backends.cudnn.deterministic = True
21 | torch.backends.cudnn.benchmark = False
22 |
23 |
24 | if __name__ == '__main__':
25 | cfgs = get_settings()
26 | set_random_seed(seed=cfgs.seed)
27 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
28 |
29 | train_loader, valid_loader, test_loader = build_loaders(cfgs)
30 |
31 | hungary_matcher = HungarianMatcher()
32 | model = build_model(cfgs)
33 | model = model.float()
34 | model.to(device)
35 |
36 | model = train_fn(cfgs, cfgs.model_name, model, hungary_matcher, train_loader, valid_loader, device)
37 | model.load_state_dict(torch.load(cfgs.train.save_checkpoints_path))
38 | model.eval()
39 | test_fn(cfgs, model, test_loader, device)
40 |
41 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/models/__init__.py
--------------------------------------------------------------------------------
/models/caption_models/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/models/caption_models/caption_module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class CaptionModule(nn.Module):
6 | """
7 | CaptionModule and its child classes are complementary.
8 | """
9 | def __init__(self, beam_size):
10 | super(CaptionModule, self).__init__()
11 | self.beam_size = beam_size
12 |
13 | def __beam_step(self, t, logprobs, beam_seq, beam_seq_logprobs, beam_logprobs_sum, state):
14 | beam_size = self.beam_size
15 |
16 | probs, idx = torch.sort(logprobs, dim=1, descending=True)
17 | candidates = []
18 | rows = beam_size if t >= 1 else 1
19 | cols = min(beam_size, probs.size(1))
20 |
21 | for r in range(rows):
22 | for c in range(cols):
23 | tmp_logprob = probs[r, c]
24 | tmp_sum = beam_logprobs_sum[r] + tmp_logprob
25 | tmp_idx = idx[r, c]
26 | candidates.append({'sum': tmp_sum, 'logprob': tmp_logprob, 'ix': tmp_idx, 'beam': r})
27 |
28 | candidates = sorted(candidates, key=lambda x: -x['sum'])
29 | prev_seq = beam_seq[:, :t].clone()
30 | prev_seq_probs = beam_seq_logprobs[:, :t].clone()
31 | prev_logprobs_sum = beam_logprobs_sum.clone()
32 | new_state = [_.clone() for _ in state]
33 |
34 | for i in range(beam_size):
35 | candidate_i = candidates[i]
36 | beam = candidate_i['beam']
37 | ix = candidate_i['ix']
38 | logprob = candidate_i['logprob']
39 |
40 | beam_seq[i, :t] = prev_seq[beam, :]
41 | beam_seq_logprobs[i, :t] = prev_seq_probs[beam, :]
42 | beam_seq[i, t] = ix
43 | beam_seq_logprobs[i, t] = logprob
44 | beam_logprobs_sum[i] = prev_logprobs_sum[beam] + logprob
45 | for j in range(len(new_state)):
46 | new_state[j][:, i, :] = state[j][:, beam, :]
47 |
48 | return beam_seq, beam_seq_logprobs, beam_logprobs_sum, new_state
49 |
50 | def __beam_search(self, objects_feats, action_feats, caption_feats, objects_pending, action_pending, caption_pending, state):
51 | beam_size = self.beam_size
52 | device = caption_feats.device if caption_feats is not None else objects_feats.device
53 |
54 | beam_seq = torch.LongTensor(beam_size, self.max_caption_len).fill_(self.eos_idx)
55 | beam_seq_logprobs = torch.FloatTensor(beam_size, self.max_caption_len).zero_()
56 | beam_logprobs_sum = torch.zeros(beam_size)
57 | ret = []
58 |
59 | it = torch.LongTensor(beam_size).fill_(self.sos_idx).to(device)
60 | it_embed = self.embedding(it)
61 | output_prob, state = self.forward_decoder(objects_feats, action_feats, caption_feats, objects_pending, action_pending, caption_pending, it_embed, state)
62 | logprob = output_prob
63 |
64 | for t in range(self.max_caption_len):
65 | # suppress UNK tokens in the decoding. So the probs of 'UNK' are extremely low
66 | logprob[:, self.unk_idx] = logprob[:, self.unk_idx] - 1000.0
67 | beam_seq, beam_seq_logprobs, beam_logprobs_sum, state = self.__beam_step(t=t,
68 | logprobs=logprob,
69 | beam_seq=beam_seq,
70 | beam_seq_logprobs=beam_seq_logprobs,
71 | beam_logprobs_sum=beam_logprobs_sum,
72 | state=state)
73 |
74 | for j in range(beam_size):
75 | if beam_seq[j, t] == self.eos_idx or t == self.max_caption_len - 1:
76 | final_beam = {
77 | 'seq': beam_seq[j, :].clone(),
78 | 'seq_logprob': beam_seq_logprobs[j, :].clone(),
79 | 'sum_logprob': beam_logprobs_sum[j].clone()
80 | }
81 | ret.append(final_beam)
82 | beam_logprobs_sum[j] = -1000.0
83 |
84 | it = beam_seq[:, t].to(device)
85 | it_embed = self.embedding(it).to(device)
86 | output_prob, state = self.forward_decoder(objects_feats, action_feats, caption_feats, objects_pending, action_pending, caption_pending, it_embed, state)
87 | logprob = output_prob
88 |
89 | ret = sorted(ret, key=lambda x: -x['sum_logprob'])[:beam_size]
90 | return ret
91 |
92 | def sample_beam(self, objects_feats, vps_feats, caption_feats, objects_semantics, action_semantics, caption_semantics):
93 | beam_size = self.beam_size
94 | batch_size = caption_feats.shape[0] if caption_feats is not None else objects_feats.shape[0]
95 | hidden_dim = caption_feats.shape[-1] if caption_feats is not None else objects_feats.shape[-1]
96 | device = caption_feats.device if caption_feats is not None else objects_feats.device
97 |
98 | seq = torch.LongTensor(batch_size, self.max_caption_len).fill_(self.eos_idx)
99 | seq_probabilities = torch.FloatTensor(batch_size, self.max_caption_len)
100 | done_beam = [[] for _ in range(batch_size)]
101 |
102 | for i in range(batch_size):
103 | single_objects_feats = objects_feats[i, ...][None, ...] if objects_feats is not None else None # (1, sample_numb, obj_per_frame, hidden_dim)
104 | single_vps_feats = vps_feats[i, ...][None, ...] if vps_feats is not None else None # (1, sample_numb, hidden_dim)
105 | single_caption_feats = caption_feats[i, ...][None, ...] if caption_feats is not None else None # (1, sample_numb, hidden_dim)
106 | single_objects_semantics = objects_semantics[i, ...][None, ...] if objects_semantics is not None else None # (1, max_objects, word_dim)
107 | single_action_semantics = action_semantics[i, ...][None, ...] if action_semantics is not None else None # (1, semantics_dim)
108 | single_caption_semantics = caption_semantics[i, ...][None, ...] if caption_semantics is not None else None # (1, semantics_dim)
109 | # print('====={}'.format(single_objects_semantics.shape))
110 |
111 | single_objects_feats = single_objects_feats.repeat(beam_size, 1, 1) if single_objects_feats is not None else None # (beam_size, max_objects, hidden_dim)
112 | single_vps_feats = single_vps_feats.repeat(beam_size, 1, 1) if single_vps_feats is not None else None # (beam_size, sample_numb, hidden_dim)
113 | single_caption_feats = single_caption_feats.repeat(beam_size, 1, 1) if single_caption_feats is not None else None # (beam_size, sample_numb, hidden_dim)
114 | single_objects_semantics = single_objects_semantics.repeat(beam_size, 1, 1) if single_objects_semantics is not None else None # (beam_size, max_objects, word_dim)
115 | single_action_semantics = single_action_semantics.repeat(beam_size, 1) if single_action_semantics is not None else None # (beam_size, semantics_dim)
116 | single_caption_semantics = single_caption_semantics.repeat(beam_size, 1) if single_caption_semantics is not None else None # (beam_size, semantics_dim)
117 |
118 | state = self.get_rnn_init_hidden(beam_size, hidden_dim, device)
119 |
120 | done_beam[i] = self.__beam_search(single_objects_feats, single_vps_feats,
121 | single_caption_feats, single_objects_semantics,
122 | single_action_semantics, single_caption_semantics,
123 | state)
124 | seq[i, ...] = done_beam[i][0]['seq']
125 | seq_probabilities[i, ...] = done_beam[i][0]['seq_logprob']
126 |
127 | return seq, seq_probabilities
128 |
129 | def sample(self, objects, object_masks, feature2ds, feature3ds, is_sample_max=True):
130 | beam_size = self.beam_size
131 | temperature = self.temperature
132 | batch_size = feature2ds.shape[0]
133 | device = feature2ds.device
134 |
135 | objects_feats, action_feats, caption_feats, \
136 | objects_semantics, action_semantics, caption_semantics = self.forward_encoder(objects, object_masks, feature2ds, feature3ds)
137 |
138 | if beam_size > 1:
139 | return self.sample_beam(objects_feats, action_feats, caption_feats,
140 | objects_semantics, action_semantics, caption_semantics)
141 |
142 | state = self.get_rnn_init_hidden(batch_size, device)
143 | seq, seq_probabilities = [], []
144 |
145 | for t in range(self.max_caption_len):
146 | if t == 0:
147 | it = objects_feats.new(batch_size).fill_(self.sos_idx).long()
148 | elif is_sample_max:
149 | sampleLogprobs, it = torch.max(log_probabilities.detach(), 1)
150 | it = it.view(-1).long()
151 | else:
152 | prev_probabilities = torch.exp(torch.div(log_probabilities.detach(), temperature))
153 | it = torch.multinomial(prev_probabilities, 1)
154 | sampleLogprobs = log_probabilities.gather(1, it)
155 | it = it.view(-1).long()
156 |
157 | it_embed = self.embedding(it)
158 |
159 | if t >= 1:
160 | if t == 1:
161 | unfinished = it > 0
162 | else:
163 | unfinished = unfinished * (it > 0)
164 | # if unfinished.sum() == 0: break
165 | it = it * unfinished.type_as(it)
166 | seq.append(it)
167 | seq_probabilities.append(sampleLogprobs.view(-1))
168 |
169 | it_embed = it_embed.to(device)
170 | log_probabilities, state = self.forward_decoder(objects_feats, action_feats, caption_feats,
171 | objects_semantics, action_semantics,
172 | caption_semantics, it_embed, state)
173 |
174 | seq.append(it.new(batch_size).long().fill_(self.eos_idx))
175 | seq_probabilities.append(sampleLogprobs.view(-1))
176 | return torch.cat([_.unsqueeze(1) for _ in seq], 1), torch.cat([_.unsqueeze(1) for _ in seq_probabilities], 1)
177 |
--------------------------------------------------------------------------------
/models/caption_models/hierarchical_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from models.caption_models.caption_module import CaptionModule
4 |
5 |
6 | class HierarchicalModel(CaptionModule):
7 | def __init__(self, entity_level: nn.Module, predicate_level: nn.Module, sentence_level: nn.Module,
8 | decoder: nn.Module, word_embedding_weights, max_caption_len: int, beam_size: int, pad_idx=0,
9 | temperature=1, eos_idx=0, sos_idx=-1, unk_idx=-1):
10 | """
11 | Args:
12 | entity_level: for encoding objects information.
13 | predicate_level: for encoding action information.
14 | sentence_level: for encoding the whole video information.
15 | decoder: for generating words.
16 | word_embedding_weights: pretrained word embedding weight.
17 | max_caption_len: generated sentences are no longer than max_caption_len.
18 | pad_idx: corresponding index of ''.
19 | """
20 | super(HierarchicalModel, self).__init__(beam_size=beam_size)
21 | self.entity_level = entity_level
22 | self.predicate_level = predicate_level
23 | self.sentence_level = sentence_level
24 | self.decoder = decoder
25 | self.max_caption_len = max_caption_len
26 | self.temperature = temperature
27 | self.eos_idx = eos_idx
28 | self.sos_idx = sos_idx
29 | self.unk_idx = unk_idx
30 | self.num_layers = decoder.num_layers
31 | self.embedding = nn.Embedding.from_pretrained(torch.FloatTensor(word_embedding_weights),
32 | freeze=False, padding_idx=pad_idx)
33 |
34 | def get_rnn_init_hidden(self, bsz, hidden_size, device):
35 | # (hidden_state, cell_state)
36 | return (torch.zeros(self.num_layers, bsz, hidden_size).to(device),
37 | torch.zeros(self.num_layers, bsz, hidden_size).to(device))
38 |
39 | def forward_encoder(self, objects, objects_mask, feature2ds, feature3ds):
40 | """
41 |
42 | Args:
43 | objects: (bsz, max_objects_per_video, object_dim)
44 | objects_mask: (bsz, max_objects_per_video)
45 | feature2ds: (bsz, sample_numb, feature2d_dim)
46 | feature3ds: (bsz, sample_numb, feature3d_dim)
47 |
48 | Returns:
49 | objects_feats: (bsz, max_objects, hidden_dim)
50 | action_feats: (bsz, sample_numb, hidden_dim)
51 | video_feats: (bsz, sample_numb, hidden_dim)
52 |
53 | objects_semantics: (bsz, max_objects, word_dim)
54 | action_semantics: (bsz, semantics_dim)
55 | video_semantics: (bsz, semantics_dim)
56 | """
57 | objects_feats, objects_semantics = self.entity_level(feature2ds, feature3ds, objects, objects_mask)
58 | action_feats, action_semantics = self.predicate_level(feature3ds, objects_feats, objects_mask)
59 | video_feats, video_semantics = self.sentence_level(feature2ds, action_feats, objects_feats, objects_mask)
60 |
61 | return objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics
62 |
63 | def forward_decoder(self, objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics, pre_embedding, pre_state):
64 | """
65 |
66 | Args:
67 | objects_feats: (bsz, max_objects, hidden_dim)
68 | action_feats: (bsz, sample_numb, hidden_dim)
69 | video_feats: (bsz, sample_numb, hidden_dim)
70 | objects_semantics: (bsz, max_objects, word_dim)
71 | action_semantics: (bsz, semantics_dim)
72 | video_semantics: (bsz, semantics_dim)
73 |
74 | pre_embedding: (bsz, word_embed_dim)
75 | pre_state: (hidden_state, cell_state)
76 |
77 | Returns:
78 | output_prob: (bsz, n_vocab)
79 | current_state: (hidden_state, cell_state)
80 | """
81 | output_prob, current_state = self.decoder(objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics, pre_embedding, pre_state)
82 | return output_prob, current_state
83 |
84 | def forward(self, objects_feats, objects_mask, feature2ds, feature3ds, numberic_captions):
85 | """
86 |
87 | Args:
88 | numberic_captions: (bsz, max_caption_len)
89 |
90 | Returns:
91 | ret_seq: (bsz, max_caption_len, n_vocab)
92 | """
93 | bsz, n_vocab = feature2ds.shape[0], self.decoder.n_vocab
94 | device = objects_feats.device
95 | objects_feats, action_feats, video_feats, objects_semantics, action_semantics, video_semantics = self.forward_encoder(objects_feats, objects_mask, feature2ds, feature3ds)
96 | state = self.get_rnn_init_hidden(bsz=bsz, hidden_size=self.decoder.hidden_dim, device=device)
97 | outputs = []
98 |
99 | for i in range(self.max_caption_len):
100 | if i > 0 and numberic_captions[:, i].sum() == 0:
101 | output_word = torch.zeros([bsz, n_vocab]).cuda()
102 | outputs.append(output_word)
103 | continue
104 |
105 | it = numberic_captions[:, i].clone()
106 | it_embeded = self.embedding(it)
107 | output_word, state = self.forward_decoder(objects_feats, action_feats, video_feats,
108 | objects_semantics, action_semantics,
109 | video_semantics, it_embeded, state)
110 | outputs.append(output_word)
111 |
112 | ret_seq = torch.stack(outputs, dim=1)
113 | return ret_seq, objects_semantics, action_semantics, video_semantics
114 |
115 |
--------------------------------------------------------------------------------
/models/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | class Decoder(nn.Module):
5 | def __init__(self, semantics_dim, hidden_dim, num_layers, embed_dim, n_vocab, with_objects, with_action, with_video, with_objects_semantics, with_action_semantics, with_video_semantics):
6 | super(Decoder, self).__init__()
7 | self.n_vocab = n_vocab
8 | self.hidden_dim = hidden_dim
9 | self.num_layers = num_layers
10 | self.W = nn.Linear(hidden_dim, hidden_dim, bias=False)
11 |
12 | self.with_objects_semantics = with_objects_semantics
13 | self.with_action_semantics = with_action_semantics
14 | self.with_video_semantics = with_video_semantics
15 |
16 | total_visual_dim = 0
17 | total_semantics_dim = 0
18 |
19 | # objects visual features and corresponding semantics
20 | if with_objects:
21 | setattr(self, 'Uo', nn.Linear(hidden_dim, hidden_dim, bias=False))
22 | setattr(self, 'bo', nn.Parameter(torch.ones(hidden_dim), requires_grad=True))
23 | setattr(self, 'wo', nn.Linear(hidden_dim, 1, bias=False))
24 | total_visual_dim += hidden_dim
25 | if with_objects_semantics:
26 | setattr(self, 'Uos', nn.Linear(semantics_dim, hidden_dim, bias=False))
27 | setattr(self, 'bos', nn.Parameter(torch.ones(hidden_dim), requires_grad=True))
28 | setattr(self, 'wos', nn.Linear(hidden_dim, 1, bias=False))
29 | total_semantics_dim += semantics_dim
30 |
31 | # action visual features and corresponding semantics
32 | if with_action:
33 | setattr(self, 'Um', nn.Linear(hidden_dim, hidden_dim, bias=False))
34 | setattr(self, 'bm', nn.Parameter(torch.ones(hidden_dim), requires_grad=True))
35 | setattr(self, 'wm', nn.Linear(hidden_dim, 1, bias=False))
36 | total_visual_dim += hidden_dim
37 | if with_action_semantics:
38 | total_semantics_dim += semantics_dim
39 |
40 | # video visual features and corresponding semantics
41 | if with_video:
42 | setattr(self, 'Uv', nn.Linear(hidden_dim, hidden_dim, bias=False))
43 | setattr(self, 'bv', nn.Parameter(torch.ones(hidden_dim), requires_grad=True))
44 | setattr(self, 'wv', nn.Linear(hidden_dim, 1, bias=False))
45 | total_visual_dim += hidden_dim
46 | if with_video_semantics:
47 | total_semantics_dim += semantics_dim
48 |
49 | # fuse visual features together
50 | if total_visual_dim != hidden_dim:
51 | setattr(self, 'linear_visual_layer', nn.Linear(total_visual_dim, hidden_dim))
52 |
53 | # fuse semantics features together
54 | if with_objects_semantics or with_action_semantics or with_video_semantics:
55 | setattr(self, 'linear_semantics_layer', nn.Linear(total_semantics_dim, hidden_dim))
56 |
57 | with_semantics = with_objects_semantics or with_action_semantics or with_video_semantics
58 | self.lstm = nn.LSTM(input_size=hidden_dim * 2 + embed_dim if with_semantics else hidden_dim + embed_dim,
59 | hidden_size=hidden_dim,
60 | num_layers=num_layers)
61 | self.to_word = nn.Linear(hidden_dim, embed_dim)
62 | self.logit = nn.Linear(embed_dim, n_vocab)
63 | self.__init_weight()
64 |
65 | def __init_weight(self):
66 | init_range = 0.1
67 | self.logit.bias.data.fill_(0)
68 | self.logit.weight.data.uniform_(-init_range, init_range)
69 |
70 | def forward(self, objects, action, video, object_semantics, action_semantics, video_semantics, embed, last_states):
71 | last_hidden = last_states[0][0] # (bsz, hidden_dim)
72 | Wh = self.W(last_hidden) # (bsz, hidden_dim)
73 | U_obj = self.Uo(objects) if hasattr(self, 'Uo') else None # (bsz, max_objects, hidden_dim)
74 | U_objs = self.Uos(object_semantics) if hasattr(self, 'Uos') else None # (bsz, max_objects, emb_dim)
75 | U_action = self.Um(action) if hasattr(self, 'Um') else None # (bsz, sample_numb, hidden_dim)
76 | U_video = self.Uv(video) if hasattr(self, 'Uv') else None # (bsz, sample_numb, hidden_dim)
77 |
78 | # for visual features
79 | if U_obj is not None:
80 | attn_weights = self.wo(torch.tanh(Wh[:, None, :] + U_obj + self.bo))
81 | attn_weights = attn_weights.softmax(dim=1) # (bsz, max_objects, 1)
82 | attn_objects = attn_weights * objects # (bsz, max_objects, hidden_dim)
83 | attn_objects = attn_objects.sum(dim=1) # (bsz, hidden_dim)
84 | else:
85 | attn_objects = None
86 |
87 | if U_action is not None:
88 | attn_weights = self.wm(torch.tanh(Wh[:, None, :] + U_action + self.bm))
89 | attn_weights = attn_weights.softmax(dim=1) # (bsz, sample_numb, 1)
90 | attn_motion = attn_weights * action # (bsz, sample_numb, hidden_dim)
91 | attn_motion = attn_motion.sum(dim=1) # (bsz, hidden_dim)
92 | else:
93 | attn_motion = None
94 |
95 | if U_video is not None:
96 | attn_weights = self.wv(torch.tanh(Wh[:, None, :] + U_video + self.bv))
97 | attn_weights = attn_weights.softmax(dim=1) # (bsz, sample_numb, 1)
98 | attn_video = attn_weights * video # (bsz, sample_numb, hidden_dim)
99 | attn_video = attn_video.sum(dim=1) # (bsz, hidden_dim)
100 | else:
101 | attn_video = None
102 |
103 | feats_list = []
104 | if attn_video is not None:
105 | feats_list.append(attn_video)
106 | if attn_motion is not None:
107 | feats_list.append(attn_motion)
108 | if attn_objects is not None:
109 | feats_list.append(attn_objects)
110 | visual_feats = torch.cat(feats_list, dim=-1)
111 | visual_feats = self.linear_visual_layer(visual_feats) if hasattr(self, 'linear_visual_layer') else visual_feats
112 |
113 | # for semantic features
114 | semantics_list = []
115 | if self.with_objects_semantics:
116 | attn_weights = self.wos(torch.tanh(Wh[:, None, :] + U_objs + self.bos))
117 | attn_weights = attn_weights.softmax(dim=1) # (bsz, max_objects, 1)
118 | attn_objs = attn_weights * object_semantics # (bsz, max_objects, emb_dim)
119 | attn_objs = attn_objs.sum(dim=1) # (bsz, emb_dim)
120 | semantics_list.append(attn_objs)
121 | if self.with_action_semantics: semantics_list.append(action_semantics)
122 | if self.with_video_semantics: semantics_list.append(video_semantics)
123 | semantics_feats = torch.cat(semantics_list, dim=-1) if len(semantics_list) > 0 else None
124 | semantics_feats = self.linear_semantics_layer(semantics_feats) if semantics_feats is not None else None
125 |
126 | # in addition to the lastest generated word, fuse visual features and semantic features together
127 | if semantics_feats is not None:
128 | input_feats = torch.cat([visual_feats, semantics_feats, embed], dim=-1) # (bsz, hidden_dim * 2 + embed_dim)
129 | else:
130 | input_feats = torch.cat([visual_feats, embed], dim=-1) # (bsz, hidden_dim + embed_dim)
131 | output, states = self.lstm(input_feats[None, ...], last_states)
132 | output = output.squeeze(0) # (bsz, hidden_dim)
133 | output = self.to_word(output) # (bsz, embed_dim)
134 | output_prob = self.logit(output) # (bsz, n_vocab)
135 | output_prob = torch.log_softmax(output_prob, dim=1) # (bsz, n_vocab)
136 |
137 | return output_prob, states
138 |
139 |
--------------------------------------------------------------------------------
/models/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
--------------------------------------------------------------------------------
/models/encoders/entity_level.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 |
5 |
6 | class EntityLevelEncoder(nn.Module):
7 | def __init__(self, transformer, max_objects, object_dim, feature2d_dim, feature3d_dim, hidden_dim, word_dim):
8 | super(EntityLevelEncoder, self).__init__()
9 | self.max_objects = max_objects
10 |
11 | self.query_embed = nn.Embedding(max_objects, hidden_dim)
12 | self.input_proj = nn.Sequential(
13 | nn.Linear(object_dim, hidden_dim * 2),
14 | nn.BatchNorm1d(hidden_dim * 2),
15 | nn.ReLU(True),
16 | nn.Dropout(0.5),
17 | nn.Linear(hidden_dim * 2, hidden_dim)
18 | )
19 | self.feature2d_proj = nn.Sequential(
20 | nn.Linear(feature2d_dim, 2 * hidden_dim),
21 | nn.BatchNorm1d(hidden_dim * 2),
22 | nn.ReLU(True),
23 | nn.Dropout(0.5),
24 | nn.Linear(hidden_dim * 2, hidden_dim)
25 | )
26 | self.feature3d_proj = nn.Sequential(
27 | nn.Linear(feature3d_dim, 2 * hidden_dim),
28 | nn.BatchNorm1d(hidden_dim * 2),
29 | nn.ReLU(True),
30 | nn.Dropout(0.5),
31 | nn.Linear(hidden_dim * 2, hidden_dim)
32 | )
33 | self.bilstm = nn.LSTM(input_size=hidden_dim * 2, hidden_size=hidden_dim//2,
34 | batch_first=True, bidirectional=True)
35 | self.transformer = transformer
36 | self.fc_layer = nn.Linear(hidden_dim, word_dim)
37 |
38 | def forward(self, features_2d: Tensor, features_3d: Tensor, objects: Tensor, objects_mask: Tensor):
39 | """
40 |
41 | Args:
42 | features_2d: (bsz, sample_numb, feature2d_dim)
43 | features_3d: (bsz, sample_numb, feature3d_dim)
44 | objects: (bsz, max_objects_per_video, object_dim)
45 | objects_mask: (bsz, max_objects_per_video)
46 |
47 | Returns:
48 | salient_info: (bsz, max_objects, hidden_dim)
49 | object_pending: (bsz, max_objects, word_dim)
50 | """
51 | device = objects.device
52 | bsz, sample_numb, max_objects_per_video = features_2d.shape[0], features_3d.shape[1], objects.shape[1]
53 | features_2d = self.feature2d_proj(features_2d.view(-1, features_2d.shape[-1]))
54 | features_2d = features_2d.view(bsz, sample_numb, -1).contiguous() # (bsz, sample_numb, hidden_dim)
55 | features_3d = self.feature3d_proj(features_3d.view(-1, features_3d.shape[-1]))
56 | features_3d = features_3d.view(bsz, sample_numb, -1).contiguous() # (bsz, sample_numb, hidden_dim)
57 | content_vectors = torch.cat([features_2d, features_3d], dim=-1) # (bsz, sample_numb, hidden_dim * 2)
58 |
59 | content_vectors, _ = self.bilstm(content_vectors) # (bsz, sample_numb, hidden_dim)
60 | content_vectors = torch.max(content_vectors, dim=1)[0] # (bsz, hidden_dim)
61 |
62 | tgt = content_vectors[None, ...].repeat(self.max_objects, 1, 1) # (max_objects, bsz, hidden_dim)
63 | objects = self.input_proj(objects.view(-1, objects.shape[-1]))
64 | objects = objects.view(bsz, max_objects_per_video, -1).contiguous() # (bsz, max_objects_per_video, hidden_dim)
65 |
66 | mask = objects_mask.to(device).bool() # (bsz, max_objects_per_video)
67 | salient_objects = self.transformer(objects, tgt, mask, self.query_embed.weight)[0][0] # (bsz, max_objects, hidden_dim)
68 | object_pending = self.fc_layer(salient_objects)
69 | return salient_objects, object_pending
70 |
71 |
72 |
73 |
--------------------------------------------------------------------------------
/models/encoders/predicate_level.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch import Tensor
4 |
5 |
6 | class PredicateLevelEncoder(nn.Module):
7 | def __init__(self, feature3d_dim, hidden_dim, semantics_dim, useless_objects):
8 | super(PredicateLevelEncoder, self).__init__()
9 | self.linear_layer = nn.Linear(feature3d_dim, hidden_dim)
10 | self.W = nn.Linear(hidden_dim, hidden_dim)
11 | self.U = nn.Linear(hidden_dim, hidden_dim)
12 | self.b = nn.Parameter(torch.ones(hidden_dim), requires_grad=True)
13 | self.w = nn.Linear(hidden_dim, 1)
14 | self.inf = 1e5
15 | self.useless_objects = useless_objects
16 |
17 | self.bilstm = nn.LSTM(input_size=hidden_dim + hidden_dim,
18 | hidden_size=hidden_dim // 2,
19 | num_layers=1, bidirectional=True, batch_first=True)
20 | self.fc_layer = nn.Linear(hidden_dim, semantics_dim)
21 |
22 | def forward(self, features3d: Tensor, objects: Tensor, objects_mask: Tensor):
23 | """
24 |
25 | Args:
26 | features3d: (bsz, sample_numb, 3d_dim)
27 | objects: (bsz, max_objects, hidden_dim)
28 | objects_mask: (bsz, max_objects_per_video)
29 |
30 | Returns:
31 | action_features: (bsz, sample_numb, hidden_dim * 2)
32 | action_pending: (bsz, semantics_dim)
33 | """
34 | sample_numb = features3d.shape[1]
35 | features3d = self.linear_layer(features3d) # (bsz, sample_numb, hidden_dim)
36 | Wf3d = self.W(features3d) # (bsz, sample_numb, hidden_dim)
37 | Uobjs = self.U(objects) # (bsz, max_objects, hidden_dim)
38 |
39 | attn_feat = Wf3d.unsqueeze(2) + Uobjs.unsqueeze(1) + self.b # (bsz, sample_numb, max_objects, hidden_dim)
40 | attn_weights = self.w(torch.tanh(attn_feat)) # (bsz, sample_numb, max_objects, 1)
41 | objects_mask = objects_mask[:, None, :, None].repeat(1, sample_numb, 1, 1) # (bsz, sample_numb, max_objects_per_video, 1)
42 | if self.useless_objects:
43 | attn_weights = attn_weights - objects_mask.float() * self.inf
44 | attn_weights = attn_weights.softmax(dim=-2) # (bsz, sample_numb, max_objects, 1)
45 | attn_objects = attn_weights * attn_feat
46 | attn_objects = attn_objects.sum(dim=-2) # (bsz, sample_numb, hidden_dim)
47 |
48 | features = torch.cat([features3d, attn_objects], dim=-1) # (bsz, sample_numb, hidden_dim * 2)
49 | output, states = self.bilstm(features) # (bsz, sample_numb, hidden_dim)
50 | action = torch.max(output, dim=1)[0] # (bsz, hidden_dim)
51 | action_pending = self.fc_layer(action) # (bsz, semantics_dim)
52 | action_features = output # (bsz, sample_numb, hidden_dim)
53 |
54 | return action_features, action_pending
55 |
56 |
57 |
58 |
--------------------------------------------------------------------------------
/models/encoders/sentence_level.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 |
4 |
5 | class SentenceLevelEncoder(nn.Module):
6 | def __init__(self, feature2d_dim, hidden_dim, semantics_dim, useless_objects):
7 | super(SentenceLevelEncoder, self).__init__()
8 | self.inf = 1e5
9 | self.useless_objects = useless_objects
10 | self.linear_2d = nn.Linear(feature2d_dim, hidden_dim)
11 |
12 | self.W = nn.Linear(hidden_dim, hidden_dim)
13 |
14 | self.Uo = nn.Linear(hidden_dim, hidden_dim)
15 | self.Um = nn.Linear(hidden_dim, hidden_dim)
16 |
17 | self.bo = nn.Parameter(torch.ones(hidden_dim), requires_grad=True)
18 | self.bm = nn.Parameter(torch.ones(hidden_dim), requires_grad=True)
19 |
20 | self.wo = nn.Linear(hidden_dim, 1)
21 | self.wm = nn.Linear(hidden_dim, 1)
22 |
23 | self.lstm = nn.LSTM(input_size=hidden_dim + hidden_dim + hidden_dim,
24 | hidden_size=hidden_dim // 2,
25 | num_layers=1, bidirectional=True, batch_first=True)
26 | self.fc_layer = nn.Linear(hidden_dim, semantics_dim)
27 |
28 | def forward(self, feature2ds: Tensor, vp_features: Tensor, object_features: Tensor, objects_mask: Tensor):
29 | """
30 |
31 | Args:
32 | feature2ds: (bsz, sample_numb, hidden_dim)
33 | vp_features: (bsz, sample_numb, hidden_dim)
34 | object_features: (bsz, max_objects, hidden_dim)
35 | objects_mask: (bsz, max_objects_per_video)
36 |
37 | Returns:
38 | video_features: (bsz, sample_numb, hidden_dim)
39 | video_pending: (bsz, semantics_dim)
40 | """
41 | sample_numb = feature2ds.shape[1]
42 | feature2ds = self.linear_2d(feature2ds)
43 | W_f2d = self.W(feature2ds)
44 | U_objs = self.Uo(object_features)
45 | U_motion = self.Um(vp_features)
46 |
47 | attn_feat = W_f2d.unsqueeze(2) + U_objs.unsqueeze(1) + self.bo # (bsz, sample_numb, max_objects, hidden_dim)
48 | attn_weights = self.wo(torch.tanh(attn_feat)) # (bsz, sample_numb, max_objects, 1)
49 | objects_mask = objects_mask[:, None, :, None].repeat(1, sample_numb, 1, 1) # (bsz, sample, max_objects_per_video, 1)
50 | if self.useless_objects:
51 | attn_weights = attn_weights - objects_mask.float() * self.inf
52 | attn_weights = attn_weights.softmax(dim=-2) # (bsz, sample_numb, max_objects, 1)
53 | attn_objects = attn_weights * attn_feat
54 | attn_objects = attn_objects.sum(dim=-2) # (bsz, sample_numb, hidden_dim)
55 |
56 | attn_feat = W_f2d.unsqueeze(2) + U_motion.unsqueeze(1) + self.bm # (bsz, sample_numb, sample_numb, hidden_dim)
57 | attn_weights = self.wm(torch.tanh(attn_feat)) # (bsz, sample_numb, sample_numb, 1)
58 | attn_weights = attn_weights.softmax(dim=-2) # (bsz, sample_numb, sample_numb, 1)
59 | attn_motion = attn_weights * attn_feat
60 | attn_motion = attn_motion.sum(dim=-2) # (bsz, sample_numb, hidden_dim)
61 |
62 | features = torch.cat([feature2ds, attn_motion, attn_objects], dim=-1) # (bsz, sample_numb, hidden_dim * 3)
63 | output, states = self.lstm(features) # (bsz, sample_numb, hidden_dim)
64 | video = torch.max(output, dim=1)[0] # (bsz, hidden_dim)
65 | video_pending = self.fc_layer(video) # (bsz, semantics_dim)
66 | video_features = output # (bsz, sample_numb, hidden_dim)
67 |
68 | return video_features, video_pending
69 |
70 |
71 |
--------------------------------------------------------------------------------
/models/encoders/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from typing import Optional, List
4 | from torch import nn, Tensor
5 | import copy
6 |
7 |
8 | class Transformer(nn.Module):
9 |
10 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
11 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
12 | activation="relu", normalize_before=False,
13 | return_intermediate_dec=False):
14 | super().__init__()
15 |
16 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
17 | dropout, activation, normalize_before)
18 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
19 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
20 |
21 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
22 | dropout, activation, normalize_before)
23 | decoder_norm = nn.LayerNorm(d_model)
24 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
25 | return_intermediate=return_intermediate_dec)
26 |
27 | self._reset_parameters()
28 |
29 | self.d_model = d_model
30 | self.nhead = nhead
31 |
32 | def _reset_parameters(self):
33 | for p in self.parameters():
34 | if p.dim() > 1:
35 | nn.init.xavier_uniform_(p)
36 |
37 | def forward(self, src, tgt, mask, query_embed, pos_embed=None):
38 | bs, detected_num, c = src.shape
39 | src = src.permute(1, 0, 2) # (max_objects_per_video, bsz, object_feats_dim)
40 | # pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
41 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
42 | mask = mask.flatten(1)
43 |
44 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)
45 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,
46 | pos=pos_embed, query_pos=query_embed)
47 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, detected_num)
48 |
49 |
50 | class TransformerEncoder(nn.Module):
51 |
52 | def __init__(self, encoder_layer, num_layers, norm=None):
53 | super().__init__()
54 | self.layers = _get_clones(encoder_layer, num_layers)
55 | self.num_layers = num_layers
56 | self.norm = norm
57 |
58 | def forward(self, src,
59 | mask: Optional[Tensor] = None,
60 | src_key_padding_mask: Optional[Tensor] = None,
61 | pos: Optional[Tensor] = None):
62 | output = src
63 |
64 | for layer in self.layers:
65 | output = layer(output, src_mask=mask,
66 | src_key_padding_mask=src_key_padding_mask, pos=pos)
67 |
68 | if self.norm is not None:
69 | output = self.norm(output)
70 |
71 | return output
72 |
73 |
74 | class TransformerDecoder(nn.Module):
75 |
76 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
77 | super().__init__()
78 | self.layers = _get_clones(decoder_layer, num_layers)
79 | self.num_layers = num_layers
80 | self.norm = norm
81 | self.return_intermediate = return_intermediate
82 |
83 | def forward(self, tgt, memory,
84 | tgt_mask: Optional[Tensor] = None,
85 | memory_mask: Optional[Tensor] = None,
86 | tgt_key_padding_mask: Optional[Tensor] = None,
87 | memory_key_padding_mask: Optional[Tensor] = None,
88 | pos: Optional[Tensor] = None,
89 | query_pos: Optional[Tensor] = None):
90 | output = tgt
91 |
92 | intermediate = []
93 |
94 | for layer in self.layers:
95 | output = layer(output, memory, tgt_mask=tgt_mask,
96 | memory_mask=memory_mask,
97 | tgt_key_padding_mask=tgt_key_padding_mask,
98 | memory_key_padding_mask=memory_key_padding_mask,
99 | pos=pos, query_pos=query_pos)
100 | if self.return_intermediate:
101 | intermediate.append(self.norm(output))
102 |
103 | if self.norm is not None:
104 | output = self.norm(output)
105 | if self.return_intermediate:
106 | intermediate.pop()
107 | intermediate.append(output)
108 |
109 | if self.return_intermediate:
110 | return torch.stack(intermediate)
111 |
112 | return output.unsqueeze(0)
113 |
114 |
115 | class TransformerEncoderLayer(nn.Module):
116 |
117 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
118 | activation="relu", normalize_before=False):
119 | super().__init__()
120 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
121 | # Implementation of Feedforward model
122 | self.linear1 = nn.Linear(d_model, dim_feedforward)
123 | self.dropout = nn.Dropout(dropout)
124 | self.linear2 = nn.Linear(dim_feedforward, d_model)
125 |
126 | self.norm1 = nn.LayerNorm(d_model)
127 | self.norm2 = nn.LayerNorm(d_model)
128 | self.dropout1 = nn.Dropout(dropout)
129 | self.dropout2 = nn.Dropout(dropout)
130 |
131 | self.activation = _get_activation_fn(activation)
132 | self.normalize_before = normalize_before
133 |
134 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
135 | return tensor if pos is None else tensor + pos
136 |
137 | def forward_post(self,
138 | src,
139 | src_mask: Optional[Tensor] = None,
140 | src_key_padding_mask: Optional[Tensor] = None,
141 | pos: Optional[Tensor] = None):
142 | q = k = self.with_pos_embed(src, pos)
143 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
144 | key_padding_mask=src_key_padding_mask)[0]
145 | src = src + self.dropout1(src2)
146 | src = self.norm1(src)
147 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
148 | src = src + self.dropout2(src2)
149 | src = self.norm2(src)
150 | return src
151 |
152 | def forward_pre(self, src,
153 | src_mask: Optional[Tensor] = None,
154 | src_key_padding_mask: Optional[Tensor] = None,
155 | pos: Optional[Tensor] = None):
156 | src2 = self.norm1(src)
157 | q = k = self.with_pos_embed(src2, pos)
158 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
159 | key_padding_mask=src_key_padding_mask)[0]
160 | src = src + self.dropout1(src2)
161 | src2 = self.norm2(src)
162 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
163 | src = src + self.dropout2(src2)
164 | return src
165 |
166 | def forward(self, src,
167 | src_mask: Optional[Tensor] = None,
168 | src_key_padding_mask: Optional[Tensor] = None,
169 | pos: Optional[Tensor] = None):
170 | if self.normalize_before:
171 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
172 | return self.forward_post(src, src_mask, src_key_padding_mask, pos)
173 |
174 |
175 | class TransformerDecoderLayer(nn.Module):
176 |
177 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
178 | activation="relu", normalize_before=False):
179 | super().__init__()
180 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
181 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
182 | # Implementation of Feedforward model
183 | self.linear1 = nn.Linear(d_model, dim_feedforward)
184 | self.dropout = nn.Dropout(dropout)
185 | self.linear2 = nn.Linear(dim_feedforward, d_model)
186 |
187 | self.norm1 = nn.LayerNorm(d_model)
188 | self.norm2 = nn.LayerNorm(d_model)
189 | self.norm3 = nn.LayerNorm(d_model)
190 | self.dropout1 = nn.Dropout(dropout)
191 | self.dropout2 = nn.Dropout(dropout)
192 | self.dropout3 = nn.Dropout(dropout)
193 |
194 | self.activation = _get_activation_fn(activation)
195 | self.normalize_before = normalize_before
196 |
197 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
198 | return tensor if pos is None else tensor + pos
199 |
200 | def forward_post(self, tgt, memory,
201 | tgt_mask: Optional[Tensor] = None,
202 | memory_mask: Optional[Tensor] = None,
203 | tgt_key_padding_mask: Optional[Tensor] = None,
204 | memory_key_padding_mask: Optional[Tensor] = None,
205 | pos: Optional[Tensor] = None,
206 | query_pos: Optional[Tensor] = None):
207 | q = k = self.with_pos_embed(tgt, query_pos)
208 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask,
209 | key_padding_mask=tgt_key_padding_mask)[0]
210 | tgt = tgt + self.dropout1(tgt2)
211 | tgt = self.norm1(tgt)
212 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
213 | key=self.with_pos_embed(memory, pos),
214 | value=memory, attn_mask=memory_mask,
215 | key_padding_mask=memory_key_padding_mask)[0]
216 | tgt = tgt + self.dropout2(tgt2)
217 | tgt = self.norm2(tgt)
218 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
219 | tgt = tgt + self.dropout3(tgt2)
220 | tgt = self.norm3(tgt)
221 | return tgt
222 |
223 | def forward_pre(self, tgt, memory,
224 | tgt_mask: Optional[Tensor] = None,
225 | memory_mask: Optional[Tensor] = None,
226 | tgt_key_padding_mask: Optional[Tensor] = None,
227 | memory_key_padding_mask: Optional[Tensor] = None,
228 | pos: Optional[Tensor] = None,
229 | query_pos: Optional[Tensor] = None):
230 | tgt2 = self.norm1(tgt)
231 | q = k = self.with_pos_embed(tgt2, query_pos)
232 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
233 | key_padding_mask=tgt_key_padding_mask)[0]
234 | tgt = tgt + self.dropout1(tgt2)
235 | tgt2 = self.norm2(tgt)
236 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
237 | key=self.with_pos_embed(memory, pos),
238 | value=memory, attn_mask=memory_mask,
239 | key_padding_mask=memory_key_padding_mask)[0]
240 | tgt = tgt + self.dropout2(tgt2)
241 | tgt2 = self.norm3(tgt)
242 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
243 | tgt = tgt + self.dropout3(tgt2)
244 | return tgt
245 |
246 | def forward(self, tgt, memory,
247 | tgt_mask: Optional[Tensor] = None,
248 | memory_mask: Optional[Tensor] = None,
249 | tgt_key_padding_mask: Optional[Tensor] = None,
250 | memory_key_padding_mask: Optional[Tensor] = None,
251 | pos: Optional[Tensor] = None,
252 | query_pos: Optional[Tensor] = None):
253 | if self.normalize_before:
254 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
255 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
256 | return self.forward_post(tgt, memory, tgt_mask, memory_mask,
257 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
258 |
259 |
260 | def _get_clones(module, N):
261 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
262 |
263 |
264 | def build_transformer(args):
265 | return Transformer(
266 | d_model=args.hidden_dim,
267 | dropout=args.dropout,
268 | nhead=args.nheads,
269 | dim_feedforward=args.dim_feedforward,
270 | num_encoder_layers=args.enc_layers,
271 | num_decoder_layers=args.dec_layers,
272 | normalize_before=args.pre_norm,
273 | return_intermediate_dec=True,
274 | )
275 |
276 |
277 | def _get_activation_fn(activation):
278 | """Return an activation function given a string"""
279 | if activation == "relu":
280 | return F.relu
281 | if activation == "gelu":
282 | return F.gelu
283 | if activation == "glu":
284 | return F.glu
285 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
286 |
287 |
288 |
--------------------------------------------------------------------------------
/models/hungary.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, Tensor
3 | import numpy as np
4 | from scipy.optimize import linear_sum_assignment
5 |
6 | class HungarianMatcher(nn.Module):
7 | """This class computes an assignment between the targets and the predictions of the network
8 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
9 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
10 | while the others are un-matched (and thus treated as non-objects).
11 | """
12 |
13 | def __init__(self):
14 | super(HungarianMatcher, self).__init__()
15 | self.eps = 1e-6
16 |
17 | @torch.no_grad()
18 | def forward(self, salient_objects: Tensor, nouns_dict_list: list):
19 | """ Performs the matching
20 | Args:
21 | salient_objects: (bsz, max_objects, word_dim)
22 | nouns_dict_list: List[{'vec': nouns_vec, 'nouns': nouns}, ...]
23 | Returns:
24 | A list of size batch_size, containing tuples of (index_i, index_j) where:
25 | - index_i is the indices of the selected predictions (in order)
26 | - index_j is the indices of the corresponding selected targets (in order)
27 | For each batch element, it holds:
28 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
29 | """
30 | bsz, max_objects = salient_objects.shape[:2]
31 | device = salient_objects.device
32 | sizes = [len(item['nouns']) for item in nouns_dict_list]
33 | nouns_semantics = torch.cat([item['vec'][:len(item['nouns'])] for item in nouns_dict_list]).to(device) # (\sigma nouns, word_dim)
34 | nouns_length = torch.norm(nouns_semantics, dim=-1, keepdim=True) # (\sigma nouns, 1)
35 | salient_objects = salient_objects.flatten(0, 1) # (bsz * max_objects, word_dim)
36 | salient_length = torch.norm(salient_objects, dim=-1, keepdim=True) # (bsz * max_objects, 1)
37 | matrix_length = salient_length * nouns_length.permute([1, 0]) + self.eps # (bsz * max_objects, \sigma nouns)
38 |
39 |
40 | cos_matrix = torch.mm(salient_objects, nouns_semantics.permute([1, 0])) # (bsz * max_objects, \sigma nouns)
41 | cos_matrix = -cos_matrix / matrix_length # (bsz * max_objects, \sigma nouns)
42 | cos_matrix = cos_matrix.view([bsz, max_objects, -1]) # (bsz, max_objects, \sigma nouns)
43 | indices = [linear_sum_assignment(c[i].detach().cpu().numpy()) for i, c in enumerate(cos_matrix.split(sizes, -1))]
44 |
45 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
46 |
47 |
--------------------------------------------------------------------------------
/scripts/split_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import h5py
4 | import argparse
5 |
6 | if __name__ == '__main__':
7 | parser = argparse.ArgumentParser()
8 | parser.add_argument('--dataset_name', type=str, default='-1', help='dataset name (MSVD | MSRVTT)')
9 | parser.add_argument('--split_dir', type=str, default='-1', help='dir contains split_list')
10 | parser.add_argument('--data_path', type=str, default='-1', help='the path of unsplited data')
11 | parser.add_argument('--data_name', type=str, default='-1', help='name of unsplited data')
12 | parser.add_argument('--target_dir', type=str, default='-1', help='where do you want to place your data')
13 | parser.add_argument('--split_objects', action='store_true')
14 |
15 | args = parser.parse_args()
16 | dataset_name = args.dataset_name
17 | split_dir = args.split_dir
18 | data_path = args.data_path
19 | data_name = args.data_name
20 | target_dir = args.target_dir
21 |
22 | assert dataset_name != '-1', 'Please set dataset_name!'
23 | assert split_dir != '-1', 'Please set split_dir!'
24 | assert data_path != '-1', 'Please set data_path!'
25 | assert data_name != '-1', 'Please set data_name!'
26 | assert target_dir != '-1', 'Please set target_dir!'
27 |
28 | splits_list_path_tpl = '{dataset_name}_{part}_list.pkl'.format(dataset_name=dataset_name, part='{}')
29 | dataset_split_path_tpl = '{}_{}_{}.hdf5'.format(dataset_name, data_name, '{}')
30 | split_part_list = ['train', 'valid', 'test']
31 | print('[split begin]', '=' * 20)
32 | with h5py.File(data_path, 'r') as f:
33 | for split in split_part_list:
34 | cur_vid_list_path = os.path.join(split_dir, splits_list_path_tpl.format(split))
35 | dataset_split_save_path = os.path.join(target_dir, dataset_split_path_tpl.format(split))
36 | with open(cur_vid_list_path, 'rb') as v:
37 | vid_list = pickle.load(v)
38 | with h5py.File(dataset_split_save_path, 'w') as t:
39 | for vid in vid_list:
40 | if args.split_objects:
41 | t[vid] = f[vid][()]
42 | else:
43 | temp_group = t.create_group(vid)
44 | temp_group['feats'] = f[vid]['feats'][()]
45 | # temp_group['bboxes'] = f[vid]['bboxes'][()]
46 | # temp_group['kinds'] = f[vid]['kinds'][()]
47 | print('[split end]', '=' * 20)
48 |
49 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import pickle
3 |
4 | from configs.settings import TotalConfigs
5 | from eval import eval_fn
6 |
7 |
8 | def test_fn(cfgs: TotalConfigs, model, loader, device):
9 | print('##############n_vocab is {}##############'.format(cfgs.decoder.n_vocab))
10 | with open(cfgs.data.idx2word_path, 'rb') as f:
11 | idx2word = pickle.load(f)
12 | with open(cfgs.data.vid2groundtruth_path, 'rb') as f:
13 | vid2groundtruth = pickle.load(f)
14 | scores = eval_fn(model=model, loader=loader, device=device,
15 | idx2word=idx2word, save_on_disk=True, cfgs=cfgs,
16 | vid2groundtruth=vid2groundtruth)
17 | print('===================Testing is finished====================')
18 |
19 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.optim as optim
4 | import numpy as np
5 | import pickle
6 |
7 | from utils.loss import LanguageModelCriterion, CosineCriterion, SoftCriterion
8 | from eval import eval_fn
9 | from configs.settings import TotalConfigs
10 | from models.hungary import HungarianMatcher
11 |
12 |
13 | def _get_src_permutation_idx(indices):
14 | # permute predictions following indices
15 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
16 | src_idx = torch.cat([src for (src, _) in indices])
17 | return batch_idx, src_idx
18 |
19 |
20 | def _get_tgt_permutation_idx(indices):
21 | # permute targets following indices
22 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
23 | tgt_idx = torch.cat([tgt for (_, tgt) in indices])
24 | return batch_idx, tgt_idx
25 |
26 |
27 | def train_fn(cfgs: TotalConfigs, model_name: str, model: nn.Module, matcher: HungarianMatcher, train_loader, valid_loader, device):
28 | optimizer = optim.Adam(model.parameters(), lr=cfgs.train.learning_rate)
29 | lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, cfgs.train.max_epochs, eta_min=0, last_epoch=-1)
30 | language_loss = LanguageModelCriterion()
31 | soft_loss = SoftCriterion()
32 | cos_loss = CosineCriterion()
33 | language_loss.to(device)
34 | best_score, cnt = None, 0
35 | loss_store, loss_entity, loss_predicate, loss_sentence, loss_xe, loss_soft_target = [], [], [], [], [], []
36 |
37 | with open(cfgs.data.idx2word_path, 'rb') as f:
38 | idx2word = pickle.load(f)
39 | with open(cfgs.data.vid2groundtruth_path, 'rb') as f:
40 | vid2groundtruth = pickle.load(f)
41 |
42 | print('===================Training begin====================')
43 | print(cfgs.train.save_checkpoints_path)
44 | for epoch in range(cfgs.train.max_epochs):
45 | print('\n{}[EPOCH {}]{}'.format('='*15, epoch, '='*15))
46 |
47 | for i, (feature2ds, feature3ds, objects, object_masks, vp_semantics, caption_semantics, numberic_caps, masks, captions, nouns, vids, vocab_ids, vocab_probs, fillmasks) in enumerate(train_loader):
48 | cnt += 1
49 | feature2ds = feature2ds.to(device)
50 | feature3ds = feature3ds.to(device)
51 | objects = objects.to(device)
52 | object_masks = object_masks.to(device)
53 | vp_semantics = vp_semantics.to(device)
54 | caption_semantics = caption_semantics.to(device)
55 | numberic_caps = numberic_caps.to(device)
56 | masks = masks.to(device)
57 | vocab_ids = vocab_ids.to(device) if vocab_ids is not None else None
58 | vocab_probs = vocab_probs.to(device) if vocab_probs is not None else None
59 | fillmasks = fillmasks.to(device) if fillmasks is not None else None
60 |
61 | # bsz, sample_numb, obj_numb, obj_dim = objects.shape
62 | # objects = objects.reshape([bsz, sample_numb * obj_numb, obj_dim])
63 |
64 | optimizer.zero_grad()
65 |
66 | preds, objects_pending, action_pending, video_pending = model(objects, object_masks, feature2ds, feature3ds, numberic_caps)
67 | xe_loss, s_loss, ent_loss, pred_loss, sent_loss = None, None, None, None, None
68 |
69 | # cross entropy loss
70 | loss_hard = language_loss(preds, numberic_caps, masks, cfgs.dict.eos_idx)
71 | loss = loss_hard
72 | xe_loss = loss_hard.detach().item()
73 |
74 | # soft loss
75 | if cfgs.train.lambda_soft > 0:
76 | loss_soft = soft_loss(preds, vocab_ids, vocab_probs, fillmasks)
77 | loss = loss + loss_soft * cfgs.train.lambda_soft
78 | s_loss = loss_soft.detach().item()
79 |
80 | # object module loss
81 | if cfgs.train.lambda_entity > 0:
82 | indices = matcher(objects_pending, nouns)
83 | src_idx = _get_src_permutation_idx(indices)
84 | objects = objects_pending[src_idx]
85 | targets = torch.cat([t['vec'][i] for t, (_, i) in zip(nouns, indices)], dim=0).to(device)
86 | if np.any(np.isnan(objects.detach().cpu().numpy())):
87 | raise RuntimeError
88 | object_loss = cos_loss(objects, targets)
89 | loss = loss + object_loss * cfgs.train.lambda_entity
90 | ent_loss = object_loss.detach().item()
91 |
92 | # action module loss
93 | if cfgs.train.lambda_predicate > 0:
94 | action_loss = cos_loss(action_pending, vp_semantics)
95 | loss = loss + action_loss * cfgs.train.lambda_predicate
96 | pred_loss = action_loss.detach().item()
97 |
98 | # video module loss
99 | if cfgs.train.lambda_sentence > 0:
100 | sent_loss = cos_loss(video_pending, caption_semantics)
101 | loss = loss + sent_loss * cfgs.train.lambda_sentence
102 | sent_loss = sent_loss.detach().item()
103 |
104 | loss.backward()
105 | loss_store.append(loss.detach().item())
106 | loss_xe.append(xe_loss)
107 | loss_entity.append(ent_loss)
108 | loss_predicate.append(pred_loss)
109 | loss_sentence.append(sent_loss)
110 | loss_soft_target.append(s_loss)
111 | nn.utils.clip_grad_norm_(model.parameters(), cfgs.train.grad_clip)
112 | optimizer.step()
113 |
114 | if cnt % cfgs.train.visualize_every == 0:
115 | loss_store, loss_xe, loss_entity, loss_predicate, loss_sentence, loss_soft_target = \
116 | loss_store[-10:], loss_xe[-10:], loss_entity[-10:], loss_predicate[-10:], loss_sentence[-10:], loss_soft_target[-10:]
117 | loss_value = np.array(loss_store).mean()
118 | xe_value = np.array(loss_xe).mean() if loss_xe[0] is not None else 0
119 | soft_value = np.array(loss_soft_target).mean() if loss_soft_target[0] is not None else 0
120 | entity_value = np.array(loss_entity).mean() if loss_entity[0] is not None else 0
121 | predicate_value = np.array(loss_predicate).mean() if loss_predicate[0] is not None else 0
122 | sentence_value = np.array(loss_sentence).mean() if loss_sentence[0] is not None else 0
123 |
124 | print('[EPOCH {};ITER {}]:loss[{:.3f}]=hard_loss[{:.3f}]*1+soft_loss[{:.3f}]*{:.2f}+entity[{:.3f}]*{:.2f}+predicate[{:.3f}]*{:.2f}+sentence[{:.3f}]*{:.2f}'
125 | .format(epoch, i, loss_value, xe_value,
126 | soft_value, cfgs.train.lambda_soft,
127 | entity_value, cfgs.train.lambda_entity,
128 | predicate_value, cfgs.train.lambda_predicate,
129 | sentence_value, cfgs.train.lambda_sentence))
130 |
131 | if cnt % cfgs.train.save_checkpoints_every == 0:
132 | ckpt_path = cfgs.train.save_checkpoints_path
133 | scores = eval_fn(model=model, loader=valid_loader, device=device,
134 | idx2word=idx2word, save_on_disk=False, cfgs=cfgs,
135 | vid2groundtruth=vid2groundtruth)
136 | cider_score = scores['CIDEr']
137 | if best_score is None or cider_score > best_score:
138 | best_score = cider_score
139 | torch.save(model.state_dict(), ckpt_path)
140 | print('=' * 10,
141 | '[EPOCH{epoch} iter{it}] :Best Cider is {bs}, Current Cider is {cs}'.
142 | format(epoch=epoch, it=i, bs=best_score, cs=cider_score),
143 | '=' * 10)
144 |
145 | lr_scheduler.step()
146 | print('===================Training is finished====================')
147 | return model
148 |
149 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MarcusNerva/HMN/8de1a51e7b59ea964f39b0157246a838c2ae05e5/utils/__init__.py
--------------------------------------------------------------------------------
/utils/build_loaders.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import DataLoader
2 | from data.loader.data_loader import CaptionDataset, collate_fn_caption
3 | from configs.settings import TotalConfigs
4 |
5 |
6 | def build_loaders(cfgs: TotalConfigs, is_total=False):
7 | train_dataset = CaptionDataset(cfgs=cfgs, mode='train', save_on_disk=False, is_total=is_total)
8 | valid_dataset = CaptionDataset(cfgs=cfgs, mode='valid', save_on_disk=False, is_total=is_total)
9 | test_dataset = CaptionDataset(cfgs=cfgs, mode='test', save_on_disk=True, is_total=is_total)
10 |
11 | train_loader = DataLoader(dataset=train_dataset, batch_size=cfgs.bsz, shuffle=True,
12 | collate_fn=collate_fn_caption, num_workers=0)
13 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=cfgs.bsz, shuffle=True,
14 | collate_fn=collate_fn_caption, num_workers=0)
15 | test_loader = DataLoader(dataset=test_dataset, batch_size=cfgs.bsz, shuffle=False,
16 | collate_fn=collate_fn_caption, num_workers=0)
17 |
18 | return train_loader, valid_loader, test_loader
19 |
20 |
21 | def get_test_loader(cfgs: TotalConfigs, is_total=False):
22 | test_dataset = CaptionDataset(cfgs=cfgs, mode='test', save_on_disk=True, is_total=is_total)
23 | test_loader = DataLoader(dataset=test_dataset, batch_size=cfgs.bsz,
24 | shuffle=False, collate_fn=collate_fn_caption,
25 | num_workers=0)
26 | return test_loader
27 |
28 |
29 | def get_train_loader(cfgs: TotalConfigs, save_on_disk=True):
30 | train_dataset = CaptionDataset(cfgs=cfgs, mode='train', save_on_disk=save_on_disk, is_total=False)
31 | train_loader = DataLoader(dataset=train_dataset, batch_size=cfgs.bsz,
32 | shuffle=False, collate_fn=collate_fn_caption,
33 | num_workers=0)
34 | return train_loader
35 |
36 | def get_valid_loader(cfgs: TotalConfigs, is_total=False):
37 | valid_dataset = CaptionDataset(cfgs=cfgs, mode='valid', save_on_disk=True, is_total=is_total)
38 | valid_loader = DataLoader(dataset=valid_dataset, batch_size=cfgs.bsz,
39 | shuffle=False, collate_fn=collate_fn_caption,
40 | num_workers=0)
41 | return valid_loader
42 |
43 |
--------------------------------------------------------------------------------
/utils/build_model.py:
--------------------------------------------------------------------------------
1 | import pickle
2 |
3 | from models.caption_models.hierarchical_model import HierarchicalModel
4 | from models.decoder import Decoder
5 |
6 | from models.encoders.transformer import Transformer
7 | from models.encoders.entity_level import EntityLevelEncoder
8 | from models.encoders.predicate_level import PredicateLevelEncoder
9 | from models.encoders.sentence_level import SentenceLevelEncoder
10 |
11 | from configs.settings import TotalConfigs
12 |
13 |
14 | def build_model(cfgs: TotalConfigs):
15 | model_name = cfgs.model_name
16 | embedding_weights_path = cfgs.data.embedding_weights_path
17 | max_caption_len = cfgs.test.max_caption_len
18 | temperature = cfgs.test.temperature
19 | beam_size = cfgs.test.beam_size
20 | pad_idx = cfgs.dict.eos_idx
21 | eos_idx = cfgs.dict.eos_idx
22 | sos_idx = cfgs.dict.sos_idx
23 | unk_idx = cfgs.dict.unk_idx
24 | with open(embedding_weights_path, 'rb') as f:
25 | embedding_weights = pickle.load(f)
26 | n_vocab = embedding_weights.shape[0]
27 | cfgs.decoder.n_vocab = n_vocab
28 |
29 | feature2d_dim = cfgs.encoder.backbone_2d_dim
30 | feature3d_dim = cfgs.encoder.backbone_3d_dim
31 | object_dim = cfgs.encoder.object_dim
32 | semantics_dim = cfgs.encoder.semantics_dim
33 | hidden_dim = cfgs.decoder.hidden_dim
34 | decoder_num_layers = cfgs.decoder.num_layers
35 | embed_dim = cfgs.data.word_dim
36 | max_objects = cfgs.encoder.max_objects
37 |
38 | nheads = cfgs.encoder.nheads
39 | trans_num_encoder_layers = cfgs.encoder.entity_encoder_layer
40 | trans_num_decoder_layers = cfgs.encoder.entity_encoder_layer
41 | dim_feedforward = cfgs.encoder.dim_feedforward
42 | transformer_activation = cfgs.encoder.transformer_activation
43 | d_model = cfgs.encoder.d_model
44 | trans_dropout = cfgs.encoder.trans_dropout
45 |
46 | if model_name == 'HMN':
47 | # encoders
48 | transformer = Transformer(d_model=d_model, nhead=nheads,
49 | num_encoder_layers=trans_num_encoder_layers,
50 | num_decoder_layers=trans_num_decoder_layers,
51 | dim_feedforward=dim_feedforward,
52 | dropout=trans_dropout,
53 | activation=transformer_activation)
54 | entity_level_encoder = EntityLevelEncoder(transformer=transformer,
55 | max_objects=max_objects,
56 | object_dim=object_dim,
57 | feature2d_dim=feature2d_dim,
58 | feature3d_dim=feature3d_dim,
59 | hidden_dim=hidden_dim,
60 | word_dim=semantics_dim)
61 | predicate_level_encoder = PredicateLevelEncoder(feature3d_dim=feature3d_dim,
62 | hidden_dim=hidden_dim,
63 | semantics_dim=semantics_dim,
64 | useless_objects=False)
65 | sentence_level_encoder = SentenceLevelEncoder(feature2d_dim=feature2d_dim,
66 | hidden_dim=hidden_dim,
67 | semantics_dim=semantics_dim,
68 | useless_objects=False)
69 |
70 | # decoder
71 | decoder = Decoder(semantics_dim=semantics_dim, hidden_dim=hidden_dim,
72 | num_layers=decoder_num_layers, embed_dim=embed_dim, n_vocab=n_vocab,
73 | with_objects=True, with_action=True, with_video=True,
74 | with_objects_semantics=True,
75 | with_action_semantics=True,
76 | with_video_semantics=True)
77 |
78 | else:
79 | raise NotImplementedError
80 |
81 | # HMN
82 | model = HierarchicalModel(entity_level=entity_level_encoder,
83 | predicate_level=predicate_level_encoder,
84 | sentence_level=sentence_level_encoder,
85 | decoder=decoder,
86 | word_embedding_weights=embedding_weights,
87 | max_caption_len=max_caption_len,
88 | beam_size=beam_size, pad_idx=pad_idx,
89 | temperature=temperature,
90 | eos_idx=eos_idx,
91 | sos_idx=sos_idx,
92 | unk_idx=unk_idx)
93 |
94 | return model
95 |
96 |
97 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | def to_contiguous(tensor):
6 | if tensor.is_contiguous():
7 | return tensor
8 | else:
9 | return tensor.contiguous()
10 |
11 |
12 | class LanguageModelCriterion(nn.Module):
13 | def __init__(self):
14 | super(LanguageModelCriterion, self).__init__()
15 |
16 | def forward(self, pred, target, mask, eos_idx):
17 |
18 | pred = to_contiguous(pred).view(-1, pred.size(-1))
19 | target = torch.cat([target[:, 1:], target[:, 0].unsqueeze(1).fill_(eos_idx)], dim=1)
20 | target = to_contiguous(target).view(-1, 1)
21 | mask = to_contiguous(mask).view(-1, 1)
22 |
23 | output = -1. * pred.gather(1, target) * mask
24 | output = torch.sum(output) / torch.sum(mask)
25 | return output.float()
26 |
27 |
28 | class SoftCriterion(nn.Module):
29 | def __init__(self):
30 | super(SoftCriterion, self).__init__()
31 |
32 | def forward(self, pred, idxs, soft_target, mask):
33 | topk = -1.0 * pred.gather(-1, idxs) * mask[..., None]
34 | output = soft_target * topk
35 | output = torch.sum(output) / torch.sum(mask)
36 | return output.float()
37 |
38 |
39 | class CosineCriterion(nn.Module):
40 | def __init__(self):
41 | super(CosineCriterion, self).__init__()
42 | self.eps = 1e-12
43 |
44 | def forward(self, pred, target):
45 | assert pred.shape == target.shape and pred.dim() == 2, \
46 | 'expected pred.shape == target.shape, ' \
47 | 'but got pred.shape == {} and target.shape == {}'.format(pred.shape, target.shape)
48 | pred_denom = torch.norm(pred, p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(pred)
49 | pred = pred / pred_denom
50 | target_denom = torch.norm(target, p=2, dim=-1, keepdim=True).clamp_min(self.eps).expand_as(target)
51 | target = target / target_denom
52 |
53 | ret = pred * target
54 | ret = 1.0 - ret.sum(dim=-1)
55 | ret = ret.sum()
56 | return ret
57 |
58 |
59 |
60 |
--------------------------------------------------------------------------------