├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── data └── thumos_14_annotations │ ├── Test_Annotation.csv │ ├── Val_Annotation.csv │ └── thumos14_test_groundtruth.csv ├── dataset.py ├── inference.py ├── loss_function.py ├── model.py ├── prior_box.py ├── train.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 rhee 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SSAD_pytorch 2 | Code for "Single Shot Temporal Action Detection" 3 | 4 | Download features by visiting [THUMOS14-ANet-features](https://drive.google.com/file/d/1gCNYPf6Fxeht1HO3eIzuyj84gtbkPETx/view?usp=drive_open) 5 | 6 | I am confused that this code(pytorch) cannot reproduce the results. There may be a bug in the code. I will fix it!!! 7 | 8 | For [SSAD_tensorflow](https://github.com/Rheelt/SSAD_tensorflow), mAP@0.5=37.8. 9 | 10 | Disclaimer: This is the reproduced code, not an original code of the paper. 11 | 12 | Ref [Decouple-SSAD](https://github.com/HYPJUDY/Decouple-SSAD) 13 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class Config(object): 2 | """ 3 | define a class to store parameters, 4 | """ 5 | 6 | def __init__(self): 7 | self.name = "SSAD" 8 | self.seed = 5 9 | self.feature_path = "~/THUMOS14_ANET_feature/" 10 | self.unit_size = 5 11 | self.feature_dim = 3072 12 | self.ioa_ratio_threshold = 0.9 13 | self.window_size = 128 14 | self.window_step = 64 # 75% overlap 15 | self.inference_window_step = 64 # 50% overlap 16 | self.num_classes = 21 17 | self.batch_size = 48 18 | 19 | # self.base_scale = {"AL1": 1. / 16, "AL2": 1. / 8, "AL3": 1. / 4} 20 | # self.num_cells = {"AL1": 16, "AL2": 8, "AL3": 4} 21 | # self.aspect_ratios = {"AL1": [0.5, 0.75, 1, 1.5, 2], 22 | # "AL2": [0.5, 0.75, 1, 1.5, 2], 23 | # "AL3": [0.5, 0.75, 1, 1.5, 2]} 24 | self.layer_names = ['AL1', 'AL2', 'AL3'] 25 | self.base_scale = [1. / 16, 1. / 8, 1. / 4] 26 | self.num_cells = [16, 8, 4] 27 | self.aspect_ratios = [[0.5, 0.75, 1., 1.5, 2.], 28 | [0.5, 0.75, 1., 1.5, 2.], 29 | [0.5, 0.75, 1., 1.5, 2.]] 30 | 31 | self.num_anchors = 5 32 | self.training_lr = 0.0001 33 | self.weight_decay = 0.0 34 | self.checkpoint_path = "./checkpoint/" 35 | self.epoch = 35 36 | self.negative_ratio = 1. 37 | self.lr_scheduler_step = 30 38 | self.lr_scheduler_gama = 0.1 39 | 40 | self.outdf_columns = ['xmin', 'xmax', 'conf', 'score_0', 'score_1', 'score_2', 41 | 'score_3', 'score_4', 'score_5', 'score_6', 'score_7', 'score_8', 42 | 'score_9', 'score_10', 'score_11', 'score_12', 'score_13', 'score_14', 43 | 'score_15', 'score_16', 'score_17', 'score_18', 'score_19', 'score_20'] 44 | self.class_real = [7, 9, 12, 21, 22, 23, 24, 26, 31, 33, 45 | 36, 40, 45, 51, 68, 79, 85, 92, 93, 97] 46 | self.nms_threshold = 0.2 47 | # when process results, remove confident negative anchors by previous 48 | self.filter_neg_threshold = 0.7 49 | # when process results, remove confident low overlap (conf) anchors by previous 50 | self.filter_conf_threshold = 0.3 51 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | import torch 5 | import torch.utils.data as data 6 | from utils import ioa_with_anchors 7 | from tqdm import tqdm 8 | from config import Config 9 | 10 | 11 | class THUMOSDataset(data.Dataset): 12 | 13 | def __init__(self, config, mode='Val'): 14 | self.feature_path = config.feature_path 15 | self.unit_size = config.unit_size 16 | self.feature_dim = config.feature_dim 17 | self.ioa_ratio_threshold = config.ioa_ratio_threshold 18 | self.window_size = config.window_size 19 | self.window_step = config.window_step 20 | self.num_classes = config.num_classes # action categroies + BG for THUMOS14 is 21 21 | self.mode = mode 22 | self.anno_df = pd.read_csv("./data/thumos_14_annotations/" + mode + "_Annotation.csv") 23 | self.videoNameList = list(set(self.anno_df.video.values[:])) 24 | self.sampels = [] 25 | self.class_real = [0] + [7, 9, 12, 21, 22, 23, 24, 26, 31, 33, 26 | 36, 40, 45, 51, 68, 79, 85, 92, 93, 97] # THUMOS14 calss label idx 27 | self._preparedata() 28 | print( 29 | 'The number of {} dataset video is {} and the number of samples is {}'.format(mode, len(self.videoNameList), 30 | len(self.sampels))) 31 | 32 | def _preparedata(self): 33 | print('wait...prepare data') 34 | for videoName in tqdm(self.videoNameList): 35 | video_annoDf = self.anno_df[self.anno_df.video == videoName] 36 | video_annoDf = video_annoDf[video_annoDf.type_idx != 0] # 0 for Ambiguous 37 | 38 | gt_xmins = video_annoDf.startFrame.values[:] 39 | gt_xmaxs = video_annoDf.endFrame.values[:] 40 | gt_type_idx = video_annoDf.type_idx.values[:] 41 | 42 | rgb_feature, flow_feature = self._getVideoFeature(videoName, self.mode.lower()) 43 | 44 | numSnippet = min(rgb_feature.shape[0], flow_feature.shape[0]) 45 | frameList = [1 + self.unit_size * i for i in range(numSnippet)] 46 | df_data = np.concatenate((rgb_feature, flow_feature), axis=1) 47 | df_snippet = frameList 48 | window_size = self.window_size 49 | stride = self.window_step 50 | n_window = (numSnippet + stride - window_size) / stride 51 | windows_start = [i * stride for i in range(int(n_window))] 52 | if numSnippet < window_size: 53 | windows_start = [0] 54 | tmp_data = np.zeros((window_size - numSnippet, self.feature_dim)) 55 | df_data = np.concatenate((df_data, tmp_data), axis=0) 56 | df_snippet.extend([df_snippet[-1] + self.unit_size * (i + 1) for i in range(window_size - numSnippet)]) 57 | elif numSnippet - windows_start[-1] - window_size > 30: 58 | windows_start.append(numSnippet - window_size) 59 | 60 | snippet_xmin = df_snippet 61 | snippet_xmax = df_snippet[1:] 62 | snippet_xmax.append(df_snippet[-1] + self.unit_size) 63 | for start in windows_start: 64 | tmp_data = df_data[start:start + window_size, :] 65 | tmp_anchor_xmins = snippet_xmin[start:start + window_size] 66 | tmp_anchor_xmaxs = snippet_xmax[start:start + window_size] 67 | tmp_gt_bbox = [] 68 | tmp_gt_class = [] 69 | tmp_ioa_list = [] 70 | for idx in range(len(gt_xmins)): 71 | tmp_ioa = ioa_with_anchors(gt_xmins[idx], gt_xmaxs[idx], tmp_anchor_xmins[0], tmp_anchor_xmaxs[-1]) 72 | tmp_ioa_list.append(tmp_ioa) 73 | if tmp_ioa > 0: 74 | # gt bbox info 75 | corrected_start = max(gt_xmins[idx], tmp_anchor_xmins[0]) - tmp_anchor_xmins[0] 76 | corrected_end = min(gt_xmaxs[idx], tmp_anchor_xmaxs[-1]) - tmp_anchor_xmins[0] 77 | tmp_gt_bbox.append([float(corrected_start) / (self.window_size * self.unit_size), 78 | float(corrected_end) / (self.window_size * self.unit_size)]) 79 | # gt class label 80 | one_hot = [0] * self.num_classes 81 | one_hot[self.class_real.index(gt_type_idx[idx])] = 1 82 | tmp_gt_class.append(one_hot) 83 | if len(tmp_gt_bbox) > 0 and max(tmp_ioa_list) > self.ioa_ratio_threshold: 84 | # the overlap region is corrected 85 | tmp_results = [torch.transpose(torch.Tensor(tmp_data), 0, 1), np.array(tmp_gt_bbox), 86 | np.array(tmp_gt_class)] 87 | self.sampels.append(tmp_results) 88 | 89 | def _getVideoFeature(self, videoname, subset): 90 | appearance_path = '~/THUMOS14_ANET_feature/{}_appearance/'.format(subset) 91 | denseflow_path = '~/THUMOS14_ANET_feature/{}_denseflow/'.format(subset) 92 | rgb_feature = np.load(appearance_path + videoname + '.npy') 93 | flow_feature = np.load(denseflow_path + videoname + '.npy') 94 | 95 | return rgb_feature, flow_feature 96 | 97 | def __getitem__(self, index): 98 | return self.sampels[index] 99 | 100 | def __len__(self): 101 | return len(self.sampels) 102 | 103 | 104 | def train_collate_fn(batch): 105 | batch_start_index = [0] 106 | batch_gt_bbox = [] 107 | batch_gt_class = [] 108 | for iitem in batch: 109 | batch_start_index.append(batch_start_index[-1] + iitem[1].shape[0]) 110 | batch_gt_bbox.append(iitem[1]) 111 | batch_gt_class.append(iitem[2]) 112 | batch_start_index = np.array(batch_start_index, dtype=np.int32) 113 | batch_data = torch.cat([x[0].unsqueeze(0) for x in batch]) 114 | batch_gt_bbox = np.vstack(batch_gt_bbox).astype(np.float32) 115 | batch_gt_class = np.vstack(batch_gt_class).astype(np.int32) 116 | 117 | return batch_data, batch_gt_bbox, batch_gt_class, batch_start_index 118 | 119 | 120 | class THUMOSInferenceDataset(data.Dataset): 121 | 122 | def __init__(self, config): 123 | self.feature_path = config.feature_path 124 | self.unit_size = config.unit_size 125 | self.feature_dim = config.feature_dim 126 | self.window_size = config.window_size 127 | self.inference_window_step = config.inference_window_step 128 | self.mode = 'Test' 129 | self.anno_df = pd.read_csv("./data/thumos_14_annotations/" + self.mode + "_Annotation.csv") 130 | self.videoNameList = list(set(self.anno_df.video.values[:])) 131 | self.sampels = [] 132 | self._preparedata() 133 | print( 134 | 'The number of {} dataset video is {} and the number of samples is {}'.format(self.mode, 135 | len(self.videoNameList), 136 | len(self.sampels))) 137 | 138 | def _preparedata(self): 139 | print('wait...prepare data') 140 | for videoName in tqdm(self.videoNameList): 141 | rgb_feature, flow_feature = self._getVideoFeature(videoName, self.mode.lower()) 142 | 143 | numSnippet = min(rgb_feature.shape[0], flow_feature.shape[0]) 144 | frameList = [1 + self.unit_size * i for i in range(numSnippet)] 145 | df_data = np.concatenate((rgb_feature, flow_feature), axis=1) 146 | df_snippet = frameList 147 | window_size = self.window_size 148 | stride = self.inference_window_step 149 | n_window = (numSnippet + stride - window_size) / stride 150 | windows_start = [i * stride for i in range(int(n_window))] 151 | if numSnippet < window_size: 152 | windows_start = [0] 153 | tmp_data = np.zeros((window_size - numSnippet, self.feature_dim)) 154 | df_data = np.concatenate((df_data, tmp_data), axis=0) 155 | df_snippet.extend([df_snippet[-1] + self.unit_size * (i + 1) for i in range(window_size - numSnippet)]) 156 | else: 157 | windows_start.append(numSnippet - window_size) 158 | 159 | snippet_xmin = df_snippet 160 | for start in windows_start: 161 | tmp_data = df_data[start:start + window_size, :] 162 | tmp_anchor_xmins = snippet_xmin[start:start + window_size] 163 | tmp_results = [torch.transpose(torch.Tensor(tmp_data), 0, 1), videoName, tmp_anchor_xmins[0]] 164 | self.sampels.append(tmp_results) 165 | 166 | def _getVideoFeature(self, videoname, subset): 167 | appearance_path = '~/THUMOS14_ANET_feature/{}_appearance/'.format(subset) 168 | denseflow_path = '~/THUMOS14_ANET_feature/{}_denseflow/'.format(subset) 169 | rgb_feature = np.load(appearance_path + videoname + '.npy') 170 | flow_feature = np.load(denseflow_path + videoname + '.npy') 171 | 172 | return rgb_feature, flow_feature 173 | 174 | def __getitem__(self, index): 175 | return self.sampels[index] 176 | 177 | def __len__(self): 178 | return len(self.sampels) 179 | 180 | 181 | def inference_collate_fn(batch): 182 | batch_data = torch.cat([x[0].unsqueeze(0) for x in batch]) 183 | batch_video_names = [x[1] for x in batch] 184 | batch_window_start = [x[2] for x in batch] 185 | return batch_data, batch_video_names, batch_window_start 186 | 187 | 188 | if __name__ == '__main__': 189 | 190 | config = Config() 191 | train_loader = torch.utils.data.DataLoader(THUMOSInferenceDataset(config), 192 | batch_size=48, shuffle=False, 193 | num_workers=8, pin_memory=True, drop_last=False, 194 | collate_fn=inference_collate_fn) 195 | for idx, (batch_data, batch_video_names, batch_window_start) in enumerate(train_loader): 196 | print(idx) 197 | print(batch_data.shape[0]) 198 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn.functional as F 4 | import random 5 | import numpy as np 6 | import pandas as pd 7 | from config import Config 8 | from dataset import THUMOSInferenceDataset, inference_collate_fn 9 | from model import SSAD 10 | from utils import post_process, temporal_nms 11 | 12 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 13 | device = torch.device('cuda') 14 | torch.backends.cudnn.benchmark = True 15 | torch.backends.cudnn.deterministic = True 16 | torch.set_default_tensor_type('torch.FloatTensor') 17 | 18 | 19 | def inference(config): 20 | # setup data_loader instances 21 | inference_loader = torch.utils.data.DataLoader(THUMOSInferenceDataset(config), 22 | batch_size=config.batch_size, shuffle=False, 23 | num_workers=8, pin_memory=True, drop_last=False, 24 | collate_fn=inference_collate_fn) 25 | 26 | # build model architecture and load checkpoint 27 | model = SSAD(config).to(device) 28 | checkpoint = torch.load(config.checkpoint_path + "/model_best.pth.tar") 29 | model.load_state_dict(checkpoint['state_dict']) 30 | model = model.to(device) 31 | model.eval() 32 | 33 | ''' 34 | ['xmin', 'xmax', 'conf', 'score_0', 'score_1', 'score_2', 35 | 'score_3', 'score_4', 'score_5', 'score_6', 'score_7', 'score_8', 36 | 'score_9', 'score_10', 'score_11', 'score_12', 'score_13', 'score_14', 37 | 'score_15', 'score_16', 'score_17', 'score_18', 'score_19', 'score_20'] 38 | ''' 39 | results = [] 40 | results_name = [] 41 | with torch.no_grad(): 42 | for n_iter, (batch_data, batch_video_names, batch_window_start) in enumerate(inference_loader): 43 | batch_data = batch_data.to(device) 44 | output_x, output_w, output_scores, output_labels = model(batch_data, device) 45 | 46 | output_labels = F.softmax(output_labels, dim=1) 47 | output_x = output_x.cpu().detach().numpy() 48 | output_w = output_w.cpu().detach().numpy() 49 | output_scores = output_scores.cpu().detach().numpy() 50 | output_labels = output_labels.cpu().detach().numpy() 51 | output_min = output_x - output_w / 2 52 | output_max = output_x + output_w / 2 53 | for ii in range(len(batch_video_names)): 54 | video_name = batch_video_names[ii] 55 | window_start = batch_window_start[ii] 56 | a_min = output_min[ii, :] 57 | a_max = output_max[ii, :] 58 | a_scores = output_scores[ii, :] 59 | a_labels = output_labels[ii, :, :] 60 | for jj in range(output_min.shape[-1]): 61 | corrected_min = max(a_min[jj] * config.window_size * config.unit_size, 0.) + window_start 62 | corrected_max = min(a_max[jj] * config.window_size * config.unit_size, 63 | config.window_size * config.unit_size) + window_start 64 | results_name.append([video_name]) 65 | results.append([corrected_min, corrected_max, a_scores[jj]] + a_labels[:, jj].tolist()) 66 | results_name = np.stack(results_name) 67 | results = np.stack(results) 68 | df = pd.DataFrame(results, columns=config.outdf_columns) 69 | df['video_name'] = results_name 70 | result_file = './results.txt' 71 | if os.path.isfile(result_file): 72 | os.remove(result_file) 73 | df = df[df.score_0 < config.filter_neg_threshold] 74 | df = df[df.conf > config.filter_conf_threshold] 75 | video_name_list = list(set(df.video_name.values[:])) 76 | 77 | for video_name in video_name_list: 78 | tmpdf = df[df.video_name == video_name] 79 | tmpdf = post_process(tmpdf, config) 80 | 81 | temporal_nms(config, tmpdf, result_file, video_name) 82 | 83 | 84 | if __name__ == '__main__': 85 | config = Config() 86 | random.seed(config.seed) 87 | np.random.seed(config.seed) 88 | torch.manual_seed(config.seed) 89 | torch.cuda.manual_seed(config.seed) 90 | torch.cuda.manual_seed_all(config.seed) 91 | inference(config) 92 | -------------------------------------------------------------------------------- /loss_function.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def SSAD_loss_function(all_prediction_x, all_prediction_w, all_prediction_score, all_prediction_label, 6 | batch_match_x, batch_match_w, batch_match_scores, batch_match_labels, device, config): 7 | # calc Loss 8 | pmask = torch.ge(batch_match_scores, 0.5).float() 9 | num_positive = torch.sum(pmask) 10 | # print('num_positive', num_positive) 11 | num_entries = all_prediction_x.shape[0] * all_prediction_x.shape[1] 12 | 13 | hmask = batch_match_scores < 0.5 14 | hmask = hmask & (all_prediction_score > 0.5) 15 | hmask = hmask.float() 16 | num_hard = torch.sum(hmask) 17 | 18 | r_negative = (config.negative_ratio - num_hard / num_positive) * num_positive / ( 19 | num_entries - num_positive - num_hard) 20 | r_negative = torch.min(r_negative, torch.Tensor([1.0]).to(device)) 21 | nmask = torch.rand(pmask.size()).to(device) 22 | nmask = nmask * (1. - pmask) 23 | nmask = nmask * (1. - hmask) 24 | nmask = torch.ge(nmask, 1. - r_negative).float() 25 | # print(r_negative, num_positive, num_hard, torch.sum(nmask)) 26 | # class_loss 27 | weights = pmask + nmask + hmask 28 | all_prediction_label = all_prediction_label.transpose(1, 2).contiguous().view(-1, config.num_classes) 29 | batch_match_labels = batch_match_labels.view(-1) 30 | class_loss = F.cross_entropy(all_prediction_label, batch_match_labels, reduction='none') 31 | class_loss = torch.sum(class_loss * weights.view(-1)) / torch.sum(weights) 32 | # loc_loss 33 | weights = pmask 34 | tmp_anchors_xmin = all_prediction_x - all_prediction_w / 2 35 | tmp_anchors_xmax = all_prediction_x + all_prediction_w / 2 36 | tmp_match_xmin = batch_match_x - batch_match_w / 2 37 | tmp_match_xmax = batch_match_x + batch_match_w / 2 38 | 39 | loc_loss = F.smooth_l1_loss(tmp_anchors_xmin, tmp_match_xmin, reduction='none') + F.smooth_l1_loss( 40 | tmp_anchors_xmax, tmp_match_xmax, reduction='none') 41 | loc_loss = torch.sum(loc_loss * weights) / torch.sum(weights) 42 | 43 | # conf loss 44 | weights = pmask + nmask + hmask 45 | # match_scores is from jaccard_with_anchors 46 | conf_loss = F.smooth_l1_loss(all_prediction_score, batch_match_scores, reduction='none') 47 | conf_loss = torch.sum(conf_loss * weights) / torch.sum(weights) 48 | 49 | loss = class_loss + 10. * loc_loss + 10. * conf_loss 50 | loss_dict = {"cost": loss, "class_loss": class_loss, 51 | "loc_loss": loc_loss, "overlap_loss": conf_loss} 52 | return loss_dict 53 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | import torch.nn as nn 4 | from collections import OrderedDict 5 | import numpy as np 6 | from config import Config 7 | from prior_box import PriorBox 8 | 9 | 10 | class SSAD(nn.Module): 11 | def __init__(self, config): 12 | super(SSAD, self).__init__() 13 | self.num_classes = config.num_classes 14 | self.num_anchors = config.num_anchors 15 | self.input_feature_dim = config.feature_dim 16 | self.prediction_output = self.num_anchors * (self.num_classes + 3) 17 | self.best_loss = 10000000 18 | self.prior_box = PriorBox(config) 19 | # Base Layers 20 | self.base_layers = nn.Sequential(OrderedDict([ 21 | ('conv1d_1', 22 | nn.Conv1d(in_channels=self.input_feature_dim, out_channels=512, kernel_size=9, stride=1, padding=4)), 23 | ('relu_1', nn.ReLU()), 24 | ('maxpooling1d_1', nn.MaxPool1d(kernel_size=4, stride=2, padding=1)), 25 | ('conv1d_2', nn.Conv1d(in_channels=512, out_channels=512, kernel_size=9, stride=1, padding=4)), 26 | ('relu_2', nn.ReLU()), 27 | ('maxpooling1d_2', nn.MaxPool1d(kernel_size=4, stride=2, padding=1)) 28 | ])) 29 | 30 | # Anchor Layers 31 | self.anchor_layer1 = nn.Sequential( 32 | nn.Conv1d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1), 33 | nn.ReLU()) 34 | self.anchor_layer2 = nn.Sequential( 35 | nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1), 36 | nn.ReLU()) 37 | self.anchor_layer3 = nn.Sequential( 38 | nn.Conv1d(in_channels=1024, out_channels=1024, kernel_size=3, stride=2, padding=1), 39 | nn.ReLU()) 40 | 41 | # Prediction Layers 42 | self.prediction_layer1 = nn.Conv1d(in_channels=1024, out_channels=self.prediction_output, kernel_size=3, 43 | stride=1, padding=1) 44 | self.prediction_layer2 = nn.Conv1d(in_channels=1024, out_channels=self.prediction_output, kernel_size=3, 45 | stride=1, padding=1) 46 | self.prediction_layer3 = nn.Conv1d(in_channels=1024, out_channels=self.prediction_output, kernel_size=3, 47 | stride=1, padding=1) 48 | 49 | self.reset_params() 50 | 51 | @staticmethod 52 | def weight_init(m): 53 | if isinstance(m, nn.Conv1d): 54 | init.xavier_uniform_(m.weight) 55 | # init.constant_(m.bias, 0) 56 | 57 | def reset_params(self): 58 | for i, m in enumerate(self.modules()): 59 | self.weight_init(m) 60 | 61 | def forward(self, input, device): 62 | """ 63 | Forward pass logic 64 | :return: Model output 65 | """ 66 | base_feature = self.base_layers(input) 67 | 68 | anchor1 = self.anchor_layer1(base_feature) 69 | anchor2 = self.anchor_layer2(anchor1) 70 | anchor3 = self.anchor_layer3(anchor2) 71 | 72 | prediction1 = self.prediction_layer1(anchor1) 73 | prediction2 = self.prediction_layer1(anchor2) 74 | prediction3 = self.prediction_layer1(anchor3) 75 | 76 | batch_size = prediction1.shape[0] 77 | 78 | prediction1 = prediction1.view(batch_size, -1, prediction1.shape[-1] * self.num_anchors) 79 | prediction2 = prediction2.view(batch_size, -1, prediction2.shape[-1] * self.num_anchors) 80 | prediction3 = prediction3.view(batch_size, -1, prediction3.shape[-1] * self.num_anchors) 81 | 82 | prediction1_x = prediction1[:, -2, :] 83 | prediction1_w = prediction1[:, -1, :] 84 | prediction1_x = prediction1_x * self.prior_box('AL1')[1].to(device) * 0.1 + self.prior_box('AL1')[0].to(device) 85 | prediction1_w = torch.exp(0.1 * prediction1_w) * self.prior_box('AL1')[1].to(device) 86 | prediction1_score = prediction1[:, -3, :] 87 | prediction1_score = torch.sigmoid(prediction1_score) 88 | prediction1_label = prediction1[:, :self.num_classes, :] 89 | 90 | prediction2_x = prediction2[:, -2, :] 91 | prediction2_w = prediction2[:, -1, :] 92 | prediction2_x = prediction2_x * self.prior_box('AL2')[1].to(device) * 0.1 + self.prior_box('AL2')[0].to(device) 93 | prediction2_w = torch.exp(0.1 * prediction2_w) * self.prior_box('AL2')[1].to(device) 94 | prediction2_score = prediction2[:, -3, :] 95 | prediction2_score = torch.sigmoid(prediction2_score) 96 | prediction2_label = prediction2[:, :self.num_classes, :] 97 | 98 | prediction3_x = prediction3[:, -2, :] 99 | prediction3_w = prediction3[:, -1, :] 100 | prediction3_x = prediction3_x * self.prior_box('AL3')[1].to(device) * 0.1 + self.prior_box('AL3')[0].to(device) 101 | prediction3_w = torch.exp(0.1 * prediction3_w) * self.prior_box('AL3')[1].to(device) 102 | prediction3_score = prediction3[:, -3, :] 103 | prediction3_score = torch.sigmoid(prediction3_score) 104 | prediction3_label = prediction3[:, :self.num_classes, :] 105 | 106 | all_prediction_x = torch.cat((prediction1_x, prediction2_x, prediction3_x), dim=-1) 107 | all_prediction_w = torch.cat((prediction1_w, prediction2_w, prediction3_w), dim=-1) 108 | all_prediction_score = torch.cat((prediction1_score, prediction2_score, prediction3_score), dim=-1) 109 | all_prediction_label = torch.cat((prediction1_label, prediction2_label, prediction3_label), dim=-1) 110 | 111 | return all_prediction_x, all_prediction_w, all_prediction_score, all_prediction_label 112 | 113 | 114 | if __name__ == '__main__': 115 | config = Config() 116 | model = SSAD(config) 117 | input = torch.Tensor(np.zeros(shape=(4, 3072, 128))) 118 | model(input) 119 | -------------------------------------------------------------------------------- /prior_box.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from config import Config 5 | 6 | 7 | class PriorBox(nn.Module): 8 | def __init__(self, config): 9 | super(PriorBox, self).__init__() 10 | self.layer_names = config.layer_names 11 | self.num_cells = config.num_cells 12 | self.base_scale = config.base_scale 13 | self.aspect_ratios = config.aspect_ratios 14 | self.priors_center = {} 15 | self.priors_width = {} 16 | self._generating_box() 17 | 18 | def _generating_box(self): 19 | """Generate SSAD Prior Boxes. 20 | """ 21 | for layer_name, layer_step, scale, ratios in zip(self.layer_names, self.num_cells, self.base_scale, 22 | self.aspect_ratios): 23 | width_set = [scale * ratio for ratio in ratios] 24 | center_set = [1. / layer_step * i + 0.5 / layer_step for i in range(layer_step)] 25 | width_default = [] 26 | center_default = [] 27 | for i in range(layer_step): 28 | for j in range(len(ratios)): 29 | width_default.append(width_set[j]) 30 | center_default.append(center_set[i]) 31 | width_default = np.array(width_default).reshape(1, -1) 32 | center_default = np.array(center_default).reshape(1, -1) 33 | width_default = torch.Tensor(width_default) 34 | center_default = torch.Tensor(center_default) 35 | self.priors_center.setdefault(layer_name, center_default) 36 | self.priors_width.setdefault(layer_name, width_default) 37 | 38 | def forward(self, output_name): 39 | return self.priors_center[output_name], self.priors_width[output_name] 40 | 41 | 42 | if __name__ == '__main__': 43 | config = Config() 44 | priorBox = PriorBox(config) 45 | priorBox('AL1') 46 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | print(torch.__version__) 4 | import torch.optim as optim 5 | import random 6 | import numpy as np 7 | from config import Config 8 | from dataset import THUMOSDataset, train_collate_fn 9 | from model import SSAD 10 | from utils import ensure_dir, build_taeget 11 | from loss_function import SSAD_loss_function 12 | 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "1" 14 | device = torch.device('cuda') 15 | # torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | torch.set_default_tensor_type('torch.FloatTensor') 18 | 19 | 20 | def main(config): 21 | # setup data_loader instances 22 | train_loader = torch.utils.data.DataLoader(THUMOSDataset(config, mode='Val'), 23 | batch_size=config.batch_size, shuffle=True, 24 | num_workers=8, pin_memory=True, drop_last=True, 25 | collate_fn=train_collate_fn) 26 | val_loader = torch.utils.data.DataLoader(THUMOSDataset(config, mode='Test'), 27 | batch_size=config.batch_size, shuffle=False, 28 | num_workers=8, pin_memory=True, drop_last=True, 29 | collate_fn=train_collate_fn) 30 | 31 | # build model architecture 32 | model = SSAD(config).to(device) 33 | 34 | # build optimizer, learning rate scheduler. delete every lines containing lr_scheduler for disabling scheduler 35 | # trainable_params = filter(lambda p: p.requires_grad, model.parameters()) 36 | optimizer = optim.Adam(model.parameters(), lr=config.training_lr, weight_decay=config.weight_decay) 37 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=config.lr_scheduler_step, 38 | gamma=config.lr_scheduler_gama) 39 | 40 | # Save configuration file into checkpoint directory: 41 | ensure_dir(config.checkpoint_path) 42 | 43 | for epoch in range(config.epoch): 44 | scheduler.step() 45 | train_epoch(train_loader, model, optimizer, epoch, config) 46 | test_epoch(val_loader, model, epoch, config) 47 | 48 | 49 | def train_epoch(data_loader, model, optimizer, epoch, config): 50 | model.train() 51 | epoch_cost = 0. 52 | epoch_class_loss = 0. 53 | epoch_overlap_loss = 0. 54 | epoch_loc_loss = 0. 55 | for n_iter, (batch_data, batch_gt_bbox, batch_gt_class, batch_start_index) in enumerate(data_loader): 56 | batch_data = batch_data.to(device) 57 | 58 | all_prediction_x, all_prediction_w, all_prediction_score, all_prediction_label = model(batch_data, device) 59 | 60 | all_prediction_x_np = all_prediction_x.data.cpu().numpy() 61 | all_prediction_w_np = all_prediction_w.data.cpu().numpy() 62 | batch_match_x, batch_match_w, batch_match_scores, batch_match_labels = build_taeget(all_prediction_x_np, 63 | all_prediction_w_np, 64 | batch_gt_bbox, 65 | batch_gt_class, 66 | batch_start_index, config) 67 | batch_match_x = torch.Tensor(batch_match_x).to(device) 68 | batch_match_w = torch.Tensor(batch_match_w).to(device) 69 | batch_match_scores = torch.Tensor(batch_match_scores).to(device) 70 | batch_match_labels = torch.LongTensor(batch_match_labels).to(device) 71 | 72 | loss = SSAD_loss_function(all_prediction_x, all_prediction_w, all_prediction_score, all_prediction_label, 73 | batch_match_x, batch_match_w, batch_match_scores, batch_match_labels, device, 74 | config) 75 | cost = loss["cost"] 76 | 77 | optimizer.zero_grad() 78 | cost.backward() 79 | optimizer.step() 80 | 81 | epoch_class_loss += loss["class_loss"].cpu().detach().numpy() 82 | epoch_overlap_loss += loss["overlap_loss"].cpu().detach().numpy() 83 | epoch_loc_loss += loss["loc_loss"].cpu().detach().numpy() 84 | epoch_cost += loss["cost"].cpu().detach().numpy() 85 | print( 86 | "SSAD training loss(epoch %d): class - %.05f, overlap - %.05f, loc - %.05f, cost - %.05f" % ( 87 | epoch, epoch_class_loss / (n_iter + 1), 88 | epoch_overlap_loss / (n_iter + 1), 89 | epoch_loc_loss / (n_iter + 1), epoch_cost / (n_iter + 1))) 90 | 91 | 92 | def test_epoch(data_loader, model, epoch, config): 93 | model.eval() 94 | epoch_cost = 0. 95 | epoch_class_loss = 0. 96 | epoch_overlap_loss = 0. 97 | epoch_loc_loss = 0. 98 | for n_iter, (batch_data, batch_gt_bbox, batch_gt_class, batch_start_index) in enumerate(data_loader): 99 | batch_data = batch_data.to(device) 100 | 101 | all_prediction_x, all_prediction_w, all_prediction_score, all_prediction_label = model(batch_data, device) 102 | 103 | all_prediction_x_np = all_prediction_x.data.cpu().numpy() 104 | all_prediction_w_np = all_prediction_w.data.cpu().numpy() 105 | batch_match_x, batch_match_w, batch_match_scores, batch_match_labels = build_taeget(all_prediction_x_np, 106 | all_prediction_w_np, 107 | batch_gt_bbox, 108 | batch_gt_class, 109 | batch_start_index, config) 110 | batch_match_x = torch.Tensor(batch_match_x).to(device) 111 | batch_match_w = torch.Tensor(batch_match_w).to(device) 112 | batch_match_scores = torch.Tensor(batch_match_scores).to(device) 113 | batch_match_labels = torch.LongTensor(batch_match_labels).to(device) 114 | 115 | loss = SSAD_loss_function(all_prediction_x, all_prediction_w, all_prediction_score, all_prediction_label, 116 | batch_match_x, batch_match_w, batch_match_scores, batch_match_labels, device, 117 | config) 118 | 119 | epoch_class_loss += loss["class_loss"].cpu().detach().numpy() 120 | epoch_overlap_loss += loss["overlap_loss"].cpu().detach().numpy() 121 | epoch_loc_loss += loss["loc_loss"].cpu().detach().numpy() 122 | epoch_cost += loss["cost"].cpu().detach().numpy() 123 | print( 124 | "SSAD validation loss(epoch %d): class - %.05f, overlap - %.05f, loc - %.05f, cost - %.05f" % ( 125 | epoch, epoch_class_loss / (n_iter + 1), 126 | epoch_overlap_loss / (n_iter + 1), 127 | epoch_loc_loss / (n_iter + 1), epoch_cost / (n_iter + 1))) 128 | 129 | state = {'epoch': epoch + 1, 130 | 'state_dict': model.state_dict()} 131 | torch.save(state, config.checkpoint_path + "/model_checkpoint.pth.tar") 132 | if np.mean(epoch_cost) < model.best_loss: 133 | model.best_loss = np.mean(epoch_cost) 134 | torch.save(state, config.checkpoint_path + "/model_best.pth.tar") 135 | 136 | 137 | if __name__ == '__main__': 138 | config = Config() 139 | random.seed(config.seed) 140 | np.random.seed(config.seed) 141 | torch.manual_seed(config.seed) 142 | torch.cuda.manual_seed(config.seed) 143 | torch.cuda.manual_seed_all(config.seed) 144 | main(config) 145 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pandas as pd 4 | 5 | 6 | def iou_with_anchors(anchors_min, anchors_max, box_min, box_max): 7 | """Compute jaccard score between a box and the anchors. 8 | """ 9 | len_anchors = anchors_max - anchors_min 10 | int_xmin = np.maximum(anchors_min, box_min) 11 | int_xmax = np.minimum(anchors_max, box_max) 12 | inter_len = np.maximum(int_xmax - int_xmin, 0.) 13 | union_len = len_anchors - inter_len + box_max - box_min 14 | # print inter_len,union_len 15 | jaccard = np.divide(inter_len, union_len) 16 | return jaccard 17 | 18 | 19 | def ioa_with_anchors(anchors_min, anchors_max, box_min, box_max): 20 | """Compute intersection between score a box and the anchors. 21 | """ 22 | len_anchors = anchors_max - anchors_min 23 | int_xmin = np.maximum(anchors_min, box_min) 24 | int_xmax = np.minimum(anchors_max, box_max) 25 | inter_len = np.maximum(int_xmax - int_xmin, 0.) 26 | scores = np.divide(inter_len, len_anchors) 27 | return scores 28 | 29 | 30 | def sigmoid(X): 31 | # map [0,1] -> [0.5,0.73] (almost linearly) ([-1, 0] -> [0.26, 0.5]) 32 | return 1.0 / (1.0 + np.exp(-1.0 * X)) 33 | 34 | 35 | def ensure_dir(path): 36 | if not os.path.exists(path): 37 | os.makedirs(path) 38 | 39 | 40 | def build_taeget(all_prediction_x_np, all_prediction_w_np, batch_gt_bbox, batch_gt_class, batch_start_index, config): 41 | batch_match_x = [] 42 | batch_match_w = [] 43 | batch_match_scores = [] 44 | batch_match_labels = [] 45 | 46 | for idx in range(config.batch_size): 47 | b_anchors_rx = all_prediction_x_np[idx, ...] 48 | b_anchors_rw = all_prediction_w_np[idx, ...] 49 | 50 | b_gt_class = batch_gt_class[batch_start_index[idx]:batch_start_index[idx + 1], ...] 51 | b_gt_bbox = batch_gt_bbox[batch_start_index[idx]:batch_start_index[idx + 1], ...] 52 | assert b_gt_class.shape[0] == b_gt_bbox.shape[0] 53 | 54 | b_gt_num = b_gt_bbox.shape[0] 55 | num_all_anchors = all_prediction_x_np.shape[1] 56 | match_x = np.zeros((num_all_anchors), dtype=np.float32) 57 | match_w = np.zeros((num_all_anchors), dtype=np.float32) 58 | match_scores = np.zeros((num_all_anchors), dtype=np.float32) 59 | 60 | match_labels_other = np.ones((num_all_anchors, 1), dtype=np.int32) 61 | match_labels_class = np.zeros((num_all_anchors, config.num_classes - 1), 62 | dtype=np.int32) 63 | match_labels = np.hstack([match_labels_other, match_labels_class]) 64 | 65 | for jj in range(b_gt_num): 66 | a_gt_min = b_gt_bbox[jj, 0] 67 | a_gt_max = b_gt_bbox[jj, 1] 68 | a_gt_class = b_gt_class[jj] 69 | # ground truth 70 | a_gt_x = (a_gt_max + a_gt_min) / 2 71 | a_gt_w = (a_gt_max - a_gt_min) 72 | 73 | # predict 74 | anchors_min = b_anchors_rx - b_anchors_rw / 2 75 | anchors_max = b_anchors_rx + b_anchors_rw / 2 76 | 77 | jaccards = iou_with_anchors(anchors_min, anchors_max, a_gt_min, a_gt_max) 78 | 79 | # jaccards > b_match_scores > -0.5 & jaccards > matching_threshold 80 | mask = jaccards > match_scores 81 | matching_threshold = 0.5 82 | mask = mask & (jaccards > matching_threshold) 83 | mask = mask & (match_scores > -0.5) 84 | 85 | imask = mask.astype(np.int32) 86 | fmask = mask.astype(np.float32) 87 | # Update values using mask. 88 | # if overlap enough, update b_match_* with gt, otherwise not update 89 | match_x = fmask * a_gt_x + (1 - fmask) * match_x 90 | match_w = fmask * a_gt_w + (1 - fmask) * match_w 91 | 92 | ref_label = np.zeros_like(match_labels, dtype=np.int32) 93 | ref_label = ref_label + a_gt_class 94 | 95 | match_labels = np.matmul(np.diag(imask), ref_label) + np.matmul(np.diag(1 - imask), match_labels) 96 | 97 | match_scores = np.maximum(jaccards, match_scores) 98 | 99 | batch_match_x.append(np.expand_dims(match_x, axis=0)) 100 | batch_match_w.append(np.expand_dims(match_w, axis=0)) 101 | batch_match_scores.append(np.expand_dims(match_scores, axis=0)) 102 | batch_match_labels.append(np.expand_dims(match_labels, axis=0)) 103 | batch_match_x = np.vstack(batch_match_x) 104 | batch_match_w = np.vstack(batch_match_w) 105 | batch_match_scores = np.vstack(batch_match_scores) 106 | batch_match_labels = np.vstack(batch_match_labels) 107 | batch_match_labels = np.argmax(batch_match_labels, axis=-1) 108 | return batch_match_x, batch_match_w, batch_match_scores, batch_match_labels 109 | 110 | 111 | def post_process(df, config): 112 | class_scores_class = [(df['score_' + str(i)]).values[:].tolist() for i in range(21)] 113 | class_scores_seg = [[class_scores_class[j][i] for j in range(21)] for i in range(len(df))] 114 | 115 | class_real = [0] + config.class_real # num_classes + 1 116 | 117 | # save the top 2 or 3 score element 118 | # append the largest score element 119 | class_type_list = [] 120 | class_score_list = [] 121 | for i in range(len(df)): 122 | class_score = np.array(class_scores_seg[i][1:]) * df.conf.values[i] 123 | class_score = class_score.tolist() 124 | class_type = class_real[class_score.index(max(class_score)) + 1] 125 | class_type_list.append(class_type) 126 | class_score_list.append(max(class_score)) 127 | resultDf1 = pd.DataFrame() 128 | resultDf1['out_type'] = class_type_list 129 | resultDf1['out_score'] = class_score_list 130 | resultDf1['start'] = df.xmin.values[:] 131 | resultDf1['end'] = df.xmax.values[:] 132 | 133 | # append the second largest score element 134 | class_type_list = [] 135 | class_score_list = [] 136 | for i in range(len(df)): 137 | class_score = np.array(class_scores_seg[i][1:]) * df.conf.values[i] 138 | class_score = class_score.tolist() 139 | class_score[class_score.index(max(class_score))] = 0 140 | class_type = class_real[class_score.index(max(class_score)) + 1] 141 | class_type_list.append(class_type) 142 | class_score_list.append(max(class_score)) 143 | resultDf2 = pd.DataFrame() 144 | resultDf2['out_type'] = class_type_list 145 | resultDf2['out_score'] = class_score_list 146 | resultDf2['start'] = df.xmin.values[:] 147 | resultDf2['end'] = df.xmax.values[:] 148 | resultDf1 = pd.concat([resultDf1, resultDf2]) 149 | 150 | # append the third largest score element (improve little and slow) 151 | class_type_list = [] 152 | class_score_list = [] 153 | for i in range(len(df)): 154 | class_score = np.array(class_scores_seg[i][1:]) * df.conf.values[i] 155 | class_score = class_score.tolist() 156 | class_score[class_score.index(max(class_score))] = 0 157 | class_score[class_score.index(max(class_score))] = 0 158 | class_type = class_real[class_score.index(max(class_score)) + 1] 159 | class_type_list.append(class_type) 160 | class_score_list.append(max(class_score)) 161 | resultDf2 = pd.DataFrame() 162 | resultDf2['out_type'] = class_type_list 163 | resultDf2['out_score'] = class_score_list 164 | resultDf2['start'] = df.xmin.values[:] 165 | resultDf2['end'] = df.xmax.values[:] 166 | resultDf1 = pd.concat([resultDf1, resultDf2]) 167 | 168 | resultDf1 = resultDf1[resultDf1.out_score > 0.0005] 169 | 170 | resultDf1['video_name'] = [df['video_name'].values[0] for _ in range(len(resultDf1))] 171 | return resultDf1 172 | 173 | 174 | def temporal_nms(config, dfNMS, filename, videoname): 175 | nms_threshold = config.nms_threshold 176 | fo = open(filename, 'a') 177 | 178 | typeSet = list(set(dfNMS.out_type.values[:])) 179 | for t in typeSet: 180 | tdf = dfNMS[dfNMS.out_type == t] 181 | 182 | t1 = np.array(tdf.start.values[:]) 183 | t2 = np.array(tdf.end.values[:]) 184 | scores = np.array(tdf.out_score.values[:]) 185 | ttype = list(tdf.out_type.values[:]) 186 | 187 | durations = t2 - t1 188 | order = scores.argsort()[::-1] 189 | 190 | keep = [] 191 | while order.size > 0: 192 | i = order[0] 193 | keep.append(i) 194 | tt1 = np.maximum(t1[i], t1[order[1:]]) 195 | tt2 = np.minimum(t2[i], t2[order[1:]]) 196 | intersection = tt2 - tt1 197 | IoU = intersection / (durations[i] + durations[order[1:]] - intersection).astype(float) 198 | 199 | inds = np.where(IoU <= nms_threshold)[0] 200 | order = order[inds + 1] 201 | 202 | for idx in keep: 203 | # class_real: do not have class 0 (ambiguous) -> remove all ambiguous class 204 | if ttype[idx] in config.class_real: 205 | if videoname in ["video_test_0001255", "video_test_0001058", 206 | "video_test_0001459", "video_test_0001195", "video_test_0000950"]: # 25fps 207 | strout = "%s\t%.3f\t%.3f\t%d\t%.4f\n" % (videoname, float(t1[idx]) / 25, 208 | float(t2[idx]) / 25, ttype[idx], scores[idx]) 209 | elif videoname == "video_test_0001207": # 24fps 210 | strout = "%s\t%.3f\t%.3f\t%d\t%.4f\n" % (videoname, float(t1[idx]) / 24, 211 | float(t2[idx]) / 24, ttype[idx], scores[idx]) 212 | else: # most videos are 30fps 213 | strout = "%s\t%.3f\t%.3f\t%d\t%.4f\n" % (videoname, float(t1[idx]) / 30, 214 | float(t2[idx]) / 30, ttype[idx], scores[idx]) 215 | fo.write(strout) 216 | --------------------------------------------------------------------------------