├── .DS_Store ├── README.md ├── config.py ├── ctrl.py ├── dataset.py ├── exp_data ├── .DS_Store └── .gitkeep.txt ├── main.py ├── utils.py └── video_allframes_info.pkl /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iworldtong/TALL.pytorch/fea4bea203f6d52453d8cc56588e34042729e2b6/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TALL.pytorch 2 | PyTorch implementation of ["TALL: Temporal Activity Localization via Language Query. Gao et al. ICCV2017."]() 3 | 4 | This implementation highly based on official code [jiyanggao/TALL]().(Tensorflow) 5 | 6 | ### Require 7 | 8 | - Python 3.6 9 | - PyTorch 1.0 10 | - CPU | GPU supported. 11 | 12 | ### Visual Features on TACoS 13 | 14 | Download the C3D features for [training set](https://drive.google.com/file/d/1zQp0aYGFCm8PqqHOh4UtXfy2U3pJMBeu/view?usp=sharing) and [test set](https://drive.google.com/file/d/1zC-UrspRf42Qiu5prQw4fQrbgLQfJN-P/view?usp=sharing) of TACoS dataset. Modify the path to feature folders in `config.py`. 15 | 16 | ### Sentence Embeddings on TACoS 17 | 18 | Download the Skip-thought sentence embeddings and sample files from [here](https://drive.google.com/file/d/1HF-hNFPvLrHwI5O7YvYKZWTeTxC5Mg1K/view?usp=sharing) of TACoS Dataset, and put them under exp_data folder. 19 | 20 | ### Reproduce the results on TACoS 21 | 22 | `python main.py` 23 | 24 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | class CONFIG(object): 4 | def __init__(self): 5 | super(CONFIG, self).__init__() 6 | 7 | self.phase = 'train' 8 | 9 | self.seed = 2019 10 | self.device = [0] 11 | 12 | self.max_epoch = 20 13 | self.batch_size = 56 14 | self.test_batch_size = 1 15 | 16 | # Dataset setting 17 | self.num_worker = 8 18 | self.test_csv_path = "./exp_data/TACoS/test_clip-sentvec.pkl" 19 | self.train_csv_path = "./exp_data/TACoS/train_clip-sentvec.pkl" 20 | self.test_feature_dir="../TACoS/Interval128_256_overlap0.8_c3d_fc6/" 21 | self.train_feature_dir = "../TACoS/Interval64_128_256_512_overlap0.8_c3d_fc6/" 22 | 23 | self.movie_length_info_path = "./video_allframes_info.pkl" 24 | 25 | self.context_num = 1 26 | self.context_size = 128 27 | 28 | # Model setting 29 | self.visual_dim = 4096 * 3 30 | self.sentence_embed_dim = 4800 31 | self.semantic_dim = 1024 # the size of visual and semantic comparison size 32 | self.middle_layer_dim = 1024 33 | 34 | self.IoU = 0.5 35 | self.nIoU = 0.15 36 | 37 | # Optimizer settking 38 | self.optimizer = 'Adam' 39 | self.vs_lr = 5e-3 40 | self.weight_decay = 1e-5 41 | 42 | self.lambda_reg = 0.01 43 | self.alpha = 1.0 / self.batch_size 44 | 45 | 46 | # Logging setting 47 | self.save_log = False 48 | self.log_path = './log.txt' 49 | 50 | self.test_output_path = "./ctrl_test_results.txt" 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /ctrl.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | 7 | class CTRL(nn.Module): 8 | def __init__(self, 9 | visual_dim, 10 | sentence_embed_dim, 11 | semantic_dim, 12 | middle_layer_dim, 13 | dropout_rate=0., 14 | ): 15 | super(CTRL, self).__init__() 16 | self.semantic_dim = semantic_dim 17 | 18 | self.v2s_fc = nn.Linear(visual_dim, semantic_dim) 19 | self.s2s_fc = nn.Linear(sentence_embed_dim, semantic_dim) 20 | self.fc1 = nn.Conv2d(semantic_dim * 4, middle_layer_dim, kernel_size=1, stride=1) 21 | self.fc2 = nn.Conv2d(middle_layer_dim, 3, kernel_size=1, stride=1) 22 | 23 | self.relu = nn.ReLU(inplace=True) 24 | self.dropout = nn.Dropout(dropout_rate) 25 | 26 | def forward(self, visual_feature, sentence_embed): 27 | batch_size, _ = visual_feature.size() 28 | 29 | transformed_clip = self.v2s_fc(visual_feature) 30 | transformed_sentence = self.s2s_fc(sentence_embed) 31 | 32 | transformed_clip_norm = transformed_clip / transformed_clip.norm(2, dim=1, keepdim=True) # by row 33 | transformed_sentence_norm = transformed_sentence / transformed_sentence.norm(2, dim=1, keepdim=True) # by row 34 | 35 | # Cross modal combine: [mul, add, concat] 36 | vv_f = transformed_clip_norm.repeat(batch_size, 1).reshape(batch_size, batch_size, self.semantic_dim) 37 | ss_f = transformed_sentence_norm.repeat(1, batch_size).reshape(batch_size, batch_size, self.semantic_dim) 38 | mul_feature = vv_f * ss_f 39 | add_feature = vv_f + ss_f 40 | cat_feature = torch.cat((vv_f, ss_f), 2) 41 | cross_modal_vec = torch.cat((mul_feature, add_feature, cat_feature), 2) 42 | 43 | # vs_multilayer 44 | out = cross_modal_vec.unsqueeze(0).permute(0,3,1,2) # match conv op 45 | out = self.fc1(out) 46 | out = self.relu(out) 47 | out = self.fc2(out) 48 | out = out.permute(0,2,3,1).squeeze(0) 49 | 50 | return out 51 | 52 | class CTRL_loss(nn.Module): 53 | def __init__(self, lambda_reg): 54 | super(CTRL_loss, self).__init__() 55 | self.lambda_reg = lambda_reg 56 | 57 | def forward(self, net, offset_label): 58 | batch_size = net.size()[0] 59 | sim_score_mat, p_reg_mat, l_reg_mat = net.split(1, dim=2) 60 | sim_score_mat = sim_score_mat.reshape(batch_size, batch_size) 61 | p_reg_mat = p_reg_mat.reshape(batch_size, batch_size) 62 | l_reg_mat = l_reg_mat.reshape(batch_size, batch_size) 63 | 64 | # make mask mat 65 | I_2 = 2.0 * torch.eye(batch_size) 66 | all1 = torch.ones([batch_size, batch_size]) 67 | mask = all1 - I_2 68 | 69 | # loss cls, not considering iou 70 | I = torch.eye(batch_size) 71 | batch_para_mat = torch.ones([batch_size, batch_size]) / batch_size 72 | para_mat = I + batch_para_mat 73 | 74 | loss_mat = torch.log(all1 + torch.exp(torch.mul(mask, sim_score_mat))) 75 | loss_mat = torch.mul(loss_mat, para_mat) 76 | loss_align = torch.mean(loss_mat) 77 | 78 | # regression loss 79 | l_reg_diag = torch.mm(torch.mul(l_reg_mat, I), torch.ones([batch_size, 1])) 80 | p_reg_diag = torch.mm(torch.mul(p_reg_mat, I), torch.ones([batch_size, 1])) 81 | offset_pred = torch.cat((p_reg_diag, l_reg_diag), 1) 82 | loss_reg = torch.mean(torch.abs(offset_pred - offset_label)) 83 | 84 | loss = loss_align + self.lambda_reg * loss_reg 85 | 86 | return loss 87 | 88 | 89 | 90 | 91 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import os 3 | import sys 4 | import numpy as np 5 | import pickle 6 | 7 | # torch 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torchvision import datasets, transforms 12 | 13 | 14 | 15 | def calc_IoU(i0, i1): 16 | union = (min(i0[0], i1[0]), max(i0[1], i1[1])) 17 | inter = (max(i0[0], i1[0]), min(i0[1], i1[1])) 18 | iou = 1.0 * (inter[1] - inter[0]) / (union[1] - union[0]) 19 | return iou 20 | 21 | 22 | def calc_nIoL(base, sliding_clip): 23 | ''' 24 | The reason we use nIoL is that we want the the most part of the sliding 25 | window clip to overlap with the assigned sentence, and simply increasing 26 | IoU threshold would harm regression layers ( regression aims to move the 27 | clip from low IoU to high IoU). 28 | ''' 29 | inter = (max(base[0], sliding_clip[0]), min(base[1], sliding_clip[1])) 30 | inter_l = inter[1] - inter[0] 31 | sliding_l = sliding_clip[1] - sliding_clip[0] 32 | nIoL = 1.0 * (sliding_l - inter_l) / sliding_l 33 | return nIoL 34 | 35 | 36 | class TrainDataset(torch.utils.data.Dataset): 37 | def __init__(self, 38 | sliding_dir, 39 | it_path, 40 | visual_dim, 41 | sentence_embed_dim, 42 | IoU=0.5, 43 | nIoU=0.15, 44 | context_num=1, 45 | context_size=128, 46 | ): 47 | self.sliding_dir = sliding_dir 48 | self.it_path = it_path 49 | self.visual_dim = visual_dim 50 | self.sentence_embed_dim = sentence_embed_dim 51 | self.IoU = IoU 52 | self.nIoU = nIoU 53 | self.context_num = context_num 54 | self.context_size = context_size 55 | 56 | self.load_data() 57 | 58 | def load_data(self): 59 | ''' 60 | Note: 61 | self.clip_sentence_pairs : list of (ori_clip_name, sent_vec) 62 | self.clip_sentence_pairs_iou : list of (ori_clip_name, sent_vec, clip_name(with ".npy"), s_o, e_o) —— not all ground truth 63 | ''' 64 | # movie_length_info = pickle.load(open("./video_allframes_info.pkl", 'rb'), encoding='iso-8859-1') 65 | print("Reading training data list from " + self.it_path) 66 | csv = pickle.load(open(self.it_path, 'rb'), encoding='iso-8859-1') 67 | self.clip_sentence_pairs = [] 68 | for l in csv: 69 | clip_name = l[0] 70 | sent_vecs = l[1] 71 | for sent_vec in sent_vecs: 72 | self.clip_sentence_pairs.append((clip_name, sent_vec)) 73 | 74 | movie_names_set = set() 75 | self.movie_clip_names = {} 76 | # read groundtruth sentence-clip pairs 77 | for k in range(len(self.clip_sentence_pairs)): 78 | clip_name = self.clip_sentence_pairs[k][0] 79 | movie_name = clip_name.split("_")[0] 80 | if not movie_name in movie_names_set: 81 | movie_names_set.add(movie_name) 82 | self.movie_clip_names[movie_name] = [] 83 | self.movie_clip_names[movie_name].append(k) 84 | self.movie_names = list(movie_names_set) 85 | self.num_samples = len(self.clip_sentence_pairs) 86 | print(str(len(self.clip_sentence_pairs))+" clip-sentence pairs are readed") 87 | 88 | # read sliding windows, and match them with the groundtruths to make training samples 89 | sliding_clips_tmp = os.listdir(self.sliding_dir) 90 | self.clip_sentence_pairs_iou = [] 91 | for clip_name in sliding_clips_tmp: 92 | if clip_name.split(".")[2]=="npy": 93 | movie_name = clip_name.split("_")[0] 94 | for clip_sentence in self.clip_sentence_pairs: 95 | original_clip_name = clip_sentence[0] 96 | original_movie_name = original_clip_name.split("_")[0] 97 | if original_movie_name == movie_name: 98 | start = int(clip_name.split("_")[1]) 99 | end = int(clip_name.split("_")[2].split(".")[0]) 100 | o_start = int(original_clip_name.split("_")[1]) 101 | o_end = int(original_clip_name.split("_")[2].split(".")[0]) 102 | iou = calc_IoU((start, end), (o_start, o_end)) 103 | if iou > self.IoU: 104 | nIoL = calc_nIoL((o_start, o_end), (start, end)) 105 | if nIoL < self.nIoU: 106 | # movie_length = movie_length_info[movie_name.split(".")[0]] 107 | start_offset = o_start - start 108 | end_offset = o_end - end 109 | self.clip_sentence_pairs_iou.append((clip_sentence[0], clip_sentence[1], clip_name, start_offset, end_offset)) 110 | self.num_samples_iou = len(self.clip_sentence_pairs_iou) 111 | print(str(len(self.clip_sentence_pairs_iou))+" iou clip-sentence pairs are readed") 112 | 113 | def __len__(self): 114 | return self.num_samples_iou 115 | 116 | def __getitem__(self, index): 117 | # read context features 118 | left_context_feat, right_context_feat = self.get_context_window(self.clip_sentence_pairs_iou[index][2]) 119 | feat_path = os.path.join(self.sliding_dir, self.clip_sentence_pairs_iou[index][2]) 120 | featmap = np.load(feat_path) 121 | vis = np.hstack((left_context_feat, featmap, right_context_feat)) 122 | 123 | sent = self.clip_sentence_pairs_iou[index][1][:self.sentence_embed_dim] 124 | 125 | p_offset = self.clip_sentence_pairs_iou[index][3] 126 | l_offset = self.clip_sentence_pairs_iou[index][4] 127 | offset = np.array([p_offset, l_offset], dtype=np.float32) 128 | 129 | data_torch = { 130 | 'vis' : torch.from_numpy(vis), 131 | 'sent' : torch.from_numpy(sent), 132 | 'offset': torch.from_numpy(offset), 133 | } 134 | return data_torch 135 | 136 | 137 | def get_context_window(self, clip_name): 138 | movie_name = clip_name.split("_")[0] 139 | start = int(clip_name.split("_")[1]) 140 | end = int(clip_name.split("_")[2].split(".")[0]) 141 | left_context_feats = np.zeros([self.context_num, self.visual_dim // 3], dtype=np.float32) 142 | right_context_feats = np.zeros([self.context_num, self.visual_dim // 3], dtype=np.float32) 143 | last_left_feat = np.load(os.path.join(self.sliding_dir, clip_name)) 144 | last_right_feat = np.load(os.path.join(self.sliding_dir, clip_name)) 145 | for k in range(self.context_num): 146 | left_context_start = start - self.context_size * (k + 1) 147 | left_context_end = start - self.context_size * k 148 | right_context_start = end + self.context_size * k 149 | right_context_end = end + self.context_size * (k + 1) 150 | left_context_name = movie_name + "_" + str(left_context_start) + "_" + str(left_context_end) + ".npy" 151 | right_context_name = movie_name + "_" + str(right_context_start) + "_" + str(right_context_end) + ".npy" 152 | 153 | left_context_path = os.path.join(self.sliding_dir, left_context_name) 154 | if os.path.exists(left_context_path): 155 | left_context_feat = np.load(left_context_path) 156 | last_left_feat = left_context_feat 157 | else: 158 | left_context_feat = last_left_feat 159 | 160 | right_context_path = os.path.join(self.sliding_dir, right_context_name) 161 | if os.path.exists(right_context_path): 162 | right_context_feat = np.load(right_context_path) 163 | last_right_feat = right_context_feat 164 | else: 165 | right_context_feat = last_right_feat 166 | 167 | left_context_feats[k] = left_context_feat 168 | right_context_feats[k] = right_context_feat 169 | return np.mean(left_context_feats, axis=0), np.mean(right_context_feats, axis=0) 170 | 171 | 172 | 173 | class TestingDataSet(object): 174 | def __init__(self, img_dir, csv_path, batch_size): 175 | #il_path: image_label_file path 176 | #self.index_in_epoch = 0 177 | #self.epochs_completed = 0 178 | self.batch_size = batch_size 179 | self.image_dir = img_dir 180 | print("Reading testing data list from "+csv_path) 181 | self.semantic_size = 4800 182 | csv = pickle.load(open(csv_path, 'rb'), encoding='iso-8859-1') 183 | self.clip_sentence_pairs = [] 184 | for l in csv: 185 | clip_name = l[0] 186 | sent_vecs = l[1] 187 | for sent_vec in sent_vecs: 188 | self.clip_sentence_pairs.append((clip_name, sent_vec)) 189 | print(str(len(self.clip_sentence_pairs))+" pairs are readed") 190 | movie_names_set = set() 191 | self.movie_clip_names = {} 192 | for k in range(len(self.clip_sentence_pairs)): 193 | clip_name = self.clip_sentence_pairs[k][0] 194 | movie_name = clip_name.split("_")[0] 195 | if not movie_name in movie_names_set: 196 | movie_names_set.add(movie_name) 197 | self.movie_clip_names[movie_name] = [] 198 | self.movie_clip_names[movie_name].append(k) 199 | self.movie_names = list(movie_names_set) 200 | 201 | self.clip_num_per_movie_max = 0 202 | for movie_name in self.movie_clip_names: 203 | if len(self.movie_clip_names[movie_name])>self.clip_num_per_movie_max: self.clip_num_per_movie_max = len(self.movie_clip_names[movie_name]) 204 | print("Max number of clips in a movie is "+str(self.clip_num_per_movie_max)) 205 | 206 | self.sliding_clip_path = img_dir 207 | sliding_clips_tmp = os.listdir(self.sliding_clip_path) 208 | self.sliding_clip_names = [] 209 | for clip_name in sliding_clips_tmp: 210 | if clip_name.split(".")[2]=="npy": 211 | movie_name = clip_name.split("_")[0] 212 | if movie_name in self.movie_clip_names: 213 | self.sliding_clip_names.append(clip_name.split(".")[0]+"."+clip_name.split(".")[1]) 214 | self.num_samples = len(self.clip_sentence_pairs) 215 | print("sliding clips number: "+str(len(self.sliding_clip_names))) 216 | assert self.batch_size <= self.num_samples 217 | 218 | 219 | def get_clip_sample(self, sample_num, movie_name, clip_name): 220 | length=len(os.listdir(self.image_dir+movie_name+"/"+clip_name)) 221 | sample_step=1.0*length/sample_num 222 | sample_pos=np.floor(sample_step*np.array(range(sample_num))) 223 | sample_pos_str=[] 224 | img_names=os.listdir(self.image_dir+movie_name+"/"+clip_name) 225 | # sort is very important! to get a correct sequence order 226 | img_names.sort() 227 | # print img_names 228 | for pos in sample_pos: 229 | sample_pos_str.append(self.image_dir+movie_name+"/"+clip_name+"/"+img_names[int(pos)]) 230 | return sample_pos_str 231 | 232 | def get_context_window(self, clip_name, win_length): 233 | movie_name = clip_name.split("_")[0] 234 | start = int(clip_name.split("_")[1]) 235 | end = int(clip_name.split("_")[2].split(".")[0]) 236 | clip_length = 128#end-start 237 | left_context_feats = np.zeros([win_length,4096], dtype=np.float32) 238 | right_context_feats = np.zeros([win_length,4096], dtype=np.float32) 239 | last_left_feat = np.load(self.sliding_clip_path+clip_name) 240 | last_right_feat = np.load(self.sliding_clip_path+clip_name) 241 | for k in range(win_length): 242 | left_context_start = start-clip_length*(k+1) 243 | left_context_end = start-clip_length*k 244 | right_context_start = end+clip_length*k 245 | right_context_end = end+clip_length*(k+1) 246 | left_context_name = movie_name+"_"+str(left_context_start)+"_"+str(left_context_end)+".npy" 247 | right_context_name = movie_name+"_"+str(right_context_start)+"_"+str(right_context_end)+".npy" 248 | if os.path.exists(self.sliding_clip_path+left_context_name): 249 | left_context_feat = np.load(self.sliding_clip_path+left_context_name) 250 | last_left_feat = left_context_feat 251 | else: 252 | left_context_feat = last_left_feat 253 | if os.path.exists(self.sliding_clip_path+right_context_name): 254 | right_context_feat = np.load(self.sliding_clip_path+right_context_name) 255 | last_right_feat = right_context_feat 256 | else: 257 | right_context_feat = last_right_feat 258 | left_context_feats[k] = left_context_feat 259 | right_context_feats[k] = right_context_feat 260 | 261 | return np.mean(left_context_feats, axis=0), np.mean(right_context_feats, axis=0) 262 | 263 | 264 | def load_movie(self, movie_name): 265 | movie_clip_sentences=[] 266 | for k in range(len(self.clip_names)): 267 | if movie_name in self.clip_names[k]: 268 | movie_clip_sentences.append((self.clip_names[k], self.sent_vecs[k][:2400], self.sentences[k])) 269 | 270 | movie_clip_imgs=[] 271 | for k in range(len(self.movie_frames[movie_name])): 272 | # print str(k)+"/"+str(len(self.movie_frames[movie_name])) 273 | if os.path.isfile(self.movie_frames[movie_name][k][1]) and os.path.getsize(self.movie_frames[movie_name][k][1])!=0: 274 | img=load_image(self.movie_frames[movie_name][k][1]) 275 | movie_clip_imgs.append((self.movie_frames[movie_name][k][0],img)) 276 | 277 | return movie_clip_imgs, movie_clip_sentences 278 | 279 | def load_movie_byclip(self,movie_name,sample_num): 280 | movie_clip_sentences=[] 281 | movie_clip_featmap=[] 282 | clip_set=set() 283 | for k in range(len(self.clip_sentence_pairs)): 284 | if movie_name in self.clip_sentence_pairs[k][0]: 285 | movie_clip_sentences.append((self.clip_sentence_pairs[k][0],self.clip_sentence_pairs[k][1][:self.semantic_size])) 286 | 287 | if not self.clip_sentence_pairs[k][0] in clip_set: 288 | clip_set.add(self.clip_sentence_pairs[k][0]) 289 | # print str(k)+"/"+str(len(self.movie_clip_names[movie_name])) 290 | visual_feature_path=self.image_dir+self.clip_sentence_pairs[k][0]+".npy" 291 | feature_data=np.load(visual_feature_path) 292 | movie_clip_featmap.append((self.clip_sentence_pairs[k][0],feature_data)) 293 | return movie_clip_featmap, movie_clip_sentences 294 | 295 | def load_movie_slidingclip(self, movie_name, sample_num): 296 | movie_clip_sentences = [] 297 | movie_clip_featmap = [] 298 | clip_set = set() 299 | for k in range(len(self.clip_sentence_pairs)): 300 | if movie_name in self.clip_sentence_pairs[k][0]: 301 | movie_clip_sentences.append((self.clip_sentence_pairs[k][0], self.clip_sentence_pairs[k][1][:self.semantic_size])) 302 | for k in range(len(self.sliding_clip_names)): 303 | if movie_name in self.sliding_clip_names[k]: 304 | # print str(k)+"/"+str(len(self.movie_clip_names[movie_name])) 305 | visual_feature_path = self.sliding_clip_path+self.sliding_clip_names[k]+".npy" 306 | #context_feat=self.get_context(self.sliding_clip_names[k]+".npy") 307 | left_context_feat,right_context_feat = self.get_context_window(self.sliding_clip_names[k]+".npy",1) 308 | feature_data = np.load(visual_feature_path) 309 | #comb_feat=np.hstack((context_feat,feature_data)) 310 | comb_feat = np.hstack((left_context_feat,feature_data,right_context_feat)) 311 | movie_clip_featmap.append((self.sliding_clip_names[k], comb_feat)) 312 | return movie_clip_featmap, movie_clip_sentences 313 | 314 | -------------------------------------------------------------------------------- /exp_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iworldtong/TALL.pytorch/fea4bea203f6d52453d8cc56588e34042729e2b6/exp_data/.DS_Store -------------------------------------------------------------------------------- /exp_data/.gitkeep.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/iworldtong/TALL.pytorch/fea4bea203f6d52453d8cc56588e34042729e2b6/exp_data/.gitkeep.txt -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import time 2 | import pickle 3 | import numpy as np 4 | from collections import OrderedDict 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | 11 | 12 | from ctrl import * 13 | from utils import * 14 | from dataset import * 15 | 16 | from config import CONFIG 17 | cfg = CONFIG() 18 | 19 | 20 | class Processor(): 21 | def __init__(self): 22 | self.load_data() 23 | self.load_model() 24 | self.load_optimizer() 25 | 26 | def load_data(self): 27 | self.data_loader = dict() 28 | if cfg.phase == 'train': 29 | self.data_loader['train'] = torch.utils.data.DataLoader( 30 | dataset=TrainDataset(cfg.train_feature_dir, 31 | cfg.train_csv_path, 32 | cfg.visual_dim, 33 | cfg.sentence_embed_dim, 34 | cfg.IoU, 35 | cfg.nIoU, 36 | cfg.context_num, 37 | cfg.context_size, 38 | ), 39 | batch_size=cfg.batch_size, 40 | shuffle=True, 41 | num_workers=cfg.num_worker) 42 | 43 | self.testDataset = TestingDataSet(cfg.test_feature_dir, cfg.test_csv_path, cfg.test_batch_size) 44 | # self.data_loader['test'] = torch.utils.data.DataLoader( 45 | # dataset=TestDataset(cfg.train_feature_dir, 46 | # cfg.train_csv_path, 47 | # cfg.visual_dim, 48 | # cfg.sentence_embed_dim, 49 | # cfg.IoU, 50 | # cfg.nIoU, 51 | # cfg.context_num, 52 | # cfg.context_size, 53 | # ), 54 | # batch_size=cfg.test_batch_size, 55 | # shuffle=False, 56 | # num_workers=cfg.num_worker) 57 | 58 | def load_model(self): 59 | torch.manual_seed(cfg.seed) 60 | if torch.cuda.is_available(): 61 | if type(cfg.device) is list and len(cfg.device) > 1: 62 | torch.cuda.manual_seed_all(cfg.seed) 63 | else: 64 | torch.cuda.manual_seed(cfg.seed) 65 | 66 | self.output_device = cfg.device[0] if type(cfg.device) is list else cfg.device 67 | 68 | self.model = CTRL(cfg.visual_dim, cfg.sentence_embed_dim, cfg.semantic_dim, cfg.middle_layer_dim) 69 | self.loss = CTRL_loss(cfg.lambda_reg) 70 | if torch.cuda.is_available(): 71 | self.model.cuda(self.output_device) 72 | self.loss.cuda(self.output_device) 73 | 74 | if torch.cuda.is_available() and type(cfg.device) is list and len(cfg.device) > 1: 75 | self.model = nn.DataParallel(self.model, device_ids=cfg.device, output_device=self.output_device) 76 | 77 | def load_optimizer(self): 78 | if cfg.optimizer == 'Adam': 79 | self.optimizer = optim.Adam( 80 | self.model.parameters(), 81 | lr=cfg.vs_lr, 82 | weight_decay=cfg.weight_decay, 83 | ) 84 | else: 85 | raise ValueError() 86 | 87 | 88 | def train(self): 89 | losses = [] 90 | for epoch in range(cfg.max_epoch): 91 | for step, data_torch in enumerate(self.data_loader['train']): 92 | self.model.train() 93 | self.record_time() 94 | 95 | # forward 96 | output = self.model(data_torch['vis'], data_torch['sent']) 97 | loss = self.loss(output, data_torch['offset']) 98 | 99 | # backward 100 | self.optimizer.zero_grad() 101 | loss.backward() 102 | self.optimizer.step() 103 | losses.append(loss.item()) 104 | 105 | duration = self.split_time() 106 | 107 | if (step+1) % 5 == 0 or step == 0: 108 | self.print_log('Epoch %d, Step %d: loss = %.3f (%.3f sec)' % (epoch+1, step+1, losses[-1], duration)) 109 | 110 | if (step+1) % 2000 == 0: 111 | self.print_log('Testing:') 112 | movie_length_info = pickle.load(open(cfg.movie_length_info_path, 'rb'), encoding='iso-8859-1') 113 | self.eval(movie_length_info, step + 1, cfg.test_output_path) 114 | 115 | 116 | 117 | def eval(self, movie_length_info, step, test_output_path): 118 | self.model.eval() 119 | IoU_thresh = [0.1, 0.2, 0.3, 0.4, 0.5] 120 | all_correct_num_10 = [0.0] * 5 121 | all_correct_num_5 = [0.0] * 5 122 | all_correct_num_1 = [0.0] * 5 123 | all_retrievd = 0.0 124 | 125 | for movie_name in self.testDataset.movie_names: 126 | movie_length=movie_length_info[movie_name.split(".")[0]] 127 | self.print_log("Test movie: " + movie_name + "....loading movie data") 128 | movie_clip_featmaps, movie_clip_sentences = self.testDataset.load_movie_slidingclip(movie_name, 16) 129 | self.print_log("sentences: "+ str(len(movie_clip_sentences))) 130 | self.print_log("clips: "+ str(len(movie_clip_featmaps))) 131 | sentence_image_mat = np.zeros([len(movie_clip_sentences), len(movie_clip_featmaps)]) 132 | sentence_image_reg_mat = np.zeros([len(movie_clip_sentences), len(movie_clip_featmaps), 2]) 133 | for k in range(len(movie_clip_sentences)): 134 | sent_vec = movie_clip_sentences[k][1] 135 | sent_vec = np.reshape(sent_vec,[1,sent_vec.shape[0]]) 136 | for t in range(len(movie_clip_featmaps)): 137 | featmap = movie_clip_featmaps[t][1] 138 | visual_clip_name = movie_clip_featmaps[t][0] 139 | start = float(visual_clip_name.split("_")[1]) 140 | end = float(visual_clip_name.split("_")[2].split("_")[0]) 141 | featmap = np.reshape(featmap, [1, featmap.shape[0]]) 142 | 143 | output = self.model(torch.from_numpy(featmap), torch.from_numpy(sent_vec)) 144 | output_np = output.detach().numpy()[0][0] 145 | 146 | sentence_image_mat[k,t] = output_np[0] 147 | reg_clip_length = (end - start) * (10 ** output_np[2]) 148 | reg_mid_point = (start + end) / 2.0 + movie_length * output_np[1] 149 | reg_end = end + output_np[2] 150 | reg_start = start + output_np[1] 151 | 152 | sentence_image_reg_mat[k, t, 0] = reg_start 153 | sentence_image_reg_mat[k, t, 1] = reg_end 154 | 155 | iclips = [b[0] for b in movie_clip_featmaps] 156 | sclips = [b[0] for b in movie_clip_sentences] 157 | 158 | # calculate Recall@m, IoU=n 159 | for k in range(len(IoU_thresh)): 160 | IoU=IoU_thresh[k] 161 | correct_num_10 = compute_IoU_recall_top_n_forreg(10, IoU, sentence_image_mat, sentence_image_reg_mat, sclips, iclips) 162 | correct_num_5 = compute_IoU_recall_top_n_forreg(5, IoU, sentence_image_mat, sentence_image_reg_mat, sclips, iclips) 163 | correct_num_1 = compute_IoU_recall_top_n_forreg(1, IoU, sentence_image_mat, sentence_image_reg_mat, sclips, iclips) 164 | self.print_log(movie_name+" IoU="+str(IoU)+", R@10: "+str(correct_num_10/len(sclips))+"; IoU="+str(IoU)+", R@5: "+str(correct_num_5/len(sclips))+"; IoU="+str(IoU)+", R@1: "+str(correct_num_1/len(sclips))) 165 | all_correct_num_10[k]+=correct_num_10 166 | all_correct_num_5[k]+=correct_num_5 167 | all_correct_num_1[k]+=correct_num_1 168 | all_retrievd += len(sclips) 169 | 170 | for k in range(len(IoU_thresh)): 171 | self.print_log("IoU="+str(IoU_thresh[k])+", R@10: "+str(all_correct_num_10[k]/all_retrievd)+"; IoU="+str(IoU_thresh[k])+", R@5: "+str(all_correct_num_5[k]/all_retrievd)+"; IoU="+str(IoU_thresh[k])+", R@1: "+str(all_correct_num_1[k]/all_retrievd)) 172 | with open(test_output_path, "w") as f: 173 | f.write("Step "+str(iter_step)+": IoU="+str(IoU_thresh[k])+", R@10: "+str(all_correct_num_10[k]/all_retrievd)+"; IoU="+str(IoU_thresh[k])+", R@5: "+str(all_correct_num_5[k]/all_retrievd)+"; IoU="+str(IoU_thresh[k])+", R@1: "+str(all_correct_num_1[k]/all_retrievd)+"\n") 174 | 175 | 176 | def record_time(self): 177 | self.cur_time = time.time() 178 | return self.cur_time 179 | 180 | def split_time(self): 181 | split_time = time.time() - self.cur_time 182 | self.record_time() 183 | return split_time 184 | 185 | def print_log(self, line, print_time=True): 186 | if print_time: 187 | localtime = time.asctime(time.localtime(time.time())) 188 | line = "[ " + localtime + ' ] ' + line 189 | print(line) 190 | if cfg.save_log: 191 | with open(cfg.log_dir, 'a') as f: 192 | print(line, file=f) 193 | 194 | def print_time(self): 195 | localtime = time.asctime(time.localtime(time.time())) 196 | self.print_log("Local current time : " + localtime) 197 | 198 | 199 | if __name__ == '__main__': 200 | processor = Processor() 201 | processor.train() 202 | 203 | 204 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from sklearn.metrics import average_precision_score 3 | import numpy as np 4 | import operator 5 | 6 | 7 | 8 | def dense_to_one_hot(labels_dense, num_classes): 9 | """Convert class labels from scalars to one-hot vectors.""" 10 | num_labels = labels_dense.shape[0] 11 | index_offset = np.arange(num_labels) * num_classes 12 | labels_one_hot = np.zeros((num_labels, num_classes)) 13 | labels_one_hot.flat[index_offset+labels_dense.ravel()] = 1 14 | return labels_one_hot 15 | 16 | def compute_ap(class_score_matrix, labels): 17 | num_classes=class_score_matrix.shape[1] 18 | one_hot_labels=dense_to_one_hot(labels, num_classes) 19 | predictions=np.array(class_score_matrix>0, dtype="int32") 20 | average_precision=[] 21 | for i in range(num_classes): 22 | ps=average_precision_score(one_hot_labels[:, i], class_score_matrix[:, i]) 23 | # if not np.isnan(ps): 24 | average_precision.append(ps) 25 | return np.array(average_precision) 26 | 27 | def calculate_IoU(i0,i1): 28 | union = (min(i0[0], i1[0]), max(i0[1], i1[1])) 29 | inter = (max(i0[0], i1[0]), min(i0[1], i1[1])) 30 | iou = 1.0*(inter[1]-inter[0])/(union[1]-union[0]) 31 | return iou 32 | 33 | def nms_temporal(x1,x2,s, overlap): 34 | pick = [] 35 | assert len(x1)==len(s) 36 | assert len(x2)==len(s) 37 | if len(x1)==0: 38 | return pick 39 | 40 | union = list(map(operator.sub, x2, x1)) # union = x2-x1 41 | I = [i[0] for i in sorted(enumerate(s), key=lambda x:x[1])] # sort and get index 42 | 43 | while len(I)>0: 44 | i = I[-1] 45 | pick.append(i) 46 | 47 | xx1 = [max(x1[i],x1[j]) for j in I[:-1]] 48 | xx2 = [min(x2[i],x2[j]) for j in I[:-1]] 49 | inter = [max(0.0, k2-k1) for k1, k2 in zip(xx1, xx2)] 50 | o = [inter[u]/(union[i] + union[I[u]] - inter[u]) for u in range(len(I)-1)] 51 | I_new = [] 52 | for j in range(len(o)): 53 | if o[j] <=overlap: 54 | I_new.append(I[j]) 55 | I = I_new 56 | return pick 57 | 58 | ''' 59 | compute recall at certain IoU 60 | ''' 61 | def compute_IoU_recall_top_n_forreg(top_n, iou_thresh, sentence_image_mat, sentence_image_reg_mat, sclips, iclips): 62 | correct_num = 0.0 63 | for k in range(sentence_image_mat.shape[0]): 64 | gt = sclips[k] 65 | gt_start = float(gt.split("_")[1]) 66 | gt_end = float(gt.split("_")[2]) 67 | #print gt +" "+str(gt_start)+" "+str(gt_end) 68 | sim_v = [v for v in sentence_image_mat[k]] 69 | starts = [s for s in sentence_image_reg_mat[k,:,0]] 70 | ends = [e for e in sentence_image_reg_mat[k,:,1]] 71 | picks = nms_temporal(starts,ends, sim_v, iou_thresh-0.05) 72 | #sim_argsort=np.argsort(sim_v)[::-1][0:top_n] 73 | if top_n=iou_thresh: 79 | correct_num+=1 80 | break 81 | return correct_num 82 | 83 | 84 | import datetime 85 | def print_log(msg='', end='\n'): 86 | now = datetime.datetime.now() 87 | t = str(now.year) + '/' + str(now.month) + '/' + str(now.day) + ' ' \ 88 | + str(now.hour).zfill(2) + ':' + str(now.minute).zfill(2) + ':' + str(now.second).zfill(2) 89 | 90 | if isinstance(msg, str): 91 | lines = msg.split('\n') 92 | else: 93 | lines = [msg] 94 | 95 | for line in lines: 96 | if line == lines[-1]: 97 | print('[' + t + '] ' + str(line), end=end) 98 | else: 99 | print('[' + t + '] ' + str(line)) 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | -------------------------------------------------------------------------------- /video_allframes_info.pkl: -------------------------------------------------------------------------------- 1 | (dp0 2 | S's30-d43' 3 | p1 4 | I19807 5 | sS's30-d40' 6 | p2 7 | I4911 8 | sS's30-d41' 9 | p3 10 | I19163 11 | sS's22-d55' 12 | p4 13 | I1530 14 | sS's25-d52' 15 | p5 16 | I5676 17 | sS's25-d51' 18 | p6 19 | I3409 20 | sS's28-d39' 21 | p7 22 | I22908 23 | sS's35-d55' 24 | p8 25 | I6758 26 | sS's24-d41' 27 | p9 28 | I17753 29 | sS's24-d40' 30 | p10 31 | I6322 32 | sS's24-d48' 33 | p11 34 | I2397 35 | sS's21-d29' 36 | p12 37 | I1881 38 | sS's21-d39' 39 | p13 40 | I3774 41 | sS's22-d29' 42 | p14 43 | I3253 44 | sS's33-d27' 45 | p15 46 | I20003 47 | sS's22-d46' 48 | p16 49 | I14844 50 | sS's32-d70' 51 | p17 52 | I7357 53 | sS's13-d52' 54 | p18 55 | I2767 56 | sS's22-d48' 57 | p19 58 | I2467 59 | sS's13-d54' 60 | p20 61 | I8199 62 | sS's37-d39' 63 | p21 64 | I9169 65 | sS's25-d23' 66 | p22 67 | I3808 68 | sS's27-d70' 69 | p23 70 | I8218 71 | sS's36-d50' 72 | p24 73 | I5219 74 | sS's26-d70' 75 | p25 76 | I5847 77 | sS's22-d53' 78 | p26 79 | I6013 80 | sS's21-d23' 81 | p27 82 | I3452 83 | sS's21-d21' 84 | p28 85 | I4450 86 | sS's24-d53' 87 | p29 88 | I5237 89 | sS's21-d28' 90 | p30 91 | I3693 92 | sS's32-d69' 93 | p31 94 | I12409 95 | sS's35-d48' 96 | p32 97 | I7628 98 | sS's31-d28' 99 | p33 100 | I5841 101 | sS's31-d25' 102 | p34 103 | I2650 104 | sS's22-d25' 105 | p35 106 | I1621 107 | sS's29-d42' 108 | p36 109 | I22198 110 | sS's35-d40' 111 | p37 112 | I13864 113 | sS's23-d31' 114 | p38 115 | I3216 116 | sS's17-d48' 117 | p39 118 | I2273 119 | sS's36-d42' 120 | p40 121 | I18133 122 | sS's34-d34' 123 | p41 124 | I17675 125 | sS's35-d41' 126 | p42 127 | I27306 128 | sS's34-d41' 129 | p43 130 | I15419 131 | sS's37-d46' 132 | p44 133 | I13677 134 | sS's21-d53' 135 | p45 136 | I4683 137 | sS's30-d26' 138 | p46 139 | I17130 140 | sS's24-d28' 141 | p47 142 | I7230 143 | sS's21-d55' 144 | p48 145 | I2133 146 | sS's14-d35' 147 | p49 148 | I2421 149 | sS's30-d29' 150 | p50 151 | I6115 152 | sS's36-d70' 153 | p51 154 | I7055 155 | sS's23-d51' 156 | p52 157 | I6830 158 | sS's31-d31' 159 | p53 160 | I6273 161 | sS's23-d54' 162 | p54 163 | I9474 164 | sS's24-d34' 165 | p55 166 | I5430 167 | sS's17-d55' 168 | p56 169 | I2335 170 | sS's27-d50' 171 | p57 172 | I2430 173 | sS's17-d53' 174 | p58 175 | I7031 176 | sS's22-d43' 177 | p59 178 | I3315 179 | sS's25-d69' 180 | p60 181 | I14351 182 | sS's27-d34' 183 | p61 184 | I1772 185 | sS's23-d46' 186 | p62 187 | I10819 188 | sS's21-d42' 189 | p63 190 | I6090 191 | sS's21-d43' 192 | p64 193 | I4033 194 | sS's23-d42' 195 | p65 196 | I12416 197 | sS's21-d45' 198 | p66 199 | I2866 200 | sS's14-d26' 201 | p67 202 | I12483 203 | sS's14-d27' 204 | p68 205 | I6134 206 | sS's21-d40' 207 | p69 208 | I3238 209 | sS's26-d26' 210 | p70 211 | I41240 212 | sS's26-d23' 213 | p71 214 | I2842 215 | sS's17-d42' 216 | p72 217 | I15601 218 | sS's34-d28' 219 | p73 220 | I11816 221 | sS's23-d45' 222 | p74 223 | I4184 224 | sS's29-d52' 225 | p75 226 | I5871 227 | sS's15-d70' 228 | p76 229 | I8315 230 | sS's27-d45' 231 | p77 232 | I4757 233 | sS's29-d50' 234 | p78 235 | I2847 236 | sS's27-d29' 237 | p79 238 | I8024 239 | sS's21-d35' 240 | p80 241 | I1875 242 | sS's27-d21' 243 | p81 244 | I4408 245 | sS's37-d25' 246 | p82 247 | I1436 248 | sS's37-d21' 249 | p83 250 | I11130 251 | sS's23-d34' 252 | p84 253 | I5231 254 | sS's23-d39' 255 | p85 256 | I5444 257 | sS's37-d29' 258 | p86 259 | I3012 260 | sS's27-d54' 261 | p87 262 | I10969 263 | sS's14-d51' 264 | p88 265 | I8323 266 | sS's28-d46' 267 | p89 268 | I11162 269 | sS's23-d21' 270 | p90 271 | I4866 272 | sS's13-d48' 273 | p91 274 | I3228 275 | sS's32-d27' 276 | p92 277 | I22541 278 | sS's13-d21' 279 | p93 280 | I2955 281 | sS's13-d25' 282 | p94 283 | I2951 284 | sS's13-d28' 285 | p95 286 | I5629 287 | sS's14-d46' 288 | p96 289 | I9249 290 | sS's14-d43' 291 | p97 292 | I4667 293 | sS's22-d34' 294 | p98 295 | I2401 296 | sS's21-d50' 297 | p99 298 | I1607 299 | sS's24-d23' 300 | p100 301 | I5735 302 | sS's29-d31' 303 | p101 304 | I3164 305 | sS's17-d69' 306 | p102 307 | I12454 308 | sS's15-d26' 309 | p103 310 | I33264 311 | sS's29-d39' 312 | p104 313 | I9630 314 | sS's32-d55' 315 | p105 316 | I5689 317 | sS's32-d52' 318 | p106 319 | I5508 320 | sS's13-d31' 321 | p107 322 | I5093 323 | sS's36-d31' 324 | p108 325 | I7071 326 | sS's36-d43' 327 | p109 328 | I26924 329 | sS's33-d45' 330 | p110 331 | I7698 332 | sS's22-d26' 333 | p111 334 | I9365 335 | sS's28-d27' 336 | p112 337 | I17235 338 | sS's28-d25' 339 | p113 340 | I3220 341 | sS's30-d53' 342 | p114 343 | I13379 344 | sS's30-d52' 345 | p115 346 | I7362 347 | sS's33-d54' 348 | p116 349 | I10898 350 | sS's22-d35' 351 | p117 352 | I2555 353 | sS's15-d35' 354 | p118 355 | I6472 356 | sS's33-d50' 357 | p119 358 | I3466 359 | sS's13-d40' 360 | p120 361 | I3347 362 | sS's13-d45' 363 | p121 364 | I4354 365 | sS's36-d23' 366 | p122 367 | I19412 368 | sS's25-d35' 369 | p123 370 | I3900 371 | sS's28-d51' 372 | p124 373 | I19725 374 | sS's26-d69' 375 | p125 376 | I20978 377 | sS's36-d27' 378 | p126 379 | I4390 380 | sS's34-d69' 381 | p127 382 | I12952 383 | s. --------------------------------------------------------------------------------