├── README.md ├── annotations └── charades │ ├── charades_sta_test_pos_original_simple_sent.json │ └── charades_train_pseudo_supervision_TEP_PS.json ├── config_parsers ├── simple_model_config.py └── simple_model_cross_modality_twostage_attention_config.py ├── configs └── cha_simple_model │ └── simplemodel_cha_BS256_two-stage_attention.yml ├── dataset ├── anetcap_basic.py └── charades_basic.py ├── inference.py ├── media └── task-1.png ├── models ├── simple_model.py ├── simple_model_cross_modal_two_stage.py └── simple_model_cross_modal_two_stage_temperature.py ├── requirements.txt ├── train.py └── utils ├── eval_utils.py └── loss.py /README.md: -------------------------------------------------------------------------------- 1 | # Zero-shot Natural Language Video Localization (ZSNLVL) by Pseudo-Supervised Video Localization (PSVL) 2 | 3 | This repository is for [Zero-shot Natural Language Video Localization](https://openaccess.thecvf.com/content/ICCV2021/papers/Nam_Zero-Shot_Natural_Language_Video_Localization_ICCV_2021_paper.pdf). (ICCV 2021, Oral) 4 | 5 | 6 | We first propose a novel task of zero-shot natural language video localization. The proposed task setup does not require any paired annotation cost for NLVL task but only requires easily available text corpora, off-the-shelf object detector, and a collection of videos to localize. To address the task, we propose a **P**seudo-**S**upervised **V**ideo **L**ocalization method, called **PSVL**, that can generate pseudo-supervision for training an NLVL model. Benchmarked on two widely used NLVL datasets, the proposed method exhibits competitive performance and performs on par or outperforms the models trained with stronger supervision. 7 | 8 | task_nlvl 9 | 10 | 11 | --- 12 | ## Environment 13 | This repository is implemented base on [PyTorch](http://pytorch.org/) with Anaconda.
14 | Refer to below instruction or use **Docker** (dcahn/psvl:latest).
15 | 16 | 17 | ### Get the code 18 | - Clone this repo with git, please use: 19 | ```bash 20 | git clone https://github.com/gistvision/PSVL.git 21 | ``` 22 | 23 | - Make your own environment (If you use docker envronment, you just clone the code and execute it.) 24 | ```bashz 25 | conda create --name PSVL --file requirements.txt 26 | conda activate PSVL 27 | ``` 28 | 29 | #### Working environment 30 | - RTX2080Ti (11G) 31 | - Ubuntu 18.04.5 32 | - pytorch 1.5.1 33 | 34 | ## Download 35 | 36 | ### Dataset & Pretrained model 37 | 38 | - This [link](https://drive.google.com/file/d/1Vjgm2XA3TYcc4h9IWR5k5efU-bXNir5f/view?usp=sharing) is connected for downloading video features used in this paper.
39 | : After downloading the video feature, you need to set the `data path` in a config file.
40 | 41 | - This [link](https://drive.google.com/file/d/1M2FX2qkEvyked50LSc9Y5r87GBnpohSX/view?usp=sharing) is connected for downloading pre-trained model. 42 | 43 | For ActivityNet-Captions, check Activinet-Captions section of this document. 44 | 45 | ## Evaluating pre-trained models 46 | 47 | If you want to evaluate the pre-trained model, you can use below command. 48 | 49 | ```bash 50 | python inference.py --model CrossModalityTwostageAttention --config "YOUR CONFIG PATH" --pre_trained "YOUR MODEL PATH" 51 | ``` 52 | 53 | ## Training models from scratch 54 | 55 | To train PSVL, run `train.py` with below command. 56 | 57 | ```bash 58 | # Training from scratch 59 | python train.py --model CrossModalityTwostageAttention --config "YOUR CONFIG PATH" 60 | # Evaluation 61 | python inference.py --model CrossModalityTwostageAttention --config "YOUR CONFIG PATH" --pre_trained "YOUR MODEL PATH" 62 | ``` 63 | 64 | ## Activinet-Captions 65 | 66 | - Go to [this repository](https://github.com/JonghwanMun/LGI4temporalgrounding), and download the video features for ActiviNet-Captions.
67 | Place the data under `/dataset/lgi_video_feature/anet_feats`.
68 | 69 | - Other data can be downloaded from [this link](https://drive.google.com/file/d/1yXnVHslpV51zqd9TRJAjFXPZZESCFrPm/view?usp=sharing). 70 | 71 | Please download the file, unzip it, and type followings to train/inference with the data. 72 | 73 | To train the model, please run: 74 | ```bash 75 | python train.py --model CrossModalityTwostageAttention --config configs/anet_simple_model/simplemodel_anet_BS256_two-stage_attention.yml --dataset anet 76 | ``` 77 | 78 | To inference with test set, please run: 79 | ```bash 80 | python inference.py --model CrossModalityTwostageAttention --config configs/anet_simple_model/simplemodel_anet_BS256_two-stage_attention.yml --pre_trained anet_pretrained_best.pth 81 | ``` 82 | 83 | ## Lisence 84 | MIT Lisence 85 | 86 | ## Citation 87 | 88 | If you use this code, please cite: 89 | ``` 90 | @inproceedings{nam2021zero, 91 | title={Zero-shot Natural Language Video Localization}, 92 | author={Nam, Jinwoo and Ahn, Daechul and Kang, Dongyeop and Ha, Seong Jong and Choi, Jonghyun}, 93 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 94 | pages={1470-1479}, 95 | year={2021} 96 | } 97 | ``` 98 | 99 | ## Contact 100 | If you have any questions, please send e-mail to me (skaws2012@gmail.com, daechulahn@gm.gist.ac.kr) 101 | -------------------------------------------------------------------------------- /config_parsers/simple_model_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | simple_model_config.py 3 | **** 4 | config for charades dataset 5 | """ 6 | from yacs.config import CfgNode as CN 7 | _C = CN() 8 | _C.EXP_NAME = "SimpleModel-verbpos" 9 | #_C.EXP_NAME = "Debug" 10 | 11 | # training options 12 | _C.DATASET = CN() 13 | _C.DATASET.NAME = "Charades" 14 | _C.DATASET.SHOW_TOP_VOCAB = 1 15 | _C.DATASET.BATCH_SIZE = 100 16 | _C.DATASET.MAX_LENGTH = 15 17 | _C.DATASET.NUM_SEGMENT = 128 18 | _C.DATASET.DATA_PATH = "/home/data/anet" 19 | _C.DATASET.TRAIN_ANNO_PATH = "annotations/charades_sta_train.json" 20 | _C.DATASET.TEST_ANNO_PATH = "annotations/charades_sta_test.json" 21 | _C.DATASET.VID_PATH = "" 22 | 23 | # model options 24 | _C.MODEL = CN() 25 | _C.MODEL.QUERY = CN() 26 | _C.MODEL.QUERY.EMB_IDIM = -1 27 | _C.MODEL.QUERY.EMB_ODIM = 300 28 | _C.MODEL.QUERY.GRU_HDIM = 256 29 | _C.MODEL.VIDEO = CN() 30 | _C.MODEL.VIDEO.IDIM = 1024 31 | _C.MODEL.VIDEO.GRU_HDIM = 256 32 | _C.MODEL.FUSION = CN() 33 | _C.MODEL.FUSION.EMB_DIM = 256 34 | _C.MODEL.FUSION.NUM_HEAD = 8 35 | _C.MODEL.FUSION.NUM_LAYERS = 3 36 | _C.MODEL.FUSION.CONVBNRELU = CN() 37 | _C.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE = 3 38 | _C.MODEL.FUSION.CONVBNRELU.PADDING = 1 39 | _C.MODEL.NONLOCAL = CN() 40 | _C.MODEL.NONLOCAL.NUM_LAYERS = 2 41 | _C.MODEL.NONLOCAL.NUM_HEAD = 4 42 | _C.MODEL.NONLOCAL.USE_BIAS = True 43 | _C.MODEL.NONLOCAL.DROPOUT = 0.0 44 | 45 | # training options 46 | _C.TRAIN = CN() 47 | _C.TRAIN.LR = 0.0001 48 | _C.TRAIN.NUM_EPOCH = 300 49 | _C.TRAIN.BATCH_SIZE = 100 50 | _C.TRAIN.NUM_WORKERS = 4 51 | _C.TRAIN.IOU_THRESH = [0.3,0.5,0.7] 52 | -------------------------------------------------------------------------------- /config_parsers/simple_model_cross_modality_twostage_attention_config.py: -------------------------------------------------------------------------------- 1 | """ 2 | simple_model_config.py 3 | **** 4 | config for charades dataset 5 | """ 6 | from yacs.config import CfgNode as CN 7 | _C = CN() 8 | _C.EXP_NAME = "SimpleModel-verbpos" 9 | #_C.EXP_NAME = "Debug" 10 | 11 | # training options 12 | _C.DATASET = CN() 13 | _C.DATASET.NAME = "Charades" 14 | _C.DATASET.SHOW_TOP_VOCAB = 1 15 | _C.DATASET.BATCH_SIZE = 100 16 | _C.DATASET.MAX_LENGTH = 15 17 | _C.DATASET.NUM_SEGMENT = 128 18 | _C.DATASET.DATA_PATH = "/home/data/charades" 19 | _C.DATASET.TRAIN_ANNO_PATH = "annotations/charades_train_pos.json" 20 | _C.DATASET.TEST_ANNO_PATH = "annotations/charades_test_pos.json" 21 | _C.DATASET.VID_PATH = "" 22 | 23 | # model options 24 | _C.MODEL = CN() 25 | _C.MODEL.QUERY = CN() 26 | _C.MODEL.QUERY.EMB_IDIM = -1 27 | _C.MODEL.QUERY.TRANSFORMER_DIM = 300 28 | _C.MODEL.QUERY.EMB_ODIM = 300 29 | _C.MODEL.QUERY.GRU_HDIM = 256 30 | _C.MODEL.QUERY.TEMPERATURE = 1.0 31 | _C.MODEL.VIDEO = CN() 32 | _C.MODEL.VIDEO.IDIM = 1024 33 | _C.MODEL.VIDEO.GRU_HDIM = 256 34 | _C.MODEL.VIDEO.ANET_TRANSFORMER_DIM = 500 35 | _C.MODEL.VIDEO.CHA_TRANSFORMER_DIM = 1024 36 | _C.MODEL.FUSION = CN() 37 | _C.MODEL.FUSION.EMB_DIM = 256 38 | _C.MODEL.FUSION.NUM_HEAD = 8 39 | _C.MODEL.FUSION.NUM_LAYERS = 3 40 | _C.MODEL.FUSION.USE_RESBLOCK = False 41 | _C.MODEL.FUSION.RESBLOCK = CN() 42 | _C.MODEL.FUSION.RESBLOCK.KERNEL_SIZE = 3 43 | _C.MODEL.FUSION.RESBLOCK.PADDING = 1 44 | _C.MODEL.FUSION.RESBLOCK.NB_ITER = 1 45 | _C.MODEL.FUSION.CONVBNRELU = CN() 46 | _C.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE = 3 47 | _C.MODEL.FUSION.CONVBNRELU.PADDING = 1 48 | 49 | _C.MODEL.NONLOCAL = CN() 50 | _C.MODEL.NONLOCAL.NUM_LAYERS = 2 51 | _C.MODEL.NONLOCAL.NUM_HEAD = 4 52 | _C.MODEL.NONLOCAL.USE_BIAS = True 53 | _C.MODEL.NONLOCAL.DROPOUT = 0.0 54 | 55 | # training options 56 | _C.TRAIN = CN() 57 | #_C.TRAIN.USE_DETERMINISTIC = True 58 | _C.TRAIN.LR = 0.0001 59 | _C.TRAIN.NUM_EPOCH = 300 60 | _C.TRAIN.BATCH_SIZE = 100 61 | _C.TRAIN.NUM_WORKERS = 4 62 | _C.TRAIN.IOU_THRESH = [0.3,0.5,0.7] 63 | -------------------------------------------------------------------------------- /configs/cha_simple_model/simplemodel_cha_BS256_two-stage_attention.yml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | BATCH_SIZE: 256 3 | DATA_PATH: "/dataset/charades_feats" 4 | MAX_LENGTH: 10 5 | NAME: Charades 6 | NUM_SEGMENT: 128 7 | SHOW_TOP_VOCAB: 10 8 | TEST_ANNO_PATH: annotations/charades/charades_sta_test_pos_original_simple_sent.json 9 | TRAIN_ANNO_PATH: annotations/charades/charades_train_pseudo_supervision_TEP_PS.json 10 | VID_PATH: '' 11 | EXP_NAME: SimpleModel_twostage_attention 12 | MODEL: 13 | FUSION: 14 | CONVBNRELU: 15 | KERNEL_SIZE: 3 16 | PADDING: 1 17 | EMB_DIM: 256 18 | NUM_HEAD: 1 19 | NUM_LAYERS: 3 20 | NONLOCAL: 21 | DROPOUT: 0.0 22 | NUM_HEAD: 4 23 | NUM_LAYERS: 2 24 | USE_BIAS: true 25 | QUERY: 26 | EMB_IDIM: 290 27 | EMB_ODIM: 300 28 | GRU_HDIM: 256 29 | VIDEO: 30 | GRU_HDIM: 256 31 | IDIM: 1024 32 | TRAIN: 33 | BATCH_SIZE: 256 34 | IOU_THRESH: 35 | - 0.1 36 | - 0.3 37 | - 0.5 38 | - 0.7 39 | LR: 0.0004 40 | NUM_EPOCH: 500 41 | NUM_WORKERS: 4 42 | -------------------------------------------------------------------------------- /dataset/anetcap_basic.py: -------------------------------------------------------------------------------- 1 | #%% 2 | """ 3 | charades_basic.py 4 | **** 5 | the basic charades dataset class. 6 | """ 7 | #%% 8 | # import things 9 | import os 10 | from os.path import join 11 | import json 12 | import h5py 13 | import numpy as np 14 | from tqdm import tqdm 15 | import torch 16 | from torch.utils.data import Dataset 17 | from torch.utils.data import DataLoader 18 | from random import random, randint, choice 19 | from collections import Counter 20 | 21 | import sys 22 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 23 | from config_parsers.simple_model_config import _C as TestCfg 24 | 25 | 26 | #%% 27 | # dataset builder class 28 | class AnetCapBasicDatasetBuilder(): 29 | def __init__(self,cfg,data_path=None,anno_path=None,vid_path=None): 30 | # make variables that will be used in the future 31 | self.cfg = cfg 32 | self.splits = ["train","test"] # list of splits 33 | # make paths 34 | if len(data_path) > 0: 35 | self.vid_path = join(data_path,"sub_activitynet_v1-3.c3d.hdf5") 36 | self.anno_path = {"train": join(data_path,"normal_set","captions","train_simplified_sentence.json"), 37 | "test": [join(data_path,"normal_set","captions","val_1_simplified_sentence.json"), 38 | join(data_path,"normal_set","captions","val_2_simplified_sentence.json")]} 39 | if len(vid_path) > 0: 40 | self.vid_path = vid_path 41 | if len(anno_path) > 0: 42 | self.anno_path = anno_path 43 | self.anno_path['test'] = [self.anno_path['test'].format(i) for i in [1,2]] 44 | # read annotations 45 | self.annos = self._read_annos(self.anno_path) 46 | # make dictionary of word-index correspondence 47 | self.wtoi, self.itow = self._make_word_dictionary(self.annos) 48 | 49 | def _read_annos(self,anno_path): 50 | # read annotations 51 | annos = {s: None for s in self.splits} 52 | for s in self.splits: 53 | if isinstance(anno_path[s],str): 54 | with open(anno_path[s],'r') as f: 55 | #annos[s] = json.load(f)[:100] 56 | annos[s] = json.load(f) 57 | elif isinstance(anno_path[s],list): 58 | annos[s] = [] 59 | for path in anno_path[s]: 60 | with open(path,'r') as f: 61 | annos[s] += json.load(f) 62 | return annos 63 | 64 | def _make_word_dictionary(self,annos): 65 | """ 66 | makes word tokens - number idx correspondences 67 | ARGS: 68 | - annos: annotations read 69 | RETURNS: 70 | - wtoi: word -> index dictionary 71 | - itow: index -> word dictionary 72 | PARAMS: 73 | - DATASET.SHOW_TOP_VOCAB: the number of top-n tokens to print 74 | """ 75 | # get training annos 76 | train_annos = self.annos["train"] 77 | # read tokens 78 | tokens_list = [] 79 | for ann in train_annos: 80 | tokens_list += [tk for tk in ann["tokens"]] 81 | # print results: count tokens and show top-n 82 | print("Top-{} tokens list:".format(self.cfg.DATASET.SHOW_TOP_VOCAB)) 83 | tokens_count = sorted(Counter(tokens_list).items(), key=lambda x:x[1]) 84 | for tk in tokens_count[-self.cfg.DATASET.SHOW_TOP_VOCAB:]: 85 | print("\t- {}: {}".format(tk[0],tk[1])) 86 | # make wtoi, itow 87 | wtoi = {} 88 | wtoi[""], wtoi[""] = 0, 1 89 | wtoi[""], wtoi[""] = 2, 3 90 | for i,(tk,cnt) in enumerate(tokens_count): 91 | idx = i+4 # idx start at 4 92 | wtoi[tk] = idx 93 | itow = {v:k for k,v in wtoi.items()} 94 | self.cfg.MODEL.QUERY.EMB_IDIM = len(wtoi) 95 | return wtoi, itow 96 | 97 | def make_dataloaders(self): 98 | """ 99 | makes actual dataset class 100 | RETURNS: 101 | - dataloaders: dataset classes for each splits. dictionary of {split: dataset} 102 | """ 103 | # read annotations 104 | annos = self._read_annos(self.anno_path) 105 | # make dictionary of word-index correspondence 106 | wtoi, itow = self._make_word_dictionary(annos) 107 | batch_size = self.cfg.TRAIN.BATCH_SIZE 108 | num_workers = self.cfg.TRAIN.NUM_WORKERS 109 | dataloaders = {} 110 | for s in self.splits: 111 | if "train" in s: 112 | dataset = AnetCapBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s]) 113 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=True, shuffle=True) 114 | else: 115 | dataset = AnetCapBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s]) 116 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=False, shuffle=False) 117 | return dataloaders 118 | 119 | 120 | #%% 121 | # Charades Dataset Class 122 | class AnetCapBasicDataset(Dataset): 123 | def __init__(self, cfg, vid_path, split, wtoi, itow, annos): 124 | self.cfg = cfg 125 | self.vid_path = vid_path 126 | self.split = split 127 | self.wtoi = wtoi 128 | self.itow = itow 129 | self.annos = annos 130 | self.feats = self._load_vid_feats() 131 | self.num_segments = self.cfg.DATASET.NUM_SEGMENT 132 | self.sentence_max_length = self.cfg.DATASET.MAX_LENGTH 133 | 134 | def _load_vid_feats(self): 135 | feats = {} 136 | hfile = h5py.File(self.vid_path,'r') 137 | vid_list = set([x['vid'] for x in self.annos]) 138 | for vid in tqdm(vid_list, desc="loading video features"): 139 | feats[vid] = hfile.get(vid).get("c3d_features")[()] 140 | return feats 141 | 142 | def _tokens_to_index(self,tokens): 143 | """ 144 | translates list of tokens into trainable index format. also does padding. 145 | """ 146 | wids = [] 147 | for tk in tokens: 148 | if tk in self.wtoi.keys(): 149 | wids.append(self.wtoi[tk]) 150 | else: 151 | wids.append(1) # 152 | for _ in range(self.sentence_max_length - len(wids)): 153 | wids.append(0) 154 | if len(wids) > self.sentence_max_length: 155 | wids = wids[:self.sentence_max_length] 156 | return wids 157 | 158 | def get_fixed_length_feat(self, feat, num_segment, start_pos, end_pos): 159 | """ 160 | makes fixed length feature. adopted from LGI code. 161 | """ 162 | nfeats = feat[:,:].shape[0] 163 | if nfeats <= self.num_segments: 164 | stride = 1 165 | else: 166 | stride = nfeats * 1.0 / num_segment 167 | if self.split != "train": 168 | spos = 0 169 | else: 170 | random_end = -0.5 + stride 171 | if random_end == np.floor(random_end): 172 | random_end = random_end - 1.0 173 | spos = np.random.random_integers(0,random_end) 174 | s = np.round( np.arange(spos, nfeats-0.5, stride) ).astype(int) 175 | start_pos = float(nfeats-1.0) * start_pos 176 | end_pos = float(nfeats-1.0) * end_pos 177 | 178 | if not (nfeats < self.num_segments and len(s) == nfeats) \ 179 | and not (nfeats >= self.num_segments and len(s) == num_segment): 180 | s = s[:num_segment] # ignore last one 181 | assert (nfeats < self.num_segments and len(s) == nfeats) \ 182 | or (nfeats >= self.num_segments and len(s) == num_segment), \ 183 | "{} != {} or {} != {}".format(len(s), nfeats, len(s), num_segment) 184 | 185 | start_index, end_index = None, None 186 | for i in range(len(s)-1): 187 | if s[i] <= end_pos < s[i+1]: 188 | end_index = i 189 | if s[i] <= start_pos < s[i+1]: 190 | start_index = i 191 | 192 | if start_index is None: 193 | start_index = 0 194 | if end_index is None: 195 | end_index = num_segment-1 196 | 197 | cur_feat = feat[s, :] 198 | nfeats = min(nfeats, num_segment) 199 | out = np.zeros((num_segment, cur_feat.shape[1])) 200 | out [:nfeats,:] = cur_feat 201 | return out, nfeats, start_index, end_index 202 | 203 | def make_attention_mask(self,start_index,end_index): 204 | attn_mask = np.zeros([self.num_segments]) 205 | attn_mask[start_index:end_index+1] = 1 206 | attn_mask = torch.Tensor(attn_mask) 207 | return attn_mask 208 | 209 | def __getitem__(self,idx): 210 | anno = self.annos[idx] 211 | vid = anno["vid"] 212 | duration = anno['duration'] 213 | timestamp = [x*duration for x in anno['timestamp']] 214 | start_pos, end_pos = anno['timestamp'] 215 | query_label = self._tokens_to_index(anno['tokens']) 216 | query_length = len(anno['tokens']) 217 | vid_feat = self.feats[vid] 218 | 219 | fixed_vid_feat, nfeats, start_index, end_index = self.get_fixed_length_feat(vid_feat, self.num_segments, start_pos, end_pos) 220 | # get video masks 221 | vid_mask = np.zeros((self.num_segments, 1)) 222 | vid_mask[:nfeats] = 1 223 | # make attn mask 224 | instance = { 225 | "vids": vid, 226 | "qids": idx, 227 | "timestamps": timestamp, # GT location [s, e] (second) 228 | "duration": duration, # video span (second) 229 | "query_lengths": query_length, 230 | "query_labels": torch.LongTensor(query_label).unsqueeze(0), # [1,L_q_max] 231 | "query_masks": (torch.FloatTensor(query_label)>0).unsqueeze(0), # [1,L_q_max] 232 | "grounding_start_pos": torch.FloatTensor([start_pos]), # [1]; normalized 233 | "grounding_end_pos": torch.FloatTensor([end_pos]), # [1]; normalized 234 | "nfeats": torch.FloatTensor([nfeats]), 235 | "video_feats": torch.FloatTensor(fixed_vid_feat), # [L_v,D_v] 236 | "video_masks": torch.ByteTensor(vid_mask), # [L_v,1] 237 | "attention_masks": self.make_attention_mask(start_index,end_index), 238 | } 239 | return instance 240 | 241 | def collate_fn(self, data): 242 | seq_items = ["video_feats", "video_masks","attention_masks"] 243 | tensor_items = [ 244 | "query_labels", "query_masks", "nfeats", 245 | "grounding_start_pos", "grounding_end_pos" 246 | ] 247 | batch = {k: [d[k] for d in data] for k in data[0].keys()} 248 | if len(data) == 1: 249 | for k,v in batch.items(): 250 | if k in tensor_items: 251 | batch[k] = torch.cat(batch[k], 0) 252 | elif k in seq_items: 253 | batch[k] = torch.nn.utils.rnn.pad_sequence( 254 | batch[k], batch_first=True) 255 | else: 256 | batch[k] = batch[k][0] 257 | else: 258 | for k in tensor_items: 259 | batch[k] = torch.cat(batch[k], 0) 260 | for k in seq_items: 261 | batch[k] = torch.nn.utils.rnn.pad_sequence(batch[k], batch_first=True) 262 | return batch 263 | 264 | def batch_to_device(self,batch,device): 265 | for k,v in batch.items(): 266 | if isinstance(v,torch.Tensor): 267 | batch[k] = v.to(device) 268 | 269 | def __len__(self): 270 | return len(self.annos) 271 | 272 | 273 | #%% 274 | # test 275 | if __name__ == "__main__": 276 | DATA_PATH = "/home/data/anet" 277 | dataloaders = AnetCapBasicDatasetBuilder(TestCfg,DATA_PATH).make_dataloaders() 278 | for item in dataloaders['train']: 279 | _ = item['vids'] 280 | 281 | -------------------------------------------------------------------------------- /dataset/charades_basic.py: -------------------------------------------------------------------------------- 1 | #%% 2 | """ 3 | charades_basic.py 4 | **** 5 | the basic charades dataset class. 6 | """ 7 | #%% 8 | # import things 9 | import os 10 | from os.path import join 11 | import json 12 | import h5py 13 | import numpy as np 14 | from tqdm import tqdm 15 | import torch 16 | from torch.utils.data import Dataset 17 | from torch.utils.data import DataLoader 18 | from random import random, randint, choice 19 | from collections import Counter 20 | 21 | import sys 22 | sys.path.append(os.path.dirname(os.path.dirname(__file__))) 23 | from config_parsers.simple_model_config import _C as TestCfg 24 | 25 | 26 | #%% 27 | # dataset builder class 28 | class CharadesBasicDatasetBuilder(): 29 | def __init__(self,cfg,data_path=None,anno_path=None,vid_path=None): 30 | # make variables that will be used in the future 31 | self.cfg = cfg 32 | self.splits = ["train","test"] # list of splits 33 | # make paths 34 | if len(data_path) > 0: 35 | self.vid_path = join(data_path,"i3d_finetuned") 36 | self.anno_path = {s: join(data_path,"annotations","charades_sta_{}_pos_original.json".format(s)) for s in self.splits} 37 | if len(vid_path) > 0: 38 | self.vid_path = vid_path 39 | if len(anno_path) > 0: 40 | self.anno_path = anno_path 41 | # read annotations 42 | self.annos = self._read_annos(self.anno_path) 43 | # make dictionary of word-index correspondence 44 | self.wtoi, self.itow = self._make_word_dictionary(self.annos) 45 | 46 | def _read_annos(self,anno_path): 47 | # read annotations 48 | 49 | annos = {s: None for s in self.splits} 50 | for s in self.splits: 51 | with open(anno_path[s],'r') as f: 52 | #annos[s] = json.load(f)[:100] 53 | annos[s] = json.load(f) 54 | return annos 55 | 56 | def _make_word_dictionary(self,annos): 57 | """ 58 | makes word tokens - number idx correspondences 59 | ARGS: 60 | - annos: annotations read 61 | RETURNS: 62 | - wtoi: word -> index dictionary 63 | - itow: index -> word dictionary 64 | PARAMS: 65 | - DATASET.SHOW_TOP_VOCAB: the number of top-n tokens to print 66 | """ 67 | # get training annos 68 | train_annos = self.annos["train"] 69 | # read tokens 70 | tokens_list = [] 71 | for ann in train_annos: 72 | tokens_list += [tk for tk in ann["tokens"]] 73 | # print results: count tokens and show top-n 74 | print("Top-{} tokens list:".format(self.cfg.DATASET.SHOW_TOP_VOCAB)) 75 | tokens_count = sorted(Counter(tokens_list).items(), key=lambda x:x[1]) 76 | for tk in tokens_count[-self.cfg.DATASET.SHOW_TOP_VOCAB:]: 77 | print("\t- {}: {}".format(tk[0],tk[1])) 78 | # make wtoi, itow 79 | wtoi = {} 80 | wtoi[""], wtoi[""] = 0, 1 81 | wtoi[""], wtoi[""] = 2, 3 82 | for i,(tk,cnt) in enumerate(tokens_count): 83 | idx = i+4 # idx start at 4 84 | wtoi[tk] = idx 85 | itow = {v:k for k,v in wtoi.items()} 86 | self.cfg.MODEL.QUERY.EMB_IDIM = len(wtoi) 87 | return wtoi, itow 88 | 89 | def make_dataloaders(self): 90 | """ 91 | makes actual dataset class 92 | RETURNS: 93 | - dataloaders: dataset classes for each splits. dictionary of {split: dataset} 94 | """ 95 | # read annotations 96 | annos = self._read_annos(self.anno_path) 97 | # make dictionary of word-index correspondence 98 | wtoi, itow = self._make_word_dictionary(annos) 99 | batch_size = self.cfg.TRAIN.BATCH_SIZE 100 | num_workers = self.cfg.TRAIN.NUM_WORKERS 101 | dataloaders = {} 102 | for s in self.splits: 103 | if "train" in s: 104 | dataset = CharadesBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s]) 105 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=True, shuffle=True) 106 | else: 107 | dataset = CharadesBasicDataset(self.cfg, self.vid_path, s, wtoi, itow, annos[s]) 108 | dataloaders[s] = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=dataset.collate_fn, drop_last=False, shuffle=False) 109 | return dataloaders 110 | 111 | 112 | #%% 113 | # Charades Dataset Class 114 | class CharadesBasicDataset(Dataset): 115 | def __init__(self, cfg, vid_path, split, wtoi, itow, annos): 116 | self.cfg = cfg 117 | self.vid_path = vid_path 118 | self.split = split 119 | self.wtoi = wtoi 120 | self.itow = itow 121 | self.annos = annos 122 | self.feats = self._load_vid_feats() 123 | self.num_segments = self.cfg.DATASET.NUM_SEGMENT 124 | self.sentence_max_length = self.cfg.DATASET.MAX_LENGTH 125 | 126 | def _load_vid_feats(self): 127 | feats = {} 128 | vid_list = [x['vid'] for x in self.annos] 129 | for vid in tqdm(vid_list, desc="loading video features"): 130 | feats[vid] = np.load(join(self.vid_path,"{}.npy".format(vid))).squeeze() 131 | return feats 132 | 133 | def _tokens_to_index(self,tokens): 134 | """ 135 | translates list of tokens into trainable index format. also does padding. 136 | """ 137 | wids = [] 138 | for tk in tokens: 139 | if tk in self.wtoi.keys(): 140 | wids.append(self.wtoi[tk]) 141 | else: 142 | wids.append(1) # 143 | for _ in range(self.sentence_max_length - len(wids)): 144 | wids.append(0) 145 | if len(wids) > self.sentence_max_length: 146 | wids = wids[:self.sentence_max_length] 147 | return wids 148 | 149 | def get_fixed_length_feat(self, feat, num_segment, start_pos, end_pos): 150 | """ 151 | makes fixed length feature. adopted from LGI code. 152 | """ 153 | nfeats = feat[:,:].shape[0] 154 | if nfeats <= self.num_segments: 155 | stride = 1 156 | else: 157 | stride = nfeats * 1.0 / num_segment 158 | if self.split != "train": 159 | spos = 0 160 | else: 161 | random_end = -0.5 + stride 162 | if random_end == np.floor(random_end): 163 | random_end = random_end - 1.0 164 | spos = np.random.random_integers(0,random_end) 165 | s = np.round( np.arange(spos, nfeats-0.5, stride) ).astype(int) 166 | start_pos = float(nfeats-1.0) * start_pos 167 | end_pos = float(nfeats-1.0) * end_pos 168 | 169 | if not (nfeats < self.num_segments and len(s) == nfeats) \ 170 | and not (nfeats >= self.num_segments and len(s) == num_segment): 171 | s = s[:num_segment] # ignore last one 172 | assert (nfeats < self.num_segments and len(s) == nfeats) \ 173 | or (nfeats >= self.num_segments and len(s) == num_segment), \ 174 | "{} != {} or {} != {}".format(len(s), nfeats, len(s), num_segment) 175 | 176 | start_index, end_index = None, None 177 | for i in range(len(s)-1): 178 | if s[i] <= end_pos < s[i+1]: 179 | end_index = i 180 | if s[i] <= start_pos < s[i+1]: 181 | start_index = i 182 | 183 | if start_index is None: 184 | start_index = 0 185 | if end_index is None: 186 | end_index = num_segment-1 187 | 188 | cur_feat = feat[s, :] 189 | nfeats = min(nfeats, num_segment) 190 | out = np.zeros((num_segment, cur_feat.shape[1])) 191 | out [:nfeats,:] = cur_feat 192 | return out, nfeats, start_index, end_index 193 | 194 | def make_attention_mask(self,start_index,end_index): 195 | attn_mask = np.zeros([self.num_segments]) 196 | attn_mask[start_index:end_index+1] = 1 197 | attn_mask = torch.Tensor(attn_mask) 198 | return attn_mask 199 | 200 | def __getitem__(self,idx): 201 | anno = self.annos[idx] 202 | vid = anno["vid"] 203 | duration = anno['duration'] 204 | timestamp = [x*duration for x in anno['timestamp']] 205 | start_pos, end_pos = anno['timestamp'] 206 | query_label = self._tokens_to_index(anno['tokens']) 207 | query_length = len(anno['tokens']) 208 | vid_feat = self.feats[vid] 209 | 210 | fixed_vid_feat, nfeats, start_index, end_index = self.get_fixed_length_feat(vid_feat, self.num_segments, start_pos, end_pos) 211 | # get video masks 212 | vid_mask = np.zeros((self.num_segments, 1)) 213 | vid_mask[:nfeats] = 1 214 | # make attn mask 215 | instance = { 216 | "vids": vid, 217 | "qids": idx, 218 | "timestamps": timestamp, # GT location [s, e] (second) 219 | "duration": duration, # video span (second) 220 | "query_lengths": query_length, 221 | "query_labels": torch.LongTensor(query_label).unsqueeze(0), # [1,L_q_max] 222 | "query_masks": (torch.FloatTensor(query_label)>0).unsqueeze(0), # [1,L_q_max] 223 | "grounding_start_pos": torch.FloatTensor([start_pos]), # [1]; normalized 224 | "grounding_end_pos": torch.FloatTensor([end_pos]), # [1]; normalized 225 | "nfeats": torch.FloatTensor([nfeats]), 226 | "video_feats": torch.FloatTensor(fixed_vid_feat), # [L_v,D_v] 227 | "video_masks": torch.ByteTensor(vid_mask), # [L_v,1] 228 | "attention_masks": self.make_attention_mask(start_index,end_index), 229 | } 230 | return instance 231 | 232 | def collate_fn(self, data): 233 | seq_items = ["video_feats", "video_masks","attention_masks"] 234 | tensor_items = [ 235 | "query_labels", "query_masks", "nfeats", 236 | "grounding_start_pos", "grounding_end_pos" 237 | ] 238 | batch = {k: [d[k] for d in data] for k in data[0].keys()} 239 | if len(data) == 1: 240 | for k,v in batch.items(): 241 | if k in tensor_items: 242 | batch[k] = torch.cat(batch[k], 0) 243 | elif k in seq_items: 244 | batch[k] = torch.nn.utils.rnn.pad_sequence( 245 | batch[k], batch_first=True) 246 | else: 247 | batch[k] = batch[k][0] 248 | else: 249 | for k in tensor_items: 250 | batch[k] = torch.cat(batch[k], 0) 251 | for k in seq_items: 252 | batch[k] = torch.nn.utils.rnn.pad_sequence(batch[k], batch_first=True) 253 | return batch 254 | 255 | def batch_to_device(self,batch,device): 256 | for k,v in batch.items(): 257 | if isinstance(v,torch.Tensor): 258 | batch[k] = v.to(device) 259 | 260 | def __len__(self): 261 | return len(self.annos) 262 | 263 | 264 | #%% 265 | # test 266 | if __name__ == "__main__": 267 | DATA_PATH = "/home/skaws2003/projects/didemo/localglobal/data/charades" 268 | dataloaders = CharadesBasicDatasetBuilder(TestCfg,DATA_PATH).make_dataloaders() 269 | for item in dataloaders['train']: 270 | _ = item['vids'] 271 | 272 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #%% 2 | # import things 3 | # model code 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from utils.eval_utils import NLVLEvaluator 7 | from utils.loss import NLVLLoss 8 | 9 | # logging code 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | # etc 13 | import random 14 | import numpy as np 15 | import os 16 | from os.path import join 17 | from tqdm import tqdm 18 | from copy import deepcopy 19 | from yacs.config import CfgNode 20 | from yaml import dump as dump_yaml 21 | import argparse 22 | 23 | #%% 24 | # argparse 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--model",'-m',type=str,default="CrossModalityTwostageAttention") 27 | parser.add_argument("--config",'-c',type=str,default="configs/cha_simple_model/simplemodel_cha_BS256_two-stage_attention.yml") 28 | parser.add_argument("--pre_trained", type=str, default="pretrained_weight.pth") 29 | parser.add_argument("--seed", '-s',type=int,default=0) 30 | parser.add_argument("--reg_w", type=float, default=1.0) 31 | args = parser.parse_args() 32 | 33 | dict_args = { 34 | "model": args.model, 35 | "confg": args.config 36 | } 37 | 38 | random.seed(int(args.seed)) 39 | os.environ['PYTHONHASHSEED'] = str(args.seed) 40 | np.random.seed(int(args.seed)) 41 | torch.manual_seed(int(args.seed)) 42 | torch.cuda.manual_seed(int(args.seed)) 43 | torch.cuda.manual_seed_all(int(args.seed)) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | torch.backends.cudnn.enabled = False 47 | 48 | #%% 49 | # load things according to arguments 50 | if args.model == "SimpleModel": 51 | from models.simple_model import SimpleModel as Model 52 | from config_parsers.simple_model_config import _C as TestCfg 53 | elif args.model == "CrossModalityTwostageAttention": 54 | from models.simple_model_cross_modal_two_stage import SimpleModel as Model 55 | from config_parsers.simple_model_cross_modality_twostage_attention_config import _C as TestCfg 56 | else: 57 | raise ValueError("No such model: {}".format(args.model)) 58 | 59 | 60 | #%% 61 | # constants 62 | CONFIG_PATH = args.config 63 | print("config file path: ", CONFIG_PATH) 64 | TestCfg.merge_from_file(CONFIG_PATH) 65 | cfg = TestCfg 66 | device = torch.device("cuda") 67 | DATA_PATH = cfg.DATASET.DATA_PATH 68 | ANNO_PATH ={"train": cfg.DATASET.TRAIN_ANNO_PATH, 69 | "test": cfg.DATASET.TEST_ANNO_PATH} 70 | VID_PATH = cfg.DATASET.VID_PATH 71 | 72 | #%% 73 | # function for dumping yaml 74 | def cfg_to_dict(cfg): 75 | dict_cfg = dict(cfg) 76 | for k,v in dict_cfg.items(): 77 | if isinstance(v,CfgNode): 78 | dict_cfg[k] = cfg_to_dict(v) 79 | return dict_cfg 80 | 81 | #%% 82 | # Load dataloader 83 | dataset_name = cfg.DATASET.NAME 84 | if dataset_name == "Charades": 85 | from dataset.charades_basic import CharadesBasicDatasetBuilder 86 | dataloaders = CharadesBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders() 87 | elif dataset_name == "AnetCap": 88 | from dataset.anetcap_basic import AnetCapBasicDatasetBuilder 89 | dataloaders = AnetCapBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders() 90 | else: 91 | raise ValueError("No such dataset: {}".format(dataset_name)) 92 | 93 | #%% 94 | # load training stuff 95 | ## model and loss 96 | model = Model(cfg).to(device) 97 | model.load_state_dict(torch.load(args.pre_trained)) 98 | 99 | ## evaluator 100 | evaluator = NLVLEvaluator(cfg) 101 | batch_to_device = dataloaders['test'].dataset.batch_to_device 102 | 103 | # information print out 104 | print("====="*10) 105 | print(args) 106 | print("====="*10) 107 | print(cfg) 108 | print("====="*10) 109 | 110 | #%% 111 | # test loop 112 | model.eval() 113 | pbar = tqdm(range(1)) 114 | for epoch in pbar: 115 | eval_results_list = [] 116 | for batch_idx,item in enumerate(dataloaders['test']): 117 | # update progress bar 118 | pbar.set_description("test {}/{}".format(batch_idx,len(dataloaders['test']))) 119 | # make evaluation 120 | with torch.no_grad(): 121 | batch_to_device(item,device) 122 | model_outputs = model(item) 123 | eval_results = evaluator(model_outputs,item) 124 | # add to mean eval results 125 | eval_results_list.append(eval_results) 126 | # write test information 127 | ## make mean metric dict 128 | mean_eval_results_dict = {} 129 | for k in list(eval_results_list[0].keys()): 130 | mean_eval_results_dict[k] = torch.mean(torch.Tensor([x[k] for x in eval_results_list])) 131 | 132 | # print test information 133 | tqdm.write("****** epoch:{} ******".format(epoch)) 134 | for k,v in mean_eval_results_dict.items(): 135 | tqdm.write("\t{}: {}".format(k,v)) 136 | -------------------------------------------------------------------------------- /media/task-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/PSVL/50b1041d472d4f140f7d98248a3bbfbb16d680da/media/task-1.png -------------------------------------------------------------------------------- /models/simple_model.py: -------------------------------------------------------------------------------- 1 | #%% 2 | """ 3 | simple_model.py 4 | **** 5 | simple, basic model for NLVL. 6 | - Query-Video matching with (Multi-Head Attention + ConvBNReLU) with residual connection 7 | - Video Encoding with simple GRU 8 | """ 9 | 10 | #%% 11 | # import things 12 | import torch 13 | import torch.nn as nn 14 | 15 | #%% 16 | # model 17 | class SimpleSentenceEmbeddingModule(nn.Module): 18 | """ 19 | A Simple Query Embedding class 20 | """ 21 | def __init__(self, cfg): 22 | super().__init__() 23 | # config params 24 | self.cfg = cfg 25 | self.query_length = self.cfg.DATASET.MAX_LENGTH 26 | # embedding Layer 27 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM 28 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM 29 | self.embedding = nn.Embedding(emb_idim, emb_odim) 30 | # RNN Layer 31 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM 32 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True) 33 | # feature adjust 34 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 35 | self.feature_aggregation = nn.Sequential( 36 | nn.Linear(in_features=gru_hidden*2,out_features=emb_dim), 37 | nn.ReLU(), 38 | nn.Dropout(0.5)) 39 | 40 | def forward(self, query_labels, query_masks): 41 | """ 42 | encode query sequence using RNN and return logits over proposals. 43 | code adopted from LGI 44 | Args: 45 | query_labels: query_labels vectors of query; [B, vocab_size] 46 | query_masks: mask for query; [B,L] 47 | out_type: output type [word-level | sentenve-level | both] 48 | Returns: 49 | w_feats: word-level features; [B,L,2*h] 50 | s_feats: sentence-level feature; [B,2*h] 51 | """ 52 | # embedding query_labels data 53 | wemb = self.embedding(query_labels) # [B,L,emb_odim] 54 | # encoding query_labels data. 55 | max_len = query_labels.size(1) # == L 56 | # make word-wise feature 57 | length = query_masks.sum(1) # [B,] 58 | pack_wemb = nn.utils.rnn.pack_padded_sequence(wemb, length, batch_first=True, enforce_sorted=False) 59 | w_feats, _ = self.gru(pack_wemb) 60 | w_feats, max_ = nn.utils.rnn.pad_packed_sequence(w_feats, batch_first=True, total_length=max_len) 61 | w_feats = w_feats.contiguous() # [B,L,2*h] 62 | # get sentence feature 63 | B, L, H = w_feats.size() 64 | idx = (length-1).long() # 0-indexed 65 | idx = idx.view(B, 1, 1).expand(B, 1, H//2) 66 | fLSTM = w_feats[:,:,:H//2].gather(1, idx).view(B, H//2) 67 | bLSTM = w_feats[:,0,H//2:].view(B,H//2) 68 | s_feats = torch.cat([fLSTM, bLSTM], dim=1) 69 | # aggregae features 70 | w_feats = self.feature_aggregation(w_feats) 71 | return w_feats, s_feats 72 | 73 | 74 | class SimpleVideoEmbeddingModule(nn.Module): 75 | """ 76 | A simple Video Embedding Class 77 | """ 78 | def __init__(self, cfg): 79 | super().__init__() # Must call super __init__() 80 | # get configuration 81 | self.cfg = cfg 82 | # video gru 83 | vid_idim = self.cfg.MODEL.VIDEO.IDIM 84 | vid_gru_hdim = self.cfg.MODEL.VIDEO.GRU_HDIM 85 | self.gru = nn.GRU(input_size=vid_idim,hidden_size=vid_gru_hdim,batch_first=True,dropout=0.5,bidirectional=True) 86 | # video feature aggregation module 87 | catted_dim = vid_idim + vid_gru_hdim*2 88 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 89 | self.feature_aggregation = nn.Sequential( 90 | nn.Linear(in_features=catted_dim,out_features=emb_dim), 91 | nn.ReLU(), 92 | nn.Dropout(0.5), 93 | ) 94 | 95 | def forward(self, vid_feats, vid_masks): 96 | """ 97 | encode video features. Utilizes GRU. 98 | Args: 99 | vid_feats: video features 100 | vid_masks: mask for video 101 | Return: 102 | vid_features: hidden state features of the video 103 | """ 104 | length = vid_masks.sum(1).squeeze(1) 105 | packed_vid = nn.utils.rnn.pack_padded_sequence(vid_feats, length, batch_first=True, enforce_sorted=False) 106 | vid_hiddens, _ = self.gru(packed_vid) 107 | vid_hiddens, max_ = nn.utils.rnn.pad_packed_sequence(vid_hiddens, batch_first=True, total_length=vid_feats.shape[1]) 108 | vid_catted = torch.cat([vid_feats,vid_hiddens],dim=2) 109 | vid_output = self.feature_aggregation(vid_catted) 110 | return vid_output 111 | 112 | 113 | class FusionConvBNReLU(nn.Module): 114 | def __init__(self,cfg): 115 | super().__init__() 116 | # get configuration 117 | self.cfg = cfg 118 | # modules 119 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 120 | kernel_size = self.cfg.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE 121 | padding = self.cfg.MODEL.FUSION.CONVBNRELU.PADDING 122 | self.module = nn.Sequential( 123 | nn.Conv1d(in_channels=emb_dim,out_channels=emb_dim,kernel_size=kernel_size,padding=padding), 124 | nn.BatchNorm1d(num_features=emb_dim), 125 | nn.ReLU()) 126 | 127 | def forward(self,feature): 128 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first) 129 | convolved_feature = self.module(transposed_feature) 130 | return torch.transpose(convolved_feature,1,2) 131 | 132 | 133 | class AttentionBlock(nn.Module): 134 | def __init__(self,cfg): 135 | super().__init__() 136 | # get configuration 137 | self.cfg = cfg 138 | # modules 139 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 140 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD 141 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head) 142 | self.convbnrelu = FusionConvBNReLU(cfg) 143 | 144 | def forward(self,vid_feats,query_feats,query_masks): 145 | # attnetion 146 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 147 | attended_feature, weights = self.attention( 148 | query=torch.transpose(vid_feats,0,1), 149 | key=torch.transpose(query_feats,0,1), 150 | value=torch.transpose(query_feats,0,1), 151 | key_padding_mask=key_padding_mask,) 152 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format 153 | # convolution 154 | convolved_feature = self.convbnrelu(attended_feature) + vid_feats 155 | return convolved_feature 156 | 157 | 158 | class SimpleFusionModule(nn.Module): 159 | def __init__(self, cfg): 160 | super().__init__() 161 | # get configuration 162 | self.cfg = cfg 163 | # attention module 164 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS 165 | self.layers = [] 166 | for _ in range(num_layers): 167 | self.layers.append(AttentionBlock(cfg)) 168 | self.layers = nn.ModuleList(self.layers) 169 | 170 | def forward(self, query_feats, query_masks, vid_feats, vid_masks): 171 | attended_vid_feats = vid_feats 172 | for attn_layer in self.layers: 173 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks) 174 | return attended_vid_feats 175 | 176 | 177 | class NonLocalBlock(nn.Module): 178 | """ 179 | Nonlocal block used for obtaining global feature. 180 | code borrowed from LGI 181 | """ 182 | def __init__(self, cfg): 183 | super(NonLocalBlock, self).__init__() 184 | self.cfg = cfg 185 | # dims 186 | self.idim = self.cfg.MODEL.FUSION.EMB_DIM 187 | self.odim = self.cfg.MODEL.FUSION.EMB_DIM 188 | self.nheads = self.cfg.MODEL.NONLOCAL.NUM_HEAD 189 | 190 | # options 191 | self.use_bias = self.cfg.MODEL.NONLOCAL.USE_BIAS 192 | 193 | # layers 194 | self.c_lin = nn.Linear(self.idim, self.odim*2, bias=self.use_bias) 195 | self.v_lin = nn.Linear(self.idim, self.odim, bias=self.use_bias) 196 | 197 | self.relu = nn.ReLU() 198 | self.sigmoid = nn.Sigmoid() 199 | self.drop = nn.Dropout(self.cfg.MODEL.NONLOCAL.DROPOUT) 200 | 201 | def forward(self, m_feats, mask): 202 | """ 203 | Inputs: 204 | m_feats: segment-level multimodal feature [B,nseg,*] 205 | mask: mask [B,nseg] 206 | Outputs: 207 | updated_m: updated multimodal feature [B,nseg,*] 208 | """ 209 | 210 | mask = mask.float() 211 | B, nseg = mask.size() 212 | 213 | # key, query, value 214 | m_k = self.v_lin(self.drop(m_feats)) # [B,num_seg,*] 215 | m_trans = self.c_lin(self.drop(m_feats)) # [B,nseg,2*] 216 | m_q, m_v = torch.split(m_trans, m_trans.size(2) // 2, dim=2) 217 | 218 | new_mq = m_q 219 | new_mk = m_k 220 | 221 | # applying multi-head attention 222 | w_list = [] 223 | mk_set = torch.split(new_mk, new_mk.size(2) // self.nheads, dim=2) 224 | mq_set = torch.split(new_mq, new_mq.size(2) // self.nheads, dim=2) 225 | mv_set = torch.split(m_v, m_v.size(2) // self.nheads, dim=2) 226 | for i in range(self.nheads): 227 | mk_slice, mq_slice, mv_slice = mk_set[i], mq_set[i], mv_set[i] # [B, nseg, *] 228 | 229 | # compute relation matrix; [B,nseg,nseg] 230 | m2m = mk_slice @ mq_slice.transpose(1,2) / ((self.odim // self.nheads) ** 0.5) 231 | m2m = m2m.masked_fill(mask.unsqueeze(1).eq(0), -1e9) # [B,nseg,nseg] 232 | m2m_w = torch.nn.functional.softmax(m2m, dim=2) # [B,nseg,nseg] 233 | w_list.append(m2m_w) 234 | 235 | # compute relation vector for each segment 236 | r = m2m_w @ mv_slice if (i==0) else torch.cat((r, m2m_w @ mv_slice), dim=2) 237 | 238 | updated_m =m_feats + r 239 | return updated_m 240 | 241 | 242 | class AttentivePooling(nn.Module): 243 | def __init__(self, cfg): 244 | self.cfg = cfg 245 | super(AttentivePooling, self).__init__() 246 | self.att_n = 1 247 | self.feat_dim = self.cfg.MODEL.FUSION.EMB_DIM 248 | self.att_hid_dim = self.cfg.MODEL.FUSION.EMB_DIM // 2 249 | self.use_embedding = True 250 | 251 | self.feat2att = nn.Linear(self.feat_dim, self.att_hid_dim, bias=False) 252 | self.to_alpha = nn.Linear(self.att_hid_dim, self.att_n, bias=False) 253 | if self.use_embedding: 254 | edim = self.cfg.MODEL.FUSION.EMB_DIM 255 | self.fc = nn.Linear(self.feat_dim, edim) 256 | 257 | def forward(self, feats, f_masks=None): 258 | """ 259 | Compute attention weights and attended feature (weighted sum) 260 | Args: 261 | feats: features where attention weights are computed; [B, A, D] 262 | f_masks: mask for effective features; [B, A] 263 | """ 264 | # check inputs 265 | assert len(feats.size()) == 3 or len(feats.size()) == 4 266 | assert f_masks is None or len(f_masks.size()) == 2 267 | 268 | # dealing with dimension 4 269 | if len(feats.size()) == 4: 270 | B, W, H, D = feats.size() 271 | feats = feats.view(B, W*H, D) 272 | 273 | # embedding feature vectors 274 | attn_f = self.feat2att(feats) # [B,A,hdim] 275 | 276 | # compute attention weights 277 | dot = torch.tanh(attn_f) # [B,A,hdim] 278 | alpha = self.to_alpha(dot) # [B,A,att_n] 279 | if f_masks is not None: 280 | alpha = alpha.masked_fill(f_masks.float().unsqueeze(2).eq(0), -1e9) 281 | attw = torch.nn.functional.softmax(alpha.transpose(1,2), dim=2) # [B,att_n,A] 282 | 283 | att_feats = attw @ feats # [B,att_n,D] 284 | att_feats = att_feats.squeeze(1) 285 | attw = attw.squeeze(1) 286 | if self.use_embedding: att_feats = self.fc(att_feats) 287 | 288 | return att_feats, attw 289 | 290 | 291 | class AttentionLocRegressor(nn.Module): 292 | def __init__(self, cfg): 293 | super(AttentionLocRegressor, self).__init__() 294 | self.cfg = cfg 295 | self.tatt = AttentivePooling(self.cfg) 296 | # Regression layer 297 | idim = self.cfg.MODEL.FUSION.EMB_DIM 298 | gdim = self.cfg.MODEL.FUSION.EMB_DIM 299 | nn_list = [ nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()] 300 | self.MLP_reg = nn.Sequential(*nn_list) 301 | 302 | def forward(self, semantic_aware_seg_feats, masks): 303 | # perform Eq. (13) and (14) 304 | summarized_vfeat, att_w = self.tatt(semantic_aware_seg_feats, masks) 305 | # perform Eq. (15) 306 | loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e] 307 | return loc, att_w 308 | 309 | 310 | class SimpleModel(nn.Module): 311 | def __init__(self,cfg): 312 | super().__init__() 313 | self.cfg = cfg 314 | self.query_encoder = SimpleSentenceEmbeddingModule(cfg) 315 | self.video_encoder = SimpleVideoEmbeddingModule(cfg) 316 | self.fusor = SimpleFusionModule(cfg) 317 | self.n_non_local = self.cfg.MODEL.NONLOCAL.NUM_LAYERS 318 | self.non_locals = nn.ModuleList([NonLocalBlock(cfg) for _ in range(self.n_non_local)]) 319 | self.loc_regressor = AttentionLocRegressor(cfg) 320 | 321 | def forward(self,inputs): 322 | # encode query 323 | query_labels = inputs['query_labels'] 324 | query_masks = inputs['query_masks'] 325 | encoded_query, encoded_sentence = self.query_encoder(query_labels, query_masks) 326 | # encode video 327 | vid_feats = inputs['video_feats'] 328 | vid_masks = inputs['video_masks'] 329 | encoded_video = self.video_encoder(vid_feats,vid_masks) 330 | attended_vid = self.fusor(encoded_query, query_masks, encoded_video, vid_masks) 331 | global_vid = attended_vid 332 | for non_local_layer in self.non_locals: 333 | global_vid = non_local_layer(global_vid,vid_masks.squeeze(2)) 334 | loc,attn_weight = self.loc_regressor(global_vid,vid_masks.squeeze(2)) 335 | return {"timestamps": loc, 336 | "attention_weights": attn_weight} 337 | 338 | 339 | 340 | 341 | -------------------------------------------------------------------------------- /models/simple_model_cross_modal_two_stage.py: -------------------------------------------------------------------------------- 1 | #%% 2 | """ 3 | simple_model.py 4 | **** 5 | simple, basic model for NLVL. 6 | - Query-Video matching with (Multi-Head Attention + ConvBNReLU) with residual connection 7 | - Video Encoding with simple GRU 8 | """ 9 | 10 | #%% 11 | # import things 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | #%% 17 | # model 18 | class SimpleSentenceEmbeddingModule(nn.Module): 19 | """ 20 | A Simple Query Embedding class 21 | """ 22 | def __init__(self, cfg): 23 | super().__init__() 24 | # config params 25 | self.cfg = cfg 26 | self.query_length = self.cfg.DATASET.MAX_LENGTH 27 | # embedding Layer 28 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM 29 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM 30 | self.embedding = nn.Embedding(emb_idim, emb_odim) 31 | # RNN Layer 32 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM 33 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True) 34 | # feature adjust 35 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 36 | self.feature_aggregation = nn.Sequential( 37 | nn.Linear(in_features=gru_hidden*2,out_features=emb_dim), 38 | nn.ReLU(), 39 | nn.Dropout(0.5)) 40 | 41 | def forward(self, query_labels, query_masks): 42 | """ 43 | encode query sequence using RNN and return logits over proposals. 44 | code adopted from LGI 45 | Args: 46 | query_labels: query_labels vectors of query; [B, vocab_size] 47 | query_masks: mask for query; [B,L] 48 | out_type: output type [word-level | sentenve-level | both] 49 | Returns: 50 | w_feats: word-level features; [B,L,2*h] 51 | s_feats: sentence-level feature; [B,2*h] 52 | """ 53 | # embedding query_labels data 54 | wemb = self.embedding(query_labels) # [B,L,emb_odim] 55 | # encoding query_labels data. 56 | max_len = query_labels.size(1) # == L 57 | # make word-wise feature 58 | length = query_masks.sum(1) # [B,] 59 | pack_wemb = nn.utils.rnn.pack_padded_sequence(wemb, length, batch_first=True, enforce_sorted=False) 60 | w_feats, _ = self.gru(pack_wemb) 61 | w_feats, max_ = nn.utils.rnn.pad_packed_sequence(w_feats, batch_first=True, total_length=max_len) 62 | w_feats = w_feats.contiguous() # [B,L,2*h] 63 | 64 | # get sentence feature 65 | B, L, H = w_feats.size() 66 | idx = (length-1).long() # 0-indexed 67 | idx = idx.view(B, 1, 1).expand(B, 1, H//2) 68 | fLSTM = w_feats[:,:,:H//2].gather(1, idx).view(B, H//2) 69 | bLSTM = w_feats[:,0,H//2:].view(B,H//2) 70 | s_feats = torch.cat([fLSTM, bLSTM], dim=1) 71 | 72 | # aggregae features 73 | w_feats = self.feature_aggregation(w_feats) 74 | return w_feats, s_feats 75 | 76 | 77 | class TransformerSentenceEmbeddingModule(nn.Module): 78 | """ 79 | A Simple Query Embedding class 80 | """ 81 | def __init__(self, cfg): 82 | super().__init__() 83 | # config params 84 | self.cfg = cfg 85 | self.query_length = self.cfg.DATASET.MAX_LENGTH 86 | # embedding Layer 87 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM 88 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM 89 | self.embedding = nn.Embedding(emb_idim, emb_odim) 90 | 91 | # RNN Layer 92 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM 93 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True) 94 | 95 | # Attention layer 96 | t_emb_dim = self.cfg.MODEL.QUERY.TRANSFORMER_DIM # 300 97 | #t_emb_dim = gru_hidden * 2 # 256 * 2 98 | self.attention = nn.MultiheadAttention(embed_dim=t_emb_dim, num_heads=4) 99 | 100 | # feature adjust 101 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 102 | 103 | self.feature_aggregation = nn.Sequential( 104 | nn.Linear(in_features=t_emb_dim, out_features=emb_dim), 105 | nn.ReLU(), 106 | nn.Dropout(0,5)) 107 | 108 | 109 | def forward(self, query_labels, query_masks): 110 | """ 111 | encode query sequence using RNN and return logits over proposals. 112 | code adopted from LGI 113 | Args: 114 | query_labels: query_labels vectors of query; [B, vocab_size] 115 | query_masks: mask for query; [B,L] 116 | out_type: output type [word-level | sentenve-level | both] 117 | Returns: 118 | w_feats: word-level features; [B,L,2*h] 119 | s_feats: sentence-level feature; [B,2*h] 120 | """ 121 | # embedding query_labels data 122 | wemb = self.embedding(query_labels) # [B,L,emb_odim] 123 | 124 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 125 | # [B, L, D] -> [L, B, D] 126 | attended_feature, weights = self.attention( 127 | query=torch.transpose(wemb, 0,1), 128 | key=torch.transpose(wemb, 0,1), 129 | value=torch.transpose(wemb, 0,1), 130 | key_padding_mask=key_padding_mask,) 131 | 132 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format 133 | # convolution? 134 | 135 | # aggregae features 136 | w_feats = self.feature_aggregation(attended_feature) 137 | #return w_feats, s_feats 138 | return w_feats 139 | 140 | class SimpleVideoEmbeddingModule(nn.Module): 141 | """ 142 | A simple Video Embedding Class 143 | """ 144 | def __init__(self, cfg): 145 | super().__init__() # Must call super __init__() 146 | # get configuration 147 | self.cfg = cfg 148 | # video gru 149 | vid_idim = self.cfg.MODEL.VIDEO.IDIM 150 | vid_gru_hdim = self.cfg.MODEL.VIDEO.GRU_HDIM 151 | self.gru = nn.GRU(input_size=vid_idim,hidden_size=vid_gru_hdim,batch_first=True,dropout=0.5,bidirectional=True) 152 | 153 | 154 | # video feature aggregation module 155 | catted_dim = vid_idim + vid_gru_hdim*2 156 | #catted_dim = vid_gru_hdim *2 157 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 158 | self.feature_aggregation = nn.Sequential( 159 | nn.Linear(in_features=catted_dim,out_features=emb_dim), 160 | nn.ReLU(), 161 | nn.Dropout(0.5), 162 | ) 163 | 164 | def forward(self, vid_feats, vid_masks): 165 | """ 166 | encode video features. Utilizes GRU. 167 | Args: 168 | vid_feats: video features 169 | vid_masks: mask for video 170 | Return: 171 | vid_features: hidden state features of the video 172 | """ 173 | length = vid_masks.sum(1).squeeze(1) 174 | packed_vid = nn.utils.rnn.pack_padded_sequence(vid_feats, length, batch_first=True, enforce_sorted=False) 175 | vid_hiddens, _ = self.gru(packed_vid) 176 | vid_hiddens, max_ = nn.utils.rnn.pad_packed_sequence(vid_hiddens, batch_first=True, total_length=vid_feats.shape[1]) 177 | #vid_output = self.feature_aggregation(vid_hiddens) 178 | 179 | vid_catted = torch.cat([vid_feats,vid_hiddens],dim=2) 180 | vid_output = self.feature_aggregation(vid_catted) 181 | return vid_output 182 | 183 | 184 | class TransformerVideoEmbeddingModule(nn.Module): 185 | """ 186 | A simple Video Embedding Class 187 | """ 188 | def __init__(self, cfg): 189 | super().__init__() # Must call super __init__() 190 | # get configuration 191 | self.cfg = cfg 192 | 193 | # video transformer 194 | vid_idim = self.cfg.MODEL.VIDEO.IDIM 195 | vid_transformer_hdim = self.cfg.MODEL.VIDEO.ANET.TRANSFORMER_DIM # 1024(charades), 1000 (anet) 196 | self.attention = nn.MultiheadAttention(embed_dim=vid_idim, num_heads=4) 197 | 198 | # video feature aggregation module 199 | catted_dim = vid_idim + vid_transformer_hdim 200 | 201 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 202 | self.feature_aggregation = nn.Sequential( 203 | nn.Linear(in_features=catted_dim,out_features=emb_dim), 204 | nn.ReLU(), 205 | nn.Dropout(0.5), 206 | ) 207 | 208 | def forward(self, vid_feats, vid_masks): 209 | """ 210 | encode video features. Utilizes GRU. 211 | Args: 212 | vid_feats: video features 213 | vid_masks: mask for video 214 | Return: 215 | vid_features: hidden state features of the video 216 | """ 217 | 218 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 219 | # [B, L, D] -> [L, B, D] 220 | attended_feature, weights = self.attention( 221 | query=torch.transpose(vid_feats, 0,1), 222 | key=torch.transpose(vid_feats, 0,1), 223 | value=torch.transpose(vid_feats, 0,1), 224 | key_padding_mask=key_padding_mask.squeeze(),) 225 | 226 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format 227 | # convolution? 228 | 229 | # aggregae features 230 | vid_catted = torch.cat([vid_feats,attended_feature],dim=2) 231 | vid_output = self.feature_aggregation(vid_catted) 232 | 233 | return vid_output 234 | 235 | class FusionConvBNReLU(nn.Module): 236 | def __init__(self,cfg): 237 | super().__init__() 238 | # get configuration 239 | self.cfg = cfg 240 | # modules 241 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 242 | kernel_size = self.cfg.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE 243 | padding = self.cfg.MODEL.FUSION.CONVBNRELU.PADDING 244 | self.module = nn.Sequential( 245 | nn.Conv1d(in_channels=emb_dim,out_channels=emb_dim,kernel_size=kernel_size,padding=padding), 246 | nn.BatchNorm1d(num_features=emb_dim), 247 | nn.ReLU()) 248 | 249 | def forward(self,feature): 250 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first) 251 | convolved_feature = self.module(transposed_feature) 252 | 253 | return torch.transpose(convolved_feature,1,2) 254 | 255 | def basic_block(idim, odim, ksize=3): 256 | layers = [] 257 | # 1st conv 258 | p = ksize // 2 259 | layers.append(nn.Conv1d(idim, odim, ksize, 1, p, bias=False)) 260 | layers.append(nn.BatchNorm1d(odim)) 261 | layers.append(nn.ReLU(inplace=True)) 262 | # 2nd conv 263 | layers.append(nn.Conv1d(odim, odim, ksize, 1, p, bias=False)) 264 | layers.append(nn.BatchNorm1d(odim)) 265 | 266 | return nn.Sequential(*layers) 267 | 268 | class FusionResBlock(nn.Module): 269 | def __init__(self, cfg): 270 | super().__init__() 271 | # get configuration 272 | self.cfg = cfg 273 | # modules 274 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 275 | kernel_size = self.cfg.MODEL.FUSION.RESBLOCK.KERNEL_SIZE 276 | padding = self.cfg.MODEL.FUSION.RESBLOCK.PADDING 277 | self.nblocks = self.cfg.MODEL.FUSION.RESBLOCK.NB_ITER 278 | 279 | # set layers 280 | self.blocks = nn.ModuleList() 281 | for i in range(self.nblocks): 282 | cur_block = basic_block(emb_dim, emb_dim, kernel_size) 283 | self.blocks.append(cur_block) 284 | 285 | def forward(self, feature): 286 | """ 287 | Args: 288 | inp: [B, input-Dim, L] 289 | out: [B, output-Dim, L] 290 | """ 291 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first) 292 | residual = transposed_feature 293 | for i in range(self.nblocks): 294 | out = self.blocks[i](residual) 295 | out += residual 296 | out = F.relu(out) 297 | residual = out 298 | 299 | return torch.transpose(out,1,2) 300 | 301 | class AttentionBlockS2V(nn.Module): 302 | def __init__(self,cfg): 303 | super().__init__() 304 | # get configuration 305 | self.cfg = cfg 306 | # modules 307 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 308 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD 309 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head) 310 | 311 | if self.cfg.MODEL.FUSION.USE_RESBLOCK: 312 | self.convbnrelu = FusionResBlock(cfg) 313 | else: 314 | self.convbnrelu = FusionConvBNReLU(cfg) 315 | 316 | def forward(self,vid_feats,query_feats,query_masks): 317 | # attnetion 318 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 319 | attended_feature, weights = self.attention( 320 | query=torch.transpose(vid_feats,0,1), 321 | key=torch.transpose(query_feats,0,1), 322 | value=torch.transpose(query_feats,0,1), 323 | key_padding_mask=key_padding_mask,) 324 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format 325 | # convolution 326 | convolved_feature = self.convbnrelu(attended_feature) + vid_feats 327 | return convolved_feature 328 | 329 | class AttentionBlockV2S(nn.Module): 330 | def __init__(self,cfg): 331 | super().__init__() 332 | # get configuration 333 | self.cfg = cfg 334 | # modules 335 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 336 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD 337 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head) 338 | 339 | if self.cfg.MODEL.FUSION.USE_RESBLOCK: 340 | self.convbnrelu = FusionResBlock(cfg) 341 | else: 342 | self.convbnrelu = FusionConvBNReLU(cfg) 343 | 344 | def forward(self,vid_feats,query_feats,vid_masks): 345 | # attnetion 346 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 347 | key_padding_mask = key_padding_mask.squeeze() 348 | attended_feature, weights = self.attention( 349 | query=torch.transpose(query_feats,0,1), 350 | key=torch.transpose(vid_feats,0,1), 351 | value=torch.transpose(vid_feats,0,1), 352 | key_padding_mask=key_padding_mask,) 353 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format 354 | # convolution 355 | convolved_feature = self.convbnrelu(attended_feature) + query_feats 356 | return convolved_feature 357 | 358 | class SimpleFusionModule(nn.Module): 359 | def __init__(self, cfg): 360 | super().__init__() 361 | # get configuration 362 | self.cfg = cfg 363 | # attention module 364 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS 365 | self.layers = [] 366 | for _ in range(num_layers): 367 | self.layers.append(AttentionBlockS2V(cfg)) 368 | self.layers = nn.ModuleList(self.layers) 369 | 370 | def forward(self, query_feats, query_masks, vid_feats, vid_masks): 371 | attended_vid_feats = vid_feats 372 | for attn_layer in self.layers: 373 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks) 374 | return attended_vid_feats 375 | 376 | class SimpleFusionModuleSent(nn.Module): 377 | def __init__(self, cfg): 378 | super().__init__() 379 | # get configuration 380 | self.cfg = cfg 381 | # attention module 382 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS 383 | self.layers = [] 384 | for _ in range(num_layers): 385 | self.layers.append(AttentionBlockV2S(cfg)) 386 | self.layers = nn.ModuleList(self.layers) 387 | 388 | def forward(self, query_feats, query_masks, vid_feats, vid_masks): 389 | attended_query_feats = query_feats 390 | for attn_layer in self.layers: 391 | attended_query_feats = attn_layer(vid_feats=vid_feats, query_feats=attended_query_feats, vid_masks=vid_masks) 392 | return attended_query_feats 393 | 394 | class TwostageSimpleFusionModule(nn.Module): 395 | def __init__(self, cfg): 396 | super().__init__() 397 | # get configuration 398 | self.cfg = cfg 399 | # attention module 400 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS 401 | self.layers = [] 402 | for _ in range(num_layers): 403 | self.layers.append(AttentionBlockS2V(cfg)) 404 | self.layers = nn.ModuleList(self.layers) 405 | 406 | def forward(self, query_feats, query_masks, vid_feats, vid_masks): 407 | attended_vid_feats = vid_feats 408 | for attn_layer in self.layers: 409 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks) 410 | return attended_vid_feats 411 | 412 | class NonLocalBlock(nn.Module): 413 | """ 414 | Nonlocal block used for obtaining global feature. 415 | code borrowed from LGI 416 | """ 417 | def __init__(self, cfg): 418 | super(NonLocalBlock, self).__init__() 419 | self.cfg = cfg 420 | # dims 421 | self.idim = self.cfg.MODEL.FUSION.EMB_DIM 422 | self.odim = self.cfg.MODEL.FUSION.EMB_DIM 423 | self.nheads = self.cfg.MODEL.NONLOCAL.NUM_HEAD 424 | 425 | # options 426 | self.use_bias = self.cfg.MODEL.NONLOCAL.USE_BIAS 427 | 428 | # layers 429 | self.c_lin = nn.Linear(self.idim, self.odim*2, bias=self.use_bias) 430 | self.v_lin = nn.Linear(self.idim, self.odim, bias=self.use_bias) 431 | 432 | self.relu = nn.ReLU() 433 | self.sigmoid = nn.Sigmoid() 434 | self.drop = nn.Dropout(self.cfg.MODEL.NONLOCAL.DROPOUT) 435 | 436 | def forward(self, m_feats, mask): 437 | """ 438 | Inputs: 439 | m_feats: segment-level multimodal feature [B,nseg,*] 440 | mask: mask [B,nseg] 441 | Outputs: 442 | updated_m: updated multimodal feature [B,nseg,*] 443 | """ 444 | 445 | mask = mask.float() 446 | B, nseg = mask.size() 447 | 448 | # key, query, value 449 | m_k = self.v_lin(self.drop(m_feats)) # [B,num_seg,*] 450 | m_trans = self.c_lin(self.drop(m_feats)) # [B,nseg,2*] 451 | m_q, m_v = torch.split(m_trans, m_trans.size(2) // 2, dim=2) 452 | 453 | new_mq = m_q 454 | new_mk = m_k 455 | 456 | # applying multi-head attention 457 | w_list = [] 458 | mk_set = torch.split(new_mk, new_mk.size(2) // self.nheads, dim=2) 459 | mq_set = torch.split(new_mq, new_mq.size(2) // self.nheads, dim=2) 460 | mv_set = torch.split(m_v, m_v.size(2) // self.nheads, dim=2) 461 | 462 | for i in range(self.nheads): 463 | mk_slice, mq_slice, mv_slice = mk_set[i], mq_set[i], mv_set[i] # [B, nseg, *] 464 | 465 | # compute relation matrix; [B,nseg,nseg] 466 | m2m = mk_slice @ mq_slice.transpose(1,2) / ((self.odim // self.nheads) ** 0.5) 467 | m2m = m2m.masked_fill(mask.unsqueeze(1).eq(0), -1e9) # [B,nseg,nseg] 468 | m2m_w = torch.nn.functional.softmax(m2m, dim=2) # [B,nseg,nseg] 469 | w_list.append(m2m_w) 470 | 471 | # compute relation vector for each segment 472 | r = m2m_w @ mv_slice if (i==0) else torch.cat((r, m2m_w @ mv_slice), dim=2) 473 | 474 | updated_m =m_feats + r 475 | return updated_m 476 | 477 | class AttentivePooling(nn.Module): 478 | def __init__(self, cfg): 479 | self.cfg = cfg 480 | super(AttentivePooling, self).__init__() 481 | self.att_n = 1 482 | self.feat_dim = self.cfg.MODEL.FUSION.EMB_DIM 483 | self.att_hid_dim = self.cfg.MODEL.FUSION.EMB_DIM // 2 484 | self.use_embedding = True 485 | 486 | self.feat2att = nn.Linear(self.feat_dim, self.att_hid_dim, bias=False) 487 | self.to_alpha = nn.Linear(self.att_hid_dim, self.att_n, bias=False) 488 | if self.use_embedding: 489 | edim = self.cfg.MODEL.FUSION.EMB_DIM 490 | self.fc = nn.Linear(self.feat_dim, edim) 491 | 492 | def forward(self, feats, f_masks=None): 493 | """ 494 | Compute attention weights and attended feature (weighted sum) 495 | Args: 496 | feats: features where attention weights are computed; [B, A, D] 497 | f_masks: mask for effective features; [B, A] 498 | """ 499 | # check inputs 500 | assert len(feats.size()) == 3 or len(feats.size()) == 4 501 | assert f_masks is None or len(f_masks.size()) == 2 502 | 503 | # dealing with dimension 4 504 | if len(feats.size()) == 4: 505 | B, W, H, D = feats.size() 506 | feats = feats.view(B, W*H, D) 507 | 508 | # embedding feature vectors 509 | attn_f = self.feat2att(feats) # [B,A,hdim] 510 | 511 | # compute attention weights 512 | dot = torch.tanh(attn_f) # [B,A,hdim] 513 | alpha = self.to_alpha(dot) # [B,A,att_n] 514 | if f_masks is not None: 515 | alpha = alpha.masked_fill(f_masks.float().unsqueeze(2).eq(0), -1e9) 516 | attw = torch.nn.functional.softmax(alpha.transpose(1,2), dim=2) # [B,att_n,A] 517 | 518 | att_feats = attw @ feats # [B,att_n,D] 519 | att_feats = att_feats.squeeze(1) 520 | attw = attw.squeeze(1) 521 | if self.use_embedding: att_feats = self.fc(att_feats) 522 | 523 | return att_feats, attw 524 | 525 | 526 | class AttentionLocRegressor(nn.Module): 527 | def __init__(self, cfg): 528 | super(AttentionLocRegressor, self).__init__() 529 | self.cfg = cfg 530 | self.tatt_vid = AttentivePooling(self.cfg) 531 | self.tatt_query = AttentivePooling(self.cfg) 532 | # Regression layer 533 | idim = self.cfg.MODEL.FUSION.EMB_DIM * 2 534 | gdim = self.cfg.MODEL.FUSION.EMB_DIM 535 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()] 536 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)] 537 | self.MLP_reg = nn.Sequential(*nn_list) 538 | 539 | 540 | def forward(self, semantic_aware_seg_vid_feats, vid_masks, semantic_aware_seg_query_feat, query_masks): 541 | # perform Eq. (13) and (14) 542 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks) 543 | summarized_qfeat, att_w_q = self.tatt_query(semantic_aware_seg_query_feat, query_masks) 544 | # perform Eq. (15) 545 | summarized_feats = torch.cat((summarized_vfeat, summarized_qfeat), dim=1) 546 | #loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e] 547 | loc = self.MLP_reg(summarized_feats) # loc = [t^s, t^e] 548 | return loc, att_w 549 | 550 | 551 | class TwostageAttentionLocRegressor(nn.Module): 552 | def __init__(self, cfg): 553 | super(TwostageAttentionLocRegressor, self).__init__() 554 | self.cfg = cfg 555 | self.tatt_vid = AttentivePooling(self.cfg) 556 | # Regression layer 557 | idim = self.cfg.MODEL.FUSION.EMB_DIM 558 | gdim = self.cfg.MODEL.FUSION.EMB_DIM 559 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()] 560 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)] 561 | self.MLP_reg = nn.Sequential(*nn_list) 562 | 563 | def forward(self, semantic_aware_seg_vid_feats, vid_masks): 564 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks) 565 | loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e] 566 | return loc, att_w 567 | 568 | 569 | class SimpleModel(nn.Module): 570 | def __init__(self,cfg): 571 | super().__init__() 572 | self.cfg = cfg 573 | self.query_encoder = SimpleSentenceEmbeddingModule(cfg) 574 | self.video_encoder = SimpleVideoEmbeddingModule(cfg) 575 | self.v_fusor = SimpleFusionModule(cfg) 576 | self.s_fusor = SimpleFusionModuleSent(cfg) 577 | self.cv_fusor = TwostageSimpleFusionModule(cfg) 578 | self.n_non_local = self.cfg.MODEL.NONLOCAL.NUM_LAYERS 579 | self.non_locals_layer = nn.ModuleList([NonLocalBlock(cfg) for _ in range(self.n_non_local)]) 580 | self.loc_regressor = AttentionLocRegressor(cfg) 581 | self.loc_regressor_two_stage = TwostageAttentionLocRegressor(cfg) 582 | 583 | def forward(self,inputs): 584 | # encode query 585 | query_labels = inputs['query_labels'] 586 | query_masks = inputs['query_masks'] 587 | encoded_query, encoded_sentence = self.query_encoder(query_labels, query_masks) # encoded_query [B, L, D] D = 256 588 | 589 | # encode video 590 | vid_feats = inputs['video_feats'] 591 | vid_masks = inputs['video_masks'] 592 | encoded_video = self.video_encoder(vid_feats,vid_masks) 593 | 594 | # Crossmodality Attention 595 | attended_sent = self.s_fusor(encoded_query, query_masks, encoded_video, vid_masks) 596 | attended_vid = self.v_fusor(encoded_query, query_masks, encoded_video, vid_masks) 597 | two_stage_attended_vid = self.cv_fusor(attended_sent, query_masks, attended_vid, vid_masks) 598 | 599 | global_two_stage_vid = two_stage_attended_vid 600 | for non_local_layer in self.non_locals_layer: 601 | global_two_stage_vid = non_local_layer(global_two_stage_vid, vid_masks.squeeze(2)) 602 | 603 | loc, temporal_attn_weight = self.loc_regressor_two_stage(global_two_stage_vid, vid_masks.squeeze(2)) 604 | 605 | return {"timestamps": loc, "attention_weights": temporal_attn_weight} -------------------------------------------------------------------------------- /models/simple_model_cross_modal_two_stage_temperature.py: -------------------------------------------------------------------------------- 1 | #%% 2 | """ 3 | simple_model.py 4 | **** 5 | simple, basic model for NLVL. 6 | - Query-Video matching with (Multi-Head Attention + ConvBNReLU) with residual connection 7 | - Video Encoding with simple GRU 8 | """ 9 | 10 | #%% 11 | # import things 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | 16 | #%% 17 | # model 18 | class SimpleSentenceEmbeddingModule(nn.Module): 19 | """ 20 | A Simple Query Embedding class 21 | """ 22 | def __init__(self, cfg): 23 | super().__init__() 24 | # config params 25 | self.cfg = cfg 26 | self.query_length = self.cfg.DATASET.MAX_LENGTH 27 | # embedding Layer 28 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM 29 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM 30 | self.embedding = nn.Embedding(emb_idim, emb_odim) 31 | # RNN Layer 32 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM 33 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True) 34 | # feature adjust 35 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 36 | self.feature_aggregation = nn.Sequential( 37 | nn.Linear(in_features=gru_hidden*2,out_features=emb_dim), 38 | nn.ReLU(), 39 | nn.Dropout(0.5)) 40 | 41 | def forward(self, query_labels, query_masks): 42 | """ 43 | encode query sequence using RNN and return logits over proposals. 44 | code adopted from LGI 45 | Args: 46 | query_labels: query_labels vectors of query; [B, vocab_size] 47 | query_masks: mask for query; [B,L] 48 | out_type: output type [word-level | sentenve-level | both] 49 | Returns: 50 | w_feats: word-level features; [B,L,2*h] 51 | s_feats: sentence-level feature; [B,2*h] 52 | """ 53 | # embedding query_labels data 54 | wemb = self.embedding(query_labels) # [B,L,emb_odim] 55 | # encoding query_labels data. 56 | max_len = query_labels.size(1) # == L 57 | # make word-wise feature 58 | length = query_masks.sum(1) # [B,] 59 | pack_wemb = nn.utils.rnn.pack_padded_sequence(wemb, length, batch_first=True, enforce_sorted=False) 60 | w_feats, _ = self.gru(pack_wemb) 61 | w_feats, max_ = nn.utils.rnn.pad_packed_sequence(w_feats, batch_first=True, total_length=max_len) 62 | w_feats = w_feats.contiguous() # [B,L,2*h] 63 | 64 | # get sentence feature 65 | B, L, H = w_feats.size() 66 | idx = (length-1).long() # 0-indexed 67 | idx = idx.view(B, 1, 1).expand(B, 1, H//2) 68 | fLSTM = w_feats[:,:,:H//2].gather(1, idx).view(B, H//2) 69 | bLSTM = w_feats[:,0,H//2:].view(B,H//2) 70 | s_feats = torch.cat([fLSTM, bLSTM], dim=1) 71 | 72 | # aggregae features 73 | w_feats = self.feature_aggregation(w_feats) 74 | return w_feats, s_feats 75 | 76 | 77 | class TransformerSentenceEmbeddingModule(nn.Module): 78 | """ 79 | A Simple Query Embedding class 80 | """ 81 | def __init__(self, cfg): 82 | super().__init__() 83 | # config params 84 | self.cfg = cfg 85 | self.query_length = self.cfg.DATASET.MAX_LENGTH 86 | # embedding Layer 87 | emb_idim = self.cfg.MODEL.QUERY.EMB_IDIM 88 | emb_odim = self.cfg.MODEL.QUERY.EMB_ODIM 89 | self.embedding = nn.Embedding(emb_idim, emb_odim) 90 | 91 | # RNN Layer 92 | gru_hidden = self.cfg.MODEL.QUERY.GRU_HDIM 93 | self.gru = nn.GRU(input_size=emb_odim,hidden_size=gru_hidden,num_layers=1,batch_first=True,bidirectional=True) 94 | 95 | # Attention layer 96 | t_emb_dim = self.cfg.MODEL.QUERY.TRANSFORMER_DIM # 300 97 | #t_emb_dim = gru_hidden * 2 # 256 * 2 98 | self.attention = nn.MultiheadAttention(embed_dim=t_emb_dim, num_heads=4) 99 | 100 | # feature adjust 101 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 102 | 103 | self.feature_aggregation = nn.Sequential( 104 | nn.Linear(in_features=t_emb_dim, out_features=emb_dim), 105 | nn.ReLU(), 106 | nn.Dropout(0,5)) 107 | 108 | 109 | def forward(self, query_labels, query_masks): 110 | """ 111 | encode query sequence using RNN and return logits over proposals. 112 | code adopted from LGI 113 | Args: 114 | query_labels: query_labels vectors of query; [B, vocab_size] 115 | query_masks: mask for query; [B,L] 116 | out_type: output type [word-level | sentenve-level | both] 117 | Returns: 118 | w_feats: word-level features; [B,L,2*h] 119 | s_feats: sentence-level feature; [B,2*h] 120 | """ 121 | # embedding query_labels data 122 | wemb = self.embedding(query_labels) # [B,L,emb_odim] 123 | 124 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 125 | # [B, L, D] -> [L, B, D] 126 | attended_feature, weights = self.attention( 127 | query=torch.transpose(wemb, 0,1), 128 | key=torch.transpose(wemb, 0,1), 129 | value=torch.transpose(wemb, 0,1), 130 | key_padding_mask=key_padding_mask,) 131 | 132 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format 133 | # convolution? 134 | 135 | # aggregae features 136 | w_feats = self.feature_aggregation(attended_feature) 137 | #return w_feats, s_feats 138 | return w_feats 139 | 140 | class SimpleVideoEmbeddingModule(nn.Module): 141 | """ 142 | A simple Video Embedding Class 143 | """ 144 | def __init__(self, cfg): 145 | super().__init__() # Must call super __init__() 146 | # get configuration 147 | self.cfg = cfg 148 | # video gru 149 | vid_idim = self.cfg.MODEL.VIDEO.IDIM 150 | vid_gru_hdim = self.cfg.MODEL.VIDEO.GRU_HDIM 151 | self.gru = nn.GRU(input_size=vid_idim,hidden_size=vid_gru_hdim,batch_first=True,dropout=0.5,bidirectional=True) 152 | 153 | 154 | # video feature aggregation module 155 | catted_dim = vid_idim + vid_gru_hdim*2 156 | #catted_dim = vid_gru_hdim *2 157 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 158 | self.feature_aggregation = nn.Sequential( 159 | nn.Linear(in_features=catted_dim,out_features=emb_dim), 160 | nn.ReLU(), 161 | nn.Dropout(0.5), 162 | ) 163 | 164 | def forward(self, vid_feats, vid_masks): 165 | """ 166 | encode video features. Utilizes GRU. 167 | Args: 168 | vid_feats: video features 169 | vid_masks: mask for video 170 | Return: 171 | vid_features: hidden state features of the video 172 | """ 173 | length = vid_masks.sum(1).squeeze(1) 174 | packed_vid = nn.utils.rnn.pack_padded_sequence(vid_feats, length, batch_first=True, enforce_sorted=False) 175 | vid_hiddens, _ = self.gru(packed_vid) 176 | vid_hiddens, max_ = nn.utils.rnn.pad_packed_sequence(vid_hiddens, batch_first=True, total_length=vid_feats.shape[1]) 177 | #vid_output = self.feature_aggregation(vid_hiddens) 178 | 179 | vid_catted = torch.cat([vid_feats,vid_hiddens],dim=2) 180 | vid_output = self.feature_aggregation(vid_catted) 181 | return vid_output 182 | 183 | 184 | class TransformerVideoEmbeddingModule(nn.Module): 185 | """ 186 | A simple Video Embedding Class 187 | """ 188 | def __init__(self, cfg): 189 | super().__init__() # Must call super __init__() 190 | # get configuration 191 | self.cfg = cfg 192 | 193 | # video transformer 194 | vid_idim = self.cfg.MODEL.VIDEO.IDIM 195 | vid_transformer_hdim = self.cfg.MODEL.VIDEO.ANET.TRANSFORMER_DIM # 1024(charades), 1000 (anet) 196 | self.attention = nn.MultiheadAttention(embed_dim=vid_idim, num_heads=4) 197 | 198 | # video feature aggregation module 199 | catted_dim = vid_idim + vid_transformer_hdim 200 | 201 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 202 | self.feature_aggregation = nn.Sequential( 203 | nn.Linear(in_features=catted_dim,out_features=emb_dim), 204 | nn.ReLU(), 205 | nn.Dropout(0.5), 206 | ) 207 | 208 | def forward(self, vid_feats, vid_masks): 209 | """ 210 | encode video features. Utilizes GRU. 211 | Args: 212 | vid_feats: video features 213 | vid_masks: mask for video 214 | Return: 215 | vid_features: hidden state features of the video 216 | """ 217 | 218 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 219 | # [B, L, D] -> [L, B, D] 220 | attended_feature, weights = self.attention( 221 | query=torch.transpose(vid_feats, 0,1), 222 | key=torch.transpose(vid_feats, 0,1), 223 | value=torch.transpose(vid_feats, 0,1), 224 | key_padding_mask=key_padding_mask.squeeze(),) 225 | 226 | attended_feature = torch.transpose(attended_feature, 0, 1) # to [B, L, D] format 227 | # convolution? 228 | 229 | # aggregae features 230 | vid_catted = torch.cat([vid_feats,attended_feature],dim=2) 231 | vid_output = self.feature_aggregation(vid_catted) 232 | 233 | return vid_output 234 | 235 | class FusionConvBNReLU(nn.Module): 236 | def __init__(self,cfg): 237 | super().__init__() 238 | # get configuration 239 | self.cfg = cfg 240 | # modules 241 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 242 | kernel_size = self.cfg.MODEL.FUSION.CONVBNRELU.KERNEL_SIZE 243 | padding = self.cfg.MODEL.FUSION.CONVBNRELU.PADDING 244 | self.module = nn.Sequential( 245 | nn.Conv1d(in_channels=emb_dim,out_channels=emb_dim,kernel_size=kernel_size,padding=padding), 246 | nn.BatchNorm1d(num_features=emb_dim), 247 | nn.ReLU()) 248 | 249 | def forward(self,feature): 250 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first) 251 | convolved_feature = self.module(transposed_feature) 252 | 253 | return torch.transpose(convolved_feature,1,2) 254 | 255 | def basic_block(idim, odim, ksize=3): 256 | layers = [] 257 | # 1st conv 258 | p = ksize // 2 259 | layers.append(nn.Conv1d(idim, odim, ksize, 1, p, bias=False)) 260 | layers.append(nn.BatchNorm1d(odim)) 261 | layers.append(nn.ReLU(inplace=True)) 262 | # 2nd conv 263 | layers.append(nn.Conv1d(odim, odim, ksize, 1, p, bias=False)) 264 | layers.append(nn.BatchNorm1d(odim)) 265 | 266 | return nn.Sequential(*layers) 267 | 268 | class FusionResBlock(nn.Module): 269 | def __init__(self, cfg): 270 | super().__init__() 271 | # get configuration 272 | self.cfg = cfg 273 | # modules 274 | emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 275 | kernel_size = self.cfg.MODEL.FUSION.RESBLOCK.KERNEL_SIZE 276 | padding = self.cfg.MODEL.FUSION.RESBLOCK.PADDING 277 | self.nblocks = self.cfg.MODEL.FUSION.RESBLOCK.NB_ITER 278 | 279 | # set layers 280 | self.blocks = nn.ModuleList() 281 | for i in range(self.nblocks): 282 | cur_block = basic_block(emb_dim, emb_dim, kernel_size) 283 | self.blocks.append(cur_block) 284 | 285 | def forward(self, feature): 286 | """ 287 | Args: 288 | inp: [B, input-Dim, L] 289 | out: [B, output-Dim, L] 290 | """ 291 | transposed_feature = torch.transpose(feature,1,2) # to [B,D,L] format (channels first) 292 | residual = transposed_feature 293 | for i in range(self.nblocks): 294 | out = self.blocks[i](residual) 295 | out += residual 296 | out = F.relu(out) 297 | residual = out 298 | 299 | return torch.transpose(out,1,2) 300 | 301 | class AttentionBlockS2V(nn.Module): 302 | def __init__(self,cfg): 303 | super().__init__() 304 | # get configuration 305 | self.cfg = cfg 306 | # modules 307 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 308 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD 309 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head) 310 | 311 | if self.cfg.MODEL.FUSION.USE_RESBLOCK: 312 | self.convbnrelu = FusionResBlock(cfg) 313 | else: 314 | self.convbnrelu = FusionConvBNReLU(cfg) 315 | 316 | def forward(self,vid_feats,query_feats,query_masks): 317 | # attnetion 318 | key_padding_mask = query_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 319 | attended_feature, weights = self.attention( 320 | query=torch.transpose(vid_feats,0,1), 321 | key=torch.transpose(query_feats,0,1), 322 | value=torch.transpose(query_feats,0,1), 323 | key_padding_mask=key_padding_mask,) 324 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format 325 | # convolution 326 | convolved_feature = self.convbnrelu(attended_feature) + vid_feats 327 | return convolved_feature 328 | 329 | class AttentionBlockV2S(nn.Module): 330 | def __init__(self,cfg): 331 | super().__init__() 332 | # get configuration 333 | self.cfg = cfg 334 | self.temp = cfg.MODEL.QUERY.TEMPERATURE 335 | # modules 336 | emb_dim = emb_dim = self.cfg.MODEL.FUSION.EMB_DIM 337 | num_head = self.cfg.MODEL.FUSION.NUM_HEAD 338 | self.attention = nn.MultiheadAttention(embed_dim=emb_dim,num_heads=num_head) 339 | 340 | if self.cfg.MODEL.FUSION.USE_RESBLOCK: 341 | self.convbnrelu = FusionResBlock(cfg) 342 | else: 343 | self.convbnrelu = FusionConvBNReLU(cfg) 344 | 345 | def forward(self,vid_feats,query_feats,vid_masks): 346 | # attnetion 347 | key_padding_mask = vid_masks < 0.1 # if true, not allowed to attend. if false, attend to it. 348 | key_padding_mask = key_padding_mask.squeeze() 349 | attended_feature, weights = self.attention( 350 | query=torch.transpose(query_feats,0,1) / self.temp, 351 | key=torch.transpose(vid_feats,0,1), 352 | value=torch.transpose(vid_feats,0,1), 353 | key_padding_mask=key_padding_mask,) 354 | attended_feature = torch.transpose(attended_feature,0,1) # to [B,L,D] format 355 | # convolution 356 | convolved_feature = self.convbnrelu(attended_feature) + query_feats 357 | return convolved_feature 358 | 359 | class SimpleFusionModule(nn.Module): 360 | def __init__(self, cfg): 361 | super().__init__() 362 | # get configuration 363 | self.cfg = cfg 364 | # attention module 365 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS 366 | self.layers = [] 367 | for _ in range(num_layers): 368 | self.layers.append(AttentionBlockS2V(cfg)) 369 | self.layers = nn.ModuleList(self.layers) 370 | 371 | def forward(self, query_feats, query_masks, vid_feats, vid_masks): 372 | attended_vid_feats = vid_feats 373 | for attn_layer in self.layers: 374 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks) 375 | return attended_vid_feats 376 | 377 | class SimpleFusionModuleSent(nn.Module): 378 | def __init__(self, cfg): 379 | super().__init__() 380 | # get configuration 381 | self.cfg = cfg 382 | # attention module 383 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS 384 | self.layers = [] 385 | for _ in range(num_layers): 386 | self.layers.append(AttentionBlockV2S(cfg)) 387 | self.layers = nn.ModuleList(self.layers) 388 | 389 | def forward(self, query_feats, query_masks, vid_feats, vid_masks): 390 | attended_query_feats = query_feats 391 | for attn_layer in self.layers: 392 | attended_query_feats = attn_layer(vid_feats=vid_feats, query_feats=attended_query_feats, vid_masks=vid_masks) 393 | return attended_query_feats 394 | 395 | class TwostageSimpleFusionModule(nn.Module): 396 | def __init__(self, cfg): 397 | super().__init__() 398 | # get configuration 399 | self.cfg = cfg 400 | # attention module 401 | num_layers = self.cfg.MODEL.FUSION.NUM_LAYERS 402 | self.layers = [] 403 | for _ in range(num_layers): 404 | self.layers.append(AttentionBlockS2V(cfg)) 405 | self.layers = nn.ModuleList(self.layers) 406 | 407 | def forward(self, query_feats, query_masks, vid_feats, vid_masks): 408 | attended_vid_feats = vid_feats 409 | for attn_layer in self.layers: 410 | attended_vid_feats = attn_layer(vid_feats=attended_vid_feats, query_feats=query_feats, query_masks=query_masks) 411 | return attended_vid_feats 412 | 413 | class NonLocalBlock(nn.Module): 414 | """ 415 | Nonlocal block used for obtaining global feature. 416 | code borrowed from LGI 417 | """ 418 | def __init__(self, cfg): 419 | super(NonLocalBlock, self).__init__() 420 | self.cfg = cfg 421 | # dims 422 | self.idim = self.cfg.MODEL.FUSION.EMB_DIM 423 | self.odim = self.cfg.MODEL.FUSION.EMB_DIM 424 | self.nheads = self.cfg.MODEL.NONLOCAL.NUM_HEAD 425 | 426 | # options 427 | self.use_bias = self.cfg.MODEL.NONLOCAL.USE_BIAS 428 | 429 | # layers 430 | self.c_lin = nn.Linear(self.idim, self.odim*2, bias=self.use_bias) 431 | self.v_lin = nn.Linear(self.idim, self.odim, bias=self.use_bias) 432 | 433 | self.relu = nn.ReLU() 434 | self.sigmoid = nn.Sigmoid() 435 | self.drop = nn.Dropout(self.cfg.MODEL.NONLOCAL.DROPOUT) 436 | 437 | def forward(self, m_feats, mask): 438 | """ 439 | Inputs: 440 | m_feats: segment-level multimodal feature [B,nseg,*] 441 | mask: mask [B,nseg] 442 | Outputs: 443 | updated_m: updated multimodal feature [B,nseg,*] 444 | """ 445 | 446 | mask = mask.float() 447 | B, nseg = mask.size() 448 | 449 | # key, query, value 450 | m_k = self.v_lin(self.drop(m_feats)) # [B,num_seg,*] 451 | m_trans = self.c_lin(self.drop(m_feats)) # [B,nseg,2*] 452 | m_q, m_v = torch.split(m_trans, m_trans.size(2) // 2, dim=2) 453 | 454 | new_mq = m_q 455 | new_mk = m_k 456 | 457 | # applying multi-head attention 458 | w_list = [] 459 | mk_set = torch.split(new_mk, new_mk.size(2) // self.nheads, dim=2) 460 | mq_set = torch.split(new_mq, new_mq.size(2) // self.nheads, dim=2) 461 | mv_set = torch.split(m_v, m_v.size(2) // self.nheads, dim=2) 462 | 463 | for i in range(self.nheads): 464 | mk_slice, mq_slice, mv_slice = mk_set[i], mq_set[i], mv_set[i] # [B, nseg, *] 465 | 466 | # compute relation matrix; [B,nseg,nseg] 467 | m2m = mk_slice @ mq_slice.transpose(1,2) / ((self.odim // self.nheads) ** 0.5) 468 | m2m = m2m.masked_fill(mask.unsqueeze(1).eq(0), -1e9) # [B,nseg,nseg] 469 | m2m_w = torch.nn.functional.softmax(m2m, dim=2) # [B,nseg,nseg] 470 | w_list.append(m2m_w) 471 | 472 | # compute relation vector for each segment 473 | r = m2m_w @ mv_slice if (i==0) else torch.cat((r, m2m_w @ mv_slice), dim=2) 474 | 475 | updated_m =m_feats + r 476 | return updated_m 477 | 478 | class AttentivePooling(nn.Module): 479 | def __init__(self, cfg): 480 | self.cfg = cfg 481 | super(AttentivePooling, self).__init__() 482 | self.att_n = 1 483 | self.feat_dim = self.cfg.MODEL.FUSION.EMB_DIM 484 | self.att_hid_dim = self.cfg.MODEL.FUSION.EMB_DIM // 2 485 | self.use_embedding = True 486 | 487 | self.feat2att = nn.Linear(self.feat_dim, self.att_hid_dim, bias=False) 488 | self.to_alpha = nn.Linear(self.att_hid_dim, self.att_n, bias=False) 489 | if self.use_embedding: 490 | edim = self.cfg.MODEL.FUSION.EMB_DIM 491 | self.fc = nn.Linear(self.feat_dim, edim) 492 | 493 | def forward(self, feats, f_masks=None): 494 | """ 495 | Compute attention weights and attended feature (weighted sum) 496 | Args: 497 | feats: features where attention weights are computed; [B, A, D] 498 | f_masks: mask for effective features; [B, A] 499 | """ 500 | # check inputs 501 | assert len(feats.size()) == 3 or len(feats.size()) == 4 502 | assert f_masks is None or len(f_masks.size()) == 2 503 | 504 | # dealing with dimension 4 505 | if len(feats.size()) == 4: 506 | B, W, H, D = feats.size() 507 | feats = feats.view(B, W*H, D) 508 | 509 | # embedding feature vectors 510 | attn_f = self.feat2att(feats) # [B,A,hdim] 511 | 512 | # compute attention weights 513 | dot = torch.tanh(attn_f) # [B,A,hdim] 514 | alpha = self.to_alpha(dot) # [B,A,att_n] 515 | if f_masks is not None: 516 | alpha = alpha.masked_fill(f_masks.float().unsqueeze(2).eq(0), -1e9) 517 | attw = torch.nn.functional.softmax(alpha.transpose(1,2), dim=2) # [B,att_n,A] 518 | 519 | att_feats = attw @ feats # [B,att_n,D] 520 | att_feats = att_feats.squeeze(1) 521 | attw = attw.squeeze(1) 522 | if self.use_embedding: att_feats = self.fc(att_feats) 523 | 524 | return att_feats, attw 525 | 526 | 527 | class AttentionLocRegressor(nn.Module): 528 | def __init__(self, cfg): 529 | super(AttentionLocRegressor, self).__init__() 530 | self.cfg = cfg 531 | self.tatt_vid = AttentivePooling(self.cfg) 532 | self.tatt_query = AttentivePooling(self.cfg) 533 | # Regression layer 534 | idim = self.cfg.MODEL.FUSION.EMB_DIM * 2 535 | gdim = self.cfg.MODEL.FUSION.EMB_DIM 536 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()] 537 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)] 538 | self.MLP_reg = nn.Sequential(*nn_list) 539 | 540 | 541 | def forward(self, semantic_aware_seg_vid_feats, vid_masks, semantic_aware_seg_query_feat, query_masks): 542 | # perform Eq. (13) and (14) 543 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks) 544 | summarized_qfeat, att_w_q = self.tatt_query(semantic_aware_seg_query_feat, query_masks) 545 | # perform Eq. (15) 546 | summarized_feats = torch.cat((summarized_vfeat, summarized_qfeat), dim=1) 547 | #loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e] 548 | loc = self.MLP_reg(summarized_feats) # loc = [t^s, t^e] 549 | return loc, att_w 550 | 551 | 552 | class TwostageAttentionLocRegressor(nn.Module): 553 | def __init__(self, cfg): 554 | super(TwostageAttentionLocRegressor, self).__init__() 555 | self.cfg = cfg 556 | self.tatt_vid = AttentivePooling(self.cfg) 557 | # Regression layer 558 | idim = self.cfg.MODEL.FUSION.EMB_DIM 559 | gdim = self.cfg.MODEL.FUSION.EMB_DIM 560 | nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2), nn.ReLU()] 561 | #nn_list = [nn.Linear(idim, gdim), nn.ReLU(), nn.Linear(gdim, 2)] 562 | self.MLP_reg = nn.Sequential(*nn_list) 563 | 564 | def forward(self, semantic_aware_seg_vid_feats, vid_masks): 565 | summarized_vfeat, att_w = self.tatt_vid(semantic_aware_seg_vid_feats, vid_masks) 566 | loc = self.MLP_reg(summarized_vfeat) # loc = [t^s, t^e] 567 | return loc, att_w 568 | 569 | 570 | class SimpleModel(nn.Module): 571 | def __init__(self,cfg): 572 | super().__init__() 573 | self.cfg = cfg 574 | self.query_encoder = SimpleSentenceEmbeddingModule(cfg) 575 | self.video_encoder = SimpleVideoEmbeddingModule(cfg) 576 | self.v_fusor = SimpleFusionModule(cfg) 577 | self.s_fusor = SimpleFusionModuleSent(cfg) 578 | self.cv_fusor = TwostageSimpleFusionModule(cfg) 579 | self.n_non_local = self.cfg.MODEL.NONLOCAL.NUM_LAYERS 580 | self.sent_temp = self.cfg.MODEL.QUERY.TEMPERATURE 581 | self.non_locals_layer = nn.ModuleList([NonLocalBlock(cfg) for _ in range(self.n_non_local)]) 582 | self.loc_regressor = AttentionLocRegressor(cfg) 583 | self.loc_regressor_two_stage = TwostageAttentionLocRegressor(cfg) 584 | 585 | def forward(self,inputs): 586 | # encode query 587 | query_labels = inputs['query_labels'] 588 | query_masks = inputs['query_masks'] 589 | encoded_query, encoded_sentence = self.query_encoder(query_labels, query_masks) # encoded_query [B, L, D] D = 256 590 | 591 | # encode video 592 | vid_feats = inputs['video_feats'] 593 | vid_masks = inputs['video_masks'] 594 | encoded_video = self.video_encoder(vid_feats,vid_masks) 595 | 596 | # Crossmodality Attention 597 | attended_sent = self.s_fusor(encoded_query, query_masks, encoded_video, vid_masks) 598 | attended_vid = self.v_fusor(encoded_query, query_masks, encoded_video, vid_masks) 599 | two_stage_attended_vid = self.cv_fusor(attended_sent, query_masks, attended_vid, vid_masks) 600 | 601 | global_two_stage_vid = two_stage_attended_vid 602 | for non_local_layer in self.non_locals_layer: 603 | global_two_stage_vid = non_local_layer(global_two_stage_vid, vid_masks.squeeze(2)) 604 | 605 | loc, temporal_attn_weight = self.loc_regressor_two_stage(global_two_stage_vid, vid_masks.squeeze(2)) 606 | 607 | return {"timestamps": loc, "attention_weights": temporal_attn_weight} 608 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | absl-py=0.11.0=pypi_0 6 | argon2-cffi=20.1.0=pypi_0 7 | astor=0.8.1=pypi_0 8 | async-generator=1.10=pypi_0 9 | attrs=20.3.0=pypi_0 10 | backcall=0.2.0=py_0 11 | beautifulsoup4=4.9.1=py37_0 12 | blas=1.0=mkl 13 | bleach=3.2.2=pypi_0 14 | bzip2=1.0.8=h7b6447c_0 15 | ca-certificates=2020.1.1=0 16 | cached-property=1.5.2=pypi_0 17 | certifi=2020.4.5.2=py37_0 18 | cffi=1.14.0=py37he30daa8_1 19 | chardet=3.0.4=py37_1003 20 | conda=4.8.3=py37_0 21 | conda-build=3.18.11=py37_0 22 | conda-package-handling=1.6.1=py37h7b6447c_0 23 | cryptography=2.9.2=py37h1ba5d50_0 24 | cudatoolkit=10.1.243=h6bb024c_0 25 | decorator=4.4.2=py_0 26 | defusedxml=0.6.0=pypi_0 27 | entrypoints=0.3=pypi_0 28 | filelock=3.0.12=py_0 29 | freetype=2.9.1=h8a8886c_1 30 | gast=0.2.2=pypi_0 31 | glob2=0.7=py_0 32 | google-pasta=0.2.0=pypi_0 33 | grpcio=1.34.0=pypi_0 34 | h5py=3.1.0=pypi_0 35 | icu=58.2=he6710b0_3 36 | idna=2.9=py_1 37 | importlib-metadata=3.3.0=pypi_0 38 | intel-openmp=2020.1=217 39 | ipykernel=5.4.3=pypi_0 40 | ipython=7.15.0=py37_0 41 | ipython_genutils=0.2.0=py37_0 42 | ipywidgets=7.6.3=pypi_0 43 | jedi=0.17.0=py37_0 44 | jinja2=2.11.2=py_0 45 | jpeg=9b=h024ee3a_2 46 | jsonschema=3.2.0=pypi_0 47 | jupyter=1.0.0=pypi_0 48 | jupyter-client=6.1.11=pypi_0 49 | jupyter-console=6.2.0=pypi_0 50 | jupyter-core=4.7.0=pypi_0 51 | jupyterlab-pygments=0.1.2=pypi_0 52 | jupyterlab-widgets=1.0.0=pypi_0 53 | keras-applications=1.0.8=pypi_0 54 | keras-preprocessing=1.1.2=pypi_0 55 | ld_impl_linux-64=2.33.1=h53a641e_7 56 | libarchive=3.4.2=h62408e4_0 57 | libedit=3.1.20181209=hc058e9b_0 58 | libffi=3.3=he6710b0_1 59 | libgcc-ng=9.1.0=hdf63c60_0 60 | libgfortran-ng=7.3.0=hdf63c60_0 61 | liblief=0.10.1=he6710b0_0 62 | libpng=1.6.37=hbc83047_0 63 | libstdcxx-ng=9.1.0=hdf63c60_0 64 | libtiff=4.1.0=h2733197_1 65 | libxml2=2.9.10=he19cac6_1 66 | lz4-c=1.9.2=he6710b0_0 67 | markdown=3.3.3=pypi_0 68 | markupsafe=1.1.1=py37h7b6447c_0 69 | mistune=0.8.4=pypi_0 70 | mkl=2020.1=217 71 | mkl-service=2.3.0=py37he904b0f_0 72 | mkl_fft=1.1.0=py37h23d657b_0 73 | mkl_random=1.1.1=py37h0573a6f_0 74 | nbclient=0.5.1=pypi_0 75 | nbconvert=6.0.7=pypi_0 76 | nbformat=5.1.2=pypi_0 77 | ncurses=6.2=he6710b0_1 78 | nest-asyncio=1.4.3=pypi_0 79 | ninja=1.9.0=py37hfd86e86_0 80 | notebook=6.2.0=pypi_0 81 | numpy=1.18.1=py37h4f9e942_0 82 | numpy-base=1.18.1=py37hde5b4d6_1 83 | olefile=0.46=py37_0 84 | openssl=1.1.1g=h7b6447c_0 85 | opt-einsum=3.3.0=pypi_0 86 | packaging=20.8=pypi_0 87 | pandocfilters=1.4.3=pypi_0 88 | parso=0.7.0=py_0 89 | patchelf=0.11=he6710b0_0 90 | pexpect=4.8.0=py37_0 91 | pickleshare=0.7.5=py37_0 92 | pillow=7.1.2=py37hb39fc2d_0 93 | pip=20.0.2=py37_3 94 | pkginfo=1.5.0.1=py37_0 95 | prometheus-client=0.9.0=pypi_0 96 | prompt-toolkit=3.0.5=py_0 97 | protobuf=3.14.0=pypi_0 98 | psutil=5.7.0=py37h7b6447c_0 99 | ptyprocess=0.6.0=py37_0 100 | py-lief=0.10.1=py37h403a769_0 101 | pycosat=0.6.3=py37h7b6447c_0 102 | pycparser=2.20=py_0 103 | pygments=2.6.1=py_0 104 | pyopenssl=19.1.0=py37_0 105 | pyparsing=2.4.7=pypi_0 106 | pyrsistent=0.17.3=pypi_0 107 | pysocks=1.7.1=py37_0 108 | python=3.7.7=hcff3b4d_5 109 | python-dateutil=2.8.1=pypi_0 110 | python-libarchive-c=2.9=py_0 111 | pytorch=1.5.1=py3.7_cuda10.1.243_cudnn7.6.3_0 112 | pytz=2020.1=py_0 113 | pyyaml=5.3.1=py37h7b6447c_0 114 | pyzmq=21.0.1=pypi_0 115 | qtconsole=5.0.2=pypi_0 116 | qtpy=1.9.0=pypi_0 117 | readline=8.0=h7b6447c_0 118 | requests=2.23.0=py37_0 119 | ripgrep=11.0.2=he32d670_0 120 | ruamel_yaml=0.15.87=py37h7b6447c_0 121 | send2trash=1.5.0=pypi_0 122 | setuptools=46.4.0=py37_0 123 | six=1.14.0=py37_0 124 | soupsieve=2.0.1=py_0 125 | sqlite=3.31.1=h62c20be_1 126 | tensorboard=1.15.0=pypi_0 127 | tensorflow=1.15.0=pypi_0 128 | tensorflow-estimator=1.15.1=pypi_0 129 | termcolor=1.1.0=pypi_0 130 | terminado=0.9.2=pypi_0 131 | testpath=0.4.4=pypi_0 132 | tk=8.6.8=hbc83047_0 133 | torchvision=0.6.1=py37_cu101 134 | tornado=6.1=pypi_0 135 | tqdm=4.46.0=py_0 136 | traitlets=4.3.3=py37_0 137 | typing-extensions=3.7.4.3=pypi_0 138 | urllib3=1.25.8=py37_0 139 | wcwidth=0.2.4=py_0 140 | webencodings=0.5.1=pypi_0 141 | werkzeug=1.0.1=pypi_0 142 | wheel=0.34.2=py37_0 143 | widgetsnbextension=3.5.1=pypi_0 144 | wrapt=1.12.1=pypi_0 145 | xz=5.2.5=h7b6447c_0 146 | yacs=0.1.8=pypi_0 147 | yaml=0.1.7=had09818_2 148 | zipp=3.4.0=pypi_0 149 | zlib=1.2.11=h7b6447c_3 150 | zstd=1.4.4=h0b5b093_3 151 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #%% 2 | # import things 3 | # model code 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from utils.eval_utils import NLVLEvaluator, renew_best_score 7 | from utils.loss import NLVLLoss 8 | 9 | # logging code 10 | from torch.utils.tensorboard import SummaryWriter 11 | 12 | # etc 13 | import random 14 | import numpy as np 15 | import os 16 | from os.path import join 17 | from tqdm import tqdm 18 | from copy import deepcopy 19 | from yacs.config import CfgNode 20 | from yaml import dump as dump_yaml 21 | import argparse 22 | 23 | #%% 24 | # argparse 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("--model",'-m',type=str,default="CrossModalityTwostageAttention") 27 | parser.add_argument("--config",'-c',type=str,default="configs/cha_simple_model/simplemodel_cha_BS256_two-stage_attention.yml") 28 | parser.add_argument("--seed", '-s',type=int,default=38) 29 | parser.add_argument("--reg_w", type=float, default=1.0) 30 | parser.add_argument("--temp",'-t',type=float,default=1.0) 31 | args = parser.parse_args() 32 | 33 | dict_args = { 34 | "model": args.model, 35 | "confg": args.config 36 | } 37 | 38 | random.seed(int(args.seed)) 39 | os.environ['PYTHONHASHSEED'] = str(args.seed) 40 | np.random.seed(int(args.seed)) 41 | torch.manual_seed(int(args.seed)) 42 | torch.cuda.manual_seed(int(args.seed)) 43 | torch.cuda.manual_seed_all(int(args.seed)) 44 | torch.backends.cudnn.deterministic = True 45 | torch.backends.cudnn.benchmark = False 46 | torch.backends.cudnn.enabled = False 47 | 48 | #%% 49 | # load things according to arguments 50 | if args.model == "SimpleModel": 51 | from models.simple_model import SimpleModel as Model 52 | from config_parsers.simple_model_config import _C as TestCfg 53 | elif args.model == "CrossModalityTwostageAttention": 54 | from models.simple_model_cross_modal_two_stage import SimpleModel as Model 55 | from config_parsers.simple_model_cross_modality_twostage_attention_config import _C as TestCfg 56 | elif args.model == "CrossModalityTwostageAttentionTemperature": 57 | from models.simple_model_cross_modal_two_stage_temperature import SimpleModel as Model 58 | from config_parsers.simple_model_cross_modality_twostage_attention_config import _C as TestCfg 59 | else: 60 | raise ValueError("No such model: {}".format(args.model)) 61 | 62 | 63 | #%% 64 | # constants 65 | CONFIG_PATH = args.config 66 | print("config file path: ", CONFIG_PATH) 67 | TestCfg.merge_from_file(CONFIG_PATH) 68 | cfg = TestCfg 69 | cfg.MODEL.QUERY.TEMPERATURE = args.temp 70 | device = torch.device("cuda") 71 | DATA_PATH = cfg.DATASET.DATA_PATH 72 | ANNO_PATH ={"train": cfg.DATASET.TRAIN_ANNO_PATH, 73 | "test": cfg.DATASET.TEST_ANNO_PATH} 74 | VID_PATH = cfg.DATASET.VID_PATH 75 | 76 | #%% 77 | # function for dumping yaml 78 | def cfg_to_dict(cfg): 79 | dict_cfg = dict(cfg) 80 | for k,v in dict_cfg.items(): 81 | if isinstance(v,CfgNode): 82 | dict_cfg[k] = cfg_to_dict(v) 83 | return dict_cfg 84 | 85 | #%% 86 | # Load dataloader 87 | dataset_name = cfg.DATASET.NAME 88 | if dataset_name == "Charades": 89 | from dataset.charades_basic import CharadesBasicDatasetBuilder 90 | dataloaders = CharadesBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders() 91 | elif dataset_name == "AnetCap": 92 | from dataset.anetcap_basic import AnetCapBasicDatasetBuilder 93 | dataloaders = AnetCapBasicDatasetBuilder(cfg,data_path=DATA_PATH,anno_path=ANNO_PATH,vid_path=VID_PATH).make_dataloaders() 94 | else: 95 | raise ValueError("No such dataset: {}".format(dataset_name)) 96 | 97 | #%% 98 | # load training stuff 99 | ## model and loss 100 | model = Model(cfg).to(device) 101 | loss_fn = NLVLLoss(cfg).to(device) 102 | 103 | ## evaluator 104 | evaluator = NLVLEvaluator(cfg) 105 | ## optimizer 106 | lr = cfg.TRAIN.LR 107 | optimizer = torch.optim.Adam(params=model.parameters(), lr=lr, betas=(0.9,0.999)) 108 | ## batch to device function 109 | batch_to_device = dataloaders['train'].dataset.batch_to_device 110 | ## Logging Functions 111 | exp_name = cfg.EXP_NAME 112 | 113 | log_dir = join("results", dataset_name, args.model, exp_name + "_temp_" + str(args.temp)) 114 | logger = SummaryWriter(log_dir=log_dir,flush_secs=60) 115 | ## arxiv cfg 116 | dict_cfg = cfg_to_dict(cfg) 117 | with open(join(log_dir,"config.yml"),'w') as f: 118 | dump_yaml(dict_cfg,f) 119 | ## arxiv args 120 | with open(join(log_dir,"args.yml"),'w') as f: 121 | dump_yaml(dict_args,f) 122 | 123 | # information print out 124 | print("====="*10) 125 | print(args) 126 | print("====="*10) 127 | print(cfg) 128 | print("====="*10) 129 | 130 | #%% 131 | curr_best = 0 132 | # training loop 133 | pbar = tqdm(range(cfg.TRAIN.NUM_EPOCH)) 134 | for epoch in pbar: 135 | # training loop 136 | model.train() 137 | losses_list = [] 138 | for batch_idx,item in enumerate(dataloaders['train']): 139 | # update progress bar 140 | pbar.set_description("training {}/{}".format(batch_idx,len(dataloaders['train']))) 141 | # make prediction 142 | optimizer.zero_grad() 143 | batch_to_device(item, device) 144 | model_outputs = model(item) 145 | # loss fn 146 | losses = loss_fn(model_outputs,item) 147 | sum_loss = sum([v for k,v in losses.items()]) 148 | sum_loss.backward() 149 | # update model 150 | optimizer.step() 151 | # add to losses list 152 | losses_list.append(losses) 153 | # write training information 154 | ## make mean losses dict 155 | mean_losses_dict = {} 156 | for k in list(losses_list[0].keys()): 157 | mean_losses_dict[k] = torch.mean(torch.Tensor([x[k] for x in losses_list])) 158 | for loss_name,val in mean_losses_dict.items(): 159 | logger.add_scalar("train/{}".format(loss_name), val, global_step=epoch) 160 | 161 | # test loop 162 | model.eval() 163 | eval_results_list = [] 164 | for batch_idx,item in enumerate(dataloaders['test']): 165 | # update progress bar 166 | pbar.set_description("test {}/{}".format(batch_idx,len(dataloaders['test']))) 167 | # make evaluation 168 | with torch.no_grad(): 169 | batch_to_device(item,device) 170 | model_outputs = model(item) 171 | eval_results = evaluator(model_outputs,item) 172 | # add to mean eval results 173 | eval_results_list.append(eval_results) 174 | # write test information 175 | ## make mean metric dict 176 | mean_eval_results_dict = {} 177 | for k in list(eval_results_list[0].keys()): 178 | mean_eval_results_dict[k] = torch.mean(torch.Tensor([x[k] for x in eval_results_list])) 179 | 180 | for metric_name,val in mean_eval_results_dict.items(): 181 | logger.add_scalar("test/{}".format(metric_name), val, global_step=epoch) 182 | curr_best = renew_best_score(mean_eval_results_dict['Recall@0.5'], curr_best, model) 183 | 184 | # print test information 185 | tqdm.write("****** epoch:{} ******".format(epoch)) 186 | for k,v in mean_eval_results_dict.items(): 187 | tqdm.write("\t{}: {}".format(k,v)) 188 | -------------------------------------------------------------------------------- /utils/eval_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch 5 | 6 | def compute_tiou(pred, gt): 7 | intersection = max(0, min(pred[1], gt[1]) - max(pred[0], gt[0])) 8 | union = max(pred[1], gt[1]) - min(pred[0], gt[0]) 9 | if union == 0.0: 10 | if intersection > 0.0: 11 | return 1.0 12 | else: 13 | return 0.0 14 | return float(intersection) / union 15 | 16 | def renew_best_score(cur, best, model): 17 | if best > cur: 18 | return best 19 | torch.save(model.state_dict(), 'pretrained_best.pth') 20 | 21 | return cur 22 | 23 | class NLVLEvaluator(): 24 | def __init__(self,cfg): 25 | self.cfg = cfg 26 | self.iou_thresh = self.cfg.TRAIN.IOU_THRESH 27 | self.metrics = ["mIoU"] + ["Recall@{:.1f}".format(x) for x in self.iou_thresh] 28 | 29 | def __call__(self,model_outputs,batch): 30 | gt = torch.cat([batch["grounding_start_pos"].unsqueeze(1),batch["grounding_end_pos"].unsqueeze(1)],dim=1) 31 | pred = model_outputs['timestamps'] 32 | scores = {x:[] for x in self.metrics} 33 | for p,g in zip(pred,gt): 34 | tiou = compute_tiou(p,g) 35 | scores["mIoU"].append(tiou) 36 | for thresh in self.iou_thresh: 37 | scores["Recall@{:.1f}".format(thresh)].append(tiou >= thresh) 38 | for k,v in scores.items(): 39 | scores[k] = torch.mean(torch.Tensor(v)) 40 | return scores 41 | -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn 4 | import torch 5 | 6 | class TAGLoss(nn.Module): 7 | def __init__(self): 8 | super(TAGLoss, self).__init__() 9 | 10 | def forward(self, w, mask): 11 | ac_loss = (-mask*torch.log(w+1e-8)).sum(1) / mask.sum(1) 12 | ac_loss = ac_loss.mean(0) 13 | 14 | return ac_loss 15 | 16 | class TGRegressionCriterion(nn.Module): 17 | def __init__(self): 18 | super(TGRegressionCriterion, self).__init__() 19 | 20 | self.regloss1 = nn.SmoothL1Loss() 21 | self.regloss2 = nn.SmoothL1Loss() 22 | 23 | def forward(self, loc, s_gt, e_gt): 24 | 25 | total_loss = self.regloss1(loc[:,0], s_gt) + self.regloss2(loc[:,1], e_gt) 26 | 27 | return total_loss 28 | 29 | class NLVLLoss(nn.Module): 30 | def __init__(self,cfg, reg_w=1): 31 | super().__init__() 32 | self.temporal_localization_loss = TGRegressionCriterion() 33 | self.temporal_attention_loss2 = TAGLoss() 34 | self.reg_w = reg_w 35 | 36 | def forward(self,model_outputs,batch): 37 | # position loss 38 | timestamps = model_outputs['timestamps'] # [B,2] 39 | gt_start_pos = batch["grounding_start_pos"] 40 | gt_end_pos = batch["grounding_end_pos"] 41 | gt_timestamps = torch.cat([gt_start_pos.unsqueeze(1),gt_end_pos.unsqueeze(1)],dim=1) # [B,2] 42 | 43 | localization_loss = self.temporal_localization_loss(timestamps, gt_start_pos, gt_end_pos) 44 | localization_loss = localization_loss * self.reg_w 45 | 46 | # attention loss 47 | attention_weights = model_outputs['attention_weights'] # [B,128] 48 | attention_masks = batch["attention_masks"] # [B,128] 49 | attention_loss = self.temporal_attention_loss2(attention_weights,attention_masks) 50 | 51 | loss_dict = { 52 | "localization_loss": localization_loss, 53 | "attention_loss": attention_loss 54 | } 55 | return loss_dict 56 | --------------------------------------------------------------------------------