├── README.md
├── annotations
└── charades
│ ├── charades_sta_test_pos_original_simple_sent.json
│ └── charades_train_pseudo_supervision_TEP_PS.json
├── config_parsers
├── simple_model_config.py
└── simple_model_cross_modality_twostage_attention_config.py
├── configs
└── cha_simple_model
│ └── simplemodel_cha_BS256_two-stage_attention.yml
├── dataset
├── anetcap_basic.py
└── charades_basic.py
├── inference.py
├── media
└── task-1.png
├── models
├── simple_model.py
├── simple_model_cross_modal_two_stage.py
└── simple_model_cross_modal_two_stage_temperature.py
├── requirements.txt
├── train.py
└── utils
├── eval_utils.py
└── loss.py
/README.md:
--------------------------------------------------------------------------------
1 | # Zero-shot Natural Language Video Localization (ZSNLVL) by Pseudo-Supervised Video Localization (PSVL)
2 |
3 | This repository is for [Zero-shot Natural Language Video Localization](https://openaccess.thecvf.com/content/ICCV2021/papers/Nam_Zero-Shot_Natural_Language_Video_Localization_ICCV_2021_paper.pdf). (ICCV 2021, Oral)
4 |
5 |
6 | We first propose a novel task of zero-shot natural language video localization. The proposed task setup does not require any paired annotation cost for NLVL task but only requires easily available text corpora, off-the-shelf object detector, and a collection of videos to localize. To address the task, we propose a **P**seudo-**S**upervised **V**ideo **L**ocalization method, called **PSVL**, that can generate pseudo-supervision for training an NLVL model. Benchmarked on two widely used NLVL datasets, the proposed method exhibits competitive performance and performs on par or outperforms the models trained with stronger supervision.
7 |
8 |
9 |
10 |
11 | ---
12 | ## Environment
13 | This repository is implemented base on [PyTorch](http://pytorch.org/) with Anaconda.
14 | Refer to below instruction or use **Docker** (dcahn/psvl:latest).
15 |
16 |
17 | ### Get the code
18 | - Clone this repo with git, please use:
19 | ```bash
20 | git clone https://github.com/gistvision/PSVL.git
21 | ```
22 |
23 | - Make your own environment (If you use docker envronment, you just clone the code and execute it.)
24 | ```bashz
25 | conda create --name PSVL --file requirements.txt
26 | conda activate PSVL
27 | ```
28 |
29 | #### Working environment
30 | - RTX2080Ti (11G)
31 | - Ubuntu 18.04.5
32 | - pytorch 1.5.1
33 |
34 | ## Download
35 |
36 | ### Dataset & Pretrained model
37 |
38 | - This [link](https://drive.google.com/file/d/1Vjgm2XA3TYcc4h9IWR5k5efU-bXNir5f/view?usp=sharing) is connected for downloading video features used in this paper.
39 | : After downloading the video feature, you need to set the `data path` in a config file.
40 |
41 | - This [link](https://drive.google.com/file/d/1M2FX2qkEvyked50LSc9Y5r87GBnpohSX/view?usp=sharing) is connected for downloading pre-trained model.
42 |
43 | For ActivityNet-Captions, check Activinet-Captions section of this document.
44 |
45 | ## Evaluating pre-trained models
46 |
47 | If you want to evaluate the pre-trained model, you can use below command.
48 |
49 | ```bash
50 | python inference.py --model CrossModalityTwostageAttention --config "YOUR CONFIG PATH" --pre_trained "YOUR MODEL PATH"
51 | ```
52 |
53 | ## Training models from scratch
54 |
55 | To train PSVL, run `train.py` with below command.
56 |
57 | ```bash
58 | # Training from scratch
59 | python train.py --model CrossModalityTwostageAttention --config "YOUR CONFIG PATH"
60 | # Evaluation
61 | python inference.py --model CrossModalityTwostageAttention --config "YOUR CONFIG PATH" --pre_trained "YOUR MODEL PATH"
62 | ```
63 |
64 | ## Activinet-Captions
65 |
66 | - Go to [this repository](https://github.com/JonghwanMun/LGI4temporalgrounding), and download the video features for ActiviNet-Captions.
67 | Place the data under `/dataset/lgi_video_feature/anet_feats`.
68 |
69 | - Other data can be downloaded from [this link](https://drive.google.com/file/d/1yXnVHslpV51zqd9TRJAjFXPZZESCFrPm/view?usp=sharing).
70 |
71 | Please download the file, unzip it, and type followings to train/inference with the data.
72 |
73 | To train the model, please run:
74 | ```bash
75 | python train.py --model CrossModalityTwostageAttention --config configs/anet_simple_model/simplemodel_anet_BS256_two-stage_attention.yml --dataset anet
76 | ```
77 |
78 | To inference with test set, please run:
79 | ```bash
80 | python inference.py --model CrossModalityTwostageAttention --config configs/anet_simple_model/simplemodel_anet_BS256_two-stage_attention.yml --pre_trained anet_pretrained_best.pth
81 | ```
82 |
83 | ## Lisence
84 | MIT Lisence
85 |
86 | ## Citation
87 |
88 | If you use this code, please cite:
89 | ```
90 | @inproceedings{nam2021zero,
91 | title={Zero-shot Natural Language Video Localization},
92 | author={Nam, Jinwoo and Ahn, Daechul and Kang, Dongyeop and Ha, Seong Jong and Choi, Jonghyun},
93 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
94 | pages={1470-1479},
95 | year={2021}
96 | }
97 | ```
98 |
99 | ## Contact
100 | If you have any questions, please send e-mail to me (skaws2012@gmail.com, daechulahn@gm.gist.ac.kr)
101 |
--------------------------------------------------------------------------------
/config_parsers/simple_model_config.py:
--------------------------------------------------------------------------------
1 | """
2 | simple_model_config.py
3 | ****
4 | config for charades dataset
5 | """
6 | from yacs.config import CfgNode as CN
7 | _C = CN()
8 | _C.EXP_NAME = "SimpleModel-verbpos"
9 | #_C.EXP_NAME = "Debug"
10 |
11 | # training options
12 | _C.DATASET = CN()
13 | _C.DATASET.NAME = "Charades"
14 | _C.DATASET.SHOW_TOP_VOCAB = 1
15 | _C.DATASET.BATCH_SIZE = 100
16 | _C.DATASET.MAX_LENGTH = 15
17 | _C.DATASET.NUM_SEGMENT = 128
18 | _C.DATASET.DATA_PATH = "/home/data/anet"
19 | _C.DATASET.TRAIN_ANNO_PATH = "annotations/charades_sta_train.json"
20 | _C.DATASET.TEST_ANNO_PATH = "annotations/charades_sta_test.json"
21 | _C.DATASET.VID_PATH = ""
22 |
23 | # model options
24 | _C.MODEL = CN()
25 | _C.MODEL.QUERY = CN()
26 | _C.MODEL.QUERY.EMB_IDIM = -1
27 | _C.MODEL.QUERY.EMB_ODIM = 300
28 | _C.MODEL.QUERY.GRU_HDIM = 256
29 | _C.MODEL.VIDEO = CN()
30 | _C.MODEL.VIDEO.IDIM = 1024
31 | _C.MODEL.VIDEO.GRU_HDIM = 256
32 | _C.MODEL.FUSION = CN()
33 | _C.MODEL.FUSION.EMB_DIM = 256
34 | _C.MODEL.FUSION.NUM_HEAD = 8
35 | _C.MODEL.FUSION.NUM_LAYERS = 3
36 | _C.MODEL.FUSION.CONVBNRELU = CN()
37 | _C.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE = 3
38 | _C.MODEL.FUSION.CONVBNRELU.PADDING = 1
39 | _C.MODEL.NONLOCAL = CN()
40 | _C.MODEL.NONLOCAL.NUM_LAYERS = 2
41 | _C.MODEL.NONLOCAL.NUM_HEAD = 4
42 | _C.MODEL.NONLOCAL.USE_BIAS = True
43 | _C.MODEL.NONLOCAL.DROPOUT = 0.0
44 |
45 | # training options
46 | _C.TRAIN = CN()
47 | _C.TRAIN.LR = 0.0001
48 | _C.TRAIN.NUM_EPOCH = 300
49 | _C.TRAIN.BATCH_SIZE = 100
50 | _C.TRAIN.NUM_WORKERS = 4
51 | _C.TRAIN.IOU_THRESH = [0.3,0.5,0.7]
52 |
--------------------------------------------------------------------------------
/config_parsers/simple_model_cross_modality_twostage_attention_config.py:
--------------------------------------------------------------------------------
1 | """
2 | simple_model_config.py
3 | ****
4 | config for charades dataset
5 | """
6 | from yacs.config import CfgNode as CN
7 | _C = CN()
8 | _C.EXP_NAME = "SimpleModel-verbpos"
9 | #_C.EXP_NAME = "Debug"
10 |
11 | # training options
12 | _C.DATASET = CN()
13 | _C.DATASET.NAME = "Charades"
14 | _C.DATASET.SHOW_TOP_VOCAB = 1
15 | _C.DATASET.BATCH_SIZE = 100
16 | _C.DATASET.MAX_LENGTH = 15
17 | _C.DATASET.NUM_SEGMENT = 128
18 | _C.DATASET.DATA_PATH = "/home/data/charades"
19 | _C.DATASET.TRAIN_ANNO_PATH = "annotations/charades_train_pos.json"
20 | _C.DATASET.TEST_ANNO_PATH = "annotations/charades_test_pos.json"
21 | _C.DATASET.VID_PATH = ""
22 |
23 | # model options
24 | _C.MODEL = CN()
25 | _C.MODEL.QUERY = CN()
26 | _C.MODEL.QUERY.EMB_IDIM = -1
27 | _C.MODEL.QUERY.TRANSFORMER_DIM = 300
28 | _C.MODEL.QUERY.EMB_ODIM = 300
29 | _C.MODEL.QUERY.GRU_HDIM = 256
30 | _C.MODEL.QUERY.TEMPERATURE = 1.0
31 | _C.MODEL.VIDEO = CN()
32 | _C.MODEL.VIDEO.IDIM = 1024
33 | _C.MODEL.VIDEO.GRU_HDIM = 256
34 | _C.MODEL.VIDEO.ANET_TRANSFORMER_DIM = 500
35 | _C.MODEL.VIDEO.CHA_TRANSFORMER_DIM = 1024
36 | _C.MODEL.FUSION = CN()
37 | _C.MODEL.FUSION.EMB_DIM = 256
38 | _C.MODEL.FUSION.NUM_HEAD = 8
39 | _C.MODEL.FUSION.NUM_LAYERS = 3
40 | _C.MODEL.FUSION.USE_RESBLOCK = False
41 | _C.MODEL.FUSION.RESBLOCK = CN()
42 | _C.MODEL.FUSION.RESBLOCK.KERNEL_SIZE = 3
43 | _C.MODEL.FUSION.RESBLOCK.PADDING = 1
44 | _C.MODEL.FUSION.RESBLOCK.NB_ITER = 1
45 | _C.MODEL.FUSION.CONVBNRELU = CN()
46 | _C.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE = 3
47 | _C.MODEL.FUSION.CONVBNRELU.PADDING = 1
48 |
49 | _C.MODEL.NONLOCAL = CN()
50 | _C.MODEL.NONLOCAL.NUM_LAYERS = 2
51 | _C.MODEL.NONLOCAL.NUM_HEAD = 4
52 | _C.MODEL.NONLOCAL.USE_BIAS = True
53 | _C.MODEL.NONLOCAL.DROPOUT = 0.0
54 |
55 | # training options
56 | _C.TRAIN = CN()
57 | #_C.TRAIN.USE_DETERMINISTIC = True
58 | _C.TRAIN.LR = 0.0001
59 | _C.TRAIN.NUM_EPOCH = 300
60 | _C.TRAIN.BATCH_SIZE = 100
61 | _C.TRAIN.NUM_WORKERS = 4
62 | _C.TRAIN.IOU_THRESH = [0.3,0.5,0.7]
63 |
--------------------------------------------------------------------------------
/configs/cha_simple_model/simplemodel_cha_BS256_two-stage_attention.yml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | BATCH_SIZE: 256
3 | DATA_PATH: "/dataset/charades_feats"
4 | MAX_LENGTH: 10
5 | NAME: Charades
6 | NUM_SEGMENT: 128
7 | SHOW_TOP_VOCAB: 10
8 | TEST_ANNO_PATH: annotations/charades/charades_sta_test_pos_original_simple_sent.json
9 | TRAIN_ANNO_PATH: annotations/charades/charades_train_pseudo_supervision_TEP_PS.json
10 | VID_PATH: ''
11 | EXP_NAME: SimpleModel_twostage_attention
12 | MODEL:
13 | FUSION:
14 | CONVBNRELU:
15 | KERNEL_SIZE: 3
16 | PADDING: 1
17 | EMB_DIM: 256
18 | NUM_HEAD: 1
19 | NUM_LAYERS: 3
20 | NONLOCAL:
21 | DROPOUT: 0.0
22 | NUM_HEAD: 4
23 | NUM_LAYERS: 2
24 | USE_BIAS: true
25 | QUERY:
26 | EMB_IDIM: 290
27 | EMB_ODIM: 300
28 | GRU_HDIM: 256
29 | VIDEO:
30 | GRU_HDIM: 256
31 | IDIM: 1024
32 | TRAIN:
33 | BATCH_SIZE: 256
34 | IOU_THRESH:
35 | - 0.1
36 | - 0.3
37 | - 0.5
38 | - 0.7
39 | LR: 0.0004
40 | NUM_EPOCH: 500
41 | NUM_WORKERS: 4
42 |
--------------------------------------------------------------------------------
/dataset/anetcap_basic.py:
--------------------------------------------------------------------------------
1 | #%%
2 | """
3 | charades_basic.py
4 | ****
5 | the basic charades dataset class.
6 | """
7 | #%%
8 | # import things
9 | import os
10 | from os.path import join
11 | import json
12 | import h5py
13 | import numpy as np
14 | from tqdm import tqdm
15 | import torch
16 | from torch.utils.data import Dataset
17 | from torch.utils.data import DataLoader
18 | from random import random, randint, choice
19 | from collections import Counter
20 |
21 | import sys
22 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
23 | from config_parsers.simple_model_config import _C as TestCfg
24 |
25 |
26 | #%%
27 | # dataset builder class
28 | class AnetCapBasicDatasetBuilder():
29 | def __init__(self,cfg,data_path=None,anno_path=None,vid_path=None):
30 | # make variables that will be used in the future
31 | self.cfg = cfg
32 | self.splits = ["train","test"] # list of splits
33 | # make paths
34 | if len(data_path) > 0:
35 | self.vid_path = join(data_path,"sub_activitynet_v1-3.c3d.hdf5")
36 | self.anno_path = {"train": join(data_path,"normal_set","captions","train_simplified_sentence.json"),
37 | "test": [join(data_path,"normal_set","captions","val_1_simplified_sentence.json"),
38 | join(data_path,"normal_set","captions","val_2_simplified_sentence.json")]}
39 | if len(vid_path) > 0:
40 | self.vid_path = vid_path
41 | if len(anno_path) > 0:
42 | self.anno_path = anno_path
43 | self.anno_path['test'] = [self.anno_path['test'].format(i) for i in [1,2]]
44 | # read annotations
45 | self.annos = self._read_annos(self.anno_path)
46 | # make dictionary of word-index correspondence
47 | self.wtoi, self.itow = self._make_word_dictionary(self.annos)
48 |
49 | def _read_annos(self,anno_path):
50 | # read annotations
51 | annos = {s: None for s in self.splits}
52 | for s in self.splits:
53 | if isinstance(anno_path[s],str):
54 | with open(anno_path[s],'r') as f:
55 | #annos[s] = json.load(f)[:100]
56 | annos[s] = json.load(f)
57 | elif isinstance(anno_path[s],list):
58 | annos[s] = []
59 | for path in anno_path[s]:
60 | with open(path,'r') as f:
61 | annos[s] += json.load(f)
62 | return annos
63 |
64 | def _make_word_dictionary(self,annos):
65 | """
66 | makes word tokens - number idx correspondences
67 | ARGS:
68 | - annos: annotations read
69 | RETURNS:
70 | - wtoi: word -> index dictionary
71 | - itow: index -> word dictionary
72 | PARAMS:
73 | - DATASET.SHOW_TOP_VOCAB: the number of top-n tokens to print
74 | """
75 | # get training annos
76 | train_annos = self.annos["train"]
77 | # read tokens
78 | tokens_list = []
79 | for ann in train_annos:
80 | tokens_list += [tk for tk in ann["tokens"]]
81 | # print results: count tokens and show top-n
82 | print("Top-{} tokens list:".format(self.cfg.DATASET.SHOW_TOP_VOCAB))
83 | tokens_count = sorted(Counter(tokens_list).items(), key=lambda x:x[1])
84 | for tk in tokens_count[-self.cfg.DATASET.SHOW_TOP_VOCAB:]:
85 | print("\t- {}: {}".format(tk[0],tk[1]))
86 | # make wtoi, itow
87 | wtoi = {}
88 | wtoi[""], wtoi[""] = 0, 1
89 | wtoi[""], wtoi[""] = 2, 3
90 | for i,(tk,cnt) in enumerate(tokens_count):
91 | idx = i+4 # idx start at 4
92 | wtoi[tk] = idx
93 | itow = {v:k for k,v in wtoi.items()}
94 | self.cfg.MODEL.QUERY.EMB_IDIM = len(wtoi)
95 | return wtoi, itow
96 |
97 | def make_dataloaders(self):
98 | """
99 | makes actual dataset class
100 | RETURNS:
101 | - dataloaders: dataset classes for each splits. dictionary of {split: dataset}
102 | """
103 | # read annotations
104 | annos = self._read_annos(self.anno_path)
105 | # make dictionary of word-index correspondence
106 | wtoi, itow = self._make_word_dictionary(annos)
107 | batch_size = self.cfg.TRAIN.BATCH_SIZE
108 | num_workers = self.cfg.TRAIN.NUM_WORKERS
109 | dataloaders = {}
110 | for s in self.splits:
111 | if "train" in s:
112 | dataset = AnetCapBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s])
113 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=True, shuffle=True)
114 | else:
115 | dataset = AnetCapBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s])
116 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=False, shuffle=False)
117 | return dataloaders
118 |
119 |
120 | #%%
121 | # Charades Dataset Class
122 | class AnetCapBasicDataset(Dataset):
123 | def __init__(self, cfg, vid_path, split, wtoi, itow, annos):
124 | self.cfg = cfg
125 | self.vid_path = vid_path
126 | self.split = split
127 | self.wtoi = wtoi
128 | self.itow = itow
129 | self.annos = annos
130 | self.feats = self._load_vid_feats()
131 | self.num_segments = self.cfg.DATASET.NUM_SEGMENT
132 | self.sentence_max_length = self.cfg.DATASET.MAX_LENGTH
133 |
134 | def _load_vid_feats(self):
135 | feats = {}
136 | hfile = h5py.File(self.vid_path,'r')
137 | vid_list = set([x['vid'] for x in self.annos])
138 | for vid in tqdm(vid_list, desc="loading video features"):
139 | feats[vid] = hfile.get(vid).get("c3d_features")[()]
140 | return feats
141 |
142 | def _tokens_to_index(self,tokens):
143 | """
144 | translates list of tokens into trainable index format. also does padding.
145 | """
146 | wids = []
147 | for tk in tokens:
148 | if tk in self.wtoi.keys():
149 | wids.append(self.wtoi[tk])
150 | else:
151 | wids.append(1) #
152 | for _ in range(self.sentence_max_length - len(wids)):
153 | wids.append(0)
154 | if len(wids) > self.sentence_max_length:
155 | wids = wids[:self.sentence_max_length]
156 | return wids
157 |
158 | def get_fixed_length_feat(self, feat, num_segment, start_pos, end_pos):
159 | """
160 | makes fixed length feature. adopted from LGI code.
161 | """
162 | nfeats = feat[:,:].shape[0]
163 | if nfeats <= self.num_segments:
164 | stride = 1
165 | else:
166 | stride = nfeats * 1.0 / num_segment
167 | if self.split != "train":
168 | spos = 0
169 | else:
170 | random_end = -0.5 + stride
171 | if random_end == np.floor(random_end):
172 | random_end = random_end - 1.0
173 | spos = np.random.random_integers(0,random_end)
174 | s = np.round( np.arange(spos, nfeats-0.5, stride) ).astype(int)
175 | start_pos = float(nfeats-1.0) * start_pos
176 | end_pos = float(nfeats-1.0) * end_pos
177 |
178 | if not (nfeats < self.num_segments and len(s) == nfeats) \
179 | and not (nfeats >= self.num_segments and len(s) == num_segment):
180 | s = s[:num_segment] # ignore last one
181 | assert (nfeats < self.num_segments and len(s) == nfeats) \
182 | or (nfeats >= self.num_segments and len(s) == num_segment), \
183 | "{} != {} or {} != {}".format(len(s), nfeats, len(s), num_segment)
184 |
185 | start_index, end_index = None, None
186 | for i in range(len(s)-1):
187 | if s[i] <= end_pos < s[i+1]:
188 | end_index = i
189 | if s[i] <= start_pos < s[i+1]:
190 | start_index = i
191 |
192 | if start_index is None:
193 | start_index = 0
194 | if end_index is None:
195 | end_index = num_segment-1
196 |
197 | cur_feat = feat[s, :]
198 | nfeats = min(nfeats, num_segment)
199 | out = np.zeros((num_segment, cur_feat.shape[1]))
200 | out [:nfeats,:] = cur_feat
201 | return out, nfeats, start_index, end_index
202 |
203 | def make_attention_mask(self,start_index,end_index):
204 | attn_mask = np.zeros([self.num_segments])
205 | attn_mask[start_index:end_index+1] = 1
206 | attn_mask = torch.Tensor(attn_mask)
207 | return attn_mask
208 |
209 | def __getitem__(self,idx):
210 | anno = self.annos[idx]
211 | vid = anno["vid"]
212 | duration = anno['duration']
213 | timestamp = [x*duration for x in anno['timestamp']]
214 | start_pos, end_pos = anno['timestamp']
215 | query_label = self._tokens_to_index(anno['tokens'])
216 | query_length = len(anno['tokens'])
217 | vid_feat = self.feats[vid]
218 |
219 | fixed_vid_feat, nfeats, start_index, end_index = self.get_fixed_length_feat(vid_feat, self.num_segments, start_pos, end_pos)
220 | # get video masks
221 | vid_mask = np.zeros((self.num_segments, 1))
222 | vid_mask[:nfeats] = 1
223 | # make attn mask
224 | instance = {
225 | "vids": vid,
226 | "qids": idx,
227 | "timestamps": timestamp, # GT location [s, e] (second)
228 | "duration": duration, # video span (second)
229 | "query_lengths": query_length,
230 | "query_labels": torch.LongTensor(query_label).unsqueeze(0), # [1,L_q_max]
231 | "query_masks": (torch.FloatTensor(query_label)>0).unsqueeze(0), # [1,L_q_max]
232 | "grounding_start_pos": torch.FloatTensor([start_pos]), # [1]; normalized
233 | "grounding_end_pos": torch.FloatTensor([end_pos]), # [1]; normalized
234 | "nfeats": torch.FloatTensor([nfeats]),
235 | "video_feats": torch.FloatTensor(fixed_vid_feat), # [L_v,D_v]
236 | "video_masks": torch.ByteTensor(vid_mask), # [L_v,1]
237 | "attention_masks": self.make_attention_mask(start_index,end_index),
238 | }
239 | return instance
240 |
241 | def collate_fn(self, data):
242 | seq_items = ["video_feats", "video_masks","attention_masks"]
243 | tensor_items = [
244 | "query_labels", "query_masks", "nfeats",
245 | "grounding_start_pos", "grounding_end_pos"
246 | ]
247 | batch = {k: [d[k] for d in data] for k in data[0].keys()}
248 | if len(data) == 1:
249 | for k,v in batch.items():
250 | if k in tensor_items:
251 | batch[k] = torch.cat(batch[k], 0)
252 | elif k in seq_items:
253 | batch[k] = torch.nn.utils.rnn.pad_sequence(
254 | batch[k], batch_first=True)
255 | else:
256 | batch[k] = batch[k][0]
257 | else:
258 | for k in tensor_items:
259 | batch[k] = torch.cat(batch[k], 0)
260 | for k in seq_items:
261 | batch[k] = torch.nn.utils.rnn.pad_sequence(batch[k], batch_first=True)
262 | return batch
263 |
264 | def batch_to_device(self,batch,device):
265 | for k,v in batch.items():
266 | if isinstance(v,torch.Tensor):
267 | batch[k] = v.to(device)
268 |
269 | def __len__(self):
270 | return len(self.annos)
271 |
272 |
273 | #%%
274 | # test
275 | if __name__ == "__main__":
276 | DATA_PATH = "/home/data/anet"
277 | dataloaders = AnetCapBasicDatasetBuilder(TestCfg,DATA_PATH).make_dataloaders()
278 | for item in dataloaders['train']:
279 | _ = item['vids']
280 |
281 |
--------------------------------------------------------------------------------
/dataset/charades_basic.py:
--------------------------------------------------------------------------------
1 | #%%
2 | """
3 | charades_basic.py
4 | ****
5 | the basic charades dataset class.
6 | """
7 | #%%
8 | # import things
9 | import os
10 | from os.path import join
11 | import json
12 | import h5py
13 | import numpy as np
14 | from tqdm import tqdm
15 | import torch
16 | from torch.utils.data import Dataset
17 | from torch.utils.data import DataLoader
18 | from random import random, randint, choice
19 | from collections import Counter
20 |
21 | import sys
22 | sys.path.append(os.path.dirname(os.path.dirname(__file__)))
23 | from config_parsers.simple_model_config import _C as TestCfg
24 |
25 |
26 | #%%
27 | # dataset builder class
28 | class CharadesBasicDatasetBuilder():
29 | def __init__(self,cfg,data_path=None,anno_path=None,vid_path=None):
30 | # make variables that will be used in the future
31 | self.cfg = cfg
32 | self.splits = ["train","test"] # list of splits
33 | # make paths
34 | if len(data_path) > 0:
35 | self.vid_path = join(data_path,"i3d_finetuned")
36 | self.anno_path = {s: join(data_path,"annotations","charades_sta_{}_pos_original.json".format(s)) for s in self.splits}
37 | if len(vid_path) > 0:
38 | self.vid_path = vid_path
39 | if len(anno_path) > 0:
40 | self.anno_path = anno_path
41 | # read annotations
42 | self.annos = self._read_annos(self.anno_path)
43 | # make dictionary of word-index correspondence
44 | self.wtoi, self.itow = self._make_word_dictionary(self.annos)
45 |
46 | def _read_annos(self,anno_path):
47 | # read annotations
48 |
49 | annos = {s: None for s in self.splits}
50 | for s in self.splits:
51 | with open(anno_path[s],'r') as f:
52 | #annos[s] = json.load(f)[:100]
53 | annos[s] = json.load(f)
54 | return annos
55 |
56 | def _make_word_dictionary(self,annos):
57 | """
58 | makes word tokens - number idx correspondences
59 | ARGS:
60 | - annos: annotations read
61 | RETURNS:
62 | - wtoi: word -> index dictionary
63 | - itow: index -> word dictionary
64 | PARAMS:
65 | - DATASET.SHOW_TOP_VOCAB: the number of top-n tokens to print
66 | """
67 | # get training annos
68 | train_annos = self.annos["train"]
69 | # read tokens
70 | tokens_list = []
71 | for ann in train_annos:
72 | tokens_list += [tk for tk in ann["tokens"]]
73 | # print results: count tokens and show top-n
74 | print("Top-{} tokens list:".format(self.cfg.DATASET.SHOW_TOP_VOCAB))
75 | tokens_count = sorted(Counter(tokens_list).items(), key=lambda x:x[1])
76 | for tk in tokens_count[-self.cfg.DATASET.SHOW_TOP_VOCAB:]:
77 | print("\t- {}: {}".format(tk[0],tk[1]))
78 | # make wtoi, itow
79 | wtoi = {}
80 | wtoi[""], wtoi[""] = 0, 1
81 | wtoi[""], wtoi[""] = 2, 3
82 | for i,(tk,cnt) in enumerate(tokens_count):
83 | idx = i+4 # idx start at 4
84 | wtoi[tk] = idx
85 | itow = {v:k for k,v in wtoi.items()}
86 | self.cfg.MODEL.QUERY.EMB_IDIM = len(wtoi)
87 | return wtoi, itow
88 |
89 | def make_dataloaders(self):
90 | """
91 | makes actual dataset class
92 | RETURNS:
93 | - dataloaders: dataset classes for each splits. dictionary of {split: dataset}
94 | """
95 | # read annotations
96 | annos = self._read_annos(self.anno_path)
97 | # make dictionary of word-index correspondence
98 | wtoi, itow = self._make_word_dictionary(annos)
99 | batch_size = self.cfg.TRAIN.BATCH_SIZE
100 | num_workers = self.cfg.TRAIN.NUM_WORKERS
101 | dataloaders = {}
102 | for s in self.splits:
103 | if "train" in s:
104 | dataset = CharadesBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s])
105 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=True, shuffle=True)
106 | else:
107 | dataset = CharadesBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s])
108 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=False, shuffle=False)
109 | return dataloaders
110 |
111 |
112 | #%%
113 | # Charades Dataset Class
114 | class CharadesBasicDataset(Dataset):
115 | def __init__(self, cfg, vid_path, split, wtoi, itow, annos):
116 | self.cfg = cfg
117 | self.vid_path = vid_path
118 | self.split = split
119 | self.wtoi = wtoi
120 | self.itow = itow
121 | self.annos = annos
122 | self.feats = self._load_vid_feats()
123 | self.num_segments = self.cfg.DATASET.NUM_SEGMENT
124 | self.sentence_max_length = self.cfg.DATASET.MAX_LENGTH
125 |
126 | def _load_vid_feats(self):
127 | feats = {}
128 | vid_list = [x['vid'] for x in self.annos]
129 | for vid in tqdm(vid_list, desc="loading video features"):
130 | feats[vid] = np.load(join(self.vid_path,"{}.npy".format(vid))).squeeze()
131 | return feats
132 |
133 | def _tokens_to_index(self,tokens):
134 | """
135 | translates list of tokens into trainable index format. also does padding.
136 | """
137 | wids = []
138 | for tk in tokens:
139 | if tk in self.wtoi.keys():
140 | wids.append(self.wtoi[tk])
141 | else:
142 | wids.append(1) #
143 | for _ in range(self.sentence_max_length - len(wids)):
144 | wids.append(0)
145 | if len(wids) > self.sentence_max_length:
146 | wids = wids[:self.sentence_max_length]
147 | return wids
148 |
149 | def get_fixed_length_feat(self, feat, num_segment, start_pos, end_pos):
150 | """
151 | makes fixed length feature. adopted from LGI code.
152 | """
153 | nfeats = feat[:,:].shape[0]
154 | if nfeats <= self.num_segments:
155 | stride = 1
156 | else:
157 | stride = nfeats * 1.0 / num_segment
158 | if self.split != "train":
159 | spos = 0
160 | else:
161 | random_end = -0.5 + stride
162 | if random_end == np.floor(random_end):
163 | random_end = random_end - 1.0
164 | spos = np.random.random_integers(0,random_end)
165 | s = np.round( np.arange(spos, nfeats-0.5, stride) ).astype(int)
166 | start_pos = float(nfeats-1.0) * start_pos
167 | end_pos = float(nfeats-1.0) * end_pos
168 |
169 | if not (nfeats < self.num_segments and len(s) == nfeats) \
170 | and not (nfeats >= self.num_segments and len(s) == num_segment):
171 | s = s[:num_segment] # ignore last one
172 | assert (nfeats < self.num_segments and len(s) == nfeats) \
173 | or (nfeats >= self.num_segments and len(s) == num_segment), \
174 | "{} != {} or {} != {}".format(len(s), nfeats, len(s), num_segment)
175 |
176 | start_index, end_index = None, None
177 | for i in range(len(s)-1):
178 | if s[i] <= end_pos < s[i+1]:
179 | end_index = i
180 | if s[i] <= start_pos < s[i+1]:
181 | start_index = i
182 |
183 | if start_index is None:
184 | start_index = 0
185 | if end_index is None:
186 | end_index = num_segment-1
187 |
188 | cur_feat = feat[s, :]
189 | nfeats = min(nfeats, num_segment)
190 | out = np.zeros((num_segment, cur_feat.shape[1]))
191 | out [:nfeats,:] = cur_feat
192 | return out, nfeats, start_index, end_index
193 |
194 | def make_attention_mask(self,start_index,end_index):
195 | attn_mask = np.zeros([self.num_segments])
196 | attn_mask[start_index:end_index+1] = 1
197 | attn_mask = torch.Tensor(attn_mask)
198 | return attn_mask
199 |
200 | def __getitem__(self,idx):
201 | anno = self.annos[idx]
202 | vid = anno["vid"]
203 | duration = anno['duration']
204 | timestamp = [x*duration for x in anno['timestamp']]
205 | start_pos, end_pos = anno['timestamp']
206 | query_label = self._tokens_to_index(anno['tokens'])
207 | query_length = len(anno['tokens'])
208 | vid_feat = self.feats[vid]
209 |
210 | fixed_vid_feat, nfeats, start_index, end_index = self.get_fixed_length_feat(vid_feat, self.num_segments, start_pos, end_pos)
211 | # get video masks
212 | vid_mask = np.zeros((self.num_segments, 1))
213 | vid_mask[:nfeats] = 1
214 | # make attn mask
215 | instance = {
216 | "vids": vid,
217 | "qids": idx,
218 | "timestamps": timestamp, # GT location [s, e] (second)
219 | "duration": duration, # video span (second)
220 | "query_lengths": query_length,
221 | "query_labels": torch.LongTensor(query_label).unsqueeze(0), # [1,L_q_max]
222 | "query_masks": (torch.FloatTensor(query_label)>0).unsqueeze(0), # [1,L_q_max]
223 | "grounding_start_pos": torch.FloatTensor([start_pos]), # [1]; normalized
224 | "grounding_end_pos": torch.FloatTensor([end_pos]), # [1]; normalized
225 | "nfeats": torch.FloatTensor([nfeats]),
226 | "video_feats": torch.FloatTensor(fixed_vid_feat), # [L_v,D_v]
227 | "video_masks": torch.ByteTensor(vid_mask), # [L_v,1]
228 | "attention_masks": self.make_attention_mask(start_index,end_index),
229 | }
230 | return instance
231 |
232 | def collate_fn(self, data):
233 | seq_items = ["video_feats", "video_masks","attention_masks"]
234 | tensor_items = [
235 | "query_labels", "query_masks", "nfeats",
236 | "grounding_start_pos", "grounding_end_pos"
237 | ]
238 | batch = {k: [d[k] for d in data] for k in data[0].keys()}
239 | if len(data) == 1:
240 | for k,v in batch.items():
241 | if k in tensor_items:
242 | batch[k] = torch.cat(batch[k], 0)
243 | elif k in seq_items:
244 | batch[k] = torch.nn.utils.rnn.pad_sequence(
245 | batch[k], batch_first=True)
246 | else:
247 | batch[k] = batch[k][0]
248 | else:
249 | for k in tensor_items:
250 | batch[k] = torch.cat(batch[k], 0)
251 | for k in seq_items:
252 | batch[k] = torch.nn.utils.rnn.pad_sequence(batch[k], batch_first=True)
253 | return batch
254 |
255 | def batch_to_device(self,batch,device):
256 | for k,v in batch.items():
257 | if isinstance(v,torch.Tensor):
258 | batch[k] = v.to(device)
259 |
260 | def __len__(self):
261 | return len(self.annos)
262 |
263 |
264 | #%%
265 | # test
266 | if __name__ == "__main__":
267 | DATA_PATH = "/home/skaws2003/projects/didemo/localglobal/data/charades"
268 | dataloaders = CharadesBasicDatasetBuilder(TestCfg,DATA_PATH).make_dataloaders()
269 | for item in dataloaders['train']:
270 | _ = item['vids']
271 |
272 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | #%%
2 | # import things
3 | # model code
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from utils.eval_utils import NLVLEvaluator
7 | from utils.loss import NLVLLoss
8 |
9 | # logging code
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | # etc
13 | import random
14 | import numpy as np
15 | import os
16 | from os.path import join
17 | from tqdm import tqdm
18 | from copy import deepcopy
19 | from yacs.config import CfgNode
20 | from yaml import dump as dump_yaml
21 | import argparse
22 |
23 | #%%
24 | # argparse
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--model",'-m',type=str,default="CrossModalityTwostageAttention")
27 | parser.add_argument("--config",'-c',type=str,default="configs/cha_simple_model/simplemodel_cha_BS256_two-stage_attention.yml")
28 | parser.add_argument("--pre_trained", type=str, default="pretrained_weight.pth")
29 | parser.add_argument("--seed", '-s',type=int,default=0)
30 | parser.add_argument("--reg_w", type=float, default=1.0)
31 | args = parser.parse_args()
32 |
33 | dict_args = {
34 | "model": args.model,
35 | "confg": args.config
36 | }
37 |
38 | random.seed(int(args.seed))
39 | os.environ['PYTHONHASHSEED'] = str(args.seed)
40 | np.random.seed(int(args.seed))
41 | torch.manual_seed(int(args.seed))
42 | torch.cuda.manual_seed(int(args.seed))
43 | torch.cuda.manual_seed_all(int(args.seed))
44 | torch.backends.cudnn.deterministic = True
45 | torch.backends.cudnn.benchmark = False
46 | torch.backends.cudnn.enabled = False
47 |
48 | #%%
49 | # load things according to arguments
50 | if args.model == "SimpleModel":
51 | from models.simple_model import SimpleModel as Model
52 | from config_parsers.simple_model_config import _C as TestCfg
53 | elif args.model == "CrossModalityTwostageAttention":
54 | from models.simple_model_cross_modal_two_stage import SimpleModel as Model
55 | from config_parsers.simple_model_cross_modality_twostage_attention_config import _C as TestCfg
56 | else:
57 | raise ValueError("No such model: {}".format(args.model))
58 |
59 |
60 | #%%
61 | # constants
62 | CONFIG_PATH = args.config
63 | print("config file path: ", CONFIG_PATH)
64 | TestCfg.merge_from_file(CONFIG_PATH)
65 | cfg = TestCfg
66 | device = torch.device("cuda")
67 | DATA_PATH = cfg.DATASET.DATA_PATH
68 | ANNO_PATH ={"train": cfg.DATASET.TRAIN_ANNO_PATH,
69 | "test": cfg.DATASET.TEST_ANNO_PATH}
70 | VID_PATH = cfg.DATASET.VID_PATH
71 |
72 | #%%
73 | # function for dumping yaml
74 | def cfg_to_dict(cfg):
75 | dict_cfg = dict(cfg)
76 | for k,v in dict_cfg.items():
77 | if isinstance(v,CfgNode):
78 | dict_cfg[k] = cfg_to_dict(v)
79 | return dict_cfg
80 |
81 | #%%
82 | # Load dataloader
83 | dataset_name = cfg.DATASET.NAME
84 | if dataset_name == "Charades":
85 | from dataset.charades_basic import CharadesBasicDatasetBuilder
86 | dataloaders = CharadesBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders()
87 | elif dataset_name == "AnetCap":
88 | from dataset.anetcap_basic import AnetCapBasicDatasetBuilder
89 | dataloaders = AnetCapBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders()
90 | else:
91 | raise ValueError("No such dataset: {}".format(dataset_name))
92 |
93 | #%%
94 | # load training stuff
95 | ## model and loss
96 | model = Model(cfg).to(device)
97 | model.load_state_dict(torch.load(args.pre_trained))
98 |
99 | ## evaluator
100 | evaluator = NLVLEvaluator(cfg)
101 | batch_to_device = dataloaders['test'].dataset.batch_to_device
102 |
103 | # information print out
104 | print("====="*10)
105 | print(args)
106 | print("====="*10)
107 | print(cfg)
108 | print("====="*10)
109 |
110 | #%%
111 | # test loop
112 | model.eval()
113 | pbar = tqdm(range(1))
114 | for epoch in pbar:
115 | eval_results_list = []
116 | for batch_idx,item in enumerate(dataloaders['test']):
117 | # update progress bar
118 | pbar.set_description("test {}/{}".format(batch_idx,len(dataloaders['test'])))
119 | # make evaluation
120 | with torch.no_grad():
121 | batch_to_device(item,device)
122 | model_outputs = model(item)
123 | eval_results = evaluator(model_outputs,item)
124 | # add to mean eval results
125 | eval_results_list.append(eval_results)
126 | # write test information
127 | ## make mean metric dict
128 | mean_eval_results_dict = {}
129 | for k in list(eval_results_list[0].keys()):
130 | mean_eval_results_dict[k] = torch.mean(torch.Tensor([x[k] for x in eval_results_list]))
131 |
132 | # print test information
133 | tqdm.write("****** epoch:{} ******".format(epoch))
134 | for k,v in mean_eval_results_dict.items():
135 | tqdm.write("\t{}: {}".format(k,v))
136 |
--------------------------------------------------------------------------------
/media/task-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gistvision/PSVL/50b1041d472d4f140f7d98248a3bbfbb16d680da/media/task-1.png
--------------------------------------------------------------------------------
/models/simple_model.py:
--------------------------------------------------------------------------------
1 | #%%
2 | """
3 | simple_model.py
4 | ****
5 | simple, basic model for NLVL.
6 | - Query-Video matching with (Multi-Head Attention + ConvBNReLU) with residual connection
7 | - Video Encoding with simple GRU
8 | """
9 |
10 | #%%
11 | # import things
12 | import torch
13 | import torch.nn as nn
14 |
15 | #%%
16 | # model
17 | class SimpleSentenceEmbeddingModule(nn.Module):
18 | """
19 | A Simple Query Embedding class
20 | """
21 | def __init__(self, cfg):
22 | super().__init__()
23 | # config params
24 | self.cfg = cfg
25 | self.query_length = self.cfg.DATASET.MAX_LENGTH
26 | # embedding Layer
27 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM
28 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM
29 | self.embedding = nn.Embedding(emb_idim, emb_odim)
30 | # RNN Layer
31 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM
32 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True)
33 | # feature adjust
34 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
35 | self.feature_aggregation = nn.Sequential(
36 | nn.Linear(in_features=gru_hidden*2,out_features=emb_dim),
37 | nn.ReLU(),
38 | nn.Dropout(0.5))
39 |
40 | def forward(self, query_labels, query_masks):
41 | """
42 | encode query sequence using RNN and return logits over proposals.
43 | code adopted from LGI
44 | Args:
45 | query_labels: query_labels vectors of query; [B, vocab_size]
46 | query_masks: mask for query; [B,L]
47 | out_type: output type [word-level | sentenve-level | both]
48 | Returns:
49 | w_feats: word-level features; [B,L,2*h]
50 | s_feats: sentence-level feature; [B,2*h]
51 | """
52 | # embedding query_labels data
53 | wemb = self.embedding(query_labels) # [B,L,emb_odim]
54 | # encoding query_labels data.
55 | max_len = query_labels.size(1) # == L
56 | # make word-wise feature
57 | length = query_masks.sum(1) # [B,]
58 | pack_wemb = nn.utils.rnn.pack_padded_sequence(wemb, length, batch_first=True, enforce_sorted=False)
59 | w_feats, _ = self.gru(pack_wemb)
60 | w_feats, max_ = nn.utils.rnn.pad_packed_sequence(w_feats, batch_first=True, total_length=max_len)
61 | w_feats = w_feats.contiguous() # [B,L,2*h]
62 | # get sentence feature
63 | B, L, H = w_feats.size()
64 | idx = (length-1).long() # 0-indexed
65 | idx = idx.view(B, 1, 1).expand(B, 1, H//2)
66 | fLSTM = w_feats[:,:,:H//2].gather(1, idx).view(B, H//2)
67 | bLSTM = w_feats[:,0,H//2:].view(B,H//2)
68 | s_feats = torch.cat([fLSTM, bLSTM], dim=1)
69 | # aggregae features
70 | w_feats = self.feature_aggregation(w_feats)
71 | return w_feats, s_feats
72 |
73 |
74 | class SimpleVideoEmbeddingModule(nn.Module):
75 | """
76 | A simple Video Embedding Class
77 | """
78 | def __init__(self, cfg):
79 | super().__init__() # Must call super __init__()
80 | # get configuration
81 | self.cfg = cfg
82 | # video gru
83 | vid_idim = self.cfg.MODEL.VIDEO.IDIM
84 | vid_gru_hdim = self.cfg.MODEL.VIDEO.GRU_HDIM
85 | self.gru = nn.GRU(input_size=vid_idim,hidden_size=vid_gru_hdim,batch_first=True,dropout=0.5,bidirectional=True)
86 | # video feature aggregation module
87 | catted_dim = vid_idim + vid_gru_hdim*2
88 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
89 | self.feature_aggregation = nn.Sequential(
90 | nn.Linear(in_features=catted_dim,out_features=emb_dim),
91 | nn.ReLU(),
92 | nn.Dropout(0.5),
93 | )
94 |
95 | def forward(self, vid_feats, vid_masks):
96 | """
97 | encode video features. Utilizes GRU.
98 | Args:
99 | vid_feats: video features
100 | vid_masks: mask for video
101 | Return:
102 | vid_features: hidden state features of the video
103 | """
104 | length = vid_masks.sum(1).squeeze(1)
105 | packed_vid = nn.utils.rnn.pack_padded_sequence(vid_feats, length, batch_first=True, enforce_sorted=False)
106 | vid_hiddens, _ = self.gru(packed_vid)
107 | vid_hiddens, max_ = nn.utils.rnn.pad_packed_sequence(vid_hiddens, batch_first=True, total_length=vid_feats.shape[1])
108 | vid_catted = torch.cat([vid_feats,vid_hiddens],dim=2)
109 | vid_output = self.feature_aggregation(vid_catted)
110 | return vid_output
111 |
112 |
113 | class FusionConvBNReLU(nn.Module):
114 | def __init__(self,cfg):
115 | super().__init__()
116 | # get configuration
117 | self.cfg = cfg
118 | # modules
119 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
120 | kernel_size = self.cfg.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE
121 | padding = self.cfg.MODEL.FUSION.CONVBNRELU.PADDING
122 | self.module = nn.Sequential(
123 | nn.Conv1d(in_channels=emb_dim,out_channels=emb_dim,kernel_size=kernel_size,padding=padding),
124 | nn.BatchNorm1d(num_features=emb_dim),
125 | nn.ReLU())
126 |
127 | def forward(self,feature):
128 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first)
129 | convolved_feature = self.module(transposed_feature)
130 | return torch.transpose(convolved_feature,1,2)
131 |
132 |
133 | class AttentionBlock(nn.Module):
134 | def __init__(self,cfg):
135 | super().__init__()
136 | # get configuration
137 | self.cfg = cfg
138 | # modules
139 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
140 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD
141 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head)
142 | self.convbnrelu = FusionConvBNReLU(cfg)
143 |
144 | def forward(self,vid_feats,query_feats,query_masks):
145 | # attnetion
146 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
147 | attended_feature, weights = self.attention(
148 | query=torch.transpose(vid_feats,0,1),
149 | key=torch.transpose(query_feats,0,1),
150 | value=torch.transpose(query_feats,0,1),
151 | key_padding_mask=key_padding_mask,)
152 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format
153 | # convolution
154 | convolved_feature = self.convbnrelu(attended_feature) + vid_feats
155 | return convolved_feature
156 |
157 |
158 | class SimpleFusionModule(nn.Module):
159 | def __init__(self, cfg):
160 | super().__init__()
161 | # get configuration
162 | self.cfg = cfg
163 | # attention module
164 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS
165 | self.layers = []
166 | for _ in range(num_layers):
167 | self.layers.append(AttentionBlock(cfg))
168 | self.layers = nn.ModuleList(self.layers)
169 |
170 | def forward(self, query_feats, query_masks, vid_feats, vid_masks):
171 | attended_vid_feats = vid_feats
172 | for attn_layer in self.layers:
173 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks)
174 | return attended_vid_feats
175 |
176 |
177 | class NonLocalBlock(nn.Module):
178 | """
179 | Nonlocal block used for obtaining global feature.
180 | code borrowed from LGI
181 | """
182 | def __init__(self, cfg):
183 | super(NonLocalBlock, self).__init__()
184 | self.cfg = cfg
185 | # dims
186 | self.idim = self.cfg.MODEL.FUSION.EMB_DIM
187 | self.odim = self.cfg.MODEL.FUSION.EMB_DIM
188 | self.nheads = self.cfg.MODEL.NONLOCAL.NUM_HEAD
189 |
190 | # options
191 | self.use_bias = self.cfg.MODEL.NONLOCAL.USE_BIAS
192 |
193 | # layers
194 | self.c_lin = nn.Linear(self.idim, self.odim*2, bias=self.use_bias)
195 | self.v_lin = nn.Linear(self.idim, self.odim, bias=self.use_bias)
196 |
197 | self.relu = nn.ReLU()
198 | self.sigmoid = nn.Sigmoid()
199 | self.drop = nn.Dropout(self.cfg.MODEL.NONLOCAL.DROPOUT)
200 |
201 | def forward(self, m_feats, mask):
202 | """
203 | Inputs:
204 | m_feats: segment-level multimodal feature [B,nseg,*]
205 | mask: mask [B,nseg]
206 | Outputs:
207 | updated_m: updated multimodal feature [B,nseg,*]
208 | """
209 |
210 | mask = mask.float()
211 | B, nseg = mask.size()
212 |
213 | # key, query, value
214 | m_k = self.v_lin(self.drop(m_feats)) # [B,num_seg,*]
215 | m_trans = self.c_lin(self.drop(m_feats)) # [B,nseg,2*]
216 | m_q, m_v = torch.split(m_trans, m_trans.size(2) // 2, dim=2)
217 |
218 | new_mq = m_q
219 | new_mk = m_k
220 |
221 | # applying multi-head attention
222 | w_list = []
223 | mk_set = torch.split(new_mk, new_mk.size(2) // self.nheads, dim=2)
224 | mq_set = torch.split(new_mq, new_mq.size(2) // self.nheads, dim=2)
225 | mv_set = torch.split(m_v, m_v.size(2) // self.nheads, dim=2)
226 | for i in range(self.nheads):
227 | mk_slice, mq_slice, mv_slice = mk_set[i], mq_set[i], mv_set[i] # [B, nseg, *]
228 |
229 | # compute relation matrix; [B,nseg,nseg]
230 | m2m = mk_slice @ mq_slice.transpose(1,2) / ((self.odim // self.nheads) ** 0.5)
231 | m2m = m2m.masked_fill(mask.unsqueeze(1).eq(0), -1e9) # [B,nseg,nseg]
232 | m2m_w = torch.nn.functional.softmax(m2m, dim=2) # [B,nseg,nseg]
233 | w_list.append(m2m_w)
234 |
235 | # compute relation vector for each segment
236 | r = m2m_w @ mv_slice if (i==0) else torch.cat((r, m2m_w @ mv_slice), dim=2)
237 |
238 | updated_m =m_feats + r
239 | return updated_m
240 |
241 |
242 | class AttentivePooling(nn.Module):
243 | def __init__(self, cfg):
244 | self.cfg = cfg
245 | super(AttentivePooling, self).__init__()
246 | self.att_n = 1
247 | self.feat_dim = self.cfg.MODEL.FUSION.EMB_DIM
248 | self.att_hid_dim = self.cfg.MODEL.FUSION.EMB_DIM // 2
249 | self.use_embedding = True
250 |
251 | self.feat2att = nn.Linear(self.feat_dim, self.att_hid_dim, bias=False)
252 | self.to_alpha = nn.Linear(self.att_hid_dim, self.att_n, bias=False)
253 | if self.use_embedding:
254 | edim = self.cfg.MODEL.FUSION.EMB_DIM
255 | self.fc = nn.Linear(self.feat_dim, edim)
256 |
257 | def forward(self, feats, f_masks=None):
258 | """
259 | Compute attention weights and attended feature (weighted sum)
260 | Args:
261 | feats: features where attention weights are computed; [B, A, D]
262 | f_masks: mask for effective features; [B, A]
263 | """
264 | # check inputs
265 | assert len(feats.size()) == 3 or len(feats.size()) == 4
266 | assert f_masks is None or len(f_masks.size()) == 2
267 |
268 | # dealing with dimension 4
269 | if len(feats.size()) == 4:
270 | B, W, H, D = feats.size()
271 | feats = feats.view(B, W*H, D)
272 |
273 | # embedding feature vectors
274 | attn_f = self.feat2att(feats) # [B,A,hdim]
275 |
276 | # compute attention weights
277 | dot = torch.tanh(attn_f) # [B,A,hdim]
278 | alpha = self.to_alpha(dot) # [B,A,att_n]
279 | if f_masks is not None:
280 | alpha = alpha.masked_fill(f_masks.float().unsqueeze(2).eq(0), -1e9)
281 | attw = torch.nn.functional.softmax(alpha.transpose(1,2), dim=2) # [B,att_n,A]
282 |
283 | att_feats = attw @ feats # [B,att_n,D]
284 | att_feats = att_feats.squeeze(1)
285 | attw = attw.squeeze(1)
286 | if self.use_embedding: att_feats = self.fc(att_feats)
287 |
288 | return att_feats, attw
289 |
290 |
291 | class AttentionLocRegressor(nn.Module):
292 | def __init__(self, cfg):
293 | super(AttentionLocRegressor, self).__init__()
294 | self.cfg = cfg
295 | self.tatt = AttentivePooling(self.cfg)
296 | # Regression layer
297 | idim = self.cfg.MODEL.FUSION.EMB_DIM
298 | gdim = self.cfg.MODEL.FUSION.EMB_DIM
299 | nn_list = [ nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()]
300 | self.MLP_reg = nn.Sequential(*nn_list)
301 |
302 | def forward(self, semantic_aware_seg_feats, masks):
303 | # perform Eq. (13) and (14)
304 | summarized_vfeat, att_w = self.tatt(semantic_aware_seg_feats, masks)
305 | # perform Eq. (15)
306 | loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e]
307 | return loc, att_w
308 |
309 |
310 | class SimpleModel(nn.Module):
311 | def __init__(self,cfg):
312 | super().__init__()
313 | self.cfg = cfg
314 | self.query_encoder = SimpleSentenceEmbeddingModule(cfg)
315 | self.video_encoder = SimpleVideoEmbeddingModule(cfg)
316 | self.fusor = SimpleFusionModule(cfg)
317 | self.n_non_local = self.cfg.MODEL.NONLOCAL.NUM_LAYERS
318 | self.non_locals = nn.ModuleList([NonLocalBlock(cfg) for _ in range(self.n_non_local)])
319 | self.loc_regressor = AttentionLocRegressor(cfg)
320 |
321 | def forward(self,inputs):
322 | # encode query
323 | query_labels = inputs['query_labels']
324 | query_masks = inputs['query_masks']
325 | encoded_query, encoded_sentence = self.query_encoder(query_labels, query_masks)
326 | # encode video
327 | vid_feats = inputs['video_feats']
328 | vid_masks = inputs['video_masks']
329 | encoded_video = self.video_encoder(vid_feats,vid_masks)
330 | attended_vid = self.fusor(encoded_query, query_masks, encoded_video, vid_masks)
331 | global_vid = attended_vid
332 | for non_local_layer in self.non_locals:
333 | global_vid = non_local_layer(global_vid,vid_masks.squeeze(2))
334 | loc,attn_weight = self.loc_regressor(global_vid,vid_masks.squeeze(2))
335 | return {"timestamps": loc,
336 | "attention_weights": attn_weight}
337 |
338 |
339 |
340 |
341 |
--------------------------------------------------------------------------------
/models/simple_model_cross_modal_two_stage.py:
--------------------------------------------------------------------------------
1 | #%%
2 | """
3 | simple_model.py
4 | ****
5 | simple, basic model for NLVL.
6 | - Query-Video matching with (Multi-Head Attention + ConvBNReLU) with residual connection
7 | - Video Encoding with simple GRU
8 | """
9 |
10 | #%%
11 | # import things
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | #%%
17 | # model
18 | class SimpleSentenceEmbeddingModule(nn.Module):
19 | """
20 | A Simple Query Embedding class
21 | """
22 | def __init__(self, cfg):
23 | super().__init__()
24 | # config params
25 | self.cfg = cfg
26 | self.query_length = self.cfg.DATASET.MAX_LENGTH
27 | # embedding Layer
28 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM
29 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM
30 | self.embedding = nn.Embedding(emb_idim, emb_odim)
31 | # RNN Layer
32 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM
33 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True)
34 | # feature adjust
35 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
36 | self.feature_aggregation = nn.Sequential(
37 | nn.Linear(in_features=gru_hidden*2,out_features=emb_dim),
38 | nn.ReLU(),
39 | nn.Dropout(0.5))
40 |
41 | def forward(self, query_labels, query_masks):
42 | """
43 | encode query sequence using RNN and return logits over proposals.
44 | code adopted from LGI
45 | Args:
46 | query_labels: query_labels vectors of query; [B, vocab_size]
47 | query_masks: mask for query; [B,L]
48 | out_type: output type [word-level | sentenve-level | both]
49 | Returns:
50 | w_feats: word-level features; [B,L,2*h]
51 | s_feats: sentence-level feature; [B,2*h]
52 | """
53 | # embedding query_labels data
54 | wemb = self.embedding(query_labels) # [B,L,emb_odim]
55 | # encoding query_labels data.
56 | max_len = query_labels.size(1) # == L
57 | # make word-wise feature
58 | length = query_masks.sum(1) # [B,]
59 | pack_wemb = nn.utils.rnn.pack_padded_sequence(wemb, length, batch_first=True, enforce_sorted=False)
60 | w_feats, _ = self.gru(pack_wemb)
61 | w_feats, max_ = nn.utils.rnn.pad_packed_sequence(w_feats, batch_first=True, total_length=max_len)
62 | w_feats = w_feats.contiguous() # [B,L,2*h]
63 |
64 | # get sentence feature
65 | B, L, H = w_feats.size()
66 | idx = (length-1).long() # 0-indexed
67 | idx = idx.view(B, 1, 1).expand(B, 1, H//2)
68 | fLSTM = w_feats[:,:,:H//2].gather(1, idx).view(B, H//2)
69 | bLSTM = w_feats[:,0,H//2:].view(B,H//2)
70 | s_feats = torch.cat([fLSTM, bLSTM], dim=1)
71 |
72 | # aggregae features
73 | w_feats = self.feature_aggregation(w_feats)
74 | return w_feats, s_feats
75 |
76 |
77 | class TransformerSentenceEmbeddingModule(nn.Module):
78 | """
79 | A Simple Query Embedding class
80 | """
81 | def __init__(self, cfg):
82 | super().__init__()
83 | # config params
84 | self.cfg = cfg
85 | self.query_length = self.cfg.DATASET.MAX_LENGTH
86 | # embedding Layer
87 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM
88 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM
89 | self.embedding = nn.Embedding(emb_idim, emb_odim)
90 |
91 | # RNN Layer
92 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM
93 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True)
94 |
95 | # Attention layer
96 | t_emb_dim = self.cfg.MODEL.QUERY.TRANSFORMER_DIM # 300
97 | #t_emb_dim = gru_hidden * 2 # 256 * 2
98 | self.attention = nn.MultiheadAttention(embed_dim=t_emb_dim, num_heads=4)
99 |
100 | # feature adjust
101 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
102 |
103 | self.feature_aggregation = nn.Sequential(
104 | nn.Linear(in_features=t_emb_dim, out_features=emb_dim),
105 | nn.ReLU(),
106 | nn.Dropout(0,5))
107 |
108 |
109 | def forward(self, query_labels, query_masks):
110 | """
111 | encode query sequence using RNN and return logits over proposals.
112 | code adopted from LGI
113 | Args:
114 | query_labels: query_labels vectors of query; [B, vocab_size]
115 | query_masks: mask for query; [B,L]
116 | out_type: output type [word-level | sentenve-level | both]
117 | Returns:
118 | w_feats: word-level features; [B,L,2*h]
119 | s_feats: sentence-level feature; [B,2*h]
120 | """
121 | # embedding query_labels data
122 | wemb = self.embedding(query_labels) # [B,L,emb_odim]
123 |
124 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
125 | # [B, L, D] -> [L, B, D]
126 | attended_feature, weights = self.attention(
127 | query=torch.transpose(wemb, 0,1),
128 | key=torch.transpose(wemb, 0,1),
129 | value=torch.transpose(wemb, 0,1),
130 | key_padding_mask=key_padding_mask,)
131 |
132 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format
133 | # convolution?
134 |
135 | # aggregae features
136 | w_feats = self.feature_aggregation(attended_feature)
137 | #return w_feats, s_feats
138 | return w_feats
139 |
140 | class SimpleVideoEmbeddingModule(nn.Module):
141 | """
142 | A simple Video Embedding Class
143 | """
144 | def __init__(self, cfg):
145 | super().__init__() # Must call super __init__()
146 | # get configuration
147 | self.cfg = cfg
148 | # video gru
149 | vid_idim = self.cfg.MODEL.VIDEO.IDIM
150 | vid_gru_hdim = self.cfg.MODEL.VIDEO.GRU_HDIM
151 | self.gru = nn.GRU(input_size=vid_idim,hidden_size=vid_gru_hdim,batch_first=True,dropout=0.5,bidirectional=True)
152 |
153 |
154 | # video feature aggregation module
155 | catted_dim = vid_idim + vid_gru_hdim*2
156 | #catted_dim = vid_gru_hdim *2
157 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
158 | self.feature_aggregation = nn.Sequential(
159 | nn.Linear(in_features=catted_dim,out_features=emb_dim),
160 | nn.ReLU(),
161 | nn.Dropout(0.5),
162 | )
163 |
164 | def forward(self, vid_feats, vid_masks):
165 | """
166 | encode video features. Utilizes GRU.
167 | Args:
168 | vid_feats: video features
169 | vid_masks: mask for video
170 | Return:
171 | vid_features: hidden state features of the video
172 | """
173 | length = vid_masks.sum(1).squeeze(1)
174 | packed_vid = nn.utils.rnn.pack_padded_sequence(vid_feats, length, batch_first=True, enforce_sorted=False)
175 | vid_hiddens, _ = self.gru(packed_vid)
176 | vid_hiddens, max_ = nn.utils.rnn.pad_packed_sequence(vid_hiddens, batch_first=True, total_length=vid_feats.shape[1])
177 | #vid_output = self.feature_aggregation(vid_hiddens)
178 |
179 | vid_catted = torch.cat([vid_feats,vid_hiddens],dim=2)
180 | vid_output = self.feature_aggregation(vid_catted)
181 | return vid_output
182 |
183 |
184 | class TransformerVideoEmbeddingModule(nn.Module):
185 | """
186 | A simple Video Embedding Class
187 | """
188 | def __init__(self, cfg):
189 | super().__init__() # Must call super __init__()
190 | # get configuration
191 | self.cfg = cfg
192 |
193 | # video transformer
194 | vid_idim = self.cfg.MODEL.VIDEO.IDIM
195 | vid_transformer_hdim = self.cfg.MODEL.VIDEO.ANET.TRANSFORMER_DIM # 1024(charades), 1000 (anet)
196 | self.attention = nn.MultiheadAttention(embed_dim=vid_idim, num_heads=4)
197 |
198 | # video feature aggregation module
199 | catted_dim = vid_idim + vid_transformer_hdim
200 |
201 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
202 | self.feature_aggregation = nn.Sequential(
203 | nn.Linear(in_features=catted_dim,out_features=emb_dim),
204 | nn.ReLU(),
205 | nn.Dropout(0.5),
206 | )
207 |
208 | def forward(self, vid_feats, vid_masks):
209 | """
210 | encode video features. Utilizes GRU.
211 | Args:
212 | vid_feats: video features
213 | vid_masks: mask for video
214 | Return:
215 | vid_features: hidden state features of the video
216 | """
217 |
218 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
219 | # [B, L, D] -> [L, B, D]
220 | attended_feature, weights = self.attention(
221 | query=torch.transpose(vid_feats, 0,1),
222 | key=torch.transpose(vid_feats, 0,1),
223 | value=torch.transpose(vid_feats, 0,1),
224 | key_padding_mask=key_padding_mask.squeeze(),)
225 |
226 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format
227 | # convolution?
228 |
229 | # aggregae features
230 | vid_catted = torch.cat([vid_feats,attended_feature],dim=2)
231 | vid_output = self.feature_aggregation(vid_catted)
232 |
233 | return vid_output
234 |
235 | class FusionConvBNReLU(nn.Module):
236 | def __init__(self,cfg):
237 | super().__init__()
238 | # get configuration
239 | self.cfg = cfg
240 | # modules
241 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
242 | kernel_size = self.cfg.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE
243 | padding = self.cfg.MODEL.FUSION.CONVBNRELU.PADDING
244 | self.module = nn.Sequential(
245 | nn.Conv1d(in_channels=emb_dim,out_channels=emb_dim,kernel_size=kernel_size,padding=padding),
246 | nn.BatchNorm1d(num_features=emb_dim),
247 | nn.ReLU())
248 |
249 | def forward(self,feature):
250 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first)
251 | convolved_feature = self.module(transposed_feature)
252 |
253 | return torch.transpose(convolved_feature,1,2)
254 |
255 | def basic_block(idim, odim, ksize=3):
256 | layers = []
257 | # 1st conv
258 | p = ksize // 2
259 | layers.append(nn.Conv1d(idim, odim, ksize, 1, p, bias=False))
260 | layers.append(nn.BatchNorm1d(odim))
261 | layers.append(nn.ReLU(inplace=True))
262 | # 2nd conv
263 | layers.append(nn.Conv1d(odim, odim, ksize, 1, p, bias=False))
264 | layers.append(nn.BatchNorm1d(odim))
265 |
266 | return nn.Sequential(*layers)
267 |
268 | class FusionResBlock(nn.Module):
269 | def __init__(self, cfg):
270 | super().__init__()
271 | # get configuration
272 | self.cfg = cfg
273 | # modules
274 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
275 | kernel_size = self.cfg.MODEL.FUSION.RESBLOCK.KERNEL_SIZE
276 | padding = self.cfg.MODEL.FUSION.RESBLOCK.PADDING
277 | self.nblocks = self.cfg.MODEL.FUSION.RESBLOCK.NB_ITER
278 |
279 | # set layers
280 | self.blocks = nn.ModuleList()
281 | for i in range(self.nblocks):
282 | cur_block = basic_block(emb_dim, emb_dim, kernel_size)
283 | self.blocks.append(cur_block)
284 |
285 | def forward(self, feature):
286 | """
287 | Args:
288 | inp: [B, input-Dim, L]
289 | out: [B, output-Dim, L]
290 | """
291 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first)
292 | residual = transposed_feature
293 | for i in range(self.nblocks):
294 | out = self.blocks[i](residual)
295 | out += residual
296 | out = F.relu(out)
297 | residual = out
298 |
299 | return torch.transpose(out,1,2)
300 |
301 | class AttentionBlockS2V(nn.Module):
302 | def __init__(self,cfg):
303 | super().__init__()
304 | # get configuration
305 | self.cfg = cfg
306 | # modules
307 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
308 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD
309 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head)
310 |
311 | if self.cfg.MODEL.FUSION.USE_RESBLOCK:
312 | self.convbnrelu = FusionResBlock(cfg)
313 | else:
314 | self.convbnrelu = FusionConvBNReLU(cfg)
315 |
316 | def forward(self,vid_feats,query_feats,query_masks):
317 | # attnetion
318 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
319 | attended_feature, weights = self.attention(
320 | query=torch.transpose(vid_feats,0,1),
321 | key=torch.transpose(query_feats,0,1),
322 | value=torch.transpose(query_feats,0,1),
323 | key_padding_mask=key_padding_mask,)
324 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format
325 | # convolution
326 | convolved_feature = self.convbnrelu(attended_feature) + vid_feats
327 | return convolved_feature
328 |
329 | class AttentionBlockV2S(nn.Module):
330 | def __init__(self,cfg):
331 | super().__init__()
332 | # get configuration
333 | self.cfg = cfg
334 | # modules
335 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
336 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD
337 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head)
338 |
339 | if self.cfg.MODEL.FUSION.USE_RESBLOCK:
340 | self.convbnrelu = FusionResBlock(cfg)
341 | else:
342 | self.convbnrelu = FusionConvBNReLU(cfg)
343 |
344 | def forward(self,vid_feats,query_feats,vid_masks):
345 | # attnetion
346 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
347 | key_padding_mask = key_padding_mask.squeeze()
348 | attended_feature, weights = self.attention(
349 | query=torch.transpose(query_feats,0,1),
350 | key=torch.transpose(vid_feats,0,1),
351 | value=torch.transpose(vid_feats,0,1),
352 | key_padding_mask=key_padding_mask,)
353 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format
354 | # convolution
355 | convolved_feature = self.convbnrelu(attended_feature) + query_feats
356 | return convolved_feature
357 |
358 | class SimpleFusionModule(nn.Module):
359 | def __init__(self, cfg):
360 | super().__init__()
361 | # get configuration
362 | self.cfg = cfg
363 | # attention module
364 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS
365 | self.layers = []
366 | for _ in range(num_layers):
367 | self.layers.append(AttentionBlockS2V(cfg))
368 | self.layers = nn.ModuleList(self.layers)
369 |
370 | def forward(self, query_feats, query_masks, vid_feats, vid_masks):
371 | attended_vid_feats = vid_feats
372 | for attn_layer in self.layers:
373 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks)
374 | return attended_vid_feats
375 |
376 | class SimpleFusionModuleSent(nn.Module):
377 | def __init__(self, cfg):
378 | super().__init__()
379 | # get configuration
380 | self.cfg = cfg
381 | # attention module
382 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS
383 | self.layers = []
384 | for _ in range(num_layers):
385 | self.layers.append(AttentionBlockV2S(cfg))
386 | self.layers = nn.ModuleList(self.layers)
387 |
388 | def forward(self, query_feats, query_masks, vid_feats, vid_masks):
389 | attended_query_feats = query_feats
390 | for attn_layer in self.layers:
391 | attended_query_feats = attn_layer(vid_feats=vid_feats, query_feats=attended_query_feats, vid_masks=vid_masks)
392 | return attended_query_feats
393 |
394 | class TwostageSimpleFusionModule(nn.Module):
395 | def __init__(self, cfg):
396 | super().__init__()
397 | # get configuration
398 | self.cfg = cfg
399 | # attention module
400 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS
401 | self.layers = []
402 | for _ in range(num_layers):
403 | self.layers.append(AttentionBlockS2V(cfg))
404 | self.layers = nn.ModuleList(self.layers)
405 |
406 | def forward(self, query_feats, query_masks, vid_feats, vid_masks):
407 | attended_vid_feats = vid_feats
408 | for attn_layer in self.layers:
409 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks)
410 | return attended_vid_feats
411 |
412 | class NonLocalBlock(nn.Module):
413 | """
414 | Nonlocal block used for obtaining global feature.
415 | code borrowed from LGI
416 | """
417 | def __init__(self, cfg):
418 | super(NonLocalBlock, self).__init__()
419 | self.cfg = cfg
420 | # dims
421 | self.idim = self.cfg.MODEL.FUSION.EMB_DIM
422 | self.odim = self.cfg.MODEL.FUSION.EMB_DIM
423 | self.nheads = self.cfg.MODEL.NONLOCAL.NUM_HEAD
424 |
425 | # options
426 | self.use_bias = self.cfg.MODEL.NONLOCAL.USE_BIAS
427 |
428 | # layers
429 | self.c_lin = nn.Linear(self.idim, self.odim*2, bias=self.use_bias)
430 | self.v_lin = nn.Linear(self.idim, self.odim, bias=self.use_bias)
431 |
432 | self.relu = nn.ReLU()
433 | self.sigmoid = nn.Sigmoid()
434 | self.drop = nn.Dropout(self.cfg.MODEL.NONLOCAL.DROPOUT)
435 |
436 | def forward(self, m_feats, mask):
437 | """
438 | Inputs:
439 | m_feats: segment-level multimodal feature [B,nseg,*]
440 | mask: mask [B,nseg]
441 | Outputs:
442 | updated_m: updated multimodal feature [B,nseg,*]
443 | """
444 |
445 | mask = mask.float()
446 | B, nseg = mask.size()
447 |
448 | # key, query, value
449 | m_k = self.v_lin(self.drop(m_feats)) # [B,num_seg,*]
450 | m_trans = self.c_lin(self.drop(m_feats)) # [B,nseg,2*]
451 | m_q, m_v = torch.split(m_trans, m_trans.size(2) // 2, dim=2)
452 |
453 | new_mq = m_q
454 | new_mk = m_k
455 |
456 | # applying multi-head attention
457 | w_list = []
458 | mk_set = torch.split(new_mk, new_mk.size(2) // self.nheads, dim=2)
459 | mq_set = torch.split(new_mq, new_mq.size(2) // self.nheads, dim=2)
460 | mv_set = torch.split(m_v, m_v.size(2) // self.nheads, dim=2)
461 |
462 | for i in range(self.nheads):
463 | mk_slice, mq_slice, mv_slice = mk_set[i], mq_set[i], mv_set[i] # [B, nseg, *]
464 |
465 | # compute relation matrix; [B,nseg,nseg]
466 | m2m = mk_slice @ mq_slice.transpose(1,2) / ((self.odim // self.nheads) ** 0.5)
467 | m2m = m2m.masked_fill(mask.unsqueeze(1).eq(0), -1e9) # [B,nseg,nseg]
468 | m2m_w = torch.nn.functional.softmax(m2m, dim=2) # [B,nseg,nseg]
469 | w_list.append(m2m_w)
470 |
471 | # compute relation vector for each segment
472 | r = m2m_w @ mv_slice if (i==0) else torch.cat((r, m2m_w @ mv_slice), dim=2)
473 |
474 | updated_m =m_feats + r
475 | return updated_m
476 |
477 | class AttentivePooling(nn.Module):
478 | def __init__(self, cfg):
479 | self.cfg = cfg
480 | super(AttentivePooling, self).__init__()
481 | self.att_n = 1
482 | self.feat_dim = self.cfg.MODEL.FUSION.EMB_DIM
483 | self.att_hid_dim = self.cfg.MODEL.FUSION.EMB_DIM // 2
484 | self.use_embedding = True
485 |
486 | self.feat2att = nn.Linear(self.feat_dim, self.att_hid_dim, bias=False)
487 | self.to_alpha = nn.Linear(self.att_hid_dim, self.att_n, bias=False)
488 | if self.use_embedding:
489 | edim = self.cfg.MODEL.FUSION.EMB_DIM
490 | self.fc = nn.Linear(self.feat_dim, edim)
491 |
492 | def forward(self, feats, f_masks=None):
493 | """
494 | Compute attention weights and attended feature (weighted sum)
495 | Args:
496 | feats: features where attention weights are computed; [B, A, D]
497 | f_masks: mask for effective features; [B, A]
498 | """
499 | # check inputs
500 | assert len(feats.size()) == 3 or len(feats.size()) == 4
501 | assert f_masks is None or len(f_masks.size()) == 2
502 |
503 | # dealing with dimension 4
504 | if len(feats.size()) == 4:
505 | B, W, H, D = feats.size()
506 | feats = feats.view(B, W*H, D)
507 |
508 | # embedding feature vectors
509 | attn_f = self.feat2att(feats) # [B,A,hdim]
510 |
511 | # compute attention weights
512 | dot = torch.tanh(attn_f) # [B,A,hdim]
513 | alpha = self.to_alpha(dot) # [B,A,att_n]
514 | if f_masks is not None:
515 | alpha = alpha.masked_fill(f_masks.float().unsqueeze(2).eq(0), -1e9)
516 | attw = torch.nn.functional.softmax(alpha.transpose(1,2), dim=2) # [B,att_n,A]
517 |
518 | att_feats = attw @ feats # [B,att_n,D]
519 | att_feats = att_feats.squeeze(1)
520 | attw = attw.squeeze(1)
521 | if self.use_embedding: att_feats = self.fc(att_feats)
522 |
523 | return att_feats, attw
524 |
525 |
526 | class AttentionLocRegressor(nn.Module):
527 | def __init__(self, cfg):
528 | super(AttentionLocRegressor, self).__init__()
529 | self.cfg = cfg
530 | self.tatt_vid = AttentivePooling(self.cfg)
531 | self.tatt_query = AttentivePooling(self.cfg)
532 | # Regression layer
533 | idim = self.cfg.MODEL.FUSION.EMB_DIM * 2
534 | gdim = self.cfg.MODEL.FUSION.EMB_DIM
535 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()]
536 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)]
537 | self.MLP_reg = nn.Sequential(*nn_list)
538 |
539 |
540 | def forward(self, semantic_aware_seg_vid_feats, vid_masks, semantic_aware_seg_query_feat, query_masks):
541 | # perform Eq. (13) and (14)
542 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks)
543 | summarized_qfeat, att_w_q = self.tatt_query(semantic_aware_seg_query_feat, query_masks)
544 | # perform Eq. (15)
545 | summarized_feats = torch.cat((summarized_vfeat, summarized_qfeat), dim=1)
546 | #loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e]
547 | loc = self.MLP_reg(summarized_feats) # loc = [t^s, t^e]
548 | return loc, att_w
549 |
550 |
551 | class TwostageAttentionLocRegressor(nn.Module):
552 | def __init__(self, cfg):
553 | super(TwostageAttentionLocRegressor, self).__init__()
554 | self.cfg = cfg
555 | self.tatt_vid = AttentivePooling(self.cfg)
556 | # Regression layer
557 | idim = self.cfg.MODEL.FUSION.EMB_DIM
558 | gdim = self.cfg.MODEL.FUSION.EMB_DIM
559 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()]
560 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)]
561 | self.MLP_reg = nn.Sequential(*nn_list)
562 |
563 | def forward(self, semantic_aware_seg_vid_feats, vid_masks):
564 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks)
565 | loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e]
566 | return loc, att_w
567 |
568 |
569 | class SimpleModel(nn.Module):
570 | def __init__(self,cfg):
571 | super().__init__()
572 | self.cfg = cfg
573 | self.query_encoder = SimpleSentenceEmbeddingModule(cfg)
574 | self.video_encoder = SimpleVideoEmbeddingModule(cfg)
575 | self.v_fusor = SimpleFusionModule(cfg)
576 | self.s_fusor = SimpleFusionModuleSent(cfg)
577 | self.cv_fusor = TwostageSimpleFusionModule(cfg)
578 | self.n_non_local = self.cfg.MODEL.NONLOCAL.NUM_LAYERS
579 | self.non_locals_layer = nn.ModuleList([NonLocalBlock(cfg) for _ in range(self.n_non_local)])
580 | self.loc_regressor = AttentionLocRegressor(cfg)
581 | self.loc_regressor_two_stage = TwostageAttentionLocRegressor(cfg)
582 |
583 | def forward(self,inputs):
584 | # encode query
585 | query_labels = inputs['query_labels']
586 | query_masks = inputs['query_masks']
587 | encoded_query, encoded_sentence = self.query_encoder(query_labels, query_masks) # encoded_query [B, L, D] D = 256
588 |
589 | # encode video
590 | vid_feats = inputs['video_feats']
591 | vid_masks = inputs['video_masks']
592 | encoded_video = self.video_encoder(vid_feats,vid_masks)
593 |
594 | # Crossmodality Attention
595 | attended_sent = self.s_fusor(encoded_query, query_masks, encoded_video, vid_masks)
596 | attended_vid = self.v_fusor(encoded_query, query_masks, encoded_video, vid_masks)
597 | two_stage_attended_vid = self.cv_fusor(attended_sent, query_masks, attended_vid, vid_masks)
598 |
599 | global_two_stage_vid = two_stage_attended_vid
600 | for non_local_layer in self.non_locals_layer:
601 | global_two_stage_vid = non_local_layer(global_two_stage_vid, vid_masks.squeeze(2))
602 |
603 | loc, temporal_attn_weight = self.loc_regressor_two_stage(global_two_stage_vid, vid_masks.squeeze(2))
604 |
605 | return {"timestamps": loc, "attention_weights": temporal_attn_weight}
--------------------------------------------------------------------------------
/models/simple_model_cross_modal_two_stage_temperature.py:
--------------------------------------------------------------------------------
1 | #%%
2 | """
3 | simple_model.py
4 | ****
5 | simple, basic model for NLVL.
6 | - Query-Video matching with (Multi-Head Attention + ConvBNReLU) with residual connection
7 | - Video Encoding with simple GRU
8 | """
9 |
10 | #%%
11 | # import things
12 | import torch
13 | import torch.nn as nn
14 | import torch.nn.functional as F
15 |
16 | #%%
17 | # model
18 | class SimpleSentenceEmbeddingModule(nn.Module):
19 | """
20 | A Simple Query Embedding class
21 | """
22 | def __init__(self, cfg):
23 | super().__init__()
24 | # config params
25 | self.cfg = cfg
26 | self.query_length = self.cfg.DATASET.MAX_LENGTH
27 | # embedding Layer
28 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM
29 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM
30 | self.embedding = nn.Embedding(emb_idim, emb_odim)
31 | # RNN Layer
32 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM
33 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True)
34 | # feature adjust
35 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
36 | self.feature_aggregation = nn.Sequential(
37 | nn.Linear(in_features=gru_hidden*2,out_features=emb_dim),
38 | nn.ReLU(),
39 | nn.Dropout(0.5))
40 |
41 | def forward(self, query_labels, query_masks):
42 | """
43 | encode query sequence using RNN and return logits over proposals.
44 | code adopted from LGI
45 | Args:
46 | query_labels: query_labels vectors of query; [B, vocab_size]
47 | query_masks: mask for query; [B,L]
48 | out_type: output type [word-level | sentenve-level | both]
49 | Returns:
50 | w_feats: word-level features; [B,L,2*h]
51 | s_feats: sentence-level feature; [B,2*h]
52 | """
53 | # embedding query_labels data
54 | wemb = self.embedding(query_labels) # [B,L,emb_odim]
55 | # encoding query_labels data.
56 | max_len = query_labels.size(1) # == L
57 | # make word-wise feature
58 | length = query_masks.sum(1) # [B,]
59 | pack_wemb = nn.utils.rnn.pack_padded_sequence(wemb, length, batch_first=True, enforce_sorted=False)
60 | w_feats, _ = self.gru(pack_wemb)
61 | w_feats, max_ = nn.utils.rnn.pad_packed_sequence(w_feats, batch_first=True, total_length=max_len)
62 | w_feats = w_feats.contiguous() # [B,L,2*h]
63 |
64 | # get sentence feature
65 | B, L, H = w_feats.size()
66 | idx = (length-1).long() # 0-indexed
67 | idx = idx.view(B, 1, 1).expand(B, 1, H//2)
68 | fLSTM = w_feats[:,:,:H//2].gather(1, idx).view(B, H//2)
69 | bLSTM = w_feats[:,0,H//2:].view(B,H//2)
70 | s_feats = torch.cat([fLSTM, bLSTM], dim=1)
71 |
72 | # aggregae features
73 | w_feats = self.feature_aggregation(w_feats)
74 | return w_feats, s_feats
75 |
76 |
77 | class TransformerSentenceEmbeddingModule(nn.Module):
78 | """
79 | A Simple Query Embedding class
80 | """
81 | def __init__(self, cfg):
82 | super().__init__()
83 | # config params
84 | self.cfg = cfg
85 | self.query_length = self.cfg.DATASET.MAX_LENGTH
86 | # embedding Layer
87 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM
88 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM
89 | self.embedding = nn.Embedding(emb_idim, emb_odim)
90 |
91 | # RNN Layer
92 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM
93 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True)
94 |
95 | # Attention layer
96 | t_emb_dim = self.cfg.MODEL.QUERY.TRANSFORMER_DIM # 300
97 | #t_emb_dim = gru_hidden * 2 # 256 * 2
98 | self.attention = nn.MultiheadAttention(embed_dim=t_emb_dim, num_heads=4)
99 |
100 | # feature adjust
101 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
102 |
103 | self.feature_aggregation = nn.Sequential(
104 | nn.Linear(in_features=t_emb_dim, out_features=emb_dim),
105 | nn.ReLU(),
106 | nn.Dropout(0,5))
107 |
108 |
109 | def forward(self, query_labels, query_masks):
110 | """
111 | encode query sequence using RNN and return logits over proposals.
112 | code adopted from LGI
113 | Args:
114 | query_labels: query_labels vectors of query; [B, vocab_size]
115 | query_masks: mask for query; [B,L]
116 | out_type: output type [word-level | sentenve-level | both]
117 | Returns:
118 | w_feats: word-level features; [B,L,2*h]
119 | s_feats: sentence-level feature; [B,2*h]
120 | """
121 | # embedding query_labels data
122 | wemb = self.embedding(query_labels) # [B,L,emb_odim]
123 |
124 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
125 | # [B, L, D] -> [L, B, D]
126 | attended_feature, weights = self.attention(
127 | query=torch.transpose(wemb, 0,1),
128 | key=torch.transpose(wemb, 0,1),
129 | value=torch.transpose(wemb, 0,1),
130 | key_padding_mask=key_padding_mask,)
131 |
132 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format
133 | # convolution?
134 |
135 | # aggregae features
136 | w_feats = self.feature_aggregation(attended_feature)
137 | #return w_feats, s_feats
138 | return w_feats
139 |
140 | class SimpleVideoEmbeddingModule(nn.Module):
141 | """
142 | A simple Video Embedding Class
143 | """
144 | def __init__(self, cfg):
145 | super().__init__() # Must call super __init__()
146 | # get configuration
147 | self.cfg = cfg
148 | # video gru
149 | vid_idim = self.cfg.MODEL.VIDEO.IDIM
150 | vid_gru_hdim = self.cfg.MODEL.VIDEO.GRU_HDIM
151 | self.gru = nn.GRU(input_size=vid_idim,hidden_size=vid_gru_hdim,batch_first=True,dropout=0.5,bidirectional=True)
152 |
153 |
154 | # video feature aggregation module
155 | catted_dim = vid_idim + vid_gru_hdim*2
156 | #catted_dim = vid_gru_hdim *2
157 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
158 | self.feature_aggregation = nn.Sequential(
159 | nn.Linear(in_features=catted_dim,out_features=emb_dim),
160 | nn.ReLU(),
161 | nn.Dropout(0.5),
162 | )
163 |
164 | def forward(self, vid_feats, vid_masks):
165 | """
166 | encode video features. Utilizes GRU.
167 | Args:
168 | vid_feats: video features
169 | vid_masks: mask for video
170 | Return:
171 | vid_features: hidden state features of the video
172 | """
173 | length = vid_masks.sum(1).squeeze(1)
174 | packed_vid = nn.utils.rnn.pack_padded_sequence(vid_feats, length, batch_first=True, enforce_sorted=False)
175 | vid_hiddens, _ = self.gru(packed_vid)
176 | vid_hiddens, max_ = nn.utils.rnn.pad_packed_sequence(vid_hiddens, batch_first=True, total_length=vid_feats.shape[1])
177 | #vid_output = self.feature_aggregation(vid_hiddens)
178 |
179 | vid_catted = torch.cat([vid_feats,vid_hiddens],dim=2)
180 | vid_output = self.feature_aggregation(vid_catted)
181 | return vid_output
182 |
183 |
184 | class TransformerVideoEmbeddingModule(nn.Module):
185 | """
186 | A simple Video Embedding Class
187 | """
188 | def __init__(self, cfg):
189 | super().__init__() # Must call super __init__()
190 | # get configuration
191 | self.cfg = cfg
192 |
193 | # video transformer
194 | vid_idim = self.cfg.MODEL.VIDEO.IDIM
195 | vid_transformer_hdim = self.cfg.MODEL.VIDEO.ANET.TRANSFORMER_DIM # 1024(charades), 1000 (anet)
196 | self.attention = nn.MultiheadAttention(embed_dim=vid_idim, num_heads=4)
197 |
198 | # video feature aggregation module
199 | catted_dim = vid_idim + vid_transformer_hdim
200 |
201 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
202 | self.feature_aggregation = nn.Sequential(
203 | nn.Linear(in_features=catted_dim,out_features=emb_dim),
204 | nn.ReLU(),
205 | nn.Dropout(0.5),
206 | )
207 |
208 | def forward(self, vid_feats, vid_masks):
209 | """
210 | encode video features. Utilizes GRU.
211 | Args:
212 | vid_feats: video features
213 | vid_masks: mask for video
214 | Return:
215 | vid_features: hidden state features of the video
216 | """
217 |
218 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
219 | # [B, L, D] -> [L, B, D]
220 | attended_feature, weights = self.attention(
221 | query=torch.transpose(vid_feats, 0,1),
222 | key=torch.transpose(vid_feats, 0,1),
223 | value=torch.transpose(vid_feats, 0,1),
224 | key_padding_mask=key_padding_mask.squeeze(),)
225 |
226 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format
227 | # convolution?
228 |
229 | # aggregae features
230 | vid_catted = torch.cat([vid_feats,attended_feature],dim=2)
231 | vid_output = self.feature_aggregation(vid_catted)
232 |
233 | return vid_output
234 |
235 | class FusionConvBNReLU(nn.Module):
236 | def __init__(self,cfg):
237 | super().__init__()
238 | # get configuration
239 | self.cfg = cfg
240 | # modules
241 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
242 | kernel_size = self.cfg.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE
243 | padding = self.cfg.MODEL.FUSION.CONVBNRELU.PADDING
244 | self.module = nn.Sequential(
245 | nn.Conv1d(in_channels=emb_dim,out_channels=emb_dim,kernel_size=kernel_size,padding=padding),
246 | nn.BatchNorm1d(num_features=emb_dim),
247 | nn.ReLU())
248 |
249 | def forward(self,feature):
250 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first)
251 | convolved_feature = self.module(transposed_feature)
252 |
253 | return torch.transpose(convolved_feature,1,2)
254 |
255 | def basic_block(idim, odim, ksize=3):
256 | layers = []
257 | # 1st conv
258 | p = ksize // 2
259 | layers.append(nn.Conv1d(idim, odim, ksize, 1, p, bias=False))
260 | layers.append(nn.BatchNorm1d(odim))
261 | layers.append(nn.ReLU(inplace=True))
262 | # 2nd conv
263 | layers.append(nn.Conv1d(odim, odim, ksize, 1, p, bias=False))
264 | layers.append(nn.BatchNorm1d(odim))
265 |
266 | return nn.Sequential(*layers)
267 |
268 | class FusionResBlock(nn.Module):
269 | def __init__(self, cfg):
270 | super().__init__()
271 | # get configuration
272 | self.cfg = cfg
273 | # modules
274 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
275 | kernel_size = self.cfg.MODEL.FUSION.RESBLOCK.KERNEL_SIZE
276 | padding = self.cfg.MODEL.FUSION.RESBLOCK.PADDING
277 | self.nblocks = self.cfg.MODEL.FUSION.RESBLOCK.NB_ITER
278 |
279 | # set layers
280 | self.blocks = nn.ModuleList()
281 | for i in range(self.nblocks):
282 | cur_block = basic_block(emb_dim, emb_dim, kernel_size)
283 | self.blocks.append(cur_block)
284 |
285 | def forward(self, feature):
286 | """
287 | Args:
288 | inp: [B, input-Dim, L]
289 | out: [B, output-Dim, L]
290 | """
291 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first)
292 | residual = transposed_feature
293 | for i in range(self.nblocks):
294 | out = self.blocks[i](residual)
295 | out += residual
296 | out = F.relu(out)
297 | residual = out
298 |
299 | return torch.transpose(out,1,2)
300 |
301 | class AttentionBlockS2V(nn.Module):
302 | def __init__(self,cfg):
303 | super().__init__()
304 | # get configuration
305 | self.cfg = cfg
306 | # modules
307 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
308 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD
309 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head)
310 |
311 | if self.cfg.MODEL.FUSION.USE_RESBLOCK:
312 | self.convbnrelu = FusionResBlock(cfg)
313 | else:
314 | self.convbnrelu = FusionConvBNReLU(cfg)
315 |
316 | def forward(self,vid_feats,query_feats,query_masks):
317 | # attnetion
318 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
319 | attended_feature, weights = self.attention(
320 | query=torch.transpose(vid_feats,0,1),
321 | key=torch.transpose(query_feats,0,1),
322 | value=torch.transpose(query_feats,0,1),
323 | key_padding_mask=key_padding_mask,)
324 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format
325 | # convolution
326 | convolved_feature = self.convbnrelu(attended_feature) + vid_feats
327 | return convolved_feature
328 |
329 | class AttentionBlockV2S(nn.Module):
330 | def __init__(self,cfg):
331 | super().__init__()
332 | # get configuration
333 | self.cfg = cfg
334 | self.temp = cfg.MODEL.QUERY.TEMPERATURE
335 | # modules
336 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM
337 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD
338 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head)
339 |
340 | if self.cfg.MODEL.FUSION.USE_RESBLOCK:
341 | self.convbnrelu = FusionResBlock(cfg)
342 | else:
343 | self.convbnrelu = FusionConvBNReLU(cfg)
344 |
345 | def forward(self,vid_feats,query_feats,vid_masks):
346 | # attnetion
347 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it.
348 | key_padding_mask = key_padding_mask.squeeze()
349 | attended_feature, weights = self.attention(
350 | query=torch.transpose(query_feats,0,1) / self.temp,
351 | key=torch.transpose(vid_feats,0,1),
352 | value=torch.transpose(vid_feats,0,1),
353 | key_padding_mask=key_padding_mask,)
354 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format
355 | # convolution
356 | convolved_feature = self.convbnrelu(attended_feature) + query_feats
357 | return convolved_feature
358 |
359 | class SimpleFusionModule(nn.Module):
360 | def __init__(self, cfg):
361 | super().__init__()
362 | # get configuration
363 | self.cfg = cfg
364 | # attention module
365 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS
366 | self.layers = []
367 | for _ in range(num_layers):
368 | self.layers.append(AttentionBlockS2V(cfg))
369 | self.layers = nn.ModuleList(self.layers)
370 |
371 | def forward(self, query_feats, query_masks, vid_feats, vid_masks):
372 | attended_vid_feats = vid_feats
373 | for attn_layer in self.layers:
374 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks)
375 | return attended_vid_feats
376 |
377 | class SimpleFusionModuleSent(nn.Module):
378 | def __init__(self, cfg):
379 | super().__init__()
380 | # get configuration
381 | self.cfg = cfg
382 | # attention module
383 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS
384 | self.layers = []
385 | for _ in range(num_layers):
386 | self.layers.append(AttentionBlockV2S(cfg))
387 | self.layers = nn.ModuleList(self.layers)
388 |
389 | def forward(self, query_feats, query_masks, vid_feats, vid_masks):
390 | attended_query_feats = query_feats
391 | for attn_layer in self.layers:
392 | attended_query_feats = attn_layer(vid_feats=vid_feats, query_feats=attended_query_feats, vid_masks=vid_masks)
393 | return attended_query_feats
394 |
395 | class TwostageSimpleFusionModule(nn.Module):
396 | def __init__(self, cfg):
397 | super().__init__()
398 | # get configuration
399 | self.cfg = cfg
400 | # attention module
401 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS
402 | self.layers = []
403 | for _ in range(num_layers):
404 | self.layers.append(AttentionBlockS2V(cfg))
405 | self.layers = nn.ModuleList(self.layers)
406 |
407 | def forward(self, query_feats, query_masks, vid_feats, vid_masks):
408 | attended_vid_feats = vid_feats
409 | for attn_layer in self.layers:
410 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks)
411 | return attended_vid_feats
412 |
413 | class NonLocalBlock(nn.Module):
414 | """
415 | Nonlocal block used for obtaining global feature.
416 | code borrowed from LGI
417 | """
418 | def __init__(self, cfg):
419 | super(NonLocalBlock, self).__init__()
420 | self.cfg = cfg
421 | # dims
422 | self.idim = self.cfg.MODEL.FUSION.EMB_DIM
423 | self.odim = self.cfg.MODEL.FUSION.EMB_DIM
424 | self.nheads = self.cfg.MODEL.NONLOCAL.NUM_HEAD
425 |
426 | # options
427 | self.use_bias = self.cfg.MODEL.NONLOCAL.USE_BIAS
428 |
429 | # layers
430 | self.c_lin = nn.Linear(self.idim, self.odim*2, bias=self.use_bias)
431 | self.v_lin = nn.Linear(self.idim, self.odim, bias=self.use_bias)
432 |
433 | self.relu = nn.ReLU()
434 | self.sigmoid = nn.Sigmoid()
435 | self.drop = nn.Dropout(self.cfg.MODEL.NONLOCAL.DROPOUT)
436 |
437 | def forward(self, m_feats, mask):
438 | """
439 | Inputs:
440 | m_feats: segment-level multimodal feature [B,nseg,*]
441 | mask: mask [B,nseg]
442 | Outputs:
443 | updated_m: updated multimodal feature [B,nseg,*]
444 | """
445 |
446 | mask = mask.float()
447 | B, nseg = mask.size()
448 |
449 | # key, query, value
450 | m_k = self.v_lin(self.drop(m_feats)) # [B,num_seg,*]
451 | m_trans = self.c_lin(self.drop(m_feats)) # [B,nseg,2*]
452 | m_q, m_v = torch.split(m_trans, m_trans.size(2) // 2, dim=2)
453 |
454 | new_mq = m_q
455 | new_mk = m_k
456 |
457 | # applying multi-head attention
458 | w_list = []
459 | mk_set = torch.split(new_mk, new_mk.size(2) // self.nheads, dim=2)
460 | mq_set = torch.split(new_mq, new_mq.size(2) // self.nheads, dim=2)
461 | mv_set = torch.split(m_v, m_v.size(2) // self.nheads, dim=2)
462 |
463 | for i in range(self.nheads):
464 | mk_slice, mq_slice, mv_slice = mk_set[i], mq_set[i], mv_set[i] # [B, nseg, *]
465 |
466 | # compute relation matrix; [B,nseg,nseg]
467 | m2m = mk_slice @ mq_slice.transpose(1,2) / ((self.odim // self.nheads) ** 0.5)
468 | m2m = m2m.masked_fill(mask.unsqueeze(1).eq(0), -1e9) # [B,nseg,nseg]
469 | m2m_w = torch.nn.functional.softmax(m2m, dim=2) # [B,nseg,nseg]
470 | w_list.append(m2m_w)
471 |
472 | # compute relation vector for each segment
473 | r = m2m_w @ mv_slice if (i==0) else torch.cat((r, m2m_w @ mv_slice), dim=2)
474 |
475 | updated_m =m_feats + r
476 | return updated_m
477 |
478 | class AttentivePooling(nn.Module):
479 | def __init__(self, cfg):
480 | self.cfg = cfg
481 | super(AttentivePooling, self).__init__()
482 | self.att_n = 1
483 | self.feat_dim = self.cfg.MODEL.FUSION.EMB_DIM
484 | self.att_hid_dim = self.cfg.MODEL.FUSION.EMB_DIM // 2
485 | self.use_embedding = True
486 |
487 | self.feat2att = nn.Linear(self.feat_dim, self.att_hid_dim, bias=False)
488 | self.to_alpha = nn.Linear(self.att_hid_dim, self.att_n, bias=False)
489 | if self.use_embedding:
490 | edim = self.cfg.MODEL.FUSION.EMB_DIM
491 | self.fc = nn.Linear(self.feat_dim, edim)
492 |
493 | def forward(self, feats, f_masks=None):
494 | """
495 | Compute attention weights and attended feature (weighted sum)
496 | Args:
497 | feats: features where attention weights are computed; [B, A, D]
498 | f_masks: mask for effective features; [B, A]
499 | """
500 | # check inputs
501 | assert len(feats.size()) == 3 or len(feats.size()) == 4
502 | assert f_masks is None or len(f_masks.size()) == 2
503 |
504 | # dealing with dimension 4
505 | if len(feats.size()) == 4:
506 | B, W, H, D = feats.size()
507 | feats = feats.view(B, W*H, D)
508 |
509 | # embedding feature vectors
510 | attn_f = self.feat2att(feats) # [B,A,hdim]
511 |
512 | # compute attention weights
513 | dot = torch.tanh(attn_f) # [B,A,hdim]
514 | alpha = self.to_alpha(dot) # [B,A,att_n]
515 | if f_masks is not None:
516 | alpha = alpha.masked_fill(f_masks.float().unsqueeze(2).eq(0), -1e9)
517 | attw = torch.nn.functional.softmax(alpha.transpose(1,2), dim=2) # [B,att_n,A]
518 |
519 | att_feats = attw @ feats # [B,att_n,D]
520 | att_feats = att_feats.squeeze(1)
521 | attw = attw.squeeze(1)
522 | if self.use_embedding: att_feats = self.fc(att_feats)
523 |
524 | return att_feats, attw
525 |
526 |
527 | class AttentionLocRegressor(nn.Module):
528 | def __init__(self, cfg):
529 | super(AttentionLocRegressor, self).__init__()
530 | self.cfg = cfg
531 | self.tatt_vid = AttentivePooling(self.cfg)
532 | self.tatt_query = AttentivePooling(self.cfg)
533 | # Regression layer
534 | idim = self.cfg.MODEL.FUSION.EMB_DIM * 2
535 | gdim = self.cfg.MODEL.FUSION.EMB_DIM
536 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()]
537 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)]
538 | self.MLP_reg = nn.Sequential(*nn_list)
539 |
540 |
541 | def forward(self, semantic_aware_seg_vid_feats, vid_masks, semantic_aware_seg_query_feat, query_masks):
542 | # perform Eq. (13) and (14)
543 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks)
544 | summarized_qfeat, att_w_q = self.tatt_query(semantic_aware_seg_query_feat, query_masks)
545 | # perform Eq. (15)
546 | summarized_feats = torch.cat((summarized_vfeat, summarized_qfeat), dim=1)
547 | #loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e]
548 | loc = self.MLP_reg(summarized_feats) # loc = [t^s, t^e]
549 | return loc, att_w
550 |
551 |
552 | class TwostageAttentionLocRegressor(nn.Module):
553 | def __init__(self, cfg):
554 | super(TwostageAttentionLocRegressor, self).__init__()
555 | self.cfg = cfg
556 | self.tatt_vid = AttentivePooling(self.cfg)
557 | # Regression layer
558 | idim = self.cfg.MODEL.FUSION.EMB_DIM
559 | gdim = self.cfg.MODEL.FUSION.EMB_DIM
560 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()]
561 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)]
562 | self.MLP_reg = nn.Sequential(*nn_list)
563 |
564 | def forward(self, semantic_aware_seg_vid_feats, vid_masks):
565 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks)
566 | loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e]
567 | return loc, att_w
568 |
569 |
570 | class SimpleModel(nn.Module):
571 | def __init__(self,cfg):
572 | super().__init__()
573 | self.cfg = cfg
574 | self.query_encoder = SimpleSentenceEmbeddingModule(cfg)
575 | self.video_encoder = SimpleVideoEmbeddingModule(cfg)
576 | self.v_fusor = SimpleFusionModule(cfg)
577 | self.s_fusor = SimpleFusionModuleSent(cfg)
578 | self.cv_fusor = TwostageSimpleFusionModule(cfg)
579 | self.n_non_local = self.cfg.MODEL.NONLOCAL.NUM_LAYERS
580 | self.sent_temp = self.cfg.MODEL.QUERY.TEMPERATURE
581 | self.non_locals_layer = nn.ModuleList([NonLocalBlock(cfg) for _ in range(self.n_non_local)])
582 | self.loc_regressor = AttentionLocRegressor(cfg)
583 | self.loc_regressor_two_stage = TwostageAttentionLocRegressor(cfg)
584 |
585 | def forward(self,inputs):
586 | # encode query
587 | query_labels = inputs['query_labels']
588 | query_masks = inputs['query_masks']
589 | encoded_query, encoded_sentence = self.query_encoder(query_labels, query_masks) # encoded_query [B, L, D] D = 256
590 |
591 | # encode video
592 | vid_feats = inputs['video_feats']
593 | vid_masks = inputs['video_masks']
594 | encoded_video = self.video_encoder(vid_feats,vid_masks)
595 |
596 | # Crossmodality Attention
597 | attended_sent = self.s_fusor(encoded_query, query_masks, encoded_video, vid_masks)
598 | attended_vid = self.v_fusor(encoded_query, query_masks, encoded_video, vid_masks)
599 | two_stage_attended_vid = self.cv_fusor(attended_sent, query_masks, attended_vid, vid_masks)
600 |
601 | global_two_stage_vid = two_stage_attended_vid
602 | for non_local_layer in self.non_locals_layer:
603 | global_two_stage_vid = non_local_layer(global_two_stage_vid, vid_masks.squeeze(2))
604 |
605 | loc, temporal_attn_weight = self.loc_regressor_two_stage(global_two_stage_vid, vid_masks.squeeze(2))
606 |
607 | return {"timestamps": loc, "attention_weights": temporal_attn_weight}
608 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | # This file may be used to create an environment using:
2 | # $ conda create --name --file
3 | # platform: linux-64
4 | _libgcc_mutex=0.1=main
5 | absl-py=0.11.0=pypi_0
6 | argon2-cffi=20.1.0=pypi_0
7 | astor=0.8.1=pypi_0
8 | async-generator=1.10=pypi_0
9 | attrs=20.3.0=pypi_0
10 | backcall=0.2.0=py_0
11 | beautifulsoup4=4.9.1=py37_0
12 | blas=1.0=mkl
13 | bleach=3.2.2=pypi_0
14 | bzip2=1.0.8=h7b6447c_0
15 | ca-certificates=2020.1.1=0
16 | cached-property=1.5.2=pypi_0
17 | certifi=2020.4.5.2=py37_0
18 | cffi=1.14.0=py37he30daa8_1
19 | chardet=3.0.4=py37_1003
20 | conda=4.8.3=py37_0
21 | conda-build=3.18.11=py37_0
22 | conda-package-handling=1.6.1=py37h7b6447c_0
23 | cryptography=2.9.2=py37h1ba5d50_0
24 | cudatoolkit=10.1.243=h6bb024c_0
25 | decorator=4.4.2=py_0
26 | defusedxml=0.6.0=pypi_0
27 | entrypoints=0.3=pypi_0
28 | filelock=3.0.12=py_0
29 | freetype=2.9.1=h8a8886c_1
30 | gast=0.2.2=pypi_0
31 | glob2=0.7=py_0
32 | google-pasta=0.2.0=pypi_0
33 | grpcio=1.34.0=pypi_0
34 | h5py=3.1.0=pypi_0
35 | icu=58.2=he6710b0_3
36 | idna=2.9=py_1
37 | importlib-metadata=3.3.0=pypi_0
38 | intel-openmp=2020.1=217
39 | ipykernel=5.4.3=pypi_0
40 | ipython=7.15.0=py37_0
41 | ipython_genutils=0.2.0=py37_0
42 | ipywidgets=7.6.3=pypi_0
43 | jedi=0.17.0=py37_0
44 | jinja2=2.11.2=py_0
45 | jpeg=9b=h024ee3a_2
46 | jsonschema=3.2.0=pypi_0
47 | jupyter=1.0.0=pypi_0
48 | jupyter-client=6.1.11=pypi_0
49 | jupyter-console=6.2.0=pypi_0
50 | jupyter-core=4.7.0=pypi_0
51 | jupyterlab-pygments=0.1.2=pypi_0
52 | jupyterlab-widgets=1.0.0=pypi_0
53 | keras-applications=1.0.8=pypi_0
54 | keras-preprocessing=1.1.2=pypi_0
55 | ld_impl_linux-64=2.33.1=h53a641e_7
56 | libarchive=3.4.2=h62408e4_0
57 | libedit=3.1.20181209=hc058e9b_0
58 | libffi=3.3=he6710b0_1
59 | libgcc-ng=9.1.0=hdf63c60_0
60 | libgfortran-ng=7.3.0=hdf63c60_0
61 | liblief=0.10.1=he6710b0_0
62 | libpng=1.6.37=hbc83047_0
63 | libstdcxx-ng=9.1.0=hdf63c60_0
64 | libtiff=4.1.0=h2733197_1
65 | libxml2=2.9.10=he19cac6_1
66 | lz4-c=1.9.2=he6710b0_0
67 | markdown=3.3.3=pypi_0
68 | markupsafe=1.1.1=py37h7b6447c_0
69 | mistune=0.8.4=pypi_0
70 | mkl=2020.1=217
71 | mkl-service=2.3.0=py37he904b0f_0
72 | mkl_fft=1.1.0=py37h23d657b_0
73 | mkl_random=1.1.1=py37h0573a6f_0
74 | nbclient=0.5.1=pypi_0
75 | nbconvert=6.0.7=pypi_0
76 | nbformat=5.1.2=pypi_0
77 | ncurses=6.2=he6710b0_1
78 | nest-asyncio=1.4.3=pypi_0
79 | ninja=1.9.0=py37hfd86e86_0
80 | notebook=6.2.0=pypi_0
81 | numpy=1.18.1=py37h4f9e942_0
82 | numpy-base=1.18.1=py37hde5b4d6_1
83 | olefile=0.46=py37_0
84 | openssl=1.1.1g=h7b6447c_0
85 | opt-einsum=3.3.0=pypi_0
86 | packaging=20.8=pypi_0
87 | pandocfilters=1.4.3=pypi_0
88 | parso=0.7.0=py_0
89 | patchelf=0.11=he6710b0_0
90 | pexpect=4.8.0=py37_0
91 | pickleshare=0.7.5=py37_0
92 | pillow=7.1.2=py37hb39fc2d_0
93 | pip=20.0.2=py37_3
94 | pkginfo=1.5.0.1=py37_0
95 | prometheus-client=0.9.0=pypi_0
96 | prompt-toolkit=3.0.5=py_0
97 | protobuf=3.14.0=pypi_0
98 | psutil=5.7.0=py37h7b6447c_0
99 | ptyprocess=0.6.0=py37_0
100 | py-lief=0.10.1=py37h403a769_0
101 | pycosat=0.6.3=py37h7b6447c_0
102 | pycparser=2.20=py_0
103 | pygments=2.6.1=py_0
104 | pyopenssl=19.1.0=py37_0
105 | pyparsing=2.4.7=pypi_0
106 | pyrsistent=0.17.3=pypi_0
107 | pysocks=1.7.1=py37_0
108 | python=3.7.7=hcff3b4d_5
109 | python-dateutil=2.8.1=pypi_0
110 | python-libarchive-c=2.9=py_0
111 | pytorch=1.5.1=py3.7_cuda10.1.243_cudnn7.6.3_0
112 | pytz=2020.1=py_0
113 | pyyaml=5.3.1=py37h7b6447c_0
114 | pyzmq=21.0.1=pypi_0
115 | qtconsole=5.0.2=pypi_0
116 | qtpy=1.9.0=pypi_0
117 | readline=8.0=h7b6447c_0
118 | requests=2.23.0=py37_0
119 | ripgrep=11.0.2=he32d670_0
120 | ruamel_yaml=0.15.87=py37h7b6447c_0
121 | send2trash=1.5.0=pypi_0
122 | setuptools=46.4.0=py37_0
123 | six=1.14.0=py37_0
124 | soupsieve=2.0.1=py_0
125 | sqlite=3.31.1=h62c20be_1
126 | tensorboard=1.15.0=pypi_0
127 | tensorflow=1.15.0=pypi_0
128 | tensorflow-estimator=1.15.1=pypi_0
129 | termcolor=1.1.0=pypi_0
130 | terminado=0.9.2=pypi_0
131 | testpath=0.4.4=pypi_0
132 | tk=8.6.8=hbc83047_0
133 | torchvision=0.6.1=py37_cu101
134 | tornado=6.1=pypi_0
135 | tqdm=4.46.0=py_0
136 | traitlets=4.3.3=py37_0
137 | typing-extensions=3.7.4.3=pypi_0
138 | urllib3=1.25.8=py37_0
139 | wcwidth=0.2.4=py_0
140 | webencodings=0.5.1=pypi_0
141 | werkzeug=1.0.1=pypi_0
142 | wheel=0.34.2=py37_0
143 | widgetsnbextension=3.5.1=pypi_0
144 | wrapt=1.12.1=pypi_0
145 | xz=5.2.5=h7b6447c_0
146 | yacs=0.1.8=pypi_0
147 | yaml=0.1.7=had09818_2
148 | zipp=3.4.0=pypi_0
149 | zlib=1.2.11=h7b6447c_3
150 | zstd=1.4.4=h0b5b093_3
151 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #%%
2 | # import things
3 | # model code
4 | import torch
5 | from torch.utils.data import DataLoader
6 | from utils.eval_utils import NLVLEvaluator, renew_best_score
7 | from utils.loss import NLVLLoss
8 |
9 | # logging code
10 | from torch.utils.tensorboard import SummaryWriter
11 |
12 | # etc
13 | import random
14 | import numpy as np
15 | import os
16 | from os.path import join
17 | from tqdm import tqdm
18 | from copy import deepcopy
19 | from yacs.config import CfgNode
20 | from yaml import dump as dump_yaml
21 | import argparse
22 |
23 | #%%
24 | # argparse
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("--model",'-m',type=str,default="CrossModalityTwostageAttention")
27 | parser.add_argument("--config",'-c',type=str,default="configs/cha_simple_model/simplemodel_cha_BS256_two-stage_attention.yml")
28 | parser.add_argument("--seed", '-s',type=int,default=38)
29 | parser.add_argument("--reg_w", type=float, default=1.0)
30 | parser.add_argument("--temp",'-t',type=float,default=1.0)
31 | args = parser.parse_args()
32 |
33 | dict_args = {
34 | "model": args.model,
35 | "confg": args.config
36 | }
37 |
38 | random.seed(int(args.seed))
39 | os.environ['PYTHONHASHSEED'] = str(args.seed)
40 | np.random.seed(int(args.seed))
41 | torch.manual_seed(int(args.seed))
42 | torch.cuda.manual_seed(int(args.seed))
43 | torch.cuda.manual_seed_all(int(args.seed))
44 | torch.backends.cudnn.deterministic = True
45 | torch.backends.cudnn.benchmark = False
46 | torch.backends.cudnn.enabled = False
47 |
48 | #%%
49 | # load things according to arguments
50 | if args.model == "SimpleModel":
51 | from models.simple_model import SimpleModel as Model
52 | from config_parsers.simple_model_config import _C as TestCfg
53 | elif args.model == "CrossModalityTwostageAttention":
54 | from models.simple_model_cross_modal_two_stage import SimpleModel as Model
55 | from config_parsers.simple_model_cross_modality_twostage_attention_config import _C as TestCfg
56 | elif args.model == "CrossModalityTwostageAttentionTemperature":
57 | from models.simple_model_cross_modal_two_stage_temperature import SimpleModel as Model
58 | from config_parsers.simple_model_cross_modality_twostage_attention_config import _C as TestCfg
59 | else:
60 | raise ValueError("No such model: {}".format(args.model))
61 |
62 |
63 | #%%
64 | # constants
65 | CONFIG_PATH = args.config
66 | print("config file path: ", CONFIG_PATH)
67 | TestCfg.merge_from_file(CONFIG_PATH)
68 | cfg = TestCfg
69 | cfg.MODEL.QUERY.TEMPERATURE = args.temp
70 | device = torch.device("cuda")
71 | DATA_PATH = cfg.DATASET.DATA_PATH
72 | ANNO_PATH ={"train": cfg.DATASET.TRAIN_ANNO_PATH,
73 | "test": cfg.DATASET.TEST_ANNO_PATH}
74 | VID_PATH = cfg.DATASET.VID_PATH
75 |
76 | #%%
77 | # function for dumping yaml
78 | def cfg_to_dict(cfg):
79 | dict_cfg = dict(cfg)
80 | for k,v in dict_cfg.items():
81 | if isinstance(v,CfgNode):
82 | dict_cfg[k] = cfg_to_dict(v)
83 | return dict_cfg
84 |
85 | #%%
86 | # Load dataloader
87 | dataset_name = cfg.DATASET.NAME
88 | if dataset_name == "Charades":
89 | from dataset.charades_basic import CharadesBasicDatasetBuilder
90 | dataloaders = CharadesBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders()
91 | elif dataset_name == "AnetCap":
92 | from dataset.anetcap_basic import AnetCapBasicDatasetBuilder
93 | dataloaders = AnetCapBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders()
94 | else:
95 | raise ValueError("No such dataset: {}".format(dataset_name))
96 |
97 | #%%
98 | # load training stuff
99 | ## model and loss
100 | model = Model(cfg).to(device)
101 | loss_fn = NLVLLoss(cfg).to(device)
102 |
103 | ## evaluator
104 | evaluator = NLVLEvaluator(cfg)
105 | ## optimizer
106 | lr = cfg.TRAIN.LR
107 | optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, betas=(0.9,0.999))
108 | ## batch to device function
109 | batch_to_device = dataloaders['train'].dataset.batch_to_device
110 | ## Logging Functions
111 | exp_name = cfg.EXP_NAME
112 |
113 | log_dir = join("results", dataset_name, args.model, exp_name + "_temp_" + str(args.temp))
114 | logger = SummaryWriter(log_dir=log_dir,flush_secs=60)
115 | ## arxiv cfg
116 | dict_cfg = cfg_to_dict(cfg)
117 | with open(join(log_dir,"config.yml"),'w') as f:
118 | dump_yaml(dict_cfg,f)
119 | ## arxiv args
120 | with open(join(log_dir,"args.yml"),'w') as f:
121 | dump_yaml(dict_args,f)
122 |
123 | # information print out
124 | print("====="*10)
125 | print(args)
126 | print("====="*10)
127 | print(cfg)
128 | print("====="*10)
129 |
130 | #%%
131 | curr_best = 0
132 | # training loop
133 | pbar = tqdm(range(cfg.TRAIN.NUM_EPOCH))
134 | for epoch in pbar:
135 | # training loop
136 | model.train()
137 | losses_list = []
138 | for batch_idx,item in enumerate(dataloaders['train']):
139 | # update progress bar
140 | pbar.set_description("training {}/{}".format(batch_idx,len(dataloaders['train'])))
141 | # make prediction
142 | optimizer.zero_grad()
143 | batch_to_device(item, device)
144 | model_outputs = model(item)
145 | # loss fn
146 | losses = loss_fn(model_outputs,item)
147 | sum_loss = sum([v for k,v in losses.items()])
148 | sum_loss.backward()
149 | # update model
150 | optimizer.step()
151 | # add to losses list
152 | losses_list.append(losses)
153 | # write training information
154 | ## make mean losses dict
155 | mean_losses_dict = {}
156 | for k in list(losses_list[0].keys()):
157 | mean_losses_dict[k] = torch.mean(torch.Tensor([x[k] for x in losses_list]))
158 | for loss_name,val in mean_losses_dict.items():
159 | logger.add_scalar("train/{}".format(loss_name), val, global_step=epoch)
160 |
161 | # test loop
162 | model.eval()
163 | eval_results_list = []
164 | for batch_idx,item in enumerate(dataloaders['test']):
165 | # update progress bar
166 | pbar.set_description("test {}/{}".format(batch_idx,len(dataloaders['test'])))
167 | # make evaluation
168 | with torch.no_grad():
169 | batch_to_device(item,device)
170 | model_outputs = model(item)
171 | eval_results = evaluator(model_outputs,item)
172 | # add to mean eval results
173 | eval_results_list.append(eval_results)
174 | # write test information
175 | ## make mean metric dict
176 | mean_eval_results_dict = {}
177 | for k in list(eval_results_list[0].keys()):
178 | mean_eval_results_dict[k] = torch.mean(torch.Tensor([x[k] for x in eval_results_list]))
179 |
180 | for metric_name,val in mean_eval_results_dict.items():
181 | logger.add_scalar("test/{}".format(metric_name), val, global_step=epoch)
182 | curr_best = renew_best_score(mean_eval_results_dict['Recall@0.5'], curr_best, model)
183 |
184 | # print test information
185 | tqdm.write("****** epoch:{} ******".format(epoch))
186 | for k,v in mean_eval_results_dict.items():
187 | tqdm.write("\t{}: {}".format(k,v))
188 |
--------------------------------------------------------------------------------
/utils/eval_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch
5 |
6 | def compute_tiou(pred, gt):
7 | intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0]))
8 | union = max(pred[1], gt[1]) - min(pred[0], gt[0])
9 | if union == 0.0:
10 | if intersection > 0.0:
11 | return 1.0
12 | else:
13 | return 0.0
14 | return float(intersection) / union
15 |
16 | def renew_best_score(cur, best, model):
17 | if best > cur:
18 | return best
19 | torch.save(model.state_dict(), 'pretrained_best.pth')
20 |
21 | return cur
22 |
23 | class NLVLEvaluator():
24 | def __init__(self,cfg):
25 | self.cfg = cfg
26 | self.iou_thresh = self.cfg.TRAIN.IOU_THRESH
27 | self.metrics = ["mIoU"] + ["Recall@{:.1f}".format(x) for x in self.iou_thresh]
28 |
29 | def __call__(self,model_outputs,batch):
30 | gt = torch.cat([batch["grounding_start_pos"].unsqueeze(1),batch["grounding_end_pos"].unsqueeze(1)],dim=1)
31 | pred = model_outputs['timestamps']
32 | scores = {x:[] for x in self.metrics}
33 | for p,g in zip(pred,gt):
34 | tiou = compute_tiou(p,g)
35 | scores["mIoU"].append(tiou)
36 | for thresh in self.iou_thresh:
37 | scores["Recall@{:.1f}".format(thresh)].append(tiou >= thresh)
38 | for k,v in scores.items():
39 | scores[k] = torch.mean(torch.Tensor(v))
40 | return scores
41 |
--------------------------------------------------------------------------------
/utils/loss.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from torch import nn
4 | import torch
5 |
6 | class TAGLoss(nn.Module):
7 | def __init__(self):
8 | super(TAGLoss, self).__init__()
9 |
10 | def forward(self, w, mask):
11 | ac_loss = (-mask*torch.log(w+1e-8)).sum(1) / mask.sum(1)
12 | ac_loss = ac_loss.mean(0)
13 |
14 | return ac_loss
15 |
16 | class TGRegressionCriterion(nn.Module):
17 | def __init__(self):
18 | super(TGRegressionCriterion, self).__init__()
19 |
20 | self.regloss1 = nn.SmoothL1Loss()
21 | self.regloss2 = nn.SmoothL1Loss()
22 |
23 | def forward(self, loc, s_gt, e_gt):
24 |
25 | total_loss = self.regloss1(loc[:,0], s_gt) + self.regloss2(loc[:,1], e_gt)
26 |
27 | return total_loss
28 |
29 | class NLVLLoss(nn.Module):
30 | def __init__(self,cfg, reg_w=1):
31 | super().__init__()
32 | self.temporal_localization_loss = TGRegressionCriterion()
33 | self.temporal_attention_loss2 = TAGLoss()
34 | self.reg_w = reg_w
35 |
36 | def forward(self,model_outputs,batch):
37 | # position loss
38 | timestamps = model_outputs['timestamps'] # [B,2]
39 | gt_start_pos = batch["grounding_start_pos"]
40 | gt_end_pos = batch["grounding_end_pos"]
41 | gt_timestamps = torch.cat([gt_start_pos.unsqueeze(1),gt_end_pos.unsqueeze(1)],dim=1) # [B,2]
42 |
43 | localization_loss = self.temporal_localization_loss(timestamps, gt_start_pos, gt_end_pos)
44 | localization_loss = localization_loss * self.reg_w
45 |
46 | # attention loss
47 | attention_weights = model_outputs['attention_weights'] # [B,128]
48 | attention_masks = batch["attention_masks"] # [B,128]
49 | attention_loss = self.temporal_attention_loss2(attention_weights,attention_masks)
50 |
51 | loss_dict = {
52 | "localization_loss": localization_loss,
53 | "attention_loss": attention_loss
54 | }
55 | return loss_dict
56 |
--------------------------------------------------------------------------------