├── LICENSE ├── README.md ├── config.py ├── dataset └── THUMOS14 │ ├── gt.json │ ├── split_test.txt │ └── split_train.txt ├── eval ├── eval_detection.py └── utils_eval.py ├── main.py ├── main_eval.py ├── model.py ├── options.py ├── requirements.txt ├── run.sh ├── run_eval.sh ├── test.py ├── thumos_features.py ├── train.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 pilhyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # BaSNet-pytorch 2 | ### Official Pytorch Implementation of '[Background Suppression Network for Weakly-supervised Temporal Action Localization](https://arxiv.org/abs/1911.09963)' (AAAI 2020 Spotlight) 3 | 4 | ![BaS-Net architecture](https://user-images.githubusercontent.com/16102333/78222568-69945500-7500-11ea-9468-22b1da6d1d77.png) 5 | 6 | > **Background Suppression Network for Weakly-supervised Temporal Action Localization**
7 | > Pilhyeon Lee (Yonsei Univ.), Youngjung Uh (Clova AI, NAVER Corp.), Hyeran Byun (Yonsei Univ.) 8 | > 9 | > Paper: https://arxiv.org/abs/1911.09963 10 | > 11 | > **Abstract:** *Weakly-supervised temporal action localization is a very challenging problem because frame-wise labels are not given in the training stage while the only hint is video-level labels: whether each video contains action frames of interest. Previous methods aggregate frame-level class scores to produce video-level prediction and learn from video-level action labels. This formulation does not fully model the problem in that background frames are forced to be misclassified as action classes to predict video-level labels accurately. In this paper, we design Background Suppression Network (BaS-Net) which introduces an auxiliary class for background and has a two-branch weight-sharing architecture with an asymmetrical training strategy. This enables BaS-Net to suppress activations from background frames to improve localization performance. Extensive experiments demonstrate the effectiveness of BaS-Net and its superiority over the state-of-the-art methods on the most popular benchmarks - THUMOS'14 and ActivityNet.* 12 | 13 | ## (2020/06/16) Our new model is available now! 14 | ### Weakly-supervised Temporal Action Localization by Uncertainty Modeling [[Paper](https://arxiv.org/abs/2006.07006)] [[Code](https://github.com/Pilhyeon/WTAL-Uncertainty-Modeling)] 15 | 16 | ## Prerequisites 17 | ### Recommended Environment 18 | * Python 3.5 19 | * Pytorch 1.0 20 | * Tensorflow 1.15 (for Tensorboard) 21 | 22 | ### Depencencies 23 | You can set up the environments by using `$ pip3 install -r requirements.txt`. 24 | 25 | ### Data Preparation 26 | 1. Prepare [THUMOS'14](https://www.crcv.ucf.edu/THUMOS14/) dataset. 27 | - We excluded three test videos (270, 1292, 1496) as previous work did. 28 | 29 | 2. Extract features with two-stream I3D networks 30 | - We recommend extracting features using [this repo](https://github.com/piergiaj/pytorch-i3d). 31 | - For convenience, we provide the features we used. You can find them [here](https://drive.google.com/file/d/19BIRy53w2H5J2Nc_mpAbYPVzElReJswe/view?usp=sharing). 32 | 33 | 3. Place the features inside the `dataset` folder. 34 | - Please ensure the data structure is as below. 35 | 36 | ~~~~ 37 | ├── dataset 38 | └── THUMOS14 39 | ├── gt.json 40 | ├── split_train.txt 41 | ├── split_test.txt 42 | └── features 43 | ├── train 44 | ├── rgb 45 | ├── video_validation_0000051.npy 46 | ├── video_validation_0000052.npy 47 | └── ... 48 | └── flow 49 | ├── video_validation_0000051.npy 50 | ├── video_validation_0000052.npy 51 | └── ... 52 | └── test 53 | ├── rgb 54 | ├── video_test_0000004.npy 55 | ├── video_test_0000006.npy 56 | └── ... 57 | └── flow 58 | ├── video_test_0000004.npy 59 | ├── video_test_0000006.npy 60 | └── ... 61 | ~~~~ 62 | 63 | ## Usage 64 | 65 | ### Running 66 | You can easily train and evaluate BaS-Net by running the script below. 67 | 68 | If you want to try other training options, please refer to `options.py`. 69 | 70 | ~~~~ 71 | $ bash run.sh 72 | ~~~~ 73 | 74 | ### Evaulation 75 | The pre-trained model can be found [here](https://drive.google.com/file/d/1W9uVOTEvJAOj99RWRUqrk9NS4ahgSOE6/view?usp=sharing). 76 | You can evaluate the model by running the command below. 77 | 78 | ~~~~ 79 | $ bash run_eval.sh 80 | ~~~~ 81 | 82 | ## References 83 | We referenced the repos below for the code. 84 | 85 | * [STPN](https://github.com/bellos1203/STPN) 86 | * [ActivityNet](https://github.com/activitynet/ActivityNet) 87 | 88 | ## Citation 89 | If you find this code useful, please cite our paper. 90 | 91 | ~~~~ 92 | @inproceedings{lee2020BaS-Net, 93 | title={Background Suppression Network for Weakly-supervised Temporal Action Localization}, 94 | author={Lee, Pilhyeon and Uh, Youngjung and Byun, Hyeran}, 95 | booktitle={The 34th AAAI Conference on Artificial Intelligence}, 96 | pages={11320--11327}, 97 | year={2020} 98 | } 99 | ~~~~ 100 | 101 | ## Contact 102 | If you have any question or comment, please contact the first author of the paper - Pilhyeon Lee (lph1114@yonsei.ac.kr). 103 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | class Config(object): 5 | def __init__(self, args): 6 | self.lr = eval(args.lr) 7 | self.num_iters = len(self.lr) 8 | self.num_classes = 20 9 | self.modal = args.modal 10 | if self.modal == 'all': 11 | self.len_feature = 2048 12 | else: 13 | self.len_feature = 1024 14 | self.batch_size = args.batch_size 15 | self.data_path = args.data_path 16 | self.model_path = args.model_path 17 | self.output_path = args.output_path 18 | self.log_path = args.log_path 19 | self.num_workers = args.num_workers 20 | self.alpha = args.alpha 21 | self.class_thresh = args.class_th 22 | self.act_thresh = np.arange(0.0, 0.25, 0.025) 23 | self.scale = 24 24 | self.gt_path = os.path.join(self.data_path, 'gt.json') 25 | self.model_file = args.model_file 26 | self.seed = args.seed 27 | self.feature_fps = 25 28 | self.num_segments = 750 29 | 30 | 31 | class_dict = {0: 'BaseballPitch', 32 | 1: 'BasketballDunk', 33 | 2: 'Billiards', 34 | 3: 'CleanAndJerk', 35 | 4: 'CliffDiving', 36 | 5: 'CricketBowling', 37 | 6: 'CricketShot', 38 | 7: 'Diving', 39 | 8: 'FrisbeeCatch', 40 | 9: 'GolfSwing', 41 | 10: 'HammerThrow', 42 | 11: 'HighJump', 43 | 12: 'JavelinThrow', 44 | 13: 'LongJump', 45 | 14: 'PoleVault', 46 | 15: 'Shotput', 47 | 16: 'SoccerPenalty', 48 | 17: 'TennisSwing', 49 | 18: 'ThrowDiscus', 50 | 19: 'VolleyballSpiking'} -------------------------------------------------------------------------------- /dataset/THUMOS14/split_test.txt: -------------------------------------------------------------------------------- 1 | video_test_0000004 2 | video_test_0000006 3 | video_test_0000007 4 | video_test_0000011 5 | video_test_0000026 6 | video_test_0000028 7 | video_test_0000039 8 | video_test_0000045 9 | video_test_0000046 10 | video_test_0000051 11 | video_test_0000058 12 | video_test_0000062 13 | video_test_0000073 14 | video_test_0000085 15 | video_test_0000113 16 | video_test_0000129 17 | video_test_0000131 18 | video_test_0000173 19 | video_test_0000179 20 | video_test_0000188 21 | video_test_0000211 22 | video_test_0000220 23 | video_test_0000238 24 | video_test_0000242 25 | video_test_0000250 26 | video_test_0000254 27 | video_test_0000273 28 | video_test_0000278 29 | video_test_0000285 30 | video_test_0000292 31 | video_test_0000293 32 | video_test_0000308 33 | video_test_0000319 34 | video_test_0000324 35 | video_test_0000353 36 | video_test_0000355 37 | video_test_0000357 38 | video_test_0000367 39 | video_test_0000372 40 | video_test_0000374 41 | video_test_0000379 42 | video_test_0000392 43 | video_test_0000405 44 | video_test_0000412 45 | video_test_0000413 46 | video_test_0000423 47 | video_test_0000426 48 | video_test_0000429 49 | video_test_0000437 50 | video_test_0000442 51 | video_test_0000443 52 | video_test_0000444 53 | video_test_0000448 54 | video_test_0000450 55 | video_test_0000461 56 | video_test_0000464 57 | video_test_0000504 58 | video_test_0000505 59 | video_test_0000538 60 | video_test_0000541 61 | video_test_0000549 62 | video_test_0000556 63 | video_test_0000558 64 | video_test_0000560 65 | video_test_0000569 66 | video_test_0000577 67 | video_test_0000591 68 | video_test_0000593 69 | video_test_0000601 70 | video_test_0000602 71 | video_test_0000611 72 | video_test_0000615 73 | video_test_0000617 74 | video_test_0000622 75 | video_test_0000624 76 | video_test_0000626 77 | video_test_0000635 78 | video_test_0000664 79 | video_test_0000665 80 | video_test_0000671 81 | video_test_0000672 82 | video_test_0000673 83 | video_test_0000689 84 | video_test_0000691 85 | video_test_0000698 86 | video_test_0000701 87 | video_test_0000714 88 | video_test_0000716 89 | video_test_0000718 90 | video_test_0000723 91 | video_test_0000724 92 | video_test_0000730 93 | video_test_0000737 94 | video_test_0000740 95 | video_test_0000756 96 | video_test_0000762 97 | video_test_0000765 98 | video_test_0000767 99 | video_test_0000771 100 | video_test_0000785 101 | video_test_0000786 102 | video_test_0000793 103 | video_test_0000796 104 | video_test_0000798 105 | video_test_0000807 106 | video_test_0000814 107 | video_test_0000839 108 | video_test_0000844 109 | video_test_0000846 110 | video_test_0000847 111 | video_test_0000854 112 | video_test_0000864 113 | video_test_0000873 114 | video_test_0000882 115 | video_test_0000887 116 | video_test_0000896 117 | video_test_0000897 118 | video_test_0000903 119 | video_test_0000940 120 | video_test_0000946 121 | video_test_0000950 122 | video_test_0000964 123 | video_test_0000981 124 | video_test_0000987 125 | video_test_0000989 126 | video_test_0000991 127 | video_test_0001008 128 | video_test_0001038 129 | video_test_0001039 130 | video_test_0001040 131 | video_test_0001058 132 | video_test_0001064 133 | video_test_0001066 134 | video_test_0001072 135 | video_test_0001075 136 | video_test_0001076 137 | video_test_0001078 138 | video_test_0001079 139 | video_test_0001080 140 | video_test_0001081 141 | video_test_0001098 142 | video_test_0001114 143 | video_test_0001118 144 | video_test_0001123 145 | video_test_0001127 146 | video_test_0001129 147 | video_test_0001134 148 | video_test_0001135 149 | video_test_0001146 150 | video_test_0001153 151 | video_test_0001159 152 | video_test_0001162 153 | video_test_0001163 154 | video_test_0001164 155 | video_test_0001168 156 | video_test_0001174 157 | video_test_0001182 158 | video_test_0001194 159 | video_test_0001195 160 | video_test_0001201 161 | video_test_0001202 162 | video_test_0001207 163 | video_test_0001209 164 | video_test_0001219 165 | video_test_0001223 166 | video_test_0001229 167 | video_test_0001235 168 | video_test_0001247 169 | video_test_0001255 170 | video_test_0001257 171 | video_test_0001267 172 | video_test_0001268 173 | video_test_0001270 174 | video_test_0001276 175 | video_test_0001281 176 | video_test_0001307 177 | video_test_0001309 178 | video_test_0001313 179 | video_test_0001314 180 | video_test_0001319 181 | video_test_0001324 182 | video_test_0001325 183 | video_test_0001339 184 | video_test_0001343 185 | video_test_0001358 186 | video_test_0001369 187 | video_test_0001389 188 | video_test_0001391 189 | video_test_0001409 190 | video_test_0001431 191 | video_test_0001433 192 | video_test_0001446 193 | video_test_0001447 194 | video_test_0001452 195 | video_test_0001459 196 | video_test_0001460 197 | video_test_0001463 198 | video_test_0001468 199 | video_test_0001483 200 | video_test_0001484 201 | video_test_0001495 202 | video_test_0001508 203 | video_test_0001512 204 | video_test_0001522 205 | video_test_0001527 206 | video_test_0001531 207 | video_test_0001532 208 | video_test_0001549 209 | video_test_0001556 210 | video_test_0001558 211 | -------------------------------------------------------------------------------- /dataset/THUMOS14/split_train.txt: -------------------------------------------------------------------------------- 1 | video_validation_0000051 2 | video_validation_0000052 3 | video_validation_0000053 4 | video_validation_0000054 5 | video_validation_0000055 6 | video_validation_0000056 7 | video_validation_0000057 8 | video_validation_0000058 9 | video_validation_0000059 10 | video_validation_0000060 11 | video_validation_0000151 12 | video_validation_0000152 13 | video_validation_0000153 14 | video_validation_0000154 15 | video_validation_0000155 16 | video_validation_0000156 17 | video_validation_0000157 18 | video_validation_0000158 19 | video_validation_0000159 20 | video_validation_0000160 21 | video_validation_0000161 22 | video_validation_0000162 23 | video_validation_0000163 24 | video_validation_0000164 25 | video_validation_0000165 26 | video_validation_0000166 27 | video_validation_0000167 28 | video_validation_0000168 29 | video_validation_0000169 30 | video_validation_0000170 31 | video_validation_0000171 32 | video_validation_0000172 33 | video_validation_0000173 34 | video_validation_0000174 35 | video_validation_0000175 36 | video_validation_0000176 37 | video_validation_0000177 38 | video_validation_0000178 39 | video_validation_0000179 40 | video_validation_0000180 41 | video_validation_0000181 42 | video_validation_0000182 43 | video_validation_0000183 44 | video_validation_0000184 45 | video_validation_0000185 46 | video_validation_0000186 47 | video_validation_0000187 48 | video_validation_0000188 49 | video_validation_0000189 50 | video_validation_0000190 51 | video_validation_0000201 52 | video_validation_0000202 53 | video_validation_0000203 54 | video_validation_0000204 55 | video_validation_0000205 56 | video_validation_0000206 57 | video_validation_0000207 58 | video_validation_0000208 59 | video_validation_0000209 60 | video_validation_0000210 61 | video_validation_0000261 62 | video_validation_0000262 63 | video_validation_0000263 64 | video_validation_0000264 65 | video_validation_0000265 66 | video_validation_0000266 67 | video_validation_0000267 68 | video_validation_0000268 69 | video_validation_0000269 70 | video_validation_0000270 71 | video_validation_0000281 72 | video_validation_0000282 73 | video_validation_0000283 74 | video_validation_0000284 75 | video_validation_0000285 76 | video_validation_0000286 77 | video_validation_0000287 78 | video_validation_0000288 79 | video_validation_0000289 80 | video_validation_0000290 81 | video_validation_0000311 82 | video_validation_0000312 83 | video_validation_0000313 84 | video_validation_0000314 85 | video_validation_0000315 86 | video_validation_0000316 87 | video_validation_0000317 88 | video_validation_0000318 89 | video_validation_0000319 90 | video_validation_0000320 91 | video_validation_0000361 92 | video_validation_0000362 93 | video_validation_0000363 94 | video_validation_0000364 95 | video_validation_0000365 96 | video_validation_0000366 97 | video_validation_0000367 98 | video_validation_0000368 99 | video_validation_0000369 100 | video_validation_0000370 101 | video_validation_0000411 102 | video_validation_0000412 103 | video_validation_0000413 104 | video_validation_0000414 105 | video_validation_0000415 106 | video_validation_0000416 107 | video_validation_0000417 108 | video_validation_0000418 109 | video_validation_0000419 110 | video_validation_0000420 111 | video_validation_0000481 112 | video_validation_0000482 113 | video_validation_0000483 114 | video_validation_0000484 115 | video_validation_0000485 116 | video_validation_0000486 117 | video_validation_0000487 118 | video_validation_0000488 119 | video_validation_0000489 120 | video_validation_0000490 121 | video_validation_0000661 122 | video_validation_0000662 123 | video_validation_0000663 124 | video_validation_0000664 125 | video_validation_0000665 126 | video_validation_0000666 127 | video_validation_0000667 128 | video_validation_0000668 129 | video_validation_0000669 130 | video_validation_0000670 131 | video_validation_0000681 132 | video_validation_0000682 133 | video_validation_0000683 134 | video_validation_0000684 135 | video_validation_0000685 136 | video_validation_0000686 137 | video_validation_0000687 138 | video_validation_0000688 139 | video_validation_0000689 140 | video_validation_0000690 141 | video_validation_0000781 142 | video_validation_0000782 143 | video_validation_0000783 144 | video_validation_0000784 145 | video_validation_0000785 146 | video_validation_0000786 147 | video_validation_0000787 148 | video_validation_0000788 149 | video_validation_0000789 150 | video_validation_0000790 151 | video_validation_0000851 152 | video_validation_0000852 153 | video_validation_0000853 154 | video_validation_0000854 155 | video_validation_0000855 156 | video_validation_0000856 157 | video_validation_0000857 158 | video_validation_0000858 159 | video_validation_0000859 160 | video_validation_0000860 161 | video_validation_0000901 162 | video_validation_0000902 163 | video_validation_0000903 164 | video_validation_0000904 165 | video_validation_0000905 166 | video_validation_0000906 167 | video_validation_0000907 168 | video_validation_0000908 169 | video_validation_0000909 170 | video_validation_0000910 171 | video_validation_0000931 172 | video_validation_0000932 173 | video_validation_0000933 174 | video_validation_0000934 175 | video_validation_0000935 176 | video_validation_0000936 177 | video_validation_0000937 178 | video_validation_0000938 179 | video_validation_0000939 180 | video_validation_0000940 181 | video_validation_0000941 182 | video_validation_0000942 183 | video_validation_0000943 184 | video_validation_0000944 185 | video_validation_0000945 186 | video_validation_0000946 187 | video_validation_0000947 188 | video_validation_0000948 189 | video_validation_0000949 190 | video_validation_0000950 191 | video_validation_0000981 192 | video_validation_0000982 193 | video_validation_0000983 194 | video_validation_0000984 195 | video_validation_0000985 196 | video_validation_0000986 197 | video_validation_0000987 198 | video_validation_0000988 199 | video_validation_0000989 200 | video_validation_0000990 201 | -------------------------------------------------------------------------------- /eval/eval_detection.py: -------------------------------------------------------------------------------- 1 | # This code is originally from the official ActivityNet repo 2 | # https://github.com/activitynet/ActivityNet 3 | # Small modification from ActivityNet Code 4 | 5 | import json 6 | import numpy as np 7 | import pandas as pd 8 | from joblib import Parallel, delayed 9 | 10 | from .utils_eval import get_blocked_videos 11 | from .utils_eval import interpolated_prec_rec 12 | from .utils_eval import segment_iou 13 | 14 | import warnings 15 | warnings.filterwarnings("ignore", message="numpy.dtype size changed") 16 | warnings.filterwarnings("ignore", message="numpy.ufunc size changed") 17 | 18 | 19 | 20 | class ANETdetection(object): 21 | GROUND_TRUTH_FIELDS = ['database'] 22 | # GROUND_TRUTH_FIELDS = ['database', 'taxonomy', 'version'] 23 | PREDICTION_FIELDS = ['results', 'version', 'external_data'] 24 | 25 | def __init__(self, ground_truth_filename=None, prediction_filename=None, 26 | ground_truth_fields=GROUND_TRUTH_FIELDS, 27 | prediction_fields=PREDICTION_FIELDS, 28 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 29 | subset='validation', verbose=False, 30 | check_status=False): 31 | if not ground_truth_filename: 32 | raise IOError('Please input a valid ground truth file.') 33 | if not prediction_filename: 34 | raise IOError('Please input a valid prediction file.') 35 | self.subset = subset 36 | self.tiou_thresholds = tiou_thresholds 37 | self.verbose = verbose 38 | self.gt_fields = ground_truth_fields 39 | self.pred_fields = prediction_fields 40 | self.ap = None 41 | self.check_status = check_status 42 | # Retrieve blocked videos from server. 43 | 44 | if self.check_status: 45 | self.blocked_videos = get_blocked_videos() 46 | else: 47 | self.blocked_videos = list() 48 | 49 | # Import ground truth and predictions. 50 | self.ground_truth, self.activity_index = self._import_ground_truth( 51 | ground_truth_filename) 52 | self.prediction = self._import_prediction(prediction_filename) 53 | 54 | if self.verbose: 55 | print ('[INIT] Loaded annotations from {} subset.'.format(subset)) 56 | nr_gt = len(self.ground_truth) 57 | print ('\tNumber of ground truth instances: {}'.format(nr_gt)) 58 | nr_pred = len(self.prediction) 59 | print ('\tNumber of predictions: {}'.format(nr_pred)) 60 | print ('\tFixed threshold for tiou score: {}'.format(self.tiou_thresholds)) 61 | 62 | def _import_ground_truth(self, ground_truth_filename): 63 | """Reads ground truth file, checks if it is well formatted, and returns 64 | the ground truth instances and the activity classes. 65 | 66 | Parameters 67 | ---------- 68 | ground_truth_filename : str 69 | Full path to the ground truth json file. 70 | 71 | Outputs 72 | ------- 73 | ground_truth : df 74 | Data frame containing the ground truth instances. 75 | activity_index : dict 76 | Dictionary containing class index. 77 | """ 78 | with open(ground_truth_filename, 'r') as fobj: 79 | data = json.load(fobj) 80 | # Checking format 81 | if not all([field in data.keys() for field in self.gt_fields]): 82 | raise IOError('Please input a valid ground truth file.') 83 | 84 | # Read ground truth data. 85 | activity_index, cidx = {}, 0 86 | video_lst, t_start_lst, t_end_lst, label_lst = [], [], [], [] 87 | for videoid, v in data['database'].items(): 88 | # print(v) 89 | if self.subset != v['subset']: 90 | continue 91 | if videoid in self.blocked_videos: 92 | continue 93 | for ann in v['annotations']: 94 | if ann['label'] not in activity_index: 95 | activity_index[ann['label']] = cidx 96 | cidx += 1 97 | video_lst.append(videoid) 98 | t_start_lst.append(float(ann['segment'][0])) 99 | t_end_lst.append(float(ann['segment'][1])) 100 | label_lst.append(activity_index[ann['label']]) 101 | 102 | ground_truth = pd.DataFrame({'video-id': video_lst, 103 | 't-start': t_start_lst, 104 | 't-end': t_end_lst, 105 | 'label': label_lst}) 106 | if self.verbose: 107 | print(activity_index) 108 | return ground_truth, activity_index 109 | 110 | def _import_prediction(self, prediction_filename): 111 | """Reads prediction file, checks if it is well formatted, and returns 112 | the prediction instances. 113 | 114 | Parameters 115 | ---------- 116 | prediction_filename : str 117 | Full path to the prediction json file. 118 | 119 | Outputs 120 | ------- 121 | prediction : df 122 | Data frame containing the prediction instances. 123 | """ 124 | with open(prediction_filename, 'r') as fobj: 125 | data = json.load(fobj) 126 | # Checking format... 127 | if not all([field in data.keys() for field in self.pred_fields]): 128 | raise IOError('Please input a valid prediction file.') 129 | 130 | # Read predictions. 131 | video_lst, t_start_lst, t_end_lst = [], [], [] 132 | label_lst, score_lst = [], [] 133 | for videoid, v in data['results'].items(): 134 | if videoid in self.blocked_videos: 135 | continue 136 | for result in v: 137 | label = self.activity_index[result['label']] 138 | video_lst.append(videoid) 139 | t_start_lst.append(float(result['segment'][0])) 140 | t_end_lst.append(float(result['segment'][1])) 141 | label_lst.append(label) 142 | score_lst.append(result['score']) 143 | prediction = pd.DataFrame({'video-id': video_lst, 144 | 't-start': t_start_lst, 145 | 't-end': t_end_lst, 146 | 'label': label_lst, 147 | 'score': score_lst}) 148 | return prediction 149 | 150 | def _get_predictions_with_label(self, prediction_by_label, label_name, cidx): 151 | """Get all predicitons of the given label. Return empty DataFrame if there 152 | is no predcitions with the given label. 153 | """ 154 | try: 155 | return prediction_by_label.get_group(cidx).reset_index(drop=True) 156 | except: 157 | if self.verbose: 158 | print ('Warning: No predictions of label \'%s\' were provdied.' % label_name) 159 | return pd.DataFrame() 160 | 161 | def wrapper_compute_average_precision(self): 162 | """Computes average precision for each class in the subset. 163 | """ 164 | ap = np.zeros((len(self.tiou_thresholds), len(self.activity_index))) 165 | 166 | # Adaptation to query faster 167 | ground_truth_by_label = self.ground_truth.groupby('label') 168 | prediction_by_label = self.prediction.groupby('label') 169 | 170 | results = Parallel(n_jobs=len(self.activity_index))( 171 | delayed(compute_average_precision_detection)( 172 | ground_truth=ground_truth_by_label.get_group(cidx).reset_index(drop=True), 173 | prediction=self._get_predictions_with_label(prediction_by_label, label_name, cidx), 174 | tiou_thresholds=self.tiou_thresholds, 175 | ) for label_name, cidx in self.activity_index.items()) 176 | 177 | for i, cidx in enumerate(self.activity_index.values()): 178 | ap[:,cidx] = results[i] 179 | 180 | return ap 181 | 182 | def evaluate(self): 183 | """Evaluates a prediction file. For the detection task we measure the 184 | interpolated mean average precision to measure the performance of a 185 | method. 186 | """ 187 | self.ap = self.wrapper_compute_average_precision() 188 | 189 | self.mAP = self.ap.mean(axis=1) 190 | self.average_mAP = self.mAP.mean() 191 | 192 | if self.verbose: 193 | print ('[RESULTS] Performance on ActivityNet detection task.') 194 | print ('Average-mAP: {}'.format(self.average_mAP)) 195 | 196 | return self.mAP, self.average_mAP 197 | 198 | 199 | def compute_average_precision_detection(ground_truth, prediction, tiou_thresholds=np.linspace(0.5, 0.95, 10)): 200 | """Compute average precision (detection task) between ground truth and 201 | predictions data frames. If multiple predictions occurs for the same 202 | predicted segment, only the one with highest score is matches as 203 | true positive. This code is greatly inspired by Pascal VOC devkit. 204 | 205 | Parameters 206 | ---------- 207 | ground_truth : df 208 | Data frame containing the ground truth instances. 209 | Required fields: ['video-id', 't-start', 't-end'] 210 | prediction : df 211 | Data frame containing the prediction instances. 212 | Required fields: ['video-id, 't-start', 't-end', 'score'] 213 | tiou_thresholds : 1darray, optional 214 | Temporal intersection over union threshold. 215 | 216 | Outputs 217 | ------- 218 | ap : float 219 | Average precision score. 220 | """ 221 | ap = np.zeros(len(tiou_thresholds)) 222 | if prediction.empty: 223 | return ap 224 | 225 | npos = float(len(ground_truth)) 226 | lock_gt = np.ones((len(tiou_thresholds),len(ground_truth))) * -1 227 | # Sort predictions by decreasing score order. 228 | sort_idx = prediction['score'].values.argsort()[::-1] 229 | prediction = prediction.loc[sort_idx].reset_index(drop=True) 230 | 231 | # Initialize true positive and false positive vectors. 232 | tp = np.zeros((len(tiou_thresholds), len(prediction))) 233 | fp = np.zeros((len(tiou_thresholds), len(prediction))) 234 | 235 | # Adaptation to query faster 236 | ground_truth_gbvn = ground_truth.groupby('video-id') 237 | 238 | # Assigning true positive to truly grount truth instances. 239 | for idx, this_pred in prediction.iterrows(): 240 | 241 | try: 242 | # Check if there is at least one ground truth in the video associated. 243 | ground_truth_videoid = ground_truth_gbvn.get_group(this_pred['video-id']) 244 | except Exception as e: 245 | fp[:, idx] = 1 246 | continue 247 | 248 | this_gt = ground_truth_videoid.reset_index() 249 | tiou_arr = segment_iou(this_pred[['t-start', 't-end']].values, 250 | this_gt[['t-start', 't-end']].values) 251 | # We would like to retrieve the predictions with highest tiou score. 252 | tiou_sorted_idx = tiou_arr.argsort()[::-1] 253 | for tidx, tiou_thr in enumerate(tiou_thresholds): 254 | for jdx in tiou_sorted_idx: 255 | if tiou_arr[jdx] < tiou_thr: 256 | fp[tidx, idx] = 1 257 | break 258 | if lock_gt[tidx, this_gt.loc[jdx]['index']] >= 0: 259 | continue 260 | # Assign as true positive after the filters above. 261 | tp[tidx, idx] = 1 262 | lock_gt[tidx, this_gt.loc[jdx]['index']] = idx 263 | break 264 | 265 | if fp[tidx, idx] == 0 and tp[tidx, idx] == 0: 266 | fp[tidx, idx] = 1 267 | 268 | tp_cumsum = np.cumsum(tp, axis=1).astype(np.float) 269 | fp_cumsum = np.cumsum(fp, axis=1).astype(np.float) 270 | recall_cumsum = tp_cumsum / npos 271 | 272 | precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum) 273 | 274 | for tidx in range(len(tiou_thresholds)): 275 | ap[tidx] = interpolated_prec_rec(precision_cumsum[tidx,:], recall_cumsum[tidx,:]) 276 | 277 | 278 | return ap 279 | -------------------------------------------------------------------------------- /eval/utils_eval.py: -------------------------------------------------------------------------------- 1 | # This code is originally from the official ActivityNet repo 2 | # https://github.com/activitynet/ActivityNet 3 | 4 | import json 5 | import urllib.request 6 | 7 | import numpy as np 8 | 9 | API = 'http://ec2-52-11-11-89.us-west-2.compute.amazonaws.com/challenge17/api.py' 10 | 11 | def get_blocked_videos(api=API): 12 | api_url = '{}?action=get_blocked'.format(api) 13 | req = urllib.request.Request(api_url) 14 | response = urllib.request.urlopen(req) 15 | return json.loads(response.read().decode('utf-8')) 16 | 17 | def interpolated_prec_rec(prec, rec): 18 | """Interpolated AP - VOCdevkit from VOC 2011. 19 | """ 20 | mprec = np.hstack([[0], prec, [0]]) 21 | mrec = np.hstack([[0], rec, [1]]) 22 | for i in range(len(mprec) - 1)[::-1]: 23 | mprec[i] = max(mprec[i], mprec[i + 1]) 24 | idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1 25 | ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx]) 26 | return ap 27 | 28 | def segment_iou(target_segment, candidate_segments): 29 | """Compute the temporal intersection over union between a 30 | target segment and all the test segments. 31 | 32 | Parameters 33 | ---------- 34 | target_segment : 1d array 35 | Temporal target segment containing [starting, ending] times. 36 | candidate_segments : 2d array 37 | Temporal candidate segments containing N x [starting, ending] times. 38 | 39 | Outputs 40 | ------- 41 | tiou : 1d array 42 | Temporal intersection over union score of the N's candidate segments. 43 | """ 44 | tt1 = np.maximum(target_segment[0], candidate_segments[:, 0]) 45 | tt2 = np.minimum(target_segment[1], candidate_segments[:, 1]) 46 | # Intersection including Non-negative overlap score. 47 | segments_intersection = (tt2 - tt1).clip(0) 48 | # Segment union. 49 | segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \ 50 | + (target_segment[1] - target_segment[0]) - segments_intersection 51 | # Compute overlap as the ratio of the intersection 52 | # over union of two segments. 53 | tIoU = segments_intersection.astype(float) / segments_union 54 | return tIoU 55 | 56 | def wrapper_segment_iou(target_segments, candidate_segments): 57 | """Compute intersection over union btw segments 58 | Parameters 59 | ---------- 60 | target_segments : ndarray 61 | 2-dim array in format [m x 2:=[init, end]] 62 | candidate_segments : ndarray 63 | 2-dim array in format [n x 2:=[init, end]] 64 | Outputs 65 | ------- 66 | tiou : ndarray 67 | 2-dim array [n x m] with IOU ratio. 68 | Note: It assumes that candidate-segments are more scarce that target-segments 69 | """ 70 | if candidate_segments.ndim != 2 or target_segments.ndim != 2: 71 | raise ValueError('Dimension of arguments is incorrect') 72 | 73 | n, m = candidate_segments.shape[0], target_segments.shape[0] 74 | tiou = np.empty((n, m)) 75 | for i in xrange(m): 76 | tiou[:, i] = segment_iou(target_segments[i,:], candidate_segments) 77 | 78 | return tiou 79 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | import torch.utils.data as data 4 | import utils 5 | from options import * 6 | from config import * 7 | from train import * 8 | from test import * 9 | from model import * 10 | from tensorboard_logger import Logger 11 | from thumos_features import * 12 | 13 | 14 | if __name__ == "__main__": 15 | args = parse_args() 16 | if args.debug: 17 | pdb.set_trace() 18 | 19 | config = Config(args) 20 | worker_init_fn = None 21 | 22 | if config.seed >= 0: 23 | utils.set_seed(config.seed) 24 | worker_init_fn = np.random.seed(config.seed) 25 | 26 | net = BaS_Net(config.len_feature, config.num_classes, config.num_segments) 27 | net = net.cuda() 28 | 29 | train_loader = data.DataLoader( 30 | ThumosFeature(data_path=config.data_path, mode='train', 31 | modal=config.modal, feature_fps=config.feature_fps, 32 | num_segments=config.num_segments, len_feature=config.len_feature, 33 | seed=config.seed, sampling='random'), 34 | batch_size=config.batch_size, 35 | shuffle=True, num_workers=config.num_workers, 36 | worker_init_fn=worker_init_fn) 37 | 38 | test_loader = data.DataLoader( 39 | ThumosFeature(data_path=config.data_path, mode='test', 40 | modal=config.modal, feature_fps=config.feature_fps, 41 | num_segments=config.num_segments, len_feature=config.len_feature, 42 | seed=config.seed, sampling='uniform'), 43 | batch_size=1, 44 | shuffle=False, num_workers=config.num_workers, 45 | worker_init_fn=worker_init_fn) 46 | 47 | test_info = {"step": [], "test_acc": [], "average_mAP": [], 48 | "mAP@0.1": [], "mAP@0.2": [], "mAP@0.3": [], 49 | "mAP@0.4": [], "mAP@0.5": [], "mAP@0.6": [], 50 | "mAP@0.7": [], "mAP@0.8": [], "mAP@0.9": []} 51 | 52 | best_mAP = -1 53 | 54 | criterion = BaS_Net_loss(config.alpha) 55 | 56 | optimizer = torch.optim.Adam(net.parameters(), lr=config.lr[0], 57 | betas=(0.9, 0.999), weight_decay=0.0005) 58 | 59 | logger = Logger(config.log_path) 60 | 61 | loader_iter = iter(train_loader) 62 | 63 | for step in tqdm( 64 | range(1, config.num_iters + 1), 65 | total = config.num_iters, 66 | dynamic_ncols = True 67 | ): 68 | if step > 1 and config.lr[step - 1] != config.lr[step - 2]: 69 | for param_group in optimizer.param_groups: 70 | param_group["lr"] = config.lr[step - 1] 71 | 72 | train(net, train_loader, loader_iter, optimizer, criterion, logger, step) 73 | 74 | test(net, config, logger, test_loader, test_info, step) 75 | 76 | if test_info["average_mAP"][-1] > best_mAP: 77 | best_mAP = test_info["average_mAP"][-1] 78 | 79 | utils.save_best_record_thumos(test_info, 80 | os.path.join(config.output_path, "best_record_seed_{}.txt".format(config.seed))) 81 | 82 | torch.save(net.state_dict(), os.path.join(args.model_path, \ 83 | "BaS_Net_model_seed_{}.pkl".format(config.seed))) 84 | 85 | -------------------------------------------------------------------------------- /main_eval.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | import numpy as np 3 | import torch.utils.data as data 4 | import utils 5 | from options import * 6 | from config import * 7 | from test import * 8 | from model import * 9 | from tensorboard_logger import Logger 10 | from thumos_features import * 11 | 12 | 13 | if __name__ == "__main__": 14 | args = parse_args() 15 | if args.debug: 16 | pdb.set_trace() 17 | 18 | config = Config(args) 19 | worker_init_fn = None 20 | 21 | if config.seed >= 0: 22 | utils.set_seed(config.seed) 23 | worker_init_fn = np.random.seed(config.seed) 24 | 25 | net = BaS_Net(config.len_feature, config.num_classes, config.num_segments) 26 | net = net.cuda() 27 | 28 | test_loader = data.DataLoader( 29 | ThumosFeature(data_path=config.data_path, mode='test', 30 | modal=config.modal, feature_fps=config.feature_fps, 31 | num_segments=config.num_segments, len_feature=config.len_feature, 32 | seed=config.seed, sampling='uniform'), 33 | batch_size=1, 34 | shuffle=False, num_workers=config.num_workers, 35 | worker_init_fn=worker_init_fn) 36 | 37 | test_info = {"step": [], "test_acc": [], "average_mAP": [], 38 | "mAP@0.1": [], "mAP@0.2": [], "mAP@0.3": [], 39 | "mAP@0.4": [], "mAP@0.5": [], "mAP@0.6": [], 40 | "mAP@0.7": [], "mAP@0.8": [], "mAP@0.9": []} 41 | 42 | logger = Logger(config.log_path) 43 | 44 | test(net, config, logger, test_loader, test_info, 0, model_file=config.model_file) 45 | 46 | utils.save_best_record_thumos(test_info, 47 | os.path.join(config.output_path, "best_record.txt")) 48 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Filter_Module(nn.Module): 5 | def __init__(self, len_feature): 6 | super(Filter_Module, self).__init__() 7 | self.len_feature = len_feature 8 | self.conv_1 = nn.Sequential( 9 | nn.Conv1d(in_channels=self.len_feature, out_channels=512, kernel_size=1, 10 | stride=1, padding=0), 11 | nn.LeakyReLU() 12 | ) 13 | self.conv_2 = nn.Sequential( 14 | nn.Conv1d(in_channels=512, out_channels=1, kernel_size=1, 15 | stride=1, padding=0), 16 | nn.Sigmoid() 17 | ) 18 | 19 | def forward(self, x): 20 | # x: (B, T, F) 21 | out = x.permute(0, 2, 1) 22 | # out: (B, F, T) 23 | out = self.conv_1(out) 24 | out = self.conv_2(out) 25 | out = out.permute(0, 2, 1) 26 | # out: (B, T, 1) 27 | return out 28 | 29 | 30 | class CAS_Module(nn.Module): 31 | def __init__(self, len_feature, num_classes): 32 | super(CAS_Module, self).__init__() 33 | self.len_feature = len_feature 34 | self.conv_1 = nn.Sequential( 35 | nn.Conv1d(in_channels=self.len_feature, out_channels=2048, kernel_size=3, 36 | stride=1, padding=1), 37 | nn.LeakyReLU() 38 | ) 39 | 40 | self.conv_2 = nn.Sequential( 41 | nn.Conv1d(in_channels=2048, out_channels=2048, kernel_size=3, 42 | stride=1, padding=1), 43 | nn.LeakyReLU() 44 | ) 45 | 46 | self.conv_3 = nn.Sequential( 47 | nn.Conv1d(in_channels=2048, out_channels=num_classes + 1, kernel_size=1, 48 | stride=1, padding=0, bias=False) 49 | ) 50 | self.drop_out = nn.Dropout(p=0.7) 51 | 52 | def forward(self, x): 53 | # x: (B, T, F) 54 | out = x.permute(0, 2, 1) 55 | # out: (B, F, T) 56 | out = self.conv_1(out) 57 | out = self.conv_2(out) 58 | out = self.drop_out(out) 59 | out = self.conv_3(out) 60 | out = out.permute(0, 2, 1) 61 | # out: (B, T, C + 1) 62 | return out 63 | 64 | class BaS_Net(nn.Module): 65 | def __init__(self, len_feature, num_classes, num_segments): 66 | super(BaS_Net, self).__init__() 67 | self.filter_module = Filter_Module(len_feature) 68 | self.len_feature = len_feature 69 | self.num_classes = num_classes 70 | 71 | self.cas_module = CAS_Module(len_feature, num_classes) 72 | 73 | self.softmax = nn.Softmax(dim=1) 74 | 75 | self.num_segments = num_segments 76 | self.k = num_segments // 8 77 | 78 | 79 | def forward(self, x): 80 | fore_weights = self.filter_module(x) 81 | 82 | x_supp = fore_weights * x 83 | 84 | cas_base = self.cas_module(x) 85 | cas_supp = self.cas_module(x_supp) 86 | 87 | # slicing after sorting is much faster than torch.topk (https://github.com/pytorch/pytorch/issues/22812) 88 | # score_base = torch.mean(torch.topk(cas_base, self.k, dim=1)[0], dim=1) 89 | sorted_scores_base, _= cas_base.sort(descending=True, dim=1) 90 | topk_scores_base = sorted_scores_base[:, :self.k, :] 91 | score_base = torch.mean(topk_scores_base, dim=1) 92 | 93 | # score_supp = torch.mean(torch.topk(cas_supp, self.k, dim=1)[0], dim=1) 94 | sorted_scores_supp, _= cas_supp.sort(descending=True, dim=1) 95 | topk_scores_supp = sorted_scores_supp[:, :self.k, :] 96 | score_supp = torch.mean(topk_scores_supp, dim=1) 97 | 98 | score_base = self.softmax(score_base) 99 | score_supp = self.softmax(score_supp) 100 | 101 | return score_base, cas_base, score_supp, cas_supp, fore_weights 102 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import shutil 3 | import os 4 | 5 | def parse_args(): 6 | descript = 'Pytorch Implementation of Background Suppression Network (BaS-Net)' 7 | parser = argparse.ArgumentParser(description=descript) 8 | 9 | parser.add_argument('--data_path', type=str, default='./dataset/THUMOS14') 10 | parser.add_argument('--model_path', type=str, default='./models/BaSnet') 11 | parser.add_argument('--output_path', type=str, default='./outputs/BaSnet') 12 | parser.add_argument('--log_path', type=str, default='./logs/BaSnet') 13 | parser.add_argument('--modal', type=str, default='all', choices=['rgb', 'flow', 'all']) 14 | parser.add_argument('--alpha', type=float, default=0.0001) 15 | parser.add_argument('--class_th', type=float, default=0.25) 16 | parser.add_argument('--lr', type=str, default='[0.0001]*1500', help='learning rates for steps(list form)') 17 | parser.add_argument('--batch_size', type=int, default=16) 18 | parser.add_argument('--num_workers', type=int, default=8) 19 | parser.add_argument('--seed', type=int, default=-1, help='random seed (-1 for no manual seed)') 20 | parser.add_argument('--model_file', type=str, default=None, help='the path of pre-trained model file') 21 | parser.add_argument('--debug', action='store_true') 22 | 23 | return init_args(parser.parse_args()) 24 | 25 | 26 | def init_args(args): 27 | if not os.path.exists(args.model_path): 28 | os.makedirs(args.model_path) 29 | 30 | if os.path.exists(args.log_path): 31 | shutil.rmtree(args.log_path) 32 | if not os.path.exists(args.log_path): 33 | os.makedirs(args.log_path) 34 | 35 | if not os.path.exists(args.output_path): 36 | os.makedirs(args.output_path) 37 | 38 | return args -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | joblib==0.13.0 2 | numpy==1.14.5 3 | pandas==0.23.4 4 | scikit-learn==0.20.0 5 | scipy==1.1.0 6 | tensorboard==1.15.0 7 | tensorboard-logger==0.1.0 8 | tensorflow==1.15.4 9 | tensorflow-estimator==1.13.0 10 | torch==1.0.0 11 | torchvision==0.2.1 12 | tqdm==4.31.1 13 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | model_path="./models/BaSNet" 2 | output_path="./outputs/BaSNet" 3 | log_path="./logs/BaSNet" 4 | seed=-1 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -W ignore ./main.py --model_path ${model_path} --output_path ${output_path} --log_path ${log_path} --seed ${seed} -------------------------------------------------------------------------------- /run_eval.sh: -------------------------------------------------------------------------------- 1 | model_path="./models/BaSNet_eval" 2 | output_path="./outputs/BaSNet_eval" 3 | log_path="./logs/BaSNet_eval" 4 | model_file='./BaSNet_model_best.pkl' 5 | 6 | CUDA_VISIBLE_DEVICES=0 python -W ignore ./main_eval.py --model_path ${model_path} --output_path ${output_path} --log_path ${log_path} --model_file ${model_file} 7 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import utils 5 | import os 6 | import json 7 | from eval.eval_detection import ANETdetection 8 | from tqdm import tqdm 9 | 10 | def test(net, config, logger, test_loader, test_info, step, model_file=None): 11 | with torch.no_grad(): 12 | net.eval() 13 | 14 | if model_file is not None: 15 | net.load_state_dict(torch.load(model_file)) 16 | 17 | final_res = {} 18 | final_res['version'] = 'VERSION 1.3' 19 | final_res['results'] = {} 20 | final_res['external_data'] = {'used': True, 'details': 'Features from I3D Network'} 21 | 22 | num_correct = 0. 23 | num_total = 0. 24 | 25 | load_iter = iter(test_loader) 26 | 27 | for i in range(len(test_loader.dataset)): 28 | 29 | _data, _label, _, vid_name, vid_num_seg = next(load_iter) 30 | 31 | _data = _data.cuda() 32 | _label = _label.cuda() 33 | 34 | _, cas_base, score_supp, cas_supp, fore_weights = net(_data) 35 | 36 | label_np = _label.cpu().numpy() 37 | score_np = score_supp[0,:-1].cpu().data.numpy() 38 | 39 | score_np[np.where(score_np < config.class_thresh)] = 0 40 | score_np[np.where(score_np >= config.class_thresh)] = 1 41 | 42 | correct_pred = np.sum(label_np == score_np, axis=1) 43 | 44 | num_correct += np.sum((correct_pred == config.num_classes).astype(np.float32)) 45 | num_total += correct_pred.shape[0] 46 | 47 | cas_base = utils.minmax_norm(cas_base) 48 | cas_supp = utils.minmax_norm(cas_supp) 49 | 50 | pred = np.where(score_np > config.class_thresh)[0] 51 | 52 | if pred.any(): 53 | cas_pred = cas_supp[0].cpu().numpy()[:, pred] 54 | cas_pred = np.reshape(cas_pred, (config.num_segments, -1, 1)) 55 | 56 | cas_pred = utils.upgrade_resolution(cas_pred, config.scale) 57 | 58 | proposal_dict = {} 59 | 60 | for i in range(len(config.act_thresh)): 61 | cas_temp = cas_pred.copy() 62 | 63 | zero_location = np.where(cas_temp[:, :, 0] < config.act_thresh[i]) 64 | cas_temp[zero_location] = 0 65 | 66 | seg_list = [] 67 | for c in range(len(pred)): 68 | pos = np.where(cas_temp[:, c, 0] > 0) 69 | seg_list.append(pos) 70 | 71 | proposals = utils.get_proposal_oic(seg_list, cas_temp, score_np, pred, config.scale, \ 72 | vid_num_seg[0].cpu().item(), config.feature_fps, config.num_segments) 73 | 74 | for i in range(len(proposals)): 75 | class_id = proposals[i][0][0] 76 | 77 | if class_id not in proposal_dict.keys(): 78 | proposal_dict[class_id] = [] 79 | 80 | proposal_dict[class_id] += proposals[i] 81 | 82 | final_proposals = [] 83 | for class_id in proposal_dict.keys(): 84 | final_proposals.append(utils.nms(proposal_dict[class_id], 0.7)) 85 | 86 | final_res['results'][vid_name[0]] = utils.result2json(final_proposals) 87 | 88 | test_acc = num_correct / num_total 89 | 90 | json_path = os.path.join(config.output_path, 'temp_result.json') 91 | with open(json_path, 'w') as f: 92 | json.dump(final_res, f) 93 | f.close() 94 | 95 | tIoU_thresh = np.linspace(0.1, 0.9, 9) 96 | anet_detection = ANETdetection(config.gt_path, json_path, 97 | subset='test', tiou_thresholds=tIoU_thresh, 98 | verbose=False, check_status=False) 99 | mAP, average_mAP = anet_detection.evaluate() 100 | 101 | logger.log_value('Test accuracy', test_acc, step) 102 | 103 | for i in range(tIoU_thresh.shape[0]): 104 | logger.log_value('mAP@{:.1f}'.format(tIoU_thresh[i]), mAP[i], step) 105 | 106 | logger.log_value('Average mAP', average_mAP, step) 107 | 108 | test_info["step"].append(step) 109 | test_info["test_acc"].append(test_acc) 110 | test_info["average_mAP"].append(average_mAP) 111 | 112 | for i in range(tIoU_thresh.shape[0]): 113 | test_info["mAP@{:.1f}".format(tIoU_thresh[i])].append(mAP[i]) 114 | -------------------------------------------------------------------------------- /thumos_features.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import os 3 | import csv 4 | import json 5 | import numpy as np 6 | import torch 7 | import pdb 8 | import time 9 | import random 10 | import utils 11 | import config 12 | 13 | 14 | class ThumosFeature(data.Dataset): 15 | def __init__(self, data_path, mode, modal, feature_fps, num_segments, len_feature, sampling, seed=-1, supervision='weak'): 16 | if seed >= 0: 17 | utils.set_seed(seed) 18 | 19 | self.mode = mode 20 | self.modal = modal 21 | self.feature_fps = feature_fps 22 | self.num_segments = num_segments 23 | self.len_feature = len_feature 24 | 25 | if self.modal == 'all': 26 | self.feature_path = [] 27 | for _modal in ['rgb', 'flow']: 28 | self.feature_path.append(os.path.join(data_path, 'features', self.mode, _modal)) 29 | else: 30 | self.feature_path = os.path.join(data_path, 'features', self.mode, self.modal) 31 | 32 | split_path = os.path.join(data_path, 'split_{}.txt'.format(self.mode)) 33 | split_file = open(split_path, 'r') 34 | self.vid_list = [] 35 | for line in split_file: 36 | self.vid_list.append(line.strip()) 37 | split_file.close() 38 | 39 | anno_path = os.path.join(data_path, 'gt.json') 40 | anno_file = open(anno_path, 'r') 41 | self.anno = json.load(anno_file) 42 | anno_file.close() 43 | 44 | self.class_name_to_idx = dict((v, k) for k, v in config.class_dict.items()) 45 | self.num_classes = len(self.class_name_to_idx.keys()) 46 | 47 | self.supervision = supervision 48 | self.sampling = sampling 49 | 50 | 51 | def __len__(self): 52 | return len(self.vid_list) 53 | 54 | def __getitem__(self, index): 55 | data, vid_num_seg, sample_idx = self.get_data(index) 56 | label, temp_anno = self.get_label(index, vid_num_seg, sample_idx) 57 | 58 | return data, label, temp_anno, self.vid_list[index], vid_num_seg 59 | 60 | def get_data(self, index): 61 | vid_name = self.vid_list[index] 62 | 63 | vid_num_seg = 0 64 | 65 | if self.modal == 'all': 66 | rgb_feature = np.load(os.path.join(self.feature_path[0], 67 | vid_name + '.npy')).astype(np.float32) 68 | flow_feature = np.load(os.path.join(self.feature_path[1], 69 | vid_name + '.npy')).astype(np.float32) 70 | 71 | vid_num_seg = rgb_feature.shape[0] 72 | 73 | if self.sampling == 'random': 74 | sample_idx = self.random_perturb(rgb_feature.shape[0]) 75 | elif self.sampling == 'uniform': 76 | sample_idx = self.uniform_sampling(rgb_feature.shape[0]) 77 | else: 78 | raise AssertionError('Not supported sampling !') 79 | 80 | rgb_feature = rgb_feature[sample_idx] 81 | flow_feature = flow_feature[sample_idx] 82 | 83 | feature = np.concatenate((rgb_feature, flow_feature), axis=1) 84 | else: 85 | feature = np.load(os.path.join(self.feature_path, 86 | vid_name + '.npy')).astype(np.float32) 87 | 88 | vid_num_seg = feature.shape[0] 89 | 90 | if self.sampling == 'random': 91 | sample_idx = self.random_perturb(feature.shape[0]) 92 | elif self.sampling == 'uniform': 93 | sample_idx = self.uniform_sampling(feature.shape[0]) 94 | else: 95 | raise AssertionError('Not supported sampling !') 96 | 97 | feature = feature[sample_idx] 98 | 99 | return torch.from_numpy(feature), vid_num_seg, sample_idx 100 | 101 | def get_label(self, index, vid_num_seg, sample_idx): 102 | vid_name = self.vid_list[index] 103 | anno_list = self.anno['database'][vid_name]['annotations'] 104 | label = np.zeros([self.num_classes], dtype=np.float32) 105 | 106 | classwise_anno = [[]] * self.num_classes 107 | 108 | for _anno in anno_list: 109 | label[self.class_name_to_idx[_anno['label']]] = 1 110 | classwise_anno[self.class_name_to_idx[_anno['label']]].append(_anno) 111 | 112 | if self.supervision == 'weak': 113 | return label, torch.Tensor(0) 114 | else: 115 | temp_anno = np.zeros([vid_num_seg, self.num_classes]) 116 | t_factor = self.feature_fps / 16 117 | 118 | for class_idx in range(self.num_classes): 119 | if label[class_idx] != 1: 120 | continue 121 | 122 | for _anno in classwise_anno[class_idx]: 123 | tmp_start_sec = float(_anno['segment'][0]) 124 | tmp_end_sec = float(_anno['segment'][1]) 125 | 126 | tmp_start = round(tmp_start_sec * t_factor) 127 | tmp_end = round(tmp_end_sec * t_factor) 128 | 129 | temp_anno[tmp_start:tmp_end+1, class_idx] = 1 130 | 131 | temp_anno = temp_anno[sample_idx, :] 132 | 133 | return label, torch.from_numpy(temp_anno) 134 | 135 | 136 | def random_perturb(self, length): 137 | if self.num_segments == length: 138 | return np.arange(self.num_segments).astype(int) 139 | samples = np.arange(self.num_segments) * length / self.num_segments 140 | for i in range(self.num_segments): 141 | if i < self.num_segments - 1: 142 | if int(samples[i]) != int(samples[i + 1]): 143 | samples[i] = np.random.choice(range(int(samples[i]), int(samples[i + 1]) + 1)) 144 | else: 145 | samples[i] = int(samples[i]) 146 | else: 147 | if int(samples[i]) < length - 1: 148 | samples[i] = np.random.choice(range(int(samples[i]), length)) 149 | else: 150 | samples[i] = int(samples[i]) 151 | return samples.astype(int) 152 | 153 | 154 | def uniform_sampling(self, length): 155 | if self.num_segments == length: 156 | return np.arange(self.num_segments).astype(int) 157 | samples = np.arange(self.num_segments) * length / self.num_segments 158 | samples = np.floor(samples) 159 | return samples.astype(int) 160 | 161 | 162 | 163 | 164 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | class BaS_Net_loss(nn.Module): 6 | def __init__(self, alpha): 7 | super(BaS_Net_loss, self).__init__() 8 | self.alpha = alpha 9 | self.ce_criterion = nn.BCELoss() 10 | 11 | def forward(self, score_base, score_supp, fore_weights, label): 12 | loss = {} 13 | 14 | label_base = torch.cat((label, torch.ones((label.shape[0], 1)).cuda()), dim=1) 15 | label_supp = torch.cat((label, torch.zeros((label.shape[0], 1)).cuda()), dim=1) 16 | 17 | label_base = label_base / torch.sum(label_base, dim=1, keepdim=True) 18 | label_supp = label_supp / torch.sum(label_supp, dim=1, keepdim=True) 19 | 20 | loss_base = self.ce_criterion(score_base, label_base) 21 | loss_supp = self.ce_criterion(score_supp, label_supp) 22 | loss_norm = torch.mean(torch.norm(fore_weights, p=1, dim=1)) 23 | 24 | loss_total = loss_base + loss_supp + self.alpha * loss_norm 25 | 26 | loss["loss_base"] = loss_base 27 | loss["loss_supp"] = loss_supp 28 | loss["loss_norm"] = loss_norm 29 | loss["loss_total"] = loss_total 30 | 31 | return loss_total, loss 32 | 33 | def train(net, train_loader, loader_iter, optimizer, criterion, logger, step): 34 | net.train() 35 | try: 36 | _data, _label, _, _, _ = next(loader_iter) 37 | except: 38 | loader_iter = iter(train_loader) 39 | _data, _label, _, _, _ = next(loader_iter) 40 | 41 | _data = _data.cuda() 42 | _label = _label.cuda() 43 | 44 | optimizer.zero_grad() 45 | 46 | score_base, _, score_supp, _, fore_weights = net(_data) 47 | 48 | cost, loss = criterion(score_base, score_supp, fore_weights, _label) 49 | 50 | cost.backward() 51 | optimizer.step() 52 | 53 | for key in loss.keys(): 54 | logger.log_value(key, loss[key].cpu().item(), step) -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from scipy.interpolate import interp1d 5 | import os 6 | import sys 7 | import random 8 | import config 9 | 10 | 11 | def upgrade_resolution(arr, scale): 12 | x = np.arange(0, arr.shape[0]) 13 | f = interp1d(x, arr, kind='linear', axis=0, fill_value='extrapolate') 14 | scale_x = np.arange(0, arr.shape[0], 1 / scale) 15 | up_scale = f(scale_x) 16 | return up_scale 17 | 18 | 19 | def get_proposal_oic(tList, wtcam, final_score, c_pred, scale, v_len, sampling_frames, num_segments, lambda_=0.25, gamma=0.2): 20 | t_factor = (16 * v_len) / (scale * num_segments * sampling_frames) 21 | temp = [] 22 | for i in range(len(tList)): 23 | c_temp = [] 24 | temp_list = np.array(tList[i])[0] 25 | if temp_list.any(): 26 | grouped_temp_list = grouping(temp_list) 27 | for j in range(len(grouped_temp_list)): 28 | inner_score = np.mean(wtcam[grouped_temp_list[j], i, 0]) 29 | 30 | len_proposal = len(grouped_temp_list[j]) 31 | outer_s = max(0, int(grouped_temp_list[j][0] - lambda_ * len_proposal)) 32 | outer_e = min(int(wtcam.shape[0] - 1), int(grouped_temp_list[j][-1] + lambda_ * len_proposal)) 33 | 34 | outer_temp_list = list(range(outer_s, int(grouped_temp_list[j][0]))) + list(range(int(grouped_temp_list[j][-1] + 1), outer_e + 1)) 35 | 36 | if len(outer_temp_list) == 0: 37 | outer_score = 0 38 | else: 39 | outer_score = np.mean(wtcam[outer_temp_list, i, 0]) 40 | 41 | c_score = inner_score - outer_score + gamma * final_score[c_pred[i]] 42 | t_start = grouped_temp_list[j][0] * t_factor 43 | t_end = (grouped_temp_list[j][-1] + 1) * t_factor 44 | c_temp.append([c_pred[i], c_score, t_start, t_end]) 45 | temp.append(c_temp) 46 | return temp 47 | 48 | 49 | def result2json(result): 50 | result_file = [] 51 | for i in range(len(result)): 52 | for j in range(len(result[i])): 53 | line = {'label': config.class_dict[result[i][j][0]], 'score': result[i][j][1], 54 | 'segment': [result[i][j][2], result[i][j][3]]} 55 | result_file.append(line) 56 | return result_file 57 | 58 | 59 | def grouping(arr): 60 | return np.split(arr, np.where(np.diff(arr) != 1)[0] + 1) 61 | 62 | 63 | def save_best_record_thumos(test_info, file_path): 64 | fo = open(file_path, "w") 65 | fo.write("Step: {}\n".format(test_info["step"][-1])) 66 | fo.write("Test_acc: {:.2f}\n".format(test_info["test_acc"][-1])) 67 | fo.write("average_mAP: {:.4f}\n".format(test_info["average_mAP"][-1])) 68 | 69 | tIoU_thresh = np.linspace(0.1, 0.9, 9) 70 | for i in range(len(tIoU_thresh)): 71 | fo.write("mAP@{:.1f}: {:.4f}\n".format(tIoU_thresh[i], test_info["mAP@{:.1f}".format(tIoU_thresh[i])][-1])) 72 | 73 | fo.close() 74 | 75 | 76 | def minmax_norm(act_map): 77 | max_val = nn.ReLU()(torch.max(act_map, dim=1)[0]) 78 | min_val = nn.ReLU()(torch.min(act_map, dim=1)[0]) 79 | delta = max_val - min_val 80 | delta[delta <=0] = 1 81 | ret = (act_map - min_val) / delta 82 | 83 | return ret 84 | 85 | 86 | def nms(proposals, thresh): 87 | proposals = np.array(proposals) 88 | x1 = proposals[:, 2] 89 | x2 = proposals[:, 3] 90 | scores = proposals[:, 1] 91 | 92 | areas = x2 - x1 + 1 93 | order = scores.argsort()[::-1] 94 | 95 | keep = [] 96 | while order.size > 0: 97 | i = order[0] 98 | keep.append(proposals[i].tolist()) 99 | xx1 = np.maximum(x1[i], x1[order[1:]]) 100 | xx2 = np.minimum(x2[i], x2[order[1:]]) 101 | 102 | inter = np.maximum(0.0, xx2 - xx1 + 1) 103 | 104 | iou = inter / (areas[i] + areas[order[1:]] - inter) 105 | 106 | inds = np.where(iou < thresh)[0] 107 | order = order[inds + 1] 108 | 109 | return keep 110 | 111 | 112 | def set_seed(seed): 113 | torch.manual_seed(seed) 114 | np.random.seed(seed) 115 | torch.cuda.manual_seed_all(seed) 116 | random.seed(seed) 117 | torch.backends.cudnn.deterministic=True 118 | torch.backends.cudnn.benchmark=False 119 | --------------------------------------------------------------------------------