├── README.md ├── compute.py ├── convert.py ├── data └── tmp.txt ├── dataset.py ├── make_proposal.py ├── mixer.py ├── model.py ├── oad_recurrent_main.py ├── oracle_hungarian.py ├── requirements.txt ├── t_oad_model └── tmp.txt ├── test.sh ├── thumos14_classifier_model └── tmp.txt ├── thumos14_v2.json ├── train.sh ├── utils.py └── yamls └── canonical.yaml /README.md: -------------------------------------------------------------------------------- 1 | "ActionSwitch: Class-agnostic Detection of Simultaneous Actions in Streaming Videos" [ECCV2024] 2 | ============= 3 | [Paper](https://arxiv.org/abs/2407.12987) [Project page](https://musicaloffering.github.io/ActionSwitch-release/) 4 | 5 | 6 | Training and assessing the model 7 | ============= 8 | 1. Check `requirements.txt` and install the necessary packages (there are very few!). 9 | 10 | 2. Prepare the data (features and labels) in the `data` directory. 11 | In this code, we will work with the THUMOS14 dataset. 12 | Features and labels are available [here](https://drive.google.com/file/d/1AUyo2YiDYMsU99G18cxWA2BcG-lh59Bf/view?usp=sharing). 13 | Extract the downloaded file into the `data` directory. 14 | After all, the file structure should be as follows: 15 | ``` 16 | data 17 | |--tmp.txt 18 | |--thumos14 19 | | |--thumos14_4state_label_1 20 | | | |--video_test_0000004.npy 21 | | | |--... 22 | | | |--... 23 | | |--thumos14_features 24 | | | |--video_test_0000004.npy 25 | | | |--... 26 | | | |--... 27 | | |--thumos14_oracle_proposals.pkl 28 | | |--thumos14_v2.json 29 | ``` 30 | 3. Download the classifier [here](https://drive.google.com/file/d/1oElmzpwHPxMvyAZtjN-_AWJsm59lCziM/view?usp=sharing), and place it in the `thumos14_classifier_model` directory. 31 | 32 | 4. [Optional] Download [checkpoint.pt](https://drive.google.com/file/d/1WEdEERZH-uw2yA9YT1fWuXgn30b9X6cX/view?usp=sharing) and place it in the `t_oad_model` directory. 33 | 34 | 5. For training, run the following command: 35 | ``` 36 | bash train.sh 37 | ``` 38 | This script will train the model, make proposals with the model, and evaluate the Hungarian F1 score of the model's predictions. 39 | 6. To test the given checkpoint, complete step 4 and run the following command: 40 | ``` 41 | bash test.sh 42 | ``` 43 | This will result in a `53.2` hungarian f1 score. 44 | -------------------------------------------------------------------------------- /compute.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import json 4 | import pandas as pd 5 | from joblib import Parallel, delayed 6 | from utils import compute_average_precision_detection, get_blocked_videos 7 | 8 | #python compute.py thumos14_v2.json canonical.json 9 | 10 | parser = argparse.ArgumentParser(description='This script allows you to evaluate the ActivityNet ' 11 | 'detection task which is intended to evaluate the ability ' 12 | 'of algorithms to temporally localize activities in ' 13 | 'untrimmed video sequences.') 14 | parser.add_argument('ground_truth_filename', 15 | help='Full path to json file containing the ground truth.') 16 | parser.add_argument('prediction_filename', 17 | help='Full path to json file containing the predictions.') 18 | parser.add_argument('--dataset', default='Thumos', 19 | help=('Dataset that wants to get map values')) 20 | parser.add_argument('--subset', default='test', 21 | help='String indicating subset to evaluate: ') 22 | parser.add_argument('--tiou_thresholds', type=float, default=0.7, 23 | help='Temporal intersection over union threshold.') 24 | parser.add_argument('--verbose', type=bool, default=True) 25 | parser.add_argument('--check_status', type=bool, default=True) 26 | 27 | def main(): 28 | args = parser.parse_args() 29 | ground_truth_filename = args.ground_truth_filename 30 | prediction_filename = args.prediction_filename 31 | dataset = args.dataset 32 | subset = args.subset 33 | tiou_thresholds = args.tiou_thresholds 34 | verbose = args.verbose 35 | check_status = args.check_status 36 | 37 | if dataset == 'Thumos': 38 | detection = THUMOSdetection(ground_truth_filename, prediction_filename, 39 | subset=subset, tiou_thresholds=tiou_thresholds, 40 | verbose=verbose, check_status=False) 41 | 42 | detection.evaluate() 43 | 44 | class THUMOSdetection(object): 45 | GROUND_TRUTH_FIELDS = ['database'] 46 | PREDICTION_FIELDS = ['results'] 47 | 48 | def __init__(self, ground_truth_filename=None, prediction_filename=None, 49 | ground_truth_fields=GROUND_TRUTH_FIELDS, 50 | prediction_fields=PREDICTION_FIELDS, 51 | tiou_thresholds=np.linspace(0.5, 0.95, 10), 52 | subset="test", verbose=False, 53 | check_status=False): 54 | if not ground_truth_filename: 55 | raise IOError('Please input a valid ground truth file.') 56 | if not prediction_filename: 57 | raise IOError('Please input a valid prediction file.') 58 | self.subset = subset 59 | self.tiou_thresholds = [tiou_thresholds] 60 | self.verbose = verbose 61 | self.gt_fields = ground_truth_fields 62 | self.pred_fields = prediction_fields 63 | self.ap = None 64 | self.check_status = check_status 65 | if self.check_status: 66 | self.blocked_videos = get_blocked_videos() 67 | else: 68 | self.blocked_videos = list() 69 | self.ground_truth, self.activity_index = self._import_ground_truth(ground_truth_filename) 70 | self.prediction = self._import_prediction(prediction_filename) 71 | 72 | if self.verbose: 73 | print ('[INIT] Loaded annotations from {} subset.'.format(subset)) 74 | nr_gt = len(self.ground_truth) 75 | print ('\tNumber of ground truth instances: {}'.format(nr_gt)) 76 | nr_pred = len(self.prediction) 77 | print ('\tNumber of predictions: {}'.format(nr_pred)) 78 | print ('\tFixed threshold for tiou score: {}'.format(self.tiou_thresholds)) 79 | 80 | def _import_ground_truth(self, ground_truth_filename): 81 | """Reads ground truth file, checks if it is well formatted, and returns 82 | the ground truth instances and the activity classes. 83 | Parameters 84 | ---------- 85 | ground_truth_filename : str 86 | Full path to the ground truth json file. 87 | Outputs 88 | ------- 89 | ground_truth : df 90 | Data frame containing the ground truth instances. 91 | activity_index : dict 92 | Dictionary containing class index. 93 | """ 94 | with open(ground_truth_filename, 'r', encoding='utf-8') as fobj: 95 | data = json.load(fobj) 96 | 97 | if not all([field in data.keys() for field in self.gt_fields]): 98 | raise IOError('Please input a valid ground truth file.') 99 | 100 | activity_index, cidx = {}, 0 101 | video_lst, t_start_lst, t_end_lst, label_lst = [], [], [], [] 102 | for vidname, v in data['database'].items(): 103 | if self.subset != v['subset']: 104 | continue 105 | if vidname in ['video_test_0000270', 'video_test_0001496']: 106 | continue 107 | 108 | for ann in v['annotations']: 109 | if ann['label'] not in activity_index: 110 | activity_index[ann['label']] = cidx 111 | cidx += 1 112 | video_lst.append(vidname) 113 | #t_start_lst.append(float(ann['segment_frame'][0] )) # bbdb / 6 | THUMOS / 5 | HACS / 2 114 | #t_end_lst.append(float(ann['segment_frame'][1] )) 115 | t_start_lst.append(ann['segment'][0]) 116 | t_end_lst.append(ann['segment'][1]) 117 | label_lst.append(activity_index[ann['label']]) 118 | ground_truth = pd.DataFrame({'video-id': video_lst, 119 | 't-start': t_start_lst, 120 | 't-end': t_end_lst, 121 | 'label': label_lst}) 122 | print(activity_index) 123 | return ground_truth, activity_index 124 | 125 | def _import_prediction(self, prediction_filename): 126 | """Reads prediction file, checks if it is well formatted, and returns 127 | the prediction instances. 128 | Parameters 129 | ---------- 130 | prediction_filename : str 131 | Full path to the prediction json file. 132 | Outputs 133 | ------- 134 | prediction : df 135 | Data frame containing the prediction instances. 136 | """ 137 | with open(prediction_filename, 'r') as fobj: 138 | data = json.load(fobj) 139 | 140 | if not all([field in data.keys() for field in self.pred_fields]): 141 | raise IOError('Please input a valid prediction file.') 142 | 143 | video_lst, t_start_lst, t_end_lst = [], [], [] 144 | label_lst, score_lst = [], [] 145 | for videoid, v in data['results'].items(): 146 | 147 | if videoid in ['video_test_0000270', 'video_test_0001496']: 148 | continue 149 | 150 | if videoid in self.blocked_videos: 151 | continue 152 | for result in v: 153 | try: 154 | label = self.activity_index[result['label']] 155 | except KeyError: 156 | label = 0 157 | 158 | video_lst.append(videoid) 159 | t_start_lst.append(float(result['segment'][0])) 160 | t_end_lst.append(float(result['segment'][1])) 161 | label_lst.append(label) 162 | score_lst.append(result['score']) 163 | prediction = pd.DataFrame({'video-id': video_lst, 164 | 't-start': t_start_lst, 165 | 't-end': t_end_lst, 166 | 'label': label_lst, 167 | 'score': score_lst}) 168 | return prediction 169 | 170 | def _get_predictions_with_label(self, prediction_by_label, label_name, cidx): 171 | """Get all predicitons of the given label. Return empty DataFrame if there 172 | is no predcitions with the given label. 173 | """ 174 | try: 175 | return prediction_by_label.get_group(cidx).reset_index(drop=True) 176 | except: 177 | print ('Warning: No predictions of label \'%s\' were provdied.' % label_name) 178 | return pd.DataFrame() 179 | 180 | def wrapper_compute_average_precision(self): 181 | """Computes average precision for each class in the subset. 182 | """ 183 | ap = np.zeros((len(self.tiou_thresholds), len(self.activity_index))) 184 | 185 | # Adaptation to query faster 186 | ground_truth_by_label = self.ground_truth.groupby('label') 187 | prediction_by_label = self.prediction.groupby('label') 188 | n_jobs = min(30, len(self.activity_index)) 189 | results = Parallel(n_jobs=n_jobs)( 190 | delayed(compute_average_precision_detection)( 191 | ground_truth=ground_truth_by_label.get_group(cidx).reset_index(drop=True), 192 | prediction=self._get_predictions_with_label(prediction_by_label, label_name, cidx), 193 | tiou_thresholds=self.tiou_thresholds, 194 | ) for label_name, cidx in self.activity_index.items()) 195 | 196 | for i, cidx in enumerate(self.activity_index.values()): 197 | ap[:,cidx] = results[i] 198 | 199 | return ap 200 | 201 | def evaluate(self): 202 | """Evaluates a prediction file. For the detection task we measure the 203 | interpolated mean average precision to measure the performance of a 204 | method. 205 | """ 206 | self.ap = self.wrapper_compute_average_precision() 207 | 208 | self.mAP = self.ap.mean(axis=1) 209 | self.average_mAP = self.mAP.mean() 210 | formated_ap = [np.around((elem), 3) * 100 for elem in self.ap] 211 | 212 | if self.verbose: 213 | print ('[RESULTS] Performance on THUMOS detection task.') 214 | for key in self.activity_index: 215 | print('class: {} - ap: {}'.format(key, self.ap[0][self.activity_index[key]] * 100 )) 216 | print ('\tAverage-mAP: {}'.format(self.average_mAP)) 217 | 218 | if __name__ == '__main__': 219 | main() 220 | -------------------------------------------------------------------------------- /convert.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pydoc import classname 4 | import torch 5 | import yaml 6 | import random 7 | import argparse 8 | import numpy as np 9 | 10 | from yaml.loader import FullLoader 11 | from model import Classifier 12 | from dataset import resize_feature 13 | from tqdm import tqdm 14 | from utils import get_idx_and_confidence, get_idx_and_confidences 15 | 16 | 17 | ''' 18 | IN: segment json 19 | { 20 | results: { 21 | vid_name: [ 22 | { 23 | "segment": [ 24 | 2, 25 | 7 26 | ] 27 | }, 28 | { 29 | ... 30 | }, 31 | ] 32 | } 33 | } 34 | 35 | OUT: labeled json 36 | { 37 | results: { 38 | vid_name: [ 39 | { 40 | "label": "quarrel", 41 | "segment": [ 42 | 2.33, 43 | 7.66 44 | ], 45 | "score": 0.90012 46 | }, 47 | { 48 | ... 49 | }, 50 | ] 51 | } 52 | } 53 | ''' 54 | 55 | if __name__ == '__main__': 56 | parser = argparse.ArgumentParser() 57 | parser.add_argument('--yaml_path', type=str, default='yamls/canonical.yaml') 58 | parser.add_argument('--source', type=str, default='canonical.json') 59 | parser.add_argument('--target', type=str, default='canonical.json') 60 | parser.add_argument('--load_model', type=str, default='standard_EqFalse_Epoch12.pt') 61 | parser.add_argument('--duplicate_proposal_num', type=int, default=1) 62 | args = parser.parse_args() 63 | yaml_path = args.yaml_path 64 | source = args.source 65 | target = args.target 66 | load_model_name = args.load_model 67 | duplicate_proposal_num = args.duplicate_proposal_num 68 | with open(yaml_path, encoding='utf-8') as f: 69 | config = yaml.load(f, FullLoader) 70 | with open(config['annotation_path'], encoding='utf-8') as f: 71 | meta = json.load(f) 72 | with open(source, encoding='utf-8') as f: 73 | source = json.load(f) 74 | classes = meta['classes'] 75 | classifier = Classifier(config) 76 | classifier.load_state_dict(torch.load(os.path.join(config['classifier_model_path'], load_model_name))) 77 | classifier.eval() 78 | out_dict = {'results': dict()} 79 | for filename in tqdm(source['results']): 80 | segment_list = [] 81 | meta_duration = meta['database'][filename]["duration"] 82 | feature = np.load(os.path.join(config['feature_path'], f'{filename}.npy')) 83 | shrink_ratio = meta_duration/len(feature) 84 | for segment in source['results'][filename]: 85 | st, ed = segment['segment'] 86 | if st == 0 and ed == 0: 87 | segment_list.append({'label': list(classes.keys())[0], 'segment': [st, ed], 'score': 0}) 88 | continue 89 | proposal_feature = feature[st:ed] 90 | if len(proposal_feature) < config['unit_length']: 91 | #padding 92 | proposal_feature_len = len(proposal_feature) 93 | padding_size = config['unit_length'] - len(proposal_feature) 94 | padding = np.zeros([padding_size, proposal_feature.shape[1]]) 95 | proposal_feature = np.concatenate([proposal_feature, padding], axis=0) 96 | elif len(proposal_feature) > config['unit_length']: 97 | #linear interpolate 98 | proposal_feature_len = config['unit_length'] 99 | proposal_feature = resize_feature(proposal_feature, config['unit_length']) 100 | else: 101 | #pittari! 102 | proposal_feature_len = config['unit_length'] 103 | mask = np.array([False for _ in range(config['unit_length'])]) 104 | mask[proposal_feature_len:] = True 105 | proposal_feature = torch.from_numpy(proposal_feature.astype(np.float32)).to(config['device']).unsqueeze(0) 106 | mask = torch.from_numpy(mask).to(config['device']).unsqueeze(0) 107 | with torch.no_grad(): 108 | score = classifier(proposal_feature, mask) 109 | score = torch.softmax(score, dim=1).squeeze().cpu().numpy()[1:] 110 | if duplicate_proposal_num == 1: 111 | #standard assignment 112 | class_idx, confidence = get_idx_and_confidence(score, [float(st*shrink_ratio), float(ed*shrink_ratio)], segment_list, classes) 113 | for key in classes.keys(): 114 | idx = classes[key] 115 | if idx == class_idx: 116 | class_name = key 117 | break 118 | _st, _ed = float(st*shrink_ratio), float(ed*shrink_ratio) 119 | segment_list.append({'label': class_name, 'segment': [_st, _ed], 'score': confidence}) 120 | elif duplicate_proposal_num > 1: 121 | #duplicated assignment 122 | class_idxes_and_confidences = get_idx_and_confidences(score, n=duplicate_proposal_num) 123 | for i_c in class_idxes_and_confidences: 124 | class_idx = i_c[0] 125 | confidence = i_c[1] 126 | for key in classes.keys(): 127 | idx = classes[key] 128 | if idx == class_idx: 129 | class_name = key 130 | break 131 | _st, _ed = float(st*shrink_ratio), float(ed*shrink_ratio) 132 | segment_list.append({'label': class_name, 'segment': [_st, _ed], 'score': confidence}) 133 | else: 134 | raise NotImplementedError() 135 | out_dict['results'][filename] = segment_list 136 | with open(target, 'w', encoding='utf-8') as f: 137 | json.dump(out_dict, f) 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | -------------------------------------------------------------------------------- /data/tmp.txt: -------------------------------------------------------------------------------- 1 | dummy file -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import math 4 | import random 5 | import yaml 6 | import json 7 | import numpy as np 8 | import torch.utils.data as data 9 | import os 10 | import os.path as osp 11 | import scipy.interpolate 12 | 13 | from yaml.loader import FullLoader 14 | from torch.utils.data import DataLoader 15 | 16 | def resize_feature(input_data, new_size): 17 | assert len(input_data) > 1 18 | x = np.arange(len(input_data)) 19 | f = scipy.interpolate.interp1d(x, input_data, axis=0) 20 | x_new = [i*float(len(input_data)-1)/(new_size-1) for i in range(new_size)] 21 | return f(x_new) 22 | 23 | def calculate_iou(prediction:list, answer:list): 24 | intersection = -1 25 | s1, e1 = prediction 26 | s2, e2 = answer 27 | if s1 > s2: 28 | s1, s2 = s2, s1 29 | e1, e2 = e2, e1 30 | if e1 <= s2: 31 | intersection = 0 32 | else: 33 | if e2 <= e1: 34 | intersection = (e2 - s2) 35 | else: 36 | intersection = (e1 - s2) 37 | l1 = e1 - s1 38 | l2 = e2 - s2 39 | iou = intersection/(l1 + l2 - intersection + 1e-8) 40 | return iou 41 | 42 | class Classifier_Dataset(data.Dataset): 43 | def __init__(self, config, subset): 44 | super().__init__() 45 | self.config = config 46 | self.subset = subset 47 | with open(config['annotation_path'], encoding='utf-8') as f: 48 | self.meta = json.load(f) 49 | self.idxes = self.get_db_idxes(subset) 50 | 51 | def __len__(self): 52 | return len(self.idxes) 53 | 54 | def __getitem__(self, index): 55 | name, _ = self.idxes[index] 56 | feature = np.load(os.path.join(self.config['feature_path'], f'{name}.npy')) 57 | original_feature_len = len(feature) 58 | duration = self.meta['database'][name]['duration'] 59 | annotations = self.meta['database'][name]['annotations'] 60 | cnt = 0 61 | randflag = False 62 | while True: 63 | annotation = random.choice(annotations) 64 | st, ed = annotation['segment'] 65 | st_idx = int((st/duration)*original_feature_len) 66 | ed_idx = int((ed/duration)*original_feature_len) 67 | if ed <= duration and st_idx != ed_idx: 68 | break 69 | cnt += 1 70 | if cnt > 10: 71 | randflag = True 72 | break 73 | if random.random() > self.config['foreground_ratio'] or len(annotations) == 0 or randflag: 74 | #random 75 | if random.random() > self.config['long_instance_ratio']: 76 | #short instance 77 | instance_length = max(1, min(int(random.random() * self.config['unit_length']), len(feature))) 78 | else: 79 | #long instance 80 | instance_length = min(max(self.config['unit_length'], int(random.random() * self.config['max_length'])), len(feature)) 81 | instance_start_idx = random.randint(0, len(feature) - instance_length) 82 | instance_end_idx = instance_start_idx+instance_length 83 | feature = feature[instance_start_idx:instance_end_idx] 84 | iou = 0 85 | label_idx = 0 86 | for annotation in annotations: 87 | st, ed = annotation['segment'] 88 | st_idx = int((st/duration)*original_feature_len) 89 | ed_idx = int((ed/duration)*original_feature_len) 90 | tmp_iou = calculate_iou([instance_start_idx, instance_end_idx], [st_idx, ed_idx]) 91 | if tmp_iou > iou: 92 | iou = tmp_iou 93 | if iou > self.config['iou_threshold']: 94 | label_idx = annotation['labelIndex'] + 1 95 | else: 96 | #foreground 97 | cnt = 0 98 | while True: 99 | st, ed = annotation['segment'] 100 | st_idx = int((st/duration)*original_feature_len) 101 | ed_idx = int((ed/duration)*original_feature_len) 102 | if ed <= duration and st_idx != ed_idx: 103 | break 104 | cnt += 1 105 | if cnt > 200: 106 | print(name) 107 | raise Exception() 108 | st, ed = annotation['segment'] 109 | center = (st+ed)/2 110 | instance_length = (ed-st) 111 | start_candidate = [max(0, st - (instance_length/2)), center] 112 | end_candidate = [center, min(ed + (instance_length/2), duration)] 113 | cnt = 0 114 | while True: 115 | start = random.uniform(*start_candidate) 116 | end = random.uniform(*end_candidate) 117 | if calculate_iou([start, end], annotation['segment']) > self.config['iou_threshold']: 118 | instance_start_idx = int((start/duration)*original_feature_len) 119 | instance_end_idx = int((end/duration)*original_feature_len) 120 | if instance_end_idx != instance_start_idx: 121 | iou = calculate_iou([start, end], annotation['segment']) 122 | label_idx = annotation['labelIndex'] + 1 123 | break 124 | cnt += 1 125 | if cnt > 200: 126 | print(name) 127 | raise Exception() 128 | instance_start_idx = int((start/duration)*original_feature_len) 129 | instance_end_idx = int((end/duration)*original_feature_len) 130 | feature = feature[instance_start_idx:instance_end_idx] 131 | if len(feature) < self.config['unit_length']: 132 | #padding 133 | feature_len = len(feature) 134 | padding_size = self.config['unit_length'] - len(feature) 135 | padding = np.zeros([padding_size, feature.shape[1]]) 136 | feature = np.concatenate([feature, padding], axis=0) 137 | elif len(feature) > self.config['unit_length']: 138 | #linear interpolate 139 | feature_len = self.config['unit_length'] 140 | feature = resize_feature(feature, self.config['unit_length']) 141 | else: 142 | #pittari! 143 | feature_len = self.config['unit_length'] 144 | mask = np.array([False for _ in range(self.config['unit_length'])]) 145 | mask[feature_len:] = True 146 | feature = feature.astype(np.float32) 147 | return feature, mask, label_idx 148 | 149 | def get_db_idxes(self, subset): 150 | idxes = [] 151 | # [ [file_name, cnt_idx] ... ] 152 | for name in self.meta['database']: 153 | if self.meta['database'][name]['subset'] == subset: 154 | feature_len = len(np.load(os.path.join(self.config['feature_path'], f'{name}.npy'))) 155 | chunk_num = feature_len//self.config['unit_length'] 156 | idxes.extend([[name, i] for i in range(chunk_num)]) 157 | return idxes 158 | 159 | class Sliced_Dataset(data.Dataset): 160 | def __init__(self, config, subset): 161 | super().__init__() 162 | self.config = config 163 | self.subset = subset 164 | with open(config['annotation_path'], encoding='utf-8') as f: 165 | self.meta = json.load(f) 166 | self.idxes = self.get_db_idxes(config['classifier_epoch_iter']) 167 | self.folder_list = os.listdir(osp.join(config['classifier_feature_path'], subset)) 168 | folder_prob = dict() 169 | cnt = 0 170 | for folder in self.folder_list: 171 | path = osp.join(config['classifier_feature_path'], subset, folder) 172 | cnt += len(os.listdir(path)) 173 | for folder in self.folder_list: 174 | path = osp.join(config['classifier_feature_path'], subset, folder) 175 | folder_prob[int(folder)] = len(os.listdir(path))/cnt 176 | keys = sorted(list(folder_prob.keys())) 177 | self.folder_prob = [] 178 | for key in keys: 179 | self.folder_prob.append(folder_prob[key]) 180 | 181 | 182 | def __len__(self): 183 | return len(self.idxes) 184 | 185 | def __getitem__(self, index): 186 | eqaul_sampling = self.config['classifier_equal_sampling'] 187 | if eqaul_sampling or self.subset == 'test': 188 | folder = random.sample(self.folder_list, 1)[0] 189 | base_path = osp.join(self.config['classifier_feature_path'], self.subset, folder) 190 | filename = random.sample(os.listdir(base_path), 1)[0] 191 | full_path = osp.join(base_path, filename) 192 | tmp = np.load(full_path, allow_pickle=True).item() 193 | #print(tmp['feature'].shape, tmp['a_b'], tmp['labelIndex']) 194 | else: 195 | p = np.array(self.folder_prob) 196 | p = p**self.config['classifier_temperature'] 197 | 198 | p /= sum(p) 199 | 200 | folder = str(np.random.choice(len(self.folder_list), 1, p=p)[0]) 201 | base_path = osp.join(self.config['classifier_feature_path'], self.subset, folder) 202 | filename = random.sample(os.listdir(base_path), 1)[0] 203 | full_path = osp.join(base_path, filename) 204 | tmp = np.load(full_path, allow_pickle=True).item() 205 | original_feature = tmp['feature'] 206 | a_b = tmp['a_b'] 207 | labelIndex = tmp['labelIndex'] 208 | if len(original_feature) > 5: 209 | while True: 210 | a = random.normalvariate(self.config['a'], self.config['a_std']) 211 | while a < 0 or a > 1: 212 | a = random.normalvariate(self.config['a'], self.config['a_std']) 213 | b = random.normalvariate(self.config['b'], self.config['b_std']) 214 | while b < 0 or b > 1: 215 | b = random.normalvariate(self.config['b'], self.config['b_std']) 216 | if a > b: 217 | a, b = b, a 218 | st_idx, ed_idx = round(len(original_feature)*a), round(len(original_feature)*b) 219 | if ed_idx - st_idx > 0: 220 | break 221 | else: 222 | a, b = a_b[0], a_b[1] 223 | st_idx, ed_idx = round(len(original_feature)*a), round(len(original_feature)*b) 224 | assert ed_idx - st_idx > 0 225 | feature = original_feature[st_idx:ed_idx] 226 | a, b = a_b[0], a_b[1] 227 | iou = calculate_iou([round(len(original_feature)*a), round(len(original_feature)*b)], [st_idx, ed_idx]) 228 | if iou > self.config['iou_threshold']: 229 | label_idx = labelIndex + 1 230 | else: 231 | label_idx = 0 232 | if len(feature) < self.config['unit_length']: 233 | #padding 234 | feature_len = len(feature) 235 | padding_size = self.config['unit_length'] - len(feature) 236 | padding = np.zeros([padding_size, feature.shape[1]]) 237 | feature = np.concatenate([feature, padding], axis=0) 238 | elif len(feature) > self.config['unit_length']: 239 | #linear interpolate 240 | feature_len = self.config['unit_length'] 241 | feature = resize_feature(feature, self.config['unit_length']) 242 | else: 243 | #pittari! 244 | feature_len = self.config['unit_length'] 245 | mask = np.array([False for _ in range(self.config['unit_length'])]) 246 | mask[feature_len:] = True 247 | feature = feature.astype(np.float32) 248 | return feature, mask, label_idx 249 | 250 | def get_db_idxes(self, length): 251 | idxes = [0 for _ in range(length)] 252 | return idxes 253 | 254 | class OAD_Dataset(data.Dataset): 255 | def __init__(self, config, mode='train'): 256 | super().__init__() 257 | self.mode = mode 258 | 259 | self.dt_iteration = config['dt_iteration'] 260 | self.feature_path = config['feature_path'] 261 | self.label_path = config['state_label_path'] 262 | self.oracle_label_path = self.label_path 263 | self.n_feature = config['n_feature'] 264 | self.history_length = config['history_length'] 265 | self.length_proportional_sampling = config['length_proportional_sampling'] 266 | self.n_state = config['n_state'] 267 | if config['dataset'] == 'fineaction': 268 | print('load feature_len...') 269 | with open(config['feature_len_path'], 'r', encoding='utf-8') as f: 270 | self.feature_len = json.load(f) 271 | with open(config['annotation_path'], 'r', encoding='utf-8') as f: 272 | self.meta = json.load(f)['database'] 273 | feature_names = [] 274 | for key in self.meta: 275 | if self.meta[key]['subset'] == mode: 276 | feature_names.append(f'{key}.npy') 277 | self.feature_names = feature_names 278 | prob_dict = {} 279 | if self.length_proportional_sampling and self.mode == 'train': 280 | cum_length = 0 281 | for feature_name in feature_names: 282 | path = os.path.join(self.feature_path, feature_name) 283 | if config['dataset'] == 'fineaction': 284 | vidname = feature_name[:-4] 285 | feature_length = self.feature_len[vidname] 286 | else: 287 | feature_length = len(np.load(path)) 288 | prob_dict[feature_name] = feature_length 289 | cum_length += feature_length 290 | for feature_name in feature_names: 291 | prob_dict[feature_name] = prob_dict[feature_name]/cum_length 292 | elif self.mode != 'train': 293 | prob_dict = None 294 | else: 295 | for feature_name in feature_names: 296 | prob_dict[feature_name] = 1/len(feature_names) 297 | self.prob_dict = prob_dict 298 | 299 | def __len__(self): 300 | if self.mode == 'train': 301 | return self.dt_iteration 302 | else: 303 | return len(self.feature_names) 304 | 305 | def __getitem__(self, index): 306 | if self.mode == 'train': 307 | name = random.choices(list(self.prob_dict.keys()), weights=self.prob_dict.values(), k=1)[0] 308 | path = os.path.join(self.feature_path, name) 309 | feature = np.load(path, mmap_mode='r') 310 | label_path_1 = f'{self.label_path}_1' 311 | label_path_2 = f'{self.label_path}_1' 312 | label_1 = np.load(os.path.join(label_path_1, name)) 313 | label_2 = np.load(os.path.join(label_path_2, name)) 314 | if len(feature) <= self.history_length: 315 | s = np.zeros((self.history_length, self.n_feature), dtype=np.float32) 316 | target_a_1 = np.zeros((self.history_length,), dtype=np.int64) 317 | target_a_2 = np.zeros((self.history_length,), dtype=np.int64) 318 | start_idx = 0 319 | s[start_idx:len(feature)] = feature[start_idx:len(feature)] 320 | target_a_1[start_idx:len(feature)] = label_1[start_idx:len(feature)] 321 | target_a_2[start_idx:len(feature)] = label_2[start_idx:len(feature)] 322 | else: 323 | start_idx = random.randint(0, len(feature)-self.history_length) 324 | s = feature[start_idx:start_idx+self.history_length] 325 | target_a_1 = label_1[start_idx:start_idx+self.history_length] 326 | target_a_2 = label_2[start_idx:start_idx+self.history_length] 327 | s = s.astype(np.float32) 328 | target_a_1 = target_a_1.astype(np.int64) 329 | target_a_2 = target_a_2.astype(np.int64) 330 | assert len(s) == len(target_a_1) and len(s) == len(target_a_2) 331 | return s, target_a_1, target_a_2 332 | else: 333 | #test 334 | name = self.feature_names[index] 335 | path = os.path.join(self.feature_path, name) 336 | s = np.load(path, mmap_mode='r') 337 | s = s.astype(np.float32) 338 | name = name[:-4] 339 | return s, name -------------------------------------------------------------------------------- /make_proposal.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import json 3 | import pickle 4 | import numpy as np 5 | import torch 6 | import matplotlib.pyplot as plt 7 | import argparse 8 | import os 9 | 10 | from dataset import OAD_Dataset 11 | from yaml.loader import FullLoader 12 | from torch.utils.data import DataLoader 13 | from model import OADModel 14 | from utils import get_4state_proposals, get_hungarian_score 15 | from tqdm import tqdm 16 | 17 | if __name__ == '__main__': 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--yaml_path', type=str, default='yamls/canonical.yaml') 20 | parser.add_argument('--json_name', type=str, default='canonical.json') 21 | parser.add_argument('--load_model', type=str, default='checkpoint.pt') 22 | args = parser.parse_args() 23 | yaml_path = args.yaml_path 24 | json_name = args.json_name 25 | load_model = args.load_model 26 | with open(yaml_path, encoding='utf-8') as f: 27 | config = yaml.load(f, FullLoader) 28 | #Agent 29 | agent = OADModel(config) 30 | agent.load_state_dict(torch.load(os.path.join(config['model_save_path'], load_model), map_location=config['device'])) 31 | agent.eval() 32 | agent.to(config['device']) 33 | #Dataset 34 | test_dataset = OAD_Dataset(config, mode='test') 35 | testloader = DataLoader(test_dataset, batch_size=1, shuffle=False) 36 | results = {} 37 | cnt = 0 38 | tp_cum = 0 39 | p_cum = 0 40 | a_cum = 0 41 | with open(config['oracle_proposal_path'], 'rb') as f: 42 | oracle_proposal_dict = pickle.load(f) 43 | for feature, video_name in tqdm(testloader): 44 | video_name = video_name[0] 45 | atsumari = [] 46 | feature = feature.squeeze().to(config['device']) 47 | duration = len(feature) 48 | hs = [torch.zeros(1, config['n_recurrent_hidden']).to(config['device']) for _ in range(config['n_recurrent_layer'])] 49 | for i in range(duration): 50 | with torch.no_grad(): 51 | snippet = feature[i].unsqueeze(0) 52 | hs, score = agent.encode(snippet, hs) 53 | a = torch.argmax(score, dim=1).squeeze().cpu().numpy().item() 54 | atsumari.append(a) 55 | oracle_proposals = oracle_proposal_dict[video_name] 56 | pred_proposals = get_4state_proposals(atsumari) 57 | hungarian_results = get_hungarian_score(oracle_proposals, pred_proposals) 58 | tp_cum += hungarian_results['tp'] 59 | p_cum += hungarian_results['p'] 60 | a_cum += hungarian_results['a'] 61 | l = [] 62 | if len(pred_proposals) == 0: 63 | pred_proposals.append([0,0]) 64 | for instance in pred_proposals: 65 | tmp = dict() 66 | tmp["segment"] = instance 67 | l.append(tmp) 68 | results[video_name] = l 69 | precision = tp_cum/(p_cum+1e-8) 70 | recall = tp_cum/(a_cum+1e-8) 71 | f1 = (2*precision*recall)/(precision+recall+1e-8) 72 | #print(f'hungarian f1, recall: {f1}, {recall}') 73 | final_dict = dict() 74 | final_dict['results'] = results 75 | with open(json_name, 'w', encoding='utf-8') as f: 76 | json.dump(final_dict, f) 77 | -------------------------------------------------------------------------------- /mixer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from functools import partial 3 | from einops.layers.torch import Rearrange, Reduce 4 | 5 | pair = lambda x: x if isinstance(x, tuple) else (x, x) 6 | 7 | class PreNormResidual(nn.Module): 8 | def __init__(self, dim, fn): 9 | super().__init__() 10 | self.fn = fn 11 | self.norm = nn.LayerNorm(dim, elementwise_affine=False) 12 | 13 | def forward(self, x): 14 | return self.fn(self.norm(x)) + x 15 | 16 | def FeedForward(dim, expansion_factor = 4, dropout = 0., dense = nn.Linear): 17 | inner_dim = int(dim * expansion_factor) 18 | return nn.Sequential( 19 | dense(dim, inner_dim), 20 | nn.GELU(), 21 | nn.Dropout(dropout), 22 | dense(inner_dim, dim), 23 | nn.Dropout(dropout) 24 | ) 25 | 26 | def MLPMixer(*,token_num, dim, depth, num_classes, expansion_factor = 4, expansion_factor_token = 4, dropout = 0.): 27 | chan_first, chan_last = partial(nn.Conv1d, kernel_size = 1), nn.Linear 28 | 29 | return nn.Sequential( 30 | *[nn.Sequential( 31 | PreNormResidual(dim, FeedForward(token_num, expansion_factor, dropout, chan_first)), 32 | PreNormResidual(dim, FeedForward(dim, expansion_factor_token, dropout, chan_last)) 33 | ) for _ in range(depth)], 34 | nn.LayerNorm(dim), 35 | Reduce('b n c -> b c', 'mean'), 36 | nn.Linear(dim, num_classes) 37 | ) -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import yaml 3 | import torch 4 | import torch.nn as nn 5 | 6 | from torch.nn import functional as F 7 | from yaml.loader import FullLoader 8 | from mixer import MLPMixer 9 | from typing import List 10 | 11 | class MultiCrossEntropyLoss(nn.Module): 12 | def __init__(self): 13 | super(MultiCrossEntropyLoss, self).__init__() 14 | 15 | def forward(self, pred, target): 16 | assert pred.size() == target.size() 17 | assert len(pred.size()) == 2 18 | target_sum = torch.sum(target, dim=1) 19 | target_div = torch.where(target_sum != 0, target_sum, torch.ones_like(target_sum)).unsqueeze(1) 20 | target = target/target_div 21 | logsoftmax = nn.LogSoftmax(dim=1).to(pred.device) 22 | output = torch.sum(-target * logsoftmax(pred), 1) 23 | return torch.mean(output) 24 | 25 | class PositionalEncoding(nn.Module): 26 | def __init__(self, d_model: int, max_len: int = 5000): 27 | super().__init__() 28 | position = torch.arange(max_len).unsqueeze(1) 29 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 30 | pe = torch.zeros(max_len, 1, d_model) 31 | pe[:, 0, 0::2] = torch.sin(position * div_term) 32 | pe[:, 0, 1::2] = torch.cos(position * div_term) 33 | self.register_buffer('pe', pe) 34 | 35 | def forward(self, x): 36 | """ 37 | Args: 38 | x: Tensor, shape [batch_size, seq_len, embedding_dim] 39 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 40 | """ 41 | x = x.permute(1,0,2) 42 | x = x + self.pe[:x.size(0)] 43 | x = x.permute(1,0,2).contiguous() 44 | return x 45 | 46 | class LearnablePositionalEncoding(nn.Module): 47 | def __init__(self, d_model: int, max_len: int = 100): 48 | super().__init__() 49 | self.embeddings = nn.Embedding(max_len, d_model) 50 | 51 | def forward(self, x): 52 | """ 53 | Args: 54 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 55 | """ 56 | positions = torch.arange(x.size(0)).to(self.embeddings.weight.device) 57 | pe = self.embeddings(positions).unsqueeze(1).to(x.device) 58 | x = x + pe 59 | return x 60 | 61 | class NewGELU(nn.Module): 62 | """ 63 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 64 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 65 | """ 66 | def forward(self, x): 67 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 68 | 69 | 70 | class Classifier(nn.Module): 71 | def __init__(self, config): 72 | super().__init__() 73 | self.config = config 74 | self.first = nn.Sequential( 75 | nn.Conv1d(self.config['feature_dim'], self.config['classifier_transformer_unit'], kernel_size=5, padding=2), 76 | NewGELU(), 77 | nn.Conv1d(self.config['classifier_transformer_unit'], self.config['classifier_transformer_unit'], kernel_size=5, padding=2), 78 | NewGELU(), 79 | ) 80 | self.class_token = nn.Embedding(num_embeddings=1, embedding_dim=self.config['classifier_transformer_unit']) 81 | self.layernorm = nn.LayerNorm(self.config['classifier_transformer_unit']) 82 | encoder_layer = nn.TransformerEncoderLayer(d_model=config['classifier_transformer_unit'], nhead=8, dim_feedforward=config['classifier_transformer_fc_unit'], activation='gelu') 83 | self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=config['classifier_transformer_layer_num']) 84 | self.mlp_head = nn.Sequential( 85 | nn.Linear(self.config['classifier_transformer_unit'], self.config['classifier_fc_unit']), 86 | NewGELU(), 87 | nn.Linear(self.config['classifier_fc_unit'], self.config['class_num']) 88 | ) 89 | self.to(config['device']) 90 | 91 | def forward(self, x, mask): 92 | #IN: [B, L, E], [B, L] 93 | #for transformer, [L,B,E] 94 | #for conv, [B, L, E]->[B,E,L] 95 | x = x.permute(0,2,1) 96 | x = self.first(x) 97 | #[B,E,L] -> [B,L,E] 98 | x = x.permute(0,2,1) 99 | x = torch.nn.functional.normalize(x, dim=-1) 100 | #[B,L,E] -> [L,B,E] 101 | x = x.permute(1,0,2) 102 | x = self.encoder(x, src_key_padding_mask=mask) 103 | 104 | #[L,B,E] -> [B,L,E] 105 | x = x.permute(1,0,2) 106 | x[mask] = torch.zeros_like(x[0,0]) 107 | valid = torch.logical_not(mask) 108 | valid_length = torch.sum(valid, dim=1).unsqueeze(1) 109 | x = torch.sum(x, dim=1)/(valid_length+1e-8) 110 | 111 | x = self.mlp_head(x) 112 | return x 113 | 114 | 115 | 116 | 117 | class CustomCell(nn.Module): 118 | def __init__(self, config): 119 | super(CustomCell, self).__init__() 120 | self.x_layernorm = nn.LayerNorm(config['n_recurrent_hidden']) 121 | self.h_layernorm = nn.LayerNorm(config['n_recurrent_hidden']) 122 | self.gru = nn.GRUCell(config['n_recurrent_hidden'], config['n_recurrent_hidden']) 123 | self.dropout = nn.Dropout(config['recurrent_pdrop']) 124 | self.linear = nn.Linear(config['n_recurrent_hidden'], config['n_recurrent_hidden']) 125 | self.processed_layernorm = nn.LayerNorm(config['n_recurrent_hidden']) 126 | self.act = nn.GELU() 127 | 128 | def forward(self, x, h): 129 | ''' 130 | IN: x: [B, E], h: [B, E] 131 | OUT: y: [B, E], next_h: [B, E] 132 | ''' 133 | normalized_x = self.x_layernorm(x) 134 | normalized_h = self.h_layernorm(h) 135 | next_h = self.gru(normalized_x, normalized_h) 136 | processed = self.dropout(next_h) 137 | y = processed 138 | return y, next_h 139 | 140 | 141 | 142 | class OADModel(nn.Module): 143 | def __init__(self, config): 144 | super(OADModel, self).__init__() 145 | self.config = config 146 | 147 | self.preprocess = nn.Sequential( 148 | nn.Linear(config['n_feature'], config['n_recurrent_hidden']), 149 | nn.LayerNorm(config['n_recurrent_hidden']), 150 | nn.GELU(), 151 | ) 152 | self.cells = nn.ModuleList([CustomCell(config) for _ in range(config['n_recurrent_layer'])]) 153 | self.proj = nn.Linear(config["n_recurrent_hidden"], config["n_projection_hidden"]) 154 | 155 | self.classifier = MLPMixer( 156 | token_num=config["n_recurrent_layer"], dim=config['n_projection_hidden'], 157 | depth=config['n_classifier_layer'], num_classes=config['n_state'], expansion_factor=4, expansion_factor_token=4 158 | ) 159 | 160 | self.opt = torch.optim.AdamW(self.parameters(), lr=config['oad_base_lr'], betas=config['oad_betas'], weight_decay=config['oad_weight_decay']) 161 | assert config['n_state'] == len(config['cross_entropy_weight']) 162 | self.cross_entropy_loss = nn.CrossEntropyLoss(torch.Tensor(config['cross_entropy_weight'])) 163 | self.focal_loss = None 164 | 165 | def encode(self, feature_in:torch.FloatTensor, hs:List[torch.FloatTensor]): 166 | x = self.preprocess(feature_in) 167 | n_layer = len(hs) 168 | next_hs = [] 169 | projected_xs = [] 170 | for i in range(n_layer): 171 | h = hs[i] 172 | x, next_h = self.cells[i](x, h) 173 | next_hs.append(next_h) 174 | projected_xs.append(self.proj(x)) 175 | if self.config['use_projection']: 176 | x = torch.stack(projected_xs, dim=1) 177 | score = self.classifier(x) 178 | return next_hs, score 179 | 180 | def forward(self, feature_in:torch.FloatTensor, target_a_1:torch.LongTensor, target_a_2:torch.LongTensor): 181 | if target_a_1 is not None: 182 | self.train() 183 | batch_size = feature_in.size(0) 184 | feature_len = feature_in.size(1) 185 | device = feature_in.device 186 | hs = [torch.zeros(batch_size, self.config['n_recurrent_hidden']).to(device) for _ in range(self.config['n_recurrent_layer'])] 187 | score_stack = [] 188 | 189 | for step in range(feature_len): 190 | hs, score = self.encode( 191 | feature_in[:, step], hs 192 | ) 193 | score_stack.append(score) 194 | logits = torch.stack(score_stack, dim=1) 195 | # if we are given some desired targets, calculate the loss 196 | loss = None 197 | if target_a_1 is not None: 198 | if 'focal' in self.config['prefix']: 199 | raise NotImplementedError() 200 | else: 201 | loss_1 = self.cross_entropy_loss(logits.view(-1, logits.size(-1)), target_a_1.view(-1)) 202 | loss_2 = self.cross_entropy_loss(logits.view(-1, logits.size(-1)), target_a_2.view(-1)) 203 | loss = torch.min(loss_1, loss_2) 204 | return logits, loss -------------------------------------------------------------------------------- /oad_recurrent_main.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import pickle 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | import argparse 7 | import os 8 | 9 | from dataset import OAD_Dataset 10 | from model import OADModel 11 | from yaml.loader import FullLoader 12 | from torch.utils.data import DataLoader 13 | from utils import get_aux_loss, soft_update, get_4state_proposals, get_hungarian_score, g 14 | from tqdm import tqdm 15 | 16 | if __name__ == '__main__': 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--yaml_path', type=str, default='yamls/canonical.yaml') 19 | args = parser.parse_args() 20 | yaml_path = args.yaml_path 21 | with open(yaml_path, encoding='utf-8') as f: 22 | config = yaml.load(f, FullLoader) 23 | 24 | if not os.path.exists(config['model_save_path']): 25 | os.mkdir(config['model_save_path']) 26 | 27 | ceweight_str = "[" 28 | for i in config["cross_entropy_weight"]: 29 | ceweight_str += str(i) 30 | ceweight_str += '_' 31 | ceweight_str = ceweight_str[:-1] 32 | ceweight_str += ']' 33 | identifier = f'{config["prefix"]}_{config["dataset"]}_{config["n_state"]}_p{config["penalty_coef"]}' 34 | tb_path = f'{identifier}' 35 | count = 1 36 | while(True): 37 | if not os.path.exists(os.path.join('runs', tb_path, str(count))): 38 | tb_path = os.path.join('runs', tb_path, str(count)) 39 | break 40 | else: 41 | count += 1 42 | 43 | 44 | #Agent 45 | model = OADModel(config).to(config['device']) 46 | 47 | print(f'identifier: {identifier}') 48 | print(f'dt_iteration: {config["dt_iteration"]}') 49 | print(f'n_state: {config["n_state"]}') 50 | print(f'aux_loss: {config["aux_loss"]}') 51 | print(f'use_target: {config["use_target"]}') 52 | print(f'history_length: {config["history_length"]}') 53 | print(f'penalty_coef: {config["penalty_coef"]}') 54 | print(f'cross_entropy_coef: {config["cross_entropy_weight"]}') 55 | 56 | if config['use_target']: 57 | print('use target network...') 58 | target = OADModel(config).to(config['device']) 59 | scheduler = scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=model.opt, lr_lambda=g(1., 0., config['oad_final_lr']/config['oad_base_lr'], config['oad_warmup_iter'], config['oad_final_iter'])) 60 | train_dataset = OAD_Dataset(config, mode='train') 61 | test_dataset = OAD_Dataset(config, mode='test') 62 | trainloader = DataLoader(train_dataset, batch_size=config['dt_batch_size'], shuffle=True, num_workers=config['num_workers']) 63 | testloader = DataLoader(test_dataset, batch_size=1, shuffle=False) 64 | iter_cnt = 0 65 | save_cnt = 0 66 | eval_cnt = 0 67 | plot_cnt = 0 68 | f1_cnt = 0 69 | 70 | iter_cnt = 0 71 | tp_cum = 0 72 | p_cum = 0 73 | a_cum = 0 74 | 75 | for epoch in range(config['epoch']): 76 | print(f'training in epoch {epoch+1}') 77 | total_correct = 0 78 | for s, target_a_1, target_a_2 in tqdm(trainloader): 79 | iter_cnt += 1 80 | model.train() 81 | s = s.to(config['device']) 82 | target_a_1 = target_a_1.to(config['device']) 83 | target_a_2 = target_a_2.to(config['device']) 84 | logits, loss = model(s, target_a_1, target_a_2) 85 | if config['aux_loss'] == 'cross_entropy': 86 | penalty = get_aux_loss(logits) 87 | else: 88 | raise NotImplementedError('invalid aux_loss') 89 | loss += config['penalty_coef']*penalty 90 | answer_sheet = torch.argmax(logits, dim=-1).detach().cpu().numpy() 91 | answer_1 = target_a_1.detach().cpu().numpy() 92 | answer_2 = target_a_2.detach().cpu().numpy() 93 | correct_1 = np.sum(np.where(answer_sheet == answer_1, 1, 0)) 94 | correct_2 = np.sum(np.where(answer_sheet == answer_2, 1, 0)) 95 | correct = np.max([correct_1, correct_2]) 96 | total_correct += correct 97 | model.opt.zero_grad() 98 | loss.backward() 99 | torch.nn.utils.clip_grad_norm_(model.parameters(), 2.5) 100 | model.opt.step() 101 | scheduler.step() 102 | if config['use_target']: 103 | soft_update(model, target, config['tau']) 104 | loss = loss.detach() 105 | if epoch+1 == config['epoch']: 106 | print(f'eval in epoch{epoch+1}') 107 | with open(config['oracle_proposal_path'], 'rb') as f: 108 | oracle_proposal_dict = pickle.load(f) 109 | eval_cnt = 0 110 | tp_cum = 0 111 | p_cum = 0 112 | a_cum = 0 113 | for feature, name in tqdm(testloader): 114 | name = name[0] 115 | atsumari = [] 116 | feature = feature.squeeze().to(config['device']) 117 | duration = len(feature) 118 | hs = [torch.zeros(1, config['n_recurrent_hidden']).to(config['device']) for _ in range(config['n_recurrent_layer'])] 119 | for i in range(duration): 120 | with torch.no_grad(): 121 | snippet = feature[i].unsqueeze(0) 122 | if config['use_target']: 123 | target.eval() 124 | hs, score = target.encode(snippet, hs) 125 | else: 126 | model.eval() 127 | hs, score = model.encode(snippet, hs) 128 | a = torch.argmax(score, dim=1).squeeze().cpu().numpy() 129 | atsumari.append(a) 130 | oracle_proposals = oracle_proposal_dict[name] 131 | pred_proposals = get_4state_proposals(atsumari) 132 | hungarian_results = get_hungarian_score(oracle_proposals, pred_proposals) 133 | tp_cum += hungarian_results['tp'] 134 | p_cum += hungarian_results['p'] 135 | a_cum += hungarian_results['a'] 136 | eval_cnt += 1 137 | if eval_cnt == config['eval_num']: 138 | break 139 | precision = tp_cum/(p_cum+1e-8) 140 | recall = tp_cum/(a_cum+1e-8) 141 | f1 = (2*precision*recall)/(precision+recall+1e-8) 142 | print(f'hungarian f1, recall: {f1}, {recall}') 143 | print(f'lr: {scheduler.get_last_lr()}') 144 | 145 | if config['use_target']: 146 | torch.save(target.state_dict(), os.path.join(config['model_save_path'], f'{identifier}.pt')) 147 | else: 148 | torch.save(model.state_dict(), os.path.join(config['model_save_path'], f'{identifier}.pt')) 149 | 150 | 151 | 152 | -------------------------------------------------------------------------------- /oracle_hungarian.py: -------------------------------------------------------------------------------- 1 | import json 2 | import argparse 3 | from utils import get_hungarian_score 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--gt', type=str, default='thumos14_v2.json') 8 | parser.add_argument('--pred', type=str, default='canonical.json') 9 | parser.add_argument('--tiou', type=float, default=0.5) 10 | args = parser.parse_args() 11 | gt = args.gt 12 | pred = args.pred 13 | tiou = args.tiou 14 | with open(gt, encoding='utf-8') as f: 15 | gt = json.load(f)['database'] 16 | with open(pred, encoding='utf-8') as f: 17 | pred = json.load(f)['results'] 18 | tp_cum = 0 19 | p_cum = 0 20 | a_cum = 0 21 | for vidname in pred: 22 | oracle_proposals = [] 23 | annotations = gt[vidname]['annotations'] 24 | for anno in annotations: 25 | oracle_proposals.append(anno['segment']) 26 | pred_proposals = [] 27 | for anno in pred[vidname]: 28 | pred_proposals.append(anno['segment']) 29 | hungarian_results = get_hungarian_score(oracle_proposals, pred_proposals, tiou) 30 | tp_cum += hungarian_results['tp'] 31 | p_cum += hungarian_results['p'] 32 | a_cum += hungarian_results['a'] 33 | precision = tp_cum/(p_cum+1e-8) 34 | recall = tp_cum/(a_cum+1e-8) 35 | f1 = (2*precision*recall)/(precision+recall+1e-8) 36 | print(f'tiou: {tiou}') 37 | print(f'p_cum: {p_cum}') 38 | print(f'a_cum: {a_cum}') 39 | print(f'hungarian f1, precision, recall: {f1}, {precision}, {recall}') -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops==0.8.0 2 | joblib==1.3.2 3 | matplotlib==3.5.3 4 | numpy==1.21.2 5 | pandas==2.2.2 6 | PyYAML==6.0.1 7 | scipy==1.7.3 8 | torch==1.10.0 9 | torchvision==0.11.0 10 | tqdm==4.61.2 11 | -------------------------------------------------------------------------------- /t_oad_model/tmp.txt: -------------------------------------------------------------------------------- 1 | dummy -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python make_proposal.py 4 | python convert.py 5 | python oracle_hungarian.py -------------------------------------------------------------------------------- /thumos14_classifier_model/tmp.txt: -------------------------------------------------------------------------------- 1 | dummy file -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | python oad_recurrent_main.py 4 | python make_proposal.py --yaml_path=yamls/canonical.yaml --json_name=main.json --load_model=main_thumos14_4_p0.025.pt 5 | python convert.py --yaml_path=yamls/canonical.yaml --source=main.json --target=main.json 6 | python oracle_hungarian.py --pred=main.json -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import urllib 3 | import json 4 | from random import uniform 5 | from scipy import optimize 6 | 7 | 8 | def get_2state_proposals(action_list): 9 | action_list = np.array(action_list) 10 | if np.any(action_list > 1): 11 | print(action_list) 12 | raise Exception('invalid action_list') 13 | action_list = action_list.tolist() 14 | cur_state = 0 15 | st = 0 16 | ed = 0 17 | proposals = [] 18 | for i, state in enumerate(action_list): 19 | if cur_state == 0 and state == 0: 20 | cur_state = 0 21 | elif cur_state == 0 and state == 1: 22 | st = i 23 | cur_state = 1 24 | elif cur_state == 1 and state == 0: 25 | ed = i 26 | proposals.append([st, ed]) 27 | cur_state = 0 28 | else: 29 | cur_state = 1 30 | return proposals 31 | 32 | def get_4state_proposals(action_list): 33 | action_list = np.array(action_list) 34 | action_list = np.where(action_list == 4, 0, action_list) 35 | if np.any(action_list > 3): 36 | print(action_list) 37 | raise Exception('invalid action_list') 38 | action_list = action_list.tolist() 39 | button1 = [] 40 | for i in action_list: 41 | if i == 2: 42 | i = 0 43 | elif i == 3: 44 | i = 1 45 | button1.append(i) 46 | button1_proposals = get_2state_proposals(button1) 47 | button2 = [] 48 | for i in action_list: 49 | if i == 1: 50 | i = 0 51 | elif i == 3: 52 | i = 2 53 | button2.append(i) 54 | button2_proposals = get_2state_proposals([1 if i == 2 else 0 for i in button2]) 55 | button1_proposals.extend(button2_proposals) 56 | button1_proposals.sort(key=lambda x: x[1]) 57 | return button1_proposals 58 | 59 | def calculate_iou(prediction:list, answer:list): 60 | intersection = -1 61 | s1 = prediction[0] 62 | e1 = prediction[1] 63 | s2 = answer[0] 64 | e2 = answer[1] 65 | if s1 > s2: 66 | s1, s2 = s2, s1 67 | e1, e2 = e2, e1 68 | if e1 <= s2: 69 | intersection = 0 70 | else: 71 | if e2 <= e1: 72 | intersection = (e2 - s2) 73 | else: 74 | intersection = (e1 - s2) 75 | l1 = e1 - s1 76 | l2 = e2 - s2 77 | iou = intersection/((l1 + l2 - intersection) + 1e-8) 78 | return iou 79 | 80 | def get_moving_average(l, moving_average_range): 81 | if len(l) <= moving_average_range: 82 | return l 83 | ret = [] 84 | for i in range(moving_average_range, len(l)): 85 | ret.append(np.mean(l[i-moving_average_range:i])) 86 | return ret 87 | 88 | def get_idx_and_confidence(score, target_segment, segment_list, class_names, iou_threshold=0.5): 89 | max_iou = -1 90 | registered_segment = None 91 | for segment in segment_list: 92 | iou = calculate_iou(target_segment, segment['segment']) 93 | if iou > max_iou: 94 | max_iou = iou 95 | registered_segment = segment 96 | if max_iou < iou_threshold: 97 | registered_segment = None 98 | if registered_segment is None: 99 | class_idx = np.argmax(score) 100 | confidence = float(score[class_idx]) 101 | else: 102 | i = class_names[registered_segment['label']] 103 | score[i] = 0 104 | class_idx = np.argmax(score) 105 | confidence = float(score[class_idx]) 106 | return class_idx, confidence 107 | 108 | def get_idx_and_confidences(score, n): 109 | class_idxes_and_confidences = [] 110 | for _ in range(n): 111 | class_idx = np.argmax(score) 112 | confidence = float(score[class_idx]) 113 | class_idxes_and_confidences.append([class_idx, confidence]) 114 | score[class_idx] = 0 115 | return class_idxes_and_confidences 116 | 117 | def get_hungarian_score(answer:list, prediction:list, iou_threshold=0.5): 118 | #IN: answer[[st,ed], [st,ed]...], prediction[[st,ed],[st,ed]...] 119 | #OUT: dict: tp(True positive), p(Positive), a(Answer) 120 | no_answer_flag = False 121 | no_pred_flag = False 122 | if len(answer) == 0: 123 | no_answer_flag = True 124 | answer.append([0,0]) 125 | if len(prediction) == 0: 126 | no_pred_flag = True 127 | prediction.append([0,0]) 128 | answer = np.array(answer) 129 | prediction = np.array(prediction) 130 | profit = np.zeros((len(answer), len(prediction))) 131 | for i in range(len(answer)): 132 | for j in range(len(prediction)): 133 | profit[i][j] = calculate_iou(answer[i], prediction[j]) 134 | r, c = optimize.linear_sum_assignment(profit, maximize=True) 135 | tp = np.sum(np.where(profit[r, c] >= iou_threshold, 1, 0)) 136 | a = answer.shape[0] 137 | p = prediction.shape[0] 138 | if no_answer_flag: 139 | a = 0 140 | if no_pred_flag: 141 | p = 0 142 | return {'tp':tp, 'p':p, 'a':a} 143 | 144 | def soft_update(local_model, target_model, tau): 145 | """Soft update model parameters. 146 | θ_target = τ*θ_local + (1 - τ)*θ_target 147 | Params 148 | ====== 149 | local_model (PyTorch model): weights will be copied from 150 | target_model (PyTorch model): weights will be copied to 151 | tau (float): interpolation parameter 152 | """ 153 | for target_param, local_param in zip(target_model.parameters(), local_model.parameters()): 154 | target_param.data.copy_(tau*local_param.data + (1.0-tau)*target_param.data) 155 | 156 | 157 | def get_aux_loss(logits): 158 | ''' 159 | logits: [B,L,n_state] 160 | return penalty loss 161 | -- loss which penalizes context change -- 162 | ''' 163 | import torch 164 | max_idx = torch.argmax(logits, dim=-1) 165 | #[B,L] 166 | dummy = torch.full((logits.size(0),1), fill_value=-1).to(logits.device) 167 | front = torch.cat([dummy, max_idx], dim=1) 168 | back = torch.cat([max_idx, dummy], dim=1) 169 | boolidx = (front != back)[:,:-1] 170 | boolidx[:, 0] = False 171 | front = front[:, :-1] 172 | target = front[boolidx] 173 | logits = logits[boolidx] 174 | max_idx = torch.argmax(logits, dim=-1) 175 | is_target_not4 = target != 4 176 | is_logits_not4 = max_idx != 4 177 | exclusive_mask = torch.logical_and(is_target_not4, is_logits_not4) 178 | target = target[exclusive_mask] 179 | logits = logits[exclusive_mask] 180 | loss = torch.nn.functional.cross_entropy(logits, target) 181 | return loss 182 | 183 | def g(base_value, start_warmup_value, final_value, warmup_iters, final_iter): 184 | import math 185 | def f(i): 186 | if i <= warmup_iters: 187 | value = (base_value - start_warmup_value)*(i/warmup_iters) + start_warmup_value 188 | elif i < final_iter: 189 | value = final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * (i-warmup_iters) / (final_iter - warmup_iters))) 190 | else: 191 | value = final_value 192 | return value 193 | return f 194 | 195 | def get_blocked_videos(api='http://ec2-52-25-205-214.us-west-2.compute.amazonaws.com/challenge19/api.py'): 196 | api_url = '{}?action=get_blocked'.format(api) 197 | response = urllib.request.urlopen(api_url) 198 | return json.loads(response.read()) 199 | 200 | def interpolated_prec_rec(prec, rec): 201 | """Interpolated AP - VOCdevkit from VOC 2011. 202 | """ 203 | mprec = np.hstack([[0], prec, [0]]) 204 | mrec = np.hstack([[0], rec, [1]]) 205 | for i in range(len(mprec) - 1)[::-1]: 206 | mprec[i] = max(mprec[i], mprec[i + 1]) 207 | idx = np.where(mrec[1::] != mrec[0:-1])[0] + 1 208 | ap = np.sum((mrec[idx] - mrec[idx - 1]) * mprec[idx]) 209 | return ap 210 | 211 | def segment_iou(target_segment, candidate_segments): 212 | """Compute the temporal intersection over union between a 213 | target segment and all the test segments. 214 | 215 | Parameters 216 | ---------- 217 | target_segment : 1d array 218 | Temporal target segment containing [starting, ending] times. 219 | candidate_segments : 2d array 220 | Temporal candidate segments containing N x [starting, ending] times. 221 | 222 | Outputs 223 | ------- 224 | tiou : 1d array 225 | Temporal intersection over union score of the N's candidate segments. 226 | """ 227 | tt1 = np.maximum(target_segment[0], candidate_segments[:, 0]) 228 | tt2 = np.minimum(target_segment[1], candidate_segments[:, 1]) 229 | # Intersection including Non-negative overlap score. 230 | segments_intersection = (tt2 - tt1).clip(0) 231 | # Segment union. 232 | segments_union = (candidate_segments[:, 1] - candidate_segments[:, 0]) \ 233 | + (target_segment[1] - target_segment[0]) - segments_intersection 234 | # Compute overlap as the ratio of the intersection 235 | # over union of two segments. 236 | tIoU = segments_intersection.astype(float) / segments_union 237 | 238 | return tIoU 239 | 240 | def compute_average_precision_detection(ground_truth, prediction, tiou_thresholds=np.linspace(0.5, 0.95, 10)): 241 | """Compute average precision (detection task) between ground truth and 242 | predictions data frames. If multiple predictions occurs for the same 243 | predicted segment, only the one with highest score is matches as 244 | true positive. This code is greatly inspired by Pascal VOC devkit. 245 | Parameters 246 | ---------- 247 | ground_truth : df 248 | Data frame containing the ground truth instances. 249 | Required fields: ['video-id', 't-start', 't-end'] 250 | prediction : df 251 | Data frame containing the prediction instances. 252 | Required fields: ['video-id, 't-start', 't-end', 'score'] 253 | tiou_thresholds : 1darray, optional 254 | Temporal intersection over union threshold. 255 | Outputs 256 | ------- 257 | ap : float 258 | Average precision score. 259 | """ 260 | ap = np.zeros(len(tiou_thresholds)) 261 | if prediction.empty: 262 | return ap 263 | 264 | npos = float(len(ground_truth)) 265 | lock_gt = np.ones((len(tiou_thresholds),len(ground_truth))) * -1 266 | # Sort predictions by decreasing score order. 267 | sort_idx = prediction['score'].values.argsort()[::-1] 268 | prediction = prediction.loc[sort_idx].reset_index(drop=True) 269 | 270 | # Initialize true positive and false positive vectors. 271 | tp = np.zeros((len(tiou_thresholds), len(prediction))) 272 | fp = np.zeros((len(tiou_thresholds), len(prediction))) 273 | 274 | # Adaptation to query faster 275 | ground_truth_gbvn = ground_truth.groupby('video-id') 276 | 277 | # Assigning true positive to truly grount truth instances. 278 | for idx, this_pred in prediction.iterrows(): 279 | 280 | try: 281 | # Check if there is at least one ground truth in the video associated. 282 | ground_truth_videoid = ground_truth_gbvn.get_group(this_pred['video-id']) 283 | except Exception as e: 284 | fp[:, idx] = 1 285 | continue 286 | 287 | this_gt = ground_truth_videoid.reset_index() 288 | tiou_arr = segment_iou(this_pred[['t-start', 't-end']].values, 289 | this_gt[['t-start', 't-end']].values) 290 | # We would like to retrieve the predictions with highest tiou score. 291 | tiou_sorted_idx = tiou_arr.argsort()[::-1] 292 | for tidx, tiou_thr in enumerate(tiou_thresholds): 293 | for jdx in tiou_sorted_idx: 294 | if tiou_arr[jdx] < tiou_thr: 295 | fp[tidx, idx] = 1 296 | break 297 | if lock_gt[tidx, this_gt.loc[jdx]['index']] >= 0: 298 | continue 299 | # Assign as true positive after the filters above. 300 | tp[tidx, idx] = 1 301 | lock_gt[tidx, this_gt.loc[jdx]['index']] = idx 302 | break 303 | 304 | if fp[tidx, idx] == 0 and tp[tidx, idx] == 0: 305 | fp[tidx, idx] = 1 306 | 307 | tp_cumsum = np.cumsum(tp, axis=1).astype(np.float) 308 | fp_cumsum = np.cumsum(fp, axis=1).astype(np.float) 309 | recall_cumsum = tp_cumsum / npos 310 | 311 | precision_cumsum = tp_cumsum / (tp_cumsum + fp_cumsum) 312 | 313 | for tidx in range(len(tiou_thresholds)): 314 | ap[tidx] = interpolated_prec_rec(precision_cumsum[tidx,:], recall_cumsum[tidx,:]) 315 | 316 | return ap 317 | -------------------------------------------------------------------------------- /yamls/canonical.yaml: -------------------------------------------------------------------------------- 1 | #path 2 | dataset: 'thumos14' 3 | #dataset: 'muses' 4 | feature_path: "/workspace/data/thumos14/thumos14_features" 5 | 6 | ######EDIT####### 7 | state_label_path: "/workspace/data/thumos14/thumos14_4state_label" 8 | ################# 9 | class_label_path: "/workspace/data/thumos14/thumos14_class_label" 10 | oracle_proposal_path: "/workspace/data/thumos14/thumos14_oracle_proposals.pkl" 11 | model_save_path: "./t_oad_model" 12 | annotation_path: "/workspace/data/thumos14/thumos14_v2.json" 13 | 14 | #classifier 15 | classifier_id: "standard" 16 | unit_length: 64 17 | max_length: 256 18 | a: 0.333 19 | a_std: 0.15 20 | b: 0.666 21 | b_std: 0.15 22 | foreground_ratio: 0.5 23 | long_instance_ratio: 0.1 24 | iou_threshold: 0.5 25 | gamma: 2. 26 | alpha: -1 27 | 28 | classifier_model_path: './thumos14_classifier_model/' 29 | classifier_feature_path: './data/thumos14/thumos14_class_features' 30 | classifier_epoch_iter: 16384 31 | classifier_equal_sampling: false 32 | classifier_temperature: 0.5 33 | classifier_save_interval: 5 34 | classifier_warmup_iter: 500 35 | classifier_final_iter: 10000 36 | classifier_base_lr: 0.0001 37 | classifier_final_lr: 0.000001 38 | classifier_batch_size: 64 39 | classifier_epoch: 1001 40 | feature_dim: 4096 41 | classifier_transformer_unit: 512 42 | classifier_transformer_fc_unit: 2048 43 | classifier_transformer_layer_num: 3 44 | classifier_fc_unit: 512 45 | class_num: 21 46 | 47 | #name of experiment 48 | prefix: "main" 49 | 50 | #hard parameters 51 | #If you change n_state, make sure to change state_label_path too. 52 | n_state: 4 53 | n_feature: 4096 54 | 55 | device: 'cuda:0' 56 | 57 | 58 | 59 | #lstm 60 | use_projection: True 61 | n_recurrent_layer: 4 62 | n_classifier_layer: 2 63 | history_length: 48 64 | n_recurrent_hidden: 2048 65 | n_projection_hidden: 512 66 | recurrent_pdrop: 0. 67 | 68 | #optimizers 69 | oad_warmup_iter: 512 70 | oad_final_iter: 3072 71 | oad_base_lr: 0.0003 72 | oad_final_lr: 0.000001 73 | oad_weight_decay: 0.01 74 | oad_betas: [0.9, 0.9] 75 | #training setting 76 | epoch: 8 77 | num_workers: 5 78 | length_proportional_sampling: True 79 | label_augmentation: False 80 | dt_iteration: 16384 81 | dt_batch_size: 64 82 | eval_num: 250 83 | 84 | use_target: True 85 | #######EDIT####### 86 | aux_loss: "cross_entropy" 87 | tau: 0.001 88 | #######EDIT####### 89 | penalty_coef: 0.025 90 | cross_entropy_weight: [0.1,0.1,0.1,0.1,] 91 | 92 | 93 | #stat 94 | moving_average_range: 25 95 | eval_epoch: 2 96 | --------------------------------------------------------------------------------