├── src
├── clip
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── simple_tokenizer.py
│ ├── clip.py
│ └── model.py
├── xd_option.py
├── ucf_option.py
├── utils
│ ├── dataset.py
│ ├── lr_warmup.py
│ ├── tools.py
│ ├── xd_detectionMAP.py
│ ├── ucf_detectionMAP.py
│ └── layers.py
├── crop.py
├── xd_test.py
├── ucf_test.py
├── xd_train.py
├── ucf_train.py
└── model.py
├── list
├── gt.npy
├── gt_ucf.npy
├── gt_label.npy
├── gt_segment.npy
├── gt_label_ucf.npy
├── gt_segment_ucf.npy
├── make_list_xd.py
├── make_list_ucf.py
├── make_gt_xd.py
├── make_gt_mAP_xd.py
├── make_gt_mAP_ucf.py
├── make_gt_ucf.py
├── Anomaly_Test.txt
├── Temporal_Anomaly_Annotation.txt
├── ucf_CLIP_rgbtest.csv
└── annotations.txt
├── data
└── framework.png
├── README.md
└── LICENSE
/src/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/list/gt.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt.npy
--------------------------------------------------------------------------------
/list/gt_ucf.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_ucf.npy
--------------------------------------------------------------------------------
/data/framework.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/data/framework.png
--------------------------------------------------------------------------------
/list/gt_label.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_label.npy
--------------------------------------------------------------------------------
/list/gt_segment.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_segment.npy
--------------------------------------------------------------------------------
/list/gt_label_ucf.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_label_ucf.npy
--------------------------------------------------------------------------------
/list/gt_segment_ucf.npy:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/list/gt_segment_ucf.npy
--------------------------------------------------------------------------------
/src/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/nwpu-zxr/VadCLIP/HEAD/src/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/list/make_list_xd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | import csv
4 |
5 | root_path = '/home/xbgydx/Desktop/XDTrainClipFeatures' ## the path of features
6 | files = sorted(glob.glob(os.path.join(root_path, "*.npy")))
7 | violents = []
8 | normal = []
9 |
10 | with open('list/xd_CLIP_rgb.csv', 'w+') as f: ## the name of feature list
11 | writer = csv.writer(f)
12 | writer.writerow(['path', 'label'])
13 | for file in files:
14 | if '.npy' in file:
15 | if '_label_A' in file:
16 | normal.append(file)
17 | else:
18 | label = file.split('_label_')[1].split('__')[0]
19 | writer.writerow([file, label])
20 |
21 | for file in normal:
22 | writer.writerow([file, 'A'])
--------------------------------------------------------------------------------
/list/make_list_ucf.py:
--------------------------------------------------------------------------------
1 | import os
2 | import csv
3 |
4 | root_path = '/home/xbgydx/Desktop/UCFClipFeatures/'
5 | txt = 'list/Anomaly_Train.txt'
6 | files = list(open(txt))
7 | normal = []
8 | count = 0
9 |
10 | with open('list/ucf_CLIP_rgb.csv', 'w+') as f: ## the name of feature list
11 | writer = csv.writer(f)
12 | writer.writerow(['path', 'label'])
13 | for file in files:
14 | filename = root_path + file[:-5] + '__0.npy'
15 | label = file.split('/')[0]
16 | if os.path.exists(filename):
17 | if 'Normal' in label:
18 | #continue
19 | filename = filename[:-5]
20 | for i in range(0, 10, 1):
21 | normal.append(filename + str(i) + '.npy')
22 | else:
23 | filename = filename[:-5]
24 | for i in range(0, 10, 1):
25 | writer.writerow([filename + str(i) + '.npy', label])
26 | else:
27 | count += 1
28 | print(filename)
29 |
30 | for file in normal:
31 | writer.writerow([file, 'Normal'])
32 |
33 | print(count)
--------------------------------------------------------------------------------
/list/make_gt_xd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import cv2
4 |
5 | clip_len = 16
6 |
7 | # the dir of testing images
8 | feature_list = 'list/xd_CLIP_rgbtest.csv'
9 | # the ground truth txt
10 |
11 | gt_txt = 'list/annotations.txt' ## the path of test annotations
12 | gt_lines = list(open(gt_txt))
13 | gt = []
14 | lists = pd.read_csv(feature_list)
15 | count = 0
16 |
17 | for idx in range(lists.shape[0]):
18 | name = lists.loc[idx]['path']
19 | if '__0.npy' not in name:
20 | continue
21 | #feature = name.split('label_')[-1]
22 | fea = np.load(name)
23 | lens = (fea.shape[0] + 1) * clip_len
24 | name = name.split('/')[-1]
25 | name = name[:-7]
26 | # the number of testing images in this sub-dir
27 |
28 | gt_vec = np.zeros(lens).astype(np.float32)
29 | if 'label_A' not in name:
30 | for gt_line in gt_lines:
31 | if name in gt_line:
32 | count += 1
33 | gt_content = gt_line.strip('\n').split()
34 | abnormal_fragment = [[int(gt_content[i]),int(gt_content[j])] for i in range(1,len(gt_content),2) \
35 | for j in range(2,len(gt_content),2) if j==i+1]
36 | if len(abnormal_fragment) != 0:
37 | abnormal_fragment = np.array(abnormal_fragment)
38 | for frag in abnormal_fragment:
39 | gt_vec[frag[0]:frag[1]]=1.0
40 | break
41 | gt.extend(gt_vec[:-clip_len])
42 |
43 | np.save('list/gt_xd.npy', gt)
44 |
45 | print(count)
--------------------------------------------------------------------------------
/list/make_gt_mAP_xd.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import os
4 | import cv2
5 | import pandas as pd
6 | import warnings
7 |
8 | clip_len = 16
9 |
10 | # the dir of testing images
11 | feature_list = 'list/xd_CLIP_rgbtest.csv'
12 |
13 | # the ground truth txt
14 | gt_txt = 'list/annotations_multiclasses.txt'
15 | gt_lines = list(open(gt_txt))
16 |
17 | #warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
18 |
19 | gt_segment = []
20 | gt_label = []
21 | lists = pd.read_csv(feature_list)
22 |
23 | for idx in range(lists.shape[0]):
24 | name = lists.loc[idx]['path']
25 | if '__0.npy' not in name:
26 | continue
27 | segment = []
28 | label = []
29 | if '_label_A' in name:
30 | fea = np.load(name)
31 | lens = fea.shape[0] * clip_len
32 | name = name.split('/')[-1]
33 | name = name[:-7]
34 | segment.append([0, lens])
35 | label.append('A')
36 | else:
37 | name = name.split('/')[-1]
38 | name = name[:-7]
39 | for gt_line in gt_lines:
40 | if name in gt_line:
41 | gt_content = gt_line.strip('\n').split()
42 | for j in range(1, len(gt_content), 3):
43 | print(gt_content, j)
44 | segment.append([gt_content[j + 1], gt_content[j + 2]])
45 | label.append(gt_content[j])
46 | break
47 | gt_segment.append(segment)
48 | gt_label.append(label)
49 |
50 | np.save('list/gt_label.npy', gt_label)
51 | np.save('list/gt_segment.npy', gt_segment)
--------------------------------------------------------------------------------
/src/xd_option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description='VadCLIP')
4 | parser.add_argument('--seed', default=234, type=int)
5 |
6 | parser.add_argument('--embed-dim', default=512, type=int)
7 | parser.add_argument('--visual-length', default=256, type=int)
8 | parser.add_argument('--visual-width', default=512, type=int)
9 | parser.add_argument('--visual-head', default=1, type=int)
10 | parser.add_argument('--visual-layers', default=1, type=int)
11 | parser.add_argument('--attn-window', default=64, type=int)
12 | parser.add_argument('--prompt-prefix', default=10, type=int)
13 | parser.add_argument('--prompt-postfix', default=10, type=int)
14 | parser.add_argument('--classes-num', default=7, type=int)
15 |
16 | parser.add_argument('--max-epoch', default=10, type=int)
17 | parser.add_argument('--model-path', default='model/model_xd.pth')
18 | parser.add_argument('--use-checkpoint', default=False, type=bool)
19 | parser.add_argument('--checkpoint-path', default='model/checkpoint.pth')
20 | parser.add_argument('--batch-size', default=96, type=int)
21 | parser.add_argument('--train-list', default='list/xd_CLIP_rgb.csv')
22 | parser.add_argument('--test-list', default='list/xd_CLIP_rgbtest.csv')
23 | parser.add_argument('--gt-path', default='list/gt.npy')
24 | parser.add_argument('--gt-segment-path', default='list/gt_segment.npy')
25 | parser.add_argument('--gt-label-path', default='list/gt_label.npy')
26 |
27 | parser.add_argument('--lr', default=1e-5)
28 | parser.add_argument('--scheduler-rate', default=0.1)
29 | parser.add_argument('--scheduler-milestones', default=[3, 6, 10])
--------------------------------------------------------------------------------
/src/ucf_option.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser(description='VadCLIP')
4 | parser.add_argument('--seed', default=234, type=int)
5 |
6 | parser.add_argument('--embed-dim', default=512, type=int)
7 | parser.add_argument('--visual-length', default=256, type=int)
8 | parser.add_argument('--visual-width', default=512, type=int)
9 | parser.add_argument('--visual-head', default=1, type=int)
10 | parser.add_argument('--visual-layers', default=2, type=int)
11 | parser.add_argument('--attn-window', default=8, type=int)
12 | parser.add_argument('--prompt-prefix', default=10, type=int)
13 | parser.add_argument('--prompt-postfix', default=10, type=int)
14 | parser.add_argument('--classes-num', default=14, type=int)
15 |
16 | parser.add_argument('--max-epoch', default=10, type=int)
17 | parser.add_argument('--model-path', default='model/model_ucf.pth')
18 | parser.add_argument('--use-checkpoint', default=False, type=bool)
19 | parser.add_argument('--checkpoint-path', default='model/checkpoint.pth')
20 | parser.add_argument('--batch-size', default=64, type=int)
21 | parser.add_argument('--train-list', default='list/ucf_CLIP_rgb.csv')
22 | parser.add_argument('--test-list', default='list/ucf_CLIP_rgbtest.csv')
23 | parser.add_argument('--gt-path', default='list/gt_ucf.npy')
24 | parser.add_argument('--gt-segment-path', default='list/gt_segment_ucf.npy')
25 | parser.add_argument('--gt-label-path', default='list/gt_label_ucf.npy')
26 |
27 | parser.add_argument('--lr', default=2e-5)
28 | parser.add_argument('--scheduler-rate', default=0.1)
29 | parser.add_argument('--scheduler-milestones', default=[4, 8])
--------------------------------------------------------------------------------
/list/make_gt_mAP_ucf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import os
4 | import cv2
5 | import pandas as pd
6 | import warnings
7 |
8 | clip_len = 16
9 |
10 | feature_list = 'list/ucf_CLIP_rgbtest.csv'
11 |
12 | # the ground truth txt
13 | gt_txt = 'list/Temporal_Anomaly_Annotation.txt'
14 | gt_lines = list(open(gt_txt))
15 |
16 | #warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)
17 |
18 | gt_segment = []
19 | gt_label = []
20 | lists = pd.read_csv(feature_list)
21 |
22 | for idx in range(lists.shape[0]):
23 | name = lists.loc[idx]['path']
24 | label_text = lists.loc[idx]['label']
25 | if '__0.npy' not in name:
26 | continue
27 | segment = []
28 | label = []
29 | if 'Normal' in label_text:
30 | fea = np.load(name)
31 | lens = fea.shape[0] * clip_len
32 | name = name.split('/')[-1]
33 | name = name[:-7]
34 | segment.append([0, lens])
35 | label.append('A')
36 | else:
37 | name = name.split('/')[-1]
38 | name = name[:-7]
39 | for gt_line in gt_lines:
40 | if name in gt_line:
41 | gt_content = gt_line.strip('\n').split(' ')
42 | segment.append([gt_content[2], gt_content[3]])
43 | label.append(gt_content[1])
44 | if gt_content[4] != '-1':
45 | segment.append([gt_content[4], gt_content[5]])
46 | label.append(gt_content[1])
47 | break
48 | gt_segment.append(segment)
49 | gt_label.append(label)
50 |
51 | np.save('list/gt_label_ucf.npy', gt_label)
52 | np.save('list/gt_segment_ucf.npy', gt_segment)
--------------------------------------------------------------------------------
/list/make_gt_ucf.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import cv2
4 |
5 | clip_len = 16
6 |
7 | # the dir of testing images
8 | feature_list = 'list/ucf_CLIP_rgbtest.csv'
9 | # the ground truth txt
10 |
11 | gt_txt = 'list/Temporal_Anomaly_Annotation.txt' ## the path of test annotations
12 | gt_lines = list(open(gt_txt))
13 | gt = []
14 | lists = pd.read_csv(feature_list)
15 | count = 0
16 |
17 | for idx in range(lists.shape[0]):
18 | name = lists.loc[idx]['path']
19 | if '__0.npy' not in name:
20 | continue
21 | #feature = name.split('label_')[-1]
22 | fea = np.load(name)
23 | lens = (fea.shape[0] + 1) * clip_len
24 | name = name.split('/')[-1]
25 | name = name[:-7]
26 | # the number of testing images in this sub-dir
27 |
28 | gt_vec = np.zeros(lens).astype(np.float32)
29 | if 'Normal' not in name:
30 | for gt_line in gt_lines:
31 | if name in gt_line:
32 | count += 1
33 | gt_content = gt_line.strip('\n').split(' ')[1:-1]
34 | abnormal_fragment = [[int(gt_content[i]),int(gt_content[j])] for i in range(1,len(gt_content),2) \
35 | for j in range(2,len(gt_content),2) if j==i+1]
36 | if len(abnormal_fragment) != 0:
37 | abnormal_fragment = np.array(abnormal_fragment)
38 | for frag in abnormal_fragment:
39 | if frag[0] != -1 and frag[1] != -1:
40 | gt_vec[frag[0]:frag[1]]=1.0
41 | break
42 | gt.extend(gt_vec[:-clip_len])
43 |
44 | print(count)
45 | np.save('list/gt_ucf.npy', gt)
--------------------------------------------------------------------------------
/src/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.utils.data as data
4 | import pandas as pd
5 | import utils.tools as tools
6 |
7 | class UCFDataset(data.Dataset):
8 | def __init__(self, clip_dim: int, file_path: str, test_mode: bool, label_map: dict, normal: bool = False):
9 | self.df = pd.read_csv(file_path)
10 | self.clip_dim = clip_dim
11 | self.test_mode = test_mode
12 | self.label_map = label_map
13 | self.normal = normal
14 | if normal == True and test_mode == False:
15 | self.df = self.df.loc[self.df['label'] == 'Normal']
16 | self.df = self.df.reset_index()
17 | elif test_mode == False:
18 | self.df = self.df.loc[self.df['label'] != 'Normal']
19 | self.df = self.df.reset_index()
20 |
21 | def __len__(self):
22 | return self.df.shape[0]
23 |
24 | def __getitem__(self, index):
25 | clip_feature = np.load(self.df.loc[index]['path'])
26 | if self.test_mode == False:
27 | clip_feature, clip_length = tools.process_feat(clip_feature, self.clip_dim)
28 | else:
29 | clip_feature, clip_length = tools.process_split(clip_feature, self.clip_dim)
30 |
31 | clip_feature = torch.tensor(clip_feature)
32 | clip_label = self.df.loc[index]['label']
33 | return clip_feature, clip_label, clip_length
34 |
35 | class XDDataset(data.Dataset):
36 | def __init__(self, clip_dim: int, file_path: str, test_mode: bool, label_map: dict):
37 | self.df = pd.read_csv(file_path)
38 | self.clip_dim = clip_dim
39 | self.test_mode = test_mode
40 | self.label_map = label_map
41 |
42 | def __len__(self):
43 | return self.df.shape[0]
44 |
45 | def __getitem__(self, index):
46 | clip_feature = np.load(self.df.loc[index]['path'])
47 | if self.test_mode == False:
48 | clip_feature, clip_length = tools.process_feat(clip_feature, self.clip_dim)
49 | else:
50 | clip_feature, clip_length = tools.process_split(clip_feature, self.clip_dim)
51 |
52 | clip_feature = torch.tensor(clip_feature)
53 | clip_label = self.df.loc[index]['label']
54 | return clip_feature, clip_label, clip_length
--------------------------------------------------------------------------------
/src/utils/lr_warmup.py:
--------------------------------------------------------------------------------
1 | import math
2 | from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
3 |
4 | class WarmupMultiStepLR(MultiStepLR):
5 | r"""
6 | # max_iter = epochs * steps_per_epoch
7 | Args:
8 | optimizer (Optimizer): Wrapped optimizer.
9 | max_iter (int): The total number of steps.
10 | milestones (list) – List of iter indices. Must be increasing.
11 | gamma (float): Multiplicative factor of learning rate decay. Default: 0.1.
12 | pct_start (float): The percentage of the cycle (in number of steps) spent
13 | increasing the learning rate.
14 | Default: 0.3
15 | warmup_factor (float):
16 | last_epoch (int): The index of last epoch. Default: -1.
17 | """
18 | def __init__(self, optimizer, max_iter, milestones, gamma=0.1, pct_start=0.3, warmup_factor=1.0 / 2,
19 | last_epoch=-1):
20 | self.warmup_factor = warmup_factor
21 | self.warmup_iters = int(pct_start * max_iter)
22 | super().__init__(optimizer, milestones, gamma, last_epoch)
23 |
24 | def get_lr(self):
25 | if self.last_epoch <= self.warmup_iters:
26 | alpha = self.last_epoch / self.warmup_iters
27 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
28 | return [lr * warmup_factor for lr in self.base_lrs]
29 | else:
30 | lr = super().get_lr()
31 | return lr
32 |
33 | class WarmupCosineLR(_LRScheduler):
34 | def __init__(self, optimizer, max_iter, pct_start=0.3, warmup_factor=1.0 / 3,
35 | eta_min=0, last_epoch=-1):
36 | self.warmup_factor = warmup_factor
37 | self.warmup_iters = int(pct_start * max_iter)
38 | self.max_iter, self.eta_min = max_iter, eta_min
39 | super().__init__(optimizer)
40 |
41 | def get_lr(self):
42 | if self.last_epoch <= self.warmup_iters:
43 | alpha = self.last_epoch / self.warmup_iters
44 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
45 | return [lr * warmup_factor for lr in self.base_lrs]
46 | else:
47 | # print ("after warmup")
48 | return [self.eta_min + (base_lr - self.eta_min) *
49 | (1 + math.cos(
50 | math.pi * (self.last_epoch - self.warmup_iters) / (self.max_iter - self.warmup_iters))) / 2
51 | for base_lr in self.base_lrs]
52 |
53 | class WarmupPolyLR(_LRScheduler):
54 | def __init__(self, optimizer, T_max, pct_start=0.3, warmup_factor=1.0 / 4,
55 | eta_min=0, power=0.9):
56 | self.warmup_factor = warmup_factor
57 | self.warmup_iters = int(pct_start * T_max)
58 | self.power = power
59 | self.T_max, self.eta_min = T_max, eta_min
60 | super().__init__(optimizer)
61 |
62 | def get_lr(self):
63 | if self.last_epoch <= self.warmup_iters:
64 | alpha = self.last_epoch / self.warmup_iters
65 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha
66 | return [lr * warmup_factor for lr in self.base_lrs]
67 | else:
68 | return [self.eta_min + (base_lr - self.eta_min) *
69 | math.pow(1 - (self.last_epoch - self.warmup_iters) / (self.T_max - self.warmup_iters),
70 | self.power) for base_lr in self.base_lrs]
71 |
--------------------------------------------------------------------------------
/src/crop.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | import torch
5 | from clip import clip
6 | from PIL import Image
7 |
8 | def video_crop(video_frame, type):
9 | l = video_frame.shape[0]
10 | new_frame = []
11 | for i in range(l):
12 | img = cv2.resize(video_frame[i], dsize=(340, 256))
13 | new_frame.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
14 |
15 | #1
16 | img = np.array(new_frame)
17 | if type == 0:
18 | img = img[:, 16:240, 58:282, :]
19 | #2
20 | elif type == 1:
21 | img = img[:, :224, :224, :]
22 | #3
23 | elif type == 2:
24 | img = img[:, :224, -224:, :]
25 | #4
26 | elif type == 3:
27 | img = img[:, -224:, :224, :]
28 | #5
29 | elif type == 4:
30 | img = img[:, -224:, -224:, :]
31 | #6
32 | elif type == 5:
33 | img = img[:, 16:240, 58:282, :]
34 | for i in range(img.shape[0]):
35 | img[i] = cv2.flip(img[i], 1)
36 | #7
37 | elif type == 6:
38 | img = img[:, :224, :224, :]
39 | for i in range(img.shape[0]):
40 | img[i] = cv2.flip(img[i], 1)
41 | #8
42 | elif type == 7:
43 | img = img[:, :224, -224:, :]
44 | for i in range(img.shape[0]):
45 | img[i] = cv2.flip(img[i], 1)
46 | #9
47 | elif type == 8:
48 | img = img[:, -224:, :224, :]
49 | for i in range(img.shape[0]):
50 | img[i] = cv2.flip(img[i], 1)
51 | #10
52 | elif type == 9:
53 | img = img[:, -224:, -224:, :]
54 | for i in range(img.shape[0]):
55 | img[i] = cv2.flip(img[i], 1)
56 |
57 | return img
58 |
59 | def image_crop(image, type):
60 | img = cv2.resize(image, dsize=(340, 256))
61 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
62 | #1
63 | if type == 0:
64 | img = img[16:240, 58:282, :]
65 | #2
66 | elif type == 1:
67 | img = img[:224, :224, :]
68 | #3
69 | elif type == 2:
70 | img = img[:224, -224:, :]
71 | #4
72 | elif type == 3:
73 | img = img[-224:, :224, :]
74 | #5
75 | elif type == 4:
76 | img = img[-224:, -224:, :]
77 | #6
78 | elif type == 5:
79 | img = img[16:240, 58:282, :]
80 | img = cv2.flip(img, 1)
81 | #7
82 | elif type == 6:
83 | img = img[:224, :224, :]
84 | img = cv2.flip(img, 1)
85 | #8
86 | elif type == 7:
87 | img = img[:224, -224:, :]
88 | img = cv2.flip(img, 1)
89 | #9
90 | elif type == 8:
91 | img = img[-224:, :224, :]
92 | img = cv2.flip(img, 1)
93 | #10
94 | elif type == 9:
95 | img = img[-224:, -224:, :]
96 | img = cv2.flip(img, 1)
97 |
98 | return img
99 |
100 | if __name__ == '__main__':
101 | video = np.zeros([3, 320, 240, 3], dtype=np.uint8)
102 | corp_video = video_crop(video, 0)
103 |
104 | device = "cuda" if torch.cuda.is_available() else "cpu"
105 | model, preprocess = clip.load("ViT-B/16", device)
106 | video_features = torch.zeros(0).to(device)
107 | with torch.no_grad():
108 | for i in range(video.shape[0]):
109 | img = Image.fromarray(corp_video[i])
110 | img = preprocess(img).unsqueeze(0).to(device)
111 | feature = model.encode_image(img)
112 | video_features = torch.cat([video_features, feature], dim=0)
113 |
114 | video_features = video_features.detach().cpu().numpy()
115 | np.save('save_path', video_features)
--------------------------------------------------------------------------------
/src/utils/tools.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def get_batch_label(texts, prompt_text, label_map: dict):
5 | label_vectors = torch.zeros(0)
6 | if len(label_map) != 7:
7 | if len(label_map) == 2:
8 | for text in texts:
9 | label_vector = torch.zeros(2)
10 | if text == 'Normal':
11 | label_vector[0] = 1
12 | else:
13 | label_vector[1] = 1
14 | label_vector = label_vector.unsqueeze(0)
15 | label_vectors = torch.cat([label_vectors, label_vector], dim=0)
16 | else:
17 | for text in texts:
18 | label_vector = torch.zeros(len(prompt_text))
19 | if text in label_map:
20 | label_text = label_map[text]
21 | label_vector[prompt_text.index(label_text)] = 1
22 |
23 | label_vector = label_vector.unsqueeze(0)
24 | label_vectors = torch.cat([label_vectors, label_vector], dim=0)
25 | else:
26 | for text in texts:
27 | label_vector = torch.zeros(len(prompt_text))
28 | labels = text.split('-')
29 | for label in labels:
30 | if label in label_map:
31 | label_text = label_map[label]
32 | label_vector[prompt_text.index(label_text)] = 1
33 |
34 | label_vector = label_vector.unsqueeze(0)
35 | label_vectors = torch.cat([label_vectors, label_vector], dim=0)
36 |
37 | return label_vectors
38 |
39 | def get_prompt_text(label_map: dict):
40 | prompt_text = []
41 | for v in label_map.values():
42 | prompt_text.append(v)
43 |
44 | return prompt_text
45 |
46 | def get_batch_mask(lengths, maxlen):
47 | batch_size = lengths.shape[0]
48 | mask = torch.empty(batch_size, maxlen)
49 | mask.fill_(0)
50 | for i in range(batch_size):
51 | if lengths[i] < maxlen:
52 | mask[i, lengths[i]:maxlen] = 1
53 |
54 | return mask.bool()
55 |
56 | def random_extract(feat, t_max):
57 | r = np.random.randint(feat.shape[0] - t_max)
58 | return feat[r : r+t_max, :]
59 |
60 | def uniform_extract(feat, t_max, avg: bool = True):
61 | new_feat = np.zeros((t_max, feat.shape[1])).astype(np.float32)
62 | r = np.linspace(0, len(feat), t_max+1, dtype=np.int32)
63 | if avg == True:
64 | for i in range(t_max):
65 | if r[i]!=r[i+1]:
66 | new_feat[i,:] = np.mean(feat[r[i]:r[i+1],:], 0)
67 | else:
68 | new_feat[i,:] = feat[r[i],:]
69 | else:
70 | r = np.linspace(0, feat.shape[0]-1, t_max, dtype=np.uint16)
71 | new_feat = feat[r, :]
72 |
73 | return new_feat
74 |
75 | def pad(feat, min_len):
76 | clip_length = feat.shape[0]
77 | if clip_length <= min_len:
78 | return np.pad(feat, ((0, min_len - clip_length), (0, 0)), mode='constant', constant_values=0)
79 | else:
80 | return feat
81 |
82 | def process_feat(feat, length, is_random=False):
83 | clip_length = feat.shape[0]
84 | if feat.shape[0] > length:
85 | if is_random:
86 | return random_extract(feat, length), length
87 | else:
88 | return uniform_extract(feat, length), length
89 | else:
90 | return pad(feat, length), clip_length
91 |
92 | def process_split(feat, length):
93 | clip_length = feat.shape[0]
94 | if clip_length < length:
95 | return pad(feat, length), clip_length
96 | else:
97 | split_num = int(clip_length / length) + 1
98 | for i in range(split_num):
99 | if i == 0:
100 | split_feat = feat[i*length:i*length+length, :].reshape(1, length, feat.shape[1])
101 | elif i < split_num - 1:
102 | split_feat = np.concatenate([split_feat, feat[i*length:i*length+length, :].reshape(1, length, feat.shape[1])], axis=0)
103 | else:
104 | split_feat = np.concatenate([split_feat, pad(feat[i*length:i*length+length, :], length).reshape(1, length, feat.shape[1])], axis=0)
105 |
106 | return split_feat, clip_length
--------------------------------------------------------------------------------
/src/xd_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | from sklearn.metrics import average_precision_score, roc_auc_score
7 |
8 | from model import CLIPVAD
9 | from utils.dataset import XDDataset
10 | from utils.tools import get_batch_mask, get_prompt_text
11 | from utils.xd_detectionMAP import getDetectionMAP as dmAP
12 | import xd_option
13 |
14 | def test(model, testdataloader, maxlen, prompt_text, gt, gtsegments, gtlabels, device):
15 |
16 | model.to(device)
17 | model.eval()
18 |
19 | element_logits2_stack = []
20 |
21 | with torch.no_grad():
22 | for i, item in enumerate(testdataloader):
23 | visual = item[0].squeeze(0)
24 | length = item[2]
25 |
26 | length = int(length)
27 | len_cur = length
28 | if len_cur < maxlen:
29 | visual = visual.unsqueeze(0)
30 |
31 | visual = visual.to(device)
32 |
33 | lengths = torch.zeros(int(length / maxlen) + 1)
34 | for j in range(int(length / maxlen) + 1):
35 | if j == 0 and length < maxlen:
36 | lengths[j] = length
37 | elif j == 0 and length > maxlen:
38 | lengths[j] = maxlen
39 | length -= maxlen
40 | elif length > maxlen:
41 | lengths[j] = maxlen
42 | length -= maxlen
43 | else:
44 | lengths[j] = length
45 | lengths = lengths.to(int)
46 | padding_mask = get_batch_mask(lengths, maxlen).to(device)
47 | _, logits1, logits2 = model(visual, padding_mask, prompt_text, lengths)
48 | logits1 = logits1.reshape(logits1.shape[0] * logits1.shape[1], logits1.shape[2])
49 | logits2 = logits2.reshape(logits2.shape[0] * logits2.shape[1], logits2.shape[2])
50 | prob2 = (1 - logits2[0:len_cur].softmax(dim=-1)[:, 0].squeeze(-1))
51 | prob1 = torch.sigmoid(logits1[0:len_cur].squeeze(-1))
52 |
53 | if i == 0:
54 | ap1 = prob1
55 | ap2 = prob2
56 | else:
57 | ap1 = torch.cat([ap1, prob1], dim=0)
58 | ap2 = torch.cat([ap2, prob2], dim=0)
59 |
60 | element_logits2 = logits2[0:len_cur].softmax(dim=-1).detach().cpu().numpy()
61 | element_logits2 = np.repeat(element_logits2, 16, 0)
62 | element_logits2_stack.append(element_logits2)
63 |
64 | ap1 = ap1.cpu().numpy()
65 | ap2 = ap2.cpu().numpy()
66 | ap1 = ap1.tolist()
67 | ap2 = ap2.tolist()
68 |
69 | ROC1 = roc_auc_score(gt, np.repeat(ap1, 16))
70 | AP1 = average_precision_score(gt, np.repeat(ap1, 16))
71 | ROC2 = roc_auc_score(gt, np.repeat(ap2, 16))
72 | AP2 = average_precision_score(gt, np.repeat(ap2, 16))
73 |
74 | print("AUC1: ", ROC1, " AP1: ", AP1)
75 | print("AUC2: ", ROC2, " AP2:", AP2)
76 |
77 | dmap, iou = dmAP(element_logits2_stack, gtsegments, gtlabels, excludeNormal=False)
78 | averageMAP = 0
79 | for i in range(5):
80 | print('mAP@{0:.1f} ={1:.2f}%'.format(iou[i], dmap[i]))
81 | averageMAP += dmap[i]
82 | averageMAP = averageMAP/(i+1)
83 | print('average MAP: {:.2f}'.format(averageMAP))
84 |
85 | return ROC1, AP2 ,0#, averageMAP
86 |
87 |
88 | if __name__ == '__main__':
89 | device = "cuda" if torch.cuda.is_available() else "cpu"
90 | args = xd_option.parser.parse_args()
91 |
92 | label_map = dict({'A': 'normal', 'B1': 'fighting', 'B2': 'shooting', 'B4': 'riot', 'B5': 'abuse', 'B6': 'car accident', 'G': 'explosion'})
93 |
94 | test_dataset = XDDataset(args.visual_length, args.test_list, True, label_map)
95 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
96 |
97 | prompt_text = get_prompt_text(label_map)
98 | gt = np.load(args.gt_path)
99 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
100 | gtlabels = np.load(args.gt_label_path, allow_pickle=True)
101 |
102 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)
103 | model_param = torch.load(args.model_path)
104 | model.load_state_dict(model_param)
105 |
106 | test(model, test_loader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)
--------------------------------------------------------------------------------
/src/ucf_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import DataLoader
5 | import numpy as np
6 | from sklearn.metrics import average_precision_score, roc_auc_score
7 |
8 | from model import CLIPVAD
9 | from utils.dataset import UCFDataset
10 | from utils.tools import get_batch_mask, get_prompt_text
11 | from utils.ucf_detectionMAP import getDetectionMAP as dmAP
12 | import ucf_option
13 |
14 | def test(model, testdataloader, maxlen, prompt_text, gt, gtsegments, gtlabels, device):
15 |
16 | model.to(device)
17 | model.eval()
18 |
19 | element_logits2_stack = []
20 |
21 | with torch.no_grad():
22 | for i, item in enumerate(testdataloader):
23 | visual = item[0].squeeze(0)
24 | length = item[2]
25 |
26 | length = int(length)
27 | len_cur = length
28 | if len_cur < maxlen:
29 | visual = visual.unsqueeze(0)
30 |
31 | visual = visual.to(device)
32 |
33 | lengths = torch.zeros(int(length / maxlen) + 1)
34 | for j in range(int(length / maxlen) + 1):
35 | if j == 0 and length < maxlen:
36 | lengths[j] = length
37 | elif j == 0 and length > maxlen:
38 | lengths[j] = maxlen
39 | length -= maxlen
40 | elif length > maxlen:
41 | lengths[j] = maxlen
42 | length -= maxlen
43 | else:
44 | lengths[j] = length
45 | lengths = lengths.to(int)
46 | padding_mask = get_batch_mask(lengths, maxlen).to(device)
47 | _, logits1, logits2 = model(visual, padding_mask, prompt_text, lengths)
48 | logits1 = logits1.reshape(logits1.shape[0] * logits1.shape[1], logits1.shape[2])
49 | logits2 = logits2.reshape(logits2.shape[0] * logits2.shape[1], logits2.shape[2])
50 | prob2 = (1 - logits2[0:len_cur].softmax(dim=-1)[:, 0].squeeze(-1))
51 | prob1 = torch.sigmoid(logits1[0:len_cur].squeeze(-1))
52 |
53 | if i == 0:
54 | ap1 = prob1
55 | ap2 = prob2
56 | #ap3 = prob3
57 | else:
58 | ap1 = torch.cat([ap1, prob1], dim=0)
59 | ap2 = torch.cat([ap2, prob2], dim=0)
60 |
61 | element_logits2 = logits2[0:len_cur].softmax(dim=-1).detach().cpu().numpy()
62 | element_logits2 = np.repeat(element_logits2, 16, 0)
63 | element_logits2_stack.append(element_logits2)
64 |
65 | ap1 = ap1.cpu().numpy()
66 | ap2 = ap2.cpu().numpy()
67 | ap1 = ap1.tolist()
68 | ap2 = ap2.tolist()
69 |
70 | ROC1 = roc_auc_score(gt, np.repeat(ap1, 16))
71 | AP1 = average_precision_score(gt, np.repeat(ap1, 16))
72 | ROC2 = roc_auc_score(gt, np.repeat(ap2, 16))
73 | AP2 = average_precision_score(gt, np.repeat(ap2, 16))
74 |
75 | print("AUC1: ", ROC1, " AP1: ", AP1)
76 | print("AUC2: ", ROC2, " AP2:", AP2)
77 |
78 | dmap, iou = dmAP(element_logits2_stack, gtsegments, gtlabels, excludeNormal=False)
79 | averageMAP = 0
80 | for i in range(5):
81 | print('mAP@{0:.1f} ={1:.2f}%'.format(iou[i], dmap[i]))
82 | averageMAP += dmap[i]
83 | averageMAP = averageMAP/(i+1)
84 | print('average MAP: {:.2f}'.format(averageMAP))
85 |
86 | return ROC1, AP1
87 |
88 |
89 | if __name__ == '__main__':
90 | device = "cuda" if torch.cuda.is_available() else "cpu"
91 | args = ucf_option.parser.parse_args()
92 |
93 | label_map = dict({'Normal': 'Normal', 'Abuse': 'Abuse', 'Arrest': 'Arrest', 'Arson': 'Arson', 'Assault': 'Assault', 'Burglary': 'Burglary', 'Explosion': 'Explosion', 'Fighting': 'Fighting', 'RoadAccidents': 'RoadAccidents', 'Robbery': 'Robbery', 'Shooting': 'Shooting', 'Shoplifting': 'Shoplifting', 'Stealing': 'Stealing', 'Vandalism': 'Vandalism'})
94 |
95 | testdataset = UCFDataset(args.visual_length, args.test_list, True, label_map)
96 | testdataloader = DataLoader(testdataset, batch_size=1, shuffle=False)
97 |
98 | prompt_text = get_prompt_text(label_map)
99 | gt = np.load(args.gt_path)
100 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
101 | gtlabels = np.load(args.gt_label_path, allow_pickle=True)
102 |
103 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)
104 | model_param = torch.load(args.model_path)
105 | model.load_state_dict(model_param)
106 |
107 | test(model, testdataloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VadCLIP
2 | This is the official Pytorch implementation of our paper:
3 | **"VadCLIP: Adapting Vision-Language Models for Weakly Supervised Video Anomaly Detection"** in **AAAI 2024.**
4 | > Peng Wu, Xuerong Zhou, Guansong Pang, Lingru Zhou, Qingsen Yan, Peng Wang, Yanning Zhang
5 |
6 | 
7 |
8 | ## Highlight
9 | - We present a novel diagram, i.e., VadCLIP, which involves dual branch to detect video anomaly in visual classification and language-visual alignment manners, respectively. With the benefit of dual branch, VadCLIP achieves both coarse-grained and fine-grained WSVAD. To our knowledge, **VadCLIP is the first work to efficiently transfer pre-trained language-visual knowledge to WSVAD**.
10 |
11 | - We propose three non-vital components to address new challenges led by the new diagram. LGT-Adapter is used to capture temporal dependencies from different perspectives; Two prompt mechanisms are devised to effectively adapt the frozen pre-trained model to WSVAD task; MIL-Align realizes the optimization of alignment paradigm under weak supervision, so as to preserve the pre-trained knowledge as much as possible.
12 |
13 | - We show that strength and effectiveness of VadCLIP on two large-scale popular benchmarks, and VadCLIP achieves state-of-the-art performance, e.g., it gets unprecedented results of 84.51\% AP and 88.02\% on XD-Violence and UCF-Crime respectively, surpassing current classification based methods by a large margin.
14 |
15 | ## Training
16 |
17 | ### Setup
18 | We extract CLIP features for UCF-Crime and XD-Violence datasets, and release these features and pretrained models as follows:
19 |
20 | | Benchmark | CLIP[Baidu] | CLIP | Model[Baidu] | Model |
21 | |--------|----------|-----------|-------------|------------|
22 | | UCF-Crime | [Code: 7yzp](https://pan.baidu.com/s/1OKRIxoLcxt-7RYxWpylgLQ) | [OneDrive](https://stuxidianeducn-my.sharepoint.com/:u:/g/personal/pengwu_stu_xidian_edu_cn/Ea86YOcp5z9KhRFDQm9a8zwBcGiGGg5BuBJtgmCVByazBQ?e=tqHLHt) | [Code: kq5u](https://pan.baidu.com/s/1_9bTC99FklrZRnkmYMuJQw) | [OneDrive](https://stuxidianeducn-my.sharepoint.com/:u:/g/personal/pengwu_stu_xidian_edu_cn/Eaz6sn40RmlFmjELcNHW1IkBV7C0U5OrOaHcuLFzH2S0-Q?e=x8wtVe) |
23 | | XD-Violence | [Code: v8tw](https://pan.baidu.com/s/1q8DiYHcPJtrBQiiJMI7aJw)| [OneDrive](https://stuxidianeducn-my.sharepoint.com/:f:/g/personal/pengwu_stu_xidian_edu_cn/Et5dWQZb2cBDs7zsrp90SrQBL_52vTRNYTdjQW6SMl0ZVA?e=foX4ph) | [Code: apw6](https://pan.baidu.com/s/1O0uwVS3ZyDA1soWUv2VasQ) | [OneDrive](https://stuxidianeducn-my.sharepoint.com/:u:/g/personal/pengwu_stu_xidian_edu_cn/EYlNnn_xfVxBtQZuQgngrMsBHY-i8QHTVOs7PmryzQ2MyA?e=99nxnR) |
24 |
25 |
26 |
27 |
28 | The following files need to be adapted in order to run the code on your own machine:
29 | - Change the file paths to the download datasets above in `list/xd_CLIP_rgb.csv` and `list/xd_CLIP_rgbtest.csv`.
30 | - Feel free to change the hyperparameters in `xd_option.py`
31 | ### Train and Test
32 | After the setup, simply run the following command:
33 |
34 |
35 | Traing and infer for XD-Violence dataset
36 | ```
37 | python xd_train.py
38 | python xd_test.py
39 | ```
40 | Traing and infer for UCF-Crime dataset
41 | ```
42 | python ucf_train.py
43 | python ucf_test.py
44 | ```
45 |
46 | ## References
47 | We referenced the repos below for the code.
48 | * [XDVioDet](https://github.com/Roc-Ng/XDVioDet)
49 | * [DeepMIL](https://github.com/Roc-Ng/DeepMIL)
50 |
51 | ## Citation
52 |
53 | If you find this repo useful for your research, please consider citing our paper:
54 |
55 | ```bibtex
56 | @article{wu2023vadclip,
57 | title={Vadclip: Adapting vision-language models for weakly supervised video anomaly detection},
58 | author={Wu, Peng and Zhou, Xuerong and Pang, Guansong and Zhou, Lingru and Yan, Qingsen and Wang, Peng and Zhang, Yanning},
59 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence (AAAI)},
60 | year={2024}
61 | }
62 |
63 | @article{wu2023open,
64 | title={Open-Vocabulary Video Anomaly Detection},
65 | author={Wu, Peng and Zhou, Xuerong and Pang, Guansong and Sun, Yujia and Liu, Jing and Wang, Peng and Zhang, Yanning},
66 | journal={arXiv preprint arXiv:2311.07042},
67 | year={2023}
68 | }
69 |
70 | ```
71 | ---
72 |
--------------------------------------------------------------------------------
/src/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/src/utils/xd_detectionMAP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.signal import savgol_filter
3 |
4 | def smooth(v):
5 | return v
6 | # l = min(5, len(v)); l = l - (1-l%2)
7 | # if len(v) <= 3:
8 | # return v
9 | # return savgol_filter(v, l, 1) #savgol_filter(v, l, 1) #0.5*(np.concatenate([v[1:],v[-1:]],axis=0) + v)
10 |
11 | def str2ind(categoryname,classlist):
12 | return [i for i in range(len(classlist)) if categoryname == classlist[i]][0]
13 |
14 | def nms(dets, thresh=0.6, top_k=-1):
15 | """Pure Python NMS baseline."""
16 | # dets: N*2 and sorted by scores
17 | if len(dets) == 0: return []
18 | order = np.arange(0,len(dets),1)
19 | dets = np.array(dets)
20 | x1 = dets[:, 0] # start
21 | x2 = dets[:, 1] # end
22 | lengths = x2 - x1
23 | keep = []
24 | while order.size > 0:
25 | i = order[0] # the first is the best proposal
26 | keep.append(i) # put into the candidate pool
27 | if len(keep) == top_k:
28 | break
29 | xx1 = np.maximum(x1[i], x1[order[1:]])
30 | xx2 = np.minimum(x2[i], x2[order[1:]])
31 | inter = np.maximum(0.0, xx2 - xx1) ## the intersection
32 | ovr = inter / (lengths[i] + lengths[order[1:]] - inter) ## the iou
33 | inds = np.where(ovr <= thresh)[0] # the index of remaining proposals
34 | order = order[inds + 1] # add 1
35 |
36 | return dets[keep], keep
37 |
38 | def getLocMAP(predictions, th, gtsegments, gtlabels, excludeNormal):
39 | if excludeNormal is True:
40 | classes_num = 6
41 | videos_num = 500
42 | predictions = predictions[:videos_num]
43 | else:
44 | classes_num = 7
45 | videos_num = 800
46 |
47 | classlist = ['A', 'B1', 'B2', 'B4', 'B5', 'B6', 'G']
48 | predictions_mod = []
49 | c_score = []
50 | for p in predictions:
51 | pp = - p
52 | [pp[:,i].sort() for i in range(np.shape(pp)[1])]
53 | pp=-pp
54 | idx_temp = int(np.shape(pp)[0]/16)
55 | c_s = np.mean(pp[:idx_temp, :], axis=0)
56 | ind = c_s > 0.0
57 | c_score.append(c_s)
58 | predictions_mod.append(p*ind)
59 | predictions = predictions_mod
60 | ap = []
61 | for c in range(0, 7):
62 | segment_predict = []
63 | # Get list of all predictions for class c
64 | for i in range(len(predictions)):
65 | tmp = smooth(predictions[i][:, c])
66 | segment_predict_multithr = []
67 | thr_set = np.arange(0.6, 0.7, 0.1)
68 | for thr in thr_set:
69 | threshold = np.max(tmp) - (np.max(tmp) - np.min(tmp))*thr ### 0.8 is the best?
70 | vid_pred = np.concatenate([np.zeros(1), (tmp>threshold).astype('float32'), np.zeros(1)], axis=0)
71 | vid_pred_diff = [vid_pred[idt]-vid_pred[idt-1] for idt in range(1, len(vid_pred))]
72 | s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1]
73 | e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1]
74 | for j in range(len(s)):
75 | if e[j]-s[j]>=2:
76 | segment_scores = np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c]
77 | segment_predict_multithr.append([i, s[j], e[j], segment_scores])
78 | # segment_predict.append([i, s[j], e[j], np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c]])
79 | if len(segment_predict_multithr)!=0:
80 | segment_predict_multithr = np.array(segment_predict_multithr)
81 | segment_predict_multithr = segment_predict_multithr[np.argsort(-segment_predict_multithr[:,-1])]
82 | _, keep = nms(segment_predict_multithr[:, 1:-1], 0.6)
83 | segment_predict.extend(list(segment_predict_multithr[keep]))
84 | segment_predict = np.array(segment_predict)
85 |
86 | # Sort the list of predictions for class c based on score
87 | if len(segment_predict) == 0:
88 | return 0
89 | segment_predict = segment_predict[np.argsort(-segment_predict[:,3])]
90 |
91 | # Create gt list
92 | segment_gt = [[i, gtsegments[i][j][0], gtsegments[i][j][1]] for i in range(len(gtsegments))
93 | for j in range(len(gtsegments[i])) if str2ind(gtlabels[i][j], classlist) == c]
94 | gtpos = len(segment_gt)
95 |
96 | # Compare predictions and gt
97 | tp, fp = [], []
98 | for i in range(len(segment_predict)):
99 | flag = 0.
100 | best_iou = 0.0
101 | for j in range(len(segment_gt)):
102 | if segment_predict[i][0]==segment_gt[j][0]:
103 | gt = range(int(segment_gt[j][1]), int(segment_gt[j][2]))
104 | p = range(int(segment_predict[i][1]), int(segment_predict[i][2]))
105 | IoU = float(len(set(gt).intersection(set(p))))/float(len(set(gt).union(set(p))))
106 | if IoU >= th:
107 | flag = 1.
108 | if IoU > best_iou:
109 | best_iou = IoU
110 | best_j = j
111 | if flag > 0:
112 | del segment_gt[best_j]
113 | tp.append(flag)
114 | fp.append(1.-flag)
115 | tp_c = np.cumsum(tp)
116 | fp_c = np.cumsum(fp)
117 | if sum(tp)==0:
118 | prc = 0.
119 | else:
120 | prc = np.sum((tp_c/(fp_c+tp_c))*tp)/gtpos
121 | ap.append(prc)
122 | # print(np.round(prc, 4))
123 | return 100*np.mean(ap)
124 |
125 |
126 | def getDetectionMAP(predictions, segments, labels, excludeNormal=False):
127 | iou_list = [0.1, 0.2, 0.3, 0.4, 0.5]
128 | # iou_list = [0.5]
129 | dmap_list = []
130 | for iou in iou_list:
131 | # print('Testing for IoU {:.1f}'.format(iou))
132 | dmap_list.append(getLocMAP(predictions, iou, segments, labels, excludeNormal))
133 | return dmap_list, iou_list
134 |
135 |
--------------------------------------------------------------------------------
/src/utils/ucf_detectionMAP.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.signal import savgol_filter
3 |
4 | def smooth(v):
5 | return v
6 | # l = min(5, len(v)); l = l - (1-l%2)
7 | # if len(v) <= 3:
8 | # return v
9 | # return savgol_filter(v, l, 1) #savgol_filter(v, l, 1) #0.5*(np.concatenate([v[1:],v[-1:]],axis=0) + v)
10 |
11 | def str2ind(categoryname,classlist):
12 | for i in range(len(classlist)):
13 | if categoryname == classlist[i]:
14 | return i
15 |
16 | def nms(dets, thresh=0.6, top_k=-1):
17 | """Pure Python NMS baseline."""
18 | # dets: N*2 and sorted by scores
19 | if len(dets) == 0: return []
20 | order = np.arange(0,len(dets),1)
21 | dets = np.array(dets)
22 | x1 = dets[:, 0] # start
23 | x2 = dets[:, 1] # end
24 | lengths = x2 - x1
25 | keep = []
26 | while order.size > 0:
27 | i = order[0] # the first is the best proposal
28 | keep.append(i) # put into the candidate pool
29 | if len(keep) == top_k:
30 | break
31 | xx1 = np.maximum(x1[i], x1[order[1:]])
32 | xx2 = np.minimum(x2[i], x2[order[1:]])
33 | inter = np.maximum(0.0, xx2 - xx1) ## the intersection
34 | ovr = inter / (lengths[i] + lengths[order[1:]] - inter) ## the iou
35 | inds = np.where(ovr <= thresh)[0] # the index of remaining proposals
36 | order = order[inds + 1] # add 1
37 |
38 | return dets[keep], keep
39 |
40 | def getLocMAP(predictions, th, gtsegments, gtlabels, excludeNormal):
41 | if excludeNormal is True:
42 | classes_num = 13
43 | videos_num = 140
44 | predictions = predictions[:videos_num]
45 | else:
46 | classes_num = 14
47 | videos_num = 290
48 |
49 | classlist = ['Normal', 'Abuse', 'Arrest', 'Arson', 'Assault', 'Burglary', 'Explosion', 'Fighting', 'RoadAccidents', 'Robbery', 'Shooting', 'Shoplifting', 'Stealing', 'Vandalism']
50 | predictions_mod = []
51 | c_score = []
52 | for p in predictions:
53 | pp = - p; [pp[:,i].sort() for i in range(np.shape(pp)[1])]; pp=-pp
54 | c_s = np.mean(pp[:int(np.shape(pp)[0]/16), :], axis=0)
55 | ind = c_s > 0.0
56 | c_score.append(c_s)
57 | predictions_mod.append(p*ind)
58 | predictions = predictions_mod
59 | ap = []
60 | for c in range(0, 14):
61 | segment_predict = []
62 | # Get list of all predictions for class c
63 | for i in range(len(predictions)):
64 | tmp = smooth(predictions[i][:, c])
65 | segment_predict_multithr = []
66 | thr_set = np.arange(0.6, 0.7, 0.1)
67 | for thr in thr_set:
68 | threshold = np.max(tmp) - (np.max(tmp) - np.min(tmp))*thr ### 0.8 is the best?
69 | vid_pred = np.concatenate([np.zeros(1), (tmp>threshold).astype('float32'), np.zeros(1)], axis=0)
70 | vid_pred_diff = [vid_pred[idt]-vid_pred[idt-1] for idt in range(1, len(vid_pred))]
71 | s = [idk for idk, item in enumerate(vid_pred_diff) if item == 1]
72 | e = [idk for idk, item in enumerate(vid_pred_diff) if item == -1]
73 | for j in range(len(s)):
74 | if e[j]-s[j]>=2:
75 | segment_scores = np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c]
76 | segment_predict_multithr.append([i, s[j], e[j], segment_scores])
77 | # segment_predict.append([i, s[j], e[j], np.max(tmp[s[j]:e[j]])+0.7*c_score[i][c]])
78 | if len(segment_predict_multithr)!=0:
79 | segment_predict_multithr = np.array(segment_predict_multithr)
80 | segment_predict_multithr = segment_predict_multithr[np.argsort(-segment_predict_multithr[:,-1])]
81 | _, keep = nms(segment_predict_multithr[:, 1:-1], 0.6)
82 | segment_predict.extend(list(segment_predict_multithr[keep]))
83 | segment_predict = np.array(segment_predict)
84 |
85 | # Sort the list of predictions for class c based on score
86 | if len(segment_predict) == 0:
87 | return 0
88 | segment_predict = segment_predict[np.argsort(-segment_predict[:,3])]
89 |
90 | # Create gt list
91 | segment_gt = [[i, gtsegments[i][j][0], gtsegments[i][j][1]] for i in range(len(gtsegments))
92 | for j in range(len(gtsegments[i])) if str2ind(gtlabels[i][j], classlist) == c]
93 | gtpos = len(segment_gt)
94 |
95 | # Compare predictions and gt
96 | tp, fp = [], []
97 | for i in range(len(segment_predict)):
98 | flag = 0.
99 | best_iou = 0.0
100 | for j in range(len(segment_gt)):
101 | if segment_predict[i][0]==segment_gt[j][0]:
102 | gt = range(int(segment_gt[j][1]), int(segment_gt[j][2]))
103 | p = range(int(segment_predict[i][1]), int(segment_predict[i][2]))
104 | IoU = float(len(set(gt).intersection(set(p))))/float(len(set(gt).union(set(p))))
105 | if IoU >= th:
106 | flag = 1.
107 | if IoU > best_iou:
108 | best_iou = IoU
109 | best_j = j
110 | if flag > 0:
111 | del segment_gt[best_j]
112 | tp.append(flag)
113 | fp.append(1.-flag)
114 | tp_c = np.cumsum(tp)
115 | fp_c = np.cumsum(fp)
116 | if sum(tp)==0:
117 | prc = 0.
118 | else:
119 | prc = np.sum((tp_c/(fp_c+tp_c))*tp)/gtpos
120 | ap.append(prc)
121 | # print(np.round(prc, 4))
122 | return 100*np.mean(ap)
123 |
124 |
125 | def getDetectionMAP(predictions, segments, labels, excludeNormal=False):
126 | iou_list = [0.1, 0.2, 0.3, 0.4, 0.5]
127 | # iou_list = [0.5]
128 | dmap_list = []
129 | for iou in iou_list:
130 | # print('Testing for IoU {:.1f}'.format(iou))
131 | dmap_list.append(getLocMAP(predictions, iou, segments, labels, excludeNormal))
132 | return dmap_list, iou_list
133 |
134 |
--------------------------------------------------------------------------------
/src/xd_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import DataLoader
5 | from torch.optim.lr_scheduler import MultiStepLR
6 | import numpy as np
7 | import random
8 |
9 | from model import CLIPVAD
10 | from xd_test import test
11 | from utils.dataset import XDDataset
12 | from utils.tools import get_prompt_text, get_batch_label
13 | import xd_option
14 |
15 | def CLASM(logits, labels, lengths, device):
16 | instance_logits = torch.zeros(0).to(device)
17 | labels = labels / torch.sum(labels, dim=1, keepdim=True)
18 | labels = labels.to(device)
19 |
20 | for i in range(logits.shape[0]):
21 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True, dim=0)
22 | instance_logits = torch.cat([instance_logits, torch.mean(tmp, 0, keepdim=True)], dim=0)
23 |
24 | milloss = -torch.mean(torch.sum(labels * F.log_softmax(instance_logits, dim=1), dim=1), dim=0)
25 | return milloss
26 |
27 | def CLAS2(logits, labels, lengths, device):
28 | instance_logits = torch.zeros(0).to(device)
29 | labels = 1 - labels[:, 0].reshape(labels.shape[0])
30 | labels = labels.to(device)
31 | logits = torch.sigmoid(logits).reshape(logits.shape[0], logits.shape[1])
32 |
33 | for i in range(logits.shape[0]):
34 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True)
35 | tmp = torch.mean(tmp).view(1)
36 | instance_logits = torch.cat((instance_logits, tmp))
37 |
38 | clsloss = F.binary_cross_entropy(instance_logits, labels)
39 | return clsloss
40 |
41 | def train(model, train_loader, test_loader, args, label_map: dict, device):
42 | model.to(device)
43 |
44 | gt = np.load(args.gt_path)
45 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
46 | gtlabels = np.load(args.gt_label_path, allow_pickle=True)
47 |
48 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
49 | scheduler = MultiStepLR(optimizer, args.scheduler_milestones, args.scheduler_rate)
50 | prompt_text = get_prompt_text(label_map)
51 | ap_best = 0
52 | epoch = 0
53 |
54 | if args.use_checkpoint == True:
55 | checkpoint = torch.load(args.checkpoint_path)
56 | model.load_state_dict(checkpoint['model_state_dict'])
57 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
58 | epoch = checkpoint['epoch']
59 | ap_best = checkpoint['ap']
60 | print("checkpoint info:")
61 | print("epoch:", epoch+1, " ap:", ap_best)
62 |
63 | for e in range(args.max_epoch):
64 | model.train()
65 | loss_total1 = 0
66 | loss_total2 = 0
67 | for i, item in enumerate(train_loader):
68 | step = 0
69 | visual_feat, text_labels, feat_lengths = item
70 | visual_feat = visual_feat.to(device)
71 | feat_lengths = feat_lengths.to(device)
72 | text_labels = get_batch_label(text_labels, prompt_text, label_map).to(device)
73 |
74 | text_features, logits1, logits2 = model(visual_feat, None, prompt_text, feat_lengths)
75 |
76 | loss1 = CLAS2(logits1, text_labels, feat_lengths, device)
77 | loss_total1 += loss1.item()
78 |
79 | loss2 = CLASM(logits2, text_labels, feat_lengths, device)
80 | loss_total2 += loss2.item()
81 |
82 | loss3 = torch.zeros(1).to(device)
83 | text_feature_normal = text_features[0] / text_features[0].norm(dim=-1, keepdim=True)
84 | for j in range(1, text_features.shape[0]):
85 | text_feature_abr = text_features[j] / text_features[j].norm(dim=-1, keepdim=True)
86 | loss3 += torch.abs(text_feature_normal @ text_feature_abr)
87 | loss3 = loss3 / 6
88 |
89 | loss = loss1 + loss2 + loss3 * 1e-4
90 |
91 | optimizer.zero_grad()
92 | loss.backward()
93 | optimizer.step()
94 | step += i * train_loader.batch_size
95 | if step % 4800 == 0 and step != 0:
96 | print('epoch: ', e+1, '| step: ', step, '| loss1: ', loss_total1 / (i+1), '| loss2: ', loss_total2 / (i+1), '| loss3: ', loss3.item())
97 |
98 | scheduler.step()
99 | AUC, AP, mAP = test(model, test_loader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)
100 |
101 | if AP > ap_best:
102 | ap_best = AP
103 | checkpoint = {
104 | 'epoch': e,
105 | 'model_state_dict': model.state_dict(),
106 | 'optimizer_state_dict': optimizer.state_dict(),
107 | 'ap': ap_best}
108 | torch.save(checkpoint, args.checkpoint_path)
109 |
110 | checkpoint = torch.load(args.checkpoint_path)
111 | model.load_state_dict(checkpoint['model_state_dict'])
112 |
113 | checkpoint = torch.load(args.checkpoint_path)
114 | torch.save(checkpoint['model_state_dict'], args.model_path)
115 |
116 | def setup_seed(seed):
117 | torch.manual_seed(seed)
118 | torch.cuda.manual_seed_all(seed)
119 | np.random.seed(seed)
120 | random.seed(seed)
121 | #torch.backends.cudnn.deterministic = True
122 |
123 | if __name__ == '__main__':
124 | device = "cuda" if torch.cuda.is_available() else "cpu"
125 | args = xd_option.parser.parse_args()
126 | setup_seed(args.seed)
127 |
128 | label_map = dict({'A': 'normal', 'B1': 'fighting', 'B2': 'shooting', 'B4': 'riot', 'B5': 'abuse', 'B6': 'car accident', 'G': 'explosion'})
129 |
130 | train_dataset = XDDataset(args.visual_length, args.train_list, False, label_map)
131 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
132 |
133 | test_dataset = XDDataset(args.visual_length, args.test_list, True, label_map)
134 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
135 |
136 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)
137 | train(model, train_loader, test_loader, args, label_map, device)
--------------------------------------------------------------------------------
/src/ucf_train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from torch.utils.data import DataLoader
5 | from torch.optim.lr_scheduler import MultiStepLR
6 | import numpy as np
7 | import random
8 |
9 | from model import CLIPVAD
10 | from ucf_test import test
11 | from utils.dataset import UCFDataset
12 | from utils.tools import get_prompt_text, get_batch_label
13 | import ucf_option
14 |
15 | def CLASM(logits, labels, lengths, device):
16 | instance_logits = torch.zeros(0).to(device)
17 | labels = labels / torch.sum(labels, dim=1, keepdim=True)
18 | labels = labels.to(device)
19 |
20 | for i in range(logits.shape[0]):
21 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True, dim=0)
22 | instance_logits = torch.cat([instance_logits, torch.mean(tmp, 0, keepdim=True)], dim=0)
23 |
24 | milloss = -torch.mean(torch.sum(labels * F.log_softmax(instance_logits, dim=1), dim=1), dim=0)
25 | return milloss
26 |
27 | def CLAS2(logits, labels, lengths, device):
28 | instance_logits = torch.zeros(0).to(device)
29 | labels = 1 - labels[:, 0].reshape(labels.shape[0])
30 | labels = labels.to(device)
31 | logits = torch.sigmoid(logits).reshape(logits.shape[0], logits.shape[1])
32 |
33 | for i in range(logits.shape[0]):
34 | tmp, _ = torch.topk(logits[i, 0:lengths[i]], k=int(lengths[i] / 16 + 1), largest=True)
35 | tmp = torch.mean(tmp).view(1)
36 | instance_logits = torch.cat([instance_logits, tmp], dim=0)
37 |
38 | clsloss = F.binary_cross_entropy(instance_logits, labels)
39 | return clsloss
40 |
41 | def train(model, normal_loader, anomaly_loader, testloader, args, label_map, device):
42 | model.to(device)
43 | gt = np.load(args.gt_path)
44 | gtsegments = np.load(args.gt_segment_path, allow_pickle=True)
45 | gtlabels = np.load(args.gt_label_path, allow_pickle=True)
46 |
47 | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
48 | scheduler = MultiStepLR(optimizer, args.scheduler_milestones, args.scheduler_rate)
49 | prompt_text = get_prompt_text(label_map)
50 | ap_best = 0
51 | epoch = 0
52 |
53 | if args.use_checkpoint == True:
54 | checkpoint = torch.load(args.checkpoint_path)
55 | model.load_state_dict(checkpoint['model_state_dict'])
56 | optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
57 | epoch = checkpoint['epoch']
58 | ap_best = checkpoint['ap']
59 | print("checkpoint info:")
60 | print("epoch:", epoch+1, " ap:", ap_best)
61 |
62 | for e in range(args.max_epoch):
63 | model.train()
64 | loss_total1 = 0
65 | loss_total2 = 0
66 | normal_iter = iter(normal_loader)
67 | anomaly_iter = iter(anomaly_loader)
68 | for i in range(min(len(normal_loader), len(anomaly_loader))):
69 | step = 0
70 | normal_features, normal_label, normal_lengths = next(normal_iter)
71 | anomaly_features, anomaly_label, anomaly_lengths = next(anomaly_iter)
72 |
73 | visual_features = torch.cat([normal_features, anomaly_features], dim=0).to(device)
74 | text_labels = list(normal_label) + list(anomaly_label)
75 | feat_lengths = torch.cat([normal_lengths, anomaly_lengths], dim=0).to(device)
76 | text_labels = get_batch_label(text_labels, prompt_text, label_map).to(device)
77 |
78 | text_features, logits1, logits2 = model(visual_features, None, prompt_text, feat_lengths)
79 | #loss1
80 | loss1 = CLAS2(logits1, text_labels, feat_lengths, device)
81 | loss_total1 += loss1.item()
82 | #loss2
83 | loss2 = CLASM(logits2, text_labels, feat_lengths, device)
84 | loss_total2 += loss2.item()
85 | #loss3
86 | loss3 = torch.zeros(1).to(device)
87 | text_feature_normal = text_features[0] / text_features[0].norm(dim=-1, keepdim=True)
88 | for j in range(1, text_features.shape[0]):
89 | text_feature_abr = text_features[j] / text_features[j].norm(dim=-1, keepdim=True)
90 | loss3 += torch.abs(text_feature_normal @ text_feature_abr)
91 | loss3 = loss3 / 13 * 1e-1
92 |
93 | loss = loss1 + loss2 + loss3
94 |
95 | optimizer.zero_grad()
96 | loss.backward()
97 | optimizer.step()
98 | step += i * normal_loader.batch_size * 2
99 | if step % 1280 == 0 and step != 0:
100 | print('epoch: ', e+1, '| step: ', step, '| loss1: ', loss_total1 / (i+1), '| loss2: ', loss_total2 / (i+1), '| loss3: ', loss3.item())
101 | AUC, AP = test(model, testloader, args.visual_length, prompt_text, gt, gtsegments, gtlabels, device)
102 | AP = AUC
103 |
104 | if AP > ap_best:
105 | ap_best = AP
106 | checkpoint = {
107 | 'epoch': e,
108 | 'model_state_dict': model.state_dict(),
109 | 'optimizer_state_dict': optimizer.state_dict(),
110 | 'ap': ap_best}
111 | torch.save(checkpoint, args.checkpoint_path)
112 |
113 | scheduler.step()
114 |
115 | torch.save(model.state_dict(), 'model/model_cur.pth')
116 | checkpoint = torch.load(args.checkpoint_path)
117 | model.load_state_dict(checkpoint['model_state_dict'])
118 |
119 | checkpoint = torch.load(args.checkpoint_path)
120 | torch.save(checkpoint['model_state_dict'], args.model_path)
121 |
122 | def setup_seed(seed):
123 | torch.manual_seed(seed)
124 | torch.cuda.manual_seed_all(seed)
125 | np.random.seed(seed)
126 | random.seed(seed)
127 | #torch.backends.cudnn.deterministic = True
128 |
129 | if __name__ == '__main__':
130 | device = "cuda" if torch.cuda.is_available() else "cpu"
131 | args = ucf_option.parser.parse_args()
132 | setup_seed(args.seed)
133 |
134 | label_map = dict({'Normal': 'normal', 'Abuse': 'abuse', 'Arrest': 'arrest', 'Arson': 'arson', 'Assault': 'assault', 'Burglary': 'burglary', 'Explosion': 'explosion', 'Fighting': 'fighting', 'RoadAccidents': 'roadAccidents', 'Robbery': 'robbery', 'Shooting': 'shooting', 'Shoplifting': 'shoplifting', 'Stealing': 'stealing', 'Vandalism': 'vandalism'})
135 |
136 | normal_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, True)
137 | normal_loader = DataLoader(normal_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
138 | anomaly_dataset = UCFDataset(args.visual_length, args.train_list, False, label_map, False)
139 | anomaly_loader = DataLoader(anomaly_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
140 |
141 | test_dataset = UCFDataset(args.visual_length, args.test_list, True, label_map)
142 | test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
143 |
144 | model = CLIPVAD(args.classes_num, args.embed_dim, args.visual_length, args.visual_width, args.visual_head, args.visual_layers, args.attn_window, args.prompt_prefix, args.prompt_postfix, device)
145 |
146 | train(model, normal_loader, anomaly_loader, test_loader, args, label_map, device)
--------------------------------------------------------------------------------
/src/utils/layers.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from torch import FloatTensor
3 | from torch.nn.parameter import Parameter
4 | from torch.nn.modules.module import Module
5 | import numpy as np
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from scipy.spatial.distance import pdist, squareform
10 |
11 | class GraphAttentionLayer(nn.Module):
12 | """
13 | Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
14 | """
15 |
16 | def __init__(self, in_features, out_features, dropout, alpha, concat=True):
17 | super(GraphAttentionLayer, self).__init__()
18 | self.dropout = dropout
19 | self.in_features = in_features
20 | self.out_features = out_features
21 | self.alpha = alpha
22 | self.concat = concat
23 |
24 | self.W = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(in_features, out_features).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
25 | self.a = nn.Parameter(nn.init.xavier_uniform(torch.Tensor(2*out_features, 1).type(torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor), gain=np.sqrt(2.0)), requires_grad=True)
26 |
27 | self.leakyrelu = nn.LeakyReLU(self.alpha)
28 |
29 | def forward(self, input, adj):
30 | h = torch.mm(input, self.W)
31 | N = h.size()[0]
32 |
33 | a_input = torch.cat([h.repeat(1, N).view(N * N, -1), h.repeat(N, 1)], dim=1).view(N, -1, 2 * self.out_features)
34 | e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(2))
35 |
36 | zero_vec = -9e15*torch.ones_like(e)
37 | attention = torch.where(adj > 0, e, zero_vec)
38 | attention = F.softmax(attention, dim=1)
39 | attention = F.dropout(attention, self.dropout, training=self.training)
40 | h_prime = torch.matmul(attention, h)
41 |
42 | if self.concat:
43 | return F.elu(h_prime)
44 | else:
45 | return h_prime
46 |
47 | def __repr__(self):
48 | return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
49 |
50 | class linear(nn.Module):
51 | def __init__(self, in_features, out_features):
52 | super(linear, self).__init__()
53 | self.weight = Parameter(FloatTensor(in_features, out_features))
54 | self.register_parameter('bias', None)
55 | stdv = 1. / sqrt(self.weight.size(1))
56 | self.weight.data.uniform_(-stdv, stdv)
57 | def forward(self, x):
58 | x = x.matmul(self.weight)
59 | return x
60 |
61 | class GraphConvolution(Module):
62 | """
63 | Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
64 | """
65 |
66 | def __init__(self, in_features, out_features, bias=False, residual=True):
67 | super(GraphConvolution, self).__init__()
68 | self.in_features = in_features
69 | self.out_features = out_features
70 | self.weight = Parameter(FloatTensor(in_features, out_features))
71 | if bias:
72 | self.bias = Parameter(FloatTensor(out_features))
73 | else:
74 | self.register_parameter('bias', None)
75 | self.reset_parameters()
76 | if not residual:
77 | self.residual = lambda x: 0
78 | elif (in_features == out_features):
79 | self.residual = lambda x: x
80 | else:
81 | # self.residual = linear(in_features, out_features)
82 | self.residual = nn.Conv1d(in_channels=in_features, out_channels=out_features, kernel_size=5, padding=2)
83 | def reset_parameters(self):
84 | # stdv = 1. / sqrt(self.weight.size(1))
85 | nn.init.xavier_uniform_(self.weight)
86 | if self.bias is not None:
87 | self.bias.data.fill_(0.1)
88 |
89 | def forward(self, input, adj):
90 | # To support batch operations
91 | support = input.matmul(self.weight)
92 | output = adj.matmul(support)
93 |
94 | if self.bias is not None:
95 | output = output + self.bias
96 | if self.in_features != self.out_features and self.residual:
97 | input = input.permute(0,2,1)
98 | res = self.residual(input)
99 | res = res.permute(0,2,1)
100 | output = output + res
101 | else:
102 | output = output + self.residual(input)
103 |
104 | return output
105 |
106 | def __repr__(self):
107 | return self.__class__.__name__ + ' (' \
108 | + str(self.in_features) + ' -> ' \
109 | + str(self.out_features) + ')'
110 |
111 | ######################################################
112 |
113 | class SimilarityAdj(Module):
114 |
115 | def __init__(self, in_features, out_features):
116 | super(SimilarityAdj, self).__init__()
117 | self.in_features = in_features
118 | self.out_features = out_features
119 |
120 | self.weight0 = Parameter(FloatTensor(in_features, out_features))
121 | self.weight1 = Parameter(FloatTensor(in_features, out_features))
122 | self.register_parameter('bias', None)
123 | self.reset_parameters()
124 |
125 | def reset_parameters(self):
126 | # stdv = 1. / sqrt(self.weight0.size(1))
127 | nn.init.xavier_uniform_(self.weight0)
128 | nn.init.xavier_uniform_(self.weight1)
129 |
130 | def forward(self, input, seq_len):
131 | # To support batch operations
132 | theta = torch.matmul(input, self.weight0)
133 | phi = torch.matmul(input, self.weight0)
134 | phi2 = phi.permute(0, 2, 1)
135 | sim_graph = torch.matmul(theta, phi2)
136 |
137 | theta_norm = torch.norm(theta, p=2, dim=2, keepdim=True) # B*T*1
138 | phi_norm = torch.norm(phi, p=2, dim=2, keepdim=True) # B*T*1
139 | x_norm_x = theta_norm.matmul(phi_norm.permute(0, 2, 1))
140 | sim_graph = sim_graph / (x_norm_x + 1e-20)
141 |
142 | output = torch.zeros_like(sim_graph)
143 | if seq_len is None:
144 | for i in range(sim_graph.shape[0]):
145 | tmp = sim_graph[i]
146 | adj2 = tmp
147 | adj2 = F.threshold(adj2, 0.7, 0)
148 | adj2 = F.softmax(adj2, dim=1)
149 | output[i] = adj2
150 | else:
151 | for i in range(len(seq_len)):
152 | tmp = sim_graph[i, :seq_len[i], :seq_len[i]]
153 | adj2 = tmp
154 | adj2 = F.threshold(adj2, 0.7, 0)
155 | adj2 = F.softmax(adj2, dim=1)
156 | output[i, :seq_len[i], :seq_len[i]] = adj2
157 |
158 | return output
159 |
160 | def __repr__(self):
161 | return self.__class__.__name__ + ' (' \
162 | + str(self.in_features) + ' -> ' \
163 | + str(self.out_features) + ')'
164 |
165 | class DistanceAdj(Module):
166 |
167 | def __init__(self):
168 | super(DistanceAdj, self).__init__()
169 | self.sigma = Parameter(FloatTensor(1))
170 | self.sigma.data.fill_(0.1)
171 |
172 | def forward(self, batch_size, max_seqlen):
173 | # To support batch operations
174 | self.arith = np.arange(max_seqlen).reshape(-1, 1)
175 | dist = pdist(self.arith, metric='cityblock').astype(np.float32)
176 | self.dist = torch.from_numpy(squareform(dist)).to('cuda')
177 | self.dist = torch.exp(-self.dist / torch.exp(torch.tensor(1.)))
178 | self.dist = torch.unsqueeze(self.dist, 0).repeat(batch_size, 1, 1).to('cuda')
179 | return self.dist
180 |
181 | if __name__ == '__main__':
182 | d = DistanceAdj()
183 | dist = d(1, 256).squeeze(0)
184 | print(dist.softmax(dim=-1))
--------------------------------------------------------------------------------
/src/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import numpy as np
4 | import torch
5 | import torch.nn.functional as F
6 | from torch import nn
7 | from clip import clip
8 | from utils.layers import GraphConvolution, DistanceAdj
9 |
10 | class LayerNorm(nn.LayerNorm):
11 |
12 | def forward(self, x: torch.Tensor):
13 | orig_type = x.dtype
14 | ret = super().forward(x.type(torch.float32))
15 | return ret.type(orig_type)
16 |
17 |
18 | class QuickGELU(nn.Module):
19 | def forward(self, x: torch.Tensor):
20 | return x * torch.sigmoid(1.702 * x)
21 |
22 |
23 | class ResidualAttentionBlock(nn.Module):
24 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
25 | super().__init__()
26 |
27 | self.attn = nn.MultiheadAttention(d_model, n_head)
28 | self.ln_1 = LayerNorm(d_model)
29 | self.mlp = nn.Sequential(OrderedDict([
30 | ("c_fc", nn.Linear(d_model, d_model * 4)),
31 | ("gelu", QuickGELU()),
32 | ("c_proj", nn.Linear(d_model * 4, d_model))
33 | ]))
34 | self.ln_2 = LayerNorm(d_model)
35 | self.attn_mask = attn_mask
36 |
37 | def attention(self, x: torch.Tensor, padding_mask: torch.Tensor):
38 | padding_mask = padding_mask.to(dtype=bool, device=x.device) if padding_mask is not None else None
39 | self.attn_mask = self.attn_mask.to(device=x.device) if self.attn_mask is not None else None
40 | return self.attn(x, x, x, need_weights=False, key_padding_mask=padding_mask, attn_mask=self.attn_mask)[0]
41 |
42 | def forward(self, x):
43 | x, padding_mask = x
44 | x = x + self.attention(self.ln_1(x), padding_mask)
45 | x = x + self.mlp(self.ln_2(x))
46 | return (x, padding_mask)
47 |
48 |
49 | class Transformer(nn.Module):
50 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
51 | super().__init__()
52 | self.width = width
53 | self.layers = layers
54 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
55 |
56 | def forward(self, x: torch.Tensor):
57 | return self.resblocks(x)
58 |
59 |
60 | class CLIPVAD(nn.Module):
61 | def __init__(self,
62 | num_class: int,
63 | embed_dim: int,
64 | visual_length: int,
65 | visual_width: int,
66 | visual_head: int,
67 | visual_layers: int,
68 | attn_window: int,
69 | prompt_prefix: int,
70 | prompt_postfix: int,
71 | device):
72 | super().__init__()
73 |
74 | self.num_class = num_class
75 | self.visual_length = visual_length
76 | self.visual_width = visual_width
77 | self.embed_dim = embed_dim
78 | self.attn_window = attn_window
79 | self.prompt_prefix = prompt_prefix
80 | self.prompt_postfix = prompt_postfix
81 | self.device = device
82 |
83 | self.temporal = Transformer(
84 | width=visual_width,
85 | layers=visual_layers,
86 | heads=visual_head,
87 | attn_mask=self.build_attention_mask(self.attn_window)
88 | )
89 |
90 | width = int(visual_width / 2)
91 | self.gc1 = GraphConvolution(visual_width, width, residual=True)
92 | self.gc2 = GraphConvolution(width, width, residual=True)
93 | self.gc3 = GraphConvolution(visual_width, width, residual=True)
94 | self.gc4 = GraphConvolution(width, width, residual=True)
95 | self.disAdj = DistanceAdj()
96 | self.linear = nn.Linear(visual_width, visual_width)
97 | self.gelu = QuickGELU()
98 |
99 | self.mlp1 = nn.Sequential(OrderedDict([
100 | ("c_fc", nn.Linear(visual_width, visual_width * 4)),
101 | ("gelu", QuickGELU()),
102 | ("c_proj", nn.Linear(visual_width * 4, visual_width))
103 | ]))
104 | self.mlp2 = nn.Sequential(OrderedDict([
105 | ("c_fc", nn.Linear(visual_width, visual_width * 4)),
106 | ("gelu", QuickGELU()),
107 | ("c_proj", nn.Linear(visual_width * 4, visual_width))
108 | ]))
109 | self.classifier = nn.Linear(visual_width, 1)
110 |
111 | self.clipmodel, _ = clip.load("ViT-B/16", device)
112 | for clip_param in self.clipmodel.parameters():
113 | clip_param.requires_grad = False
114 |
115 | self.frame_position_embeddings = nn.Embedding(visual_length, visual_width)
116 | self.text_prompt_embeddings = nn.Embedding(77, self.embed_dim)
117 |
118 | self.initialize_parameters()
119 |
120 | def initialize_parameters(self):
121 | nn.init.normal_(self.text_prompt_embeddings.weight, std=0.01)
122 | nn.init.normal_(self.frame_position_embeddings.weight, std=0.01)
123 |
124 | def build_attention_mask(self, attn_window):
125 | # lazily create causal attention mask, with full attention between the vision tokens
126 | # pytorch uses additive attention mask; fill with -inf
127 | mask = torch.empty(self.visual_length, self.visual_length)
128 | mask.fill_(float('-inf'))
129 | for i in range(int(self.visual_length / attn_window)):
130 | if (i + 1) * attn_window < self.visual_length:
131 | mask[i * attn_window: (i + 1) * attn_window, i * attn_window: (i + 1) * attn_window] = 0
132 | else:
133 | mask[i * attn_window: self.visual_length, i * attn_window: self.visual_length] = 0
134 |
135 | return mask
136 |
137 | def adj4(self, x, seq_len):
138 | soft = nn.Softmax(1)
139 | x2 = x.matmul(x.permute(0, 2, 1)) # B*T*T
140 | x_norm = torch.norm(x, p=2, dim=2, keepdim=True) # B*T*1
141 | x_norm_x = x_norm.matmul(x_norm.permute(0, 2, 1))
142 | x2 = x2/(x_norm_x+1e-20)
143 | output = torch.zeros_like(x2)
144 | if seq_len is None:
145 | for i in range(x.shape[0]):
146 | tmp = x2[i]
147 | adj2 = tmp
148 | adj2 = F.threshold(adj2, 0.7, 0)
149 | adj2 = soft(adj2)
150 | output[i] = adj2
151 | else:
152 | for i in range(len(seq_len)):
153 | tmp = x2[i, :seq_len[i], :seq_len[i]]
154 | adj2 = tmp
155 | adj2 = F.threshold(adj2, 0.7, 0)
156 | adj2 = soft(adj2)
157 | output[i, :seq_len[i], :seq_len[i]] = adj2
158 |
159 | return output
160 |
161 | def encode_video(self, images, padding_mask, lengths):
162 | images = images.to(torch.float)
163 | position_ids = torch.arange(self.visual_length, device=self.device)
164 | position_ids = position_ids.unsqueeze(0).expand(images.shape[0], -1)
165 | frame_position_embeddings = self.frame_position_embeddings(position_ids)
166 | frame_position_embeddings = frame_position_embeddings.permute(1, 0, 2)
167 | images = images.permute(1, 0, 2) + frame_position_embeddings
168 |
169 | x, _ = self.temporal((images, None))
170 | x = x.permute(1, 0, 2)
171 |
172 | adj = self.adj4(x, lengths)
173 | disadj = self.disAdj(x.shape[0], x.shape[1])
174 | x1_h = self.gelu(self.gc1(x, adj))
175 | x2_h = self.gelu(self.gc3(x, disadj))
176 |
177 | x1 = self.gelu(self.gc2(x1_h, adj))
178 | x2 = self.gelu(self.gc4(x2_h, disadj))
179 |
180 | x = torch.cat((x1, x2), 2)
181 | x = self.linear(x)
182 |
183 | return x
184 |
185 | def encode_textprompt(self, text):
186 | word_tokens = clip.tokenize(text).to(self.device)
187 | word_embedding = self.clipmodel.encode_token(word_tokens)
188 | text_embeddings = self.text_prompt_embeddings(torch.arange(77).to(self.device)).unsqueeze(0).repeat([len(text), 1, 1])
189 | text_tokens = torch.zeros(len(text), 77).to(self.device)
190 |
191 | for i in range(len(text)):
192 | ind = torch.argmax(word_tokens[i], -1)
193 | text_embeddings[i, 0] = word_embedding[i, 0]
194 | text_embeddings[i, self.prompt_prefix + 1: self.prompt_prefix + ind] = word_embedding[i, 1: ind]
195 | text_embeddings[i, self.prompt_prefix + ind + self.prompt_postfix] = word_embedding[i, ind]
196 | text_tokens[i, self.prompt_prefix + ind + self.prompt_postfix] = word_tokens[i, ind]
197 |
198 | text_features = self.clipmodel.encode_text(text_embeddings, text_tokens)
199 |
200 | return text_features
201 |
202 | def forward(self, visual, padding_mask, text, lengths):
203 | visual_features = self.encode_video(visual, padding_mask, lengths)
204 | logits1 = self.classifier(visual_features + self.mlp2(visual_features))
205 |
206 | text_features_ori = self.encode_textprompt(text)
207 |
208 | text_features = text_features_ori
209 | logits_attn = logits1.permute(0, 2, 1)
210 | visual_attn = logits_attn @ visual_features
211 | visual_attn = visual_attn / visual_attn.norm(dim=-1, keepdim=True)
212 | visual_attn = visual_attn.expand(visual_attn.shape[0], text_features_ori.shape[0], visual_attn.shape[2])
213 | text_features = text_features_ori.unsqueeze(0)
214 | text_features = text_features.expand(visual_attn.shape[0], text_features.shape[1], text_features.shape[2])
215 | text_features = text_features + visual_attn
216 | text_features = text_features + self.mlp1(text_features)
217 |
218 | visual_features_norm = visual_features / visual_features.norm(dim=-1, keepdim=True)
219 | text_features_norm = text_features / text_features.norm(dim=-1, keepdim=True)
220 | text_features_norm = text_features_norm.permute(0, 2, 1)
221 | logits2 = visual_features_norm @ text_features_norm.type(visual_features_norm.dtype) / 0.07
222 |
223 | return text_features_ori, logits1, logits2
224 |
--------------------------------------------------------------------------------
/src/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
95 | """Load a CLIP model
96 |
97 | Parameters
98 | ----------
99 | name : str
100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101 |
102 | device : Union[str, torch.device]
103 | The device to put the loaded model
104 |
105 | jit : bool
106 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
107 |
108 | download_root: str
109 | path to download the model files; by default, it uses "~/.cache/clip"
110 |
111 | Returns
112 | -------
113 | model : torch.nn.Module
114 | The CLIP model
115 |
116 | preprocess : Callable[[PIL.Image], torch.Tensor]
117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118 | """
119 | if name in _MODELS:
120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121 | elif os.path.isfile(name):
122 | model_path = name
123 | else:
124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125 |
126 | with open(model_path, 'rb') as opened_file:
127 | try:
128 | # loading JIT archive
129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130 | state_dict = None
131 | except RuntimeError:
132 | # loading saved state dict
133 | if jit:
134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135 | jit = False
136 | state_dict = torch.load(opened_file, map_location="cpu")
137 |
138 | if not jit:
139 | model = build_model(state_dict or model.state_dict()).to(device)
140 | if str(device) == "cpu":
141 | model.float()
142 | return model, _transform(model.visual.input_resolution)
143 |
144 | # patch the device names
145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147 |
148 | def patch_device(module):
149 | try:
150 | graphs = [module.graph] if hasattr(module, "graph") else []
151 | except RuntimeError:
152 | graphs = []
153 |
154 | if hasattr(module, "forward1"):
155 | graphs.append(module.forward1.graph)
156 |
157 | for graph in graphs:
158 | for node in graph.findAllNodes("prim::Constant"):
159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160 | node.copyAttributes(device_node)
161 |
162 | model.apply(patch_device)
163 | patch_device(model.encode_image)
164 | patch_device(model.encode_text)
165 |
166 | # patch dtype to float32 on CPU
167 | if str(device) == "cpu":
168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170 | float_node = float_input.node()
171 |
172 | def patch_float(module):
173 | try:
174 | graphs = [module.graph] if hasattr(module, "graph") else []
175 | except RuntimeError:
176 | graphs = []
177 |
178 | if hasattr(module, "forward1"):
179 | graphs.append(module.forward1.graph)
180 |
181 | for graph in graphs:
182 | for node in graph.findAllNodes("aten::to"):
183 | inputs = list(node.inputs())
184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185 | if inputs[i].node()["value"] == 5:
186 | inputs[i].node().copyAttributes(float_node)
187 |
188 | model.apply(patch_float)
189 | patch_float(model.encode_image)
190 | patch_float(model.encode_text)
191 |
192 | model.float()
193 |
194 | return model, _transform(model.input_resolution.item())
195 |
196 |
197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198 | """
199 | Returns the tokenized representation of given input string(s)
200 |
201 | Parameters
202 | ----------
203 | texts : Union[str, List[str]]
204 | An input string or a list of input strings to tokenize
205 |
206 | context_length : int
207 | The context length to use; all CLIP models use 77 as the context length
208 |
209 | truncate: bool
210 | Whether to truncate the text in case its encoding is longer than the context length
211 |
212 | Returns
213 | -------
214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216 | """
217 | if isinstance(texts, str):
218 | texts = [texts]
219 |
220 | sot_token = _tokenizer.encoder["<|startoftext|>"]
221 | eot_token = _tokenizer.encoder["<|endoftext|>"]
222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225 | else:
226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227 |
228 | for i, tokens in enumerate(all_tokens):
229 | if len(tokens) > context_length:
230 | if truncate:
231 | tokens = tokens[:context_length]
232 | tokens[-1] = eot_token
233 | else:
234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235 | result[i, :len(tokens)] = torch.tensor(tokens)
236 |
237 | return result
238 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/list/Anomaly_Test.txt:
--------------------------------------------------------------------------------
1 | Abuse/Abuse028_x264.mp4
2 | Abuse/Abuse030_x264.mp4
3 | Arrest/Arrest001_x264.mp4
4 | Arrest/Arrest007_x264.mp4
5 | Arrest/Arrest024_x264.mp4
6 | Arrest/Arrest030_x264.mp4
7 | Arrest/Arrest039_x264.mp4
8 | Arson/Arson007_x264.mp4
9 | Arson/Arson009_x264.mp4
10 | Arson/Arson010_x264.mp4
11 | Arson/Arson011_x264.mp4
12 | Arson/Arson016_x264.mp4
13 | Arson/Arson018_x264.mp4
14 | Arson/Arson022_x264.mp4
15 | Arson/Arson035_x264.mp4
16 | Arson/Arson041_x264.mp4
17 | Assault/Assault006_x264.mp4
18 | Assault/Assault010_x264.mp4
19 | Assault/Assault011_x264.mp4
20 | Burglary/Burglary005_x264.mp4
21 | Burglary/Burglary017_x264.mp4
22 | Burglary/Burglary018_x264.mp4
23 | Burglary/Burglary021_x264.mp4
24 | Burglary/Burglary024_x264.mp4
25 | Burglary/Burglary032_x264.mp4
26 | Burglary/Burglary033_x264.mp4
27 | Burglary/Burglary035_x264.mp4
28 | Burglary/Burglary037_x264.mp4
29 | Burglary/Burglary061_x264.mp4
30 | Burglary/Burglary076_x264.mp4
31 | Burglary/Burglary079_x264.mp4
32 | Burglary/Burglary092_x264.mp4
33 | Explosion/Explosion002_x264.mp4
34 | Explosion/Explosion004_x264.mp4
35 | Explosion/Explosion007_x264.mp4
36 | Explosion/Explosion008_x264.mp4
37 | Explosion/Explosion010_x264.mp4
38 | Explosion/Explosion011_x264.mp4
39 | Explosion/Explosion013_x264.mp4
40 | Explosion/Explosion016_x264.mp4
41 | Explosion/Explosion017_x264.mp4
42 | Explosion/Explosion020_x264.mp4
43 | Explosion/Explosion021_x264.mp4
44 | Explosion/Explosion022_x264.mp4
45 | Explosion/Explosion025_x264.mp4
46 | Explosion/Explosion027_x264.mp4
47 | Explosion/Explosion028_x264.mp4
48 | Explosion/Explosion029_x264.mp4
49 | Explosion/Explosion033_x264.mp4
50 | Explosion/Explosion035_x264.mp4
51 | Explosion/Explosion036_x264.mp4
52 | Explosion/Explosion039_x264.mp4
53 | Explosion/Explosion043_x264.mp4
54 | Fighting/Fighting003_x264.mp4
55 | Fighting/Fighting018_x264.mp4
56 | Fighting/Fighting033_x264.mp4
57 | Fighting/Fighting042_x264.mp4
58 | Fighting/Fighting047_x264.mp4
59 | RoadAccidents/RoadAccidents001_x264.mp4
60 | RoadAccidents/RoadAccidents002_x264.mp4
61 | RoadAccidents/RoadAccidents004_x264.mp4
62 | RoadAccidents/RoadAccidents009_x264.mp4
63 | RoadAccidents/RoadAccidents010_x264.mp4
64 | RoadAccidents/RoadAccidents011_x264.mp4
65 | RoadAccidents/RoadAccidents012_x264.mp4
66 | RoadAccidents/RoadAccidents016_x264.mp4
67 | RoadAccidents/RoadAccidents017_x264.mp4
68 | RoadAccidents/RoadAccidents019_x264.mp4
69 | RoadAccidents/RoadAccidents020_x264.mp4
70 | RoadAccidents/RoadAccidents021_x264.mp4
71 | RoadAccidents/RoadAccidents022_x264.mp4
72 | RoadAccidents/RoadAccidents121_x264.mp4
73 | RoadAccidents/RoadAccidents122_x264.mp4
74 | RoadAccidents/RoadAccidents123_x264.mp4
75 | RoadAccidents/RoadAccidents124_x264.mp4
76 | RoadAccidents/RoadAccidents125_x264.mp4
77 | RoadAccidents/RoadAccidents127_x264.mp4
78 | RoadAccidents/RoadAccidents128_x264.mp4
79 | RoadAccidents/RoadAccidents131_x264.mp4
80 | RoadAccidents/RoadAccidents132_x264.mp4
81 | RoadAccidents/RoadAccidents133_x264.mp4
82 | Robbery/Robbery048_x264.mp4
83 | Robbery/Robbery050_x264.mp4
84 | Robbery/Robbery102_x264.mp4
85 | Robbery/Robbery106_x264.mp4
86 | Robbery/Robbery137_x264.mp4
87 | Shooting/Shooting002_x264.mp4
88 | Shooting/Shooting004_x264.mp4
89 | Shooting/Shooting007_x264.mp4
90 | Shooting/Shooting008_x264.mp4
91 | Shooting/Shooting010_x264.mp4
92 | Shooting/Shooting011_x264.mp4
93 | Shooting/Shooting013_x264.mp4
94 | Shooting/Shooting015_x264.mp4
95 | Shooting/Shooting018_x264.mp4
96 | Shooting/Shooting019_x264.mp4
97 | Shooting/Shooting021_x264.mp4
98 | Shooting/Shooting022_x264.mp4
99 | Shooting/Shooting024_x264.mp4
100 | Shooting/Shooting026_x264.mp4
101 | Shooting/Shooting028_x264.mp4
102 | Shooting/Shooting032_x264.mp4
103 | Shooting/Shooting033_x264.mp4
104 | Shooting/Shooting034_x264.mp4
105 | Shooting/Shooting037_x264.mp4
106 | Shooting/Shooting043_x264.mp4
107 | Shooting/Shooting046_x264.mp4
108 | Shooting/Shooting047_x264.mp4
109 | Shooting/Shooting048_x264.mp4
110 | Shoplifting/Shoplifting001_x264.mp4
111 | Shoplifting/Shoplifting004_x264.mp4
112 | Shoplifting/Shoplifting005_x264.mp4
113 | Shoplifting/Shoplifting007_x264.mp4
114 | Shoplifting/Shoplifting010_x264.mp4
115 | Shoplifting/Shoplifting015_x264.mp4
116 | Shoplifting/Shoplifting016_x264.mp4
117 | Shoplifting/Shoplifting017_x264.mp4
118 | Shoplifting/Shoplifting020_x264.mp4
119 | Shoplifting/Shoplifting021_x264.mp4
120 | Shoplifting/Shoplifting022_x264.mp4
121 | Shoplifting/Shoplifting027_x264.mp4
122 | Shoplifting/Shoplifting028_x264.mp4
123 | Shoplifting/Shoplifting029_x264.mp4
124 | Shoplifting/Shoplifting031_x264.mp4
125 | Shoplifting/Shoplifting033_x264.mp4
126 | Shoplifting/Shoplifting034_x264.mp4
127 | Shoplifting/Shoplifting037_x264.mp4
128 | Shoplifting/Shoplifting039_x264.mp4
129 | Shoplifting/Shoplifting044_x264.mp4
130 | Shoplifting/Shoplifting049_x264.mp4
131 | Stealing/Stealing019_x264.mp4
132 | Stealing/Stealing036_x264.mp4
133 | Stealing/Stealing058_x264.mp4
134 | Stealing/Stealing062_x264.mp4
135 | Stealing/Stealing079_x264.mp4
136 | Testing_Normal_Videos_Anomaly/Normal_Videos_003_x264.mp4
137 | Testing_Normal_Videos_Anomaly/Normal_Videos_006_x264.mp4
138 | Testing_Normal_Videos_Anomaly/Normal_Videos_010_x264.mp4
139 | Testing_Normal_Videos_Anomaly/Normal_Videos_014_x264.mp4
140 | Testing_Normal_Videos_Anomaly/Normal_Videos_015_x264.mp4
141 | Testing_Normal_Videos_Anomaly/Normal_Videos_018_x264.mp4
142 | Testing_Normal_Videos_Anomaly/Normal_Videos_019_x264.mp4
143 | Testing_Normal_Videos_Anomaly/Normal_Videos_024_x264.mp4
144 | Testing_Normal_Videos_Anomaly/Normal_Videos_025_x264.mp4
145 | Testing_Normal_Videos_Anomaly/Normal_Videos_027_x264.mp4
146 | Testing_Normal_Videos_Anomaly/Normal_Videos_033_x264.mp4
147 | Testing_Normal_Videos_Anomaly/Normal_Videos_034_x264.mp4
148 | Testing_Normal_Videos_Anomaly/Normal_Videos_041_x264.mp4
149 | Testing_Normal_Videos_Anomaly/Normal_Videos_042_x264.mp4
150 | Testing_Normal_Videos_Anomaly/Normal_Videos_048_x264.mp4
151 | Testing_Normal_Videos_Anomaly/Normal_Videos_050_x264.mp4
152 | Testing_Normal_Videos_Anomaly/Normal_Videos_051_x264.mp4
153 | Testing_Normal_Videos_Anomaly/Normal_Videos_056_x264.mp4
154 | Testing_Normal_Videos_Anomaly/Normal_Videos_059_x264.mp4
155 | Testing_Normal_Videos_Anomaly/Normal_Videos_063_x264.mp4
156 | Testing_Normal_Videos_Anomaly/Normal_Videos_067_x264.mp4
157 | Testing_Normal_Videos_Anomaly/Normal_Videos_070_x264.mp4
158 | Testing_Normal_Videos_Anomaly/Normal_Videos_100_x264.mp4
159 | Testing_Normal_Videos_Anomaly/Normal_Videos_129_x264.mp4
160 | Testing_Normal_Videos_Anomaly/Normal_Videos_150_x264.mp4
161 | Testing_Normal_Videos_Anomaly/Normal_Videos_168_x264.mp4
162 | Testing_Normal_Videos_Anomaly/Normal_Videos_175_x264.mp4
163 | Testing_Normal_Videos_Anomaly/Normal_Videos_182_x264.mp4
164 | Testing_Normal_Videos_Anomaly/Normal_Videos_189_x264.mp4
165 | Testing_Normal_Videos_Anomaly/Normal_Videos_196_x264.mp4
166 | Testing_Normal_Videos_Anomaly/Normal_Videos_203_x264.mp4
167 | Testing_Normal_Videos_Anomaly/Normal_Videos_210_x264.mp4
168 | Testing_Normal_Videos_Anomaly/Normal_Videos_217_x264.mp4
169 | Testing_Normal_Videos_Anomaly/Normal_Videos_224_x264.mp4
170 | Testing_Normal_Videos_Anomaly/Normal_Videos_246_x264.mp4
171 | Testing_Normal_Videos_Anomaly/Normal_Videos_247_x264.mp4
172 | Testing_Normal_Videos_Anomaly/Normal_Videos_248_x264.mp4
173 | Testing_Normal_Videos_Anomaly/Normal_Videos_251_x264.mp4
174 | Testing_Normal_Videos_Anomaly/Normal_Videos_289_x264.mp4
175 | Testing_Normal_Videos_Anomaly/Normal_Videos_310_x264.mp4
176 | Testing_Normal_Videos_Anomaly/Normal_Videos_312_x264.mp4
177 | Testing_Normal_Videos_Anomaly/Normal_Videos_317_x264.mp4
178 | Testing_Normal_Videos_Anomaly/Normal_Videos_345_x264.mp4
179 | Testing_Normal_Videos_Anomaly/Normal_Videos_352_x264.mp4
180 | Testing_Normal_Videos_Anomaly/Normal_Videos_360_x264.mp4
181 | Testing_Normal_Videos_Anomaly/Normal_Videos_365_x264.mp4
182 | Testing_Normal_Videos_Anomaly/Normal_Videos_401_x264.mp4
183 | Testing_Normal_Videos_Anomaly/Normal_Videos_417_x264.mp4
184 | Testing_Normal_Videos_Anomaly/Normal_Videos_439_x264.mp4
185 | Testing_Normal_Videos_Anomaly/Normal_Videos_452_x264.mp4
186 | Testing_Normal_Videos_Anomaly/Normal_Videos_453_x264.mp4
187 | Testing_Normal_Videos_Anomaly/Normal_Videos_478_x264.mp4
188 | Testing_Normal_Videos_Anomaly/Normal_Videos_576_x264.mp4
189 | Testing_Normal_Videos_Anomaly/Normal_Videos_597_x264.mp4
190 | Testing_Normal_Videos_Anomaly/Normal_Videos_603_x264.mp4
191 | Testing_Normal_Videos_Anomaly/Normal_Videos_606_x264.mp4
192 | Testing_Normal_Videos_Anomaly/Normal_Videos_621_x264.mp4
193 | Testing_Normal_Videos_Anomaly/Normal_Videos_634_x264.mp4
194 | Testing_Normal_Videos_Anomaly/Normal_Videos_641_x264.mp4
195 | Testing_Normal_Videos_Anomaly/Normal_Videos_656_x264.mp4
196 | Testing_Normal_Videos_Anomaly/Normal_Videos_686_x264.mp4
197 | Testing_Normal_Videos_Anomaly/Normal_Videos_696_x264.mp4
198 | Testing_Normal_Videos_Anomaly/Normal_Videos_702_x264.mp4
199 | Testing_Normal_Videos_Anomaly/Normal_Videos_704_x264.mp4
200 | Testing_Normal_Videos_Anomaly/Normal_Videos_710_x264.mp4
201 | Testing_Normal_Videos_Anomaly/Normal_Videos_717_x264.mp4
202 | Testing_Normal_Videos_Anomaly/Normal_Videos_722_x264.mp4
203 | Testing_Normal_Videos_Anomaly/Normal_Videos_725_x264.mp4
204 | Testing_Normal_Videos_Anomaly/Normal_Videos_745_x264.mp4
205 | Testing_Normal_Videos_Anomaly/Normal_Videos_758_x264.mp4
206 | Testing_Normal_Videos_Anomaly/Normal_Videos_778_x264.mp4
207 | Testing_Normal_Videos_Anomaly/Normal_Videos_780_x264.mp4
208 | Testing_Normal_Videos_Anomaly/Normal_Videos_781_x264.mp4
209 | Testing_Normal_Videos_Anomaly/Normal_Videos_782_x264.mp4
210 | Testing_Normal_Videos_Anomaly/Normal_Videos_783_x264.mp4
211 | Testing_Normal_Videos_Anomaly/Normal_Videos_798_x264.mp4
212 | Testing_Normal_Videos_Anomaly/Normal_Videos_801_x264.mp4
213 | Testing_Normal_Videos_Anomaly/Normal_Videos_828_x264.mp4
214 | Testing_Normal_Videos_Anomaly/Normal_Videos_831_x264.mp4
215 | Testing_Normal_Videos_Anomaly/Normal_Videos_866_x264.mp4
216 | Testing_Normal_Videos_Anomaly/Normal_Videos_867_x264.mp4
217 | Testing_Normal_Videos_Anomaly/Normal_Videos_868_x264.mp4
218 | Testing_Normal_Videos_Anomaly/Normal_Videos_869_x264.mp4
219 | Testing_Normal_Videos_Anomaly/Normal_Videos_870_x264.mp4
220 | Testing_Normal_Videos_Anomaly/Normal_Videos_871_x264.mp4
221 | Testing_Normal_Videos_Anomaly/Normal_Videos_872_x264.mp4
222 | Testing_Normal_Videos_Anomaly/Normal_Videos_873_x264.mp4
223 | Testing_Normal_Videos_Anomaly/Normal_Videos_874_x264.mp4
224 | Testing_Normal_Videos_Anomaly/Normal_Videos_875_x264.mp4
225 | Testing_Normal_Videos_Anomaly/Normal_Videos_876_x264.mp4
226 | Testing_Normal_Videos_Anomaly/Normal_Videos_877_x264.mp4
227 | Testing_Normal_Videos_Anomaly/Normal_Videos_878_x264.mp4
228 | Testing_Normal_Videos_Anomaly/Normal_Videos_879_x264.mp4
229 | Testing_Normal_Videos_Anomaly/Normal_Videos_880_x264.mp4
230 | Testing_Normal_Videos_Anomaly/Normal_Videos_881_x264.mp4
231 | Testing_Normal_Videos_Anomaly/Normal_Videos_882_x264.mp4
232 | Testing_Normal_Videos_Anomaly/Normal_Videos_883_x264.mp4
233 | Testing_Normal_Videos_Anomaly/Normal_Videos_884_x264.mp4
234 | Testing_Normal_Videos_Anomaly/Normal_Videos_885_x264.mp4
235 | Testing_Normal_Videos_Anomaly/Normal_Videos_886_x264.mp4
236 | Testing_Normal_Videos_Anomaly/Normal_Videos_887_x264.mp4
237 | Testing_Normal_Videos_Anomaly/Normal_Videos_888_x264.mp4
238 | Testing_Normal_Videos_Anomaly/Normal_Videos_889_x264.mp4
239 | Testing_Normal_Videos_Anomaly/Normal_Videos_890_x264.mp4
240 | Testing_Normal_Videos_Anomaly/Normal_Videos_891_x264.mp4
241 | Testing_Normal_Videos_Anomaly/Normal_Videos_892_x264.mp4
242 | Testing_Normal_Videos_Anomaly/Normal_Videos_893_x264.mp4
243 | Testing_Normal_Videos_Anomaly/Normal_Videos_894_x264.mp4
244 | Testing_Normal_Videos_Anomaly/Normal_Videos_895_x264.mp4
245 | Testing_Normal_Videos_Anomaly/Normal_Videos_896_x264.mp4
246 | Testing_Normal_Videos_Anomaly/Normal_Videos_897_x264.mp4
247 | Testing_Normal_Videos_Anomaly/Normal_Videos_898_x264.mp4
248 | Testing_Normal_Videos_Anomaly/Normal_Videos_899_x264.mp4
249 | Testing_Normal_Videos_Anomaly/Normal_Videos_900_x264.mp4
250 | Testing_Normal_Videos_Anomaly/Normal_Videos_901_x264.mp4
251 | Testing_Normal_Videos_Anomaly/Normal_Videos_902_x264.mp4
252 | Testing_Normal_Videos_Anomaly/Normal_Videos_903_x264.mp4
253 | Testing_Normal_Videos_Anomaly/Normal_Videos_904_x264.mp4
254 | Testing_Normal_Videos_Anomaly/Normal_Videos_905_x264.mp4
255 | Testing_Normal_Videos_Anomaly/Normal_Videos_906_x264.mp4
256 | Testing_Normal_Videos_Anomaly/Normal_Videos_907_x264.mp4
257 | Testing_Normal_Videos_Anomaly/Normal_Videos_908_x264.mp4
258 | Testing_Normal_Videos_Anomaly/Normal_Videos_909_x264.mp4
259 | Testing_Normal_Videos_Anomaly/Normal_Videos_910_x264.mp4
260 | Testing_Normal_Videos_Anomaly/Normal_Videos_911_x264.mp4
261 | Testing_Normal_Videos_Anomaly/Normal_Videos_912_x264.mp4
262 | Testing_Normal_Videos_Anomaly/Normal_Videos_913_x264.mp4
263 | Testing_Normal_Videos_Anomaly/Normal_Videos_914_x264.mp4
264 | Testing_Normal_Videos_Anomaly/Normal_Videos_915_x264.mp4
265 | Testing_Normal_Videos_Anomaly/Normal_Videos_923_x264.mp4
266 | Testing_Normal_Videos_Anomaly/Normal_Videos_924_x264.mp4
267 | Testing_Normal_Videos_Anomaly/Normal_Videos_925_x264.mp4
268 | Testing_Normal_Videos_Anomaly/Normal_Videos_926_x264.mp4
269 | Testing_Normal_Videos_Anomaly/Normal_Videos_927_x264.mp4
270 | Testing_Normal_Videos_Anomaly/Normal_Videos_928_x264.mp4
271 | Testing_Normal_Videos_Anomaly/Normal_Videos_929_x264.mp4
272 | Testing_Normal_Videos_Anomaly/Normal_Videos_930_x264.mp4
273 | Testing_Normal_Videos_Anomaly/Normal_Videos_931_x264.mp4
274 | Testing_Normal_Videos_Anomaly/Normal_Videos_932_x264.mp4
275 | Testing_Normal_Videos_Anomaly/Normal_Videos_933_x264.mp4
276 | Testing_Normal_Videos_Anomaly/Normal_Videos_934_x264.mp4
277 | Testing_Normal_Videos_Anomaly/Normal_Videos_935_x264.mp4
278 | Testing_Normal_Videos_Anomaly/Normal_Videos_936_x264.mp4
279 | Testing_Normal_Videos_Anomaly/Normal_Videos_937_x264.mp4
280 | Testing_Normal_Videos_Anomaly/Normal_Videos_938_x264.mp4
281 | Testing_Normal_Videos_Anomaly/Normal_Videos_939_x264.mp4
282 | Testing_Normal_Videos_Anomaly/Normal_Videos_940_x264.mp4
283 | Testing_Normal_Videos_Anomaly/Normal_Videos_941_x264.mp4
284 | Testing_Normal_Videos_Anomaly/Normal_Videos_943_x264.mp4
285 | Testing_Normal_Videos_Anomaly/Normal_Videos_944_x264.mp4
286 | Vandalism/Vandalism007_x264.mp4
287 | Vandalism/Vandalism015_x264.mp4
288 | Vandalism/Vandalism017_x264.mp4
289 | Vandalism/Vandalism028_x264.mp4
290 | Vandalism/Vandalism036_x264.mp4
291 |
--------------------------------------------------------------------------------
/list/Temporal_Anomaly_Annotation.txt:
--------------------------------------------------------------------------------
1 | Abuse028_x264 Abuse 165 240 -1 -1
2 | Abuse030_x264 Abuse 1275 1360 -1 -1
3 | Arrest001_x264 Arrest 1185 1485 -1 -1
4 | Arrest007_x264 Arrest 1530 2160 -1 -1
5 | Arrest024_x264 Arrest 1005 3105 -1 -1
6 | Arrest030_x264 Arrest 5535 7200 -1 -1
7 | Arrest039_x264 Arrest 7215 10335 -1 -1
8 | Arson007_x264 Arson 2250 5700 -1 -1
9 | Arson009_x264 Arson 220 315 -1 -1
10 | Arson010_x264 Arson 885 1230 -1 -1
11 | Arson011_x264 Arson 150 420 680 1267
12 | Arson016_x264 Arson 1000 1796 -1 -1
13 | Arson018_x264 Arson 270 600 -1 -1
14 | Arson022_x264 Arson 3500 4000 -1 -1
15 | Arson035_x264 Arson 600 900 -1 -1
16 | Arson041_x264 Arson 2130 3615 -1 -1
17 | Assault006_x264 Assault 1185 8096 -1 -1
18 | Assault010_x264 Assault 11330 11680 12260 12930
19 | Assault011_x264 Assault 375 960 -1 -1
20 | Burglary005_x264 Burglary 4710 5040 -1 -1
21 | Burglary017_x264 Burglary 150 600 -1 -1
22 | Burglary018_x264 Burglary 720 1050 -1 -1
23 | Burglary021_x264 Burglary 60 200 840 1340
24 | Burglary024_x264 Burglary 60 1230 -1 -1
25 | Burglary032_x264 Burglary 1290 3690 -1 -1
26 | Burglary033_x264 Burglary 60 330 -1 -1
27 | Burglary035_x264 Burglary 1 1740 -1 -1
28 | Burglary037_x264 Burglary 240 390 540 1800
29 | Burglary061_x264 Burglary 4200 5700 -1 -1
30 | Burglary076_x264 Burglary 1590 4300 -1 -1
31 | Burglary079_x264 Burglary 7750 10710 -1 -1
32 | Burglary092_x264 Burglary 240 420 -1 -1
33 | Explosion002_x264 Explosion 1500 2100 -1 -1
34 | Explosion004_x264 Explosion 75 225 -1 -1
35 | Explosion007_x264 Explosion 1590 2280 -1 -1
36 | Explosion008_x264 Explosion 1005 1245 -1 -1
37 | Explosion010_x264 Explosion 285 1080 -1 -1
38 | Explosion011_x264 Explosion 795 945 -1 -1
39 | Explosion013_x264 Explosion 2520 2970 -1 -1
40 | Explosion016_x264 Explosion 180 450 -1 -1
41 | Explosion017_x264 Explosion 990 1440 -1 -1
42 | Explosion020_x264 Explosion 60 270 -1 -1
43 | Explosion021_x264 Explosion 135 270 -1 -1
44 | Explosion022_x264 Explosion 2230 2420 -1 -1
45 | Explosion025_x264 Explosion 260 420 -1 -1
46 | Explosion027_x264 Explosion 105 180 -1 -1
47 | Explosion028_x264 Explosion 280 700 -1 -1
48 | Explosion029_x264 Explosion 1830 2020 -1 -1
49 | Explosion033_x264 Explosion 970 1350 1550 3156
50 | Explosion035_x264 Explosion 250 350 -1 -1
51 | Explosion036_x264 Explosion 1950 2070 -1 -1
52 | Explosion039_x264 Explosion 60 150 -1 -1
53 | Explosion043_x264 Explosion 4460 4600 -1 -1
54 | Fighting003_x264 Fighting 1820 3103 -1 -1
55 | Fighting018_x264 Fighting 80 420 -1 -1
56 | Fighting033_x264 Fighting 570 840 -1 -1
57 | Fighting042_x264 Fighting 290 1200 -1 -1
58 | Fighting047_x264 Fighting 200 1830 -1 -1
59 | Normal_Videos_003_x264 Normal -1 -1 -1 -1
60 | Normal_Videos_006_x264 Normal -1 -1 -1 -1
61 | Normal_Videos_010_x264 Normal -1 -1 -1 -1
62 | Normal_Videos_014_x264 Normal -1 -1 -1 -1
63 | Normal_Videos_015_x264 Normal -1 -1 -1 -1
64 | Normal_Videos_018_x264 Normal -1 -1 -1 -1
65 | Normal_Videos_019_x264 Normal -1 -1 -1 -1
66 | Normal_Videos_024_x264 Normal -1 -1 -1 -1
67 | Normal_Videos_025_x264 Normal -1 -1 -1 -1
68 | Normal_Videos_027_x264 Normal -1 -1 -1 -1
69 | Normal_Videos_033_x264 Normal -1 -1 -1 -1
70 | Normal_Videos_034_x264 Normal -1 -1 -1 -1
71 | Normal_Videos_041_x264 Normal -1 -1 -1 -1
72 | Normal_Videos_042_x264 Normal -1 -1 -1 -1
73 | Normal_Videos_048_x264 Normal -1 -1 -1 -1
74 | Normal_Videos_050_x264 Normal -1 -1 -1 -1
75 | Normal_Videos_051_x264 Normal -1 -1 -1 -1
76 | Normal_Videos_056_x264 Normal -1 -1 -1 -1
77 | Normal_Videos_059_x264 Normal -1 -1 -1 -1
78 | Normal_Videos_063_x264 Normal -1 -1 -1 -1
79 | Normal_Videos_067_x264 Normal -1 -1 -1 -1
80 | Normal_Videos_070_x264 Normal -1 -1 -1 -1
81 | Normal_Videos_100_x264 Normal -1 -1 -1 -1
82 | Normal_Videos_129_x264 Normal -1 -1 -1 -1
83 | Normal_Videos_150_x264 Normal -1 -1 -1 -1
84 | Normal_Videos_168_x264 Normal -1 -1 -1 -1
85 | Normal_Videos_175_x264 Normal -1 -1 -1 -1
86 | Normal_Videos_182_x264 Normal -1 -1 -1 -1
87 | Normal_Videos_189_x264 Normal -1 -1 -1 -1
88 | Normal_Videos_196_x264 Normal -1 -1 -1 -1
89 | Normal_Videos_203_x264 Normal -1 -1 -1 -1
90 | Normal_Videos_210_x264 Normal -1 -1 -1 -1
91 | Normal_Videos_217_x264 Normal -1 -1 -1 -1
92 | Normal_Videos_224_x264 Normal -1 -1 -1 -1
93 | Normal_Videos_246_x264 Normal -1 -1 -1 -1
94 | Normal_Videos_247_x264 Normal -1 -1 -1 -1
95 | Normal_Videos_248_x264 Normal -1 -1 -1 -1
96 | Normal_Videos_251_x264 Normal -1 -1 -1 -1
97 | Normal_Videos_289_x264 Normal -1 -1 -1 -1
98 | Normal_Videos_310_x264 Normal -1 -1 -1 -1
99 | Normal_Videos_312_x264 Normal -1 -1 -1 -1
100 | Normal_Videos_317_x264 Normal -1 -1 -1 -1
101 | Normal_Videos_345_x264 Normal -1 -1 -1 -1
102 | Normal_Videos_352_x264 Normal -1 -1 -1 -1
103 | Normal_Videos_360_x264 Normal -1 -1 -1 -1
104 | Normal_Videos_365_x264 Normal -1 -1 -1 -1
105 | Normal_Videos_401_x264 Normal -1 -1 -1 -1
106 | Normal_Videos_417_x264 Normal -1 -1 -1 -1
107 | Normal_Videos_439_x264 Normal -1 -1 -1 -1
108 | Normal_Videos_452_x264 Normal -1 -1 -1 -1
109 | Normal_Videos_453_x264 Normal -1 -1 -1 -1
110 | Normal_Videos_478_x264 Normal -1 -1 -1 -1
111 | Normal_Videos_576_x264 Normal -1 -1 -1 -1
112 | Normal_Videos_597_x264 Normal -1 -1 -1 -1
113 | Normal_Videos_603_x264 Normal -1 -1 -1 -1
114 | Normal_Videos_606_x264 Normal -1 -1 -1 -1
115 | Normal_Videos_621_x264 Normal -1 -1 -1 -1
116 | Normal_Videos_634_x264 Normal -1 -1 -1 -1
117 | Normal_Videos_641_x264 Normal -1 -1 -1 -1
118 | Normal_Videos_656_x264 Normal -1 -1 -1 -1
119 | Normal_Videos_686_x264 Normal -1 -1 -1 -1
120 | Normal_Videos_696_x264 Normal -1 -1 -1 -1
121 | Normal_Videos_702_x264 Normal -1 -1 -1 -1
122 | Normal_Videos_704_x264 Normal -1 -1 -1 -1
123 | Normal_Videos_710_x264 Normal -1 -1 -1 -1
124 | Normal_Videos_717_x264 Normal -1 -1 -1 -1
125 | Normal_Videos_722_x264 Normal -1 -1 -1 -1
126 | Normal_Videos_725_x264 Normal -1 -1 -1 -1
127 | Normal_Videos_745_x264 Normal -1 -1 -1 -1
128 | Normal_Videos_758_x264 Normal -1 -1 -1 -1
129 | Normal_Videos_778_x264 Normal -1 -1 -1 -1
130 | Normal_Videos_780_x264 Normal -1 -1 -1 -1
131 | Normal_Videos_781_x264 Normal -1 -1 -1 -1
132 | Normal_Videos_782_x264 Normal -1 -1 -1 -1
133 | Normal_Videos_783_x264 Normal -1 -1 -1 -1
134 | Normal_Videos_798_x264 Normal -1 -1 -1 -1
135 | Normal_Videos_801_x264 Normal -1 -1 -1 -1
136 | Normal_Videos_828_x264 Normal -1 -1 -1 -1
137 | Normal_Videos_831_x264 Normal -1 -1 -1 -1
138 | Normal_Videos_866_x264 Normal -1 -1 -1 -1
139 | Normal_Videos_867_x264 Normal -1 -1 -1 -1
140 | Normal_Videos_868_x264 Normal -1 -1 -1 -1
141 | Normal_Videos_869_x264 Normal -1 -1 -1 -1
142 | Normal_Videos_870_x264 Normal -1 -1 -1 -1
143 | Normal_Videos_871_x264 Normal -1 -1 -1 -1
144 | Normal_Videos_872_x264 Normal -1 -1 -1 -1
145 | Normal_Videos_873_x264 Normal -1 -1 -1 -1
146 | Normal_Videos_874_x264 Normal -1 -1 -1 -1
147 | Normal_Videos_875_x264 Normal -1 -1 -1 -1
148 | Normal_Videos_876_x264 Normal -1 -1 -1 -1
149 | Normal_Videos_877_x264 Normal -1 -1 -1 -1
150 | Normal_Videos_878_x264 Normal -1 -1 -1 -1
151 | Normal_Videos_879_x264 Normal -1 -1 -1 -1
152 | Normal_Videos_880_x264 Normal -1 -1 -1 -1
153 | Normal_Videos_881_x264 Normal -1 -1 -1 -1
154 | Normal_Videos_882_x264 Normal -1 -1 -1 -1
155 | Normal_Videos_883_x264 Normal -1 -1 -1 -1
156 | Normal_Videos_884_x264 Normal -1 -1 -1 -1
157 | Normal_Videos_885_x264 Normal -1 -1 -1 -1
158 | Normal_Videos_886_x264 Normal -1 -1 -1 -1
159 | Normal_Videos_887_x264 Normal -1 -1 -1 -1
160 | Normal_Videos_888_x264 Normal -1 -1 -1 -1
161 | Normal_Videos_889_x264 Normal -1 -1 -1 -1
162 | Normal_Videos_890_x264 Normal -1 -1 -1 -1
163 | Normal_Videos_891_x264 Normal -1 -1 -1 -1
164 | Normal_Videos_892_x264 Normal -1 -1 -1 -1
165 | Normal_Videos_893_x264 Normal -1 -1 -1 -1
166 | Normal_Videos_894_x264 Normal -1 -1 -1 -1
167 | Normal_Videos_895_x264 Normal -1 -1 -1 -1
168 | Normal_Videos_896_x264 Normal -1 -1 -1 -1
169 | Normal_Videos_897_x264 Normal -1 -1 -1 -1
170 | Normal_Videos_898_x264 Normal -1 -1 -1 -1
171 | Normal_Videos_899_x264 Normal -1 -1 -1 -1
172 | Normal_Videos_900_x264 Normal -1 -1 -1 -1
173 | Normal_Videos_901_x264 Normal -1 -1 -1 -1
174 | Normal_Videos_902_x264 Normal -1 -1 -1 -1
175 | Normal_Videos_903_x264 Normal -1 -1 -1 -1
176 | Normal_Videos_904_x264 Normal -1 -1 -1 -1
177 | Normal_Videos_905_x264 Normal -1 -1 -1 -1
178 | Normal_Videos_906_x264 Normal -1 -1 -1 -1
179 | Normal_Videos_907_x264 Normal -1 -1 -1 -1
180 | Normal_Videos_908_x264 Normal -1 -1 -1 -1
181 | Normal_Videos_909_x264 Normal -1 -1 -1 -1
182 | Normal_Videos_910_x264 Normal -1 -1 -1 -1
183 | Normal_Videos_911_x264 Normal -1 -1 -1 -1
184 | Normal_Videos_912_x264 Normal -1 -1 -1 -1
185 | Normal_Videos_913_x264 Normal -1 -1 -1 -1
186 | Normal_Videos_914_x264 Normal -1 -1 -1 -1
187 | Normal_Videos_915_x264 Normal -1 -1 -1 -1
188 | Normal_Videos_923_x264 Normal -1 -1 -1 -1
189 | Normal_Videos_924_x264 Normal -1 -1 -1 -1
190 | Normal_Videos_925_x264 Normal -1 -1 -1 -1
191 | Normal_Videos_926_x264 Normal -1 -1 -1 -1
192 | Normal_Videos_927_x264 Normal -1 -1 -1 -1
193 | Normal_Videos_928_x264 Normal -1 -1 -1 -1
194 | Normal_Videos_929_x264 Normal -1 -1 -1 -1
195 | Normal_Videos_930_x264 Normal -1 -1 -1 -1
196 | Normal_Videos_931_x264 Normal -1 -1 -1 -1
197 | Normal_Videos_932_x264 Normal -1 -1 -1 -1
198 | Normal_Videos_933_x264 Normal -1 -1 -1 -1
199 | Normal_Videos_934_x264 Normal -1 -1 -1 -1
200 | Normal_Videos_935_x264 Normal -1 -1 -1 -1
201 | Normal_Videos_936_x264 Normal -1 -1 -1 -1
202 | Normal_Videos_937_x264 Normal -1 -1 -1 -1
203 | Normal_Videos_938_x264 Normal -1 -1 -1 -1
204 | Normal_Videos_939_x264 Normal -1 -1 -1 -1
205 | Normal_Videos_940_x264 Normal -1 -1 -1 -1
206 | Normal_Videos_941_x264 Normal -1 -1 -1 -1
207 | Normal_Videos_943_x264 Normal -1 -1 -1 -1
208 | Normal_Videos_944_x264 Normal -1 -1 -1 -1
209 | RoadAccidents001_x264 RoadAccidents 210 300 -1 -1
210 | RoadAccidents002_x264 RoadAccidents 240 300 -1 -1
211 | RoadAccidents004_x264 RoadAccidents 140 189 -1 -1
212 | RoadAccidents009_x264 RoadAccidents 210 240 -1 -1
213 | RoadAccidents010_x264 RoadAccidents 230 270 -1 -1
214 | RoadAccidents011_x264 RoadAccidents 260 300 -1 -1
215 | RoadAccidents012_x264 RoadAccidents 250 390 -1 -1
216 | RoadAccidents016_x264 RoadAccidents 530 720 -1 -1
217 | RoadAccidents017_x264 RoadAccidents 60 130 -1 -1
218 | RoadAccidents019_x264 RoadAccidents 750 900 -1 -1
219 | RoadAccidents020_x264 RoadAccidents 610 730 -1 -1
220 | RoadAccidents021_x264 RoadAccidents 30 90 -1 -1
221 | RoadAccidents022_x264 RoadAccidents 120 220 -1 -1
222 | RoadAccidents121_x264 RoadAccidents 330 390 -1 -1
223 | RoadAccidents122_x264 RoadAccidents 300 360 -1 -1
224 | RoadAccidents123_x264 RoadAccidents 130 210 -1 -1
225 | RoadAccidents124_x264 RoadAccidents 250 420 -1 -1
226 | RoadAccidents125_x264 RoadAccidents 490 600 -1 -1
227 | RoadAccidents127_x264 RoadAccidents 2160 2300 -1 -1
228 | RoadAccidents128_x264 RoadAccidents 90 200 -1 -1
229 | RoadAccidents131_x264 RoadAccidents 180 240 -1 -1
230 | RoadAccidents132_x264 RoadAccidents 220 320 -1 -1
231 | RoadAccidents133_x264 RoadAccidents 270 450 -1 -1
232 | Robbery048_x264 Robbery 450 930 -1 -1
233 | Robbery050_x264 Robbery 495 1410 -1 -1
234 | Robbery102_x264 Robbery 1080 1560 -1 -1
235 | Robbery106_x264 Robbery 480 600 -1 -1
236 | Robbery137_x264 Robbery 135 1950 -1 -1
237 | Shooting002_x264 Shooting 1020 1100 -1 -1
238 | Shooting004_x264 Shooting 500 660 -1 -1
239 | Shooting007_x264 Shooting 45 165 -1 -1
240 | Shooting008_x264 Shooting 75 315 -1 -1
241 | Shooting010_x264 Shooting 1095 1260 -1 -1
242 | Shooting011_x264 Shooting 1480 1750 -1 -1
243 | Shooting013_x264 Shooting 860 945 -1 -1
244 | Shooting015_x264 Shooting 855 1715 -1 -1
245 | Shooting018_x264 Shooting 315 480 -1 -1
246 | Shooting019_x264 Shooting 1020 1455 -1 -1
247 | Shooting021_x264 Shooting 480 630 -1 -1
248 | Shooting022_x264 Shooting 2850 3300 -1 -1
249 | Shooting024_x264 Shooting 720 1305 -1 -1
250 | Shooting026_x264 Shooting 195 600 -1 -1
251 | Shooting028_x264 Shooting 285 555 -1 -1
252 | Shooting032_x264 Shooting 7995 8205 -1 -1
253 | Shooting033_x264 Shooting 1680 2000 -1 -1
254 | Shooting034_x264 Shooting 960 1050 -1 -1
255 | Shooting037_x264 Shooting 140 260 -1 -1
256 | Shooting043_x264 Shooting 945 1230 -1 -1
257 | Shooting046_x264 Shooting 4005 4230 4760 5088
258 | Shooting047_x264 Shooting 2160 3900 4860 6600
259 | Shooting048_x264 Shooting 1410 1730 -1 -1
260 | Shoplifting001_x264 Shoplifting 1550 2000 -1 -1
261 | Shoplifting004_x264 Shoplifting 2200 4900 -1 -1
262 | Shoplifting005_x264 Shoplifting 720 930 -1 -1
263 | Shoplifting007_x264 Shoplifting 550 760 4630 4920
264 | Shoplifting010_x264 Shoplifting 750 920 1550 1970
265 | Shoplifting015_x264 Shoplifting 2010 2160 -1 -1
266 | Shoplifting016_x264 Shoplifting 630 720 -1 -1
267 | Shoplifting017_x264 Shoplifting 360 420 -1 -1
268 | Shoplifting020_x264 Shoplifting 2340 2460 -1 -1
269 | Shoplifting021_x264 Shoplifting 2070 2220 -1 -1
270 | Shoplifting022_x264 Shoplifting 270 420 1440 1560
271 | Shoplifting027_x264 Shoplifting 1080 1160 1470 1710
272 | Shoplifting028_x264 Shoplifting 570 840 -1 -1
273 | Shoplifting029_x264 Shoplifting 1020 1470 -1 -1
274 | Shoplifting031_x264 Shoplifting 120 330 -1 -1
275 | Shoplifting033_x264 Shoplifting 630 750 -1 -1
276 | Shoplifting034_x264 Shoplifting 7350 7470 -1 -1
277 | Shoplifting037_x264 Shoplifting 1140 1200 -1 -1
278 | Shoplifting039_x264 Shoplifting 2190 2340 -1 -1
279 | Shoplifting044_x264 Shoplifting 11070 11250 -1 -1
280 | Shoplifting049_x264 Shoplifting 1020 1350 -1 -1
281 | Stealing019_x264 Stealing 2730 2790 4170 4350
282 | Stealing036_x264 Stealing 1260 1590 -1 -1
283 | Stealing058_x264 Stealing 570 3660 -1 -1
284 | Stealing062_x264 Stealing 360 1050 -1 -1
285 | Stealing079_x264 Stealing 2550 3210 3510 4500
286 | Vandalism007_x264 Vandalism 240 750 -1 -1
287 | Vandalism015_x264 Vandalism 2010 2700 -1 -1
288 | Vandalism017_x264 Vandalism 270 330 780 840
289 | Vandalism028_x264 Vandalism 1830 1980 2400 2670
290 | Vandalism036_x264 Vandalism 540 780 990 1080
291 |
--------------------------------------------------------------------------------
/src/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.relu1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.relu2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.relu3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.relu1(self.bn1(self.conv1(x)))
46 | out = self.relu2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x[:1], key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 | return x.squeeze(0)
92 |
93 |
94 | class ModifiedResNet(nn.Module):
95 | """
96 | A ResNet class that is similar to torchvision's but contains the following changes:
97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99 | - The final pooling layer is a QKV attention instead of an average pool
100 | """
101 |
102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103 | super().__init__()
104 | self.output_dim = output_dim
105 | self.input_resolution = input_resolution
106 |
107 | # the 3-layer stem
108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109 | self.bn1 = nn.BatchNorm2d(width // 2)
110 | self.relu1 = nn.ReLU(inplace=True)
111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112 | self.bn2 = nn.BatchNorm2d(width // 2)
113 | self.relu2 = nn.ReLU(inplace=True)
114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115 | self.bn3 = nn.BatchNorm2d(width)
116 | self.relu3 = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(2)
118 |
119 | # residual layers
120 | self._inplanes = width # this is a *mutable* variable used during construction
121 | self.layer1 = self._make_layer(width, layers[0])
122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125 |
126 | embed_dim = width * 32 # the ResNet feature dimension
127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128 |
129 | def _make_layer(self, planes, blocks, stride=1):
130 | layers = [Bottleneck(self._inplanes, planes, stride)]
131 |
132 | self._inplanes = planes * Bottleneck.expansion
133 | for _ in range(1, blocks):
134 | layers.append(Bottleneck(self._inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | def stem(x):
140 | x = self.relu1(self.bn1(self.conv1(x)))
141 | x = self.relu2(self.bn2(self.conv2(x)))
142 | x = self.relu3(self.bn3(self.conv3(x)))
143 | x = self.avgpool(x)
144 | return x
145 |
146 | x = x.type(self.conv1.weight.dtype)
147 | x = stem(x)
148 | x = self.layer1(x)
149 | x = self.layer2(x)
150 | x = self.layer3(x)
151 | x = self.layer4(x)
152 | x = self.attnpool(x)
153 |
154 | return x
155 |
156 |
157 | class LayerNorm(nn.LayerNorm):
158 | """Subclass torch's LayerNorm to handle fp16."""
159 |
160 | def forward(self, x: torch.Tensor):
161 | orig_type = x.dtype
162 | ret = super().forward(x.type(torch.float32))
163 | return ret.type(orig_type)
164 |
165 |
166 | class QuickGELU(nn.Module):
167 | def forward(self, x: torch.Tensor):
168 | return x * torch.sigmoid(1.702 * x)
169 |
170 |
171 | class ResidualAttentionBlock(nn.Module):
172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173 | super().__init__()
174 |
175 | self.attn = nn.MultiheadAttention(d_model, n_head)
176 | self.ln_1 = LayerNorm(d_model)
177 | self.mlp = nn.Sequential(OrderedDict([
178 | ("c_fc", nn.Linear(d_model, d_model * 4)),
179 | ("gelu", QuickGELU()),
180 | ("c_proj", nn.Linear(d_model * 4, d_model))
181 | ]))
182 | self.ln_2 = LayerNorm(d_model)
183 | self.attn_mask = attn_mask
184 |
185 | def attention(self, x: torch.Tensor):
186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188 |
189 | def forward(self, x: torch.Tensor):
190 | x = x + self.attention(self.ln_1(x))
191 | x = x + self.mlp(self.ln_2(x))
192 | return x
193 |
194 |
195 | class Transformer(nn.Module):
196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197 | super().__init__()
198 | self.width = width
199 | self.layers = layers
200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201 |
202 | def forward(self, x: torch.Tensor):
203 | return self.resblocks(x)
204 |
205 |
206 | class VisionTransformer(nn.Module):
207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
208 | super().__init__()
209 | self.input_resolution = input_resolution
210 | self.output_dim = output_dim
211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212 |
213 | scale = width ** -0.5
214 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216 | self.ln_pre = LayerNorm(width)
217 |
218 | self.transformer = Transformer(width, layers, heads)
219 |
220 | self.ln_post = LayerNorm(width)
221 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
222 |
223 | def forward(self, x: torch.Tensor):
224 | x = self.conv1(x) # shape = [*, width, grid, grid]
225 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
226 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
227 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
228 | x = x + self.positional_embedding.to(x.dtype)
229 | x = self.ln_pre(x)
230 |
231 | x = x.permute(1, 0, 2) # NLD -> LND
232 | x = self.transformer(x)
233 | x = x.permute(1, 0, 2) # LND -> NLD
234 |
235 | x = self.ln_post(x[:, 0, :])
236 |
237 | if self.proj is not None:
238 | x = x @ self.proj
239 |
240 | return x
241 |
242 |
243 | class CLIP(nn.Module):
244 | def __init__(self,
245 | embed_dim: int,
246 | # vision
247 | image_resolution: int,
248 | vision_layers: Union[Tuple[int, int, int, int], int],
249 | vision_width: int,
250 | vision_patch_size: int,
251 | # text
252 | context_length: int,
253 | vocab_size: int,
254 | transformer_width: int,
255 | transformer_heads: int,
256 | transformer_layers: int
257 | ):
258 | super().__init__()
259 |
260 | self.context_length = context_length
261 |
262 | if isinstance(vision_layers, (tuple, list)):
263 | vision_heads = vision_width * 32 // 64
264 | self.visual = ModifiedResNet(
265 | layers=vision_layers,
266 | output_dim=embed_dim,
267 | heads=vision_heads,
268 | input_resolution=image_resolution,
269 | width=vision_width
270 | )
271 | else:
272 | vision_heads = vision_width // 64
273 | self.visual = VisionTransformer(
274 | input_resolution=image_resolution,
275 | patch_size=vision_patch_size,
276 | width=vision_width,
277 | layers=vision_layers,
278 | heads=vision_heads,
279 | output_dim=embed_dim
280 | )
281 |
282 | self.transformer = Transformer(
283 | width=transformer_width,
284 | layers=transformer_layers,
285 | heads=transformer_heads,
286 | attn_mask=self.build_attention_mask()
287 | )
288 |
289 | self.vocab_size = vocab_size
290 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
291 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
292 | self.ln_final = LayerNorm(transformer_width)
293 |
294 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
295 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
296 |
297 | self.initialize_parameters()
298 |
299 | def initialize_parameters(self):
300 | nn.init.normal_(self.token_embedding.weight, std=0.02)
301 | nn.init.normal_(self.positional_embedding, std=0.01)
302 |
303 | if isinstance(self.visual, ModifiedResNet):
304 | if self.visual.attnpool is not None:
305 | std = self.visual.attnpool.c_proj.in_features ** -0.5
306 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
307 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
308 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
309 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
310 |
311 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
312 | for name, param in resnet_block.named_parameters():
313 | if name.endswith("bn3.weight"):
314 | nn.init.zeros_(param)
315 |
316 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
317 | attn_std = self.transformer.width ** -0.5
318 | fc_std = (2 * self.transformer.width) ** -0.5
319 | for block in self.transformer.resblocks:
320 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
321 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
322 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
323 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
324 |
325 | if self.text_projection is not None:
326 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
327 |
328 | def build_attention_mask(self):
329 | # lazily create causal attention mask, with full attention between the vision tokens
330 | # pytorch uses additive attention mask; fill with -inf
331 | mask = torch.empty(self.context_length, self.context_length)
332 | mask.fill_(float("-inf"))
333 | mask.triu_(1) # zero out the lower diagonal
334 | return mask
335 |
336 | @property
337 | def dtype(self):
338 | return self.visual.conv1.weight.dtype
339 |
340 | def encode_image(self, image):
341 | return self.visual(image.type(self.dtype))
342 |
343 | def encode_token(self, token):
344 | x = self.token_embedding(token)
345 | return x
346 |
347 | def encode_text(self, text, token):
348 | #x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
349 | x = text.type(self.dtype) + self.positional_embedding.type(self.dtype)
350 | x = x.permute(1, 0, 2) # NLD -> LND
351 | x = self.transformer(x)
352 | x = x.permute(1, 0, 2) # LND -> NLD
353 | x = self.ln_final(x).type(self.dtype)
354 |
355 | # x.shape = [batch_size, n_ctx, transformer.width]
356 | # take features from the eot embedding (eot_token is the highest number in each sequence)
357 | x = x[torch.arange(x.shape[0]), token.argmax(dim=-1)] @ self.text_projection
358 |
359 | return x
360 |
361 | def forward(self, image, text):
362 | image_features = self.encode_image(image)
363 | text_features = self.encode_text(text)
364 |
365 | # normalized features
366 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
367 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
368 |
369 | # cosine similarity as logits
370 | logit_scale = self.logit_scale.exp()
371 | logits_per_image = logit_scale * image_features @ text_features.t()
372 | logits_per_text = logits_per_image.t()
373 |
374 | # shape = [global_batch_size, global_batch_size]
375 | return logits_per_image, logits_per_text
376 |
377 |
378 | def convert_weights(model: nn.Module):
379 | """Convert applicable model parameters to fp16"""
380 |
381 | def _convert_weights_to_fp16(l):
382 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
383 | l.weight.data = l.weight.data.half()
384 | if l.bias is not None:
385 | l.bias.data = l.bias.data.half()
386 |
387 | if isinstance(l, nn.MultiheadAttention):
388 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
389 | tensor = getattr(l, attr)
390 | if tensor is not None:
391 | tensor.data = tensor.data.half()
392 |
393 | for name in ["text_projection", "proj"]:
394 | if hasattr(l, name):
395 | attr = getattr(l, name)
396 | if attr is not None:
397 | attr.data = attr.data.half()
398 |
399 | model.apply(_convert_weights_to_fp16)
400 |
401 |
402 | def build_model(state_dict: dict):
403 | vit = "visual.proj" in state_dict
404 |
405 | if vit:
406 | vision_width = state_dict["visual.conv1.weight"].shape[0]
407 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
408 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
409 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
410 | image_resolution = vision_patch_size * grid_size
411 | else:
412 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
413 | vision_layers = tuple(counts)
414 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
415 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
416 | vision_patch_size = None
417 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
418 | image_resolution = output_width * 32
419 |
420 | embed_dim = state_dict["text_projection"].shape[1]
421 | context_length = state_dict["positional_embedding"].shape[0]
422 | vocab_size = state_dict["token_embedding.weight"].shape[0]
423 | transformer_width = state_dict["ln_final.weight"].shape[0]
424 | transformer_heads = transformer_width // 64
425 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
426 |
427 | model = CLIP(
428 | embed_dim,
429 | image_resolution, vision_layers, vision_width, vision_patch_size,
430 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
431 | )
432 |
433 | for key in ["input_resolution", "context_length", "vocab_size"]:
434 | if key in state_dict:
435 | del state_dict[key]
436 |
437 | #convert_weights(model)
438 | model.load_state_dict(state_dict)
439 | return model.eval()
440 |
--------------------------------------------------------------------------------
/list/ucf_CLIP_rgbtest.csv:
--------------------------------------------------------------------------------
1 | path,label
2 | /home/xbgydx/Desktop/UCFClipFeatures/Abuse/Abuse028_x264__5.npy,Abuse
3 | /home/xbgydx/Desktop/UCFClipFeatures/Abuse/Abuse030_x264__5.npy,Abuse
4 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest001_x264__5.npy,Arrest
5 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest007_x264__5.npy,Arrest
6 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest024_x264__5.npy,Arrest
7 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest030_x264__5.npy,Arrest
8 | /home/xbgydx/Desktop/UCFClipFeatures/Arrest/Arrest039_x264__5.npy,Arrest
9 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson007_x264__5.npy,Arson
10 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson009_x264__5.npy,Arson
11 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson010_x264__5.npy,Arson
12 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson011_x264__5.npy,Arson
13 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson016_x264__5.npy,Arson
14 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson018_x264__5.npy,Arson
15 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson022_x264__5.npy,Arson
16 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson035_x264__5.npy,Arson
17 | /home/xbgydx/Desktop/UCFClipFeatures/Arson/Arson041_x264__5.npy,Arson
18 | /home/xbgydx/Desktop/UCFClipFeatures/Assault/Assault006_x264__5.npy,Assault
19 | /home/xbgydx/Desktop/UCFClipFeatures/Assault/Assault010_x264__5.npy,Assault
20 | /home/xbgydx/Desktop/UCFClipFeatures/Assault/Assault011_x264__5.npy,Assault
21 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary005_x264__5.npy,Burglary
22 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary017_x264__5.npy,Burglary
23 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary018_x264__5.npy,Burglary
24 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary021_x264__5.npy,Burglary
25 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary024_x264__5.npy,Burglary
26 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary032_x264__5.npy,Burglary
27 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary033_x264__5.npy,Burglary
28 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary035_x264__5.npy,Burglary
29 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary037_x264__5.npy,Burglary
30 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary061_x264__5.npy,Burglary
31 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary076_x264__5.npy,Burglary
32 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary079_x264__5.npy,Burglary
33 | /home/xbgydx/Desktop/UCFClipFeatures/Burglary/Burglary092_x264__5.npy,Burglary
34 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion002_x264__5.npy,Explosion
35 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion004_x264__5.npy,Explosion
36 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion007_x264__5.npy,Explosion
37 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion008_x264__5.npy,Explosion
38 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion010_x264__5.npy,Explosion
39 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion011_x264__5.npy,Explosion
40 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion013_x264__5.npy,Explosion
41 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion016_x264__5.npy,Explosion
42 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion017_x264__5.npy,Explosion
43 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion020_x264__5.npy,Explosion
44 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion021_x264__5.npy,Explosion
45 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion022_x264__5.npy,Explosion
46 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion025_x264__5.npy,Explosion
47 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion027_x264__5.npy,Explosion
48 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion028_x264__5.npy,Explosion
49 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion029_x264__5.npy,Explosion
50 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion033_x264__5.npy,Explosion
51 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion035_x264__5.npy,Explosion
52 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion036_x264__5.npy,Explosion
53 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion039_x264__5.npy,Explosion
54 | /home/xbgydx/Desktop/UCFClipFeatures/Explosion/Explosion043_x264__5.npy,Explosion
55 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting003_x264__5.npy,Fighting
56 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting018_x264__5.npy,Fighting
57 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting033_x264__5.npy,Fighting
58 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting042_x264__5.npy,Fighting
59 | /home/xbgydx/Desktop/UCFClipFeatures/Fighting/Fighting047_x264__5.npy,Fighting
60 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents001_x264__5.npy,RoadAccidents
61 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents002_x264__5.npy,RoadAccidents
62 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents004_x264__5.npy,RoadAccidents
63 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents009_x264__5.npy,RoadAccidents
64 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents010_x264__5.npy,RoadAccidents
65 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents011_x264__5.npy,RoadAccidents
66 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents012_x264__5.npy,RoadAccidents
67 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents016_x264__5.npy,RoadAccidents
68 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents017_x264__5.npy,RoadAccidents
69 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents019_x264__5.npy,RoadAccidents
70 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents020_x264__5.npy,RoadAccidents
71 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents021_x264__5.npy,RoadAccidents
72 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents022_x264__5.npy,RoadAccidents
73 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents121_x264__5.npy,RoadAccidents
74 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents122_x264__5.npy,RoadAccidents
75 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents123_x264__5.npy,RoadAccidents
76 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents124_x264__5.npy,RoadAccidents
77 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents125_x264__5.npy,RoadAccidents
78 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents127_x264__5.npy,RoadAccidents
79 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents128_x264__5.npy,RoadAccidents
80 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents131_x264__5.npy,RoadAccidents
81 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents132_x264__5.npy,RoadAccidents
82 | /home/xbgydx/Desktop/UCFClipFeatures/RoadAccidents/RoadAccidents133_x264__5.npy,RoadAccidents
83 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery048_x264__5.npy,Robbery
84 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery050_x264__5.npy,Robbery
85 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery102_x264__5.npy,Robbery
86 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery106_x264__5.npy,Robbery
87 | /home/xbgydx/Desktop/UCFClipFeatures/Robbery/Robbery137_x264__5.npy,Robbery
88 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting002_x264__5.npy,Shooting
89 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting004_x264__5.npy,Shooting
90 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting007_x264__5.npy,Shooting
91 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting008_x264__5.npy,Shooting
92 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting010_x264__5.npy,Shooting
93 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting011_x264__5.npy,Shooting
94 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting013_x264__5.npy,Shooting
95 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting015_x264__5.npy,Shooting
96 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting018_x264__5.npy,Shooting
97 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting019_x264__5.npy,Shooting
98 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting021_x264__5.npy,Shooting
99 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting022_x264__5.npy,Shooting
100 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting024_x264__5.npy,Shooting
101 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting026_x264__5.npy,Shooting
102 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting028_x264__5.npy,Shooting
103 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting032_x264__5.npy,Shooting
104 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting033_x264__5.npy,Shooting
105 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting034_x264__5.npy,Shooting
106 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting037_x264__5.npy,Shooting
107 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting043_x264__5.npy,Shooting
108 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting046_x264__5.npy,Shooting
109 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting047_x264__5.npy,Shooting
110 | /home/xbgydx/Desktop/UCFClipFeatures/Shooting/Shooting048_x264__5.npy,Shooting
111 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting001_x264__5.npy,Shoplifting
112 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting004_x264__5.npy,Shoplifting
113 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting005_x264__5.npy,Shoplifting
114 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting007_x264__5.npy,Shoplifting
115 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting010_x264__5.npy,Shoplifting
116 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting015_x264__5.npy,Shoplifting
117 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting016_x264__5.npy,Shoplifting
118 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting017_x264__5.npy,Shoplifting
119 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting020_x264__5.npy,Shoplifting
120 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting021_x264__5.npy,Shoplifting
121 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting022_x264__5.npy,Shoplifting
122 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting027_x264__5.npy,Shoplifting
123 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting028_x264__5.npy,Shoplifting
124 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting029_x264__5.npy,Shoplifting
125 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting031_x264__5.npy,Shoplifting
126 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting033_x264__5.npy,Shoplifting
127 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting034_x264__5.npy,Shoplifting
128 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting037_x264__5.npy,Shoplifting
129 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting039_x264__5.npy,Shoplifting
130 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting044_x264__5.npy,Shoplifting
131 | /home/xbgydx/Desktop/UCFClipFeatures/Shoplifting/Shoplifting049_x264__5.npy,Shoplifting
132 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing019_x264__5.npy,Stealing
133 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing036_x264__5.npy,Stealing
134 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing058_x264__5.npy,Stealing
135 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing062_x264__5.npy,Stealing
136 | /home/xbgydx/Desktop/UCFClipFeatures/Stealing/Stealing079_x264__5.npy,Stealing
137 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism007_x264__5.npy,Vandalism
138 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism015_x264__5.npy,Vandalism
139 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism017_x264__5.npy,Vandalism
140 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism028_x264__5.npy,Vandalism
141 | /home/xbgydx/Desktop/UCFClipFeatures/Vandalism/Vandalism036_x264__5.npy,Vandalism
142 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_003_x264__5.npy,Normal
143 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_006_x264__5.npy,Normal
144 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_010_x264__5.npy,Normal
145 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_014_x264__5.npy,Normal
146 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_015_x264__5.npy,Normal
147 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_018_x264__5.npy,Normal
148 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_019_x264__5.npy,Normal
149 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_024_x264__5.npy,Normal
150 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_025_x264__5.npy,Normal
151 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_027_x264__5.npy,Normal
152 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_033_x264__5.npy,Normal
153 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_034_x264__5.npy,Normal
154 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_041_x264__5.npy,Normal
155 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_042_x264__5.npy,Normal
156 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_048_x264__5.npy,Normal
157 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_050_x264__5.npy,Normal
158 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_051_x264__5.npy,Normal
159 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_056_x264__5.npy,Normal
160 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_059_x264__5.npy,Normal
161 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_063_x264__5.npy,Normal
162 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_067_x264__5.npy,Normal
163 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_070_x264__5.npy,Normal
164 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_100_x264__5.npy,Normal
165 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_129_x264__5.npy,Normal
166 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_150_x264__5.npy,Normal
167 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_168_x264__5.npy,Normal
168 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_175_x264__5.npy,Normal
169 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_182_x264__5.npy,Normal
170 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_189_x264__5.npy,Normal
171 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_196_x264__5.npy,Normal
172 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_203_x264__5.npy,Normal
173 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_210_x264__5.npy,Normal
174 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_217_x264__5.npy,Normal
175 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_224_x264__5.npy,Normal
176 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_246_x264__5.npy,Normal
177 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_247_x264__5.npy,Normal
178 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_248_x264__5.npy,Normal
179 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_251_x264__5.npy,Normal
180 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_289_x264__5.npy,Normal
181 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_310_x264__5.npy,Normal
182 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_312_x264__5.npy,Normal
183 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_317_x264__5.npy,Normal
184 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_345_x264__5.npy,Normal
185 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_352_x264__5.npy,Normal
186 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_360_x264__5.npy,Normal
187 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_365_x264__5.npy,Normal
188 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_401_x264__5.npy,Normal
189 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_417_x264__5.npy,Normal
190 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_439_x264__5.npy,Normal
191 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_452_x264__5.npy,Normal
192 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_453_x264__5.npy,Normal
193 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_478_x264__5.npy,Normal
194 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_576_x264__5.npy,Normal
195 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_597_x264__5.npy,Normal
196 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_603_x264__5.npy,Normal
197 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_606_x264__5.npy,Normal
198 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_621_x264__5.npy,Normal
199 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_634_x264__5.npy,Normal
200 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_641_x264__5.npy,Normal
201 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_656_x264__5.npy,Normal
202 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_686_x264__5.npy,Normal
203 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_696_x264__5.npy,Normal
204 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_702_x264__5.npy,Normal
205 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_704_x264__5.npy,Normal
206 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_710_x264__5.npy,Normal
207 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_717_x264__5.npy,Normal
208 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_722_x264__5.npy,Normal
209 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_725_x264__5.npy,Normal
210 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_745_x264__5.npy,Normal
211 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_758_x264__5.npy,Normal
212 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_778_x264__5.npy,Normal
213 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_780_x264__5.npy,Normal
214 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_781_x264__5.npy,Normal
215 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_782_x264__5.npy,Normal
216 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_783_x264__5.npy,Normal
217 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_798_x264__5.npy,Normal
218 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_801_x264__5.npy,Normal
219 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_828_x264__5.npy,Normal
220 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_831_x264__5.npy,Normal
221 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_866_x264__5.npy,Normal
222 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_867_x264__5.npy,Normal
223 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_868_x264__5.npy,Normal
224 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_869_x264__5.npy,Normal
225 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_870_x264__5.npy,Normal
226 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_871_x264__5.npy,Normal
227 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_872_x264__5.npy,Normal
228 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_873_x264__5.npy,Normal
229 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_874_x264__5.npy,Normal
230 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_875_x264__5.npy,Normal
231 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_876_x264__5.npy,Normal
232 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_877_x264__5.npy,Normal
233 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_878_x264__5.npy,Normal
234 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_879_x264__5.npy,Normal
235 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_880_x264__5.npy,Normal
236 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_881_x264__5.npy,Normal
237 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_882_x264__5.npy,Normal
238 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_883_x264__5.npy,Normal
239 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_884_x264__5.npy,Normal
240 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_885_x264__5.npy,Normal
241 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_886_x264__5.npy,Normal
242 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_887_x264__5.npy,Normal
243 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_888_x264__5.npy,Normal
244 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_889_x264__5.npy,Normal
245 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_890_x264__5.npy,Normal
246 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_891_x264__5.npy,Normal
247 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_892_x264__5.npy,Normal
248 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_893_x264__5.npy,Normal
249 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_894_x264__5.npy,Normal
250 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_895_x264__5.npy,Normal
251 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_896_x264__5.npy,Normal
252 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_897_x264__5.npy,Normal
253 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_898_x264__5.npy,Normal
254 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_899_x264__5.npy,Normal
255 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_900_x264__5.npy,Normal
256 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_901_x264__5.npy,Normal
257 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_902_x264__5.npy,Normal
258 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_903_x264__5.npy,Normal
259 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_904_x264__5.npy,Normal
260 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_905_x264__5.npy,Normal
261 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_906_x264__5.npy,Normal
262 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_907_x264__5.npy,Normal
263 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_908_x264__5.npy,Normal
264 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_909_x264__5.npy,Normal
265 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_910_x264__5.npy,Normal
266 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_911_x264__5.npy,Normal
267 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_912_x264__5.npy,Normal
268 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_913_x264__5.npy,Normal
269 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_914_x264__5.npy,Normal
270 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_915_x264__5.npy,Normal
271 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_923_x264__5.npy,Normal
272 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_924_x264__5.npy,Normal
273 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_925_x264__5.npy,Normal
274 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_926_x264__5.npy,Normal
275 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_927_x264__5.npy,Normal
276 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_928_x264__5.npy,Normal
277 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_929_x264__5.npy,Normal
278 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_930_x264__5.npy,Normal
279 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_931_x264__5.npy,Normal
280 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_932_x264__5.npy,Normal
281 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_933_x264__5.npy,Normal
282 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_934_x264__5.npy,Normal
283 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_935_x264__5.npy,Normal
284 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_936_x264__5.npy,Normal
285 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_937_x264__5.npy,Normal
286 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_938_x264__5.npy,Normal
287 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_939_x264__5.npy,Normal
288 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_940_x264__5.npy,Normal
289 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_941_x264__5.npy,Normal
290 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_943_x264__5.npy,Normal
291 | /home/xbgydx/Desktop/UCFClipFeatures/Testing_Normal_Videos_Anomaly/Normal_Videos_944_x264__5.npy,Normal
292 |
--------------------------------------------------------------------------------
/list/annotations.txt:
--------------------------------------------------------------------------------
1 | v=S-7rRLrxnVQ__#1_label_B4-0-0 0 1517 1970 3038
2 | v=u5SF4SlqNDQ__#00-00-00_00-01-00_label_G-0-0 624 912 950 1441
3 | v=u5SF4SlqNDQ__#00-02-39_00-03-41_label_G-0-0 77 1430
4 | v=cEOM18n8fhU__#1_label_G-0-0 2133 2750
5 | v=NnmqkS1e88s__#1_label_B4-0-0 0 1750
6 | v=vaSOMEIe1Bg__#1_label_G-0-0 300 1380
7 | v=ZnFlL84K7HE__#00-27-53_00-34-45_label_B4-0-0 0 9523
8 | v=wVey5JDRf_g__#00-00-00_00-01-20_label_B6-0-0 51 210 273 410 544 720 828 970 1135 1270 1413 1500 1618 1710 1822 1920
9 | v=wVey5JDRf_g__#00-01-30_00-02-41_label_B6-0-0 142 296 447 508 606 710 800 870 1000 1185 1414 1550 1635 1700
10 | v=wVey5JDRf_g__#00-04-09_00-05-06_label_B6-0-0 85 200 352 438 547 690 777 833 918 952 1055 1170 1314 1365
11 | v=m8EkFsaGPzU__#1_label_B4-0-0 0 3380 4523 5700
12 | v=5BBoVfKOyeM__#00-08-19_00-10-10_label_B6-0-0 1067 1200 1650 1734
13 | v=BQjKQbYgUBA__#1_label_B1-0-0 635 667 733 759 932 1020 1234 1297 1977 2373 3195 3422
14 | v=bfOlheR5nUQ__#00-00-00_00-04-33_label_B4-0-0 0 6553
15 | v=bfOlheR5nUQ__#00-04-33_00-06-02_label_B4-0-0 0 2135
16 | v=YKtw6uLEGps__#1_label_B4-0-0 807 2118
17 | v=7xQr5wYwNPg__#1_label_B4-0-0 0 403 600 1008
18 | v=-fOWSLV6Esw__#1_label_B4-0-0 20 772
19 | v=b0MUjeKAGZw__#1_label_B4-0-0 0 2790
20 | Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6 157 180 185 244 250 360 582 810
21 | Bad.Boys.1995__#01-33-51_01-34-37_label_B2-0-0 57 670 728 800
22 | Bad.Boys.II.2003__#00-06-42_00-10-00_label_B2-G-0 2610 2670 2760 3060 3167 4310
23 | v=DD3jfKr8e-k__#00-00-00_00-01-40_label_B4-0-0 0 2400
24 | v=2jwO15SMyuo__#1_label_B4-0-0 0 2374
25 | v=jYENhkzdpO8__#00-00-00_00-01-31_label_B6-0-0 245 335 1187 1252 1836 2000
26 | v=1djrJ0wxlYo__#1_label_B4-0-0 0 758 880 1533
27 | v=7zEFHKKBA0g__#1_label_B4-0-0 1192 1600 2031 2690
28 | v=hxyhulJYz5I__#00-06-09_00-07-00_label_B6-0-0 25 70 256 360 520 640 680 855 927 1000 1140 1220
29 | v=bhZs3ALdL7Y__#1_label_G-0-0 0 70
30 | Black.Hawk.Down.2001__#01-13-59_01-14-49_label_B2-0-0 188 1190
31 | Black.Hawk.Down.2001__#01-32-40_01-34-00_label_B4-0-0 0 269 518 882 1090 1180 1220 1919
32 | Black.Hawk.Down.2001__#01-42-58_01-43-58_label_G-0-0 406 505
33 | Black.Hawk.Down.2001__#02-00-12_02-01-29_label_B2-0-0 10 1848
34 | v=8tqeeBGjnPg__#00-03-12_00-04-15_label_B6-0-0 160 300 690 980 1100 1270
35 | v=6nfo9c7a5pE__#1_label_B1-0-0 0 150 295 700
36 | v=rJz4NXm6vis__#1_label_B4-0-0 0 1210 1351 1797
37 | Braveheart.1995__#00-56-30_00-57-20_label_B1-0-0 0 360
38 | Braveheart.1995__#01-26-50_01-32-30_label_B1-0-0 1011 7515
39 | Braveheart.1995__#02-05-34_02-06-40_label_B1-0-0 256 1580
40 | Braveheart.1995__#02-07-00_02-08-15_label_B1-0-0 450 600 1027 1300
41 | v=zI0q5UDP47g__#1_label_B4-0-0 66 3560
42 | Brick.Mansions.2014__#00-11-57_00-12-12_label_B2-0-0 308 338
43 | Brick.Mansions.2014__#00-16-26_00-17-12_label_B1-0-0 30 1100
44 | Brick.Mansions.2014__#00-41-25_00-42-36_label_B1-0-0 523 610 900 1700
45 | Brick.Mansions.2014__#01-02-00_01-02-43_label_B2-0-0 280 550
46 | v=73uRcX0Dvfc__#00-00-00_00-01-51_label_B6-G-0 177 217 450 540 958 1260 1507 1670 2053 2180 2460 2589
47 | v=Ia9ATKNeUbY__#00-01-52_00-03-10_label_B6-0-0 90 145 470 570 1130 1200 1493 1660
48 | v=Ia9ATKNeUbY__#00-04-19_00-05-11_label_B6-0-0 27 133 274 510 1137 1170
49 | v=Ia9ATKNeUbY__#00-05-19_00-06-09_label_B6-0-0 120 180 800 855 1112 1185
50 | v=Ia9ATKNeUbY__#00-06-19_00-06-50_label_B6-0-0 246 290 427 527 687 744
51 | Bullet.in.the.Head.1990__#00-17-20_00-18-55_label_B1-0-0 138 280 546 1915
52 | Bullet.in.the.Head.1990__#00-41-30_00-44-16_label_B4-G-0 388 3425 3626 3668
53 | Bullet.in.the.Head.1990__#01-26-30_01-27-34_label_B2-0-0 695 820
54 | Bullet.in.the.Head.1990__#01-45-21_01-46-20_label_B2-0-0 400 502
55 | Bullet.in.the.Head.1990__#02-02-00_02-05-22_label_B2-B6-G 0 35 203 860 1005 1465 2130 2380 2860 3250 3541 4346
56 | v=EOAfl8yMiN8__#00-00-00_00-00-51_label_B6-0-0 119 150 330 380 515 595 708 760 940 1200
57 | v=EOAfl8yMiN8__#00-03-00_00-04-09_label_B6-0-0 245 325 470 520 716 850 960 1034 1165 1200 1370 1454 1614 1652
58 | v=kVl-6-A9ZO4__#00-08-01_00-09-35_label_B6-0-0 65 140 428 500 693 800 986 1117 1550 1700 2038 2218
59 | v=iHuggczItBk__#00-01-00_00-02-45_label_B6-0-0 70 140 600 700 1080 1223 1459 1474 1750 1950 2257 2315 2474 2520
60 | v=iHuggczItBk__#00-03-00_00-03-55_label_B6-0-0 180 240 450 555 880 960 1200 1300
61 | v=iHuggczItBk__#00-04-00_00-04-55_label_B6-0-0 56 173 385 435 780 830
62 | v=0qtIjyt-7wg__#00-00-00_00-00-51_label_B6-0-0 205 270 537 610 1028 1115
63 | v=0qtIjyt-7wg__#00-01-00_00-02-45_label_B6-0-0 264 340 760 900 1337 1430 1612 1656 1806 1940 2390 2480
64 | v=0qtIjyt-7wg__#00-04-00_00-05-15_label_B6-0-0 25 55 265 315
65 | v=vhACO_m5pH0__#00-09-42_00-10-40_label_B6-0-0 238 314 475 520 911 960 1084 1208
66 | v=vhACO_m5pH0__#00-10-42_00-11-40_label_B6-0-0 48 55 303 354 890 1024
67 | v=vhACO_m5pH0__#00-14-42_00-15-40_label_B6-0-0 20 140 460 500 735 820 970 1070 1225 1266
68 | v=vhACO_m5pH0__#00-17-42_00-18-40_label_B6-0-0 0 45 148 218 440 486 640 710 860 900 1010 1060 1248 1300
69 | v=vhACO_m5pH0__#00-20-42_00-22-40_label_B6-0-0 30 130 300 380 625 680 830 920 1015 1085 1346 1368 1596 1642 1820 1869 2020 2100 2258 2345 2530 2560 2670 2749
70 | v=tT-SrQC6Ddw__#00-04-00_00-05-15_label_B6-0-0 185 226 1237 1320 1485 1640
71 | v=ZoS8gm5OcOM__#00-06-20_00-07-30_label_B6-0-0 364 627 1464 1680
72 | v=ZoS8gm5OcOM__#00-07-50_00-08-55_label_B6-0-0 267 320 795 850 1156 1454
73 | v=D1eP4Bn4hDQ__#00-09-20_00-10-45_label_B6-0-0 319 375 1050 1105 1390 1445
74 | v=BFPkt0AuSFA__#1_label_B4-0-0 28 3350
75 | v=251___mEwZA__#1_label_B1-0-0 183 1359 1453 2487 3749 4540
76 | Casino.Royale.2006__#00-18-30_00-19-20_label_G-B2-0 261 580
77 | Casino.Royale.2006__#00-50-05_00-51-16_label_B1-B2-B6 43 180 383 625 702 830 860 1006 1066 1096 1527 1634
78 | Casino.Royale.2006__#00-51-16_00-52-41_label_B1-B6-0 37 400 977 1400 1590 1752
79 | Casino.Royale.2006__#01-17-29_01-17-46_label_B1-0-0 216 350
80 | Casino.Royale.2006__#01-46-40_01-47-14_label_B6-0-0 197 600
81 | v=ROrpKx3aIjA__#1_label_G-0-0 713 830
82 | v=Hnfy9XhlIPM__#1_label_B4-0-0 66 4620
83 | v=J_ZsNd97wXw__#1_label_G-0-0 60 660 1025 1090
84 | v=zQbnQBCTSiA__#1_label_B1-0-0 663 1530 2020 2553 3319 3710 4157 4478 4597 4930
85 | City.of.God.2002__#00-37-20_00-38-02_label_B5-0-0 530 750
86 | City.of.God.2002__#00-40-16_00-41-30_label_B2-0-0 460 490 858 910 1142 1330
87 | City.of.God.2002__#01-52-20_01-54-32_label_B2-0-0 24 54 585 611 2892 3167
88 | City.Of.Men.2007__#00-51-50_00-53-31_label_B2-0-0 254 540 1519 1545
89 | City.Of.Men.2007__#00-57-37_00-58-27_label_B2-0-0 218 280 883 920
90 | City.Of.Men.2007__#01-17-27_01-17-59_label_B2-0-0 197 394
91 | City.Of.Men.2007__#01-35-30_01-36-00_label_B2-0-0 135 265
92 | v=qrKfaX1lCUM__#1_label_G-0-0 247 280 1601 1656 1684 1736 1928 2106 2323 2355 2428 2471 2645 2875 5600 5645 5812 6000
93 | v=ZkUciDD55kA__#00-00-00_00-00-30_label_G-0-0 124 617
94 | Crank.Dircut.2006__#0-27-42_0-29-01_label_B1-0-0 347 445 560 1427 1886 1894
95 | v=OOjjPGN8jSU__#1_label_B6-0-0 884 911 994 1030 1083 1102 1235 1266 1375 1483 1680 1745 1830 1890 2070 2250 2320 2520 2557 2650 2798 3030 3175 3265 3622 3670 3935 3999 4089 4150 4435 4515 4590 4687
96 | v=uQY15O3LKI0__#1_label_B6-0-0 297 336 600 750 811 912 1068 1208 1320 1350 1425 1470 1593 1640 1742 1860 1949 2020 2140 2350 2550 2700 2867 2980 3073 3120 3255 3350 3489 3526 3623 3700 3804 3965
97 | Deadpool.2.2018__#0-04-46_0-05-01_label_B2-0-0 113 190 246 300
98 | Deadpool.2.2018__#0-50-30_0-51-20_label_B1-0-0 220 320 548 630 837 1125
99 | Deadpool.2.2018__#01-03-09_01-03-54_label_B4-0-0 390 608
100 | Deadpool.2016__#0-18-58_0-19-20_label_B1-0-0 80 130 290 510
101 | Death.Proof.2007__#00-45-05_00-47-36_label_B5-0-0 1268 1600 3418 3450
102 | Death.Proof.2007__#01-40-41_01-42-17_label_B5-B6-0 1606 1666 1760 2070
103 | v=MqrNCb2N5to__#1_label_B1-0-0 247 425 1495 1666 1852 1950 2219 2450
104 | Desperado.1995__#00-16-48_00-18-52_label_B1-0-0 40 110 160 200 786 805 946 2030 2190 2310
105 | Desperado.1995__#00-38-36_00-39-21_label_B1-B2-0 325 350 381 498 531 570 580 700
106 | Desperado.1995__#01-14-11_01-17-28_label_B2-G-0 838 1300 2640 2900 2957 3015 3310 3340 3760 4290 4478 4600
107 | v=PWBZiM-rkoE__#1_label_G-0-0 2050 2300 2499 2550
108 | v=CLFYvG6MsZU__#00-02-16_00-05-06_label_B2-0-0 600 1060 1577 2030 2636 3140
109 | v=3xWiBCIxjIk__#1_label_B4-0-0 122 3700
110 | v=yrpZJ8Vr3aA__#1_label_B1-0-0 120 1270 2180 2352
111 | Election.2005__#00-47-57_00-50-54_label_B5-0-0 170 480 1219 1620 2063 2150
112 | v=abATVcjFumY__#1_label_G-0-0 143 450 1560 1830
113 | v=d5lTTPvJLpw__#1_label_B4-0-0 188 368 508 810 920 5880
114 | v=J67oj92maC0__#1_label_G-0-0 293 740 806 1200 1346 1650 1698 2086 2213 2830 3260 3960 4058 4670
115 | Fast.Five.2011__#00-01-41_00-01-57_label_B6-0-0 70 310
116 | Fast.Five.2011__#00-32-05_00-32-24_label_B1-B2-0 104 155 194 360
117 | Fast.Five.2011__#00-32-56_00-33-26_label_B2-0-0 155 300 450 686
118 | Fast.Five.2011__#01-29-27_01-32-05_label_B1-0-0 110 2934
119 | Fast.Furious.6.2013__#00-45-40_00-47-13_label_B2-0-0 160 295 425 580 700 890 985 1038 1180 1280 2030 2190
120 | Fast.Furious.6.2013__#01-03-25_01-04-59_label_B1-0-0 460 1470 1550 2120
121 | Fast.Furious.6.2013__#01-38-33_01-39-20_label_B1-0-0 140 170 366 490
122 | Fast.Furious.2009__#00-21-06_00-21-57_label_B6-B2-0 170 599 1170 1194
123 | Fast.Furious.2009__#00-25-59_00-26-25_label_B1-0-0 240 622
124 | Fast.Furious.2009__#00-42-10_00-42-41_label_B6-0-0 35 727
125 | Fast.Furious.2009__#01-18-00_01-19-00_label_B2-B6-0 691 1078 1145 1323
126 | Fast.Furious.2009__#01-29-10_01-29-55_label_B2-0-0 877 1058
127 | v=O9g1uuI62m0__#1_label_B4-0-0 0 4460
128 | v=wG-2t0M8CPc__#1_label_B1-0-0 781 1255 1325 1700
129 | v=wPkBaI6vOv0__#00-00-00_00-01-00_label_G-0-0 25 250 468 530 618 860 1350 1435
130 | v=pFamvR9CpYw__#00-00-00_00-06-36_label_B4-0-0 135 1300 1830 2580 3120 3220 3680 7050 7720 9500
131 | v=pFamvR9CpYw__#00-06-36_00-15-20_label_B4-0-0 0 1985 2450 6360 7460 11600 12200 12575
132 | v=y1z12D5bx7c__#1_label_B4-0-0 0 4847
133 | v=pHZ9gOfmY_k__#1_label_B4-0-0 0 5177
134 | v=gbRIKogNZvE__#1_label_B4-0-0 0 2988
135 | v=T-LH8Gv5zzY__#1_label_B4-0-0 0 2425
136 | v=CoxsDkB-rWU__#00-07-06_00-10-30_label_B4-0-0 0 4896
137 | v=TfS-MJoVNjM__#1_label_B4-0-0 0 1309
138 | Fury.2014__#00-07-30_00-07-58_label_G-0-0 125 180 417 490
139 | Fury.2014__#00-39-36_00-40-44_label_B1-B2-0 78 1424 1550 1575
140 | Mission.Impossible.V.Rogue.Nation.2015__#01-54-23_01-55-10_label_B2-0-0 723 992
141 | v=5Ftl4nOSytc__#1_label_G-0-0 166 290
142 | v=38GQ9L2meyE__#1_label_B6-0-0 26 80 210 288 377 396 450 517 597 628 650 850 895 973 1106 1226 1330 1400 1490 1675 1713 1820 1890 1980 2025 2100 2177 2290 2375 2450 2579 2655 2657 3045 3091 3170 3259 3475 3571 4060 4143 4288 4364 4450
143 | v=GqMh_HBNWZE__#00-01-00_00-02-45_label_B6-0-0 44 148 180 220 393 500 810 940 1325 1465 1582 1614 1857 1917 2259 2308 2480 2520
144 | v=RUjrMYWhLng__#1_label_B6-0-0 330 2620 3526 3760 3875 4420 4695 6540 6758 7220 8440 9111 9280 10350
145 | v=UK--hvgP2uY__#1_label_G-0-0 2310 2595
146 | v=Z12t5h2mBJc__#1_label_B1-0-0 22 780 1089 1275 1442 1710 2535 2950 2975 3600 3965 4300 4488 4664
147 | New.Kids.Turbo.2010__#00-22-35_00-23-14_label_B1-0-0 94 134 670 741
148 | New.Kids.Turbo.2010__#00-50-13_00-50-34_label_B1-0-0 344 420
149 | v=gsz_P8t-KM4__#1_label_B4-0-0 0 9695
150 | v=oipd63DGadU__#1_label_B1-0-0 73 760 990 1040
151 | v=xbPWQKZspfU__#1_label_B1-0-0 356 1200 2709 3544 3927 4782
152 | v=afhttnv46Y4__#1_label_B1-0-0 304 860 1435 2020
153 | v=0yHBkMBE8r4__#1_label_B1-0-0 498 720 3260 3570
154 | v=Q8K7roZu3WU__#1_label_B1-0-0 0 1294
155 | v=gENp4SyNxkI__#1_label_B1-0-0 0 1250
156 | v=sE-DC2trBkI__#00-01-00_00-02-45_label_B6-0-0 225 322 640 760 991 1160 1450 1550 1815 1883 2205 2245
157 | v=sE-DC2trBkI__#00-03-00_00-04-09_label_B6-0-0 258 410 700 802 942 1215 1435 1540
158 | v=sE-DC2trBkI__#00-09-00_00-10-05_label_B6-0-0 196 340 495 605 880 1023 1145 1230 1433 1493
159 | v=15wDrZJQpsw__#00-00-00_00-00-51_label_B6-0-0 434 562 749 866
160 | v=15wDrZJQpsw__#00-05-20_00-06-20_label_B6-0-0 95 233 465 516 852 924 1330 1441
161 | v=15wDrZJQpsw__#00-09-00_00-10-55_label_B6-0-0 130 650 1030 1140 1625 1685 1875 2330 2580 2620
162 | v=HSisjzLESak__#00-00-00_00-00-51_label_B6-0-0 578 700 864 960 1153 1224
163 | v=HSisjzLESak__#00-01-00_00-02-45_label_B6-0-0 0 122 640 900 1163 1330 1633 1681
164 | v=HSisjzLESak__#00-07-40_00-08-50_label_B6-0-0 322 522 1340 1545
165 | v=X4I8FhpGJwo__#00-07-40_00-08-50_label_B6-0-0 258 300 552 598 1150 1299
166 | v=X4I8FhpGJwo__#00-09-00_00-10-45_label_B6-0-0 105 165 550 633 1273 1440 1640 1700 2017 2050 2362 2450
167 | v=waIS8TaJxts__#00-00-00_00-00-51_label_B6-0-0 490 625 990 1075 1113 1220
168 | v=waIS8TaJxts__#00-01-30_00-02-45_label_B6-0-0 189 255 1145 1800
169 | v=waIS8TaJxts__#00-05-50_00-06-20_label_B6-0-0 285 536
170 | v=waIS8TaJxts__#00-07-40_00-08-50_label_B6-0-0 329 375 724 920 1276 1460
171 | v=waIS8TaJxts__#00-09-00_00-10-45_label_B6-0-0 257 540 1082 1245 1468 1805 2303 2490
172 | v=H5W58Loofks__#00-05-50_00-06-20_label_B6-0-0 0 90 521 720
173 | v=TOEVc-pOEwM__#00-08-40_00-10-40_label_B6-0-0 40 95 350 530 796 1224 1512 1658 1983 2166 2444 2532 2845 2880
174 | Operation.Red.Sea.2018__#0-02-17_0-02-37_label_B2-0-0 180 397
175 | Operation.Red.Sea.2018__#0-16-06_0-16-46_label_G-0-0 630 774
176 | Operation.Red.Sea.2018__#0-25-36_0-26-03_label_G-0-0 585 620
177 | Operation.Red.Sea.2018__#0-29-36_0-30-20_label_B2-G-0 222 450 460 533
178 | Operation.Red.Sea.2018__#0-56-26_0-57-10_label_G-0-0 330 430 504 533 1038 1055
179 | Operation.Red.Sea.2018__#01-20-58_01-22-00_label_B5-0-0 1271 1336
180 | Operation.Red.Sea.2018__#01-37-22_01-37-36_label_G-B2-0 150 215 233 275
181 | v=sbEHU4GskX4__#1_label_B1-0-0 608 2030
182 | v=3y2OOc6WTrg__#1_label_B1-0-0 400 790 1345 1700 2684 3273
183 | v=iWfUqxibqG0__#1_label_G-0-0 176 628
184 | v=7isLhaNYgkU__#1_label_B4-0-0 0 5705
185 | v=OB0-MXNPUIU__#1_label_B4-0-0 0 4020
186 | v=FRj2K0ulD8Q__#00-00-00_00-00-57_label_B4-0-0 0 1202
187 | v=FRj2K0ulD8Q__#00-04-16_00-06-05_label_B4-0-0 0 2617
188 | v=FRj2K0ulD8Q__#00-06-05_00-08-25_label_B4-0-0 0 2547 2765 3329
189 | v=ICnreR1hxP0__#1_label_B4-0-0 0 3976
190 | v=MJLylzPRvyw__#1_label_B4-0-0 135 3473
191 | v=OFxPTJxA5pg__#1_label_B4-0-0 0 15000
192 | v=fhiAyxpDQMU__#1_label_B4-0-0 0 3975
193 | v=SE5fxK7SasU__#00-03-23_00-04-08_label_B4-0-0 0 1081
194 | v=_tsSKAsVZfo__#1_label_B4-0-0 153 1090 1170 2480
195 | Quantum.Of.Solace.2008__#01-19-30_01-20-05_label_B1-0-0 131 295
196 | Quantum.Of.Solace.2008__#01-30-26_01-30-42_label_B2-B1-0 30 55 114 144
197 | v=k_cvJa1kNaM__#1_label_B4-0-0 900 4505 4810 5625
198 | v=pL2HjXvWuPc__#1_label_B1-0-0 1090 3392
199 | v=U6MC8JJJZSY__#1_label_G-0-0.mp4 150 779
200 | v=FW_qbiPH7UA__#1_label_B1-0-0.mp4 1654 2080 2367 2745
201 | The.Fast.and.the.Furious.2001__#00-29-00_00-29-52_label_B2-G-0.mp4 310 515 790 1045
202 | The.Fast.and.the.Furious.2001__#01-24-18_01-25-05_label_B6-0-0.mp4 147 347
203 | The.Fast.and.the.Furious.2001__#01-37-40_01-38-30_label_B6-0-0.mp4 167 540
204 | v=aWPWHU8x6kE__#1_label_B4-0-0.mp4 3368 4698
205 | v=LJ0Pu5_Mefs__#1_label_G-0-0.mp4 694 851
206 | v=BpargJW29Wo__#00-00-50_00-01-30_label_B1-0-0.mp4 127 788
207 | v=Y1hxr1MWLjM__#1_label_B1-0-0.mp4 1920 2483
208 | v=UYFR-XbyZQc__#00-01-19_00-02-42_label_G-0-0.mp4 33 333
209 | v=Q9Re4CnFJRg__#00-01-21_00-01-47_label_G-0-0.mp4 53 263 460 615
210 | v=PRH5rYUHoVU__#00-00-25_00-02-50_label_B6-0-0.mp4 27 144 230 440 527 600 715 905 958 1072 1235 1290 1485 1565 1820 1910 2020 3030 3061 3160 3255 3387
211 | v=qmsQ-obL1Z4__#00-03-26_00-04-04_label_B6-0-0.mp4 95 220 615 875
212 | v=f6j3YWgVBto__#00-06-16_00-06-51_label_B6-0-0.mp4 240 370
213 | v=yDqThVpu1AM__#1_label_B4-0-0 0 743
214 | v=KUeUxbsBO6s__#1_label_B1-0-0 1042 2000
215 | v=qqsd-cZr01k__#00-01-00_00-02-45_label_B6-0-0 79 103 381 500 750 825 1026 1060 1340 1380 1700 1880 2100 2255 2430 2480
216 | v=qqsd-cZr01k__#00-06-10_00-07-19_label_B6-0-0 0 66 214 410 665 705 852 900 1120 1160 1285 1330 1350 1385 1468 1510 1598 1657
217 | v=6TR2rcgHm4g__#1_label_B4-0-0 0 344 497 790 1025 1195 1246 1590
218 | v=utQ5AvXtNLA__#1_label_B4-0-0 0 893
219 | v=QWwDsPSe7iM__#1_label_B4-0-0 0 2214
220 | v=pMtu7fOHdII__#1_label_B1-0-0 650 850
221 | Gladiator.2000__#01-10-27_01-11-59_label_B1-0-0 745 1663
222 | God.Bless.America.2011__#00-12-40_00-13-12_label_B2-0-0 122 322
223 | God.Bless.America.2011__#00-38-25_00-39-37_label_B2-B5-0 110 160 300 560 810 1150
224 | God.Bless.America.2011__#01-32-00_01-32-50_label_B2-0-0 122 180
225 | GoldenEye.1995__#00-10-00_00-10-40_label_G-0-0 621 920
226 | GoldenEye.1995__#00-24-53_00-25-06_label_B1-0-0 75 270
227 | GoldenEye.1995__#01-17-05_01-19-57_label_B2-B1-0 232 260 728 990 1200 1591 1700 1765 2959 3402 3631 3769 3995 4122
228 | GoldenEye.1995__#02-02-15_02-02-47_label_B1-0-0 100 310 380 690
229 | GoldenEye.1995__#02-03-40_02-04-13_label_G-0-0 64 777
230 | v=ovQ1VTJ_IUI__#1_label_B1-0-0 140 370 660 861
231 | v=tivXK3PGByk__#1_label_B1-0-0 1037 1800 2300 2750
232 | v=m5zK-tzYCQM__#1_label_B4-0-0 126 1630
233 | v=Q6zIX29HkSo__#1_label_B4-0-0 43 2100
234 | Haywire.2011__#00-03-51_00-05-12_label_B1-B2-0 261 797 799 810 872 1470
235 | Haywire.2011__#00-16-07_00-16-49_label_B1-0-0 153 640
236 | Haywire.2011__#00-39-07_00-41-48_label_B1-B2-0 111 2139 2459 3210 3504 3530
237 | Haywire.2011__#01-20-20_01-21-57_label_B1-0-0 578 2225
238 | v=3wxWNAM8Cso__#1_label_G-0-0 4318 4612 4881 5495
239 | v=YdrISbwy_zI__#1_label_G-0-0 530 1362
240 | v=8ewWXhYRUNs__#1_label_B1-0-0 88 910
241 | v=zoIn2hOrUIM__#00-01-00_00-02-45_label_B6-0-0 156 240 521 580 1315 1370 1694 1766 2067 2146 2460 2521
242 | v=zoIn2hOrUIM__#00-06-10_00-07-19_label_B6-0-0 226 270 369 410 790 825 1078 1110 1462 1473
243 | v=yy-KIWDcBr4__#00-08-35_00-10-25_label_B6-0-0 0 187 270 360 852 930 1155 1230 1501 1556 2100 2200
244 | v=yy-KIWDcBr4__#00-16-10_00-17-19_label_B6-0-0 0 90 350 424 514 550 856 956 1040 1103 1576 1657
245 | v=5Tnl7_8RqlA__#00-11-00_00-12-45_label_B6-0-0 110 137 391 465 746 990 1060 1180 1430 1490 1850 1880 2128 2158 2390 2470
246 | v=BzwNU2xmT64__#00-01-00_00-02-45_label_B6-0-0 0 45 246 285 760 1006 1180 1276 1632 1738 2241 2360
247 | v=BzwNU2xmT64__#00-04-10_00-06-09_label_B6-0-0 27 160 500 560 843 888 1060 1160 1360 1510 1637 1685 1840 1880 2233 2320
248 | v=X1arMZmYhsk__#00-13-00_00-14-09_label_B6-0-0 329 450 660 746 943 990 1200 1250 1552 1620
249 | v=uBXprlsmd18__#00-03-00_00-04-09_label_B6-0-0 107 170 974 1020 1600 1657
250 | v=uBXprlsmd18__#00-06-10_00-07-19_label_B6-0-0 356 529 1075 1156 1440 1657
251 | v=7rDRFFSUrPI__#00-01-50_00-02-32_label_G-0-0 655 700 840 860 912 930
252 | v=EH_QB6cm6BE__#1_label_G-0-0 158 400 535 656
253 | v=q_DhkdHGXos__#1_label_B1-0-0 0 1420
254 | v=Dh-3BlhAOnE__#00-00-00_00-13-36_label_B4-0-0 0 170 14846 19116
255 | v=rAlZRFZTwxM__#00-03-13_00-07-08_label_G-0-0 264 304 566 859 2391 2422 2780 3190
256 | v=Dn_SDu22WpM__#00-08-35_00-10-05_label_B6-0-0 55 100 253 313 437 500 755 788 989 1040 1812 1893
257 | v=9Jk2sIp5MRQ__#1_label_G-0-0 175 215 273 567
258 | v=UK2w9Sh47fM__#1_label_G-0-0 210 500
259 | v=vv-MFJPi4Qs__#1_label_G-0-0 649 680
260 | v=hAHXCMRvY_I__#1_label_G-0-0 165 278
261 | v=w16YmhAZ1_Y__#1_label_G-0-0 64 315
262 | v=78u9uBJBqIw__#1_label_B4-B5-0 1330 1784 2059 2360
263 | IP.Man.2.2010__#00-07-14_00-09-13_label_B1-0-0 579 742 1010 1350 1745 1870 1995 2190
264 | IP.Man.2.2010__#01-35-55_01-39-11_label_B1-0-0 637 748 890 1150 2564 3962
265 | Ip.Man.3.2015__#00-17-51_00-18-51_label_B1-0-0 0 145 305 470
266 | Ip.Man.3.2015__#00-44-45_00-45-08_label_B1-0-0 193 365
267 | Ip.Man.3.2015__#00-46-08_00-54-20_label_B1-0-0 4826 4964 5050 5134 5712 6300 6572 10880 11440 11530 11700 11807
268 | Ip.Man.2008__#00-46-51_00-47-30_label_B1-0-0 310 520
269 | Ip.Man.2008__#00-48-57_00-50-47_label_B1-0-0 12 148 502 730 1260 2047
270 | Ip.Man.2008__#01-14-31_01-15-24_label_B1-0-0 177 462
271 | Ip.Man.2008__#01-40-37_01-42-07_label_B4-0-0 105 1482
272 | v=cO1UefhG7AY__#1_label_G-0-0 1048 1200 1940 2190 4180 4300
273 | v=v_LxqgpRouM__#1_label_G-0-0 148 888
274 | v=CUr-Mg8Hmh0__#1_label_G-0-0 3668 3853 4409 4600
275 | v=u0Tnw01PNxc__#1_label_B4-0-0 0 1376
276 | Jason.Bourne.2016__#0-17-25_0-18-53_label_B4-0-0 85 485 640 1280 1522 1950
277 | Jason.Bourne.2016__#0-29-48_0-30-29_label_B4-0-0 117 300 500 660
278 | Jason.Bourne.2016__#0-50-20_0-50-30_label_G-0-0 12 77
279 | v=PuZy1PQIgrA__#1_label_B1-0-0 528 1120 2042 2590
280 | v=Q7NpJXOTiMY__#1_label_G-0-0 10 900
281 | v=JfLYNEsrTew__#1_label_G-0-0 52 430
282 | v=w2NgAYJHnS0__#1_label_G-0-0 50 300
283 | v=9Ydg5IeZpFI__#1_label_B1-0-0 55 660 1131 1710 2231 2475 2799 3291 3830 4489
284 | Kill.Bill.Vol.1.2003__#01-17-20_01-20-20_label_B1-0-0 1835 4273
285 | Kill.Bill.Vol.1.2003__#01-21-39_01-26-55_label_B1-0-0 1829 6854
286 | Kill.Bill.Vol.1.2003__#01-26-56_01-28-18_label_B1-0-0 297 1880
287 | Kill.Bill.Vol.2.2004__#0-25-20_0-26-20_label_B2-0-0 970 1010
288 | Kingsman.The.Golden.Circle.2017__#00-35-55_00-36-35_label_B1-0-0 287 693
289 | Kingsman.The.Golden.Circle.2017__#00-41-22_00-41-32_label_B2-0-0 14 44
290 | Kingsman.The.Secret.Service.2014__#00-22-10_00-23-10_label_B2-0-0 0 230
291 | Kingsman.The.Secret.Service.2014__#01-43-40_01-44-50_label_B2-G-0 96 117 1394 1438 1555 1580
292 | Law.Abiding.Citizen.2009__#01-26-34_01-27-25_label_B2-0-0 626 990
293 | v=SMy2_qNO2Y0__#00-01-50_00-03-13_label_G-0-0 257 310 1040 1300
294 | v=qyTGi0N4OVg__#1_label_G-0-0 646 1120
295 | v=CvMpD1yqfxI__#1_label_G-0-0 194 338 560 720
296 | v=ZCjNPgNu42g__#1_label_B4-0-0 1060 2970
297 | v=gp_D8r-2hwk__#1_label_G-0-0 1365 1600
298 | Lord.of.War__#00-50-10_00-50-50_label_G-0-0 119 230
299 | Lord.of.War__#01-41-04_01-41-53_label_B1-0-0 295 318 433 467
300 | v=cuUK-3MNqtw__#1_label_B1-0-0 1280 1450
301 | Love.Death.and.Robots.S01E10__#0-01-00_0-02-13_label_B2-G-0 560 600 846 905 1050 1295
302 | Love.Death.and.Robots.S01E15__#0-05-38_0-06-18_label_G-0-0 340 430
303 | v=OE9t3XImErk__#1_label_B1-0-0 193 290
304 | v=3e3C-LQWF2s__#1_label_B1-0-0 141 1000 1503 1665 1860 2200 3170 3300 3882 3960 4085 4420
305 | v=HUL83wCQATc__#1_label_B1-0-0 145 402 1000 1403
306 | Mindhunters.2004__#01-19-00_01-20-05_label_B2-0-0 0 375 546 604
307 | Mission.Impossible.Fallout.2018__#00-31-21_00-32-50_label_B1-0-0 96 2125
308 | Mission.Impossible.Fallout.2018__#00-39-18_00-40-36_label_B1-0-0 327 1680
309 | Mission.Impossible.Fallout.2018__#02-03-50_02-04-35_label_B1-0-0 222 976
310 | Mission.Impossible.Ghost.Protocol.2011__#00-01-18_00-01-40_label_B2-0-0 63 200 371 410
311 | Mission.Impossible.Ghost.Protocol.2011__#01-12-38_01-13-43_label_B1-0-0 180 630 1086 1122 1260 1390
312 | Mission.Impossible.Ghost.Protocol.2011__#01-54-24_01-54-50_label_B1-0-0 87 347
313 | Mission.Impossible.Ghost.Protocol.2011__#01-56-52_01-57-03_label_B1-0-0 54 120
314 | Mission.Impossible.II.2000__#01-18-55_01-20-53_label_B2-0-0 159 330
315 | Mission.Impossible.II.2000__#01-25-08_01-25-31_label_B2-0-0 9 117
316 | Mission.Impossible.II.2000__#01-29-30_01-29-44_label_B1-0-0 136 280
317 | Mission.Impossible.II.2000__#01-32-56_01-33-25_label_B1-0-0 460 600
318 | Mission.Impossible.II.2000__#01-38-25_01-38-41_label_B2-B5-0 130 165 260 295
319 | Mission.Impossible.II.2000__#01-46-20_01-47-11_label_B2-B6-0 10 56 84 135 184 413 608 919
320 | Mission.Impossible.III.2006__#00-19-56_00-20-26_label_B2-0-0 19 560
321 | Mission.Impossible.III.2006__#01-06-52_01-07-15_label_G-0-0 259 346
322 | Mission.Impossible.III.2006__#01-54-02_01-54-40_label_B2-0-0 206 300 554 735
323 | Mission.Impossible.V.Rogue.Nation.2015__#00-14-40_00-17-12_label_B1-0-0 170 222 395 420 538 558 1445 3296
324 | Mission.Impossible.V.Rogue.Nation.2015__#00-17-27_00-17-55_label_B2-0-0 476 579
325 | Mission.Impossible.V.Rogue.Nation.2015__#01-15-52_01-16-48_label_B6-B2-0 84 216 328 407 526 995
326 | Mission.Impossible.V.Rogue.Nation.2015__#01-19-48_01-20-20_label_B6-0-0 192 332
327 | v=iegHZ_UWWsA__#00-14-05_00-20-33_label_B4-0-0 1108 5466 5921 7130
328 | v=8oJMGgaww70__#1_label_B4-0-0 0 3578
329 | v=Y43UkfAe_bc__#00-00-00_00-01-40_label_B4-0-0 0 2400
330 | v=aQ3qMpgZjwg__#00-00-00_00-01-48_label_B4-0-0 637 1530 1800 2100
331 | v=3xkwZk44VQs__#1_label_B4-0-0 0 771 792 1489 1509 1978
332 | v=yvnj5VIDsNI__#1_label_B4-0-0 0 1197 3496 3710 3765 3910
333 | v=GbMQ2CxeNHU__#1_label_B4-0-0 0 5489
334 | v=3zgaqeVSXuI__#1_label_B4-0-0 0 663 670 1320 1500 1604 1762 1848
335 | v=ADQNFs9bfFk__#00-00-00_00-00-58_label_B4-0-0 0 350 465 1379
336 | v=fIlmgvc-bUk__#1_label_B4-0-0 123 1370
337 | v=kPWdgckIhLI__#1_label_B4-0-0 0 558
338 | Rush.Hour.3.2007.BluRay__#00-04-59_00-05-31_label_B2-0-0 573 623
339 | Rush.Hour.3.2007.BluRay__#00-34-11_00-35-46_label_B1-0-0 416 738 805 1034 1118 1390 1461 1788 1862 2055 2232 2258
340 | Rush.Hour.3.2007.BluRay__#00-40-56_00-41-52_label_B2-0-0 665 684 1145 1195
341 | Rush.Hour.3.2007.BluRay__#01-12-16_01-12-36_label_B1-0-0 0 450
342 | Rush.Hour.1998.BluRay__#00-08-08_00-09-15_label_B2-B6-G 115 174 411 445 457 494 594 620 640 830 1000 1127 1248 1445
343 | Rush.Hour.1998.BluRay__#00-11-05_00-12-10_label_B2-B1-0 140 175 275 575 935 1360
344 | Rush.Hour.1998.BluRay__#01-20-28_01-21-38_label_B2-0-0 122 200 220 300 316 357 762 848 998 1049 1105 1174 1323 1475
345 | Rush.Hour.1998.BluRay__#01-25-21_01-25-46_label_B2-0-0 437 519
346 | v=wnd3IYH7x1o__#1_label_G-0-0 260 440 746 1000
347 | v=lcBUb7EOQ4o__#1_label_G-0-0 913 1050 1132 1270 1306 2235
348 | Salt.2010__#00-17-40_00-18-16_label_B1-0-0 530 864
349 | Salt.2010__#01-06-27_01-07-07_label_G-B2-0 25 92 119 137 190 270 319 526 690 730
350 | Salt.2010__#01-12-05_01-12-32_label_B2-G-0 123 296 340 456
351 | Saving.Private.Ryan.1998__#00-47-50_00-49-35_label_B2-0-0 483 670 801 887 1211 1431
352 | Saving.Private.Ryan.1998__#02-23-00_02-23-28_label_B1-0-0 38 672
353 | Saving.Private.Ryan.1998__#02-23-58_02-24-18_label_G-0-0 228 311
354 | Saving.Private.Ryan.1998__#02-24-18_02-25-42_label_B1-0-0 72 428 550 830 1085 2016
355 | Saving.Private.Ryan.1998__#02-25-42_02-26-03_label_B2-0-0 247 504
356 | v=Qc3pAEeK16A__#1_label_B1-0-0 70 237 980 1218
357 | v=6KQFxxHmVMc__#1_label_B1-0-0 144 590 1991 2100 2220 2340 2479 2640
358 | Shoot.Em.Up.2007__#00-32-40_00-35-10_label_B2-0-0 43 74 114 159 234 371 560 781 905 1426 1522 1696 1790 2185 2374 2434 2517 2697 2795 3432
359 | Shoot.Em.Up.2007__#00-53-25_00-56-47_label_B2-0-0 101 200 705 760 2650 2709 2875 3120 3217 3469 4021 4265 4770 4820
360 | Shoot.Em.Up.2007__#00-56-56_00-57-15_label_B2-0-0 0 380
361 | Sin.City.2005__#0-06-42_0-07-06_label_B1-0-0 340 380
362 | Sin.City.2005__#0-10-40_0-11-40_label_B2-0-0 50 106 355 392 466 480 555 580 650 750
363 | Sin.City.2005__#0-22-04_0-22-18_label_B5-0-0 0 336
364 | Sin.City.2005__#0-29-34_0-29-56_label_B2-0-0 355 425 470 510
365 | Sin.City.2005__#0-32-00_0-32-16_label_B2-0-0 125 159 280 320
366 | Sin.City.2005__#0-43-45_0-44-01_label_B2-0-0 52 148
367 | Sin.City.2005__#01-19-08_01-19-23_label_G-0-0 75 150
368 | Sin.City.2005__#01-51-59_01-52-29_label_B2-0-0 35 80 555 610
369 | Skyfall.2012__#00-03-22_00-03-40_label_B6-0-0 180 432
370 | Skyfall.2012__#00-06-10_00-06-25_label_B6-0-0 165 255
371 | Skyfall.2012__#01-20-05_01-20-20_label_B1-B2-0 45 260
372 | Skyfall.2012__#01-55-44_01-56-57_label_B2-G-0 54 298 370 467 548 596 687 740 944 1076 1153 1182 1235 1449
373 | Skyfall.2012__#01-57-40_01-58-00_label_B2-0-0 0 168
374 | Skyfall.2012__#02-08-44_02-09-17_label_B1-0-0 78 655
375 | v=-kwNh1lMU-w__#1_label_B4-0-0 0 837 1360 4080
376 | v=_BgJEXQkjNQ__#1_label_G-0-0 1720 5145 5302 7938
377 | v=xe4ee56aHSg__#1_label_G-0-0 270 531
378 | v=2al4t5inObA__#1_label_B4-0-0 0 3322
379 | Spectre.2015__#00-44-05_00-44-40_label_B1-B2-0 38 217 254 329 539 817
380 | Spectre.2015__#01-08-58_01-09-20_label_B1-B2-0 50 100 295 320 330 355
381 | Spectre.2015__#01-10-35_01-10-53_label_B2-0-0 170 237 270 337
382 | Spectre.2015__#01-32-55_01-35-40_label_B1-B2-0 92 132 134 154 177 2870 2960 2980 3126 3743
383 | Spectre.2015__#02-02-57_02-03-27_label_B1-B2-0 384 430
384 | Spectre.2015__#02-15-30_02-16-05_label_G-0-0 430 571
385 | v=rXWOpZ7W2fA__#1_label_B1-0-0 1500 1570 1665 1800 1957 2137
386 | Taken.2.UNRATED.EXTENDED.2012__#00-13-42_00-14-16_label_B5-0-0 40 424
387 | Taken.2.UNRATED.EXTENDED.2012__#00-45-50_00-46-20_label_G-0-0 374 418
388 | Taken.2.UNRATED.EXTENDED.2012__#01-00-00_01-00-16_label_B2-0-0 327 368
389 | Taken.2.UNRATED.EXTENDED.2012__#01-06-28_01-06-49_label_B2-0-0 49 228
390 | Taken.2.UNRATED.EXTENDED.2012__#01-18-25_01-18-40_label_B2-0-0 49 260
391 | Taken.2.UNRATED.EXTENDED.2012__#01-18-40_01-19-00_label_B1-0-0 18 68
392 | Taken.3.2014__#00-03-58_00-04-16_label_B2-0-0 44 70
393 | Taken.3.2014__#01-04-50_01-05-05_label_B1-0-0 109 147
394 | Taken.3.2014__#01-12-48_01-14-30_label_B2-B1-0 47 585 685 1516 2060 2080
395 | Taken.3.2014__#01-19-13_01-19-30_label_B1-0-0 39 204
396 | Taken.Extended.Cut.2008__#00-33-51_00-34-03_label_B1-0-0 139 185 206 281
397 | Taken.Extended.Cut.2008__#00-50-45_00-51-04_label_B2-B6-0 103 190 430 455
398 | v=rKpXqwE2rg8__#00-01-57_00-03-06_label_B4-0-0 0 1657
399 | v=96N4XAJ7FKM__#1_label_B1-0-0 922 1936 2831 3281
400 | v=nLAapCIlr-o__#00-07-22_00-08-19_label_B6-0-0 85 140 399 650 921 1076
401 | v=nLAapCIlr-o__#00-08-35_00-10-25_label_B6-0-0 105 200 662 840 1295 1386 1852 1913 2106 2197 2590 2630
402 | v=nLAapCIlr-o__#00-11-00_00-12-45_label_B6-0-0 35 75 319 362 854 917 1523 1596 1747 1800 2013 2098 2417 2521
403 | v=nLAapCIlr-o__#00-16-10_00-18-09_label_B6-0-0 15 120 692 770 1337 1486 1874 2000 2240 2400 2793 2830
404 | v=osjdmjNJUdg__#1_label_B1-0-0 191 856 1380 1985
405 | The.Bourne.Identity.2002__#0-19-53_0-21-48_label_B1-0-0 1931 2115 2641 2750
406 | The.Bourne.Identity.2002__#0-54-45_0-56-25_label_B6-0-0 504 590 633 745 872 945 2052 2106
407 | The.Bourne.Identity.2002__#01-27-39_01-29-19_label_B2-0-0 530 560 1455 1514 1794 1830
408 | The.Bourne.Identity.2002__#01-45-00_01-45-46_label_B1-0-0 240 296 525 559 1045 1086
409 | The.Bourne.Legacy.2012__#0-39-56_0-40-15_label_G-0-0 24 119
410 | The.Bourne.Legacy.2012__#0-59-18_01-01-52_label_B2-B1-0 46 879 946 1018 1111 1160 1212 1354 1426 1660 3300 3330 3340 3610
411 | The.Bourne.Ultimatum.2007__#01-26-27_01-28-03_label_B2-B6-0 475 569 740 900 983 1250 1531 1590 1754 2039
412 | The.Bourne.Ultimatum.2007__#01-44-02_01-44-11_label_B2-0-0 60 80
413 | The.Fast.and.the.Furious.2001__#00-52-18_00-53-36_label_B1-0-0 542 1033 1734 1800
414 | The.Fast.and.the.Furious.2001__#01-14-04_01-15-20_label_B1-0-0 715 979 1030 1072 1221 1250 1333 1460
415 | v=0W8LohxH9nI__#1_label_B1-0-0 933 1250
416 | The.Hurt.Locker.2008__#0-08-45_0-09-57_label_G-0-0 858 1560
417 | The.Hurt.Locker.2008__#0-19-22_0-22-32_label_B2-B1-0 2094 2125 2244 2270 2532 2575 3580 3750
418 | The.Hurt.Locker.2008__#01-02-10_01-03-55_label_B2-0-0 448 520 755 830 1225 1374 1968 1990
419 | The.Hurt.Locker.2008__#01-27-42_01-28-09_label_G-0-0 270 413
420 | v=7QOdBpDrmKk__#00-01-00_00-02-45_label_B6-0-0 32 59 240 281 771 797 1100 1130 1894 1988 2230 2270
421 | v=7QOdBpDrmKk__#00-10-30_00-10-51_label_B6-0-0 332 412
422 | v=k7R_Qo-BiAw__#00-10-30_00-10-51_label_B6-0-0 0 150 296 505
423 | The.World.Is.Not.Enough.1999__#00-07-04_00-07-45_label_G-B2-0 270 480 847 944
424 | The.World.Is.Not.Enough.1999__#00-10-22_00-10-40_label_G-0-0 6 65
425 | v=DRGMajXo_PI__#00-02-15_00-03-02_label_G-0-0 0 1111
426 | v=ONsmJAyFAAw__#1_label_G-0-0 0 1523
427 | v=Rc9EOFjtj0c__#1_label_B6-0-0 564 625 894 988 1307 1389 1615 1820 1852 1980 2263 2328 2701 2738 2901 2980 3109 3270 3301 3342
428 | v=FfrXebzAwOg__#00-06-10_00-07-19_label_B6-0-0 165 202 341 410 599 700 780 852 996 1022 1117 1181 1442 1516
429 | v=DtxU8UYiFws__#1_label_G-0-0 272 415
430 | Tropa.de.Elite.2.2010__#00-13-41_00-14-56_label_B4-0-0 0 535 1677 1735
431 | Tropa.de.Elite.2.2010__#01-33-01_01-33-30_label_B2-0-0 138 176 211 260
432 | Tropa.de.Elite.2.2010__#00-42-00_00-43-00_label_B2-0-0 777 900
433 | Tropa.de.Elite.2.2010__#00-46-11_00-46-50_label_B2-0-0 109 265
434 | v=Lt9M_tSij_4__#00-00-00_00-04-24_label_B6-0-0 170 361 502 657 910 1044 1173 1700 1883 2000 2103 2158 2573 2642 2982 3217 3350 3461 3653 3770 4092 4220 4483 4548 4859 4982 5187 5291 5515 5724 6289 6335
435 | v=BXR3d22BhHs__#00-06-00_00-07-00_label_B4-0-0 1 1425
436 | v=BXR3d22BhHs__#00-07-00_00-08-00_label_B4-0-0 1 1431
437 | v=tP19WuyY3IY__#00-00-00_00-00-51_label_B6-0-0 104 189 524 635 888 966
438 | v=tP19WuyY3IY__#00-06-10_00-07-51_label_B6-0-0 1243 1270 1631 1766
439 | v=YANmwpkuWRo__#1_label_G-0-0 968 1072
440 | v=Q-u3Y9TkZhQ__#1_label_B4-0-0 5 1031
441 | v=IWzI9V3WSnc__#1_label_B4-0-0 4 2687
442 | v=X6Tbn8_X-Os__#1_label_B1-0-0 200 914 1790 3450
443 | v=1GFBuaGCeTo__#1_label_G-0-0 802 1230
444 | v=GN-7_ye5k4E__#1_label_G-0-0 572 1300
445 | v=eqtJjxsTgtg__#1_label_G-0-0 343 480
446 | v=k98yVPXiYoU__#1_label_G-0-0 64 252
447 | v=bDm2-1NZBLw__#1_label_B4-0-0 1 2209
448 | v=2Rr21qkZEDQ__#00-03-00_00-04-09_label_B6-0-0 548 647 1300 1410
449 | v=2Rr21qkZEDQ__#00-14-10_00-16-09_label_B6-0-0 482 576 922 1032 1389 1444 2035 2160 2472 2670
450 | v=CcD-YNlkMPY__#1_label_G-0-0 1300 1356
451 | v=V7FtjJStY1M__#1_label_G-0-0 44 359
452 | v=v-Plzx73K68__#1_label_B4-0-0 316 921
453 | v=AF9p0YA0hyA__#00-00-45_00-01-44_label_B4-0-0 3 1417
454 | v=AF9p0YA0hyA__#00-03-03_00-04-20_label_B4-0-0 1 1841
455 | v=_5Kmb4tqMxs__#1_label_B4-0-0 23 1250
456 | v=TsuuRo6PEXM__#1_label_B1-0-0 307 1218 2306 3200
457 | wangted.2008__#0-22-55_0-23-40_label_B2-0-0 395 560 632 697
458 | wangted.2008__#0-45-25_0-47-01_label_B1-0-0 28 1245
459 | wangted.2008__#0-48-30_0-49-00_label_B2-0-0 346 511
460 | v=YhCSFZNwfHc__#00-01-15_00-01-42_label_G-B2-0 0 65 285 500
461 | v=_YobflFU_HU__#00-06-10_00-07-51_label_B6-0-0 0 65 300 439 630 712 1130 1240 1500 1620 1935 2058 2290 2398
462 | v=g2v3EkBj9Zc__#1_label_B4-0-0 0 530
463 | v=uprb6aBzymw__#00-03-03_00-05-13_label_B4-0-0 0 125 530 790 2045 2887
464 | Yellow.Sea.2010__#01-04-00_01-04-50_label_B1-0-0 0 356
465 | Yellow.Sea.2010__#01-05-40_01-06-47_label_B1-B2-0 140 407 735 1180
466 | Yellow.Sea.2010__#01-58-30_01-59-27_label_B5-0-0 0 787
467 | v=WiS3TIvykeY__#1_label_B4-0-0 0 1560
468 | v=UzxuX79xq4s__#1_label_B4-0-0 270 840 1312 2989 3430 4660
469 | v=lpkL0Y1MhA8__#1_label_B4-0-0 0 248 397 585 1165 1940 2393 2523
470 | v=1io9Uh54Vhg__#1_label_B4-0-0 0 1976
471 | v=m5Ya6z6qroo__#1_label_B4-0-0 138 5505
472 | Young.And.Dangerous.I.1996__#0-20-50_0-22-15_label_B1-0-0 425 1079
473 | Young.And.Dangerous.I.1996__#0-46-25_0-46-55_label_B6-0-0 350 570
474 | Young.And.Dangerous.I.1996__#0-46-57_0-48-57_label_B1-0-0 0 2454
475 | Young.And.Dangerous.III.1996__#00-23-00_00-25-54_label_B4-0-0 2497 3051
476 | Young.And.Dangerous.IV.1997__#01-07-50_01-09-45_label_B1-0-0 368 2734
477 | Young.And.Dangerous.IV.1997__#01-32-04_01-32-58_label_B1-0-0 650 1100
478 | Taken.3.2014__#01-09-26_01-09-54_label_G-0-0 22 258
479 | Bullet.in.the.Head.1990__#00-23-31_00-24-40_label_G-0-0 1325 1400
480 | v=DVdXoVUNkhg__#1_label_G-0-0 990 1145
481 | v=CczYDDDf22A__#1_label_G-0-0 190 240
482 | v=9CWJd1SezkA__#1_label_G-0-0 464 614
483 | v=8oTjTufJnXI__#1_label_G-0-0 977 1610
484 | Mission.Impossible.III.2006__#00-56-23_00-56-40_label_G-0-0 260 363
485 | v=ntaXjnpusxM__#1_label_G-0-0 312 500
486 | Saving.Private.Ryan.1998__#02-20-00_02-20-13_label_G-0-0 115 215
487 | Saving.Private.Ryan.1998__#02-16-58_02-17-13_label_G-0-0 147 300
488 | Love.Death.and.Robots.S01E13__#0-11-48_0-12-07_label_G-0-0 60 240
489 | v=35zynOrkvMk__#1_label_G-0-0 114 164
490 | v=1q5V6DKH3bw__#1_label_B4-0-0 0 425 680 1515 2030 2353 2500 2950
491 | v=3WmMhircZOc__#1_label_B4-0-0 580 4150
492 | v=-etV57xZ4_I__#1_label_B4-0-0 0 260
493 | v=tv0rI-5ycBU__#1_label_B4-0-0 0 130 206 2945
494 | v=2WkuPNLfl5s__#1_label_B4-0-0 225 1008
495 | v=CFfc4UWzVB8__#00-00-00_00-00-30_label_B4-0-0 508 581 605 721
496 | v=fQQzg2VfJME__#1_label_B4-0-0 117 1300
497 | v=0N0PbzjEg0U__#1_label_B4-0-0 0 550
498 | v=p7kvQ4OpDgU__#1_label_B4-0-0 270 1370
499 | v=GI49YSCruwY__#1_label_B4-0-0 0 1410 1795 1965 2320 2980
500 | City.of.God.2002__#01-24-10_01-25-10_label_B2-0-0 340 1090
--------------------------------------------------------------------------------