├── src ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── simple_tokenizer.py │ ├── clip.py │ └── model.py ├── xd_option.py ├── ucf_option.py ├── utils │ ├── dataset.py │ ├── lr_warmup.py │ ├── tools.py │ ├── xd_detectionMAP.py │ ├── ucf_detectionMAP.py │ └── layers.py ├── crop.py ├── xd_test.py ├── ucf_test.py ├── xd_train.py ├── ucf_train.py └── model.py ├── list ├── gt.npy ├── gt_ucf.npy ├── gt_label.npy ├── gt_segment.npy ├── gt_label_ucf.npy ├── gt_segment_ucf.npy ├── make_list_xd.py ├── make_list_ucf.py ├── make_gt_xd.py ├── make_gt_mAP_xd.py ├── make_gt_mAP_ucf.py ├── make_gt_ucf.py ├── Anomaly_Test.txt ├── Temporal_Anomaly_Annotation.txt ├── ucf_CLIP_rgbtest.csv └── annotations.txt ├── data └── framework.png ├── README.md └── LICENSE /src/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /list/gt.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt.npy -------------------------------------------------------------------------------- /list/gt_ucf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_ucf.npy -------------------------------------------------------------------------------- /data/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/data/framework.png -------------------------------------------------------------------------------- /list/gt_label.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_label.npy -------------------------------------------------------------------------------- /list/gt_segment.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_segment.npy -------------------------------------------------------------------------------- /list/gt_label_ucf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_label_ucf.npy -------------------------------------------------------------------------------- /list/gt_segment_ucf.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_segment_ucf.npy -------------------------------------------------------------------------------- /src/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/src/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /list/make_list_xd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import csv 4 | 5 | root_path = '/home/xbgydx/Desktop/XDTrainClipFeatures' ## the path of features 6 | files = sorted(glob.glob(os.path.join(root_path, "*.npy"))) 7 | violents = [] 8 | normal = [] 9 | 10 | with open('list/xd_CLIP_rgb.csv', 'w+') as f: ## the name of feature list 11 | writer = csv.writer(f) 12 | writer.writerow(['path', 'label']) 13 | for file in files: 14 | if '.npy' in file: 15 | if '_label_A' in file: 16 | normal.append(file) 17 | else: 18 | label = file.split('_label_')[1].split('__')[0] 19 | writer.writerow([file, label]) 20 | 21 | for file in normal: 22 | writer.writerow([file, 'A']) -------------------------------------------------------------------------------- /list/make_list_ucf.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | 4 | root_path = '/home/xbgydx/Desktop/UCFClipFeatures/' 5 | txt = 'list/Anomaly_Train.txt' 6 | files = list(open(txt)) 7 | normal = [] 8 | count = 0 9 | 10 | with open('list/ucf_CLIP_rgb.csv', 'w+') as f: ## the name of feature list 11 | writer = csv.writer(f) 12 | writer.writerow(['path', 'label']) 13 | for file in files: 14 | filename = root_path + file[:-5] + '__0.npy' 15 | label = file.split('/')[0] 16 | if os.path.exists(filename): 17 | if 'Normal' in label: 18 | #continue 19 | filename = filename[:-5] 20 | for i in range(0, 10, 1): 21 | normal.append(filename + str(i) + '.npy') 22 | else: 23 | filename = filename[:-5] 24 | for i in range(0, 10, 1): 25 | writer.writerow([filename + str(i) + '.npy', label]) 26 | else: 27 | count += 1 28 | print(filename) 29 | 30 | for file in normal: 31 | writer.writerow([file, 'Normal']) 32 | 33 | print(count) -------------------------------------------------------------------------------- /list/make_gt_xd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import cv2 4 | 5 | clip_len = 16 6 | 7 | # the dir of testing images 8 | feature_list = 'list/xd_CLIP_rgbtest.csv' 9 | # the ground truth txt 10 | 11 | gt_txt = 'list/annotations.txt' ## the path of test annotations 12 | gt_lines = list(open(gt_txt)) 13 | gt = [] 14 | lists = pd.read_csv(feature_list) 15 | count = 0 16 | 17 | for idx in range(lists.shape[0]): 18 | name = lists.loc[idx]['path'] 19 | if '__0.npy' not in name: 20 | continue 21 | #feature = name.split('label_')[-1] 22 | fea = np.load(name) 23 | lens = (fea.shape[0] + 1) * clip_len 24 | name = name.split('/')[-1] 25 | name = name[:-7] 26 | # the number of testing images in this sub-dir 27 | 28 | gt_vec = np.zeros(lens).astype(np.float32) 29 | if 'label_A' not in name: 30 | for gt_line in gt_lines: 31 | if name in gt_line: 32 | count += 1 33 | gt_content = gt_line.strip('\n').split() 34 | abnormal_fragment = [[int(gt_content[i]),int(gt_content[j])] for i in range(1,len(gt_content),2) \ 35 | for j in range(2,len(gt_content),2) if j==i+1] 36 | if len(abnormal_fragment) != 0: 37 | abnormal_fragment = np.array(abnormal_fragment) 38 | for frag in abnormal_fragment: 39 | gt_vec[frag[0]:frag[1]]=1.0 40 | break 41 | gt.extend(gt_vec[:-clip_len]) 42 | 43 | np.save('list/gt_xd.npy', gt) 44 | 45 | print(count) -------------------------------------------------------------------------------- /list/make_gt_mAP_xd.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import cv2 5 | import pandas as pd 6 | import warnings 7 | 8 | clip_len = 16 9 | 10 | # the dir of testing images 11 | feature_list = 'list/xd_CLIP_rgbtest.csv' 12 | 13 | # the ground truth txt 14 | gt_txt = 'list/annotations_multiclasses.txt' 15 | gt_lines = list(open(gt_txt)) 16 | 17 | #warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 18 | 19 | gt_segment = [] 20 | gt_label = [] 21 | lists = pd.read_csv(feature_list) 22 | 23 | for idx in range(lists.shape[0]): 24 | name = lists.loc[idx]['path'] 25 | if '__0.npy' not in name: 26 | continue 27 | segment = [] 28 | label = [] 29 | if '_label_A' in name: 30 | fea = np.load(name) 31 | lens = fea.shape[0] * clip_len 32 | name = name.split('/')[-1] 33 | name = name[:-7] 34 | segment.append([0, lens]) 35 | label.append('A') 36 | else: 37 | name = name.split('/')[-1] 38 | name = name[:-7] 39 | for gt_line in gt_lines: 40 | if name in gt_line: 41 | gt_content = gt_line.strip('\n').split() 42 | for j in range(1, len(gt_content), 3): 43 | print(gt_content, j) 44 | segment.append([gt_content[j + 1], gt_content[j + 2]]) 45 | label.append(gt_content[j]) 46 | break 47 | gt_segment.append(segment) 48 | gt_label.append(label) 49 | 50 | np.save('list/gt_label.npy', gt_label) 51 | np.save('list/gt_segment.npy', gt_segment) -------------------------------------------------------------------------------- /src/xd_option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='VadCLIP') 4 | parser.add_argument('--seed', default=234, type=int) 5 | 6 | parser.add_argument('--embed-dim', default=512, type=int) 7 | parser.add_argument('--visual-length', default=256, type=int) 8 | parser.add_argument('--visual-width', default=512, type=int) 9 | parser.add_argument('--visual-head', default=1, type=int) 10 | parser.add_argument('--visual-layers', default=1, type=int) 11 | parser.add_argument('--attn-window', default=64, type=int) 12 | parser.add_argument('--prompt-prefix', default=10, type=int) 13 | parser.add_argument('--prompt-postfix', default=10, type=int) 14 | parser.add_argument('--classes-num', default=7, type=int) 15 | 16 | parser.add_argument('--max-epoch', default=10, type=int) 17 | parser.add_argument('--model-path', default='model/model_xd.pth') 18 | parser.add_argument('--use-checkpoint', default=False, type=bool) 19 | parser.add_argument('--checkpoint-path', default='model/checkpoint.pth') 20 | parser.add_argument('--batch-size', default=96, type=int) 21 | parser.add_argument('--train-list', default='list/xd_CLIP_rgb.csv') 22 | parser.add_argument('--test-list', default='list/xd_CLIP_rgbtest.csv') 23 | parser.add_argument('--gt-path', default='list/gt.npy') 24 | parser.add_argument('--gt-segment-path', default='list/gt_segment.npy') 25 | parser.add_argument('--gt-label-path', default='list/gt_label.npy') 26 | 27 | parser.add_argument('--lr', default=1e-5) 28 | parser.add_argument('--scheduler-rate', default=0.1) 29 | parser.add_argument('--scheduler-milestones', default=[3, 6, 10]) -------------------------------------------------------------------------------- /src/ucf_option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser(description='VadCLIP') 4 | parser.add_argument('--seed', default=234, type=int) 5 | 6 | parser.add_argument('--embed-dim', default=512, type=int) 7 | parser.add_argument('--visual-length', default=256, type=int) 8 | parser.add_argument('--visual-width', default=512, type=int) 9 | parser.add_argument('--visual-head', default=1, type=int) 10 | parser.add_argument('--visual-layers', default=2, type=int) 11 | parser.add_argument('--attn-window', default=8, type=int) 12 | parser.add_argument('--prompt-prefix', default=10, type=int) 13 | parser.add_argument('--prompt-postfix', default=10, type=int) 14 | parser.add_argument('--classes-num', default=14, type=int) 15 | 16 | parser.add_argument('--max-epoch', default=10, type=int) 17 | parser.add_argument('--model-path', default='model/model_ucf.pth') 18 | parser.add_argument('--use-checkpoint', default=False, type=bool) 19 | parser.add_argument('--checkpoint-path', default='model/checkpoint.pth') 20 | parser.add_argument('--batch-size', default=64, type=int) 21 | parser.add_argument('--train-list', default='list/ucf_CLIP_rgb.csv') 22 | parser.add_argument('--test-list', default='list/ucf_CLIP_rgbtest.csv') 23 | parser.add_argument('--gt-path', default='list/gt_ucf.npy') 24 | parser.add_argument('--gt-segment-path', default='list/gt_segment_ucf.npy') 25 | parser.add_argument('--gt-label-path', default='list/gt_label_ucf.npy') 26 | 27 | parser.add_argument('--lr', default=2e-5) 28 | parser.add_argument('--scheduler-rate', default=0.1) 29 | parser.add_argument('--scheduler-milestones', default=[4, 8]) -------------------------------------------------------------------------------- /list/make_gt_mAP_ucf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | import cv2 5 | import pandas as pd 6 | import warnings 7 | 8 | clip_len = 16 9 | 10 | feature_list = 'list/ucf_CLIP_rgbtest.csv' 11 | 12 | # the ground truth txt 13 | gt_txt = 'list/Temporal_Anomaly_Annotation.txt' 14 | gt_lines = list(open(gt_txt)) 15 | 16 | #warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 17 | 18 | gt_segment = [] 19 | gt_label = [] 20 | lists = pd.read_csv(feature_list) 21 | 22 | for idx in range(lists.shape[0]): 23 | name = lists.loc[idx]['path'] 24 | label_text = lists.loc[idx]['label'] 25 | if '__0.npy' not in name: 26 | continue 27 | segment = [] 28 | label = [] 29 | if 'Normal' in label_text: 30 | fea = np.load(name) 31 | lens = fea.shape[0] * clip_len 32 | name = name.split('/')[-1] 33 | name = name[:-7] 34 | segment.append([0, lens]) 35 | label.append('A') 36 | else: 37 | name = name.split('/')[-1] 38 | name = name[:-7] 39 | for gt_line in gt_lines: 40 | if name in gt_line: 41 | gt_content = gt_line.strip('\n').split(' ') 42 | segment.append([gt_content[2], gt_content[3]]) 43 | label.append(gt_content[1]) 44 | if gt_content[4] != '-1': 45 | segment.append([gt_content[4], gt_content[5]]) 46 | label.append(gt_content[1]) 47 | break 48 | gt_segment.append(segment) 49 | gt_label.append(label) 50 | 51 | np.save('list/gt_label_ucf.npy', gt_label) 52 | np.save('list/gt_segment_ucf.npy', gt_segment) -------------------------------------------------------------------------------- /list/make_gt_ucf.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import cv2 4 | 5 | clip_len = 16 6 | 7 | # the dir of testing images 8 | feature_list = 'list/ucf_CLIP_rgbtest.csv' 9 | # the ground truth txt 10 | 11 | gt_txt = 'list/Temporal_Anomaly_Annotation.txt' ## the path of test annotations 12 | gt_lines = list(open(gt_txt)) 13 | gt = [] 14 | lists = pd.read_csv(feature_list) 15 | count = 0 16 | 17 | for idx in range(lists.shape[0]): 18 | name = lists.loc[idx]['path'] 19 | if '__0.npy' not in name: 20 | continue 21 | #feature = name.split('label_')[-1] 22 | fea = np.load(name) 23 | lens = (fea.shape[0] + 1) * clip_len 24 | name = name.split('/')[-1] 25 | name = name[:-7] 26 | # the number of testing images in this sub-dir 27 | 28 | gt_vec = np.zeros(lens).astype(np.float32) 29 | if 'Normal' not in name: 30 | for gt_line in gt_lines: 31 | if name in gt_line: 32 | count += 1 33 | gt_content = gt_line.strip('\n').split(' ')[1:-1] 34 | abnormal_fragment = [[int(gt_content[i]),int(gt_content[j])] for i in range(1,len(gt_content),2) \ 35 | for j in range(2,len(gt_content),2) if j==i+1] 36 | if len(abnormal_fragment) != 0: 37 | abnormal_fragment = np.array(abnormal_fragment) 38 | for frag in abnormal_fragment: 39 | if frag[0] != -1 and frag[1] != -1: 40 | gt_vec[frag[0]:frag[1]]=1.0 41 | break 42 | gt.extend(gt_vec[:-clip_len]) 43 | 44 | print(count) 45 | np.save('list/gt_ucf.npy', gt) -------------------------------------------------------------------------------- /src/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import pandas as pd 5 | import utils.tools as tools 6 | 7 | class UCFDataset(data.Dataset): 8 | def __init__(self, clip_dim: int, file_path: str, test_mode: bool, label_map: dict, normal: bool = False): 9 | self.df = pd.read_csv(file_path) 10 | self.clip_dim = clip_dim 11 | self.test_mode = test_mode 12 | self.label_map = label_map 13 | self.normal = normal 14 | if normal == True and test_mode == False: 15 | self.df = self.df.loc[self.df['label'] == 'Normal'] 16 | self.df = self.df.reset_index() 17 | elif test_mode == False: 18 | self.df = self.df.loc[self.df['label'] != 'Normal'] 19 | self.df = self.df.reset_index() 20 | 21 | def __len__(self): 22 | return self.df.shape[0] 23 | 24 | def __getitem__(self, index): 25 | clip_feature = np.load(self.df.loc[index]['path']) 26 | if self.test_mode == False: 27 | clip_feature, clip_length = tools.process_feat(clip_feature, self.clip_dim) 28 | else: 29 | clip_feature, clip_length = tools.process_split(clip_feature, self.clip_dim) 30 | 31 | clip_feature = torch.tensor(clip_feature) 32 | clip_label = self.df.loc[index]['label'] 33 | return clip_feature, clip_label, clip_length 34 | 35 | class XDDataset(data.Dataset): 36 | def __init__(self, clip_dim: int, file_path: str, test_mode: bool, label_map: dict): 37 | self.df = pd.read_csv(file_path) 38 | self.clip_dim = clip_dim 39 | self.test_mode = test_mode 40 | self.label_map = label_map 41 | 42 | def __len__(self): 43 | return self.df.shape[0] 44 | 45 | def __getitem__(self, index): 46 | clip_feature = np.load(self.df.loc[index]['path']) 47 | if self.test_mode == False: 48 | clip_feature, clip_length = tools.process_feat(clip_feature, self.clip_dim) 49 | else: 50 | clip_feature, clip_length = tools.process_split(clip_feature, self.clip_dim) 51 | 52 | clip_feature = torch.tensor(clip_feature) 53 | clip_label = self.df.loc[index]['label'] 54 | return clip_feature, clip_label, clip_length -------------------------------------------------------------------------------- /src/utils/lr_warmup.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler 3 | 4 | class WarmupMultiStepLR(MultiStepLR): 5 | r""" 6 | # max_iter = epochs * steps_per_epoch 7 | Args: 8 | optimizer (Optimizer): Wrapped optimizer. 9 | max_iter (int): The total number of steps. 10 | milestones (list) – List of iter indices. Must be increasing. 11 | gamma (float): Multiplicative factor of learning rate decay. Default: 0.1. 12 | pct_start (float): The percentage of the cycle (in number of steps) spent 13 | increasing the learning rate. 14 | Default: 0.3 15 | warmup_factor (float): 16 | last_epoch (int): The index of last epoch. Default: -1. 17 | """ 18 | def __init__(self, optimizer, max_iter, milestones, gamma=0.1, pct_start=0.3, warmup_factor=1.0 / 2, 19 | last_epoch=-1): 20 | self.warmup_factor = warmup_factor 21 | self.warmup_iters = int(pct_start * max_iter) 22 | super().__init__(optimizer, milestones, gamma, last_epoch) 23 | 24 | def get_lr(self): 25 | if self.last_epoch <= self.warmup_iters: 26 | alpha = self.last_epoch / self.warmup_iters 27 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 28 | return [lr * warmup_factor for lr in self.base_lrs] 29 | else: 30 | lr = super().get_lr() 31 | return lr 32 | 33 | class WarmupCosineLR(_LRScheduler): 34 | def __init__(self, optimizer, max_iter, pct_start=0.3, warmup_factor=1.0 / 3, 35 | eta_min=0, last_epoch=-1): 36 | self.warmup_factor = warmup_factor 37 | self.warmup_iters = int(pct_start * max_iter) 38 | self.max_iter, self.eta_min = max_iter, eta_min 39 | super().__init__(optimizer) 40 | 41 | def get_lr(self): 42 | if self.last_epoch <= self.warmup_iters: 43 | alpha = self.last_epoch / self.warmup_iters 44 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 45 | return [lr * warmup_factor for lr in self.base_lrs] 46 | else: 47 | # print ("after warmup") 48 | return [self.eta_min + (base_lr - self.eta_min) * 49 | (1 + math.cos( 50 | math.pi * (self.last_epoch - self.warmup_iters) / (self.max_iter - self.warmup_iters))) / 2 51 | for base_lr in self.base_lrs] 52 | 53 | class WarmupPolyLR(_LRScheduler): 54 | def __init__(self, optimizer, T_max, pct_start=0.3, warmup_factor=1.0 / 4, 55 | eta_min=0, power=0.9): 56 | self.warmup_factor = warmup_factor 57 | self.warmup_iters = int(pct_start * T_max) 58 | self.power = power 59 | self.T_max, self.eta_min = T_max, eta_min 60 | super().__init__(optimizer) 61 | 62 | def get_lr(self): 63 | if self.last_epoch <= self.warmup_iters: 64 | alpha = self.last_epoch / self.warmup_iters 65 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 66 | return [lr * warmup_factor for lr in self.base_lrs] 67 | else: 68 | return [self.eta_min + (base_lr - self.eta_min) * 69 | math.pow(1 - (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters), 70 | self.power) for base_lr in self.base_lrs] 71 | -------------------------------------------------------------------------------- /src/crop.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import torch 5 | from clip import clip 6 | from PIL import Image 7 | 8 | def video_crop(video_frame, type): 9 | l = video_frame.shape[0] 10 | new_frame = [] 11 | for i in range(l): 12 | img = cv2.resize(video_frame[i], dsize=(340, 256)) 13 | new_frame.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 14 | 15 | #1 16 | img = np.array(new_frame) 17 | if type == 0: 18 | img = img[:, 16:240, 58:282, :] 19 | #2 20 | elif type == 1: 21 | img = img[:, :224, :224, :] 22 | #3 23 | elif type == 2: 24 | img = img[:, :224, -224:, :] 25 | #4 26 | elif type == 3: 27 | img = img[:, -224:, :224, :] 28 | #5 29 | elif type == 4: 30 | img = img[:, -224:, -224:, :] 31 | #6 32 | elif type == 5: 33 | img = img[:, 16:240, 58:282, :] 34 | for i in range(img.shape[0]): 35 | img[i] = cv2.flip(img[i], 1) 36 | #7 37 | elif type == 6: 38 | img = img[:, :224, :224, :] 39 | for i in range(img.shape[0]): 40 | img[i] = cv2.flip(img[i], 1) 41 | #8 42 | elif type == 7: 43 | img = img[:, :224, -224:, :] 44 | for i in range(img.shape[0]): 45 | img[i] = cv2.flip(img[i], 1) 46 | #9 47 | elif type == 8: 48 | img = img[:, -224:, :224, :] 49 | for i in range(img.shape[0]): 50 | img[i] = cv2.flip(img[i], 1) 51 | #10 52 | elif type == 9: 53 | img = img[:, -224:, -224:, :] 54 | for i in range(img.shape[0]): 55 | img[i] = cv2.flip(img[i], 1) 56 | 57 | return img 58 | 59 | def image_crop(image, type): 60 | img = cv2.resize(image, dsize=(340, 256)) 61 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 62 | #1 63 | if type == 0: 64 | img = img[16:240, 58:282, :] 65 | #2 66 | elif type == 1: 67 | img = img[:224, :224, :] 68 | #3 69 | elif type == 2: 70 | img = img[:224, -224:, :] 71 | #4 72 | elif type == 3: 73 | img = img[-224:, :224, :] 74 | #5 75 | elif type == 4: 76 | img = img[-224:, -224:, :] 77 | #6 78 | elif type == 5: 79 | img = img[16:240, 58:282, :] 80 | img = cv2.flip(img, 1) 81 | #7 82 | elif type == 6: 83 | img = img[:224, :224, :] 84 | img = cv2.flip(img, 1) 85 | #8 86 | elif type == 7: 87 | img = img[:224, -224:, :] 88 | img = cv2.flip(img, 1) 89 | #9 90 | elif type == 8: 91 | img = img[-224:, :224, :] 92 | img = cv2.flip(img, 1) 93 | #10 94 | elif type == 9: 95 | img = img[-224:, -224:, :] 96 | img = cv2.flip(img, 1) 97 | 98 | return img 99 | 100 | if __name__ == '__main__': 101 | video = np.zeros([3, 320, 240, 3], dtype=np.uint8) 102 | corp_video = video_crop(video, 0) 103 | 104 | device = "cuda" if torch.cuda.is_available() else "cpu" 105 | model, preprocess = clip.load("ViT-B/16", device) 106 | video_features = torch.zeros(0).to(device) 107 | with torch.no_grad(): 108 | for i in range(video.shape[0]): 109 | img = Image.fromarray(corp_video[i]) 110 | img = preprocess(img).unsqueeze(0).to(device) 111 | feature = model.encode_image(img) 112 | video_features = torch.cat([video_features, feature], dim=0) 113 | 114 | video_features = video_features.detach().cpu().numpy() 115 | np.save('save_path', video_features) -------------------------------------------------------------------------------- /src/utils/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def get_batch_label(texts, prompt_text, label_map: dict): 5 | label_vectors = torch.zeros(0) 6 | if len(label_map) != 7: 7 | if len(label_map) == 2: 8 | for text in texts: 9 | label_vector = torch.zeros(2) 10 | if text == 'Normal': 11 | label_vector[0] = 1 12 | else: 13 | label_vector[1] = 1 14 | label_vector = label_vector.unsqueeze(0) 15 | label_vectors = torch.cat([label_vectors, label_vector], dim=0) 16 | else: 17 | for text in texts: 18 | label_vector = torch.zeros(len(prompt_text)) 19 | if text in label_map: 20 | label_text = label_map[text] 21 | label_vector[prompt_text.index(label_text)] = 1 22 | 23 | label_vector = label_vector.unsqueeze(0) 24 | label_vectors = torch.cat([label_vectors, label_vector], dim=0) 25 | else: 26 | for text in texts: 27 | label_vector = torch.zeros(len(prompt_text)) 28 | labels = text.split('-') 29 | for label in labels: 30 | if label in label_map: 31 | label_text = label_map[label] 32 | label_vector[prompt_text.index(label_text)] = 1 33 | 34 | label_vector = label_vector.unsqueeze(0) 35 | label_vectors = torch.cat([label_vectors, label_vector], dim=0) 36 | 37 | return label_vectors 38 | 39 | def get_prompt_text(label_map: dict): 40 | prompt_text = [] 41 | for v in label_map.values(): 42 | prompt_text.append(v) 43 | 44 | return prompt_text 45 | 46 | def get_batch_mask(lengths, maxlen): 47 | batch_size = lengths.shape[0] 48 | mask = torch.empty(batch_size, maxlen) 49 | mask.fill_(0) 50 | for i in range(batch_size): 51 | if lengths[i] < maxlen: 52 | mask[i, lengths[i]:maxlen] = 1 53 | 54 | return mask.bool() 55 | 56 | def random_extract(feat, t_max): 57 | r = np.random.randint(feat.shape[0] - t_max) 58 | return feat[r : r+t_max, :] 59 | 60 | def uniform_extract(feat, t_max, avg: bool = True): 61 | new_feat = np.zeros((t_max, feat.shape[1])).astype(np.float32) 62 | r = np.linspace(0, len(feat), t_max+1, dtype=np.int32) 63 | if avg == True: 64 | for i in range(t_max): 65 | if r[i]!=r[i+1]: 66 | new_feat[i,:] = np.mean(feat[r[i]:r[i+1],:], 0) 67 | else: 68 | new_feat[i,:] = feat[r[i],:] 69 | else: 70 | r = np.linspace(0, feat.shape[0]-1, t_max, dtype=np.uint16) 71 | new_feat = feat[r, :] 72 | 73 | return new_feat 74 | 75 | def pad(feat, min_len): 76 | clip_length = feat.shape[0] 77 | if clip_length <= min_len: 78 | return np.pad(feat, ((0, min_len - clip_length), (0, 0)), mode='constant', constant_values=0) 79 | else: 80 | return feat 81 | 82 | def process_feat(feat, length, is_random=False): 83 | clip_length = feat.shape[0] 84 | if feat.shape[0] > length: 85 | if is_random: 86 | return random_extract(feat, length), length 87 | else: 88 | return uniform_extract(feat, length), length 89 | else: 90 | return pad(feat, length), clip_length 91 | 92 | def process_split(feat, length): 93 | clip_length = feat.shape[0] 94 | if clip_length < length: 95 | return pad(feat, length), clip_length 96 | else: 97 | split_num = int(clip_length / length) + 1 98 | for i in range(split_num): 99 | if i == 0: 100 | split_feat = feat[i*length:i*length+length, :].reshape(1, length, feat.shape[1]) 101 | elif i < split_num - 1: 102 | split_feat = np.concatenate([split_feat, feat[i*length:i*length+length, :].reshape(1, length, feat.shape[1])], axis=0) 103 | else: 104 | split_feat = np.concatenate([split_feat, pad(feat[i*length:i*length+length, :], length).reshape(1, length, feat.shape[1])], axis=0) 105 | 106 | return split_feat, clip_length -------------------------------------------------------------------------------- /src/xd_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | from sklearn.metrics import average_precision_score, roc_auc_score 7 | 8 | from model import CLIPVAD 9 | from utils.dataset import XDDataset 10 | from utils.tools import get_batch_mask, get_prompt_text 11 | from utils.xd_detectionMAP import getDetectionMAP as dmAP 12 | import xd_option 13 | 14 | def test(model, testdataloader, maxlen, prompt_text, gt, gtsegments, gtlabels, device): 15 | 16 | model.to(device) 17 | model.eval() 18 | 19 | element_logits2_stack = [] 20 | 21 | with torch.no_grad(): 22 | for i, item in enumerate(testdataloader): 23 | visual = item[0].squeeze(0) 24 | length = item[2] 25 | 26 | length = int(length) 27 | len_cur = length 28 | if len_cur < maxlen: 29 | visual = visual.unsqueeze(0) 30 | 31 | visual = visual.to(device) 32 | 33 | lengths = torch.zeros(int(length / maxlen) + 1) 34 | for j in range(int(length / maxlen) + 1): 35 | if j == 0 and length < maxlen: 36 | lengths[j] = length 37 | elif j == 0 and length > maxlen: 38 | lengths[j] = maxlen 39 | length -= maxlen 40 | elif length > maxlen: 41 | lengths[j] = maxlen 42 | length -= maxlen 43 | else: 44 | lengths[j] = length 45 | lengths = lengths.to(int) 46 | padding_mask = get_batch_mask(lengths, maxlen).to(device) 47 | _, logits1, logits2 = model(visual, padding_mask, prompt_text, lengths) 48 | logits1 = logits1.reshape(logits1.shape[0] * logits1.shape[1], logits1.shape[2]) 49 | logits2 = logits2.reshape(logits2.shape[0] * logits2.shape[1], logits2.shape[2]) 50 | prob2 = (1 - logits2[0:len_cur].softmax(dim=-1)[:, 0].squeeze(-1)) 51 | prob1 = torch.sigmoid(logits1[0:len_cur].squeeze(-1)) 52 | 53 | if i == 0: 54 | ap1 = prob1 55 | ap2 = prob2 56 | else: 57 | ap1 = torch.cat([ap1, prob1], dim=0) 58 | ap2 = torch.cat([ap2, prob2], dim=0) 59 | 60 | element_logits2 = logits2[0:len_cur].softmax(dim=-1).detach().cpu().numpy() 61 | element_logits2 = np.repeat(element_logits2, 16, 0) 62 | element_logits2_stack.append(element_logits2) 63 | 64 | ap1 = ap1.cpu().numpy() 65 | ap2 = ap2.cpu().numpy() 66 | ap1 = ap1.tolist() 67 | ap2 = ap2.tolist() 68 | 69 | ROC1 = roc_auc_score(gt, np.repeat(ap1, 16)) 70 | AP1 = average_precision_score(gt, np.repeat(ap1, 16)) 71 | ROC2 = roc_auc_score(gt, np.repeat(ap2, 16)) 72 | AP2 = average_precision_score(gt, np.repeat(ap2, 16)) 73 | 74 | print("AUC1: ", ROC1, " AP1: ", AP1) 75 | print("AUC2: ", ROC2, " AP2:", AP2) 76 | 77 | dmap, iou = dmAP(element_logits2_stack, gtsegments, gtlabels, excludeNormal=False) 78 | averageMAP = 0 79 | for i in range(5): 80 | print('mAP@{0:.1f} ={1:.2f}%'.format(iou[i], dmap[i])) 81 | averageMAP += dmap[i] 82 | averageMAP = averageMAP/(i+1) 83 | print('average MAP: {:.2f}'.format(averageMAP)) 84 | 85 | return ROC1, AP2 ,0#, averageMAP 86 | 87 | 88 | if __name__ == '__main__': 89 | device = "cuda" if torch.cuda.is_available() else "cpu" 90 | args = xd_option.parser.parse_args() 91 | 92 | label_map = dict({'A': 'normal', 'B1': 'fighting', 'B2': 'shooting', 'B4': 'riot', 'B5': 'abuse', 'B6': 'car accident', 'G': 'explosion'}) 93 | 94 | test_dataset = XDDataset(args.visual_length, args.test_list, True, label_map) 95 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) 96 | 97 | prompt_text = get_prompt_text(label_map) 98 | gt = np.load(args.gt_path) 99 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True) 100 | gtlabels = np.load(args.gt_label_path, allow_pickle=True) 101 | 102 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device) 103 | model_param = torch.load(args.model_path) 104 | model.load_state_dict(model_param) 105 | 106 | test(model, test_loader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device) -------------------------------------------------------------------------------- /src/ucf_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | import numpy as np 6 | from sklearn.metrics import average_precision_score, roc_auc_score 7 | 8 | from model import CLIPVAD 9 | from utils.dataset import UCFDataset 10 | from utils.tools import get_batch_mask, get_prompt_text 11 | from utils.ucf_detectionMAP import getDetectionMAP as dmAP 12 | import ucf_option 13 | 14 | def test(model, testdataloader, maxlen, prompt_text, gt, gtsegments, gtlabels, device): 15 | 16 | model.to(device) 17 | model.eval() 18 | 19 | element_logits2_stack = [] 20 | 21 | with torch.no_grad(): 22 | for i, item in enumerate(testdataloader): 23 | visual = item[0].squeeze(0) 24 | length = item[2] 25 | 26 | length = int(length) 27 | len_cur = length 28 | if len_cur < maxlen: 29 | visual = visual.unsqueeze(0) 30 | 31 | visual = visual.to(device) 32 | 33 | lengths = torch.zeros(int(length / maxlen) + 1) 34 | for j in range(int(length / maxlen) + 1): 35 | if j == 0 and length < maxlen: 36 | lengths[j] = length 37 | elif j == 0 and length > maxlen: 38 | lengths[j] = maxlen 39 | length -= maxlen 40 | elif length > maxlen: 41 | lengths[j] = maxlen 42 | length -= maxlen 43 | else: 44 | lengths[j] = length 45 | lengths = lengths.to(int) 46 | padding_mask = get_batch_mask(lengths, maxlen).to(device) 47 | _, logits1, logits2 = model(visual, padding_mask, prompt_text, lengths) 48 | logits1 = logits1.reshape(logits1.shape[0] * logits1.shape[1], logits1.shape[2]) 49 | logits2 = logits2.reshape(logits2.shape[0] * logits2.shape[1], logits2.shape[2]) 50 | prob2 = (1 - logits2[0:len_cur].softmax(dim=-1)[:, 0].squeeze(-1)) 51 | prob1 = torch.sigmoid(logits1[0:len_cur].squeeze(-1)) 52 | 53 | if i == 0: 54 | ap1 = prob1 55 | ap2 = prob2 56 | #ap3 = prob3 57 | else: 58 | ap1 = torch.cat([ap1, prob1], dim=0) 59 | ap2 = torch.cat([ap2, prob2], dim=0) 60 | 61 | element_logits2 = logits2[0:len_cur].softmax(dim=-1).detach().cpu().numpy() 62 | element_logits2 = np.repeat(element_logits2, 16, 0) 63 | element_logits2_stack.append(element_logits2) 64 | 65 | ap1 = ap1.cpu().numpy() 66 | ap2 = ap2.cpu().numpy() 67 | ap1 = ap1.tolist() 68 | ap2 = ap2.tolist() 69 | 70 | ROC1 = roc_auc_score(gt, np.repeat(ap1, 16)) 71 | AP1 = average_precision_score(gt, np.repeat(ap1, 16)) 72 | ROC2 = roc_auc_score(gt, np.repeat(ap2, 16)) 73 | AP2 = average_precision_score(gt, np.repeat(ap2, 16)) 74 | 75 | print("AUC1: ", ROC1, " AP1: ", AP1) 76 | print("AUC2: ", ROC2, " AP2:", AP2) 77 | 78 | dmap, iou = dmAP(element_logits2_stack, gtsegments, gtlabels, excludeNormal=False) 79 | averageMAP = 0 80 | for i in range(5): 81 | print('mAP@{0:.1f} ={1:.2f}%'.format(iou[i], dmap[i])) 82 | averageMAP += dmap[i] 83 | averageMAP = averageMAP/(i+1) 84 | print('average MAP: {:.2f}'.format(averageMAP)) 85 | 86 | return ROC1, AP1 87 | 88 | 89 | if __name__ == '__main__': 90 | device = "cuda" if torch.cuda.is_available() else "cpu" 91 | args = ucf_option.parser.parse_args() 92 | 93 | label_map = dict({'Normal': 'Normal', 'Abuse': 'Abuse', 'Arrest': 'Arrest', 'Arson': 'Arson', 'Assault': 'Assault', 'Burglary': 'Burglary', 'Explosion': 'Explosion', 'Fighting': 'Fighting', 'RoadAccidents': 'RoadAccidents', 'Robbery': 'Robbery', 'Shooting': 'Shooting', 'Shoplifting': 'Shoplifting', 'Stealing': 'Stealing', 'Vandalism': 'Vandalism'}) 94 | 95 | testdataset = UCFDataset(args.visual_length, args.test_list, True, label_map) 96 | testdataloader = DataLoader(testdataset, batch_size=1, shuffle=False) 97 | 98 | prompt_text = get_prompt_text(label_map) 99 | gt = np.load(args.gt_path) 100 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True) 101 | gtlabels = np.load(args.gt_label_path, allow_pickle=True) 102 | 103 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device) 104 | model_param = torch.load(args.model_path) 105 | model.load_state_dict(model_param) 106 | 107 | test(model, testdataloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VadCLIP 2 | This is the official Pytorch implementation of our paper: 3 | **"VadCLIP: Adapting Vision-Language Models for Weakly Supervised Video Anomaly Detection"** in **AAAI 2024.** 4 | > Peng Wu, Xuerong Zhou, Guansong Pang, Lingru Zhou, Qingsen Yan, Peng Wang, Yanning Zhang 5 | 6 | ![framework](data/framework.png) 7 | 8 | ## Highlight 9 | - We present a novel diagram, i.e., VadCLIP, which involves dual branch to detect video anomaly in visual classification and language-visual alignment manners, respectively. With the benefit of dual branch, VadCLIP achieves both coarse-grained and fine-grained WSVAD. To our knowledge, **VadCLIP is the first work to efficiently transfer pre-trained language-visual knowledge to WSVAD**. 10 | 11 | - We propose three non-vital components to address new challenges led by the new diagram. LGT-Adapter is used to capture temporal dependencies from different perspectives; Two prompt mechanisms are devised to effectively adapt the frozen pre-trained model to WSVAD task; MIL-Align realizes the optimization of alignment paradigm under weak supervision, so as to preserve the pre-trained knowledge as much as possible. 12 | 13 | - We show that strength and effectiveness of VadCLIP on two large-scale popular benchmarks, and VadCLIP achieves state-of-the-art performance, e.g., it gets unprecedented results of 84.51\% AP and 88.02\% on XD-Violence and UCF-Crime respectively, surpassing current classification based methods by a large margin. 14 | 15 | ## Training 16 | 17 | ### Setup 18 | We extract CLIP features for UCF-Crime and XD-Violence datasets, and release these features and pretrained models as follows: 19 | 20 | | Benchmark | CLIP[Baidu] | CLIP | Model[Baidu] | Model | 21 | |--------|----------|-----------|-------------|------------| 22 | | UCF-Crime | [Code: 7yzp](https://pan.baidu.com/s/1OKRIxoLcxt-7RYxWpylgLQ) | [OneDrive](https://stuxidianeducn-my.sharepoint.com/:u:/g/personal/pengwu_stu_xidian_edu_cn/Ea86YOcp5z9KhRFDQm9a8zwBcGiGGg5BuBJtgmCVByazBQ?e=tqHLHt) | [Code: kq5u](https://pan.baidu.com/s/1_9bTC99FklrZRnkmYMuJQw) | [OneDrive](https://stuxidianeducn-my.sharepoint.com/:u:/g/personal/pengwu_stu_xidian_edu_cn/Eaz6sn40RmlFmjELcNHW1IkBV7C0U5OrOaHcuLFzH2S0-Q?e=x8wtVe) | 23 | | XD-Violence | [Code: v8tw](https://pan.baidu.com/s/1q8DiYHcPJtrBQiiJMI7aJw)| [OneDrive](https://stuxidianeducn-my.sharepoint.com/:f:/g/personal/pengwu_stu_xidian_edu_cn/Et5dWQZb2cBDs7zsrp90SrQBL_52vTRNYTdjQW6SMl0ZVA?e=foX4ph) | [Code: apw6](https://pan.baidu.com/s/1O0uwVS3ZyDA1soWUv2VasQ) | [OneDrive](https://stuxidianeducn-my.sharepoint.com/:u:/g/personal/pengwu_stu_xidian_edu_cn/EYlNnn_xfVxBtQZuQgngrMsBHY-i8QHTVOs7PmryzQ2MyA?e=99nxnR) | 24 | 25 | 26 | 27 | 28 | The following files need to be adapted in order to run the code on your own machine: 29 | - Change the file paths to the download datasets above in `list/xd_CLIP_rgb.csv` and `list/xd_CLIP_rgbtest.csv`. 30 | - Feel free to change the hyperparameters in `xd_option.py` 31 | ### Train and Test 32 | After the setup, simply run the following command: 33 | 34 | 35 | Traing and infer for XD-Violence dataset 36 | ``` 37 | python xd_train.py 38 | python xd_test.py 39 | ``` 40 | Traing and infer for UCF-Crime dataset 41 | ``` 42 | python ucf_train.py 43 | python ucf_test.py 44 | ``` 45 | 46 | ## References 47 | We referenced the repos below for the code. 48 | * [XDVioDet](https://github.com/Roc-Ng/XDVioDet) 49 | * [DeepMIL](https://github.com/Roc-Ng/DeepMIL) 50 | 51 | ## Citation 52 | 53 | If you find this repo useful for your research, please consider citing our paper: 54 | 55 | ```bibtex 56 | @article{wu2023vadclip, 57 | title={Vadclip: Adapting vision-language models for weakly supervised video anomaly detection}, 58 | author={Wu, Peng and Zhou, Xuerong and Pang, Guansong and Zhou, Lingru and Yan, Qingsen and Wang, Peng and Zhang, Yanning}, 59 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence (AAAI)}, 60 | year={2024} 61 | } 62 | 63 | @article{wu2023open, 64 | title={Open-Vocabulary Video Anomaly Detection}, 65 | author={Wu, Peng and Zhou, Xuerong and Pang, Guansong and Sun, Yujia and Liu, Jing and Wang, Peng and Zhang, Yanning}, 66 | journal={arXiv preprint arXiv:2311.07042}, 67 | year={2023} 68 | } 69 | 70 | ``` 71 | --- 72 | -------------------------------------------------------------------------------- /src/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /src/utils/xd_detectionMAP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.signal import savgol_filter 3 | 4 | def smooth(v): 5 | return v 6 | # l = min(5, len(v)); l = l - (1-l%2) 7 | # if len(v) <= 3: 8 | # return v 9 | # return savgol_filter(v, l, 1) #savgol_filter(v, l, 1) #0.5*(np.concatenate([v[1:],v[-1:]],axis=0) + v) 10 | 11 | def str2ind(categoryname,classlist): 12 | return [i for i in range(len(classlist)) if categoryname == classlist[i]][0] 13 | 14 | def nms(dets, thresh=0.6, top_k=-1): 15 | """Pure Python NMS baseline.""" 16 | # dets: N*2 and sorted by scores 17 | if len(dets) == 0: return [] 18 | order = np.arange(0,len(dets),1) 19 | dets = np.array(dets) 20 | x1 = dets[:, 0] # start 21 | x2 = dets[:, 1] # end 22 | lengths = x2 - x1 23 | keep = [] 24 | while order.size > 0: 25 | i = order[0] # the first is the best proposal 26 | keep.append(i) # put into the candidate pool 27 | if len(keep) == top_k: 28 | break 29 | xx1 = np.maximum(x1[i], x1[order[1:]]) 30 | xx2 = np.minimum(x2[i], x2[order[1:]]) 31 | inter = np.maximum(0.0, xx2 - xx1) ## the intersection 32 | ovr = inter / (lengths[i] + lengths[order[1:]] - inter) ## the iou 33 | inds = np.where(ovr <= thresh)[0] # the index of remaining proposals 34 | order = order[inds + 1] # add 1 35 | 36 | return dets[keep], keep 37 | 38 | def getLocMAP(predictions, th, gtsegments, gtlabels, excludeNormal): 39 | if excludeNormal is True: 40 | classes_num = 6 41 | videos_num = 500 42 | predictions = predictions[:videos_num] 43 | else: 44 | classes_num = 7 45 | videos_num = 800 46 | 47 | classlist = ['A', 'B1', 'B2', 'B4', 'B5', 'B6', 'G'] 48 | predictions_mod = [] 49 | c_score = [] 50 | for p in predictions: 51 | pp = - p 52 | [pp[:,i].sort() for i in range(np.shape(pp)[1])] 53 | pp=-pp 54 | idx_temp = int(np.shape(pp)[0]/16) 55 | c_s = np.mean(pp[:idx_temp, :], axis=0) 56 | ind = c_s > 0.0 57 | c_score.append(c_s) 58 | predictions_mod.append(p*ind) 59 | predictions = predictions_mod 60 | ap = [] 61 | for c in range(0, 7): 62 | segment_predict = [] 63 | # Get list of all predictions for class c 64 | for i in range(len(predictions)): 65 | tmp = smooth(predictions[i][:, c]) 66 | segment_predict_multithr = [] 67 | thr_set = np.arange(0.6, 0.7, 0.1) 68 | for thr in thr_set: 69 | threshold = np.max(tmp) - (np.max(tmp) - np.min(tmp))*thr ### 0.8 is the best? 70 | vid_pred = np.concatenate([np.zeros(1), (tmp>threshold).astype('float32'), np.zeros(1)], axis=0) 71 | vid_pred_diff = [vid_pred[idt]-vid_pred[idt-1] for idt in range(1, len(vid_pred))] 72 | s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1] 73 | e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1] 74 | for j in range(len(s)): 75 | if e[j]-s[j]>=2: 76 | segment_scores = np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c] 77 | segment_predict_multithr.append([i, s[j], e[j], segment_scores]) 78 | # segment_predict.append([i, s[j], e[j], np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c]]) 79 | if len(segment_predict_multithr)!=0: 80 | segment_predict_multithr = np.array(segment_predict_multithr) 81 | segment_predict_multithr = segment_predict_multithr[np.argsort(-segment_predict_multithr[:,-1])] 82 | _, keep = nms(segment_predict_multithr[:, 1:-1], 0.6) 83 | segment_predict.extend(list(segment_predict_multithr[keep])) 84 | segment_predict = np.array(segment_predict) 85 | 86 | # Sort the list of predictions for class c based on score 87 | if len(segment_predict) == 0: 88 | return 0 89 | segment_predict = segment_predict[np.argsort(-segment_predict[:,3])] 90 | 91 | # Create gt list 92 | segment_gt = [[i, gtsegments[i][j][0], gtsegments[i][j][1]] for i in range(len(gtsegments)) 93 | for j in range(len(gtsegments[i])) if str2ind(gtlabels[i][j], classlist) == c] 94 | gtpos = len(segment_gt) 95 | 96 | # Compare predictions and gt 97 | tp, fp = [], [] 98 | for i in range(len(segment_predict)): 99 | flag = 0. 100 | best_iou = 0.0 101 | for j in range(len(segment_gt)): 102 | if segment_predict[i][0]==segment_gt[j][0]: 103 | gt = range(int(segment_gt[j][1]), int(segment_gt[j][2])) 104 | p = range(int(segment_predict[i][1]), int(segment_predict[i][2])) 105 | IoU = float(len(set(gt).intersection(set(p))))/float(len(set(gt).union(set(p)))) 106 | if IoU >= th: 107 | flag = 1. 108 | if IoU > best_iou: 109 | best_iou = IoU 110 | best_j = j 111 | if flag > 0: 112 | del segment_gt[best_j] 113 | tp.append(flag) 114 | fp.append(1.-flag) 115 | tp_c = np.cumsum(tp) 116 | fp_c = np.cumsum(fp) 117 | if sum(tp)==0: 118 | prc = 0. 119 | else: 120 | prc = np.sum((tp_c/(fp_c+tp_c))*tp)/gtpos 121 | ap.append(prc) 122 | # print(np.round(prc, 4)) 123 | return 100*np.mean(ap) 124 | 125 | 126 | def getDetectionMAP(predictions, segments, labels, excludeNormal=False): 127 | iou_list = [0.1, 0.2, 0.3, 0.4, 0.5] 128 | # iou_list = [0.5] 129 | dmap_list = [] 130 | for iou in iou_list: 131 | # print('Testing for IoU {:.1f}'.format(iou)) 132 | dmap_list.append(getLocMAP(predictions, iou, segments, labels, excludeNormal)) 133 | return dmap_list, iou_list 134 | 135 | -------------------------------------------------------------------------------- /src/utils/ucf_detectionMAP.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.signal import savgol_filter 3 | 4 | def smooth(v): 5 | return v 6 | # l = min(5, len(v)); l = l - (1-l%2) 7 | # if len(v) <= 3: 8 | # return v 9 | # return savgol_filter(v, l, 1) #savgol_filter(v, l, 1) #0.5*(np.concatenate([v[1:],v[-1:]],axis=0) + v) 10 | 11 | def str2ind(categoryname,classlist): 12 | for i in range(len(classlist)): 13 | if categoryname == classlist[i]: 14 | return i 15 | 16 | def nms(dets, thresh=0.6, top_k=-1): 17 | """Pure Python NMS baseline.""" 18 | # dets: N*2 and sorted by scores 19 | if len(dets) == 0: return [] 20 | order = np.arange(0,len(dets),1) 21 | dets = np.array(dets) 22 | x1 = dets[:, 0] # start 23 | x2 = dets[:, 1] # end 24 | lengths = x2 - x1 25 | keep = [] 26 | while order.size > 0: 27 | i = order[0] # the first is the best proposal 28 | keep.append(i) # put into the candidate pool 29 | if len(keep) == top_k: 30 | break 31 | xx1 = np.maximum(x1[i], x1[order[1:]]) 32 | xx2 = np.minimum(x2[i], x2[order[1:]]) 33 | inter = np.maximum(0.0, xx2 - xx1) ## the intersection 34 | ovr = inter / (lengths[i] + lengths[order[1:]] - inter) ## the iou 35 | inds = np.where(ovr <= thresh)[0] # the index of remaining proposals 36 | order = order[inds + 1] # add 1 37 | 38 | return dets[keep], keep 39 | 40 | def getLocMAP(predictions, th, gtsegments, gtlabels, excludeNormal): 41 | if excludeNormal is True: 42 | classes_num = 13 43 | videos_num = 140 44 | predictions = predictions[:videos_num] 45 | else: 46 | classes_num = 14 47 | videos_num = 290 48 | 49 | classlist = ['Normal', 'Abuse', 'Arrest', 'Arson', 'Assault', 'Burglary', 'Explosion', 'Fighting', 'RoadAccidents', 'Robbery', 'Shooting', 'Shoplifting', 'Stealing', 'Vandalism'] 50 | predictions_mod = [] 51 | c_score = [] 52 | for p in predictions: 53 | pp = - p; [pp[:,i].sort() for i in range(np.shape(pp)[1])]; pp=-pp 54 | c_s = np.mean(pp[:int(np.shape(pp)[0]/16), :], axis=0) 55 | ind = c_s > 0.0 56 | c_score.append(c_s) 57 | predictions_mod.append(p*ind) 58 | predictions = predictions_mod 59 | ap = [] 60 | for c in range(0, 14): 61 | segment_predict = [] 62 | # Get list of all predictions for class c 63 | for i in range(len(predictions)): 64 | tmp = smooth(predictions[i][:, c]) 65 | segment_predict_multithr = [] 66 | thr_set = np.arange(0.6, 0.7, 0.1) 67 | for thr in thr_set: 68 | threshold = np.max(tmp) - (np.max(tmp) - np.min(tmp))*thr ### 0.8 is the best? 69 | vid_pred = np.concatenate([np.zeros(1), (tmp>threshold).astype('float32'), np.zeros(1)], axis=0) 70 | vid_pred_diff = [vid_pred[idt]-vid_pred[idt-1] for idt in range(1, len(vid_pred))] 71 | s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1] 72 | e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1] 73 | for j in range(len(s)): 74 | if e[j]-s[j]>=2: 75 | segment_scores = np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c] 76 | segment_predict_multithr.append([i, s[j], e[j], segment_scores]) 77 | # segment_predict.append([i, s[j], e[j], np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c]]) 78 | if len(segment_predict_multithr)!=0: 79 | segment_predict_multithr = np.array(segment_predict_multithr) 80 | segment_predict_multithr = segment_predict_multithr[np.argsort(-segment_predict_multithr[:,-1])] 81 | _, keep = nms(segment_predict_multithr[:, 1:-1], 0.6) 82 | segment_predict.extend(list(segment_predict_multithr[keep])) 83 | segment_predict = np.array(segment_predict) 84 | 85 | # Sort the list of predictions for class c based on score 86 | if len(segment_predict) == 0: 87 | return 0 88 | segment_predict = segment_predict[np.argsort(-segment_predict[:,3])] 89 | 90 | # Create gt list 91 | segment_gt = [[i, gtsegments[i][j][0], gtsegments[i][j][1]] for i in range(len(gtsegments)) 92 | for j in range(len(gtsegments[i])) if str2ind(gtlabels[i][j], classlist) == c] 93 | gtpos = len(segment_gt) 94 | 95 | # Compare predictions and gt 96 | tp, fp = [], [] 97 | for i in range(len(segment_predict)): 98 | flag = 0. 99 | best_iou = 0.0 100 | for j in range(len(segment_gt)): 101 | if segment_predict[i][0]==segment_gt[j][0]: 102 | gt = range(int(segment_gt[j][1]), int(segment_gt[j][2])) 103 | p = range(int(segment_predict[i][1]), int(segment_predict[i][2])) 104 | IoU = float(len(set(gt).intersection(set(p))))/float(len(set(gt).union(set(p)))) 105 | if IoU >= th: 106 | flag = 1. 107 | if IoU > best_iou: 108 | best_iou = IoU 109 | best_j = j 110 | if flag > 0: 111 | del segment_gt[best_j] 112 | tp.append(flag) 113 | fp.append(1.-flag) 114 | tp_c = np.cumsum(tp) 115 | fp_c = np.cumsum(fp) 116 | if sum(tp)==0: 117 | prc = 0. 118 | else: 119 | prc = np.sum((tp_c/(fp_c+tp_c))*tp)/gtpos 120 | ap.append(prc) 121 | # print(np.round(prc, 4)) 122 | return 100*np.mean(ap) 123 | 124 | 125 | def getDetectionMAP(predictions, segments, labels, excludeNormal=False): 126 | iou_list = [0.1, 0.2, 0.3, 0.4, 0.5] 127 | # iou_list = [0.5] 128 | dmap_list = [] 129 | for iou in iou_list: 130 | # print('Testing for IoU {:.1f}'.format(iou)) 131 | dmap_list.append(getLocMAP(predictions, iou, segments, labels, excludeNormal)) 132 | return dmap_list, iou_list 133 | 134 | -------------------------------------------------------------------------------- /src/xd_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | import numpy as np 7 | import random 8 | 9 | from model import CLIPVAD 10 | from xd_test import test 11 | from utils.dataset import XDDataset 12 | from utils.tools import get_prompt_text, get_batch_label 13 | import xd_option 14 | 15 | def CLASM(logits, labels, lengths, device): 16 | instance_logits = torch.zeros(0).to(device) 17 | labels = labels / torch.sum(labels, dim=1, keepdim=True) 18 | labels = labels.to(device) 19 | 20 | for i in range(logits.shape[0]): 21 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True, dim=0) 22 | instance_logits = torch.cat([instance_logits, torch.mean(tmp, 0, keepdim=True)], dim=0) 23 | 24 | milloss = -torch.mean(torch.sum(labels * F.log_softmax(instance_logits, dim=1), dim=1), dim=0) 25 | return milloss 26 | 27 | def CLAS2(logits, labels, lengths, device): 28 | instance_logits = torch.zeros(0).to(device) 29 | labels = 1 - labels[:, 0].reshape(labels.shape[0]) 30 | labels = labels.to(device) 31 | logits = torch.sigmoid(logits).reshape(logits.shape[0], logits.shape[1]) 32 | 33 | for i in range(logits.shape[0]): 34 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True) 35 | tmp = torch.mean(tmp).view(1) 36 | instance_logits = torch.cat((instance_logits, tmp)) 37 | 38 | clsloss = F.binary_cross_entropy(instance_logits, labels) 39 | return clsloss 40 | 41 | def train(model, train_loader, test_loader, args, label_map: dict, device): 42 | model.to(device) 43 | 44 | gt = np.load(args.gt_path) 45 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True) 46 | gtlabels = np.load(args.gt_label_path, allow_pickle=True) 47 | 48 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 49 | scheduler = MultiStepLR(optimizer, args.scheduler_milestones, args.scheduler_rate) 50 | prompt_text = get_prompt_text(label_map) 51 | ap_best = 0 52 | epoch = 0 53 | 54 | if args.use_checkpoint == True: 55 | checkpoint = torch.load(args.checkpoint_path) 56 | model.load_state_dict(checkpoint['model_state_dict']) 57 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 58 | epoch = checkpoint['epoch'] 59 | ap_best = checkpoint['ap'] 60 | print("checkpoint info:") 61 | print("epoch:", epoch+1, " ap:", ap_best) 62 | 63 | for e in range(args.max_epoch): 64 | model.train() 65 | loss_total1 = 0 66 | loss_total2 = 0 67 | for i, item in enumerate(train_loader): 68 | step = 0 69 | visual_feat, text_labels, feat_lengths = item 70 | visual_feat = visual_feat.to(device) 71 | feat_lengths = feat_lengths.to(device) 72 | text_labels = get_batch_label(text_labels, prompt_text, label_map).to(device) 73 | 74 | text_features, logits1, logits2 = model(visual_feat, None, prompt_text, feat_lengths) 75 | 76 | loss1 = CLAS2(logits1, text_labels, feat_lengths, device) 77 | loss_total1 += loss1.item() 78 | 79 | loss2 = CLASM(logits2, text_labels, feat_lengths, device) 80 | loss_total2 += loss2.item() 81 | 82 | loss3 = torch.zeros(1).to(device) 83 | text_feature_normal = text_features[0] / text_features[0].norm(dim=-1, keepdim=True) 84 | for j in range(1, text_features.shape[0]): 85 | text_feature_abr = text_features[j] / text_features[j].norm(dim=-1, keepdim=True) 86 | loss3 += torch.abs(text_feature_normal @ text_feature_abr) 87 | loss3 = loss3 / 6 88 | 89 | loss = loss1 + loss2 + loss3 * 1e-4 90 | 91 | optimizer.zero_grad() 92 | loss.backward() 93 | optimizer.step() 94 | step += i * train_loader.batch_size 95 | if step % 4800 == 0 and step != 0: 96 | print('epoch: ', e+1, '| step: ', step, '| loss1: ', loss_total1 / (i+1), '| loss2: ', loss_total2 / (i+1), '| loss3: ', loss3.item()) 97 | 98 | scheduler.step() 99 | AUC, AP, mAP = test(model, test_loader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device) 100 | 101 | if AP > ap_best: 102 | ap_best = AP 103 | checkpoint = { 104 | 'epoch': e, 105 | 'model_state_dict': model.state_dict(), 106 | 'optimizer_state_dict': optimizer.state_dict(), 107 | 'ap': ap_best} 108 | torch.save(checkpoint, args.checkpoint_path) 109 | 110 | checkpoint = torch.load(args.checkpoint_path) 111 | model.load_state_dict(checkpoint['model_state_dict']) 112 | 113 | checkpoint = torch.load(args.checkpoint_path) 114 | torch.save(checkpoint['model_state_dict'], args.model_path) 115 | 116 | def setup_seed(seed): 117 | torch.manual_seed(seed) 118 | torch.cuda.manual_seed_all(seed) 119 | np.random.seed(seed) 120 | random.seed(seed) 121 | #torch.backends.cudnn.deterministic = True 122 | 123 | if __name__ == '__main__': 124 | device = "cuda" if torch.cuda.is_available() else "cpu" 125 | args = xd_option.parser.parse_args() 126 | setup_seed(args.seed) 127 | 128 | label_map = dict({'A': 'normal', 'B1': 'fighting', 'B2': 'shooting', 'B4': 'riot', 'B5': 'abuse', 'B6': 'car accident', 'G': 'explosion'}) 129 | 130 | train_dataset = XDDataset(args.visual_length, args.train_list, False, label_map) 131 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) 132 | 133 | test_dataset = XDDataset(args.visual_length, args.test_list, True, label_map) 134 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) 135 | 136 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device) 137 | train(model, train_loader, test_loader, args, label_map, device) -------------------------------------------------------------------------------- /src/ucf_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from torch.utils.data import DataLoader 5 | from torch.optim.lr_scheduler import MultiStepLR 6 | import numpy as np 7 | import random 8 | 9 | from model import CLIPVAD 10 | from ucf_test import test 11 | from utils.dataset import UCFDataset 12 | from utils.tools import get_prompt_text, get_batch_label 13 | import ucf_option 14 | 15 | def CLASM(logits, labels, lengths, device): 16 | instance_logits = torch.zeros(0).to(device) 17 | labels = labels / torch.sum(labels, dim=1, keepdim=True) 18 | labels = labels.to(device) 19 | 20 | for i in range(logits.shape[0]): 21 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True, dim=0) 22 | instance_logits = torch.cat([instance_logits, torch.mean(tmp, 0, keepdim=True)], dim=0) 23 | 24 | milloss = -torch.mean(torch.sum(labels * F.log_softmax(instance_logits, dim=1), dim=1), dim=0) 25 | return milloss 26 | 27 | def CLAS2(logits, labels, lengths, device): 28 | instance_logits = torch.zeros(0).to(device) 29 | labels = 1 - labels[:, 0].reshape(labels.shape[0]) 30 | labels = labels.to(device) 31 | logits = torch.sigmoid(logits).reshape(logits.shape[0], logits.shape[1]) 32 | 33 | for i in range(logits.shape[0]): 34 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True) 35 | tmp = torch.mean(tmp).view(1) 36 | instance_logits = torch.cat([instance_logits, tmp], dim=0) 37 | 38 | clsloss = F.binary_cross_entropy(instance_logits, labels) 39 | return clsloss 40 | 41 | def train(model, normal_loader, anomaly_loader, testloader, args, label_map, device): 42 | model.to(device) 43 | gt = np.load(args.gt_path) 44 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True) 45 | gtlabels = np.load(args.gt_label_path, allow_pickle=True) 46 | 47 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr) 48 | scheduler = MultiStepLR(optimizer, args.scheduler_milestones, args.scheduler_rate) 49 | prompt_text = get_prompt_text(label_map) 50 | ap_best = 0 51 | epoch = 0 52 | 53 | if args.use_checkpoint == True: 54 | checkpoint = torch.load(args.checkpoint_path) 55 | model.load_state_dict(checkpoint['model_state_dict']) 56 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 57 | epoch = checkpoint['epoch'] 58 | ap_best = checkpoint['ap'] 59 | print("checkpoint info:") 60 | print("epoch:", epoch+1, " ap:", ap_best) 61 | 62 | for e in range(args.max_epoch): 63 | model.train() 64 | loss_total1 = 0 65 | loss_total2 = 0 66 | normal_iter = iter(normal_loader) 67 | anomaly_iter = iter(anomaly_loader) 68 | for i in range(min(len(normal_loader), len(anomaly_loader))): 69 | step = 0 70 | normal_features, normal_label, normal_lengths = next(normal_iter) 71 | anomaly_features, anomaly_label, anomaly_lengths = next(anomaly_iter) 72 | 73 | visual_features = torch.cat([normal_features, anomaly_features], dim=0).to(device) 74 | text_labels = list(normal_label) + list(anomaly_label) 75 | feat_lengths = torch.cat([normal_lengths, anomaly_lengths], dim=0).to(device) 76 | text_labels = get_batch_label(text_labels, prompt_text, label_map).to(device) 77 | 78 | text_features, logits1, logits2 = model(visual_features, None, prompt_text, feat_lengths) 79 | #loss1 80 | loss1 = CLAS2(logits1, text_labels, feat_lengths, device) 81 | loss_total1 += loss1.item() 82 | #loss2 83 | loss2 = CLASM(logits2, text_labels, feat_lengths, device) 84 | loss_total2 += loss2.item() 85 | #loss3 86 | loss3 = torch.zeros(1).to(device) 87 | text_feature_normal = text_features[0] / text_features[0].norm(dim=-1, keepdim=True) 88 | for j in range(1, text_features.shape[0]): 89 | text_feature_abr = text_features[j] / text_features[j].norm(dim=-1, keepdim=True) 90 | loss3 += torch.abs(text_feature_normal @ text_feature_abr) 91 | loss3 = loss3 / 13 * 1e-1 92 | 93 | loss = loss1 + loss2 + loss3 94 | 95 | optimizer.zero_grad() 96 | loss.backward() 97 | optimizer.step() 98 | step += i * normal_loader.batch_size * 2 99 | if step % 1280 == 0 and step != 0: 100 | print('epoch: ', e+1, '| step: ', step, '| loss1: ', loss_total1 / (i+1), '| loss2: ', loss_total2 / (i+1), '| loss3: ', loss3.item()) 101 | AUC, AP = test(model, testloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device) 102 | AP = AUC 103 | 104 | if AP > ap_best: 105 | ap_best = AP 106 | checkpoint = { 107 | 'epoch': e, 108 | 'model_state_dict': model.state_dict(), 109 | 'optimizer_state_dict': optimizer.state_dict(), 110 | 'ap': ap_best} 111 | torch.save(checkpoint, args.checkpoint_path) 112 | 113 | scheduler.step() 114 | 115 | torch.save(model.state_dict(), 'model/model_cur.pth') 116 | checkpoint = torch.load(args.checkpoint_path) 117 | model.load_state_dict(checkpoint['model_state_dict']) 118 | 119 | checkpoint = torch.load(args.checkpoint_path) 120 | torch.save(checkpoint['model_state_dict'], args.model_path) 121 | 122 | def setup_seed(seed): 123 | torch.manual_seed(seed) 124 | torch.cuda.manual_seed_all(seed) 125 | np.random.seed(seed) 126 | random.seed(seed) 127 | #torch.backends.cudnn.deterministic = True 128 | 129 | if __name__ == '__main__': 130 | device = "cuda" if torch.cuda.is_available() else "cpu" 131 | args = ucf_option.parser.parse_args() 132 | setup_seed(args.seed) 133 | 134 | label_map = dict({'Normal': 'normal', 'Abuse': 'abuse', 'Arrest': 'arrest', 'Arson': 'arson', 'Assault': 'assault', 'Burglary': 'burglary', 'Explosion': 'explosion', 'Fighting': 'fighting', 'RoadAccidents': 'roadAccidents', 'Robbery': 'robbery', 'Shooting': 'shooting', 'Shoplifting': 'shoplifting', 'Stealing': 'stealing', 'Vandalism': 'vandalism'}) 135 | 136 | normal_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, True) 137 | normal_loader = DataLoader(normal_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 138 | anomaly_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, False) 139 | anomaly_loader = DataLoader(anomaly_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True) 140 | 141 | test_dataset = UCFDataset(args.visual_length, args.test_list, True, label_map) 142 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) 143 | 144 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device) 145 | 146 | train(model, normal_loader, anomaly_loader, test_loader, args, label_map, device) -------------------------------------------------------------------------------- /src/utils/layers.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from torch import FloatTensor 3 | from torch.nn.parameter import Parameter 4 | from torch.nn.modules.module import Module 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from scipy.spatial.distance import pdist, squareform 10 | 11 | class GraphAttentionLayer(nn.Module): 12 | """ 13 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903 14 | """ 15 | 16 | def __init__(self, in_features, out_features, dropout, alpha, concat=True): 17 | super(GraphAttentionLayer, self).__init__() 18 | self.dropout = dropout 19 | self.in_features = in_features 20 | self.out_features = out_features 21 | self.alpha = alpha 22 | self.concat = concat 23 | 24 | self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True) 25 | self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True) 26 | 27 | self.leakyrelu = nn.LeakyReLU(self.alpha) 28 | 29 | def forward(self, input, adj): 30 | h = torch.mm(input, self.W) 31 | N = h.size()[0] 32 | 33 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features) 34 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2)) 35 | 36 | zero_vec = -9e15*torch.ones_like(e) 37 | attention = torch.where(adj > 0, e, zero_vec) 38 | attention = F.softmax(attention, dim=1) 39 | attention = F.dropout(attention, self.dropout, training=self.training) 40 | h_prime = torch.matmul(attention, h) 41 | 42 | if self.concat: 43 | return F.elu(h_prime) 44 | else: 45 | return h_prime 46 | 47 | def __repr__(self): 48 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')' 49 | 50 | class linear(nn.Module): 51 | def __init__(self, in_features, out_features): 52 | super(linear, self).__init__() 53 | self.weight = Parameter(FloatTensor(in_features, out_features)) 54 | self.register_parameter('bias', None) 55 | stdv = 1. / sqrt(self.weight.size(1)) 56 | self.weight.data.uniform_(-stdv, stdv) 57 | def forward(self, x): 58 | x = x.matmul(self.weight) 59 | return x 60 | 61 | class GraphConvolution(Module): 62 | """ 63 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907 64 | """ 65 | 66 | def __init__(self, in_features, out_features, bias=False, residual=True): 67 | super(GraphConvolution, self).__init__() 68 | self.in_features = in_features 69 | self.out_features = out_features 70 | self.weight = Parameter(FloatTensor(in_features, out_features)) 71 | if bias: 72 | self.bias = Parameter(FloatTensor(out_features)) 73 | else: 74 | self.register_parameter('bias', None) 75 | self.reset_parameters() 76 | if not residual: 77 | self.residual = lambda x: 0 78 | elif (in_features == out_features): 79 | self.residual = lambda x: x 80 | else: 81 | # self.residual = linear(in_features, out_features) 82 | self.residual = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=5, padding=2) 83 | def reset_parameters(self): 84 | # stdv = 1. / sqrt(self.weight.size(1)) 85 | nn.init.xavier_uniform_(self.weight) 86 | if self.bias is not None: 87 | self.bias.data.fill_(0.1) 88 | 89 | def forward(self, input, adj): 90 | # To support batch operations 91 | support = input.matmul(self.weight) 92 | output = adj.matmul(support) 93 | 94 | if self.bias is not None: 95 | output = output + self.bias 96 | if self.in_features != self.out_features and self.residual: 97 | input = input.permute(0,2,1) 98 | res = self.residual(input) 99 | res = res.permute(0,2,1) 100 | output = output + res 101 | else: 102 | output = output + self.residual(input) 103 | 104 | return output 105 | 106 | def __repr__(self): 107 | return self.__class__.__name__ + ' (' \ 108 | + str(self.in_features) + ' -> ' \ 109 | + str(self.out_features) + ')' 110 | 111 | ###################################################### 112 | 113 | class SimilarityAdj(Module): 114 | 115 | def __init__(self, in_features, out_features): 116 | super(SimilarityAdj, self).__init__() 117 | self.in_features = in_features 118 | self.out_features = out_features 119 | 120 | self.weight0 = Parameter(FloatTensor(in_features, out_features)) 121 | self.weight1 = Parameter(FloatTensor(in_features, out_features)) 122 | self.register_parameter('bias', None) 123 | self.reset_parameters() 124 | 125 | def reset_parameters(self): 126 | # stdv = 1. / sqrt(self.weight0.size(1)) 127 | nn.init.xavier_uniform_(self.weight0) 128 | nn.init.xavier_uniform_(self.weight1) 129 | 130 | def forward(self, input, seq_len): 131 | # To support batch operations 132 | theta = torch.matmul(input, self.weight0) 133 | phi = torch.matmul(input, self.weight0) 134 | phi2 = phi.permute(0, 2, 1) 135 | sim_graph = torch.matmul(theta, phi2) 136 | 137 | theta_norm = torch.norm(theta, p=2, dim=2, keepdim=True) # B*T*1 138 | phi_norm = torch.norm(phi, p=2, dim=2, keepdim=True) # B*T*1 139 | x_norm_x = theta_norm.matmul(phi_norm.permute(0, 2, 1)) 140 | sim_graph = sim_graph / (x_norm_x + 1e-20) 141 | 142 | output = torch.zeros_like(sim_graph) 143 | if seq_len is None: 144 | for i in range(sim_graph.shape[0]): 145 | tmp = sim_graph[i] 146 | adj2 = tmp 147 | adj2 = F.threshold(adj2, 0.7, 0) 148 | adj2 = F.softmax(adj2, dim=1) 149 | output[i] = adj2 150 | else: 151 | for i in range(len(seq_len)): 152 | tmp = sim_graph[i, :seq_len[i], :seq_len[i]] 153 | adj2 = tmp 154 | adj2 = F.threshold(adj2, 0.7, 0) 155 | adj2 = F.softmax(adj2, dim=1) 156 | output[i, :seq_len[i], :seq_len[i]] = adj2 157 | 158 | return output 159 | 160 | def __repr__(self): 161 | return self.__class__.__name__ + ' (' \ 162 | + str(self.in_features) + ' -> ' \ 163 | + str(self.out_features) + ')' 164 | 165 | class DistanceAdj(Module): 166 | 167 | def __init__(self): 168 | super(DistanceAdj, self).__init__() 169 | self.sigma = Parameter(FloatTensor(1)) 170 | self.sigma.data.fill_(0.1) 171 | 172 | def forward(self, batch_size, max_seqlen): 173 | # To support batch operations 174 | self.arith = np.arange(max_seqlen).reshape(-1, 1) 175 | dist = pdist(self.arith, metric='cityblock').astype(np.float32) 176 | self.dist = torch.from_numpy(squareform(dist)).to('cuda') 177 | self.dist = torch.exp(-self.dist / torch.exp(torch.tensor(1.))) 178 | self.dist = torch.unsqueeze(self.dist, 0).repeat(batch_size, 1, 1).to('cuda') 179 | return self.dist 180 | 181 | if __name__ == '__main__': 182 | d = DistanceAdj() 183 | dist = d(1, 256).squeeze(0) 184 | print(dist.softmax(dim=-1)) -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn 7 | from clip import clip 8 | from utils.layers import GraphConvolution, DistanceAdj 9 | 10 | class LayerNorm(nn.LayerNorm): 11 | 12 | def forward(self, x: torch.Tensor): 13 | orig_type = x.dtype 14 | ret = super().forward(x.type(torch.float32)) 15 | return ret.type(orig_type) 16 | 17 | 18 | class QuickGELU(nn.Module): 19 | def forward(self, x: torch.Tensor): 20 | return x * torch.sigmoid(1.702 * x) 21 | 22 | 23 | class ResidualAttentionBlock(nn.Module): 24 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 25 | super().__init__() 26 | 27 | self.attn = nn.MultiheadAttention(d_model, n_head) 28 | self.ln_1 = LayerNorm(d_model) 29 | self.mlp = nn.Sequential(OrderedDict([ 30 | ("c_fc", nn.Linear(d_model, d_model * 4)), 31 | ("gelu", QuickGELU()), 32 | ("c_proj", nn.Linear(d_model * 4, d_model)) 33 | ])) 34 | self.ln_2 = LayerNorm(d_model) 35 | self.attn_mask = attn_mask 36 | 37 | def attention(self, x: torch.Tensor, padding_mask: torch.Tensor): 38 | padding_mask = padding_mask.to(dtype=bool, device=x.device) if padding_mask is not None else None 39 | self.attn_mask = self.attn_mask.to(device=x.device) if self.attn_mask is not None else None 40 | return self.attn(x, x, x, need_weights=False, key_padding_mask=padding_mask, attn_mask=self.attn_mask)[0] 41 | 42 | def forward(self, x): 43 | x, padding_mask = x 44 | x = x + self.attention(self.ln_1(x), padding_mask) 45 | x = x + self.mlp(self.ln_2(x)) 46 | return (x, padding_mask) 47 | 48 | 49 | class Transformer(nn.Module): 50 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 51 | super().__init__() 52 | self.width = width 53 | self.layers = layers 54 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 55 | 56 | def forward(self, x: torch.Tensor): 57 | return self.resblocks(x) 58 | 59 | 60 | class CLIPVAD(nn.Module): 61 | def __init__(self, 62 | num_class: int, 63 | embed_dim: int, 64 | visual_length: int, 65 | visual_width: int, 66 | visual_head: int, 67 | visual_layers: int, 68 | attn_window: int, 69 | prompt_prefix: int, 70 | prompt_postfix: int, 71 | device): 72 | super().__init__() 73 | 74 | self.num_class = num_class 75 | self.visual_length = visual_length 76 | self.visual_width = visual_width 77 | self.embed_dim = embed_dim 78 | self.attn_window = attn_window 79 | self.prompt_prefix = prompt_prefix 80 | self.prompt_postfix = prompt_postfix 81 | self.device = device 82 | 83 | self.temporal = Transformer( 84 | width=visual_width, 85 | layers=visual_layers, 86 | heads=visual_head, 87 | attn_mask=self.build_attention_mask(self.attn_window) 88 | ) 89 | 90 | width = int(visual_width / 2) 91 | self.gc1 = GraphConvolution(visual_width, width, residual=True) 92 | self.gc2 = GraphConvolution(width, width, residual=True) 93 | self.gc3 = GraphConvolution(visual_width, width, residual=True) 94 | self.gc4 = GraphConvolution(width, width, residual=True) 95 | self.disAdj = DistanceAdj() 96 | self.linear = nn.Linear(visual_width, visual_width) 97 | self.gelu = QuickGELU() 98 | 99 | self.mlp1 = nn.Sequential(OrderedDict([ 100 | ("c_fc", nn.Linear(visual_width, visual_width * 4)), 101 | ("gelu", QuickGELU()), 102 | ("c_proj", nn.Linear(visual_width * 4, visual_width)) 103 | ])) 104 | self.mlp2 = nn.Sequential(OrderedDict([ 105 | ("c_fc", nn.Linear(visual_width, visual_width * 4)), 106 | ("gelu", QuickGELU()), 107 | ("c_proj", nn.Linear(visual_width * 4, visual_width)) 108 | ])) 109 | self.classifier = nn.Linear(visual_width, 1) 110 | 111 | self.clipmodel, _ = clip.load("ViT-B/16", device) 112 | for clip_param in self.clipmodel.parameters(): 113 | clip_param.requires_grad = False 114 | 115 | self.frame_position_embeddings = nn.Embedding(visual_length, visual_width) 116 | self.text_prompt_embeddings = nn.Embedding(77, self.embed_dim) 117 | 118 | self.initialize_parameters() 119 | 120 | def initialize_parameters(self): 121 | nn.init.normal_(self.text_prompt_embeddings.weight, std=0.01) 122 | nn.init.normal_(self.frame_position_embeddings.weight, std=0.01) 123 | 124 | def build_attention_mask(self, attn_window): 125 | # lazily create causal attention mask, with full attention between the vision tokens 126 | # pytorch uses additive attention mask; fill with -inf 127 | mask = torch.empty(self.visual_length, self.visual_length) 128 | mask.fill_(float('-inf')) 129 | for i in range(int(self.visual_length / attn_window)): 130 | if (i + 1) * attn_window < self.visual_length: 131 | mask[i * attn_window: (i + 1) * attn_window, i * attn_window: (i + 1) * attn_window] = 0 132 | else: 133 | mask[i * attn_window: self.visual_length, i * attn_window: self.visual_length] = 0 134 | 135 | return mask 136 | 137 | def adj4(self, x, seq_len): 138 | soft = nn.Softmax(1) 139 | x2 = x.matmul(x.permute(0, 2, 1)) # B*T*T 140 | x_norm = torch.norm(x, p=2, dim=2, keepdim=True) # B*T*1 141 | x_norm_x = x_norm.matmul(x_norm.permute(0, 2, 1)) 142 | x2 = x2/(x_norm_x+1e-20) 143 | output = torch.zeros_like(x2) 144 | if seq_len is None: 145 | for i in range(x.shape[0]): 146 | tmp = x2[i] 147 | adj2 = tmp 148 | adj2 = F.threshold(adj2, 0.7, 0) 149 | adj2 = soft(adj2) 150 | output[i] = adj2 151 | else: 152 | for i in range(len(seq_len)): 153 | tmp = x2[i, :seq_len[i], :seq_len[i]] 154 | adj2 = tmp 155 | adj2 = F.threshold(adj2, 0.7, 0) 156 | adj2 = soft(adj2) 157 | output[i, :seq_len[i], :seq_len[i]] = adj2 158 | 159 | return output 160 | 161 | def encode_video(self, images, padding_mask, lengths): 162 | images = images.to(torch.float) 163 | position_ids = torch.arange(self.visual_length, device=self.device) 164 | position_ids = position_ids.unsqueeze(0).expand(images.shape[0], -1) 165 | frame_position_embeddings = self.frame_position_embeddings(position_ids) 166 | frame_position_embeddings = frame_position_embeddings.permute(1, 0, 2) 167 | images = images.permute(1, 0, 2) + frame_position_embeddings 168 | 169 | x, _ = self.temporal((images, None)) 170 | x = x.permute(1, 0, 2) 171 | 172 | adj = self.adj4(x, lengths) 173 | disadj = self.disAdj(x.shape[0], x.shape[1]) 174 | x1_h = self.gelu(self.gc1(x, adj)) 175 | x2_h = self.gelu(self.gc3(x, disadj)) 176 | 177 | x1 = self.gelu(self.gc2(x1_h, adj)) 178 | x2 = self.gelu(self.gc4(x2_h, disadj)) 179 | 180 | x = torch.cat((x1, x2), 2) 181 | x = self.linear(x) 182 | 183 | return x 184 | 185 | def encode_textprompt(self, text): 186 | word_tokens = clip.tokenize(text).to(self.device) 187 | word_embedding = self.clipmodel.encode_token(word_tokens) 188 | text_embeddings = self.text_prompt_embeddings(torch.arange(77).to(self.device)).unsqueeze(0).repeat([len(text), 1, 1]) 189 | text_tokens = torch.zeros(len(text), 77).to(self.device) 190 | 191 | for i in range(len(text)): 192 | ind = torch.argmax(word_tokens[i], -1) 193 | text_embeddings[i, 0] = word_embedding[i, 0] 194 | text_embeddings[i, self.prompt_prefix + 1: self.prompt_prefix + ind] = word_embedding[i, 1: ind] 195 | text_embeddings[i, self.prompt_prefix + ind + self.prompt_postfix] = word_embedding[i, ind] 196 | text_tokens[i, self.prompt_prefix + ind + self.prompt_postfix] = word_tokens[i, ind] 197 | 198 | text_features = self.clipmodel.encode_text(text_embeddings, text_tokens) 199 | 200 | return text_features 201 | 202 | def forward(self, visual, padding_mask, text, lengths): 203 | visual_features = self.encode_video(visual, padding_mask, lengths) 204 | logits1 = self.classifier(visual_features + self.mlp2(visual_features)) 205 | 206 | text_features_ori = self.encode_textprompt(text) 207 | 208 | text_features = text_features_ori 209 | logits_attn = logits1.permute(0, 2, 1) 210 | visual_attn = logits_attn @ visual_features 211 | visual_attn = visual_attn / visual_attn.norm(dim=-1, keepdim=True) 212 | visual_attn = visual_attn.expand(visual_attn.shape[0], text_features_ori.shape[0], visual_attn.shape[2]) 213 | text_features = text_features_ori.unsqueeze(0) 214 | text_features = text_features.expand(visual_attn.shape[0], text_features.shape[1], text_features.shape[2]) 215 | text_features = text_features + visual_attn 216 | text_features = text_features + self.mlp1(text_features) 217 | 218 | visual_features_norm = visual_features / visual_features.norm(dim=-1, keepdim=True) 219 | text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True) 220 | text_features_norm = text_features_norm.permute(0, 2, 1) 221 | logits2 = visual_features_norm @ text_features_norm.type(visual_features_norm.dtype) / 0.07 222 | 223 | return text_features_ori, logits1, logits2 224 | -------------------------------------------------------------------------------- /src/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict()).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /list/Anomaly_Test.txt: -------------------------------------------------------------------------------- 1 | Abuse/Abuse028_x264.mp4 2 | Abuse/Abuse030_x264.mp4 3 | Arrest/Arrest001_x264.mp4 4 | Arrest/Arrest007_x264.mp4 5 | Arrest/Arrest024_x264.mp4 6 | Arrest/Arrest030_x264.mp4 7 | Arrest/Arrest039_x264.mp4 8 | Arson/Arson007_x264.mp4 9 | Arson/Arson009_x264.mp4 10 | Arson/Arson010_x264.mp4 11 | Arson/Arson011_x264.mp4 12 | Arson/Arson016_x264.mp4 13 | Arson/Arson018_x264.mp4 14 | Arson/Arson022_x264.mp4 15 | Arson/Arson035_x264.mp4 16 | Arson/Arson041_x264.mp4 17 | Assault/Assault006_x264.mp4 18 | Assault/Assault010_x264.mp4 19 | Assault/Assault011_x264.mp4 20 | Burglary/Burglary005_x264.mp4 21 | Burglary/Burglary017_x264.mp4 22 | Burglary/Burglary018_x264.mp4 23 | Burglary/Burglary021_x264.mp4 24 | Burglary/Burglary024_x264.mp4 25 | Burglary/Burglary032_x264.mp4 26 | Burglary/Burglary033_x264.mp4 27 | Burglary/Burglary035_x264.mp4 28 | Burglary/Burglary037_x264.mp4 29 | Burglary/Burglary061_x264.mp4 30 | Burglary/Burglary076_x264.mp4 31 | Burglary/Burglary079_x264.mp4 32 | Burglary/Burglary092_x264.mp4 33 | Explosion/Explosion002_x264.mp4 34 | Explosion/Explosion004_x264.mp4 35 | Explosion/Explosion007_x264.mp4 36 | Explosion/Explosion008_x264.mp4 37 | Explosion/Explosion010_x264.mp4 38 | Explosion/Explosion011_x264.mp4 39 | Explosion/Explosion013_x264.mp4 40 | Explosion/Explosion016_x264.mp4 41 | Explosion/Explosion017_x264.mp4 42 | Explosion/Explosion020_x264.mp4 43 | Explosion/Explosion021_x264.mp4 44 | Explosion/Explosion022_x264.mp4 45 | Explosion/Explosion025_x264.mp4 46 | Explosion/Explosion027_x264.mp4 47 | Explosion/Explosion028_x264.mp4 48 | Explosion/Explosion029_x264.mp4 49 | Explosion/Explosion033_x264.mp4 50 | Explosion/Explosion035_x264.mp4 51 | Explosion/Explosion036_x264.mp4 52 | Explosion/Explosion039_x264.mp4 53 | Explosion/Explosion043_x264.mp4 54 | Fighting/Fighting003_x264.mp4 55 | Fighting/Fighting018_x264.mp4 56 | Fighting/Fighting033_x264.mp4 57 | Fighting/Fighting042_x264.mp4 58 | Fighting/Fighting047_x264.mp4 59 | RoadAccidents/RoadAccidents001_x264.mp4 60 | RoadAccidents/RoadAccidents002_x264.mp4 61 | RoadAccidents/RoadAccidents004_x264.mp4 62 | RoadAccidents/RoadAccidents009_x264.mp4 63 | RoadAccidents/RoadAccidents010_x264.mp4 64 | RoadAccidents/RoadAccidents011_x264.mp4 65 | RoadAccidents/RoadAccidents012_x264.mp4 66 | RoadAccidents/RoadAccidents016_x264.mp4 67 | RoadAccidents/RoadAccidents017_x264.mp4 68 | RoadAccidents/RoadAccidents019_x264.mp4 69 | RoadAccidents/RoadAccidents020_x264.mp4 70 | RoadAccidents/RoadAccidents021_x264.mp4 71 | RoadAccidents/RoadAccidents022_x264.mp4 72 | RoadAccidents/RoadAccidents121_x264.mp4 73 | RoadAccidents/RoadAccidents122_x264.mp4 74 | RoadAccidents/RoadAccidents123_x264.mp4 75 | RoadAccidents/RoadAccidents124_x264.mp4 76 | RoadAccidents/RoadAccidents125_x264.mp4 77 | RoadAccidents/RoadAccidents127_x264.mp4 78 | RoadAccidents/RoadAccidents128_x264.mp4 79 | RoadAccidents/RoadAccidents131_x264.mp4 80 | RoadAccidents/RoadAccidents132_x264.mp4 81 | RoadAccidents/RoadAccidents133_x264.mp4 82 | Robbery/Robbery048_x264.mp4 83 | Robbery/Robbery050_x264.mp4 84 | Robbery/Robbery102_x264.mp4 85 | Robbery/Robbery106_x264.mp4 86 | Robbery/Robbery137_x264.mp4 87 | Shooting/Shooting002_x264.mp4 88 | Shooting/Shooting004_x264.mp4 89 | Shooting/Shooting007_x264.mp4 90 | Shooting/Shooting008_x264.mp4 91 | Shooting/Shooting010_x264.mp4 92 | Shooting/Shooting011_x264.mp4 93 | Shooting/Shooting013_x264.mp4 94 | Shooting/Shooting015_x264.mp4 95 | Shooting/Shooting018_x264.mp4 96 | Shooting/Shooting019_x264.mp4 97 | Shooting/Shooting021_x264.mp4 98 | Shooting/Shooting022_x264.mp4 99 | Shooting/Shooting024_x264.mp4 100 | Shooting/Shooting026_x264.mp4 101 | Shooting/Shooting028_x264.mp4 102 | Shooting/Shooting032_x264.mp4 103 | Shooting/Shooting033_x264.mp4 104 | Shooting/Shooting034_x264.mp4 105 | Shooting/Shooting037_x264.mp4 106 | Shooting/Shooting043_x264.mp4 107 | Shooting/Shooting046_x264.mp4 108 | Shooting/Shooting047_x264.mp4 109 | Shooting/Shooting048_x264.mp4 110 | Shoplifting/Shoplifting001_x264.mp4 111 | Shoplifting/Shoplifting004_x264.mp4 112 | Shoplifting/Shoplifting005_x264.mp4 113 | Shoplifting/Shoplifting007_x264.mp4 114 | Shoplifting/Shoplifting010_x264.mp4 115 | Shoplifting/Shoplifting015_x264.mp4 116 | Shoplifting/Shoplifting016_x264.mp4 117 | Shoplifting/Shoplifting017_x264.mp4 118 | Shoplifting/Shoplifting020_x264.mp4 119 | Shoplifting/Shoplifting021_x264.mp4 120 | Shoplifting/Shoplifting022_x264.mp4 121 | Shoplifting/Shoplifting027_x264.mp4 122 | Shoplifting/Shoplifting028_x264.mp4 123 | Shoplifting/Shoplifting029_x264.mp4 124 | Shoplifting/Shoplifting031_x264.mp4 125 | Shoplifting/Shoplifting033_x264.mp4 126 | Shoplifting/Shoplifting034_x264.mp4 127 | Shoplifting/Shoplifting037_x264.mp4 128 | Shoplifting/Shoplifting039_x264.mp4 129 | Shoplifting/Shoplifting044_x264.mp4 130 | Shoplifting/Shoplifting049_x264.mp4 131 | Stealing/Stealing019_x264.mp4 132 | Stealing/Stealing036_x264.mp4 133 | Stealing/Stealing058_x264.mp4 134 | Stealing/Stealing062_x264.mp4 135 | Stealing/Stealing079_x264.mp4 136 | Testing_Normal_Videos_Anomaly/Normal_Videos_003_x264.mp4 137 | Testing_Normal_Videos_Anomaly/Normal_Videos_006_x264.mp4 138 | Testing_Normal_Videos_Anomaly/Normal_Videos_010_x264.mp4 139 | Testing_Normal_Videos_Anomaly/Normal_Videos_014_x264.mp4 140 | Testing_Normal_Videos_Anomaly/Normal_Videos_015_x264.mp4 141 | Testing_Normal_Videos_Anomaly/Normal_Videos_018_x264.mp4 142 | Testing_Normal_Videos_Anomaly/Normal_Videos_019_x264.mp4 143 | Testing_Normal_Videos_Anomaly/Normal_Videos_024_x264.mp4 144 | Testing_Normal_Videos_Anomaly/Normal_Videos_025_x264.mp4 145 | Testing_Normal_Videos_Anomaly/Normal_Videos_027_x264.mp4 146 | Testing_Normal_Videos_Anomaly/Normal_Videos_033_x264.mp4 147 | Testing_Normal_Videos_Anomaly/Normal_Videos_034_x264.mp4 148 | Testing_Normal_Videos_Anomaly/Normal_Videos_041_x264.mp4 149 | Testing_Normal_Videos_Anomaly/Normal_Videos_042_x264.mp4 150 | Testing_Normal_Videos_Anomaly/Normal_Videos_048_x264.mp4 151 | Testing_Normal_Videos_Anomaly/Normal_Videos_050_x264.mp4 152 | Testing_Normal_Videos_Anomaly/Normal_Videos_051_x264.mp4 153 | Testing_Normal_Videos_Anomaly/Normal_Videos_056_x264.mp4 154 | Testing_Normal_Videos_Anomaly/Normal_Videos_059_x264.mp4 155 | Testing_Normal_Videos_Anomaly/Normal_Videos_063_x264.mp4 156 | Testing_Normal_Videos_Anomaly/Normal_Videos_067_x264.mp4 157 | Testing_Normal_Videos_Anomaly/Normal_Videos_070_x264.mp4 158 | Testing_Normal_Videos_Anomaly/Normal_Videos_100_x264.mp4 159 | Testing_Normal_Videos_Anomaly/Normal_Videos_129_x264.mp4 160 | Testing_Normal_Videos_Anomaly/Normal_Videos_150_x264.mp4 161 | Testing_Normal_Videos_Anomaly/Normal_Videos_168_x264.mp4 162 | Testing_Normal_Videos_Anomaly/Normal_Videos_175_x264.mp4 163 | Testing_Normal_Videos_Anomaly/Normal_Videos_182_x264.mp4 164 | Testing_Normal_Videos_Anomaly/Normal_Videos_189_x264.mp4 165 | Testing_Normal_Videos_Anomaly/Normal_Videos_196_x264.mp4 166 | Testing_Normal_Videos_Anomaly/Normal_Videos_203_x264.mp4 167 | Testing_Normal_Videos_Anomaly/Normal_Videos_210_x264.mp4 168 | Testing_Normal_Videos_Anomaly/Normal_Videos_217_x264.mp4 169 | Testing_Normal_Videos_Anomaly/Normal_Videos_224_x264.mp4 170 | Testing_Normal_Videos_Anomaly/Normal_Videos_246_x264.mp4 171 | Testing_Normal_Videos_Anomaly/Normal_Videos_247_x264.mp4 172 | Testing_Normal_Videos_Anomaly/Normal_Videos_248_x264.mp4 173 | Testing_Normal_Videos_Anomaly/Normal_Videos_251_x264.mp4 174 | Testing_Normal_Videos_Anomaly/Normal_Videos_289_x264.mp4 175 | Testing_Normal_Videos_Anomaly/Normal_Videos_310_x264.mp4 176 | Testing_Normal_Videos_Anomaly/Normal_Videos_312_x264.mp4 177 | Testing_Normal_Videos_Anomaly/Normal_Videos_317_x264.mp4 178 | Testing_Normal_Videos_Anomaly/Normal_Videos_345_x264.mp4 179 | Testing_Normal_Videos_Anomaly/Normal_Videos_352_x264.mp4 180 | Testing_Normal_Videos_Anomaly/Normal_Videos_360_x264.mp4 181 | Testing_Normal_Videos_Anomaly/Normal_Videos_365_x264.mp4 182 | Testing_Normal_Videos_Anomaly/Normal_Videos_401_x264.mp4 183 | Testing_Normal_Videos_Anomaly/Normal_Videos_417_x264.mp4 184 | Testing_Normal_Videos_Anomaly/Normal_Videos_439_x264.mp4 185 | Testing_Normal_Videos_Anomaly/Normal_Videos_452_x264.mp4 186 | Testing_Normal_Videos_Anomaly/Normal_Videos_453_x264.mp4 187 | Testing_Normal_Videos_Anomaly/Normal_Videos_478_x264.mp4 188 | Testing_Normal_Videos_Anomaly/Normal_Videos_576_x264.mp4 189 | Testing_Normal_Videos_Anomaly/Normal_Videos_597_x264.mp4 190 | Testing_Normal_Videos_Anomaly/Normal_Videos_603_x264.mp4 191 | Testing_Normal_Videos_Anomaly/Normal_Videos_606_x264.mp4 192 | Testing_Normal_Videos_Anomaly/Normal_Videos_621_x264.mp4 193 | Testing_Normal_Videos_Anomaly/Normal_Videos_634_x264.mp4 194 | Testing_Normal_Videos_Anomaly/Normal_Videos_641_x264.mp4 195 | Testing_Normal_Videos_Anomaly/Normal_Videos_656_x264.mp4 196 | Testing_Normal_Videos_Anomaly/Normal_Videos_686_x264.mp4 197 | Testing_Normal_Videos_Anomaly/Normal_Videos_696_x264.mp4 198 | Testing_Normal_Videos_Anomaly/Normal_Videos_702_x264.mp4 199 | Testing_Normal_Videos_Anomaly/Normal_Videos_704_x264.mp4 200 | Testing_Normal_Videos_Anomaly/Normal_Videos_710_x264.mp4 201 | Testing_Normal_Videos_Anomaly/Normal_Videos_717_x264.mp4 202 | Testing_Normal_Videos_Anomaly/Normal_Videos_722_x264.mp4 203 | Testing_Normal_Videos_Anomaly/Normal_Videos_725_x264.mp4 204 | Testing_Normal_Videos_Anomaly/Normal_Videos_745_x264.mp4 205 | Testing_Normal_Videos_Anomaly/Normal_Videos_758_x264.mp4 206 | Testing_Normal_Videos_Anomaly/Normal_Videos_778_x264.mp4 207 | Testing_Normal_Videos_Anomaly/Normal_Videos_780_x264.mp4 208 | Testing_Normal_Videos_Anomaly/Normal_Videos_781_x264.mp4 209 | Testing_Normal_Videos_Anomaly/Normal_Videos_782_x264.mp4 210 | Testing_Normal_Videos_Anomaly/Normal_Videos_783_x264.mp4 211 | Testing_Normal_Videos_Anomaly/Normal_Videos_798_x264.mp4 212 | Testing_Normal_Videos_Anomaly/Normal_Videos_801_x264.mp4 213 | Testing_Normal_Videos_Anomaly/Normal_Videos_828_x264.mp4 214 | Testing_Normal_Videos_Anomaly/Normal_Videos_831_x264.mp4 215 | Testing_Normal_Videos_Anomaly/Normal_Videos_866_x264.mp4 216 | Testing_Normal_Videos_Anomaly/Normal_Videos_867_x264.mp4 217 | Testing_Normal_Videos_Anomaly/Normal_Videos_868_x264.mp4 218 | Testing_Normal_Videos_Anomaly/Normal_Videos_869_x264.mp4 219 | Testing_Normal_Videos_Anomaly/Normal_Videos_870_x264.mp4 220 | Testing_Normal_Videos_Anomaly/Normal_Videos_871_x264.mp4 221 | Testing_Normal_Videos_Anomaly/Normal_Videos_872_x264.mp4 222 | Testing_Normal_Videos_Anomaly/Normal_Videos_873_x264.mp4 223 | Testing_Normal_Videos_Anomaly/Normal_Videos_874_x264.mp4 224 | Testing_Normal_Videos_Anomaly/Normal_Videos_875_x264.mp4 225 | Testing_Normal_Videos_Anomaly/Normal_Videos_876_x264.mp4 226 | Testing_Normal_Videos_Anomaly/Normal_Videos_877_x264.mp4 227 | Testing_Normal_Videos_Anomaly/Normal_Videos_878_x264.mp4 228 | Testing_Normal_Videos_Anomaly/Normal_Videos_879_x264.mp4 229 | Testing_Normal_Videos_Anomaly/Normal_Videos_880_x264.mp4 230 | Testing_Normal_Videos_Anomaly/Normal_Videos_881_x264.mp4 231 | Testing_Normal_Videos_Anomaly/Normal_Videos_882_x264.mp4 232 | Testing_Normal_Videos_Anomaly/Normal_Videos_883_x264.mp4 233 | Testing_Normal_Videos_Anomaly/Normal_Videos_884_x264.mp4 234 | Testing_Normal_Videos_Anomaly/Normal_Videos_885_x264.mp4 235 | Testing_Normal_Videos_Anomaly/Normal_Videos_886_x264.mp4 236 | Testing_Normal_Videos_Anomaly/Normal_Videos_887_x264.mp4 237 | Testing_Normal_Videos_Anomaly/Normal_Videos_888_x264.mp4 238 | Testing_Normal_Videos_Anomaly/Normal_Videos_889_x264.mp4 239 | Testing_Normal_Videos_Anomaly/Normal_Videos_890_x264.mp4 240 | Testing_Normal_Videos_Anomaly/Normal_Videos_891_x264.mp4 241 | Testing_Normal_Videos_Anomaly/Normal_Videos_892_x264.mp4 242 | Testing_Normal_Videos_Anomaly/Normal_Videos_893_x264.mp4 243 | Testing_Normal_Videos_Anomaly/Normal_Videos_894_x264.mp4 244 | Testing_Normal_Videos_Anomaly/Normal_Videos_895_x264.mp4 245 | Testing_Normal_Videos_Anomaly/Normal_Videos_896_x264.mp4 246 | Testing_Normal_Videos_Anomaly/Normal_Videos_897_x264.mp4 247 | Testing_Normal_Videos_Anomaly/Normal_Videos_898_x264.mp4 248 | Testing_Normal_Videos_Anomaly/Normal_Videos_899_x264.mp4 249 | Testing_Normal_Videos_Anomaly/Normal_Videos_900_x264.mp4 250 | Testing_Normal_Videos_Anomaly/Normal_Videos_901_x264.mp4 251 | Testing_Normal_Videos_Anomaly/Normal_Videos_902_x264.mp4 252 | Testing_Normal_Videos_Anomaly/Normal_Videos_903_x264.mp4 253 | Testing_Normal_Videos_Anomaly/Normal_Videos_904_x264.mp4 254 | Testing_Normal_Videos_Anomaly/Normal_Videos_905_x264.mp4 255 | Testing_Normal_Videos_Anomaly/Normal_Videos_906_x264.mp4 256 | Testing_Normal_Videos_Anomaly/Normal_Videos_907_x264.mp4 257 | Testing_Normal_Videos_Anomaly/Normal_Videos_908_x264.mp4 258 | Testing_Normal_Videos_Anomaly/Normal_Videos_909_x264.mp4 259 | Testing_Normal_Videos_Anomaly/Normal_Videos_910_x264.mp4 260 | Testing_Normal_Videos_Anomaly/Normal_Videos_911_x264.mp4 261 | Testing_Normal_Videos_Anomaly/Normal_Videos_912_x264.mp4 262 | Testing_Normal_Videos_Anomaly/Normal_Videos_913_x264.mp4 263 | Testing_Normal_Videos_Anomaly/Normal_Videos_914_x264.mp4 264 | Testing_Normal_Videos_Anomaly/Normal_Videos_915_x264.mp4 265 | Testing_Normal_Videos_Anomaly/Normal_Videos_923_x264.mp4 266 | Testing_Normal_Videos_Anomaly/Normal_Videos_924_x264.mp4 267 | Testing_Normal_Videos_Anomaly/Normal_Videos_925_x264.mp4 268 | Testing_Normal_Videos_Anomaly/Normal_Videos_926_x264.mp4 269 | Testing_Normal_Videos_Anomaly/Normal_Videos_927_x264.mp4 270 | Testing_Normal_Videos_Anomaly/Normal_Videos_928_x264.mp4 271 | Testing_Normal_Videos_Anomaly/Normal_Videos_929_x264.mp4 272 | Testing_Normal_Videos_Anomaly/Normal_Videos_930_x264.mp4 273 | Testing_Normal_Videos_Anomaly/Normal_Videos_931_x264.mp4 274 | Testing_Normal_Videos_Anomaly/Normal_Videos_932_x264.mp4 275 | Testing_Normal_Videos_Anomaly/Normal_Videos_933_x264.mp4 276 | Testing_Normal_Videos_Anomaly/Normal_Videos_934_x264.mp4 277 | Testing_Normal_Videos_Anomaly/Normal_Videos_935_x264.mp4 278 | Testing_Normal_Videos_Anomaly/Normal_Videos_936_x264.mp4 279 | Testing_Normal_Videos_Anomaly/Normal_Videos_937_x264.mp4 280 | Testing_Normal_Videos_Anomaly/Normal_Videos_938_x264.mp4 281 | Testing_Normal_Videos_Anomaly/Normal_Videos_939_x264.mp4 282 | Testing_Normal_Videos_Anomaly/Normal_Videos_940_x264.mp4 283 | Testing_Normal_Videos_Anomaly/Normal_Videos_941_x264.mp4 284 | Testing_Normal_Videos_Anomaly/Normal_Videos_943_x264.mp4 285 | Testing_Normal_Videos_Anomaly/Normal_Videos_944_x264.mp4 286 | Vandalism/Vandalism007_x264.mp4 287 | Vandalism/Vandalism015_x264.mp4 288 | Vandalism/Vandalism017_x264.mp4 289 | Vandalism/Vandalism028_x264.mp4 290 | Vandalism/Vandalism036_x264.mp4 291 | -------------------------------------------------------------------------------- /list/Temporal_Anomaly_Annotation.txt: -------------------------------------------------------------------------------- 1 | Abuse028_x264 Abuse 165 240 -1 -1 2 | Abuse030_x264 Abuse 1275 1360 -1 -1 3 | Arrest001_x264 Arrest 1185 1485 -1 -1 4 | Arrest007_x264 Arrest 1530 2160 -1 -1 5 | Arrest024_x264 Arrest 1005 3105 -1 -1 6 | Arrest030_x264 Arrest 5535 7200 -1 -1 7 | Arrest039_x264 Arrest 7215 10335 -1 -1 8 | Arson007_x264 Arson 2250 5700 -1 -1 9 | Arson009_x264 Arson 220 315 -1 -1 10 | Arson010_x264 Arson 885 1230 -1 -1 11 | Arson011_x264 Arson 150 420 680 1267 12 | Arson016_x264 Arson 1000 1796 -1 -1 13 | Arson018_x264 Arson 270 600 -1 -1 14 | Arson022_x264 Arson 3500 4000 -1 -1 15 | Arson035_x264 Arson 600 900 -1 -1 16 | Arson041_x264 Arson 2130 3615 -1 -1 17 | Assault006_x264 Assault 1185 8096 -1 -1 18 | Assault010_x264 Assault 11330 11680 12260 12930 19 | Assault011_x264 Assault 375 960 -1 -1 20 | Burglary005_x264 Burglary 4710 5040 -1 -1 21 | Burglary017_x264 Burglary 150 600 -1 -1 22 | Burglary018_x264 Burglary 720 1050 -1 -1 23 | Burglary021_x264 Burglary 60 200 840 1340 24 | Burglary024_x264 Burglary 60 1230 -1 -1 25 | Burglary032_x264 Burglary 1290 3690 -1 -1 26 | Burglary033_x264 Burglary 60 330 -1 -1 27 | Burglary035_x264 Burglary 1 1740 -1 -1 28 | Burglary037_x264 Burglary 240 390 540 1800 29 | Burglary061_x264 Burglary 4200 5700 -1 -1 30 | Burglary076_x264 Burglary 1590 4300 -1 -1 31 | Burglary079_x264 Burglary 7750 10710 -1 -1 32 | Burglary092_x264 Burglary 240 420 -1 -1 33 | Explosion002_x264 Explosion 1500 2100 -1 -1 34 | Explosion004_x264 Explosion 75 225 -1 -1 35 | Explosion007_x264 Explosion 1590 2280 -1 -1 36 | Explosion008_x264 Explosion 1005 1245 -1 -1 37 | Explosion010_x264 Explosion 285 1080 -1 -1 38 | Explosion011_x264 Explosion 795 945 -1 -1 39 | Explosion013_x264 Explosion 2520 2970 -1 -1 40 | Explosion016_x264 Explosion 180 450 -1 -1 41 | Explosion017_x264 Explosion 990 1440 -1 -1 42 | Explosion020_x264 Explosion 60 270 -1 -1 43 | Explosion021_x264 Explosion 135 270 -1 -1 44 | Explosion022_x264 Explosion 2230 2420 -1 -1 45 | Explosion025_x264 Explosion 260 420 -1 -1 46 | Explosion027_x264 Explosion 105 180 -1 -1 47 | Explosion028_x264 Explosion 280 700 -1 -1 48 | Explosion029_x264 Explosion 1830 2020 -1 -1 49 | Explosion033_x264 Explosion 970 1350 1550 3156 50 | Explosion035_x264 Explosion 250 350 -1 -1 51 | Explosion036_x264 Explosion 1950 2070 -1 -1 52 | Explosion039_x264 Explosion 60 150 -1 -1 53 | Explosion043_x264 Explosion 4460 4600 -1 -1 54 | Fighting003_x264 Fighting 1820 3103 -1 -1 55 | Fighting018_x264 Fighting 80 420 -1 -1 56 | Fighting033_x264 Fighting 570 840 -1 -1 57 | Fighting042_x264 Fighting 290 1200 -1 -1 58 | Fighting047_x264 Fighting 200 1830 -1 -1 59 | Normal_Videos_003_x264 Normal -1 -1 -1 -1 60 | Normal_Videos_006_x264 Normal -1 -1 -1 -1 61 | Normal_Videos_010_x264 Normal -1 -1 -1 -1 62 | Normal_Videos_014_x264 Normal -1 -1 -1 -1 63 | Normal_Videos_015_x264 Normal -1 -1 -1 -1 64 | Normal_Videos_018_x264 Normal -1 -1 -1 -1 65 | Normal_Videos_019_x264 Normal -1 -1 -1 -1 66 | Normal_Videos_024_x264 Normal -1 -1 -1 -1 67 | Normal_Videos_025_x264 Normal -1 -1 -1 -1 68 | Normal_Videos_027_x264 Normal -1 -1 -1 -1 69 | Normal_Videos_033_x264 Normal -1 -1 -1 -1 70 | Normal_Videos_034_x264 Normal -1 -1 -1 -1 71 | Normal_Videos_041_x264 Normal -1 -1 -1 -1 72 | Normal_Videos_042_x264 Normal -1 -1 -1 -1 73 | Normal_Videos_048_x264 Normal -1 -1 -1 -1 74 | Normal_Videos_050_x264 Normal -1 -1 -1 -1 75 | Normal_Videos_051_x264 Normal -1 -1 -1 -1 76 | Normal_Videos_056_x264 Normal -1 -1 -1 -1 77 | Normal_Videos_059_x264 Normal -1 -1 -1 -1 78 | Normal_Videos_063_x264 Normal -1 -1 -1 -1 79 | Normal_Videos_067_x264 Normal -1 -1 -1 -1 80 | Normal_Videos_070_x264 Normal -1 -1 -1 -1 81 | Normal_Videos_100_x264 Normal -1 -1 -1 -1 82 | Normal_Videos_129_x264 Normal -1 -1 -1 -1 83 | Normal_Videos_150_x264 Normal -1 -1 -1 -1 84 | Normal_Videos_168_x264 Normal -1 -1 -1 -1 85 | Normal_Videos_175_x264 Normal -1 -1 -1 -1 86 | Normal_Videos_182_x264 Normal -1 -1 -1 -1 87 | Normal_Videos_189_x264 Normal -1 -1 -1 -1 88 | Normal_Videos_196_x264 Normal -1 -1 -1 -1 89 | Normal_Videos_203_x264 Normal -1 -1 -1 -1 90 | Normal_Videos_210_x264 Normal -1 -1 -1 -1 91 | Normal_Videos_217_x264 Normal -1 -1 -1 -1 92 | Normal_Videos_224_x264 Normal -1 -1 -1 -1 93 | Normal_Videos_246_x264 Normal -1 -1 -1 -1 94 | Normal_Videos_247_x264 Normal -1 -1 -1 -1 95 | Normal_Videos_248_x264 Normal -1 -1 -1 -1 96 | Normal_Videos_251_x264 Normal -1 -1 -1 -1 97 | Normal_Videos_289_x264 Normal -1 -1 -1 -1 98 | Normal_Videos_310_x264 Normal -1 -1 -1 -1 99 | Normal_Videos_312_x264 Normal -1 -1 -1 -1 100 | Normal_Videos_317_x264 Normal -1 -1 -1 -1 101 | Normal_Videos_345_x264 Normal -1 -1 -1 -1 102 | Normal_Videos_352_x264 Normal -1 -1 -1 -1 103 | Normal_Videos_360_x264 Normal -1 -1 -1 -1 104 | Normal_Videos_365_x264 Normal -1 -1 -1 -1 105 | Normal_Videos_401_x264 Normal -1 -1 -1 -1 106 | Normal_Videos_417_x264 Normal -1 -1 -1 -1 107 | Normal_Videos_439_x264 Normal -1 -1 -1 -1 108 | Normal_Videos_452_x264 Normal -1 -1 -1 -1 109 | Normal_Videos_453_x264 Normal -1 -1 -1 -1 110 | Normal_Videos_478_x264 Normal -1 -1 -1 -1 111 | Normal_Videos_576_x264 Normal -1 -1 -1 -1 112 | Normal_Videos_597_x264 Normal -1 -1 -1 -1 113 | Normal_Videos_603_x264 Normal -1 -1 -1 -1 114 | Normal_Videos_606_x264 Normal -1 -1 -1 -1 115 | Normal_Videos_621_x264 Normal -1 -1 -1 -1 116 | Normal_Videos_634_x264 Normal -1 -1 -1 -1 117 | Normal_Videos_641_x264 Normal -1 -1 -1 -1 118 | Normal_Videos_656_x264 Normal -1 -1 -1 -1 119 | Normal_Videos_686_x264 Normal -1 -1 -1 -1 120 | Normal_Videos_696_x264 Normal -1 -1 -1 -1 121 | Normal_Videos_702_x264 Normal -1 -1 -1 -1 122 | Normal_Videos_704_x264 Normal -1 -1 -1 -1 123 | Normal_Videos_710_x264 Normal -1 -1 -1 -1 124 | Normal_Videos_717_x264 Normal -1 -1 -1 -1 125 | Normal_Videos_722_x264 Normal -1 -1 -1 -1 126 | Normal_Videos_725_x264 Normal -1 -1 -1 -1 127 | Normal_Videos_745_x264 Normal -1 -1 -1 -1 128 | Normal_Videos_758_x264 Normal -1 -1 -1 -1 129 | Normal_Videos_778_x264 Normal -1 -1 -1 -1 130 | Normal_Videos_780_x264 Normal -1 -1 -1 -1 131 | Normal_Videos_781_x264 Normal -1 -1 -1 -1 132 | Normal_Videos_782_x264 Normal -1 -1 -1 -1 133 | Normal_Videos_783_x264 Normal -1 -1 -1 -1 134 | Normal_Videos_798_x264 Normal -1 -1 -1 -1 135 | Normal_Videos_801_x264 Normal -1 -1 -1 -1 136 | Normal_Videos_828_x264 Normal -1 -1 -1 -1 137 | Normal_Videos_831_x264 Normal -1 -1 -1 -1 138 | Normal_Videos_866_x264 Normal -1 -1 -1 -1 139 | Normal_Videos_867_x264 Normal -1 -1 -1 -1 140 | Normal_Videos_868_x264 Normal -1 -1 -1 -1 141 | Normal_Videos_869_x264 Normal -1 -1 -1 -1 142 | Normal_Videos_870_x264 Normal -1 -1 -1 -1 143 | Normal_Videos_871_x264 Normal -1 -1 -1 -1 144 | Normal_Videos_872_x264 Normal -1 -1 -1 -1 145 | Normal_Videos_873_x264 Normal -1 -1 -1 -1 146 | Normal_Videos_874_x264 Normal -1 -1 -1 -1 147 | Normal_Videos_875_x264 Normal -1 -1 -1 -1 148 | Normal_Videos_876_x264 Normal -1 -1 -1 -1 149 | Normal_Videos_877_x264 Normal -1 -1 -1 -1 150 | Normal_Videos_878_x264 Normal -1 -1 -1 -1 151 | Normal_Videos_879_x264 Normal -1 -1 -1 -1 152 | Normal_Videos_880_x264 Normal -1 -1 -1 -1 153 | Normal_Videos_881_x264 Normal -1 -1 -1 -1 154 | Normal_Videos_882_x264 Normal -1 -1 -1 -1 155 | Normal_Videos_883_x264 Normal -1 -1 -1 -1 156 | Normal_Videos_884_x264 Normal -1 -1 -1 -1 157 | Normal_Videos_885_x264 Normal -1 -1 -1 -1 158 | Normal_Videos_886_x264 Normal -1 -1 -1 -1 159 | Normal_Videos_887_x264 Normal -1 -1 -1 -1 160 | Normal_Videos_888_x264 Normal -1 -1 -1 -1 161 | Normal_Videos_889_x264 Normal -1 -1 -1 -1 162 | Normal_Videos_890_x264 Normal -1 -1 -1 -1 163 | Normal_Videos_891_x264 Normal -1 -1 -1 -1 164 | Normal_Videos_892_x264 Normal -1 -1 -1 -1 165 | Normal_Videos_893_x264 Normal -1 -1 -1 -1 166 | Normal_Videos_894_x264 Normal -1 -1 -1 -1 167 | Normal_Videos_895_x264 Normal -1 -1 -1 -1 168 | Normal_Videos_896_x264 Normal -1 -1 -1 -1 169 | Normal_Videos_897_x264 Normal -1 -1 -1 -1 170 | Normal_Videos_898_x264 Normal -1 -1 -1 -1 171 | Normal_Videos_899_x264 Normal -1 -1 -1 -1 172 | Normal_Videos_900_x264 Normal -1 -1 -1 -1 173 | Normal_Videos_901_x264 Normal -1 -1 -1 -1 174 | Normal_Videos_902_x264 Normal -1 -1 -1 -1 175 | Normal_Videos_903_x264 Normal -1 -1 -1 -1 176 | Normal_Videos_904_x264 Normal -1 -1 -1 -1 177 | Normal_Videos_905_x264 Normal -1 -1 -1 -1 178 | Normal_Videos_906_x264 Normal -1 -1 -1 -1 179 | Normal_Videos_907_x264 Normal -1 -1 -1 -1 180 | Normal_Videos_908_x264 Normal -1 -1 -1 -1 181 | Normal_Videos_909_x264 Normal -1 -1 -1 -1 182 | Normal_Videos_910_x264 Normal -1 -1 -1 -1 183 | Normal_Videos_911_x264 Normal -1 -1 -1 -1 184 | Normal_Videos_912_x264 Normal -1 -1 -1 -1 185 | Normal_Videos_913_x264 Normal -1 -1 -1 -1 186 | Normal_Videos_914_x264 Normal -1 -1 -1 -1 187 | Normal_Videos_915_x264 Normal -1 -1 -1 -1 188 | Normal_Videos_923_x264 Normal -1 -1 -1 -1 189 | Normal_Videos_924_x264 Normal -1 -1 -1 -1 190 | Normal_Videos_925_x264 Normal -1 -1 -1 -1 191 | Normal_Videos_926_x264 Normal -1 -1 -1 -1 192 | Normal_Videos_927_x264 Normal -1 -1 -1 -1 193 | Normal_Videos_928_x264 Normal -1 -1 -1 -1 194 | Normal_Videos_929_x264 Normal -1 -1 -1 -1 195 | Normal_Videos_930_x264 Normal -1 -1 -1 -1 196 | Normal_Videos_931_x264 Normal -1 -1 -1 -1 197 | Normal_Videos_932_x264 Normal -1 -1 -1 -1 198 | Normal_Videos_933_x264 Normal -1 -1 -1 -1 199 | Normal_Videos_934_x264 Normal -1 -1 -1 -1 200 | Normal_Videos_935_x264 Normal -1 -1 -1 -1 201 | Normal_Videos_936_x264 Normal -1 -1 -1 -1 202 | Normal_Videos_937_x264 Normal -1 -1 -1 -1 203 | Normal_Videos_938_x264 Normal -1 -1 -1 -1 204 | Normal_Videos_939_x264 Normal -1 -1 -1 -1 205 | Normal_Videos_940_x264 Normal -1 -1 -1 -1 206 | Normal_Videos_941_x264 Normal -1 -1 -1 -1 207 | Normal_Videos_943_x264 Normal -1 -1 -1 -1 208 | Normal_Videos_944_x264 Normal -1 -1 -1 -1 209 | RoadAccidents001_x264 RoadAccidents 210 300 -1 -1 210 | RoadAccidents002_x264 RoadAccidents 240 300 -1 -1 211 | RoadAccidents004_x264 RoadAccidents 140 189 -1 -1 212 | RoadAccidents009_x264 RoadAccidents 210 240 -1 -1 213 | RoadAccidents010_x264 RoadAccidents 230 270 -1 -1 214 | RoadAccidents011_x264 RoadAccidents 260 300 -1 -1 215 | RoadAccidents012_x264 RoadAccidents 250 390 -1 -1 216 | RoadAccidents016_x264 RoadAccidents 530 720 -1 -1 217 | RoadAccidents017_x264 RoadAccidents 60 130 -1 -1 218 | RoadAccidents019_x264 RoadAccidents 750 900 -1 -1 219 | RoadAccidents020_x264 RoadAccidents 610 730 -1 -1 220 | RoadAccidents021_x264 RoadAccidents 30 90 -1 -1 221 | RoadAccidents022_x264 RoadAccidents 120 220 -1 -1 222 | RoadAccidents121_x264 RoadAccidents 330 390 -1 -1 223 | RoadAccidents122_x264 RoadAccidents 300 360 -1 -1 224 | RoadAccidents123_x264 RoadAccidents 130 210 -1 -1 225 | RoadAccidents124_x264 RoadAccidents 250 420 -1 -1 226 | RoadAccidents125_x264 RoadAccidents 490 600 -1 -1 227 | RoadAccidents127_x264 RoadAccidents 2160 2300 -1 -1 228 | RoadAccidents128_x264 RoadAccidents 90 200 -1 -1 229 | RoadAccidents131_x264 RoadAccidents 180 240 -1 -1 230 | RoadAccidents132_x264 RoadAccidents 220 320 -1 -1 231 | RoadAccidents133_x264 RoadAccidents 270 450 -1 -1 232 | Robbery048_x264 Robbery 450 930 -1 -1 233 | Robbery050_x264 Robbery 495 1410 -1 -1 234 | Robbery102_x264 Robbery 1080 1560 -1 -1 235 | Robbery106_x264 Robbery 480 600 -1 -1 236 | Robbery137_x264 Robbery 135 1950 -1 -1 237 | Shooting002_x264 Shooting 1020 1100 -1 -1 238 | Shooting004_x264 Shooting 500 660 -1 -1 239 | Shooting007_x264 Shooting 45 165 -1 -1 240 | Shooting008_x264 Shooting 75 315 -1 -1 241 | Shooting010_x264 Shooting 1095 1260 -1 -1 242 | Shooting011_x264 Shooting 1480 1750 -1 -1 243 | Shooting013_x264 Shooting 860 945 -1 -1 244 | Shooting015_x264 Shooting 855 1715 -1 -1 245 | Shooting018_x264 Shooting 315 480 -1 -1 246 | Shooting019_x264 Shooting 1020 1455 -1 -1 247 | Shooting021_x264 Shooting 480 630 -1 -1 248 | Shooting022_x264 Shooting 2850 3300 -1 -1 249 | Shooting024_x264 Shooting 720 1305 -1 -1 250 | Shooting026_x264 Shooting 195 600 -1 -1 251 | Shooting028_x264 Shooting 285 555 -1 -1 252 | Shooting032_x264 Shooting 7995 8205 -1 -1 253 | Shooting033_x264 Shooting 1680 2000 -1 -1 254 | Shooting034_x264 Shooting 960 1050 -1 -1 255 | Shooting037_x264 Shooting 140 260 -1 -1 256 | Shooting043_x264 Shooting 945 1230 -1 -1 257 | Shooting046_x264 Shooting 4005 4230 4760 5088 258 | Shooting047_x264 Shooting 2160 3900 4860 6600 259 | Shooting048_x264 Shooting 1410 1730 -1 -1 260 | Shoplifting001_x264 Shoplifting 1550 2000 -1 -1 261 | Shoplifting004_x264 Shoplifting 2200 4900 -1 -1 262 | Shoplifting005_x264 Shoplifting 720 930 -1 -1 263 | Shoplifting007_x264 Shoplifting 550 760 4630 4920 264 | Shoplifting010_x264 Shoplifting 750 920 1550 1970 265 | Shoplifting015_x264 Shoplifting 2010 2160 -1 -1 266 | Shoplifting016_x264 Shoplifting 630 720 -1 -1 267 | Shoplifting017_x264 Shoplifting 360 420 -1 -1 268 | Shoplifting020_x264 Shoplifting 2340 2460 -1 -1 269 | Shoplifting021_x264 Shoplifting 2070 2220 -1 -1 270 | Shoplifting022_x264 Shoplifting 270 420 1440 1560 271 | Shoplifting027_x264 Shoplifting 1080 1160 1470 1710 272 | Shoplifting028_x264 Shoplifting 570 840 -1 -1 273 | Shoplifting029_x264 Shoplifting 1020 1470 -1 -1 274 | Shoplifting031_x264 Shoplifting 120 330 -1 -1 275 | Shoplifting033_x264 Shoplifting 630 750 -1 -1 276 | Shoplifting034_x264 Shoplifting 7350 7470 -1 -1 277 | Shoplifting037_x264 Shoplifting 1140 1200 -1 -1 278 | Shoplifting039_x264 Shoplifting 2190 2340 -1 -1 279 | Shoplifting044_x264 Shoplifting 11070 11250 -1 -1 280 | Shoplifting049_x264 Shoplifting 1020 1350 -1 -1 281 | Stealing019_x264 Stealing 2730 2790 4170 4350 282 | Stealing036_x264 Stealing 1260 1590 -1 -1 283 | Stealing058_x264 Stealing 570 3660 -1 -1 284 | Stealing062_x264 Stealing 360 1050 -1 -1 285 | Stealing079_x264 Stealing 2550 3210 3510 4500 286 | Vandalism007_x264 Vandalism 240 750 -1 -1 287 | Vandalism015_x264 Vandalism 2010 2700 -1 -1 288 | Vandalism017_x264 Vandalism 270 330 780 840 289 | Vandalism028_x264 Vandalism 1830 1980 2400 2670 290 | Vandalism036_x264 Vandalism 540 780 990 1080 291 | -------------------------------------------------------------------------------- /src/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | return self.resblocks(x) 204 | 205 | 206 | class VisionTransformer(nn.Module): 207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 208 | super().__init__() 209 | self.input_resolution = input_resolution 210 | self.output_dim = output_dim 211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 212 | 213 | scale = width ** -0.5 214 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 216 | self.ln_pre = LayerNorm(width) 217 | 218 | self.transformer = Transformer(width, layers, heads) 219 | 220 | self.ln_post = LayerNorm(width) 221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 222 | 223 | def forward(self, x: torch.Tensor): 224 | x = self.conv1(x) # shape = [*, width, grid, grid] 225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 228 | x = x + self.positional_embedding.to(x.dtype) 229 | x = self.ln_pre(x) 230 | 231 | x = x.permute(1, 0, 2) # NLD -> LND 232 | x = self.transformer(x) 233 | x = x.permute(1, 0, 2) # LND -> NLD 234 | 235 | x = self.ln_post(x[:, 0, :]) 236 | 237 | if self.proj is not None: 238 | x = x @ self.proj 239 | 240 | return x 241 | 242 | 243 | class CLIP(nn.Module): 244 | def __init__(self, 245 | embed_dim: int, 246 | # vision 247 | image_resolution: int, 248 | vision_layers: Union[Tuple[int, int, int, int], int], 249 | vision_width: int, 250 | vision_patch_size: int, 251 | # text 252 | context_length: int, 253 | vocab_size: int, 254 | transformer_width: int, 255 | transformer_heads: int, 256 | transformer_layers: int 257 | ): 258 | super().__init__() 259 | 260 | self.context_length = context_length 261 | 262 | if isinstance(vision_layers, (tuple, list)): 263 | vision_heads = vision_width * 32 // 64 264 | self.visual = ModifiedResNet( 265 | layers=vision_layers, 266 | output_dim=embed_dim, 267 | heads=vision_heads, 268 | input_resolution=image_resolution, 269 | width=vision_width 270 | ) 271 | else: 272 | vision_heads = vision_width // 64 273 | self.visual = VisionTransformer( 274 | input_resolution=image_resolution, 275 | patch_size=vision_patch_size, 276 | width=vision_width, 277 | layers=vision_layers, 278 | heads=vision_heads, 279 | output_dim=embed_dim 280 | ) 281 | 282 | self.transformer = Transformer( 283 | width=transformer_width, 284 | layers=transformer_layers, 285 | heads=transformer_heads, 286 | attn_mask=self.build_attention_mask() 287 | ) 288 | 289 | self.vocab_size = vocab_size 290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 292 | self.ln_final = LayerNorm(transformer_width) 293 | 294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 296 | 297 | self.initialize_parameters() 298 | 299 | def initialize_parameters(self): 300 | nn.init.normal_(self.token_embedding.weight, std=0.02) 301 | nn.init.normal_(self.positional_embedding, std=0.01) 302 | 303 | if isinstance(self.visual, ModifiedResNet): 304 | if self.visual.attnpool is not None: 305 | std = self.visual.attnpool.c_proj.in_features ** -0.5 306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 310 | 311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 312 | for name, param in resnet_block.named_parameters(): 313 | if name.endswith("bn3.weight"): 314 | nn.init.zeros_(param) 315 | 316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 317 | attn_std = self.transformer.width ** -0.5 318 | fc_std = (2 * self.transformer.width) ** -0.5 319 | for block in self.transformer.resblocks: 320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 324 | 325 | if self.text_projection is not None: 326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 327 | 328 | def build_attention_mask(self): 329 | # lazily create causal attention mask, with full attention between the vision tokens 330 | # pytorch uses additive attention mask; fill with -inf 331 | mask = torch.empty(self.context_length, self.context_length) 332 | mask.fill_(float("-inf")) 333 | mask.triu_(1) # zero out the lower diagonal 334 | return mask 335 | 336 | @property 337 | def dtype(self): 338 | return self.visual.conv1.weight.dtype 339 | 340 | def encode_image(self, image): 341 | return self.visual(image.type(self.dtype)) 342 | 343 | def encode_token(self, token): 344 | x = self.token_embedding(token) 345 | return x 346 | 347 | def encode_text(self, text, token): 348 | #x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 349 | x = text.type(self.dtype) + self.positional_embedding.type(self.dtype) 350 | x = x.permute(1, 0, 2) # NLD -> LND 351 | x = self.transformer(x) 352 | x = x.permute(1, 0, 2) # LND -> NLD 353 | x = self.ln_final(x).type(self.dtype) 354 | 355 | # x.shape = [batch_size, n_ctx, transformer.width] 356 | # take features from the eot embedding (eot_token is the highest number in each sequence) 357 | x = x[torch.arange(x.shape[0]), token.argmax(dim=-1)] @ self.text_projection 358 | 359 | return x 360 | 361 | def forward(self, image, text): 362 | image_features = self.encode_image(image) 363 | text_features = self.encode_text(text) 364 | 365 | # normalized features 366 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 367 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 368 | 369 | # cosine similarity as logits 370 | logit_scale = self.logit_scale.exp() 371 | logits_per_image = logit_scale * image_features @ text_features.t() 372 | logits_per_text = logits_per_image.t() 373 | 374 | # shape = [global_batch_size, global_batch_size] 375 | return logits_per_image, logits_per_text 376 | 377 | 378 | def convert_weights(model: nn.Module): 379 | """Convert applicable model parameters to fp16""" 380 | 381 | def _convert_weights_to_fp16(l): 382 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 383 | l.weight.data = l.weight.data.half() 384 | if l.bias is not None: 385 | l.bias.data = l.bias.data.half() 386 | 387 | if isinstance(l, nn.MultiheadAttention): 388 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 389 | tensor = getattr(l, attr) 390 | if tensor is not None: 391 | tensor.data = tensor.data.half() 392 | 393 | for name in ["text_projection", "proj"]: 394 | if hasattr(l, name): 395 | attr = getattr(l, name) 396 | if attr is not None: 397 | attr.data = attr.data.half() 398 | 399 | model.apply(_convert_weights_to_fp16) 400 | 401 | 402 | def build_model(state_dict: dict): 403 | vit = "visual.proj" in state_dict 404 | 405 | if vit: 406 | vision_width = state_dict["visual.conv1.weight"].shape[0] 407 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 408 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 409 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 410 | image_resolution = vision_patch_size * grid_size 411 | else: 412 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 413 | vision_layers = tuple(counts) 414 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 415 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 416 | vision_patch_size = None 417 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 418 | image_resolution = output_width * 32 419 | 420 | embed_dim = state_dict["text_projection"].shape[1] 421 | context_length = state_dict["positional_embedding"].shape[0] 422 | vocab_size = state_dict["token_embedding.weight"].shape[0] 423 | transformer_width = state_dict["ln_final.weight"].shape[0] 424 | transformer_heads = transformer_width // 64 425 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 426 | 427 | model = CLIP( 428 | embed_dim, 429 | image_resolution, vision_layers, vision_width, vision_patch_size, 430 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 431 | ) 432 | 433 | for key in ["input_resolution", "context_length", "vocab_size"]: 434 | if key in state_dict: 435 | del state_dict[key] 436 | 437 | #convert_weights(model) 438 | model.load_state_dict(state_dict) 439 | return model.eval() 440 | -------------------------------------------------------------------------------- /list/ucf_CLIP_rgbtest.csv: -------------------------------------------------------------------------------- 1 | path,label 2 | /home/xbgydx/Desktop/UCFClipFeatures/Abuse/Abuse028_x264__5.npy,Abuse 3 | /home/xbgydx/Desktop/UCFClipFeatures/Abuse/Abuse030_x264__5.npy,Abuse 4 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest001_x264__5.npy,Arrest 5 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest007_x264__5.npy,Arrest 6 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest024_x264__5.npy,Arrest 7 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest030_x264__5.npy,Arrest 8 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest039_x264__5.npy,Arrest 9 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson007_x264__5.npy,Arson 10 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson009_x264__5.npy,Arson 11 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson010_x264__5.npy,Arson 12 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson011_x264__5.npy,Arson 13 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson016_x264__5.npy,Arson 14 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson018_x264__5.npy,Arson 15 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson022_x264__5.npy,Arson 16 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson035_x264__5.npy,Arson 17 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson041_x264__5.npy,Arson 18 | /home/xbgydx/Desktop/UCFClipFeatures/Assault/Assault006_x264__5.npy,Assault 19 | /home/xbgydx/Desktop/UCFClipFeatures/Assault/Assault010_x264__5.npy,Assault 20 | /home/xbgydx/Desktop/UCFClipFeatures/Assault/Assault011_x264__5.npy,Assault 21 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary005_x264__5.npy,Burglary 22 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary017_x264__5.npy,Burglary 23 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary018_x264__5.npy,Burglary 24 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary021_x264__5.npy,Burglary 25 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary024_x264__5.npy,Burglary 26 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary032_x264__5.npy,Burglary 27 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary033_x264__5.npy,Burglary 28 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary035_x264__5.npy,Burglary 29 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary037_x264__5.npy,Burglary 30 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary061_x264__5.npy,Burglary 31 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary076_x264__5.npy,Burglary 32 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary079_x264__5.npy,Burglary 33 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary092_x264__5.npy,Burglary 34 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion002_x264__5.npy,Explosion 35 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion004_x264__5.npy,Explosion 36 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion007_x264__5.npy,Explosion 37 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion008_x264__5.npy,Explosion 38 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion010_x264__5.npy,Explosion 39 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion011_x264__5.npy,Explosion 40 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion013_x264__5.npy,Explosion 41 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion016_x264__5.npy,Explosion 42 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion017_x264__5.npy,Explosion 43 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion020_x264__5.npy,Explosion 44 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion021_x264__5.npy,Explosion 45 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion022_x264__5.npy,Explosion 46 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion025_x264__5.npy,Explosion 47 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion027_x264__5.npy,Explosion 48 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion028_x264__5.npy,Explosion 49 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion029_x264__5.npy,Explosion 50 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion033_x264__5.npy,Explosion 51 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion035_x264__5.npy,Explosion 52 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion036_x264__5.npy,Explosion 53 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion039_x264__5.npy,Explosion 54 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion043_x264__5.npy,Explosion 55 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting003_x264__5.npy,Fighting 56 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting018_x264__5.npy,Fighting 57 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting033_x264__5.npy,Fighting 58 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting042_x264__5.npy,Fighting 59 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting047_x264__5.npy,Fighting 60 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents001_x264__5.npy,RoadAccidents 61 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents002_x264__5.npy,RoadAccidents 62 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents004_x264__5.npy,RoadAccidents 63 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents009_x264__5.npy,RoadAccidents 64 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents010_x264__5.npy,RoadAccidents 65 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents011_x264__5.npy,RoadAccidents 66 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents012_x264__5.npy,RoadAccidents 67 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents016_x264__5.npy,RoadAccidents 68 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents017_x264__5.npy,RoadAccidents 69 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents019_x264__5.npy,RoadAccidents 70 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents020_x264__5.npy,RoadAccidents 71 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents021_x264__5.npy,RoadAccidents 72 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents022_x264__5.npy,RoadAccidents 73 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents121_x264__5.npy,RoadAccidents 74 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents122_x264__5.npy,RoadAccidents 75 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents123_x264__5.npy,RoadAccidents 76 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents124_x264__5.npy,RoadAccidents 77 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents125_x264__5.npy,RoadAccidents 78 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents127_x264__5.npy,RoadAccidents 79 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents128_x264__5.npy,RoadAccidents 80 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents131_x264__5.npy,RoadAccidents 81 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents132_x264__5.npy,RoadAccidents 82 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents133_x264__5.npy,RoadAccidents 83 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery048_x264__5.npy,Robbery 84 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery050_x264__5.npy,Robbery 85 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery102_x264__5.npy,Robbery 86 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery106_x264__5.npy,Robbery 87 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery137_x264__5.npy,Robbery 88 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting002_x264__5.npy,Shooting 89 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting004_x264__5.npy,Shooting 90 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting007_x264__5.npy,Shooting 91 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting008_x264__5.npy,Shooting 92 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting010_x264__5.npy,Shooting 93 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting011_x264__5.npy,Shooting 94 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting013_x264__5.npy,Shooting 95 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting015_x264__5.npy,Shooting 96 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting018_x264__5.npy,Shooting 97 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting019_x264__5.npy,Shooting 98 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting021_x264__5.npy,Shooting 99 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting022_x264__5.npy,Shooting 100 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting024_x264__5.npy,Shooting 101 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting026_x264__5.npy,Shooting 102 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting028_x264__5.npy,Shooting 103 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting032_x264__5.npy,Shooting 104 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting033_x264__5.npy,Shooting 105 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting034_x264__5.npy,Shooting 106 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting037_x264__5.npy,Shooting 107 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting043_x264__5.npy,Shooting 108 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting046_x264__5.npy,Shooting 109 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting047_x264__5.npy,Shooting 110 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting048_x264__5.npy,Shooting 111 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting001_x264__5.npy,Shoplifting 112 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting004_x264__5.npy,Shoplifting 113 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting005_x264__5.npy,Shoplifting 114 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting007_x264__5.npy,Shoplifting 115 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting010_x264__5.npy,Shoplifting 116 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting015_x264__5.npy,Shoplifting 117 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting016_x264__5.npy,Shoplifting 118 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting017_x264__5.npy,Shoplifting 119 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting020_x264__5.npy,Shoplifting 120 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting021_x264__5.npy,Shoplifting 121 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting022_x264__5.npy,Shoplifting 122 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting027_x264__5.npy,Shoplifting 123 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting028_x264__5.npy,Shoplifting 124 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting029_x264__5.npy,Shoplifting 125 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting031_x264__5.npy,Shoplifting 126 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting033_x264__5.npy,Shoplifting 127 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting034_x264__5.npy,Shoplifting 128 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting037_x264__5.npy,Shoplifting 129 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting039_x264__5.npy,Shoplifting 130 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting044_x264__5.npy,Shoplifting 131 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting049_x264__5.npy,Shoplifting 132 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing019_x264__5.npy,Stealing 133 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing036_x264__5.npy,Stealing 134 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing058_x264__5.npy,Stealing 135 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing062_x264__5.npy,Stealing 136 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing079_x264__5.npy,Stealing 137 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism007_x264__5.npy,Vandalism 138 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism015_x264__5.npy,Vandalism 139 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism017_x264__5.npy,Vandalism 140 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism028_x264__5.npy,Vandalism 141 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism036_x264__5.npy,Vandalism 142 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_003_x264__5.npy,Normal 143 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_006_x264__5.npy,Normal 144 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_010_x264__5.npy,Normal 145 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_014_x264__5.npy,Normal 146 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_015_x264__5.npy,Normal 147 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_018_x264__5.npy,Normal 148 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_019_x264__5.npy,Normal 149 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_024_x264__5.npy,Normal 150 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_025_x264__5.npy,Normal 151 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_027_x264__5.npy,Normal 152 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_033_x264__5.npy,Normal 153 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_034_x264__5.npy,Normal 154 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_041_x264__5.npy,Normal 155 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_042_x264__5.npy,Normal 156 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_048_x264__5.npy,Normal 157 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_050_x264__5.npy,Normal 158 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_051_x264__5.npy,Normal 159 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_056_x264__5.npy,Normal 160 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_059_x264__5.npy,Normal 161 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_063_x264__5.npy,Normal 162 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_067_x264__5.npy,Normal 163 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_070_x264__5.npy,Normal 164 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_100_x264__5.npy,Normal 165 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_129_x264__5.npy,Normal 166 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_150_x264__5.npy,Normal 167 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_168_x264__5.npy,Normal 168 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_175_x264__5.npy,Normal 169 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_182_x264__5.npy,Normal 170 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_189_x264__5.npy,Normal 171 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_196_x264__5.npy,Normal 172 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_203_x264__5.npy,Normal 173 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_210_x264__5.npy,Normal 174 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_217_x264__5.npy,Normal 175 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_224_x264__5.npy,Normal 176 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_246_x264__5.npy,Normal 177 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_247_x264__5.npy,Normal 178 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_248_x264__5.npy,Normal 179 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_251_x264__5.npy,Normal 180 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_289_x264__5.npy,Normal 181 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_310_x264__5.npy,Normal 182 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_312_x264__5.npy,Normal 183 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_317_x264__5.npy,Normal 184 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_345_x264__5.npy,Normal 185 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_352_x264__5.npy,Normal 186 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_360_x264__5.npy,Normal 187 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_365_x264__5.npy,Normal 188 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_401_x264__5.npy,Normal 189 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_417_x264__5.npy,Normal 190 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_439_x264__5.npy,Normal 191 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_452_x264__5.npy,Normal 192 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_453_x264__5.npy,Normal 193 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_478_x264__5.npy,Normal 194 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_576_x264__5.npy,Normal 195 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_597_x264__5.npy,Normal 196 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_603_x264__5.npy,Normal 197 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_606_x264__5.npy,Normal 198 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_621_x264__5.npy,Normal 199 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_634_x264__5.npy,Normal 200 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_641_x264__5.npy,Normal 201 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_656_x264__5.npy,Normal 202 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_686_x264__5.npy,Normal 203 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_696_x264__5.npy,Normal 204 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_702_x264__5.npy,Normal 205 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_704_x264__5.npy,Normal 206 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_710_x264__5.npy,Normal 207 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_717_x264__5.npy,Normal 208 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_722_x264__5.npy,Normal 209 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_725_x264__5.npy,Normal 210 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_745_x264__5.npy,Normal 211 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_758_x264__5.npy,Normal 212 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_778_x264__5.npy,Normal 213 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_780_x264__5.npy,Normal 214 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_781_x264__5.npy,Normal 215 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_782_x264__5.npy,Normal 216 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_783_x264__5.npy,Normal 217 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_798_x264__5.npy,Normal 218 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_801_x264__5.npy,Normal 219 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_828_x264__5.npy,Normal 220 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_831_x264__5.npy,Normal 221 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_866_x264__5.npy,Normal 222 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_867_x264__5.npy,Normal 223 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_868_x264__5.npy,Normal 224 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_869_x264__5.npy,Normal 225 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_870_x264__5.npy,Normal 226 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_871_x264__5.npy,Normal 227 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_872_x264__5.npy,Normal 228 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_873_x264__5.npy,Normal 229 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_874_x264__5.npy,Normal 230 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_875_x264__5.npy,Normal 231 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_876_x264__5.npy,Normal 232 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_877_x264__5.npy,Normal 233 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_878_x264__5.npy,Normal 234 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_879_x264__5.npy,Normal 235 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_880_x264__5.npy,Normal 236 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_881_x264__5.npy,Normal 237 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_882_x264__5.npy,Normal 238 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_883_x264__5.npy,Normal 239 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_884_x264__5.npy,Normal 240 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_885_x264__5.npy,Normal 241 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_886_x264__5.npy,Normal 242 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_887_x264__5.npy,Normal 243 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_888_x264__5.npy,Normal 244 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_889_x264__5.npy,Normal 245 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_890_x264__5.npy,Normal 246 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_891_x264__5.npy,Normal 247 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_892_x264__5.npy,Normal 248 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_893_x264__5.npy,Normal 249 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_894_x264__5.npy,Normal 250 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_895_x264__5.npy,Normal 251 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_896_x264__5.npy,Normal 252 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_897_x264__5.npy,Normal 253 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_898_x264__5.npy,Normal 254 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_899_x264__5.npy,Normal 255 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_900_x264__5.npy,Normal 256 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_901_x264__5.npy,Normal 257 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_902_x264__5.npy,Normal 258 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_903_x264__5.npy,Normal 259 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_904_x264__5.npy,Normal 260 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_905_x264__5.npy,Normal 261 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_906_x264__5.npy,Normal 262 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_907_x264__5.npy,Normal 263 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_908_x264__5.npy,Normal 264 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_909_x264__5.npy,Normal 265 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_910_x264__5.npy,Normal 266 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_911_x264__5.npy,Normal 267 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_912_x264__5.npy,Normal 268 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_913_x264__5.npy,Normal 269 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_914_x264__5.npy,Normal 270 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_915_x264__5.npy,Normal 271 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_923_x264__5.npy,Normal 272 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_924_x264__5.npy,Normal 273 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_925_x264__5.npy,Normal 274 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_926_x264__5.npy,Normal 275 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_927_x264__5.npy,Normal 276 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_928_x264__5.npy,Normal 277 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_929_x264__5.npy,Normal 278 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_930_x264__5.npy,Normal 279 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_931_x264__5.npy,Normal 280 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_932_x264__5.npy,Normal 281 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_933_x264__5.npy,Normal 282 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_934_x264__5.npy,Normal 283 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_935_x264__5.npy,Normal 284 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_936_x264__5.npy,Normal 285 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_937_x264__5.npy,Normal 286 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_938_x264__5.npy,Normal 287 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_939_x264__5.npy,Normal 288 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_940_x264__5.npy,Normal 289 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_941_x264__5.npy,Normal 290 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_943_x264__5.npy,Normal 291 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_944_x264__5.npy,Normal 292 | -------------------------------------------------------------------------------- /list/annotations.txt: -------------------------------------------------------------------------------- 1 | v=S-7rRLrxnVQ__#1_label_B4-0-0 0 1517 1970 3038 2 | v=u5SF4SlqNDQ__#00-00-00_00-01-00_label_G-0-0 624 912 950 1441 3 | v=u5SF4SlqNDQ__#00-02-39_00-03-41_label_G-0-0 77 1430 4 | v=cEOM18n8fhU__#1_label_G-0-0 2133 2750 5 | v=NnmqkS1e88s__#1_label_B4-0-0 0 1750 6 | v=vaSOMEIe1Bg__#1_label_G-0-0 300 1380 7 | v=ZnFlL84K7HE__#00-27-53_00-34-45_label_B4-0-0 0 9523 8 | v=wVey5JDRf_g__#00-00-00_00-01-20_label_B6-0-0 51 210 273 410 544 720 828 970 1135 1270 1413 1500 1618 1710 1822 1920 9 | v=wVey5JDRf_g__#00-01-30_00-02-41_label_B6-0-0 142 296 447 508 606 710 800 870 1000 1185 1414 1550 1635 1700 10 | v=wVey5JDRf_g__#00-04-09_00-05-06_label_B6-0-0 85 200 352 438 547 690 777 833 918 952 1055 1170 1314 1365 11 | v=m8EkFsaGPzU__#1_label_B4-0-0 0 3380 4523 5700 12 | v=5BBoVfKOyeM__#00-08-19_00-10-10_label_B6-0-0 1067 1200 1650 1734 13 | v=BQjKQbYgUBA__#1_label_B1-0-0 635 667 733 759 932 1020 1234 1297 1977 2373 3195 3422 14 | v=bfOlheR5nUQ__#00-00-00_00-04-33_label_B4-0-0 0 6553 15 | v=bfOlheR5nUQ__#00-04-33_00-06-02_label_B4-0-0 0 2135 16 | v=YKtw6uLEGps__#1_label_B4-0-0 807 2118 17 | v=7xQr5wYwNPg__#1_label_B4-0-0 0 403 600 1008 18 | v=-fOWSLV6Esw__#1_label_B4-0-0 20 772 19 | v=b0MUjeKAGZw__#1_label_B4-0-0 0 2790 20 | Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6 157 180 185 244 250 360 582 810 21 | Bad.Boys.1995__#01-33-51_01-34-37_label_B2-0-0 57 670 728 800 22 | Bad.Boys.II.2003__#00-06-42_00-10-00_label_B2-G-0 2610 2670 2760 3060 3167 4310 23 | v=DD3jfKr8e-k__#00-00-00_00-01-40_label_B4-0-0 0 2400 24 | v=2jwO15SMyuo__#1_label_B4-0-0 0 2374 25 | v=jYENhkzdpO8__#00-00-00_00-01-31_label_B6-0-0 245 335 1187 1252 1836 2000 26 | v=1djrJ0wxlYo__#1_label_B4-0-0 0 758 880 1533 27 | v=7zEFHKKBA0g__#1_label_B4-0-0 1192 1600 2031 2690 28 | v=hxyhulJYz5I__#00-06-09_00-07-00_label_B6-0-0 25 70 256 360 520 640 680 855 927 1000 1140 1220 29 | v=bhZs3ALdL7Y__#1_label_G-0-0 0 70 30 | Black.Hawk.Down.2001__#01-13-59_01-14-49_label_B2-0-0 188 1190 31 | Black.Hawk.Down.2001__#01-32-40_01-34-00_label_B4-0-0 0 269 518 882 1090 1180 1220 1919 32 | Black.Hawk.Down.2001__#01-42-58_01-43-58_label_G-0-0 406 505 33 | Black.Hawk.Down.2001__#02-00-12_02-01-29_label_B2-0-0 10 1848 34 | v=8tqeeBGjnPg__#00-03-12_00-04-15_label_B6-0-0 160 300 690 980 1100 1270 35 | v=6nfo9c7a5pE__#1_label_B1-0-0 0 150 295 700 36 | v=rJz4NXm6vis__#1_label_B4-0-0 0 1210 1351 1797 37 | Braveheart.1995__#00-56-30_00-57-20_label_B1-0-0 0 360 38 | Braveheart.1995__#01-26-50_01-32-30_label_B1-0-0 1011 7515 39 | Braveheart.1995__#02-05-34_02-06-40_label_B1-0-0 256 1580 40 | Braveheart.1995__#02-07-00_02-08-15_label_B1-0-0 450 600 1027 1300 41 | v=zI0q5UDP47g__#1_label_B4-0-0 66 3560 42 | Brick.Mansions.2014__#00-11-57_00-12-12_label_B2-0-0 308 338 43 | Brick.Mansions.2014__#00-16-26_00-17-12_label_B1-0-0 30 1100 44 | Brick.Mansions.2014__#00-41-25_00-42-36_label_B1-0-0 523 610 900 1700 45 | Brick.Mansions.2014__#01-02-00_01-02-43_label_B2-0-0 280 550 46 | v=73uRcX0Dvfc__#00-00-00_00-01-51_label_B6-G-0 177 217 450 540 958 1260 1507 1670 2053 2180 2460 2589 47 | v=Ia9ATKNeUbY__#00-01-52_00-03-10_label_B6-0-0 90 145 470 570 1130 1200 1493 1660 48 | v=Ia9ATKNeUbY__#00-04-19_00-05-11_label_B6-0-0 27 133 274 510 1137 1170 49 | v=Ia9ATKNeUbY__#00-05-19_00-06-09_label_B6-0-0 120 180 800 855 1112 1185 50 | v=Ia9ATKNeUbY__#00-06-19_00-06-50_label_B6-0-0 246 290 427 527 687 744 51 | Bullet.in.the.Head.1990__#00-17-20_00-18-55_label_B1-0-0 138 280 546 1915 52 | Bullet.in.the.Head.1990__#00-41-30_00-44-16_label_B4-G-0 388 3425 3626 3668 53 | Bullet.in.the.Head.1990__#01-26-30_01-27-34_label_B2-0-0 695 820 54 | Bullet.in.the.Head.1990__#01-45-21_01-46-20_label_B2-0-0 400 502 55 | Bullet.in.the.Head.1990__#02-02-00_02-05-22_label_B2-B6-G 0 35 203 860 1005 1465 2130 2380 2860 3250 3541 4346 56 | v=EOAfl8yMiN8__#00-00-00_00-00-51_label_B6-0-0 119 150 330 380 515 595 708 760 940 1200 57 | v=EOAfl8yMiN8__#00-03-00_00-04-09_label_B6-0-0 245 325 470 520 716 850 960 1034 1165 1200 1370 1454 1614 1652 58 | v=kVl-6-A9ZO4__#00-08-01_00-09-35_label_B6-0-0 65 140 428 500 693 800 986 1117 1550 1700 2038 2218 59 | v=iHuggczItBk__#00-01-00_00-02-45_label_B6-0-0 70 140 600 700 1080 1223 1459 1474 1750 1950 2257 2315 2474 2520 60 | v=iHuggczItBk__#00-03-00_00-03-55_label_B6-0-0 180 240 450 555 880 960 1200 1300 61 | v=iHuggczItBk__#00-04-00_00-04-55_label_B6-0-0 56 173 385 435 780 830 62 | v=0qtIjyt-7wg__#00-00-00_00-00-51_label_B6-0-0 205 270 537 610 1028 1115 63 | v=0qtIjyt-7wg__#00-01-00_00-02-45_label_B6-0-0 264 340 760 900 1337 1430 1612 1656 1806 1940 2390 2480 64 | v=0qtIjyt-7wg__#00-04-00_00-05-15_label_B6-0-0 25 55 265 315 65 | v=vhACO_m5pH0__#00-09-42_00-10-40_label_B6-0-0 238 314 475 520 911 960 1084 1208 66 | v=vhACO_m5pH0__#00-10-42_00-11-40_label_B6-0-0 48 55 303 354 890 1024 67 | v=vhACO_m5pH0__#00-14-42_00-15-40_label_B6-0-0 20 140 460 500 735 820 970 1070 1225 1266 68 | v=vhACO_m5pH0__#00-17-42_00-18-40_label_B6-0-0 0 45 148 218 440 486 640 710 860 900 1010 1060 1248 1300 69 | v=vhACO_m5pH0__#00-20-42_00-22-40_label_B6-0-0 30 130 300 380 625 680 830 920 1015 1085 1346 1368 1596 1642 1820 1869 2020 2100 2258 2345 2530 2560 2670 2749 70 | v=tT-SrQC6Ddw__#00-04-00_00-05-15_label_B6-0-0 185 226 1237 1320 1485 1640 71 | v=ZoS8gm5OcOM__#00-06-20_00-07-30_label_B6-0-0 364 627 1464 1680 72 | v=ZoS8gm5OcOM__#00-07-50_00-08-55_label_B6-0-0 267 320 795 850 1156 1454 73 | v=D1eP4Bn4hDQ__#00-09-20_00-10-45_label_B6-0-0 319 375 1050 1105 1390 1445 74 | v=BFPkt0AuSFA__#1_label_B4-0-0 28 3350 75 | v=251___mEwZA__#1_label_B1-0-0 183 1359 1453 2487 3749 4540 76 | Casino.Royale.2006__#00-18-30_00-19-20_label_G-B2-0 261 580 77 | Casino.Royale.2006__#00-50-05_00-51-16_label_B1-B2-B6 43 180 383 625 702 830 860 1006 1066 1096 1527 1634 78 | Casino.Royale.2006__#00-51-16_00-52-41_label_B1-B6-0 37 400 977 1400 1590 1752 79 | Casino.Royale.2006__#01-17-29_01-17-46_label_B1-0-0 216 350 80 | Casino.Royale.2006__#01-46-40_01-47-14_label_B6-0-0 197 600 81 | v=ROrpKx3aIjA__#1_label_G-0-0 713 830 82 | v=Hnfy9XhlIPM__#1_label_B4-0-0 66 4620 83 | v=J_ZsNd97wXw__#1_label_G-0-0 60 660 1025 1090 84 | v=zQbnQBCTSiA__#1_label_B1-0-0 663 1530 2020 2553 3319 3710 4157 4478 4597 4930 85 | City.of.God.2002__#00-37-20_00-38-02_label_B5-0-0 530 750 86 | City.of.God.2002__#00-40-16_00-41-30_label_B2-0-0 460 490 858 910 1142 1330 87 | City.of.God.2002__#01-52-20_01-54-32_label_B2-0-0 24 54 585 611 2892 3167 88 | City.Of.Men.2007__#00-51-50_00-53-31_label_B2-0-0 254 540 1519 1545 89 | City.Of.Men.2007__#00-57-37_00-58-27_label_B2-0-0 218 280 883 920 90 | City.Of.Men.2007__#01-17-27_01-17-59_label_B2-0-0 197 394 91 | City.Of.Men.2007__#01-35-30_01-36-00_label_B2-0-0 135 265 92 | v=qrKfaX1lCUM__#1_label_G-0-0 247 280 1601 1656 1684 1736 1928 2106 2323 2355 2428 2471 2645 2875 5600 5645 5812 6000 93 | v=ZkUciDD55kA__#00-00-00_00-00-30_label_G-0-0 124 617 94 | Crank.Dircut.2006__#0-27-42_0-29-01_label_B1-0-0 347 445 560 1427 1886 1894 95 | v=OOjjPGN8jSU__#1_label_B6-0-0 884 911 994 1030 1083 1102 1235 1266 1375 1483 1680 1745 1830 1890 2070 2250 2320 2520 2557 2650 2798 3030 3175 3265 3622 3670 3935 3999 4089 4150 4435 4515 4590 4687 96 | v=uQY15O3LKI0__#1_label_B6-0-0 297 336 600 750 811 912 1068 1208 1320 1350 1425 1470 1593 1640 1742 1860 1949 2020 2140 2350 2550 2700 2867 2980 3073 3120 3255 3350 3489 3526 3623 3700 3804 3965 97 | Deadpool.2.2018__#0-04-46_0-05-01_label_B2-0-0 113 190 246 300 98 | Deadpool.2.2018__#0-50-30_0-51-20_label_B1-0-0 220 320 548 630 837 1125 99 | Deadpool.2.2018__#01-03-09_01-03-54_label_B4-0-0 390 608 100 | Deadpool.2016__#0-18-58_0-19-20_label_B1-0-0 80 130 290 510 101 | Death.Proof.2007__#00-45-05_00-47-36_label_B5-0-0 1268 1600 3418 3450 102 | Death.Proof.2007__#01-40-41_01-42-17_label_B5-B6-0 1606 1666 1760 2070 103 | v=MqrNCb2N5to__#1_label_B1-0-0 247 425 1495 1666 1852 1950 2219 2450 104 | Desperado.1995__#00-16-48_00-18-52_label_B1-0-0 40 110 160 200 786 805 946 2030 2190 2310 105 | Desperado.1995__#00-38-36_00-39-21_label_B1-B2-0 325 350 381 498 531 570 580 700 106 | Desperado.1995__#01-14-11_01-17-28_label_B2-G-0 838 1300 2640 2900 2957 3015 3310 3340 3760 4290 4478 4600 107 | v=PWBZiM-rkoE__#1_label_G-0-0 2050 2300 2499 2550 108 | v=CLFYvG6MsZU__#00-02-16_00-05-06_label_B2-0-0 600 1060 1577 2030 2636 3140 109 | v=3xWiBCIxjIk__#1_label_B4-0-0 122 3700 110 | v=yrpZJ8Vr3aA__#1_label_B1-0-0 120 1270 2180 2352 111 | Election.2005__#00-47-57_00-50-54_label_B5-0-0 170 480 1219 1620 2063 2150 112 | v=abATVcjFumY__#1_label_G-0-0 143 450 1560 1830 113 | v=d5lTTPvJLpw__#1_label_B4-0-0 188 368 508 810 920 5880 114 | v=J67oj92maC0__#1_label_G-0-0 293 740 806 1200 1346 1650 1698 2086 2213 2830 3260 3960 4058 4670 115 | Fast.Five.2011__#00-01-41_00-01-57_label_B6-0-0 70 310 116 | Fast.Five.2011__#00-32-05_00-32-24_label_B1-B2-0 104 155 194 360 117 | Fast.Five.2011__#00-32-56_00-33-26_label_B2-0-0 155 300 450 686 118 | Fast.Five.2011__#01-29-27_01-32-05_label_B1-0-0 110 2934 119 | Fast.Furious.6.2013__#00-45-40_00-47-13_label_B2-0-0 160 295 425 580 700 890 985 1038 1180 1280 2030 2190 120 | Fast.Furious.6.2013__#01-03-25_01-04-59_label_B1-0-0 460 1470 1550 2120 121 | Fast.Furious.6.2013__#01-38-33_01-39-20_label_B1-0-0 140 170 366 490 122 | Fast.Furious.2009__#00-21-06_00-21-57_label_B6-B2-0 170 599 1170 1194 123 | Fast.Furious.2009__#00-25-59_00-26-25_label_B1-0-0 240 622 124 | Fast.Furious.2009__#00-42-10_00-42-41_label_B6-0-0 35 727 125 | Fast.Furious.2009__#01-18-00_01-19-00_label_B2-B6-0 691 1078 1145 1323 126 | Fast.Furious.2009__#01-29-10_01-29-55_label_B2-0-0 877 1058 127 | v=O9g1uuI62m0__#1_label_B4-0-0 0 4460 128 | v=wG-2t0M8CPc__#1_label_B1-0-0 781 1255 1325 1700 129 | v=wPkBaI6vOv0__#00-00-00_00-01-00_label_G-0-0 25 250 468 530 618 860 1350 1435 130 | v=pFamvR9CpYw__#00-00-00_00-06-36_label_B4-0-0 135 1300 1830 2580 3120 3220 3680 7050 7720 9500 131 | v=pFamvR9CpYw__#00-06-36_00-15-20_label_B4-0-0 0 1985 2450 6360 7460 11600 12200 12575 132 | v=y1z12D5bx7c__#1_label_B4-0-0 0 4847 133 | v=pHZ9gOfmY_k__#1_label_B4-0-0 0 5177 134 | v=gbRIKogNZvE__#1_label_B4-0-0 0 2988 135 | v=T-LH8Gv5zzY__#1_label_B4-0-0 0 2425 136 | v=CoxsDkB-rWU__#00-07-06_00-10-30_label_B4-0-0 0 4896 137 | v=TfS-MJoVNjM__#1_label_B4-0-0 0 1309 138 | Fury.2014__#00-07-30_00-07-58_label_G-0-0 125 180 417 490 139 | Fury.2014__#00-39-36_00-40-44_label_B1-B2-0 78 1424 1550 1575 140 | Mission.Impossible.V.Rogue.Nation.2015__#01-54-23_01-55-10_label_B2-0-0 723 992 141 | v=5Ftl4nOSytc__#1_label_G-0-0 166 290 142 | v=38GQ9L2meyE__#1_label_B6-0-0 26 80 210 288 377 396 450 517 597 628 650 850 895 973 1106 1226 1330 1400 1490 1675 1713 1820 1890 1980 2025 2100 2177 2290 2375 2450 2579 2655 2657 3045 3091 3170 3259 3475 3571 4060 4143 4288 4364 4450 143 | v=GqMh_HBNWZE__#00-01-00_00-02-45_label_B6-0-0 44 148 180 220 393 500 810 940 1325 1465 1582 1614 1857 1917 2259 2308 2480 2520 144 | v=RUjrMYWhLng__#1_label_B6-0-0 330 2620 3526 3760 3875 4420 4695 6540 6758 7220 8440 9111 9280 10350 145 | v=UK--hvgP2uY__#1_label_G-0-0 2310 2595 146 | v=Z12t5h2mBJc__#1_label_B1-0-0 22 780 1089 1275 1442 1710 2535 2950 2975 3600 3965 4300 4488 4664 147 | New.Kids.Turbo.2010__#00-22-35_00-23-14_label_B1-0-0 94 134 670 741 148 | New.Kids.Turbo.2010__#00-50-13_00-50-34_label_B1-0-0 344 420 149 | v=gsz_P8t-KM4__#1_label_B4-0-0 0 9695 150 | v=oipd63DGadU__#1_label_B1-0-0 73 760 990 1040 151 | v=xbPWQKZspfU__#1_label_B1-0-0 356 1200 2709 3544 3927 4782 152 | v=afhttnv46Y4__#1_label_B1-0-0 304 860 1435 2020 153 | v=0yHBkMBE8r4__#1_label_B1-0-0 498 720 3260 3570 154 | v=Q8K7roZu3WU__#1_label_B1-0-0 0 1294 155 | v=gENp4SyNxkI__#1_label_B1-0-0 0 1250 156 | v=sE-DC2trBkI__#00-01-00_00-02-45_label_B6-0-0 225 322 640 760 991 1160 1450 1550 1815 1883 2205 2245 157 | v=sE-DC2trBkI__#00-03-00_00-04-09_label_B6-0-0 258 410 700 802 942 1215 1435 1540 158 | v=sE-DC2trBkI__#00-09-00_00-10-05_label_B6-0-0 196 340 495 605 880 1023 1145 1230 1433 1493 159 | v=15wDrZJQpsw__#00-00-00_00-00-51_label_B6-0-0 434 562 749 866 160 | v=15wDrZJQpsw__#00-05-20_00-06-20_label_B6-0-0 95 233 465 516 852 924 1330 1441 161 | v=15wDrZJQpsw__#00-09-00_00-10-55_label_B6-0-0 130 650 1030 1140 1625 1685 1875 2330 2580 2620 162 | v=HSisjzLESak__#00-00-00_00-00-51_label_B6-0-0 578 700 864 960 1153 1224 163 | v=HSisjzLESak__#00-01-00_00-02-45_label_B6-0-0 0 122 640 900 1163 1330 1633 1681 164 | v=HSisjzLESak__#00-07-40_00-08-50_label_B6-0-0 322 522 1340 1545 165 | v=X4I8FhpGJwo__#00-07-40_00-08-50_label_B6-0-0 258 300 552 598 1150 1299 166 | v=X4I8FhpGJwo__#00-09-00_00-10-45_label_B6-0-0 105 165 550 633 1273 1440 1640 1700 2017 2050 2362 2450 167 | v=waIS8TaJxts__#00-00-00_00-00-51_label_B6-0-0 490 625 990 1075 1113 1220 168 | v=waIS8TaJxts__#00-01-30_00-02-45_label_B6-0-0 189 255 1145 1800 169 | v=waIS8TaJxts__#00-05-50_00-06-20_label_B6-0-0 285 536 170 | v=waIS8TaJxts__#00-07-40_00-08-50_label_B6-0-0 329 375 724 920 1276 1460 171 | v=waIS8TaJxts__#00-09-00_00-10-45_label_B6-0-0 257 540 1082 1245 1468 1805 2303 2490 172 | v=H5W58Loofks__#00-05-50_00-06-20_label_B6-0-0 0 90 521 720 173 | v=TOEVc-pOEwM__#00-08-40_00-10-40_label_B6-0-0 40 95 350 530 796 1224 1512 1658 1983 2166 2444 2532 2845 2880 174 | Operation.Red.Sea.2018__#0-02-17_0-02-37_label_B2-0-0 180 397 175 | Operation.Red.Sea.2018__#0-16-06_0-16-46_label_G-0-0 630 774 176 | Operation.Red.Sea.2018__#0-25-36_0-26-03_label_G-0-0 585 620 177 | Operation.Red.Sea.2018__#0-29-36_0-30-20_label_B2-G-0 222 450 460 533 178 | Operation.Red.Sea.2018__#0-56-26_0-57-10_label_G-0-0 330 430 504 533 1038 1055 179 | Operation.Red.Sea.2018__#01-20-58_01-22-00_label_B5-0-0 1271 1336 180 | Operation.Red.Sea.2018__#01-37-22_01-37-36_label_G-B2-0 150 215 233 275 181 | v=sbEHU4GskX4__#1_label_B1-0-0 608 2030 182 | v=3y2OOc6WTrg__#1_label_B1-0-0 400 790 1345 1700 2684 3273 183 | v=iWfUqxibqG0__#1_label_G-0-0 176 628 184 | v=7isLhaNYgkU__#1_label_B4-0-0 0 5705 185 | v=OB0-MXNPUIU__#1_label_B4-0-0 0 4020 186 | v=FRj2K0ulD8Q__#00-00-00_00-00-57_label_B4-0-0 0 1202 187 | v=FRj2K0ulD8Q__#00-04-16_00-06-05_label_B4-0-0 0 2617 188 | v=FRj2K0ulD8Q__#00-06-05_00-08-25_label_B4-0-0 0 2547 2765 3329 189 | v=ICnreR1hxP0__#1_label_B4-0-0 0 3976 190 | v=MJLylzPRvyw__#1_label_B4-0-0 135 3473 191 | v=OFxPTJxA5pg__#1_label_B4-0-0 0 15000 192 | v=fhiAyxpDQMU__#1_label_B4-0-0 0 3975 193 | v=SE5fxK7SasU__#00-03-23_00-04-08_label_B4-0-0 0 1081 194 | v=_tsSKAsVZfo__#1_label_B4-0-0 153 1090 1170 2480 195 | Quantum.Of.Solace.2008__#01-19-30_01-20-05_label_B1-0-0 131 295 196 | Quantum.Of.Solace.2008__#01-30-26_01-30-42_label_B2-B1-0 30 55 114 144 197 | v=k_cvJa1kNaM__#1_label_B4-0-0 900 4505 4810 5625 198 | v=pL2HjXvWuPc__#1_label_B1-0-0 1090 3392 199 | v=U6MC8JJJZSY__#1_label_G-0-0.mp4 150 779 200 | v=FW_qbiPH7UA__#1_label_B1-0-0.mp4 1654 2080 2367 2745 201 | The.Fast.and.the.Furious.2001__#00-29-00_00-29-52_label_B2-G-0.mp4 310 515 790 1045 202 | The.Fast.and.the.Furious.2001__#01-24-18_01-25-05_label_B6-0-0.mp4 147 347 203 | The.Fast.and.the.Furious.2001__#01-37-40_01-38-30_label_B6-0-0.mp4 167 540 204 | v=aWPWHU8x6kE__#1_label_B4-0-0.mp4 3368 4698 205 | v=LJ0Pu5_Mefs__#1_label_G-0-0.mp4 694 851 206 | v=BpargJW29Wo__#00-00-50_00-01-30_label_B1-0-0.mp4 127 788 207 | v=Y1hxr1MWLjM__#1_label_B1-0-0.mp4 1920 2483 208 | v=UYFR-XbyZQc__#00-01-19_00-02-42_label_G-0-0.mp4 33 333 209 | v=Q9Re4CnFJRg__#00-01-21_00-01-47_label_G-0-0.mp4 53 263 460 615 210 | v=PRH5rYUHoVU__#00-00-25_00-02-50_label_B6-0-0.mp4 27 144 230 440 527 600 715 905 958 1072 1235 1290 1485 1565 1820 1910 2020 3030 3061 3160 3255 3387 211 | v=qmsQ-obL1Z4__#00-03-26_00-04-04_label_B6-0-0.mp4 95 220 615 875 212 | v=f6j3YWgVBto__#00-06-16_00-06-51_label_B6-0-0.mp4 240 370 213 | v=yDqThVpu1AM__#1_label_B4-0-0 0 743 214 | v=KUeUxbsBO6s__#1_label_B1-0-0 1042 2000 215 | v=qqsd-cZr01k__#00-01-00_00-02-45_label_B6-0-0 79 103 381 500 750 825 1026 1060 1340 1380 1700 1880 2100 2255 2430 2480 216 | v=qqsd-cZr01k__#00-06-10_00-07-19_label_B6-0-0 0 66 214 410 665 705 852 900 1120 1160 1285 1330 1350 1385 1468 1510 1598 1657 217 | v=6TR2rcgHm4g__#1_label_B4-0-0 0 344 497 790 1025 1195 1246 1590 218 | v=utQ5AvXtNLA__#1_label_B4-0-0 0 893 219 | v=QWwDsPSe7iM__#1_label_B4-0-0 0 2214 220 | v=pMtu7fOHdII__#1_label_B1-0-0 650 850 221 | Gladiator.2000__#01-10-27_01-11-59_label_B1-0-0 745 1663 222 | God.Bless.America.2011__#00-12-40_00-13-12_label_B2-0-0 122 322 223 | God.Bless.America.2011__#00-38-25_00-39-37_label_B2-B5-0 110 160 300 560 810 1150 224 | God.Bless.America.2011__#01-32-00_01-32-50_label_B2-0-0 122 180 225 | GoldenEye.1995__#00-10-00_00-10-40_label_G-0-0 621 920 226 | GoldenEye.1995__#00-24-53_00-25-06_label_B1-0-0 75 270 227 | GoldenEye.1995__#01-17-05_01-19-57_label_B2-B1-0 232 260 728 990 1200 1591 1700 1765 2959 3402 3631 3769 3995 4122 228 | GoldenEye.1995__#02-02-15_02-02-47_label_B1-0-0 100 310 380 690 229 | GoldenEye.1995__#02-03-40_02-04-13_label_G-0-0 64 777 230 | v=ovQ1VTJ_IUI__#1_label_B1-0-0 140 370 660 861 231 | v=tivXK3PGByk__#1_label_B1-0-0 1037 1800 2300 2750 232 | v=m5zK-tzYCQM__#1_label_B4-0-0 126 1630 233 | v=Q6zIX29HkSo__#1_label_B4-0-0 43 2100 234 | Haywire.2011__#00-03-51_00-05-12_label_B1-B2-0 261 797 799 810 872 1470 235 | Haywire.2011__#00-16-07_00-16-49_label_B1-0-0 153 640 236 | Haywire.2011__#00-39-07_00-41-48_label_B1-B2-0 111 2139 2459 3210 3504 3530 237 | Haywire.2011__#01-20-20_01-21-57_label_B1-0-0 578 2225 238 | v=3wxWNAM8Cso__#1_label_G-0-0 4318 4612 4881 5495 239 | v=YdrISbwy_zI__#1_label_G-0-0 530 1362 240 | v=8ewWXhYRUNs__#1_label_B1-0-0 88 910 241 | v=zoIn2hOrUIM__#00-01-00_00-02-45_label_B6-0-0 156 240 521 580 1315 1370 1694 1766 2067 2146 2460 2521 242 | v=zoIn2hOrUIM__#00-06-10_00-07-19_label_B6-0-0 226 270 369 410 790 825 1078 1110 1462 1473 243 | v=yy-KIWDcBr4__#00-08-35_00-10-25_label_B6-0-0 0 187 270 360 852 930 1155 1230 1501 1556 2100 2200 244 | v=yy-KIWDcBr4__#00-16-10_00-17-19_label_B6-0-0 0 90 350 424 514 550 856 956 1040 1103 1576 1657 245 | v=5Tnl7_8RqlA__#00-11-00_00-12-45_label_B6-0-0 110 137 391 465 746 990 1060 1180 1430 1490 1850 1880 2128 2158 2390 2470 246 | v=BzwNU2xmT64__#00-01-00_00-02-45_label_B6-0-0 0 45 246 285 760 1006 1180 1276 1632 1738 2241 2360 247 | v=BzwNU2xmT64__#00-04-10_00-06-09_label_B6-0-0 27 160 500 560 843 888 1060 1160 1360 1510 1637 1685 1840 1880 2233 2320 248 | v=X1arMZmYhsk__#00-13-00_00-14-09_label_B6-0-0 329 450 660 746 943 990 1200 1250 1552 1620 249 | v=uBXprlsmd18__#00-03-00_00-04-09_label_B6-0-0 107 170 974 1020 1600 1657 250 | v=uBXprlsmd18__#00-06-10_00-07-19_label_B6-0-0 356 529 1075 1156 1440 1657 251 | v=7rDRFFSUrPI__#00-01-50_00-02-32_label_G-0-0 655 700 840 860 912 930 252 | v=EH_QB6cm6BE__#1_label_G-0-0 158 400 535 656 253 | v=q_DhkdHGXos__#1_label_B1-0-0 0 1420 254 | v=Dh-3BlhAOnE__#00-00-00_00-13-36_label_B4-0-0 0 170 14846 19116 255 | v=rAlZRFZTwxM__#00-03-13_00-07-08_label_G-0-0 264 304 566 859 2391 2422 2780 3190 256 | v=Dn_SDu22WpM__#00-08-35_00-10-05_label_B6-0-0 55 100 253 313 437 500 755 788 989 1040 1812 1893 257 | v=9Jk2sIp5MRQ__#1_label_G-0-0 175 215 273 567 258 | v=UK2w9Sh47fM__#1_label_G-0-0 210 500 259 | v=vv-MFJPi4Qs__#1_label_G-0-0 649 680 260 | v=hAHXCMRvY_I__#1_label_G-0-0 165 278 261 | v=w16YmhAZ1_Y__#1_label_G-0-0 64 315 262 | v=78u9uBJBqIw__#1_label_B4-B5-0 1330 1784 2059 2360 263 | IP.Man.2.2010__#00-07-14_00-09-13_label_B1-0-0 579 742 1010 1350 1745 1870 1995 2190 264 | IP.Man.2.2010__#01-35-55_01-39-11_label_B1-0-0 637 748 890 1150 2564 3962 265 | Ip.Man.3.2015__#00-17-51_00-18-51_label_B1-0-0 0 145 305 470 266 | Ip.Man.3.2015__#00-44-45_00-45-08_label_B1-0-0 193 365 267 | Ip.Man.3.2015__#00-46-08_00-54-20_label_B1-0-0 4826 4964 5050 5134 5712 6300 6572 10880 11440 11530 11700 11807 268 | Ip.Man.2008__#00-46-51_00-47-30_label_B1-0-0 310 520 269 | Ip.Man.2008__#00-48-57_00-50-47_label_B1-0-0 12 148 502 730 1260 2047 270 | Ip.Man.2008__#01-14-31_01-15-24_label_B1-0-0 177 462 271 | Ip.Man.2008__#01-40-37_01-42-07_label_B4-0-0 105 1482 272 | v=cO1UefhG7AY__#1_label_G-0-0 1048 1200 1940 2190 4180 4300 273 | v=v_LxqgpRouM__#1_label_G-0-0 148 888 274 | v=CUr-Mg8Hmh0__#1_label_G-0-0 3668 3853 4409 4600 275 | v=u0Tnw01PNxc__#1_label_B4-0-0 0 1376 276 | Jason.Bourne.2016__#0-17-25_0-18-53_label_B4-0-0 85 485 640 1280 1522 1950 277 | Jason.Bourne.2016__#0-29-48_0-30-29_label_B4-0-0 117 300 500 660 278 | Jason.Bourne.2016__#0-50-20_0-50-30_label_G-0-0 12 77 279 | v=PuZy1PQIgrA__#1_label_B1-0-0 528 1120 2042 2590 280 | v=Q7NpJXOTiMY__#1_label_G-0-0 10 900 281 | v=JfLYNEsrTew__#1_label_G-0-0 52 430 282 | v=w2NgAYJHnS0__#1_label_G-0-0 50 300 283 | v=9Ydg5IeZpFI__#1_label_B1-0-0 55 660 1131 1710 2231 2475 2799 3291 3830 4489 284 | Kill.Bill.Vol.1.2003__#01-17-20_01-20-20_label_B1-0-0 1835 4273 285 | Kill.Bill.Vol.1.2003__#01-21-39_01-26-55_label_B1-0-0 1829 6854 286 | Kill.Bill.Vol.1.2003__#01-26-56_01-28-18_label_B1-0-0 297 1880 287 | Kill.Bill.Vol.2.2004__#0-25-20_0-26-20_label_B2-0-0 970 1010 288 | Kingsman.The.Golden.Circle.2017__#00-35-55_00-36-35_label_B1-0-0 287 693 289 | Kingsman.The.Golden.Circle.2017__#00-41-22_00-41-32_label_B2-0-0 14 44 290 | Kingsman.The.Secret.Service.2014__#00-22-10_00-23-10_label_B2-0-0 0 230 291 | Kingsman.The.Secret.Service.2014__#01-43-40_01-44-50_label_B2-G-0 96 117 1394 1438 1555 1580 292 | Law.Abiding.Citizen.2009__#01-26-34_01-27-25_label_B2-0-0 626 990 293 | v=SMy2_qNO2Y0__#00-01-50_00-03-13_label_G-0-0 257 310 1040 1300 294 | v=qyTGi0N4OVg__#1_label_G-0-0 646 1120 295 | v=CvMpD1yqfxI__#1_label_G-0-0 194 338 560 720 296 | v=ZCjNPgNu42g__#1_label_B4-0-0 1060 2970 297 | v=gp_D8r-2hwk__#1_label_G-0-0 1365 1600 298 | Lord.of.War__#00-50-10_00-50-50_label_G-0-0 119 230 299 | Lord.of.War__#01-41-04_01-41-53_label_B1-0-0 295 318 433 467 300 | v=cuUK-3MNqtw__#1_label_B1-0-0 1280 1450 301 | Love.Death.and.Robots.S01E10__#0-01-00_0-02-13_label_B2-G-0 560 600 846 905 1050 1295 302 | Love.Death.and.Robots.S01E15__#0-05-38_0-06-18_label_G-0-0 340 430 303 | v=OE9t3XImErk__#1_label_B1-0-0 193 290 304 | v=3e3C-LQWF2s__#1_label_B1-0-0 141 1000 1503 1665 1860 2200 3170 3300 3882 3960 4085 4420 305 | v=HUL83wCQATc__#1_label_B1-0-0 145 402 1000 1403 306 | Mindhunters.2004__#01-19-00_01-20-05_label_B2-0-0 0 375 546 604 307 | Mission.Impossible.Fallout.2018__#00-31-21_00-32-50_label_B1-0-0 96 2125 308 | Mission.Impossible.Fallout.2018__#00-39-18_00-40-36_label_B1-0-0 327 1680 309 | Mission.Impossible.Fallout.2018__#02-03-50_02-04-35_label_B1-0-0 222 976 310 | Mission.Impossible.Ghost.Protocol.2011__#00-01-18_00-01-40_label_B2-0-0 63 200 371 410 311 | Mission.Impossible.Ghost.Protocol.2011__#01-12-38_01-13-43_label_B1-0-0 180 630 1086 1122 1260 1390 312 | Mission.Impossible.Ghost.Protocol.2011__#01-54-24_01-54-50_label_B1-0-0 87 347 313 | Mission.Impossible.Ghost.Protocol.2011__#01-56-52_01-57-03_label_B1-0-0 54 120 314 | Mission.Impossible.II.2000__#01-18-55_01-20-53_label_B2-0-0 159 330 315 | Mission.Impossible.II.2000__#01-25-08_01-25-31_label_B2-0-0 9 117 316 | Mission.Impossible.II.2000__#01-29-30_01-29-44_label_B1-0-0 136 280 317 | Mission.Impossible.II.2000__#01-32-56_01-33-25_label_B1-0-0 460 600 318 | Mission.Impossible.II.2000__#01-38-25_01-38-41_label_B2-B5-0 130 165 260 295 319 | Mission.Impossible.II.2000__#01-46-20_01-47-11_label_B2-B6-0 10 56 84 135 184 413 608 919 320 | Mission.Impossible.III.2006__#00-19-56_00-20-26_label_B2-0-0 19 560 321 | Mission.Impossible.III.2006__#01-06-52_01-07-15_label_G-0-0 259 346 322 | Mission.Impossible.III.2006__#01-54-02_01-54-40_label_B2-0-0 206 300 554 735 323 | Mission.Impossible.V.Rogue.Nation.2015__#00-14-40_00-17-12_label_B1-0-0 170 222 395 420 538 558 1445 3296 324 | Mission.Impossible.V.Rogue.Nation.2015__#00-17-27_00-17-55_label_B2-0-0 476 579 325 | Mission.Impossible.V.Rogue.Nation.2015__#01-15-52_01-16-48_label_B6-B2-0 84 216 328 407 526 995 326 | Mission.Impossible.V.Rogue.Nation.2015__#01-19-48_01-20-20_label_B6-0-0 192 332 327 | v=iegHZ_UWWsA__#00-14-05_00-20-33_label_B4-0-0 1108 5466 5921 7130 328 | v=8oJMGgaww70__#1_label_B4-0-0 0 3578 329 | v=Y43UkfAe_bc__#00-00-00_00-01-40_label_B4-0-0 0 2400 330 | v=aQ3qMpgZjwg__#00-00-00_00-01-48_label_B4-0-0 637 1530 1800 2100 331 | v=3xkwZk44VQs__#1_label_B4-0-0 0 771 792 1489 1509 1978 332 | v=yvnj5VIDsNI__#1_label_B4-0-0 0 1197 3496 3710 3765 3910 333 | v=GbMQ2CxeNHU__#1_label_B4-0-0 0 5489 334 | v=3zgaqeVSXuI__#1_label_B4-0-0 0 663 670 1320 1500 1604 1762 1848 335 | v=ADQNFs9bfFk__#00-00-00_00-00-58_label_B4-0-0 0 350 465 1379 336 | v=fIlmgvc-bUk__#1_label_B4-0-0 123 1370 337 | v=kPWdgckIhLI__#1_label_B4-0-0 0 558 338 | Rush.Hour.3.2007.BluRay__#00-04-59_00-05-31_label_B2-0-0 573 623 339 | Rush.Hour.3.2007.BluRay__#00-34-11_00-35-46_label_B1-0-0 416 738 805 1034 1118 1390 1461 1788 1862 2055 2232 2258 340 | Rush.Hour.3.2007.BluRay__#00-40-56_00-41-52_label_B2-0-0 665 684 1145 1195 341 | Rush.Hour.3.2007.BluRay__#01-12-16_01-12-36_label_B1-0-0 0 450 342 | Rush.Hour.1998.BluRay__#00-08-08_00-09-15_label_B2-B6-G 115 174 411 445 457 494 594 620 640 830 1000 1127 1248 1445 343 | Rush.Hour.1998.BluRay__#00-11-05_00-12-10_label_B2-B1-0 140 175 275 575 935 1360 344 | Rush.Hour.1998.BluRay__#01-20-28_01-21-38_label_B2-0-0 122 200 220 300 316 357 762 848 998 1049 1105 1174 1323 1475 345 | Rush.Hour.1998.BluRay__#01-25-21_01-25-46_label_B2-0-0 437 519 346 | v=wnd3IYH7x1o__#1_label_G-0-0 260 440 746 1000 347 | v=lcBUb7EOQ4o__#1_label_G-0-0 913 1050 1132 1270 1306 2235 348 | Salt.2010__#00-17-40_00-18-16_label_B1-0-0 530 864 349 | Salt.2010__#01-06-27_01-07-07_label_G-B2-0 25 92 119 137 190 270 319 526 690 730 350 | Salt.2010__#01-12-05_01-12-32_label_B2-G-0 123 296 340 456 351 | Saving.Private.Ryan.1998__#00-47-50_00-49-35_label_B2-0-0 483 670 801 887 1211 1431 352 | Saving.Private.Ryan.1998__#02-23-00_02-23-28_label_B1-0-0 38 672 353 | Saving.Private.Ryan.1998__#02-23-58_02-24-18_label_G-0-0 228 311 354 | Saving.Private.Ryan.1998__#02-24-18_02-25-42_label_B1-0-0 72 428 550 830 1085 2016 355 | Saving.Private.Ryan.1998__#02-25-42_02-26-03_label_B2-0-0 247 504 356 | v=Qc3pAEeK16A__#1_label_B1-0-0 70 237 980 1218 357 | v=6KQFxxHmVMc__#1_label_B1-0-0 144 590 1991 2100 2220 2340 2479 2640 358 | Shoot.Em.Up.2007__#00-32-40_00-35-10_label_B2-0-0 43 74 114 159 234 371 560 781 905 1426 1522 1696 1790 2185 2374 2434 2517 2697 2795 3432 359 | Shoot.Em.Up.2007__#00-53-25_00-56-47_label_B2-0-0 101 200 705 760 2650 2709 2875 3120 3217 3469 4021 4265 4770 4820 360 | Shoot.Em.Up.2007__#00-56-56_00-57-15_label_B2-0-0 0 380 361 | Sin.City.2005__#0-06-42_0-07-06_label_B1-0-0 340 380 362 | Sin.City.2005__#0-10-40_0-11-40_label_B2-0-0 50 106 355 392 466 480 555 580 650 750 363 | Sin.City.2005__#0-22-04_0-22-18_label_B5-0-0 0 336 364 | Sin.City.2005__#0-29-34_0-29-56_label_B2-0-0 355 425 470 510 365 | Sin.City.2005__#0-32-00_0-32-16_label_B2-0-0 125 159 280 320 366 | Sin.City.2005__#0-43-45_0-44-01_label_B2-0-0 52 148 367 | Sin.City.2005__#01-19-08_01-19-23_label_G-0-0 75 150 368 | Sin.City.2005__#01-51-59_01-52-29_label_B2-0-0 35 80 555 610 369 | Skyfall.2012__#00-03-22_00-03-40_label_B6-0-0 180 432 370 | Skyfall.2012__#00-06-10_00-06-25_label_B6-0-0 165 255 371 | Skyfall.2012__#01-20-05_01-20-20_label_B1-B2-0 45 260 372 | Skyfall.2012__#01-55-44_01-56-57_label_B2-G-0 54 298 370 467 548 596 687 740 944 1076 1153 1182 1235 1449 373 | Skyfall.2012__#01-57-40_01-58-00_label_B2-0-0 0 168 374 | Skyfall.2012__#02-08-44_02-09-17_label_B1-0-0 78 655 375 | v=-kwNh1lMU-w__#1_label_B4-0-0 0 837 1360 4080 376 | v=_BgJEXQkjNQ__#1_label_G-0-0 1720 5145 5302 7938 377 | v=xe4ee56aHSg__#1_label_G-0-0 270 531 378 | v=2al4t5inObA__#1_label_B4-0-0 0 3322 379 | Spectre.2015__#00-44-05_00-44-40_label_B1-B2-0 38 217 254 329 539 817 380 | Spectre.2015__#01-08-58_01-09-20_label_B1-B2-0 50 100 295 320 330 355 381 | Spectre.2015__#01-10-35_01-10-53_label_B2-0-0 170 237 270 337 382 | Spectre.2015__#01-32-55_01-35-40_label_B1-B2-0 92 132 134 154 177 2870 2960 2980 3126 3743 383 | Spectre.2015__#02-02-57_02-03-27_label_B1-B2-0 384 430 384 | Spectre.2015__#02-15-30_02-16-05_label_G-0-0 430 571 385 | v=rXWOpZ7W2fA__#1_label_B1-0-0 1500 1570 1665 1800 1957 2137 386 | Taken.2.UNRATED.EXTENDED.2012__#00-13-42_00-14-16_label_B5-0-0 40 424 387 | Taken.2.UNRATED.EXTENDED.2012__#00-45-50_00-46-20_label_G-0-0 374 418 388 | Taken.2.UNRATED.EXTENDED.2012__#01-00-00_01-00-16_label_B2-0-0 327 368 389 | Taken.2.UNRATED.EXTENDED.2012__#01-06-28_01-06-49_label_B2-0-0 49 228 390 | Taken.2.UNRATED.EXTENDED.2012__#01-18-25_01-18-40_label_B2-0-0 49 260 391 | Taken.2.UNRATED.EXTENDED.2012__#01-18-40_01-19-00_label_B1-0-0 18 68 392 | Taken.3.2014__#00-03-58_00-04-16_label_B2-0-0 44 70 393 | Taken.3.2014__#01-04-50_01-05-05_label_B1-0-0 109 147 394 | Taken.3.2014__#01-12-48_01-14-30_label_B2-B1-0 47 585 685 1516 2060 2080 395 | Taken.3.2014__#01-19-13_01-19-30_label_B1-0-0 39 204 396 | Taken.Extended.Cut.2008__#00-33-51_00-34-03_label_B1-0-0 139 185 206 281 397 | Taken.Extended.Cut.2008__#00-50-45_00-51-04_label_B2-B6-0 103 190 430 455 398 | v=rKpXqwE2rg8__#00-01-57_00-03-06_label_B4-0-0 0 1657 399 | v=96N4XAJ7FKM__#1_label_B1-0-0 922 1936 2831 3281 400 | v=nLAapCIlr-o__#00-07-22_00-08-19_label_B6-0-0 85 140 399 650 921 1076 401 | v=nLAapCIlr-o__#00-08-35_00-10-25_label_B6-0-0 105 200 662 840 1295 1386 1852 1913 2106 2197 2590 2630 402 | v=nLAapCIlr-o__#00-11-00_00-12-45_label_B6-0-0 35 75 319 362 854 917 1523 1596 1747 1800 2013 2098 2417 2521 403 | v=nLAapCIlr-o__#00-16-10_00-18-09_label_B6-0-0 15 120 692 770 1337 1486 1874 2000 2240 2400 2793 2830 404 | v=osjdmjNJUdg__#1_label_B1-0-0 191 856 1380 1985 405 | The.Bourne.Identity.2002__#0-19-53_0-21-48_label_B1-0-0 1931 2115 2641 2750 406 | The.Bourne.Identity.2002__#0-54-45_0-56-25_label_B6-0-0 504 590 633 745 872 945 2052 2106 407 | The.Bourne.Identity.2002__#01-27-39_01-29-19_label_B2-0-0 530 560 1455 1514 1794 1830 408 | The.Bourne.Identity.2002__#01-45-00_01-45-46_label_B1-0-0 240 296 525 559 1045 1086 409 | The.Bourne.Legacy.2012__#0-39-56_0-40-15_label_G-0-0 24 119 410 | The.Bourne.Legacy.2012__#0-59-18_01-01-52_label_B2-B1-0 46 879 946 1018 1111 1160 1212 1354 1426 1660 3300 3330 3340 3610 411 | The.Bourne.Ultimatum.2007__#01-26-27_01-28-03_label_B2-B6-0 475 569 740 900 983 1250 1531 1590 1754 2039 412 | The.Bourne.Ultimatum.2007__#01-44-02_01-44-11_label_B2-0-0 60 80 413 | The.Fast.and.the.Furious.2001__#00-52-18_00-53-36_label_B1-0-0 542 1033 1734 1800 414 | The.Fast.and.the.Furious.2001__#01-14-04_01-15-20_label_B1-0-0 715 979 1030 1072 1221 1250 1333 1460 415 | v=0W8LohxH9nI__#1_label_B1-0-0 933 1250 416 | The.Hurt.Locker.2008__#0-08-45_0-09-57_label_G-0-0 858 1560 417 | The.Hurt.Locker.2008__#0-19-22_0-22-32_label_B2-B1-0 2094 2125 2244 2270 2532 2575 3580 3750 418 | The.Hurt.Locker.2008__#01-02-10_01-03-55_label_B2-0-0 448 520 755 830 1225 1374 1968 1990 419 | The.Hurt.Locker.2008__#01-27-42_01-28-09_label_G-0-0 270 413 420 | v=7QOdBpDrmKk__#00-01-00_00-02-45_label_B6-0-0 32 59 240 281 771 797 1100 1130 1894 1988 2230 2270 421 | v=7QOdBpDrmKk__#00-10-30_00-10-51_label_B6-0-0 332 412 422 | v=k7R_Qo-BiAw__#00-10-30_00-10-51_label_B6-0-0 0 150 296 505 423 | The.World.Is.Not.Enough.1999__#00-07-04_00-07-45_label_G-B2-0 270 480 847 944 424 | The.World.Is.Not.Enough.1999__#00-10-22_00-10-40_label_G-0-0 6 65 425 | v=DRGMajXo_PI__#00-02-15_00-03-02_label_G-0-0 0 1111 426 | v=ONsmJAyFAAw__#1_label_G-0-0 0 1523 427 | v=Rc9EOFjtj0c__#1_label_B6-0-0 564 625 894 988 1307 1389 1615 1820 1852 1980 2263 2328 2701 2738 2901 2980 3109 3270 3301 3342 428 | v=FfrXebzAwOg__#00-06-10_00-07-19_label_B6-0-0 165 202 341 410 599 700 780 852 996 1022 1117 1181 1442 1516 429 | v=DtxU8UYiFws__#1_label_G-0-0 272 415 430 | Tropa.de.Elite.2.2010__#00-13-41_00-14-56_label_B4-0-0 0 535 1677 1735 431 | Tropa.de.Elite.2.2010__#01-33-01_01-33-30_label_B2-0-0 138 176 211 260 432 | Tropa.de.Elite.2.2010__#00-42-00_00-43-00_label_B2-0-0 777 900 433 | Tropa.de.Elite.2.2010__#00-46-11_00-46-50_label_B2-0-0 109 265 434 | v=Lt9M_tSij_4__#00-00-00_00-04-24_label_B6-0-0 170 361 502 657 910 1044 1173 1700 1883 2000 2103 2158 2573 2642 2982 3217 3350 3461 3653 3770 4092 4220 4483 4548 4859 4982 5187 5291 5515 5724 6289 6335 435 | v=BXR3d22BhHs__#00-06-00_00-07-00_label_B4-0-0 1 1425 436 | v=BXR3d22BhHs__#00-07-00_00-08-00_label_B4-0-0 1 1431 437 | v=tP19WuyY3IY__#00-00-00_00-00-51_label_B6-0-0 104 189 524 635 888 966 438 | v=tP19WuyY3IY__#00-06-10_00-07-51_label_B6-0-0 1243 1270 1631 1766 439 | v=YANmwpkuWRo__#1_label_G-0-0 968 1072 440 | v=Q-u3Y9TkZhQ__#1_label_B4-0-0 5 1031 441 | v=IWzI9V3WSnc__#1_label_B4-0-0 4 2687 442 | v=X6Tbn8_X-Os__#1_label_B1-0-0 200 914 1790 3450 443 | v=1GFBuaGCeTo__#1_label_G-0-0 802 1230 444 | v=GN-7_ye5k4E__#1_label_G-0-0 572 1300 445 | v=eqtJjxsTgtg__#1_label_G-0-0 343 480 446 | v=k98yVPXiYoU__#1_label_G-0-0 64 252 447 | v=bDm2-1NZBLw__#1_label_B4-0-0 1 2209 448 | v=2Rr21qkZEDQ__#00-03-00_00-04-09_label_B6-0-0 548 647 1300 1410 449 | v=2Rr21qkZEDQ__#00-14-10_00-16-09_label_B6-0-0 482 576 922 1032 1389 1444 2035 2160 2472 2670 450 | v=CcD-YNlkMPY__#1_label_G-0-0 1300 1356 451 | v=V7FtjJStY1M__#1_label_G-0-0 44 359 452 | v=v-Plzx73K68__#1_label_B4-0-0 316 921 453 | v=AF9p0YA0hyA__#00-00-45_00-01-44_label_B4-0-0 3 1417 454 | v=AF9p0YA0hyA__#00-03-03_00-04-20_label_B4-0-0 1 1841 455 | v=_5Kmb4tqMxs__#1_label_B4-0-0 23 1250 456 | v=TsuuRo6PEXM__#1_label_B1-0-0 307 1218 2306 3200 457 | wangted.2008__#0-22-55_0-23-40_label_B2-0-0 395 560 632 697 458 | wangted.2008__#0-45-25_0-47-01_label_B1-0-0 28 1245 459 | wangted.2008__#0-48-30_0-49-00_label_B2-0-0 346 511 460 | v=YhCSFZNwfHc__#00-01-15_00-01-42_label_G-B2-0 0 65 285 500 461 | v=_YobflFU_HU__#00-06-10_00-07-51_label_B6-0-0 0 65 300 439 630 712 1130 1240 1500 1620 1935 2058 2290 2398 462 | v=g2v3EkBj9Zc__#1_label_B4-0-0 0 530 463 | v=uprb6aBzymw__#00-03-03_00-05-13_label_B4-0-0 0 125 530 790 2045 2887 464 | Yellow.Sea.2010__#01-04-00_01-04-50_label_B1-0-0 0 356 465 | Yellow.Sea.2010__#01-05-40_01-06-47_label_B1-B2-0 140 407 735 1180 466 | Yellow.Sea.2010__#01-58-30_01-59-27_label_B5-0-0 0 787 467 | v=WiS3TIvykeY__#1_label_B4-0-0 0 1560 468 | v=UzxuX79xq4s__#1_label_B4-0-0 270 840 1312 2989 3430 4660 469 | v=lpkL0Y1MhA8__#1_label_B4-0-0 0 248 397 585 1165 1940 2393 2523 470 | v=1io9Uh54Vhg__#1_label_B4-0-0 0 1976 471 | v=m5Ya6z6qroo__#1_label_B4-0-0 138 5505 472 | Young.And.Dangerous.I.1996__#0-20-50_0-22-15_label_B1-0-0 425 1079 473 | Young.And.Dangerous.I.1996__#0-46-25_0-46-55_label_B6-0-0 350 570 474 | Young.And.Dangerous.I.1996__#0-46-57_0-48-57_label_B1-0-0 0 2454 475 | Young.And.Dangerous.III.1996__#00-23-00_00-25-54_label_B4-0-0 2497 3051 476 | Young.And.Dangerous.IV.1997__#01-07-50_01-09-45_label_B1-0-0 368 2734 477 | Young.And.Dangerous.IV.1997__#01-32-04_01-32-58_label_B1-0-0 650 1100 478 | Taken.3.2014__#01-09-26_01-09-54_label_G-0-0 22 258 479 | Bullet.in.the.Head.1990__#00-23-31_00-24-40_label_G-0-0 1325 1400 480 | v=DVdXoVUNkhg__#1_label_G-0-0 990 1145 481 | v=CczYDDDf22A__#1_label_G-0-0 190 240 482 | v=9CWJd1SezkA__#1_label_G-0-0 464 614 483 | v=8oTjTufJnXI__#1_label_G-0-0 977 1610 484 | Mission.Impossible.III.2006__#00-56-23_00-56-40_label_G-0-0 260 363 485 | v=ntaXjnpusxM__#1_label_G-0-0 312 500 486 | Saving.Private.Ryan.1998__#02-20-00_02-20-13_label_G-0-0 115 215 487 | Saving.Private.Ryan.1998__#02-16-58_02-17-13_label_G-0-0 147 300 488 | Love.Death.and.Robots.S01E13__#0-11-48_0-12-07_label_G-0-0 60 240 489 | v=35zynOrkvMk__#1_label_G-0-0 114 164 490 | v=1q5V6DKH3bw__#1_label_B4-0-0 0 425 680 1515 2030 2353 2500 2950 491 | v=3WmMhircZOc__#1_label_B4-0-0 580 4150 492 | v=-etV57xZ4_I__#1_label_B4-0-0 0 260 493 | v=tv0rI-5ycBU__#1_label_B4-0-0 0 130 206 2945 494 | v=2WkuPNLfl5s__#1_label_B4-0-0 225 1008 495 | v=CFfc4UWzVB8__#00-00-00_00-00-30_label_B4-0-0 508 581 605 721 496 | v=fQQzg2VfJME__#1_label_B4-0-0 117 1300 497 | v=0N0PbzjEg0U__#1_label_B4-0-0 0 550 498 | v=p7kvQ4OpDgU__#1_label_B4-0-0 270 1370 499 | v=GI49YSCruwY__#1_label_B4-0-0 0 1410 1795 1965 2320 2980 500 | City.of.God.2002__#01-24-10_01-25-10_label_B2-0-0 340 1090 --------------------------------------------------------------------------------