├── README.md
├── dataset.py
├── extract.py
├── main.py
├── model.py
├── result
├── ambiguous.txt
├── model.png
└── vis.png
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # ACRNet
2 |
3 | A PyTorch implementation of ACRNet based on ICME 2023 paper
4 | [Weakly-supervised Temporal Action Localization with Adaptive Clustering and Refining Network](https://ieeexplore.ieee.org/abstract/document/10219653).
5 |
6 | 
7 |
8 | ## Requirements
9 |
10 | - [Anaconda](https://www.anaconda.com/download/)
11 | - [PyTorch](https://pytorch.org)
12 |
13 | ```
14 | conda install pytorch torchvision torchaudio pytorch-cuda=11.6 -c pytorch -c nvidia
15 | ```
16 |
17 | - [MMAction2](https://mmaction2.readthedocs.io)
18 |
19 | ```
20 | pip install openmim
21 | mim install mmaction2 -f https://github.com/open-mmlab/mmaction2.git
22 | ```
23 |
24 | ## Datasets
25 |
26 | [THUMOS 14](http://crcv.ucf.edu/THUMOS14/download.html) and [ActivityNet](http://activity-net.org/download.html)
27 | datasets are used in this repo, you should download these datasets from official websites. The RGB and Flow features of
28 | these datasets are extracted by [dataset.py](dataset.py) with `25 FPS`. You should follow
29 | [this link](https://gist.github.com/raulqf/f42c718a658cddc16f9df07ecc627be7) to install OpenCV4 with CUDA. And then
30 | compile [denseFlow_GPU](https://github.com/daveboat/denseFlow_GPU), put the executable program in this dir. The options
31 | could be found in [dataset.py](dataset.py), this script will take a lot of time to extract the features. Finally, I3D
32 | features of these datasets are extracted by [this repo](https://github.com/Finspire13/pytorch-i3d-feature-extraction),
33 | the `extract_features.py` file should be replaced with [extract.py](extract.py), the options could be found in
34 | [extract.py](extract.py). To make this research friendly, we uploaded these I3D features in
35 | [MEGA](https://mega.nz/folder/6sFxjaZB#Jtx69Kb2RHu2ldXoNzsODQ). You could download them from there, and make sure the
36 | data directory structure is organized as follows:
37 |
38 | ```
39 | ├── thumos14 | ├── activitynet
40 | ├── features | ├── features
41 | ├── val | ├── training
42 | ├── video_validation_0000051_flow.npy | ├── v___c8enCfzqw_flow.npy
43 | ├── video_validation_0000051_rgb.npy | ├── v___c8enCfzqw_rgb.npy
44 | └── ... | └── ...
45 | ├── test | ├── validation
46 | ├── video_test_0000004_flow.npy | ├── v__1vYKA7mNLI_flow.npy
47 | ├── video_test_0000004_rgb.npy | ├── v__1vYKA7mNLI_rgb.npy
48 | └── ... | └── ...
49 | ├── videos | ├── videos
50 | ├── val | ├── training
51 | ├── video_validation_0000051.mp4 | ├── v___c8enCfzqw.mp4
52 | └──... | └──...
53 | ├── test | ├── validation
54 | ├── video_test_0000004.mp4 | ├── v__1vYKA7mNLI.mp4
55 | └──... | └──...
56 | annotations.json | annotations_1.2.json, annotations_1.3.json
57 | ```
58 |
59 | ## Usage
60 |
61 | You can easily train and test the model by running the script below. If you want to try other options, please refer to
62 | [utils.py](utils.py).
63 |
64 | ### Train Model
65 |
66 | ```
67 | python main.py --data_name activitynet1.2 --num_segments 80 --seed 42
68 | ```
69 |
70 | ### Test Model
71 |
72 | ```
73 | python main.py --data_name thumos14 --model_file result/thumos14.pth
74 | ```
75 |
76 | ## Benchmarks
77 |
78 | The models are trained on one NVIDIA GeForce RTX 3090 GPU (24G). `seed` is `42` for all datasets, `num_seg` is `80`,
79 | `alpha` is `0.8` and `batch_size` is `128` for both `activitynet1.2&1.3` datasets, the other hyper-parameters are the
80 | default values.
81 |
82 | ### THUMOS14
83 |
84 |
85 |
86 |
87 | Method |
88 | THUMOS14 |
89 | Download |
90 |
91 |
92 | mAP@0.1 |
93 | mAP@0.2 |
94 | mAP@0.3 |
95 | mAP@0.4 |
96 | mAP@0.5 |
97 | mAP@0.6 |
98 | mAP@0.7 |
99 | mAP@AVG |
100 |
101 |
102 |
103 |
104 | ACRNet |
105 | 76.7 |
106 | 70.7 |
107 | 61.0 |
108 | 49.0 |
109 | 37.0 |
110 | 24.8 |
111 | 13.4 |
112 | 47.5 |
113 | MEGA |
114 |
115 |
116 |
117 |
118 | mAP@AVG is the average mAP under the thresholds [0.1:0.1:0.7].
119 |
120 | ### ActivityNet
121 |
122 |
123 |
124 |
125 | Method |
126 | ActivityNet 1.2 |
127 | ActivityNet 1.3 |
128 | Download |
129 |
130 |
131 | mAP@0.5 |
132 | mAP@0.75 |
133 | mAP@0.95 |
134 | mAP@AVG |
135 | mAP@0.5 |
136 | mAP@0.75 |
137 | mAP@0.95 |
138 | mAP@AVG |
139 |
140 |
141 |
142 |
143 | ACRNet |
144 | 46.2 |
145 | 28.4 |
146 | 5.7 |
147 | 28.4 |
148 | 40.9 |
149 | 26.0 |
150 | 5.4 |
151 | 25.7 |
152 | MEGA |
153 |
154 |
155 |
156 |
157 | mAP@AVG is the average mAP under the thresholds [0.5:0.05:0.95].
158 |
159 | ## Results
160 |
161 | 
162 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import json
4 | import os
5 | import subprocess
6 |
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import Dataset
10 |
11 | from utils import which_ffmpeg
12 |
13 |
14 | class VideoDataset(Dataset):
15 | def __init__(self, data_path, data_name, data_type, num_seg, length=None):
16 |
17 | self.data_name, self.data_type, self.num_seg = data_name, data_type, num_seg
18 |
19 | # prepare annotations
20 | if data_name == 'thumos14':
21 | data_type = 'val' if data_type == 'train' else 'test'
22 | label_name = 'annotations.json'
23 | else:
24 | data_type = 'training' if data_type == 'train' else 'validation'
25 | label_name = 'annotations_{}.json'.format(data_name[-3:])
26 | data_name = data_name[:-3]
27 | with open(os.path.join(data_path, data_name, label_name), 'r') as f:
28 | annotations = json.load(f)
29 |
30 | # prepare data
31 | self.rgb, self.flow, self.annotations, classes = [], [], {}, set()
32 | self.names, self.labels, self.class_to_idx, self.idx_to_class = [], [], {}, {}
33 | for key, value in annotations.items():
34 | if value['subset'] == data_type:
35 | # ref: Weakly-supervised Temporal Action Localization by Uncertainty Modeling (AAAI 2021)
36 | if key in ['video_test_0000270', 'video_test_0001292', 'video_test_0001496']:
37 | continue
38 | rgb = np.load('{}/{}/features/{}/{}_rgb.npy'.format(data_path, data_name, data_type, key))
39 | flow = np.load('{}/{}/features/{}/{}_flow.npy'.format(data_path, data_name, data_type, key))
40 | # ref: Cross-modal Consensus Network for Weakly Supervised Temporal Action Localization (ACM MM 2021)
41 | if len(rgb) <= 10:
42 | continue
43 | self.rgb.append(rgb)
44 | self.flow.append(flow)
45 | self.names.append(key)
46 | # the prefix is added to compatible with ActivityNetLocalization class
47 | self.annotations['d_{}'.format(key)] = {'annotations': value['annotations']}
48 | for annotation in value['annotations']:
49 | classes.add(annotation['label'])
50 | for i, key in enumerate(sorted(classes)):
51 | self.class_to_idx[key] = i
52 | self.idx_to_class[i] = key
53 | for i in range(len(self.rgb)):
54 | label = np.zeros(len(classes), dtype=np.float32)
55 | for item in annotations[self.names[i]]['annotations']:
56 | label[self.class_to_idx[item['label']]] = 1
57 | self.labels.append(label)
58 |
59 | # for train according to the given length, for test according to the real length
60 | self.num = len(self.rgb)
61 | self.sample_num = length if self.data_type == 'train' else self.num
62 |
63 | def __len__(self):
64 | return self.sample_num
65 |
66 | def __getitem__(self, index):
67 | rgb, flow = self.rgb[index % self.num], self.flow[index % self.num]
68 | video_key, label, num_seg = self.names[index % self.num], self.labels[index % self.num], len(rgb)
69 | sample_idx = self.random_sampling(num_seg) if self.data_type == 'train' else self.uniform_sampling(num_seg)
70 | rgb, flow = torch.from_numpy(rgb[sample_idx]), torch.from_numpy(flow[sample_idx])
71 | feat, label = torch.cat((rgb, flow), dim=-1), torch.from_numpy(label)
72 | return feat, label, video_key, num_seg
73 |
74 | def random_sampling(self, num_seg):
75 | sample_idx = np.append(np.arange(self.num_seg) * num_seg / self.num_seg, num_seg)
76 | for i in range(self.num_seg):
77 | if int(sample_idx[i]) == int(sample_idx[i + 1]):
78 | sample_idx[i] = int(sample_idx[i])
79 | else:
80 | sample_idx[i] = np.random.randint(int(sample_idx[i]), int(sample_idx[i + 1]))
81 | return sample_idx[:-1].astype(np.int)
82 |
83 | def uniform_sampling(self, num_seg):
84 | # because the length may different as these two line codes, make sure batch size == 1 in test mode
85 | if num_seg <= self.num_seg:
86 | return np.arange(num_seg).astype(np.int)
87 | else:
88 | return np.floor(np.arange(self.num_seg) * num_seg / self.num_seg).astype(np.int)
89 |
90 |
91 | if __name__ == '__main__':
92 | description = 'Extract the RGB and Flow features from videos with assigned fps'
93 | parser = argparse.ArgumentParser(description=description)
94 | parser.add_argument('--data_root', type=str, default='/home/data')
95 | parser.add_argument('--save_path', type=str, default='result')
96 | parser.add_argument('--dataset', type=str, default='thumos14', choices=['thumos14', 'activitynet'])
97 | parser.add_argument('--fps', type=int, default=25)
98 | parser.add_argument('--data_split', type=str, required=True)
99 | args = parser.parse_args()
100 |
101 | data_root, save_path, dataset, data_split = args.data_root, args.save_path, args.dataset, args.data_split
102 | fps, ffmpeg_path = args.fps, which_ffmpeg()
103 | videos = sorted(glob.glob('{}/{}/videos/{}/*'.format(data_root, dataset, data_split)))
104 | total = len(videos)
105 |
106 | for j, video_path in enumerate(videos):
107 | dir_name, video_name = os.path.dirname(video_path).split('/')[-1], os.path.basename(video_path).split('.')[0]
108 | save_root = '{}/{}/{}/{}'.format(save_path, dataset, dir_name, video_name)
109 | # pass the already precessed videos
110 | try:
111 | os.makedirs(save_root)
112 | except OSError:
113 | continue
114 | print('[{}/{}] Saving {} to {}/{}.mp4 with {} fps'.format(j + 1, total, video_path, save_root, video_name, fps))
115 | ffmpeg_cmd = '{} -hide_banner -loglevel panic -i {} -r {} -y {}/{}.mp4' \
116 | .format(ffmpeg_path, video_path, fps, save_root, video_name)
117 | subprocess.call(ffmpeg_cmd.split())
118 | flow_cmd = './denseFlow_gpu -f={}/{}.mp4 -o={}'.format(save_root, video_name, save_root)
119 | subprocess.call(flow_cmd.split())
120 |
--------------------------------------------------------------------------------
/extract.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 | from pytorch_i3d import InceptionI3d
8 | from torchvision.transforms import CenterCrop
9 |
10 |
11 | def load_frame(frame_file):
12 | data = Image.open(frame_file)
13 | assert (min(data.size) == 256)
14 | data = CenterCrop(size=224)(data)
15 | data = np.array(data, dtype=np.float32)
16 | data = (data * 2 / 255) - 1
17 |
18 | assert (data.max() <= 1.0)
19 | assert (data.min() >= -1.0)
20 |
21 | return data
22 |
23 |
24 | def load_rgb_batch(frames_dir, rgb_files, frame_indices):
25 | batch_data = np.zeros(frame_indices.shape + (224, 224, 3), dtype=np.float32)
26 | for i in range(frame_indices.shape[0]):
27 | for j in range(frame_indices.shape[1]):
28 | batch_data[i, j, :, :, :] = load_frame(os.path.join(frames_dir, 'rgb', rgb_files[frame_indices[i][j]]))
29 |
30 | return batch_data
31 |
32 |
33 | def load_flow_batch(frames_dir, flow_x_files, flow_y_files, frame_indices):
34 | batch_data = np.zeros(frame_indices.shape + (224, 224, 2), dtype=np.float32)
35 | for i in range(frame_indices.shape[0]):
36 | for j in range(frame_indices.shape[1]):
37 | batch_data[i, j, :, :, 0] = load_frame(os.path.join(frames_dir, 'flow_x',
38 | flow_x_files[frame_indices[i][j]]))
39 | batch_data[i, j, :, :, 1] = load_frame(os.path.join(frames_dir, 'flow_y',
40 | flow_y_files[frame_indices[i][j]]))
41 |
42 | return batch_data
43 |
44 |
45 | def run(mode='rgb', load_model='', frequency=16, chunk_size=16, input_dir='', output_dir='', batch_size=40):
46 | # setup the model
47 | if mode == 'flow':
48 | i3d = InceptionI3d(400, in_channels=2)
49 | else:
50 | i3d = InceptionI3d(400, in_channels=3)
51 |
52 | i3d.load_state_dict(torch.load(load_model))
53 | i3d.cuda()
54 | # set model to evaluate mode
55 | i3d.eval()
56 |
57 | video_names = [i for i in os.listdir(input_dir)]
58 | for pro_i, video_name in enumerate(video_names):
59 | save_file = '{}_{}.npy'.format(video_name, mode)
60 | if save_file in os.listdir(output_dir):
61 | continue
62 |
63 | frames_dir = os.path.join(input_dir, video_name)
64 | if mode == 'rgb':
65 | rgb_files = [i for i in os.listdir(os.path.join(frames_dir, 'rgb'))]
66 | rgb_files.sort(key=lambda x: int(x.split('.')[0]))
67 | frame_cnt = len(rgb_files)
68 | else:
69 | flow_x_files = [i for i in os.listdir(os.path.join(frames_dir, 'flow_x'))]
70 | flow_y_files = [i for i in os.listdir(os.path.join(frames_dir, 'flow_y'))]
71 | flow_x_files.sort(key=lambda x: int(x.split('.')[0]))
72 | flow_y_files.sort(key=lambda x: int(x.split('.')[0]))
73 | assert (len(flow_y_files) == len(flow_x_files))
74 | frame_cnt = len(flow_y_files)
75 |
76 | # cut frames
77 | assert (frame_cnt > chunk_size)
78 | clipped_length = frame_cnt - chunk_size
79 | # the start of last chunk
80 | clipped_length = (clipped_length // frequency) * frequency
81 | # frames to chunks
82 | frame_indices = []
83 | for i in range(clipped_length // frequency + 1):
84 | frame_indices.append([j for j in range(i * frequency, i * frequency + chunk_size)])
85 | frame_indices = np.array(frame_indices)
86 |
87 | chunk_num = frame_indices.shape[0]
88 | # chunks to batches
89 | batch_num = int(np.ceil(chunk_num / batch_size))
90 | frame_indices = np.array_split(frame_indices, batch_num, axis=0)
91 |
92 | full_features = []
93 | for batch_id in range(batch_num):
94 | if mode == 'rgb':
95 | batch_data = load_rgb_batch(frames_dir, rgb_files, frame_indices[batch_id])
96 | else:
97 | batch_data = load_flow_batch(frames_dir, flow_x_files, flow_y_files, frame_indices[batch_id])
98 | with torch.no_grad():
99 | # [b, c, t, h, w]
100 | batch_data = torch.from_numpy(batch_data.transpose([0, 4, 1, 2, 3])).cuda()
101 | batch_feature = i3d.extract_features(batch_data)
102 | batch_feature = torch.flatten(batch_feature, start_dim=1).cpu().numpy()
103 | full_features.append(batch_feature)
104 |
105 | full_features = np.concatenate(full_features, axis=0)
106 | np.save(os.path.join(output_dir, save_file), full_features)
107 | print('[{}/{}] {} done: {} / {}, {}'.format(pro_i + 1, len(video_names), video_name, frame_cnt,
108 | clipped_length, full_features.shape))
109 |
110 |
111 | if __name__ == '__main__':
112 | parser = argparse.ArgumentParser()
113 | parser.add_argument('--mode', type=str, required=True, choices=['rgb', 'flow'])
114 | parser.add_argument('--load_model', type=str, required=True)
115 | parser.add_argument('--input_dir', type=str, default='data')
116 | parser.add_argument('--output_dir', type=str, default='result')
117 | parser.add_argument('--batch_size', type=int, default=40)
118 | parser.add_argument('--frequency', type=int, default=16)
119 | parser.add_argument('--chunk_size', type=int, default=16)
120 | args = parser.parse_args()
121 |
122 | run(mode=args.mode, load_model=args.load_model, frequency=args.frequency, chunk_size=args.chunk_size,
123 | input_dir=args.input_dir, output_dir=args.output_dir, batch_size=args.batch_size)
124 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | import numpy as np
4 | import pandas as pd
5 | import torch
6 | from mmaction.core.evaluation import ActivityNetLocalization
7 | from mmaction.localization import soft_nms
8 | from torch.optim import Adam
9 | from torch.optim.lr_scheduler import CosineAnnealingLR
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 |
13 | from dataset import VideoDataset
14 | from model import Model, cross_entropy, generalized_cross_entropy
15 | from utils import parse_args, oic_score, revert_frame, grouping, result2json, filter_results
16 |
17 |
18 | def test_loop(net, data_loader, num_iter):
19 | net.eval()
20 | results, num_correct, num_total, test_info = {'results': {}}, 0, 0, {}
21 | with torch.no_grad():
22 | for data, gt, video_name, num_seg in tqdm(data_loader, initial=1, dynamic_ncols=True):
23 | data, gt = data.cuda(non_blocking=True), gt.squeeze(0).cuda(non_blocking=True)
24 | video_name, num_seg = video_name[0], num_seg.squeeze(0)
25 | act_score, _, seg_score, _, _, _ = net(data)
26 | # [C], [T, C]
27 | act_score, seg_score = act_score.squeeze(0), seg_score.squeeze(0)
28 |
29 | pred = torch.ge(act_score, args.cls_th)
30 | # make sure at least one prediction
31 | if torch.sum(pred) == 0:
32 | pred[torch.argmax(act_score, dim=-1)] = True
33 | num_correct += 1 if torch.equal(gt, pred.float()) else 0
34 | num_total += 1
35 |
36 | frame_score = revert_frame(seg_score.cpu().numpy(), args.rate * num_seg.item())
37 | # make sure the score between [0, 1]
38 | frame_score = np.clip(frame_score, a_min=0.0, a_max=1.0)
39 |
40 | proposal_dict = {}
41 | for i, status in enumerate(pred):
42 | if status:
43 | # enrich the proposal pool by using multiple thresholds
44 | for threshold in args.act_th:
45 | proposals = grouping(np.where(frame_score[:, i] >= threshold)[0])
46 | # make sure the proposal to be regions
47 | for proposal in proposals:
48 | if len(proposal) >= 2:
49 | if i not in proposal_dict:
50 | proposal_dict[i] = []
51 | score = oic_score(frame_score[:, i], act_score[i].cpu().numpy(), proposal)
52 | # change frame index to second
53 | start, end = (proposal[0] + 1) / args.fps, (proposal[-1] + 2) / args.fps
54 | proposal_dict[i].append([start, end, score])
55 | # temporal soft nms
56 | # ref: BSN: Boundary Sensitive Network for Temporal Action Proposal Generation (ECCV 2018)
57 | if i in proposal_dict:
58 | proposal_dict[i] = soft_nms(np.array(proposal_dict[i]), args.alpha, args.iou_th, args.iou_th,
59 | top_k=len(proposal_dict[i])).tolist()
60 |
61 | results['results'][video_name] = result2json(proposal_dict, data_loader.dataset.idx_to_class)
62 |
63 | test_acc = num_correct / num_total
64 |
65 | if args.data_name == 'thumos14':
66 | results = filter_results(results, 'result/ambiguous.txt')
67 | gt_path = '{}/{}_gt.json'.format(args.save_path, args.data_name)
68 | with open(gt_path, 'w') as json_file:
69 | json.dump(data_loader.dataset.annotations, json_file, indent=4)
70 | pred_path = '{}/{}_pred.json'.format(args.save_path, args.data_name)
71 | with open(pred_path, 'w') as json_file:
72 | json.dump(results, json_file, indent=4)
73 |
74 | # evaluate the metrics
75 | evaluator_atl = ActivityNetLocalization(gt_path, pred_path, tiou_thresholds=args.map_th, verbose=False)
76 | m_ap, m_ap_avg = evaluator_atl.evaluate()
77 |
78 | desc = 'Test Step: [{}/{}] ACC: {:.1f} mAP@AVG: {:.1f}'.format(num_iter, args.num_iter, test_acc * 100,
79 | m_ap_avg * 100)
80 | test_info['Test ACC'] = round(test_acc * 100, 1)
81 | test_info['mAP@AVG'] = round(m_ap_avg * 100, 1)
82 | for i in range(args.map_th.shape[0]):
83 | desc += ' mAP@{:.2f}: {:.1f}'.format(args.map_th[i], m_ap[i] * 100)
84 | test_info['mAP@{:.2f}'.format(args.map_th[i])] = round(m_ap[i] * 100, 1)
85 | print(desc)
86 | return test_info
87 |
88 |
89 | def save_loop(net, data_loader, num_iter):
90 | global best_mAP
91 | test_info = test_loop(net, data_loader, num_iter)
92 | for key, value in test_info.items():
93 | if key not in metric_info:
94 | metric_info[key] = []
95 | metric_info[key].append('{:.3f}'.format(value))
96 |
97 | # save statistics
98 | data_frame = pd.DataFrame(data=metric_info, index=range(1, (num_iter if args.model_file
99 | else num_iter // args.eval_iter) + 1))
100 | data_frame.to_csv('{}/{}.csv'.format(args.save_path, args.data_name), index_label='Step', float_format='%.3f')
101 | if test_info['mAP@AVG'] > best_mAP:
102 | best_mAP = test_info['mAP@AVG']
103 | torch.save(net.state_dict(), '{}/{}.pth'.format(args.save_path, args.data_name))
104 |
105 |
106 | if __name__ == '__main__':
107 | args = parse_args()
108 | test_data = VideoDataset(args.data_path, args.data_name, 'test', args.num_seg)
109 | test_loader = DataLoader(test_data, 1, False, num_workers=args.workers, pin_memory=True)
110 |
111 | model = Model(len(test_data.class_to_idx)).cuda()
112 | best_mAP, metric_info = 0, {}
113 | if args.model_file:
114 | model.load_state_dict(torch.load(args.model_file))
115 | save_loop(model, test_loader, 1)
116 |
117 | else:
118 | model.train()
119 | train_data = VideoDataset(args.data_path, args.data_name, 'train', args.num_seg,
120 | args.batch_size * args.num_iter)
121 | train_loader = iter(DataLoader(train_data, args.batch_size, True, num_workers=args.workers, pin_memory=True))
122 | optimizer = Adam(model.parameters(), lr=args.init_lr, weight_decay=args.weight_decay)
123 | lr_scheduler = CosineAnnealingLR(optimizer, T_max=args.num_iter)
124 |
125 | total_loss, total_num, metric_info['Loss'] = 0.0, 0, []
126 | train_bar = tqdm(range(1, args.num_iter + 1), initial=1, dynamic_ncols=True)
127 | for step in train_bar:
128 | feat, label, _, _ = next(train_loader)
129 | feat, label = feat.cuda(non_blocking=True), label.cuda(non_blocking=True)
130 | act_score, bkg_score, seg_score, seg_mask, aas_rgb, aas_flow = model(feat)
131 | cas_loss = cross_entropy(act_score, bkg_score, label)
132 | aas_rgb_loss = generalized_cross_entropy(aas_rgb, seg_mask, label)
133 | aas_flow_loss = generalized_cross_entropy(aas_flow, seg_mask, label)
134 | loss = cas_loss + args.lamda * (aas_rgb_loss + aas_flow_loss)
135 | optimizer.zero_grad()
136 | loss.backward()
137 | optimizer.step()
138 |
139 | total_num += feat.size(0)
140 | total_loss += loss.item() * feat.size(0)
141 | train_bar.set_description('Train Step: [{}/{}] Loss: {:.3f}'
142 | .format(step, args.num_iter, total_loss / total_num))
143 | lr_scheduler.step()
144 | if step % args.eval_iter == 0:
145 | metric_info['Loss'].append('{:.3f}'.format(total_loss / total_num))
146 | save_loop(model, test_loader, step)
147 | model.train()
148 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def weights_init(m):
8 | classname = m.__class__.__name__
9 | if classname.find('Conv') != -1 or classname.find('Linear') != -1:
10 | torch.nn.init.kaiming_uniform_(m.weight)
11 | if not isinstance(m.bias, type(None)):
12 | m.bias.data.fill_(0)
13 |
14 |
15 | # ref: Weakly-supervised Temporal Action Localization with Multi-head Cross-modal Attention (PRICAI 2022)
16 | class MCA(nn.Module):
17 | def __init__(self, feat_dim, num_head=4):
18 | super(MCA, self).__init__()
19 | self.rgb_proj = nn.Parameter(torch.empty(num_head, feat_dim, feat_dim // num_head))
20 | self.flow_proj = nn.Parameter(torch.empty(num_head, feat_dim, feat_dim // num_head))
21 | self.atte = nn.Parameter(torch.empty(num_head, feat_dim // num_head, feat_dim // num_head))
22 |
23 | nn.init.uniform_(self.rgb_proj, -math.sqrt(feat_dim), math.sqrt(feat_dim))
24 | nn.init.uniform_(self.flow_proj, -math.sqrt(feat_dim), math.sqrt(feat_dim))
25 | nn.init.uniform_(self.atte, -math.sqrt(feat_dim // num_head), math.sqrt(feat_dim // num_head))
26 | self.num_head = num_head
27 |
28 | def forward(self, rgb, flow):
29 | rgb, flow = rgb.mT.contiguous(), flow.mT.contiguous()
30 | n, t, d = rgb.shape
31 | # [N, H, T, D/H]
32 | o_rgb = F.normalize(torch.matmul(rgb.unsqueeze(dim=1), self.rgb_proj), dim=-1)
33 | o_flow = F.normalize(torch.matmul(flow.unsqueeze(dim=1), self.flow_proj), dim=-1)
34 | # [N, H, T, T]
35 | atte = torch.matmul(torch.matmul(o_rgb, self.atte), o_flow.mT.contiguous())
36 | rgb_atte = torch.softmax(atte, dim=-1)
37 | flow_atte = torch.softmax(atte.mT.contiguous(), dim=-1)
38 |
39 | # [N, H, T, D/H]
40 | e_rgb = F.gelu(torch.matmul(rgb_atte, o_rgb))
41 | e_flow = F.gelu(torch.matmul(flow_atte, o_flow))
42 | # [N, T, D]
43 | f_rgb = torch.tanh(e_rgb.mT.reshape(n, t, -1).contiguous() + rgb)
44 | f_flow = torch.tanh(e_flow.mT.reshape(n, t, -1).contiguous() + flow)
45 |
46 | f_rgb, f_flow = f_rgb.mT.contiguous(), f_flow.mT.contiguous()
47 | return f_rgb, f_flow
48 |
49 |
50 | # ref: Dual-Evidential Learning for Weakly-supervised Temporal Action Localization (ECCV 2022)
51 | class BWA(nn.Module):
52 | def __init__(self, feat_dim):
53 | super(BWA, self).__init__()
54 | self.attn = nn.Sequential(nn.Conv1d(feat_dim, feat_dim, 3, padding=1), nn.LeakyReLU(0.2), nn.Dropout(0.5))
55 | self.conv = nn.Sequential(nn.Conv1d(feat_dim, feat_dim, 3, padding=1), nn.LeakyReLU(0.2), nn.Dropout(0.5))
56 | self.avg = nn.AdaptiveAvgPool1d(1)
57 |
58 | def forward(self, base_feat, ref_feat):
59 | channel_attn = self.conv(self.avg(base_feat))
60 | bit_attn = self.attn(ref_feat)
61 | filter_feat = torch.sigmoid(bit_attn * channel_attn) * base_feat
62 | return filter_feat
63 |
64 |
65 | class Model(nn.Module):
66 | def __init__(self, num_classes):
67 | super(Model, self).__init__()
68 |
69 | self.mca = MCA(1024)
70 | self.rgb_bwa = BWA(1024)
71 | self.flow_bwa = BWA(1024)
72 | self.cas_rgb_encoder = nn.Sequential(nn.Conv1d(1024, 1024, 3, padding=1), nn.ReLU(),
73 | nn.Conv1d(1024, num_classes, kernel_size=1))
74 | self.cas_flow_encoder = nn.Sequential(nn.Conv1d(1024, 1024, 3, padding=1), nn.ReLU(),
75 | nn.Conv1d(1024, num_classes, kernel_size=1))
76 |
77 | self.aas_rgb_encoder = nn.Sequential(nn.Conv1d(1024, 512, 1), nn.ReLU(), nn.Conv1d(512, 1, 1))
78 | self.aas_flow_encoder = nn.Sequential(nn.Conv1d(1024, 512, 1), nn.ReLU(), nn.Conv1d(512, 1, 1))
79 |
80 | # ref: A Hybrid Attention Mechanism for Weakly-Supervised Temporal Action Localization (AAAI 2021)
81 | if num_classes != 20:
82 | pool = nn.AvgPool1d(13, 1, padding=6, count_include_pad=True)
83 | self.cas_rgb_encoder.append(pool)
84 | self.cas_flow_encoder.append(pool)
85 | self.aas_rgb_encoder.append(pool)
86 | self.aas_flow_encoder.append(pool)
87 |
88 | self.apply(weights_init)
89 |
90 | def forward(self, x):
91 | # [N, D, T]
92 | x = x.mT.contiguous()
93 | rgb, flow = self.mca(x[:, :1024, :], x[:, 1024:, :])
94 | rgb, flow = self.rgb_bwa(rgb, flow), self.flow_bwa(flow, rgb)
95 |
96 | # [N, T, C], class activation sequence
97 | cas_rgb = self.cas_rgb_encoder(rgb).mT.contiguous()
98 | cas_flow = self.cas_flow_encoder(flow).mT.contiguous()
99 | cas = cas_rgb + cas_flow
100 | cas_score = torch.softmax(cas, dim=-1)
101 | # [N, T, 1], action activation sequence
102 | aas_rgb = torch.sigmoid(self.aas_rgb_encoder(rgb).mT.contiguous())
103 | aas_flow = torch.sigmoid(self.aas_flow_encoder(flow).mT.contiguous())
104 | aas_score = (aas_rgb + aas_flow) / 2
105 | # [N, T, C]
106 | seg_score = (cas_score + aas_score) / 2
107 | seg_mask = temporal_clustering(seg_score)
108 | seg_mask = mask_refining(seg_score, seg_mask, cas)
109 |
110 | # [N, C]
111 | act_score, bkg_score = calculate_score(seg_score, seg_mask, cas)
112 | return act_score, bkg_score, seg_score, seg_mask, aas_rgb, aas_flow
113 |
114 |
115 | def temporal_clustering(seg_score):
116 | n, t, c = seg_score.shape
117 | # [N*C, T]
118 | seg_score = seg_score.mT.contiguous().view(-1, t)
119 | sort_value, sort_index = torch.sort(seg_score, dim=-1, descending=True, stable=True)
120 | mask = torch.zeros_like(seg_score)
121 | row_index = torch.arange(mask.shape[0], device=mask.device)
122 | # the index of the largest value is inited as positive
123 | mask[row_index, sort_index[:, 0]] = 1
124 | # [N*C]
125 | pos_sum, neg_sum = sort_value[:, 0], sort_value[:, -1]
126 | pos_num, neg_num = torch.ones_like(pos_sum), torch.ones_like(neg_sum)
127 | for i in range(1, t - 1):
128 | pos_center = pos_sum / pos_num
129 | neg_center = neg_sum / neg_num
130 | index, value = sort_index[:, i], sort_value[:, i]
131 | pos_distance = torch.abs(value - pos_center)
132 | neg_distance = torch.abs(value - neg_center)
133 | condition = torch.le(pos_distance, neg_distance)
134 | pos_list = torch.where(condition, value, torch.zeros_like(value))
135 | neg_list = torch.where(~condition, value, torch.zeros_like(value))
136 | # update centers
137 | pos_num = pos_num + condition.float() / (i + 1)
138 | pos_sum = pos_sum + pos_list / (i + 1)
139 | neg_num = neg_num + (~condition).float()
140 | neg_sum = neg_sum + neg_list
141 | # update mask
142 | mask[row_index, index] = condition.float()
143 | # [N, T, C]
144 | mask = mask.view(n, c, t).mT.contiguous()
145 | return mask
146 |
147 |
148 | def mask_refining(seg_score, seg_mask, cas):
149 | n, t, c = seg_score.shape
150 | sort_value, sort_index = torch.sort(seg_score, dim=1, descending=True, stable=True)
151 | # [N, T]
152 | ranks = torch.arange(2, t + 2, device=seg_score.device).reciprocal().view(1, -1).expand(n, -1).contiguous()
153 | row_index = torch.arange(n, device=seg_score.device).view(-1, 1).expand(-1, t).contiguous()
154 | # [N, C]
155 | act_score = torch.zeros(n, c, device=seg_score.device)
156 | mean_score = torch.zeros(n, c, device=seg_score.device)
157 |
158 | for i in range(c):
159 | # [N, T]
160 | index, value = sort_index[:, :, i], sort_value[:, :, i]
161 | mask = seg_mask[:, :, i][row_index, index]
162 | cs = cas[:, :, i][row_index, index]
163 | rank = ranks * mask
164 | # [N]
165 | tmp_score = (cs * rank).sum(dim=-1) / torch.clamp_min(rank.sum(dim=-1), 1.0)
166 | act_score[:, i] = tmp_score
167 | for j in range(n):
168 | ref_score = tmp_score[j]
169 | ref_val = cs[j][mask[j].bool()]
170 | sort_val = value[j][mask[j].bool()]
171 | if ref_val.shape[0] > 0:
172 | cum_cnts = torch.arange(1, mask[j].sum() + 1, device=seg_score.device)
173 | cum_scores = torch.cumsum(ref_val, dim=-1) / cum_cnts
174 | tmp_mask = torch.ge(cum_scores, ref_score).long()
175 | mean_score[j, i] = sort_val[min(tmp_mask.sum() - 1, sort_val.shape[0] - 1)]
176 | else:
177 | mean_score[j, i] = 0.0
178 | max_mask = torch.ge(seg_score, mean_score.unsqueeze(dim=1)).float()
179 | refined_mask = seg_mask * max_mask
180 | return refined_mask
181 |
182 |
183 | def calculate_score(seg_score, seg_mask, cas):
184 | n, t, c = seg_score.shape
185 | # [N*C, T]
186 | seg_score = seg_score.mT.contiguous().view(-1, t)
187 | sort_value, sort_index = torch.sort(seg_score, dim=-1, descending=True, stable=True)
188 | seg_mask = seg_mask.mT.contiguous().view(-1, t)
189 | row_index = torch.arange(seg_mask.shape[0], device=seg_mask.device).view(-1, 1).expand(-1, t).contiguous()
190 | sort_mask = seg_mask[row_index, sort_index]
191 | cas = cas.mT.contiguous().view(-1, t)
192 | sort_cas = cas[row_index, sort_index]
193 | # [1, T]
194 | rank = torch.arange(2, t + 2, device=seg_score.device).unsqueeze(dim=0).reciprocal()
195 | # [N*C]
196 | act_num = (rank * sort_mask).sum(dim=-1)
197 | act_score = (sort_cas * rank * sort_mask).sum(dim=-1) / torch.clamp_min(act_num, 1.0)
198 | bkg_num = (1.0 - sort_mask).sum(dim=-1)
199 | bkg_score = (sort_cas * (1.0 - sort_mask)).sum(dim=-1) / torch.clamp_min(bkg_num, 1.0)
200 | act_score, bkg_score = torch.softmax(act_score.view(n, c), dim=-1), torch.softmax(bkg_score.view(n, c), dim=-1)
201 | return act_score, bkg_score
202 |
203 |
204 | def cross_entropy(act_score, bkg_score, label, eps=1e-8):
205 | act_num = torch.clamp_min(torch.sum(label, dim=-1), 1.0)
206 | act_loss = (-(label * torch.log(torch.clamp_min(act_score, eps))).sum(dim=-1) / act_num).mean()
207 | bkg_loss = (-torch.log(torch.clamp_min(1.0 - bkg_score, eps))).mean()
208 | return act_loss + bkg_loss
209 |
210 |
211 | # ref: Weakly Supervised Action Selection Learning in Video (CVPR 2021)
212 | def generalized_cross_entropy(aas_score, seg_mask, label, q=0.7, eps=1e-8):
213 | # [N, T]
214 | aas_score = aas_score.squeeze(dim=-1)
215 | n, t, c = seg_mask.shape
216 | # [N, T]
217 | mask = torch.zeros(n, t, device=seg_mask.device)
218 | for i in range(n):
219 | mask[i, :] = torch.sum(seg_mask[i, :, label[i, :].bool()], dim=-1)
220 | # [N, T]
221 | mask = torch.clamp_max(mask, 1.0)
222 | # [N]
223 | pos_num = torch.clamp_min(torch.sum(mask, dim=1), 1.0)
224 | neg_num = torch.clamp_min(torch.sum(1.0 - mask, dim=1), 1.0)
225 |
226 | pos_loss = ((((1.0 - torch.clamp_min(aas_score, eps) ** q) / q) * mask).sum(dim=-1) / pos_num).mean()
227 | neg_loss = ((((1.0 - torch.clamp_min(1.0 - aas_score, eps) ** q) / q) * (1.0 - mask)).sum(dim=-1) / neg_num).mean()
228 | return pos_loss + neg_loss
229 |
--------------------------------------------------------------------------------
/result/ambiguous.txt:
--------------------------------------------------------------------------------
1 | video_test_0000278 0.0 1.4
2 | video_test_0000278 95.7 97.2
3 | video_test_0000293 50.6 54.6
4 | video_test_0000293 67.4 71.7
5 | video_test_0000293 99.7 106.4
6 | video_test_0000293 118.1 126.4
7 | video_test_0000293 145.8 149.8
8 | video_test_0000293 162.9 168.2
9 | video_test_0000293 181.3 184.0
10 | video_test_0000367 56.4 63.8
11 | video_test_0000367 167.8 170.7
12 | video_test_0000405 15.7 18.4
13 | video_test_0000426 17.8 18.6
14 | video_test_0000426 24.0 24.8
15 | video_test_0000426 40.1 41.8
16 | video_test_0000426 113.9 115.0
17 | video_test_0000426 118.8 119.7
18 | video_test_0000426 124.0 125.2
19 | video_test_0000426 135.2 136.9
20 | video_test_0000437 1.6 12.1
21 | video_test_0000437 47.2 48.4
22 | video_test_0000437 53.2 54.0
23 | video_test_0000437 65.9 67.8
24 | video_test_0000448 42.2 53.8
25 | video_test_0000461 14.0 16.7
26 | video_test_0000549 28.2 33.6
27 | video_test_0000549 14.4 17.4
28 | video_test_0000549 55.0 57.1
29 | video_test_0000593 23.6 29.0
30 | video_test_0000593 36.9 44.3
31 | video_test_0000611 43.0 45.9
32 | video_test_0000611 55.6 58.8
33 | video_test_0000611 59.9 71.0
34 | video_test_0000615 136.9 142.5
35 | video_test_0000615 152.7 159.8
36 | video_test_0000615 164.7 168.0
37 | video_test_0000624 2.5 6.8
38 | video_test_0000664 4.2 5.6
39 | video_test_0000691 36.3 80.8
40 | video_test_0000691 123.9 151.0
41 | video_test_0000714 136.7 137.5
42 | video_test_0000718 13.3 15.5
43 | video_test_0000847 33.7 35.4
44 | video_test_0000847 46.0 52.0
45 | video_test_0000847 58.7 67.0
46 | video_test_0000847 82.8 98.1
47 | video_test_0000847 136.2 171.2
48 | video_test_0000847 175.0 178.5
49 | video_test_0000847 204.5 212.8
50 | video_test_0000940 90.3 92.0
51 | video_test_0000989 170.6 188.4
52 | video_test_0001075 12.3 13.3
53 | video_test_0001075 142.6 143.9
54 | video_test_0001076 17.0 18.4
55 | video_test_0001076 23.8 25.9
56 | video_test_0001076 47.5 57.8
57 | video_test_0001079 335.1 342.8
58 | video_test_0001079 416.0 420.7
59 | video_test_0001127 76.1 161.2
60 | video_test_0001134 2.6 4.7
61 | video_test_0001134 21.2 22.8
62 | video_test_0001134 30.2 36.1
63 | video_test_0001134 41.4 45.4
64 | video_test_0001134 72.4 73.0
65 | video_test_0001168 52.6 78.2
66 | video_test_0001201 122.4 125.3
67 | video_test_0001209 122.9 141.3
68 | video_test_0001267 81.7 84.8
69 | video_test_0001292 78.4 113.2
70 | video_test_0001292 39.9 47.4
71 | video_test_0001292 141.2 154.1
72 | video_test_0001343 224.9 226.6
73 | video_test_0001343 241.0 244.3
74 | video_test_0001433 16.4 17.7
75 | video_test_0001496 22.2 23.9
76 | video_test_0001496 41.9 44.1
77 | video_test_0001496 54.8 56.6
78 | video_test_0001496 62.1 64.4
79 | video_test_0001496 71.9 73.4
80 | video_test_0001496 119.4 121.0
81 | video_test_0001496 124.2 126.3
82 | video_test_0001496 136.0 137.7
83 | video_test_0001496 145.2 147.1
84 | video_test_0001508 11.2 13.1
85 | video_test_0001508 19.4 23.3
86 | video_test_0001508 23.9 27.2
87 | video_test_0001508 29.2 32.2
88 | video_test_0001508 33.2 36.6
89 | video_test_0001508 43.0 45.6
90 | video_test_0001508 46.5 48.5
91 | video_test_0001508 131.4 132.5
92 | video_test_0001508 139.7 141.7
93 | video_test_0001508 149.1 151.9
94 | video_test_0001508 153.4 155.3
95 | video_test_0001508 160.8 166.0
96 | video_test_0001512 7.5 10.0
97 | video_test_0001532 57.3 121.4
98 | video_test_0001549 15.9 26.7
99 | video_test_0001549 75.0 100.2
--------------------------------------------------------------------------------
/result/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/ACRNet/5718240ec0311d0d61a78b8d4c1895601800e79e/result/model.png
--------------------------------------------------------------------------------
/result/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/leftthomas/ACRNet/5718240ec0311d0d61a78b8d4c1895601800e79e/result/vis.png
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import random
4 | import subprocess
5 |
6 | import numpy as np
7 | import torch
8 | from scipy.interpolate import interp1d
9 | from torch.backends import cudnn
10 |
11 |
12 | def parse_args():
13 | desc = 'Pytorch Implementation of \'Weakly-supervised Temporal Action Localization with Adaptive ' \
14 | 'Clustering and Refining Network\''
15 | parser = argparse.ArgumentParser(description=desc)
16 | parser.add_argument('--data_path', type=str, default='/home/data')
17 | parser.add_argument('--save_path', type=str, default='result')
18 | parser.add_argument('--data_name', type=str, default='thumos14',
19 | choices=['thumos14', 'activitynet1.2', 'activitynet1.3'])
20 | parser.add_argument('--cls_th', type=float, default=0.1, help='threshold for action classification')
21 | parser.add_argument('--iou_th', type=float, default=0.1, help='threshold for NMS IoU')
22 | parser.add_argument('--act_th', type=str, default='np.arange(0.1, 1.0, 0.05)',
23 | help='threshold for candidate frames')
24 | parser.add_argument('--alpha', type=float, default=0.5, help='alpha value for soft nms')
25 | parser.add_argument('--num_seg', type=int, default=750, help='sampled segments for each video')
26 | parser.add_argument('--fps', type=int, default=25, help='fps for each video')
27 | parser.add_argument('--rate', type=int, default=16, help='number of frames in each segment')
28 | parser.add_argument('--num_iter', type=int, default=2000, help='iterations of training')
29 | parser.add_argument('--eval_iter', type=int, default=100, help='iterations of evaluating')
30 | parser.add_argument('--batch_size', type=int, default=32, help='batch size of loading videos for training')
31 | parser.add_argument('--init_lr', type=float, default=1e-4, help='initial learning rate')
32 | parser.add_argument('--weight_decay', type=float, default=1e-3, help='weight decay for optimizer')
33 | parser.add_argument('--lamda', type=float, default=0.1, help='loss weight for aas loss')
34 | parser.add_argument('--workers', type=int, default=8, help='number of data loading workers')
35 | parser.add_argument('--seed', type=int, default=-1, help='random seed (-1 for no manual seed)')
36 | parser.add_argument('--model_file', type=str, default=None, help='the path of pre-trained model file')
37 |
38 | return init_args(parser.parse_args())
39 |
40 |
41 | class Config(object):
42 | def __init__(self, args):
43 | self.data_path = args.data_path
44 | self.save_path = args.save_path
45 | self.data_name = args.data_name
46 | self.cls_th = args.cls_th
47 | self.iou_th = args.iou_th
48 | self.act_th = eval(args.act_th)
49 | self.map_th = args.map_th
50 | self.alpha = args.alpha
51 | self.num_seg = args.num_seg
52 | self.fps = args.fps
53 | self.rate = args.rate
54 | self.num_iter = args.num_iter
55 | self.eval_iter = args.eval_iter
56 | self.batch_size = args.batch_size
57 | self.init_lr = args.init_lr
58 | self.weight_decay = args.weight_decay
59 | self.lamda = args.lamda
60 | self.workers = args.workers
61 | self.model_file = args.model_file
62 |
63 |
64 | def init_args(args):
65 | if not os.path.exists(args.save_path):
66 | os.makedirs(args.save_path)
67 |
68 | if args.seed >= 0:
69 | random.seed(args.seed)
70 | np.random.seed(args.seed)
71 | torch.manual_seed(args.seed)
72 | torch.cuda.manual_seed_all(args.seed)
73 | cudnn.deterministic = True
74 | cudnn.benchmark = False
75 |
76 | args.map_th = np.linspace(0.1, 0.7, 7) if args.data_name == 'thumos14' else np.linspace(0.5, 0.95, 10)
77 | return Config(args)
78 |
79 |
80 | # change the segment based scores to frame based scores
81 | def revert_frame(scores, num_frame):
82 | x = np.arange(scores.shape[0])
83 | f = interp1d(x, scores, kind='linear', axis=0, fill_value='extrapolate')
84 | scale = np.arange(num_frame) * scores.shape[0] / num_frame
85 | return f(scale)
86 |
87 |
88 | # split frames to action regions
89 | def grouping(frames):
90 | return np.split(frames, np.where(np.diff(frames) != 1)[0] + 1)
91 |
92 |
93 | def result2json(result, class_dict):
94 | result_file = []
95 | for key, value in result.items():
96 | for line in value:
97 | result_file.append({'label': class_dict[key], 'score': float(line[-1]),
98 | 'segment': [float(line[0]), float(line[1])]})
99 | return result_file
100 |
101 |
102 | def which_ffmpeg():
103 | result = subprocess.run(['which', 'ffmpeg'], stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
104 | return result.stdout.decode('utf-8').replace('\n', '')
105 |
106 |
107 | # ref: D2-Net: Weakly-Supervised Action Localization via Discriminative Embeddings and Denoised Activations (ICCV 2021)
108 | def filter_results(results, ambi_file):
109 | ambi_list = [line.strip('\n').split(' ') for line in list(open(ambi_file, 'r'))]
110 | for key, value in results['results'].items():
111 | for filter_item in ambi_list:
112 | if filter_item[0] == key:
113 | filtered = []
114 | for item in value:
115 | if max(float(filter_item[2]), item['segment'][0]) < min(float(filter_item[3]), item['segment'][1]):
116 | continue
117 | else:
118 | filtered.append(item)
119 | value = filtered
120 | results['results'][key] = value
121 | return results
122 |
123 |
124 | # ref: Completeness Modeling and Context Separation for Weakly Supervised Temporal Action Localization (CVPR 2019)
125 | def oic_score(frame_scores, act_score, proposal, _lambda=0.25, gamma=0.2):
126 | inner_score = np.mean(frame_scores[proposal])
127 | outer_s = max(0, int(proposal[0] - _lambda * len(proposal)))
128 | outer_e = min(int(frame_scores.shape[0] - 1), int(proposal[-1] + _lambda * len(proposal)))
129 | outer_temp_list = list(range(outer_s, int(proposal[0]))) + list(range(int(proposal[-1] + 1), outer_e + 1))
130 |
131 | if len(outer_temp_list) == 0:
132 | outer_score = 0.0
133 | else:
134 | outer_score = np.mean(frame_scores[outer_temp_list])
135 | score = inner_score - outer_score + gamma * act_score
136 | return score
137 |
--------------------------------------------------------------------------------