├── dataloader
├── __init__.py
├── cafe.py
└── dataloader.py
├── environment.yml
├── evaluation
└── cafe_eval.py
├── label_map
└── group_action_list.pbtxt
├── models
├── __init__.py
├── backbone.py
├── criterion.py
├── feed_forward.py
├── group_matcher.py
├── group_transformer.py
├── models.py
└── position_encoding.py
├── readme.md
├── scripts
├── download_checkpoints.sh
├── download_datasets.sh
├── setup.sh
├── test_cafe_place.sh
├── test_cafe_view.sh
├── train_cafe_place.sh
└── train_cafe_view.sh
├── test.py
├── train.py
└── util
├── __init__.py
├── box_ops.py
├── logger.py
├── misc.py
└── utils.py
/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dk-kim/CAFE_codebase/caee21a8ccdd5faadecbc5a5034107084a7cd270/dataloader/__init__.py
--------------------------------------------------------------------------------
/dataloader/cafe.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import torchvision.transforms as transforms
4 |
5 | import os
6 | import json
7 | import numpy as np
8 | import random
9 | from PIL import Image
10 |
11 | ACTIVITIES = ['Queueing', 'Ordering', 'Eating/Drinking', 'Working/Studying', 'Fighting', 'TakingSelfie']
12 |
13 |
14 | # read annotation files
15 | def cafe_read_annotations(path, videos, num_class):
16 | labels = {}
17 | group_to_id = {name: i for i, name in enumerate(ACTIVITIES)}
18 |
19 | for vid in videos:
20 | video_path = os.path.join(path, vid)
21 | for cid in os.listdir(video_path):
22 | clip_path = os.path.join(video_path, cid)
23 | label_path = clip_path + '/ann.json'
24 |
25 | with open(label_path, 'r') as file:
26 | groups = {}
27 | boxes, actions, activities, members, membership = [], [], [], [], []
28 |
29 | values = json.load(file)
30 | num_frames = values['framesCount']
31 | frame_interval = values['framesEach']
32 | actors = values['figures']
33 |
34 | key_frame = actors[0]['shapes'][0]['frame']
35 |
36 | for i, actor in enumerate(actors):
37 | actor_idx = actor['id']
38 | group_name = actor['label']
39 |
40 | box = actor['shapes'][0]['coordinates']
41 | x1, y1 = box[0]
42 | x2, y2 = box[1]
43 | x_c, y_c = (x1 + x2) / 2, (y1 + y2) / 2
44 | w, h = x2 - x1, y2 - y1
45 | boxes.append([x_c, y_c, w, h])
46 |
47 | if group_name != 'individual':
48 | group_idx = int(group_name[-1])
49 | if actor['attributes'][0]['value'] != "":
50 | action = group_to_id[actor['attributes'][0]['value']['key']]
51 |
52 | if group_idx not in groups.keys():
53 | groups[group_idx] = {'activity': action}
54 |
55 | if 'members' in groups[group_idx].keys():
56 | groups[group_idx]['members'][i] = 1
57 | else:
58 | groups[group_idx]['members'] = torch.zeros(len(actors))
59 | groups[group_idx]['members'][i] = 1
60 | else:
61 | if group_idx in groups.keys():
62 | action = groups[group_idx]['activity']
63 | else:
64 | action = -1
65 | else:
66 | action = num_class
67 | group_idx = 0
68 |
69 | actions.append(action)
70 | membership.append(group_idx)
71 |
72 | for i, action in enumerate(actions):
73 | if action == -1:
74 | group_idx = membership[i]
75 |
76 | if group_idx in groups.keys():
77 | new_action = groups[group_idx]['activity']
78 | actions[i] = new_action
79 |
80 | group_members = groups[group_idx]['members']
81 | group_members[i] = 1
82 | else:
83 | membership[i] = 0
84 | actions[i] = num_class
85 |
86 | for group_id in sorted(groups):
87 | if group_id - 1 >= len(groups):
88 | new_id = len(groups)
89 |
90 | while new_id > 0:
91 | if new_id not in groups:
92 | groups[new_id] = groups[group_id]
93 | del groups[group_id]
94 | for i in range(len(membership)):
95 | if membership[i] == group_id:
96 | membership[i] = new_id
97 | group_id = new_id
98 | new_id -= 1
99 |
100 | for group_id in sorted(groups):
101 | activities.append(groups[group_id]['activity'])
102 | members.append(groups[group_id]['members'])
103 |
104 | actions = np.array(actions, dtype=np.int32)
105 | boxes = np.vstack(boxes)
106 | membership = np.array(membership, dtype=np.int32) - 1
107 | activities = np.array(activities, dtype=np.int32)
108 |
109 | actions = torch.from_numpy(actions).long()
110 | boxes = torch.from_numpy(boxes).float()
111 | membership = torch.from_numpy(membership).long()
112 | activities = torch.from_numpy(activities).long()
113 |
114 | if len(members) == 0:
115 | members = torch.tensor(members)
116 | else:
117 | members = torch.stack(members).float()
118 |
119 | annotations = {
120 | 'boxes': boxes,
121 | 'actions': actions,
122 | 'membership': membership,
123 | 'activities': activities,
124 | 'members': members,
125 | 'num_frames': num_frames,
126 | 'interval': frame_interval,
127 | 'key_frame': key_frame,
128 | }
129 |
130 | if len(annotations['activities']) != 0:
131 | labels[(int(vid), int(cid))] = annotations
132 |
133 | return labels
134 |
135 |
136 | def cafe_all_frames(labels):
137 | frames = []
138 |
139 | for sid, anns in labels.items():
140 | frames.append(sid)
141 | return frames
142 |
143 |
144 | class CafeDataset(data.Dataset):
145 | def __init__(self, frames, anns, tracks, image_path, args, is_training=True):
146 | super(CafeDataset, self).__init__()
147 | self.frames = frames
148 | self.anns = anns
149 | self.tracks = tracks
150 | self.image_path = image_path
151 | self.image_size = (args.image_width, args.image_height)
152 | self.num_boxes = args.num_boxes
153 | self.random_sampling = args.random_sampling
154 | self.num_frame = args.num_frame
155 | self.num_class = args.num_class
156 | self.is_training = is_training
157 | self.transform = transforms.Compose([
158 | transforms.Resize((args.image_height, args.image_width)),
159 | transforms.ToTensor(),
160 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
161 | ])
162 |
163 | def __getitem__(self, idx):
164 | if self.num_frame == 1:
165 | frames = self.select_key_frames(self.frames[idx])
166 | else:
167 | frames = self.select_frames(self.frames[idx])
168 |
169 | samples = self.load_samples(frames)
170 |
171 | return samples
172 |
173 | def __len__(self):
174 | return len(self.frames)
175 |
176 | def select_key_frames(self, frame):
177 | annotation = self.anns[frame]
178 | key_frame = annotation['key_frame']
179 |
180 | return [(frame, int(key_frame))]
181 |
182 | def select_frames(self, frame):
183 | annotation = self.anns[frame]
184 | key_frame = annotation['key_frame']
185 | total_frames = annotation['num_frames']
186 | interval = annotation['interval']
187 |
188 | if self.is_training:
189 | # random sampling
190 | if self.random_sampling:
191 | sample_frames = random.sample(range(total_frames), self.num_frame)
192 | sample_frames.sort()
193 | # segment-based sampling
194 | else:
195 | segment_duration = total_frames // self.num_frame
196 | sample_frames = np.multiply(list(range(self.num_frame)), segment_duration) + np.random.randint(
197 | segment_duration, size=self.num_frame)
198 | else:
199 | # random sampling
200 | if self.random_sampling:
201 | sample_frames = random.sample(range(total_frames), self.num_frame)
202 | sample_frames.sort()
203 | # segment-based sampling
204 | else:
205 | segment_duration = total_frames // self.num_frame
206 | sample_frames = np.multiply(list(range(self.num_frame)), segment_duration) + np.random.randint(
207 | segment_duration, size=self.num_frame)
208 |
209 | return [(frame, int(fid * annotation['interval'])) for fid in sample_frames]
210 |
211 | def load_samples(self, frames):
212 | images, boxes, gt_boxes, actions, activities, members, membership = [], [], [], [], [], [], []
213 | targets = {}
214 | fids = []
215 |
216 | for i, (frame, fid) in enumerate(frames):
217 | vid, cid = frame
218 | fids.append(fid)
219 | img = Image.open(self.image_path + '/%s/%s/images/frames_%d.jpg' % (vid, cid, fid))
220 | image_w, image_h = img.width, img.height
221 | img = self.transform(img)
222 | images.append(img)
223 |
224 | num_boxes = self.anns[frame]['boxes'].shape[0]
225 |
226 | for box in self.anns[frame]['boxes']:
227 | x_c, y_c, w, h = box
228 | gt_boxes.append([x_c / image_w, y_c / image_h, w / image_w, h / image_h])
229 |
230 | temp_boxes = np.ones((num_boxes, 4))
231 | for j, track in enumerate(self.tracks[(vid, cid)][fid]):
232 | _id, x1, y1, x2, y2 = track
233 |
234 | if x1 < 0.0 and y2 < 0.0:
235 | x1, y1, x2, y2 = 0.0, 0.0, 1e-8, 1e-8
236 |
237 | x_c, y_c = (x1 + x2) / 2, (y1 + y2) / 2
238 | w, h = x2 - x1, y2 - y1
239 |
240 | if _id <= num_boxes:
241 | temp_boxes[int(_id - 1)] = np.array([x_c, y_c, w, h])
242 |
243 | boxes.append(temp_boxes)
244 | actions = [self.anns[frame]['actions']]
245 | activities = [self.anns[frame]['activities']]
246 | members = [self.anns[frame]['members']]
247 | membership = [self.anns[frame]['membership']]
248 |
249 | if len(boxes[-1]) != self.num_boxes:
250 | boxes[-1] = np.vstack([boxes[-1], (self.num_boxes - len(boxes[-1])) * [[0.0, 0.0, 0.0, 0.0]]])
251 |
252 | if len(actions[-1]) != self.num_boxes:
253 | actions[-1] = torch.cat((actions[-1], torch.tensor((self.num_boxes - len(actions[-1])) * [self.num_class + 1])))
254 |
255 | if members[-1].shape[1] != self.num_boxes:
256 | members[-1] = torch.hstack(
257 | (members[-1], torch.zeros((members[-1].shape[0], self.num_boxes - members[-1].shape[1]))))
258 |
259 | if len(membership) != self.num_boxes:
260 | membership[-1] = torch.cat((membership[-1], torch.tensor((self.num_boxes - len(membership[-1])) * [-1])))
261 |
262 | images = torch.stack(images)
263 | boxes = np.vstack(boxes).reshape([self.num_frame, -1, 4])
264 | gt_boxes = np.vstack(gt_boxes).reshape([self.num_frame, -1, 4])
265 | actions = torch.stack(actions)
266 | membership = torch.stack(membership)
267 |
268 | if len(activities) == 0:
269 | activities = torch.tensor(activities)
270 | members = torch.tensor(activities)
271 | else:
272 | activities = torch.stack(activities)
273 | members = torch.stack(members)
274 |
275 | boxes = torch.from_numpy(boxes).float()
276 | gt_boxes = torch.from_numpy(gt_boxes).float()
277 |
278 | targets['actions'] = actions
279 | targets['activities'] = activities
280 | targets['boxes'] = boxes
281 | targets['gt_boxes'] = gt_boxes
282 | targets['members'] = members
283 | targets['membership'] = membership
284 |
285 | infos = {'vid': vid, 'sid': cid, 'fid': fids, 'key_frame': self.anns[frame]['key_frame']}
286 |
287 | return images, targets, infos
288 |
--------------------------------------------------------------------------------
/dataloader/dataloader.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from SSU (https://github.com/cvlab-epfl/social-scene-understanding)
3 | # Modified from ARG (https://github.com/wjchaoGit/Group-Activity-Recognition)
4 | # ------------------------------------------------------------------------
5 | from .cafe import *
6 |
7 | import pickle
8 |
9 | TRAIN_CAFE_P = ['1', '2', '3', '4', '9', '10', '11', '12', '17', '18', '19', '20', '21', '22', '23', '24']
10 | VAL_CAFE_P = ['13', '14', '15', '16']
11 | TEST_CAFE_P = ['5', '6', '7', '8']
12 |
13 | TRAIN_CAFE_V = ['1', '2', '5', '6', '9', '10', '13', '14', '17', '18', '21', '22']
14 | VAL_CAFE_V = ['3', '7', '11', '15', '19', '23']
15 | TEST_CAFE_V = ['4', '8', '12', '16', '20', '24']
16 |
17 |
18 | def read_dataset(args):
19 | if args.dataset == 'cafe':
20 | data_path = args.data_path + 'cafe'
21 |
22 | # split-by-place setting
23 | if args.split == 'place':
24 | TRAIN_VIDEOS_CAFE = TRAIN_CAFE_P
25 | VAL_VIDEOS_CAFE = VAL_CAFE_P
26 | TEST_VIDEOS_CAFE = TEST_CAFE_P
27 | # split-by-view setting
28 | elif args.split == 'view':
29 | TRAIN_VIDEOS_CAFE = TRAIN_CAFE_V
30 | VAL_VIDEOS_CAFE = VAL_CAFE_V
31 | TEST_VIDEOS_CAFE = TEST_CAFE_V
32 | else:
33 | assert False
34 |
35 | if args.val_mode:
36 | train_data = cafe_read_annotations(data_path, TRAIN_VIDEOS_CAFE, args.num_class)
37 | train_frames = cafe_all_frames(train_data)
38 |
39 | test_data = cafe_read_annotations(data_path, VAL_VIDEOS_CAFE, args.num_class)
40 | test_frames = cafe_all_frames(test_data)
41 | else:
42 | train_data = cafe_read_annotations(data_path, TRAIN_VIDEOS_CAFE + VAL_VIDEOS_CAFE, args.num_class)
43 | train_frames = cafe_all_frames(train_data)
44 |
45 | test_data = cafe_read_annotations(data_path, TEST_VIDEOS_CAFE, args.num_class)
46 | test_frames = cafe_all_frames(test_data)
47 |
48 | # actor tracklets for all frames
49 | all_tracks = pickle.load(open(data_path + '/gt_tracks.pkl', 'rb'))
50 |
51 | train_set = CafeDataset(train_frames, train_data, all_tracks, data_path, args, is_training=True)
52 | test_set = CafeDataset(test_frames, test_data, all_tracks, data_path, args, is_training=False)
53 | else:
54 | assert False
55 |
56 | print("%d train samples and %d test samples" % (len(train_frames), len(test_frames)))
57 |
58 | return train_set, test_set
59 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: gad
2 | channels:
3 | - pytorch
4 | - defaults
5 | dependencies:
6 | - _libgcc_mutex=0.1=main
7 | - _openmp_mutex=4.5=1_gnu
8 | - blas=1.0=mkl
9 | - ca-certificates=2021.10.26=h06a4308_2
10 | - certifi=2021.10.8=py38h06a4308_0
11 | - cudatoolkit=11.0.221=h6bb024c_0
12 | - freetype=2.11.0=h70c0345_0
13 | - giflib=5.2.1=h7b6447c_0
14 | - intel-openmp=2021.4.0=h06a4308_3561
15 | - joblib=1.1.0=pyhd3eb1b0_0
16 | - jpeg=9b=h024ee3a_2
17 | - lcms2=2.12=h3be6417_0
18 | - ld_impl_linux-64=2.35.1=h7274673_9
19 | - libffi=3.3=he6710b0_2
20 | - libgcc-ng=9.3.0=h5101ec6_17
21 | - libgfortran-ng=7.5.0=ha8ba4b0_17
22 | - libgfortran4=7.5.0=ha8ba4b0_17
23 | - libgomp=9.3.0=h5101ec6_17
24 | - libpng=1.6.37=hbc83047_0
25 | - libstdcxx-ng=9.3.0=hd4cf53a_17
26 | - libtiff=4.2.0=h85742a9_0
27 | - libuv=1.40.0=h7b6447c_0
28 | - libwebp=1.2.0=h89dd481_0
29 | - libwebp-base=1.2.0=h27cfd23_0
30 | - lz4-c=1.9.3=h295c915_1
31 | - mkl=2021.4.0=h06a4308_640
32 | - mkl-service=2.4.0=py38h7f8727e_0
33 | - mkl_fft=1.3.1=py38hd3c417c_0
34 | - mkl_random=1.2.2=py38h51133e4_0
35 | - ncurses=6.3=h7f8727e_2
36 | - ninja=1.10.2=py38hd09550d_3
37 | - numpy=1.21.2=py38h20f2e39_0
38 | - numpy-base=1.21.2=py38h79a1101_0
39 | - olefile=0.46=pyhd3eb1b0_0
40 | - openssl=1.1.1l=h7f8727e_0
41 | - pillow=8.4.0=py38h5aabda8_0
42 | - pip=21.2.4=py38h06a4308_0
43 | - python=3.8.5=h7579374_1
44 | - readline=8.1=h27cfd23_0
45 | - scikit-learn=1.0.1=py38h51133e4_0
46 | - scipy=1.7.1=py38h292c36d_2
47 | - setuptools=58.0.4=py38h06a4308_0
48 | - six=1.16.0=pyhd3eb1b0_0
49 | - sqlite=3.36.0=hc218d9a_0
50 | - threadpoolctl=2.2.0=pyh0d69192_0
51 | - tk=8.6.11=h1ccaba5_0
52 | - torchvision=0.8.2=py38_cu110
53 | - typing_extensions=3.10.0.2=pyh06a4308_0
54 | - wheel=0.37.0=pyhd3eb1b0_1
55 | - xz=5.2.5=h7b6447c_0
56 | - zlib=1.2.11=h7b6447c_3
57 | - zstd=1.4.9=haebb681_0
58 | prefix: /opt/conda/envs/gad
59 |
--------------------------------------------------------------------------------
/evaluation/cafe_eval.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from JRDB-ACT (https://github.com/JRDB-dataset/jrdb_toolkit/tree/main/Action%26Social_grouping_eval)
3 | # ------------------------------------------------------------------------
4 | from collections import defaultdict, Counter
5 | import numpy as np
6 | import copy
7 |
8 |
9 | def make_image_key(v_id, c_id, f_id):
10 | """Returns a unique identifier for a video id & clip id & frame id"""
11 | return "%d,%d,%d" % (int(v_id), int(c_id), int(f_id))
12 |
13 | def make_clip_key(image_key):
14 | """Returns a unique identifier for a video id & clip id"""
15 | v_id = image_key.split(',')[0]
16 | c_id = image_key.split(',')[1]
17 | return "%d,%d" % (int(v_id), int(c_id))
18 |
19 | def read_text_file(text_file, eval_type, mode):
20 | """Loads boxes and class labels from a CSV file in the cafe format.
21 |
22 | Args:
23 | text_file: A file object.
24 | mode: 'gt' or 'pred'
25 | eval_type:
26 | 'gt_base': Eval type for trained model with ground turth actor tracklets as inputs.
27 | 'detect_base': Eval type for trained model with tracker actor tracklets as inputs.
28 |
29 | Returns:
30 | boxes: A dictionary mapping each unique image key (string) to a list of
31 | boxes, given as coordinates [y1, x1, y2, x2].
32 | g_labels: A dictionary mapping each unique image key (string) to a list of
33 | integer group id labels, matching the corresponding box in 'boxes'.
34 | act_labels: A dictionary mapping each unique image key (string) to a list of
35 | integer group activity class lables, matching the corresponding box in `boxes`.
36 | a_scores: A dictionary mapping each unique image key (string) to a list of
37 | actor confidence score values lables, matching the corresponding box in `boxes`.
38 | g_scores: A dictionary mapping each unique image key (string) to a list of
39 | group confidence score values lables, matching the corresponding box in `boxes`.
40 | """
41 | boxes = defaultdict(list)
42 | g_labels = defaultdict(list)
43 | act_labels = defaultdict(list)
44 | a_scores = defaultdict(list)
45 | g_scores = defaultdict(list)
46 | # reads each row in text file.
47 | with open(text_file.name) as r:
48 | for line in r.readlines():
49 | row = line[:-1].split(' ')
50 | # makes image key.
51 | image_key = make_image_key(row[0], row[1], row[2])
52 | # box coordinates.
53 | x1, y1, x2, y2 = [float(n) for n in row[3:7]]
54 | # actor confidence score.
55 | if eval_type == 'detect_base' and mode == 'pred':
56 | a_score = float(row[10])
57 | else:
58 | a_score = 1.0
59 | # group confidence score.
60 | if mode == 'gt':
61 | g_score = None
62 | elif mode == 'pred':
63 | g_score = float(row[9])
64 | # group identity document.
65 | group_id = int(row[7])
66 | # group activity label.
67 | activity = int(row[8])
68 |
69 | boxes[image_key].append([x1, y1, x2, y2])
70 | g_labels[image_key].append(group_id)
71 | act_labels[image_key].append(activity)
72 | a_scores[image_key].append(a_score)
73 | g_scores[image_key].append(g_score)
74 | return boxes, g_labels, act_labels, a_scores, g_scores
75 |
76 | def actor_matching(pred_boxes, pred_a_scores, gt_boxes):
77 | """matches prediction tracklets and ground truth tracklets.
78 |
79 | Args:
80 | pred_boxes: A dictionary mapping each unique image key (string) to a list of
81 | prediction boxes, given as coordinates [y1, x1, y2, x2]. it has same permutation
82 | in image keys of each clip key.
83 | pred_a_scores: A dictionary mapping each unique image key (string) to a list of
84 | actor confidence score values lables, matching the corresponding box in `pred_boxes`.
85 | gt_boxes: A dictionary mapping each unique image key (string) to a list of
86 | ground truth boxes, given as coordinates [y1, x1, y2, x2]. it has same permutation
87 | in image keys of each clip key.
88 |
89 | Returns:
90 | matching_results: A dictionary mapping each unique clip key (string) to a list of
91 | id matching results between prediction tracklets and ground truth tracklets, given as
92 | {[prediction id: [ground truth id]]}.
93 | """
94 | image_keys = pred_boxes.keys()
95 | frame_list = defaultdict(list)
96 | matching_results = defaultdict(list)
97 | for image_key in image_keys:
98 | clip_key = make_clip_key(image_key)
99 | frame_list[clip_key].append(image_key)
100 | clip_keys = frame_list.keys()
101 | for clip_key in clip_keys:
102 | matching = defaultdict(list)
103 | # IoU matrix
104 | iou_matrix = np.zeros((len(pred_boxes[frame_list[clip_key][0]]), len(gt_boxes[frame_list[clip_key][0]])))
105 | # confidence score for prediction tracklets.
106 | confidence_mean = np.zeros(len(pred_boxes[frame_list[clip_key][0]]))
107 | # image numbers of clip.
108 | frame_len = len(frame_list[clip_key])
109 | # puts sum of IoUs on the IoU matrix and sum of confidence score.
110 | for image_key in frame_list[clip_key]:
111 | for i, pred_box in enumerate(pred_boxes[image_key]):
112 | if pred_box[2] != 0 and pred_box[2] != -1:
113 | confidence_mean[i] += pred_a_scores[image_key][i]
114 | for j, gt_box in enumerate(gt_boxes[image_key]):
115 | if gt_box[2] != 0 and gt_box[2] != -1:
116 | iou_matrix[i,j] += IoU(pred_box, gt_box)
117 | # takes each mean of IoU and confidence score.
118 | confidence_mean = confidence_mean / frame_len
119 | iou_matrix = iou_matrix / frame_len
120 | # sorts by confidence score.
121 | sorted_scores = sorted(confidence_mean, reverse=True)
122 |
123 | # matching algorithm
124 | duplicated_score = 0
125 | for score in sorted_scores:
126 | if duplicated_score == score:
127 | continue
128 | else:
129 | for i in (np.where(confidence_mean == score)[0]):
130 | if max(iou_matrix[i]) > 0.5:
131 | j = np.where(iou_matrix[i] == max(iou_matrix[i]))[0][0]
132 | matching[i].append(j)
133 | iou_matrix[:,j] = 0
134 |
135 | # matching results
136 | matching_results[clip_key].append(matching)
137 | return matching_results
138 |
139 |
140 | def make_groups(boxes, g_labels, act_labels, g_scores):
141 | """combines boxes, activity, score to same group, same image
142 |
143 | Returns:
144 | groups_ids: A dictionary mapping each unique clip key (string) to a list of
145 | actor ids of each 'g_label'.
146 | groups_activity: A dictionary mapping each unique clip key (string) to a list of
147 | group activity class labels.
148 | groups_score: A dictionary mapping each unique clip key (string) to a list of
149 | group confidence score.
150 | """
151 | image_keys = boxes.keys()
152 | groups_activity = defaultdict(list)
153 | groups_score = defaultdict(list)
154 | groups_ids = defaultdict(list)
155 | frame_list = defaultdict(list)
156 | # makes clip key.
157 | for image_key in image_keys:
158 | clip_key = make_clip_key(image_key)
159 | frame_list[clip_key].append(image_key)
160 | clip_keys = frame_list.keys()
161 | for clip_key in clip_keys:
162 | group_ids = defaultdict(list)
163 | group_activity = defaultdict(set)
164 | group_score = defaultdict(set)
165 | for i, (g_label, act_label, g_score) in enumerate(
166 | zip(g_labels[frame_list[clip_key][0]], act_labels[frame_list[clip_key][0]],
167 | g_scores[frame_list[clip_key][0]])):
168 | group_ids[g_label].append(i)
169 | group_activity[g_label].add(act_label)
170 | group_score[g_label].add(g_score)
171 | groups_ids[clip_key].append(group_ids)
172 | groups_activity[clip_key].append(group_activity)
173 | groups_score[clip_key].append(group_score)
174 |
175 | return groups_ids, groups_activity, groups_score
176 |
177 |
178 | def read_labelmap(labelmap_file):
179 | """Reads a labelmap without the dependency on protocol buffers.
180 |
181 | Args:
182 | labelmap_file: A file object containing a label map protocol buffer.
183 |
184 | Returns:
185 | labelmap: The label map in the form used by the object_detection_evaluation
186 | module - a list of {"id": integer, "name": classname } dicts.
187 | class_ids: A set containing all of the valid class id integers.
188 | """
189 | labelmap = []
190 | class_ids = set()
191 | name = ""
192 | class_id = ""
193 | for line in labelmap_file:
194 | if line.startswith(" name:"):
195 | name = line.split('"')[1]
196 | elif line.startswith(" id:") or line.startswith(" label_id:"):
197 | class_id = int(line.strip().split(" ")[-1])
198 | labelmap.append({"id": class_id, "name": name})
199 | class_ids.add(class_id)
200 | return labelmap, class_ids
201 |
202 |
203 | def IoU(box1, box2):
204 | """calculates IoU between two different boxes."""
205 | # box = (x1, y1, x2, y2)
206 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
207 | box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
208 |
209 | # obtain x1, y1, x2, y2 of the intersection
210 | x1 = max(box1[0], box2[0])
211 | y1 = max(box1[1], box2[1])
212 | x2 = min(box1[2], box2[2])
213 | y2 = min(box1[3], box2[3])
214 |
215 | # compute the width and height of the intersection
216 | w = max(0, x2 - x1)
217 | h = max(0, y2 - y1)
218 |
219 | inter = w * h
220 | iou = inter / (box1_area + box2_area - inter + 1e-8)
221 | return iou
222 |
223 |
224 | def cal_group_IoU(pred_group, gt_group):
225 | """calculates group IoU between two different groups"""
226 | # Intersection
227 | Intersection = sum([1 for det_id in pred_group[2] if det_id in gt_group[2]])
228 |
229 | # group IoU
230 | if Intersection != 0:
231 | group_IoU = Intersection / (len(pred_group[2]) + len(gt_group[2]) - Intersection)
232 | else:
233 | group_IoU = 0
234 | return group_IoU
235 |
236 |
237 | def calculateAveragePrecision(rec, prec):
238 | """calculates AP score of each activity class by all-point interploation method."""
239 | mrec = [0] + [e for e in rec] + [1]
240 | mpre = [0] + [e for e in prec] + [0]
241 |
242 | for i in range(len(mpre) - 1, 0, -1):
243 | mpre[i - 1] = max(mpre[i - 1], mpre[i])
244 |
245 | ii = []
246 |
247 | for i in range(len(mrec) - 1):
248 | if mrec[1:][i] != mrec[0:-1][i]:
249 | ii.append(i + 1)
250 |
251 | ap = 0
252 | for i in ii:
253 | ap = ap + np.sum((mrec[i] - mrec[i - 1]) * mpre[i])
254 |
255 | return [ap, mpre[0:len(mpre) - 1], mrec[0:len(mpre) - 1], ii]
256 |
257 |
258 | def outlier_metric(gt_groups_ids, gt_groups_activity, pred_groups_ids, pred_groups_activity, num_class):
259 | """calculates Outlier mIoU.
260 |
261 | Args:
262 | num_class: A number of activity classes.
263 |
264 | Returns:
265 | outlier_mIoU: Mean of outlier IoUs on each clip.
266 | """
267 | clip_IoU = defaultdict(list)
268 | TP = defaultdict(list)
269 | clip_keys = pred_groups_ids.keys()
270 | c_pred_groups_activity = copy.deepcopy(pred_groups_activity)
271 | c_gt_groups_activity = copy.deepcopy(gt_groups_activity)
272 | # prediction groups on each class. defines group has members equals or more than two.
273 | pred_groups = [[clip_key, group_id, pred_groups_ids[clip_key][0][group_id]] for clip_key in clip_keys if
274 | clip_key in gt_groups_ids.keys() for group_id in pred_groups_ids[clip_key][0].keys() if
275 | c_pred_groups_activity[clip_key][0][group_id].pop() == (num_class + 1)]
276 | # ground truth groups on each class.
277 | gt_groups = [[clip_key, group_id, gt_groups_ids[clip_key][0][group_id]] for clip_key in clip_keys if
278 | clip_key in gt_groups_ids.keys() for group_id in gt_groups_ids[clip_key][0].keys() if
279 | c_gt_groups_activity[clip_key][0][group_id].pop() == (num_class + 1)]
280 | for clip_key in clip_keys:
281 | # escapes error that there are not exist pred_image_key on gt.txt.
282 | if clip_key in gt_groups_ids.keys():
283 | # groups on same clip
284 | c_pred_groups = [pred_group for pred_group in pred_groups if pred_group[0] == clip_key]
285 | c_gt_groups = [gt_group for gt_group in gt_groups if gt_group[0] == clip_key]
286 | if len(c_pred_groups) != 0 and len(c_gt_groups) != 0:
287 | # outliers on prediction and ground truth.
288 | c_pred_ids = [pred_id for c_pred_group in c_pred_groups for pred_id in c_pred_group[2]]
289 | c_gt_ids = [gt_id for c_gt_group in c_gt_groups for gt_id in c_gt_group[2]]
290 | # number of True positive outliers.
291 | TP[clip_key] = sum([1 for pred_id in c_pred_ids if pred_id in c_gt_ids])
292 | clip_IoU[clip_key] = TP[clip_key] / (len(c_pred_ids) + len(c_gt_ids) - TP[clip_key])
293 | clip_IoU['total'].append(clip_IoU[clip_key])
294 | elif len(c_pred_groups) != 0 or len(c_gt_groups) != 0:
295 | TP[clip_key] = 0
296 | clip_IoU[clip_key] = 0
297 | clip_IoU['total'].append(clip_IoU[clip_key])
298 | # outlier mIoU.
299 | outlier_mIoU = np.array(clip_IoU['total']).mean()
300 | return outlier_mIoU * 100.0
301 |
302 |
303 | def group_mAP_eval(gt_groups_ids, gt_groups_activity, pred_groups_ids, pred_groups_activity, pred_groups_scores,
304 | categories, thresh):
305 | """calculates group mAP.
306 |
307 | Args:
308 | categories: A list of group activity classes, given as {name: ,id: }.
309 | thresh: A group IoU threshold for true positive prediction group condition.
310 |
311 | Returns:
312 | group_mAP: Mean of group APs on each activity class.
313 | group_APs; A list of each group AP on each activity class.
314 | """
315 | clip_keys = pred_groups_ids.keys()
316 | # acc on each class.
317 | group_APs = np.zeros(len(categories))
318 | for c, clas in enumerate(categories):
319 | # copy for set funtion to pop.
320 | c_pred_groups_activity = copy.deepcopy(pred_groups_activity)
321 | c_gt_groups_activity = copy.deepcopy(gt_groups_activity)
322 | # prediction groups on each class.
323 | pred_groups = [
324 | [clip_key, group_id, pred_groups_ids[clip_key][0][group_id], pred_groups_scores[clip_key][0][group_id]] for
325 | clip_key in clip_keys if clip_key in gt_groups_ids.keys() for group_id in
326 | pred_groups_ids[clip_key][0].keys() if
327 | c_pred_groups_activity[clip_key][0][group_id].pop() == clas['id'] and len(
328 | pred_groups_ids[clip_key][0][group_id]) >= 2]
329 | # ground truth groups on each class.
330 | gt_groups = [[clip_key, group_id, gt_groups_ids[clip_key][0][group_id]] for clip_key in clip_keys if
331 | clip_key in gt_groups_ids.keys() for group_id in gt_groups_ids[clip_key][0].keys() if
332 | c_gt_groups_activity[clip_key][0][group_id].pop() == clas['id'] and len(
333 | gt_groups_ids[clip_key][0][group_id]) >= 2]
334 |
335 | # denominator of Recall.
336 | npos = len(gt_groups)
337 |
338 | # sorts det_groups in descending order for g_score.
339 | pred_groups = sorted(pred_groups, key=lambda conf: conf[3], reverse=True)
340 |
341 | TP = np.zeros(len(pred_groups))
342 | FP = np.zeros(len(pred_groups))
343 |
344 | det = Counter(gt_group[0] for gt_group in gt_groups)
345 |
346 | for key, val in det.items():
347 | det[key] = np.zeros(val)
348 |
349 | # AP matching algorithm.
350 | for p, pred_group in enumerate(pred_groups):
351 | if pred_group[0] in gt_groups_ids.keys():
352 | gt = [gt_group for gt_group in gt_groups if gt_group[0] == pred_group[0]]
353 | group_IoU_Max = 0
354 | for j, gt_group in enumerate(gt):
355 | group_IoU = cal_group_IoU(pred_group, gt_group)
356 | if group_IoU > group_IoU_Max:
357 | group_IoU_Max = group_IoU
358 | jmax = j
359 | # true positive prediction group condition.
360 | if group_IoU_Max >= thresh:
361 | if det[pred_group[0]][jmax] == 0:
362 | TP[p] = 1
363 | det[pred_group[0]][jmax] = 1
364 | else:
365 | FP[p] = 1
366 | else:
367 | FP[p] = 1
368 |
369 | acc_FP = np.cumsum(FP)
370 | acc_TP = np.cumsum(TP)
371 | # recall
372 | rec = acc_TP / npos
373 | # precision
374 | prec = np.divide(acc_TP, (acc_FP + acc_TP))
375 | [ap, mpre, mrec, ii] = calculateAveragePrecision(rec, prec)
376 | # group AP on each group activity class
377 | group_APs[c] = ap * 100
378 | # group mAP
379 | group_mAP = group_APs.mean()
380 | return group_mAP, group_APs
381 |
382 | class GAD_Evaluation():
383 | def __init__(self, args):
384 | super(GAD_Evaluation, self).__init__()
385 | self.eval_type = args.eval_type
386 | self.categories, self.class_whitelist = read_labelmap(args.labelmap)
387 | self.gt_boxes, self.gt_g_labels, self.gt_act_labels, _, self.gt_g_scores = read_text_file(args.groundtruth, self.eval_type, mode='gt')
388 | self.gt_groups_ids, self.gt_groups_activity, _ = make_groups(
389 | self.gt_boxes, self.gt_g_labels, self.gt_act_labels, self.gt_g_scores)
390 |
391 |
392 | def evaluate(self, detections):
393 | pred_boxes, pred_g_labels, pred_act_labels, pred_a_scores, pred_g_scores = read_text_file(detections, self.eval_type, mode='pred')
394 | pred_groups_ids, pred_groups_activity, pred_groups_scores = make_groups(pred_boxes, pred_g_labels,
395 | pred_act_labels,
396 | pred_g_scores)
397 | group_mAP, group_APs = group_mAP_eval(self.gt_groups_ids, self.gt_groups_activity,
398 | pred_groups_ids, pred_groups_activity, pred_groups_scores,
399 | self.categories, thresh=1.0)
400 | group_mAP_2, group_APs_2 = group_mAP_eval(self.gt_groups_ids, self.gt_groups_activity,
401 | pred_groups_ids, pred_groups_activity, pred_groups_scores,
402 | self.categories, thresh=0.5)
403 | outlier_mIoU = outlier_metric(self.gt_groups_ids, self.gt_groups_activity,
404 | pred_groups_ids, pred_groups_activity,
405 | len(self.categories))
406 | result = {
407 | 'group_APs_1.0': group_APs,
408 | 'group_mAP_1.0': group_mAP,
409 | 'group_APs_0.5': group_APs_2,
410 | 'group_mAP_0.5': group_mAP_2,
411 | 'outlier_mIoU': outlier_mIoU,
412 | }
413 | return result
414 |
--------------------------------------------------------------------------------
/label_map/group_action_list.pbtxt:
--------------------------------------------------------------------------------
1 | item {
2 | name: "Queueing"
3 | id: 1
4 | }
5 | item {
6 | name: "Ordering"
7 | id: 2
8 | }
9 | item {
10 | name: "Eating/Drinking"
11 | id: 3
12 | }
13 | item {
14 | name: "Working/Studying"
15 | id: 4
16 | }
17 | item {
18 | name: "Fighting"
19 | id: 5
20 | }
21 | item {
22 | name: "TakingSelfie"
23 | id: 6
24 | }
25 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .group_matcher import build_group_matcher
2 | from .criterion import SetCriterion
3 | from .models import GADTR
4 |
5 |
6 | def build_model(args):
7 | model = GADTR(args)
8 |
9 | losses = ['labels', 'cardinality']
10 | group_losses = ['group_labels', 'group_cardinality', 'group_code', 'group_consistency']
11 |
12 | # Set loss coefficients
13 | weight_dict = {}
14 | weight_dict['loss_ce'] = args.ce_loss_coef
15 | weight_dict['loss_group_ce'] = args.group_ce_loss_coef
16 | weight_dict['loss_group_code'] = args.group_code_loss_coef
17 | weight_dict['loss_consistency'] = args.consistency_loss_coef
18 |
19 | # Group matching
20 | group_matcher = build_group_matcher(args)
21 |
22 | # Loss functions
23 | criterion = SetCriterion(args.num_class, weight_dict=weight_dict, eos_coef=args.eos_coef,
24 | losses=losses, group_losses=group_losses, group_matcher=group_matcher, args=args)
25 |
26 | criterion.cuda()
27 |
28 | return model, criterion
29 |
--------------------------------------------------------------------------------
/models/backbone.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from DETR (https://github.com/facebookresearch/detr)
3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torchvision
9 | from torchvision.models._utils import IntermediateLayerGetter
10 |
11 | from .position_encoding import build_position_encoding
12 |
13 |
14 | class FrozenBatchNorm2d(torch.nn.Module):
15 | """
16 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
17 |
18 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
19 | without which any other models than torchvision.models.resnet[18,34,50,101]
20 | produce nans.
21 | """
22 |
23 | def __init__(self, n):
24 | super(FrozenBatchNorm2d, self).__init__()
25 | self.register_buffer("weight", torch.ones(n))
26 | self.register_buffer("bias", torch.zeros(n))
27 | self.register_buffer("running_mean", torch.zeros(n))
28 | self.register_buffer("running_var", torch.ones(n))
29 |
30 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
31 | missing_keys, unexpected_keys, error_msgs):
32 | num_batches_tracked_key = prefix + 'num_batches_tracked'
33 | if num_batches_tracked_key in state_dict:
34 | del state_dict[num_batches_tracked_key]
35 |
36 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
37 | state_dict, prefix, local_metadata, strict,
38 | missing_keys, unexpected_keys, error_msgs)
39 |
40 | def forward(self, x):
41 | # move reshapes to the beginning
42 | # to make it fuser-friendly
43 | w = self.weight.reshape(1, -1, 1, 1)
44 | b = self.bias.reshape(1, -1, 1, 1)
45 | rv = self.running_var.reshape(1, -1, 1, 1)
46 | rm = self.running_mean.reshape(1, -1, 1, 1)
47 | eps = 1e-5
48 | scale = w * (rv + eps).rsqrt()
49 | bias = b - rm * scale
50 | return x * scale + bias
51 |
52 |
53 | class Backbone(nn.Module):
54 | def __init__(self, args):
55 | super(Backbone, self).__init__()
56 |
57 | if args.frozen_batch_norm:
58 | backbone = getattr(torchvision.models, args.backbone)(
59 | replace_stride_with_dilation=[False, False, args.dilation],
60 | pretrained=True, norm_layer=FrozenBatchNorm2d)
61 | else:
62 | backbone = getattr(torchvision.models, args.backbone)(
63 | replace_stride_with_dilation=[False, False, args.dilation],
64 | pretrained=True)
65 |
66 | self.num_frames = args.num_frame
67 | self.num_channels = 512 if args.backbone in ('resnet18', 'resnet34') else 2048
68 |
69 | self.body = IntermediateLayerGetter(backbone, return_layers={'layer4': "0"})
70 |
71 | def forward(self, x):
72 | x = self.body(x)["0"]
73 |
74 | return x
75 |
76 |
77 | class Joiner(nn.Sequential):
78 | def __init__(self, backbone, position_embedding):
79 | super().__init__(backbone, position_embedding)
80 |
81 | def forward(self, x):
82 | bs, t, _, h, w = x.shape
83 | x = x.reshape(bs * t, 3, h, w)
84 |
85 | features = self[0](x)
86 | _, c, oh, ow = features.shape
87 |
88 | pos = self[1](features).to(x.dtype)
89 |
90 | return features, pos
91 |
92 |
93 | def build_backbone(args):
94 | pos_embed = build_position_encoding(args)
95 | backbone = Backbone(args)
96 | model = Joiner(backbone, pos_embed)
97 | model.num_channels = backbone.num_channels
98 | return model
99 |
--------------------------------------------------------------------------------
/models/criterion.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR)
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | import torch
9 | import torch.nn.functional as F
10 | import copy
11 | import numpy as np
12 |
13 | from torch import nn
14 |
15 | from util import box_ops
16 | from util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized)
17 |
18 |
19 | class SetCriterion(nn.Module):
20 | def __init__(self, num_classes, weight_dict, eos_coef, losses, group_losses=None,
21 | group_matcher=None, args=None):
22 | """ Create the criterion.
23 | Parameters:
24 | num_classes: number of object categories, omitting the special no-object category
25 | weight_dict: dict containing as key the names of the losses and as values their relative weight.
26 | eos_coef: relative classification weight applied to the no-group activity class category
27 | losses: list of all the losses to be applied. See get_loss for list of available losses.
28 | group_losses: list of all the group losses to be applied. See get_group_loss for list of available group losses.
29 | group_matcher: module able to compute a matching between targets and predictions
30 | """
31 | super().__init__()
32 | self.num_classes = num_classes
33 | self.weight_dict = weight_dict
34 | self.losses = losses
35 | self.eos_coef = eos_coef
36 |
37 | self.group_losses = group_losses
38 | self.group_matcher = group_matcher
39 |
40 | empty_weight = torch.ones(self.num_classes + 1)
41 | empty_weight[-1] = eos_coef
42 | self.register_buffer('empty_weight', empty_weight)
43 |
44 | empty_group_weight = torch.ones(self.num_classes + 1)
45 | empty_group_weight[-1] = args.group_eos_coef
46 | self.register_buffer('empty_group_weight', empty_group_weight)
47 |
48 | self.num_boxes = args.num_boxes
49 |
50 | # option
51 | self.temperature = args.temperature
52 |
53 | #######################################################################################################################
54 | # * Individual Losses
55 | #######################################################################################################################
56 | def loss_labels(self, outputs, targets, num_boxes, log=True):
57 | """Individual action classification loss (NLL)"""
58 | assert 'pred_actions' in outputs
59 | src_logits = outputs['pred_actions']
60 | target_classes = torch.cat([v["actions"] for v in targets], dim=0)
61 |
62 | loss_ce = 0.0
63 |
64 | src_logits_log = None
65 | tgt_classes_log = None
66 |
67 | for batch_idx in range(src_logits.shape[0]):
68 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze()
69 | non_dummy_idx = dummy_idx.nonzero(as_tuple=True)
70 | src_logit = src_logits[batch_idx][non_dummy_idx].unsqueeze(0)
71 | target_class = target_classes[batch_idx][non_dummy_idx].unsqueeze(0)
72 | loss_ce += F.cross_entropy(src_logit.transpose(1, 2), target_class, self.empty_weight)
73 |
74 | if src_logits_log is None:
75 | src_logits_log = src_logit
76 | tgt_classes_log = target_class
77 | else:
78 | src_logits_log = torch.cat([src_logits_log.squeeze(), src_logit.squeeze()], dim=0)
79 | tgt_classes_log = torch.cat([tgt_classes_log.squeeze(), target_class.squeeze()], dim=0)
80 |
81 | loss_ce /= src_logits.shape[0]
82 | losses = {'loss_ce': loss_ce}
83 |
84 | if log:
85 | # TODO this should probably be a separate loss, not hacked in this one here
86 | losses['class_error'] = 100 - accuracy(src_logits_log, tgt_classes_log)[0]
87 |
88 | return losses
89 |
90 | @torch.no_grad()
91 | def loss_cardinality(self, outputs, targets, num_boxes):
92 | pred_logits = outputs['pred_actions']
93 | device = pred_logits.device
94 | # tgt_lengths = torch.as_tensor([len(v["actions"]) for v in targets], device=device)
95 | tgt_lengths = torch.as_tensor([len(k) for v in targets for k in v["actions"]], device=device)
96 | # Count the number of predictions that are NOT "no-object" (which is the last class)
97 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
98 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
99 | losses = {'cardinality_error': card_err}
100 | return losses
101 |
102 | #######################################################################################################################
103 | # * Group Losses
104 | #######################################################################################################################
105 | def loss_group_labels(self, outputs, targets, group_indices, log=True):
106 | """ Group activity classification loss (NLL)"""
107 | assert 'pred_activities' in outputs
108 | src_logits = outputs['pred_activities']
109 |
110 | idx = self._get_src_permutation_idx(group_indices)
111 | flatten_targets = [u for t in targets for u in t["activities"]]
112 | target_classes_o = torch.cat([t[J] for t, (_, J) in zip(flatten_targets, group_indices)])
113 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64,
114 | device=src_logits.device)
115 | target_classes[idx] = target_classes_o
116 |
117 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_group_weight)
118 | losses = {'loss_group_ce': loss_ce}
119 |
120 | if log:
121 | # TODO this should probably be a separate loss, not hacked in this one here
122 | losses['group_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0]
123 |
124 | return losses
125 |
126 | @torch.no_grad()
127 | def loss_group_cardinality(self, outputs, targets, group_indices):
128 | pred_logits = outputs['pred_activities']
129 | device = pred_logits.device
130 | tgt_lengths = torch.as_tensor([len(k) for v in targets for k in v["activities"]], device=device)
131 | # Count the number of predictions that are NOT "no-object" (which is the last class)
132 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1)
133 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
134 | losses = {'group_cardinality_error': card_err}
135 | return losses
136 |
137 | def loss_group_code(self, outputs, targets, group_indices, log=True):
138 | """Membership loss"""
139 | sim = outputs['membership']
140 |
141 | idx = self._get_src_permutation_idx(group_indices)
142 |
143 | # Binary cross entropy loss
144 | flatten_targets = [u for t in targets for u in t["members"]]
145 |
146 | # target_members_o = torch.cat([t[J] for t, (_, J) in zip(flatten_targets, group_indices)]).type(torch.FloatTensor).to(sim.device)
147 | target_members_o = torch.cat([t[J] for t, (_, J) in zip(flatten_targets, group_indices)])
148 | target_members = torch.full(sim.shape, 0.0, dtype=torch.float, device=sim.device)
149 | target_members[idx] = target_members_o
150 |
151 | loss_membership = 0.0
152 | for batch_idx in range(sim.shape[0]):
153 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze()
154 | non_dummy_idx = dummy_idx.nonzero(as_tuple=True)
155 | sim_batch = sim[batch_idx].transpose(0, 1)[non_dummy_idx].transpose(0, 1).unsqueeze(0)
156 |
157 | target_members_batch = target_members[batch_idx].transpose(0, 1)[non_dummy_idx].transpose(0, 1).unsqueeze(0)
158 |
159 | loss_membership += F.binary_cross_entropy(sim_batch, target_members_batch)
160 | loss_membership /= sim.shape[0]
161 |
162 | losses = {'loss_group_code': loss_membership}
163 | return losses
164 |
165 | def loss_group_consistency(self, outputs, targets, group_indices):
166 | """Group consistency loss"""
167 | actor_embeds = outputs['actor_embeddings']
168 |
169 | consistency_loss = 0.0
170 |
171 | for batch_idx in range(actor_embeds.shape[0]):
172 | membership = targets[batch_idx]["membership"][0]
173 | actor_embed = actor_embeds[batch_idx] # [n, f]
174 |
175 | cos = nn.CosineSimilarity(dim=-1)
176 | sim = cos(actor_embed.unsqueeze(1), actor_embed.unsqueeze(0)) / self.temperature
177 |
178 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze()
179 | non_dummy_idx = dummy_idx.nonzero(as_tuple=True)
180 |
181 | N = len(non_dummy_idx[0])
182 |
183 | non_dummy_membership = membership[non_dummy_idx]
184 |
185 | group_count = 0
186 |
187 | for actor_idx in range(N):
188 | group_id = non_dummy_membership[actor_idx]
189 |
190 | if group_id != -1:
191 | positive_idx = (non_dummy_membership == group_id).nonzero(as_tuple=True)
192 | positive_idx = list(positive_idx[0])
193 | positive_idx.remove(actor_idx)
194 | positive_idx = [tuple(positive_idx)]
195 | positive_samples = sim[actor_idx][positive_idx]
196 |
197 | negative_idx = (non_dummy_membership != group_id).nonzero(as_tuple=True)
198 | negative_samples = sim[actor_idx][negative_idx]
199 |
200 | nominator = torch.exp(positive_samples)
201 | denominator = torch.exp(torch.cat((positive_samples, negative_samples)))
202 | loss_partial = -torch.log(torch.sum(nominator) / torch.sum(denominator))
203 | group_count += 1
204 |
205 | consistency_loss += loss_partial
206 |
207 | consistency_loss /= group_count
208 |
209 | consistency_loss /= actor_embeds.shape[0]
210 | losses = {'loss_consistency': consistency_loss}
211 | return losses
212 |
213 | def _get_src_permutation_idx(self, indices):
214 | # permute predictions following indices
215 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
216 | src_idx = torch.cat([src for (src, _) in indices])
217 | return batch_idx, src_idx
218 |
219 | def _get_tgt_permutation_idx(self, indices):
220 | # permute targets following indices
221 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
222 | tgt_idx = torch.cat([tgt for (_, tgt) in indices])
223 | return batch_idx, tgt_idx
224 |
225 | # *****************************************************************************
226 | # >>> DETR Losses
227 | def get_loss(self, loss, outputs, targets, num_boxes, **kwargs):
228 | loss_map = {
229 | 'labels': self.loss_labels,
230 | 'cardinality': self.loss_cardinality,
231 | }
232 | assert loss in loss_map, f'do you really want to compute {loss} loss?'
233 | return loss_map[loss](outputs, targets, num_boxes, **kwargs)
234 |
235 | # >>> Group Losses
236 | def get_group_loss(self, loss, outputs, targets, group_indices, **kwargs):
237 | loss_map = {
238 | 'group_labels': self.loss_group_labels,
239 | 'group_cardinality': self.loss_group_cardinality,
240 | 'group_code': self.loss_group_code,
241 | 'group_consistency': self.loss_group_consistency,
242 | }
243 | assert loss in loss_map, f'do you really want to compute {loss} loss?'
244 | return loss_map[loss](outputs, targets, group_indices, **kwargs)
245 |
246 | # *****************************************************************************
247 |
248 | def forward(self, outputs, targets, log=True):
249 | """ This performs the loss computation.
250 | Parameters:
251 | outputs: dict of tensors, see the output specification of the model for the format
252 | targets: list of dicts, such that len(targets) == batch_size.
253 | The expected keys in each dict depends on the losses applied, see each loss' doc
254 | """
255 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
256 |
257 | num_boxes = sum(len(u) for t in targets for u in t["actions"])
258 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
259 | if is_dist_avail_and_initialized():
260 | torch.distributed.all_reduce(num_boxes)
261 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
262 |
263 | sim = outputs['membership']
264 | bs, num_queries, num_clip_boxes = sim.shape
265 |
266 | for tgt in targets:
267 | tgt["dummy_idx"] = torch.ones_like(tgt["actions"], dtype=int)
268 | for box_idx in range(num_clip_boxes):
269 | if bool(tgt["actions"][0, box_idx] == self.num_classes + 1):
270 | tgt["dummy_idx"][0, box_idx] = 0
271 |
272 | input_targets = [copy.deepcopy(target) for target in targets]
273 | group_indices = self.group_matcher(outputs_without_aux, input_targets)
274 |
275 | # Compute all the requested losses
276 | losses = {}
277 | for loss in self.losses:
278 | losses.update(self.get_loss(loss, outputs, targets, num_boxes))
279 |
280 | # Group activity detection losses
281 | for loss in self.group_losses:
282 | losses.update(self.get_group_loss(loss, outputs, targets, group_indices))
283 |
284 | return losses
285 |
--------------------------------------------------------------------------------
/models/feed_forward.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR)
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | import torch.nn.functional as F
6 | from torch import nn
7 |
8 |
9 | class MLP(nn.Module):
10 | """ Very simple multi-layer perceptron (also called FFN)"""
11 |
12 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
13 | super().__init__()
14 | self.num_layers = num_layers
15 | h = [hidden_dim] * (num_layers - 1)
16 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
17 |
18 | def forward(self, x):
19 | for i, layer in enumerate(self.layers):
20 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
21 | return x
--------------------------------------------------------------------------------
/models/group_matcher.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from DETR (https://github.com/facebookresearch/detr)
3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | """
6 | Modules to compute the matching cost and solve the corresponding LSAP.
7 | """
8 | import torch
9 | from scipy.optimize import linear_sum_assignment
10 | from torch import nn
11 | import torch.nn.functional as F
12 |
13 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
14 |
15 |
16 | class HungarianMatcher(nn.Module):
17 | """This class computes an assignment between the targets and the predictions of the network
18 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
19 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
20 | while the others are un-matched (and thus treated as non-objects).
21 | """
22 |
23 | def __init__(self, cost_class: float = 1, cost_code: float = 1):
24 | """Creates the matcher
25 | Params:
26 | cost_class: This is the relative weight of the classification error in the matching cost
27 | cost_code: This is the relative weight of the L2 error of the membership code in the matching cost
28 | """
29 | super().__init__()
30 | self.cost_class = cost_class
31 | self.cost_code = cost_code
32 | assert cost_class != 0 or cost_code != 0, "all costs cant be 0"
33 |
34 | # membership cost
35 | def _get_cost_code(self, sim, targets_membership):
36 | cost_code = torch.cdist(sim, targets_membership, p=2)
37 | return cost_code
38 |
39 | @torch.no_grad()
40 | def forward(self, outputs, targets):
41 | bs, num_queries, num_boxes = outputs["membership"].shape
42 |
43 | # We flatten to compute the cost matrices in a batch
44 | out_prob = outputs["pred_activities"].flatten(0, 1).softmax(-1) # [bs * t * num_queries, num_classes]
45 |
46 | sim = outputs["membership"]
47 |
48 | if "dummy_idx" in targets[0].keys():
49 | sim_dummy = []
50 | for batch_idx in range(bs):
51 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze()
52 | sim_batch = sim[batch_idx] * dummy_idx
53 | sim_dummy.append(sim_batch.unsqueeze(0))
54 | sim = torch.cat(sim_dummy, dim=0)
55 |
56 | sim = sim.flatten(0, 1)
57 |
58 | # Also concat the target labels and boxes
59 | tgt_ids = torch.cat([v["activities"] for v in targets], dim=1).reshape(-1)
60 | targets_membership = torch.cat([v["members"] for v in targets], dim=1).reshape(-1, num_boxes)
61 |
62 | # Compute the classification cost. Contrary to the loss, we don't use the NLL,
63 | # but approximate it in 1 - proba[target class].
64 | # The 1 is a constant that doesn't change the matching, it can be ommitted.
65 | cost_class = -out_prob[:, tgt_ids]
66 | # If cost code is similarity, it should be minus sign
67 | cost_code = self._get_cost_code(sim, targets_membership)
68 |
69 | # Final cost matrix
70 | C = self.cost_class * cost_class + self.cost_code * cost_code
71 | C = C.view(bs, num_queries, -1).cpu()
72 |
73 | sizes = [len(k) for v in targets for k in v["activities"]]
74 |
75 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
76 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
77 |
78 |
79 | def build_group_matcher(args):
80 | return HungarianMatcher(cost_class=args.set_cost_group_class, cost_code=args.set_cost_membership)
81 |
--------------------------------------------------------------------------------
/models/group_transformer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR)
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | import copy
9 | from typing import Optional
10 |
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn, Tensor
14 |
15 |
16 | class Transformer(nn.Module):
17 |
18 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_actor=14,
19 | dim_feedforward=2048, dropout=0.1, activation="relu", return_intermediate_dec=False):
20 | super().__init__()
21 |
22 | decoder_layer = TransformerDecoderLayer(d_model, nhead, num_actor, dim_feedforward, dropout, activation)
23 | decoder_norm = nn.LayerNorm(d_model)
24 | self.decoder = TransformerDecoder(decoder_layer, num_encoder_layers, decoder_norm,
25 | return_intermediate=return_intermediate_dec)
26 |
27 | self._reset_parameters()
28 | self.d_model = d_model
29 | self.nhead = nhead
30 |
31 | def _reset_parameters(self):
32 | for p in self.parameters():
33 | if p.dim() > 1:
34 | nn.init.xavier_uniform_(p)
35 |
36 | def forward(self, src, actor_mask, group_dummy_mask, group_embed, pos_embed, actor_embed):
37 | bs, t, c, h, w = src.shape
38 | _, _, n, _ = actor_embed.shape
39 | src = src.reshape(bs * t, c, h, w).flatten(2).permute(2, 0, 1) # [h x w, bs x t, c]
40 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
41 | group_embed = group_embed.reshape(-1, t, c)
42 | group_embed = group_embed.unsqueeze(1).repeat(1, bs, 1, 1).reshape(-1, bs * t, c)
43 | actor_embed = actor_embed.reshape(bs * t, -1, c).permute(1, 0, 2) # [n, bs x t, c]
44 | query_embed = torch.cat([actor_embed, group_embed], dim=0) # [n + k, bs x t, c]
45 | tgt = torch.zeros_like(query_embed) # [n + k, bs x t, c]
46 | if actor_mask is not None:
47 | actor_mask = actor_mask.unsqueeze(1).repeat(1, self.nhead, 1, 1).reshape(-1, n, n)
48 | hs, actor_att, feature_att = self.decoder(tgt, src, attn_mask=actor_mask,
49 | tgt_key_padding_mask=group_dummy_mask,
50 | pos=pos_embed, query_pos=query_embed)
51 |
52 | return hs.transpose(1, 2), actor_att, feature_att
53 |
54 |
55 | class TransformerDecoder(nn.Module):
56 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
57 | super().__init__()
58 | self.layers = _get_clones(decoder_layer, num_layers)
59 | self.num_layers = num_layers
60 | self.norm = norm
61 | self.return_intermediate = return_intermediate
62 |
63 | def forward(self, tgt, memory,
64 | tgt_mask: Optional[Tensor] = None,
65 | memory_mask: Optional[Tensor] = None,
66 | tgt_key_padding_mask: Optional[Tensor] = None,
67 | memory_key_padding_mask: Optional[Tensor] = None,
68 | attn_mask: Optional[Tensor] = None,
69 | pos: Optional[Tensor] = None,
70 | query_pos: Optional[Tensor] = None):
71 | output = tgt
72 | actor_att = None
73 | feature_att = None
74 |
75 | intermediate = []
76 | intermediate_actor_att = []
77 | intermediate_feature_att = []
78 |
79 | for layer in self.layers:
80 | output, actor_att, feature_att = layer(output, memory, tgt_mask=tgt_mask,
81 | memory_mask=memory_mask,
82 | tgt_key_padding_mask=tgt_key_padding_mask,
83 | memory_key_padding_mask=memory_key_padding_mask,
84 | attn_mask=attn_mask,
85 | pos=pos, query_pos=query_pos)
86 | if self.return_intermediate:
87 | intermediate.append(self.norm(output))
88 | intermediate_actor_att.append(actor_att)
89 | intermediate_feature_att.append(feature_att)
90 |
91 | if self.norm is not None:
92 | output = self.norm(output)
93 | if self.return_intermediate:
94 | intermediate.pop()
95 | intermediate.append(output)
96 |
97 | if self.return_intermediate:
98 | return torch.stack(intermediate), torch.stack(intermediate_actor_att), torch.stack(intermediate_feature_att)
99 |
100 | if actor_att is not None:
101 | actor_att = actor_att.unsqueeze(0)
102 | if feature_att is not None:
103 | feature_att = feature_att.unsqueeze(0)
104 |
105 | return output.unsqueeze(0), actor_att, feature_att
106 |
107 |
108 | class TransformerDecoderLayer(nn.Module):
109 |
110 | def __init__(self, d_model, nhead, num_actor, dim_feedforward=2048, dropout=0.1, activation="relu"):
111 | super().__init__()
112 |
113 | self.multihead_attn1 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
114 |
115 | self.self_attn1 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
116 | self.self_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
117 | self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
118 |
119 | # Implementation of Feedforward model
120 | self.linear1 = nn.Linear(d_model, dim_feedforward)
121 | self.dropout = nn.Dropout(dropout)
122 | self.linear2 = nn.Linear(dim_feedforward, d_model)
123 |
124 | self.norm1 = nn.LayerNorm(d_model)
125 | self.norm2 = nn.LayerNorm(d_model)
126 | self.norm3 = nn.LayerNorm(d_model)
127 | self.norm4 = nn.LayerNorm(d_model)
128 | self.dropout1 = nn.Dropout(dropout)
129 | self.dropout2 = nn.Dropout(dropout)
130 | self.dropout3 = nn.Dropout(dropout)
131 | self.dropout4 = nn.Dropout(dropout)
132 |
133 | self.activation = _get_activation_fn(activation)
134 |
135 | self.num_actor = num_actor
136 |
137 | def with_pos_embed(self, tensor, pos: Optional[Tensor]):
138 | return tensor if pos is None else tensor + pos
139 |
140 | def forward_post(self, tgt, memory,
141 | tgt_mask: Optional[Tensor] = None,
142 | memory_mask: Optional[Tensor] = None,
143 | tgt_key_padding_mask: Optional[Tensor] = None,
144 | memory_key_padding_mask: Optional[Tensor] = None,
145 | attn_mask=None,
146 | pos: Optional[Tensor] = None,
147 | query_pos: Optional[Tensor] = None):
148 |
149 | q = k = self.with_pos_embed(tgt, query_pos)
150 |
151 | actor_q = q[:self.num_actor, :, :]
152 | group_q = q[self.num_actor:, :, :]
153 | actor_k = k[:self.num_actor, :, :]
154 | group_k = k[self.num_actor:, :, :]
155 |
156 | tgt_actor = tgt[:self.num_actor, :, :]
157 | tgt_group = tgt[self.num_actor:, :, :]
158 |
159 | # actor-actor, group-group self-attention
160 | tgt2_actor, actor_att = self.self_attn1(actor_q, actor_k, value=tgt_actor, attn_mask=attn_mask)
161 | tgt2_group, _ = self.self_attn2(group_q, group_k, value=tgt_group)
162 | tgt2 = torch.cat([tgt2_actor, tgt2_group], dim=0)
163 |
164 | tgt = tgt + self.dropout1(tgt2)
165 | tgt = self.norm1(tgt)
166 |
167 | # actor-group attention
168 | tgt_actor = tgt[:self.num_actor, :, :]
169 | tgt_group = tgt[self.num_actor:, :, :]
170 |
171 | tgt2_group, group_att = self.multihead_attn1(query=tgt_group, key=tgt_actor, value=tgt_actor,
172 | key_padding_mask=tgt_key_padding_mask)
173 |
174 | tgt2 = torch.cat([tgt_actor, tgt2_group], dim=0)
175 |
176 | tgt = tgt + self.dropout2(tgt2)
177 | tgt = self.norm2(tgt)
178 |
179 | # actor-feature, group-feature cross-attention
180 | tgt2, feature_att = self.multihead_attn2(query=self.with_pos_embed(tgt, query_pos),
181 | key=self.with_pos_embed(memory, pos), value=memory)
182 | tgt = tgt + self.dropout3(tgt2)
183 | tgt = self.norm3(tgt)
184 |
185 | # FFN
186 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
187 | tgt = tgt + self.dropout4(tgt2)
188 | tgt = self.norm4(tgt)
189 |
190 | return tgt, group_att, feature_att
191 |
192 | def forward(self, tgt, memory,
193 | tgt_mask: Optional[Tensor] = None,
194 | memory_mask: Optional[Tensor] = None,
195 | tgt_key_padding_mask: Optional[Tensor] = None,
196 | memory_key_padding_mask: Optional[Tensor] = None,
197 | attn_mask: Optional[Tensor] = None,
198 | pos: Optional[Tensor] = None,
199 | query_pos: Optional[Tensor] = None):
200 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask,
201 | memory_key_padding_mask, attn_mask, pos, query_pos)
202 |
203 |
204 | def _get_clones(module, N):
205 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
206 |
207 |
208 | def build_group_transformer(args):
209 | return Transformer(
210 | d_model=args.hidden_dim,
211 | dropout=args.drop_rate,
212 | nhead=args.gar_nheads,
213 | dim_feedforward=args.gar_ffn_dim,
214 | num_encoder_layers=args.gar_enc_layers,
215 | return_intermediate_dec=False,
216 | num_actor=args.num_boxes,
217 | )
218 |
219 |
220 | def _get_activation_fn(activation):
221 | """Return an activation function given a string"""
222 | if activation == "relu":
223 | return F.relu
224 | if activation == "gelu":
225 | return F.gelu
226 | if activation == "glu":
227 | return F.glu
228 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
229 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from roi_align.roi_align import RoIAlign
6 |
7 | from .backbone import build_backbone
8 | from .group_transformer import build_group_transformer
9 | from .feed_forward import MLP
10 |
11 |
12 | class GADTR(nn.Module):
13 | def __init__(self, args):
14 | super(GADTR, self).__init__()
15 |
16 | self.dataset = args.dataset
17 | self.num_class = args.num_class
18 | self.num_frame = args.num_frame
19 | self.num_boxes = args.num_boxes
20 |
21 | self.hidden_dim = args.hidden_dim
22 | self.backbone = build_backbone(args)
23 |
24 | # RoI Align
25 | self.crop_size = args.crop_size
26 | self.roi_align = RoIAlign(crop_height=self.crop_size, crop_width=self.crop_size)
27 | self.fc_emb = nn.Linear(self.crop_size*self.crop_size*self.backbone.num_channels, self.hidden_dim)
28 | self.drop_emb = nn.Dropout(p=args.drop_rate)
29 |
30 | # Actor embedding
31 | self.input_proj = nn.Conv2d(self.backbone.num_channels, self.hidden_dim, kernel_size=1)
32 | self.box_pos_emb = MLP(4, self.hidden_dim, self.hidden_dim, 3)
33 |
34 | # Individual action classification head
35 | self.class_emb = nn.Linear(self.hidden_dim, self.num_class + 1)
36 |
37 | # Group Transformer
38 | self.group_transformer = build_group_transformer(args)
39 | self.num_group_tokens = args.num_group_tokens
40 | self.group_query_emb = nn.Embedding(self.num_group_tokens * self.num_frame, self.hidden_dim)
41 |
42 | # Group activity classfication head
43 | self.group_emb = nn.Linear(self.hidden_dim, self.num_class + 1)
44 |
45 | # Distance mask threshold
46 | self.distance_threshold = args.distance_threshold
47 |
48 | # Membership prediction heads
49 | self.actor_match_emb = nn.Linear(self.hidden_dim, self.hidden_dim)
50 | self.group_match_emb = nn.Linear(self.hidden_dim, self.hidden_dim)
51 |
52 | self.relu = F.relu
53 |
54 | for name, m in self.named_modules():
55 | if 'backbone' not in name and 'group_transformer' not in name:
56 | if isinstance(m, nn.Linear):
57 | nn.init.kaiming_normal_(m.weight)
58 | if m.bias is not None:
59 | nn.init.zeros_(m.bias)
60 |
61 | def calculate_pairwise_distnace(self, boxes):
62 | bs = boxes.shape[0]
63 |
64 | rx = boxes.pow(2).sum(dim=2).reshape((bs, -1, 1))
65 | ry = boxes.pow(2).sum(dim=2).reshape((bs, -1, 1))
66 |
67 | dist = rx - 2.0 * boxes.matmul(boxes.transpose(1, 2)) + ry.transpose(1, 2)
68 |
69 | return torch.sqrt(dist)
70 |
71 | def forward(self, x, boxes, dummy_mask):
72 | """
73 | :param x: [B, T, 3, H, W]
74 | :param boxes: [B, T, N, 4]
75 | :param dummy_mask: [B, N]
76 | :return:
77 | """
78 | bs, t, _, h, w = x.shape
79 | n = boxes.shape[2]
80 |
81 | boxes = torch.reshape(boxes, (-1, 4)) # [b x t x n, 4]
82 | boxes_flat = boxes.clone().detach()
83 | boxes_idx = [i * torch.ones(n, dtype=torch.int) for i in range(bs * t)]
84 | boxes_idx = torch.stack(boxes_idx).to(device=boxes.device)
85 | boxes_idx_flat = torch.reshape(boxes_idx, (bs * t * n, )) # [b x t x n]
86 |
87 | features, pos = self.backbone(x)
88 | _, c, oh, ow = features.shape # [b x t, d, oh, ow]
89 |
90 | src = self.input_proj(features)
91 | src = torch.reshape(src, (bs, t, -1, oh, ow)) # [b, t, c, oh, ow]
92 |
93 | # calculate distance & distance mask
94 | boxes_center = boxes.clone().detach()
95 | boxes_center = torch.reshape(boxes_center[:, :2], (-1, n, 2))
96 | boxes_distance = self.calculate_pairwise_distnace(boxes_center)
97 |
98 | distance_mask = (boxes_distance > self.distance_threshold)
99 |
100 | # ignore dummy boxes (padded boxes to match the number of actors)
101 | dummy_mask = dummy_mask.unsqueeze(1).repeat(1, t, 1).reshape(-1, n)
102 | actor_dummy_mask = (~dummy_mask.unsqueeze(2)).float() @ (~dummy_mask.unsqueeze(1)).float()
103 | dummy_diag = (dummy_mask.unsqueeze(2).float() @ dummy_mask.unsqueeze(1).float()).nonzero(as_tuple=True)
104 | actor_mask = ~(actor_dummy_mask.bool())
105 | actor_mask[dummy_diag] = False
106 | actor_mask = distance_mask + actor_mask
107 | group_dummy_mask = dummy_mask
108 |
109 | boxes_flat[:, 0] = (boxes[:, 0] - boxes[:, 2] / 2) * ow
110 | boxes_flat[:, 1] = (boxes[:, 1] - boxes[:, 3] / 2) * oh
111 | boxes_flat[:, 2] = (boxes[:, 0] + boxes[:, 2] / 2) * ow
112 | boxes_flat[:, 3] = (boxes[:, 1] + boxes[:, 3] / 2) * oh
113 |
114 | boxes_flat.requires_grad = False
115 | boxes_idx_flat.requires_grad = False
116 |
117 | # extract actor features
118 | actor_features = self.roi_align(features, boxes_flat, boxes_idx_flat)
119 | actor_features = torch.reshape(actor_features, (bs * t * n, -1))
120 | actor_features = self.fc_emb(actor_features)
121 | actor_features = F.relu(actor_features)
122 | actor_features = self.drop_emb(actor_features)
123 | actor_features = actor_features.reshape(bs, t, n, self.hidden_dim)
124 |
125 | # add positional information to box features
126 | box_pos_emb = self.box_pos_emb(boxes)
127 | box_pos_emb = torch.reshape(box_pos_emb, (bs, t, n, -1)) # [b, t, n, c]
128 | actor_features = actor_features + box_pos_emb
129 |
130 | # group transformer
131 | hs, actor_att, feature_att = self.group_transformer(src, actor_mask, group_dummy_mask,
132 | self.group_query_emb.weight, pos, actor_features)
133 | # [1, bs * t, n + k, f'], [1, bs * t, k, n], [1, bs * t, n + k, oh x ow] M: # group tokens, K: # boxes
134 |
135 | actor_hs = hs[0, :, :n]
136 | group_hs = hs[0, :, n:]
137 |
138 | actor_hs = actor_hs.reshape(bs, t, n, -1)
139 | actor_hs = actor_features + actor_hs
140 |
141 | # normalize
142 | inst_repr = F.normalize(actor_hs.reshape(bs, t, n, -1).mean(dim=1), p=2, dim=2)
143 | group_repr = F.normalize(group_hs.reshape(bs, t, self.num_group_tokens, -1).mean(dim=1), p=2, dim=2)
144 |
145 | # prediction heads
146 | outputs_class = self.class_emb(actor_hs)
147 |
148 | outputs_group_class = self.group_emb(group_hs)
149 |
150 | outputs_actor_emb = self.actor_match_emb(inst_repr)
151 | outputs_group_emb = self.group_match_emb(group_repr)
152 |
153 | membership = torch.bmm(outputs_group_emb, outputs_actor_emb.transpose(1, 2))
154 | membership = F.softmax(membership, dim=1)
155 |
156 | out = {
157 | "pred_actions": outputs_class.reshape(bs, t, self.num_boxes, self.num_class + 1).mean(dim=1),
158 | "pred_activities": outputs_group_class.reshape(bs, t, self.num_group_tokens, self.num_class + 1).mean(dim=1),
159 | "membership": membership.reshape(bs, self.num_group_tokens, self.num_boxes),
160 | "actor_embeddings": F.normalize(actor_hs.reshape(bs, t, n, -1).mean(dim=1), p=2, dim=2),
161 | }
162 |
163 | return out
164 |
--------------------------------------------------------------------------------
/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from DETR (https://github.com/facebookresearch/detr)
3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 |
6 | import math
7 | import torch
8 | from torch import nn
9 |
10 |
11 | class PositionEmbeddingSine(nn.Module):
12 | """
13 | This is a more standard version of the position embedding, very similar to the one
14 | used by the Attention is all you need paper, generalized to work on images.
15 | """
16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
17 | super().__init__()
18 | self.num_pos_feats = num_pos_feats
19 | self.temperature = temperature
20 | self.normalize = normalize
21 | if scale is not None and normalize is False:
22 | raise ValueError("normalize should be True if scale is passed")
23 | if scale is None:
24 | scale = 2 * math.pi
25 | self.scale = scale
26 |
27 | def forward(self, x):
28 | bs, c, h, w = x.shape
29 |
30 | y_embed = torch.arange(1, h + 1, device=x.device).unsqueeze(0).unsqueeze(2)
31 | y_embed = y_embed.repeat(bs, 1, w)
32 | x_embed = torch.arange(1, w + 1, device=x.device).unsqueeze(0).unsqueeze(1)
33 | x_embed = x_embed.repeat(bs, h, 1)
34 |
35 | if self.normalize:
36 | eps = 1e-6
37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39 |
40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42 |
43 | pos_x = x_embed[:, :, :, None] / dim_t
44 | pos_y = y_embed[:, :, :, None] / dim_t
45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
48 | return pos
49 |
50 |
51 | class PositionEmbeddingLearned(nn.Module):
52 | """
53 | Absolute pos embedding, learned.
54 | """
55 | def __init__(self, num_pos_feats=256):
56 | super().__init__()
57 | self.row_embed = nn.Embedding(50, num_pos_feats)
58 | self.col_embed = nn.Embedding(50, num_pos_feats)
59 | self.reset_parameters()
60 |
61 | def reset_parameters(self):
62 | nn.init.uniform_(self.row_embed.weight)
63 | nn.init.uniform_(self.col_embed.weight)
64 |
65 | def forward(self, x):
66 | h, w = x.shape[-2:]
67 | i = torch.arange(w, device=x.device)
68 | j = torch.arange(h, device=x.device)
69 | x_emb = self.col_embed(i)
70 | y_emb = self.row_embed(j)
71 | pos = torch.cat([
72 | x_emb.unsqueeze(0).repeat(h, 1, 1),
73 | y_emb.unsqueeze(1).repeat(1, w, 1),
74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
75 | return pos
76 |
77 |
78 | def build_position_encoding(args):
79 | N_steps = args.hidden_dim // 2
80 | if args.position_embedding in ('v2', 'sine'):
81 | # TODO find a better way of exposing other arguments
82 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
83 | elif args.position_embedding in ('v3', 'learned'):
84 | position_embedding = PositionEmbeddingLearned(N_steps)
85 | else:
86 | raise ValueError(f"not supported {args.position_embedding}")
87 |
88 | return position_embedding
89 |
90 |
91 | def build_index_encoding(args):
92 | N_steps = args.hidden_dim // 2
93 | if args.index_embedding in ('v2', 'sine'):
94 | # TODO find a better way of exposing other arguments
95 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
96 | elif args.index_embedding in ('v3', 'learned'):
97 | position_embedding = PositionEmbeddingLearned(N_steps)
98 | else:
99 | raise ValueError(f"not supported {args.position_embedding}")
100 |
101 | return position_embedding
102 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # Towards More Practical Group Activity Detection:
A New Benchmark and Model
2 |
3 | ### [Dongkeun Kim](https://dk-kim.github.io/), [Youngkil Song](https://www.linkedin.com/in/youngkil-song-8936792a3/), [Minsu Cho](https://cvlab.postech.ac.kr/~mcho/), [Suha Kwak](https://suhakwak.github.io/)
4 |
5 | ### [Project Page](http://dk-kim.github.io/CAFE) | [Paper](https://arxiv.org/abs/2312.02878)
6 |
7 | ## Overview
8 | This work introduces the new benchmark dataset, Café, and a new model for group activity detection (GAD).
9 |
10 | ## Requirements
11 |
12 | - Ubuntu 20.04
13 | - Python 3.8.5
14 | - CUDA 11.0
15 | - PyTorch 1.7.1
16 |
17 | ## Conda environment installation
18 | conda env create --file environment.yml
19 |
20 | conda activate gad
21 |
22 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html
23 |
24 | ## Install additional package
25 | sh scripts/setup.sh
26 |
27 | ## Download datasets
28 |
29 | Download Café dataset from:
30 |
31 | https://cvlab.postech.ac.kr/research/CAFE/
32 |
33 | ## Download trained weights
34 |
35 | sh scripts/download_checkpoints.sh
36 |
37 | or from:
38 | https://drive.google.com/file/d/1W_2gkzARCzSdK8Db4G4pkzN3GrJTYo8R/view?usp=drive_link
39 |
40 | ## Run test scripts
41 |
42 | - Café dataset (split by view)
43 |
44 | sh scripts/test_cafe_view.sh
45 |
46 | - Café dataset (split by place)
47 |
48 |
49 | sh scripts/test_cafe_place.sh
50 |
51 | ## Run train scripts
52 |
53 | - Café dataset (split by view)
54 |
55 |
56 | sh scripts/train_cafe_view.sh
57 |
58 | - Café dataset (split by place)
59 |
60 |
61 | sh scripts/train_cafe_place.sh
62 |
63 |
64 | ## File structure
65 |
66 | ├── Dataset/
67 | │ └── cafe/
68 | │ └── gt_tracks.pkl
69 | ├── dataloader/
70 | ├── evaluation/
71 | │ └── gt_tracks.txt
72 | ├── label_map/
73 | ├── models/
74 | ├── scripts/
75 | └── util/
76 | train.py
77 | test.py
78 | environment.yml
79 | README.md
80 |
81 | ## Citation
82 | If you find our work useful, please consider citing our paper:
83 | ```BibTeX
84 | @article{kim2023towards,
85 | title={Towards More Practical Group Activity Detection: A New Benchmark and Model},
86 | author={Kim, Dongkeun and Song, Youngkil and Cho, Minsu and Kwak, Suha},
87 | journal={arXiv preprint arXiv:2312.02878},
88 | year={2023}
89 | }
90 | ```
91 |
92 | ## Acknowledgement
93 | This work was supported by the NRF grant and the IITP grant funded by Ministry of Science and ICT, Korea (RS-2019-II191906, IITP-2020-0-00842, NRF-2021R1A2C3012728, RS-2022-II220264).
94 |
--------------------------------------------------------------------------------
/scripts/download_checkpoints.sh:
--------------------------------------------------------------------------------
1 | wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/file/d/1M0vmruhcU_SpYQL0SPqUcg-wKlfyOyWa/view?usp=drive_link' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1M0vmruhcU_SpYQL0SPqUcg-wKlfyOyWa" -O cafe_checkpoints.zip && rm -rf ~/cookies.txt
2 | unzip cafe_checkpoints.zip
3 | rm cafe_checkpoints.zip
--------------------------------------------------------------------------------
/scripts/download_datasets.sh:
--------------------------------------------------------------------------------
1 | wget https://cvlab.postech.ac.kr/research/CAFE/Cafe_Dataset.zip
2 | unzip Dataset.zip
3 | rm Dataset.zip
4 |
--------------------------------------------------------------------------------
/scripts/setup.sh:
--------------------------------------------------------------------------------
1 | wget https://github.com/longcw/RoIAlign.pytorch/archive/refs/heads/master.zip
2 | unzip master.zip
3 | rm master.zip
4 | cd RoIAlign.pytorch-master/
5 | python setup.py install
6 | mv roi_align/ roi_align_torch/
7 | cd ../
--------------------------------------------------------------------------------
/scripts/test_cafe_place.sh:
--------------------------------------------------------------------------------
1 | python test.py --data_path Dataset/ --split 'place' --model_path cafe_place.pth
--------------------------------------------------------------------------------
/scripts/test_cafe_view.sh:
--------------------------------------------------------------------------------
1 | python test.py --data_path Dataset/ --split 'view' --model_path cafe_view.pth --random_seed 11
2 |
--------------------------------------------------------------------------------
/scripts/train_cafe_place.sh:
--------------------------------------------------------------------------------
1 | python train.py --data_path Dataset/ --split 'place'
--------------------------------------------------------------------------------
/scripts/train_cafe_view.sh:
--------------------------------------------------------------------------------
1 | python train.py --data_path Dataset/ --split 'view'
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.data as data
5 |
6 | import os
7 | import math
8 | import sys
9 | import copy
10 | import time
11 | import random
12 | import numpy as np
13 | import argparse
14 |
15 | from models import build_model
16 | from util.utils import *
17 | import util.misc as utils
18 | import util.logger as loggers
19 | from dataloader.dataloader import read_dataset
20 | import evaluation.cafe_eval as evaluation
21 |
22 | parser = argparse.ArgumentParser(description='Group Activity Detection train code', add_help=False)
23 |
24 | # Dataset specification
25 | parser.add_argument('--dataset', default='cafe', type=str, help='dataset name')
26 | parser.add_argument('--val_mode', action='store_true')
27 | parser.add_argument('--split', default='place', type=str, help='dataset split. place or view')
28 | parser.add_argument('--data_path', default='../Dataset/', type=str, help='data path')
29 | parser.add_argument('--image_width', default=1280, type=int, help='Image width to resize')
30 | parser.add_argument('--image_height', default=720, type=int, help='Image height to resize')
31 | parser.add_argument('--random_sampling', action='store_true', help='random sampling strategy')
32 | parser.add_argument('--num_frame', default=5, type=int, help='number of frames for each clip')
33 | parser.add_argument('--num_class', default=6, type=int, help='number of activity classes')
34 |
35 | # Backbone parameters
36 | parser.add_argument('--backbone', default='resnet18', type=str, help='feature extraction backbone')
37 | parser.add_argument('--dilation', action='store_true', help='use dilation or not')
38 | parser.add_argument('--frozen_batch_norm', action='store_true', help='use frozen batch normalization')
39 | parser.add_argument('--hidden_dim', default=256, type=int, help='transformer channel dimension')
40 |
41 | # RoI Align parameters
42 | parser.add_argument('--num_boxes', default=14, type=int, help='maximum number of actors')
43 | parser.add_argument('--crop_size', default=5, type=int, help='roi align crop size')
44 |
45 | # Group Transformer
46 | parser.add_argument('--gar_nheads', default=4, type=int, help='number of heads')
47 | parser.add_argument('--gar_enc_layers', default=6, type=int, help='number of group transformer layers')
48 | parser.add_argument('--gar_ffn_dim', default=512, type=int, help='feed forward network dimension')
49 | parser.add_argument('--position_embedding', default='sine', type=str, help='various position encoding')
50 | parser.add_argument('--num_group_tokens', default=12, type=int, help='number of group tokens')
51 | parser.add_argument('--aux_loss', action='store_true')
52 | parser.add_argument('--group_threshold', default=0.5, type=float, help='post processing threshold')
53 | parser.add_argument('--distance_threshold', default=0.2, type=float, help='distance mask threshold')
54 |
55 | # Loss option
56 | parser.add_argument('--temperature', default=0.2, type=float, help='consistency loss temperature')
57 |
58 | # Loss coefficients (Individual)
59 | parser.add_argument('--ce_loss_coef', default=1, type=float)
60 | parser.add_argument('--eos_coef', default=1, type=float,
61 | help="Relative classification weight of the no-object class")
62 |
63 | # Loss coefficients (Group)
64 | parser.add_argument('--group_eos_coef', default=1, type=float)
65 | parser.add_argument('--group_ce_loss_coef', default=1, type=float)
66 | parser.add_argument('--group_code_loss_coef', default=5, type=float)
67 | parser.add_argument('--consistency_loss_coef', default=2, type=float)
68 |
69 | # Matcher (Group)
70 | parser.add_argument('--set_cost_group_class', default=1, type=float,
71 | help="Class coefficient in the matching cost")
72 | parser.add_argument('--set_cost_membership', default=1, type=float,
73 | help="Membership coefficient in the matching cost")
74 |
75 | # Training parameters
76 | parser.add_argument('--random_seed', default=1, type=int, help='random seed for reproduction')
77 | parser.add_argument('--batch', default=16, type=int, help='Batch size')
78 | parser.add_argument('--test_batch', default=16, type=int, help='Test batch size')
79 | parser.add_argument('--drop_rate', default=0.1, type=float, help='Dropout rate')
80 | # GPU
81 | parser.add_argument('--device', default="0, 1", type=str, help='GPU device')
82 | parser.add_argument('--distributed', action='store_true')
83 |
84 | # Load model
85 | parser.add_argument('--model_path', default="", type=str, help='pretrained model path')
86 |
87 | # Visualization
88 | parser.add_argument('--result_path', default="./outputs/")
89 |
90 | # Evaluation
91 | parser.add_argument('--groundtruth', default='./evaluation/gt_tracks.txt', type=argparse.FileType("r"))
92 | parser.add_argument('--labelmap', default='./label_map/group_action_list.pbtxt', type=argparse.FileType("r"))
93 | parser.add_argument('--giou_thresh', default=1.0, type=float)
94 | parser.add_argument('--eval_type', default="gt_base", type=str, help='gt_based or detection_based')
95 |
96 | args = parser.parse_args()
97 | path = None
98 |
99 | SEQS_CAFE = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
100 |
101 | ACTIVITIES = ['Queueing', 'Ordering', 'Drinking', 'Working', 'Fighting', 'Selfie', 'Individual', 'No']
102 |
103 |
104 | def main():
105 | global args, path
106 |
107 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
108 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
109 |
110 | time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
111 | exp_name = '[%s]_GAD_<%s>' % (args.dataset, time_str)
112 |
113 | path = args.result_path + exp_name
114 | if not os.path.exists(path):
115 | os.makedirs(path)
116 |
117 | # set random seed
118 | random.seed(args.random_seed)
119 | np.random.seed(args.random_seed)
120 | torch.manual_seed(args.random_seed)
121 | torch.cuda.manual_seed(args.random_seed)
122 | torch.cuda.manual_seed_all(args.random_seed)
123 | torch.backends.cudnn.deterministic = True
124 | torch.backends.cudnn.benchmark = False
125 |
126 | _, test_set = read_dataset(args)
127 | sampler_test = data.RandomSampler(test_set)
128 |
129 | test_loader = data.DataLoader(test_set, args.test_batch, sampler=sampler_test, drop_last=False,
130 | collate_fn=collate_fn, num_workers=4, pin_memory=True)
131 |
132 | model, criterion = build_model(args)
133 | model = torch.nn.DataParallel(model).cuda()
134 |
135 | pretrained_dict = torch.load(args.model_path)['state_dict']
136 | new_state_dict = model.state_dict()
137 | for k, v in pretrained_dict.items():
138 | if k in new_state_dict:
139 | new_state_dict.update({k:v})
140 |
141 | model.load_state_dict(new_state_dict)
142 |
143 | metrics = evaluation.GAD_Evaluation(args)
144 |
145 | test_log, result = validate(test_loader, model, criterion, metrics)
146 | print("group mAP at 1.0: %.2f" % result['group_mAP_1.0'])
147 | print("group mAP at 0.5: %.2f" % result['group_mAP_0.5'])
148 | print("outlier mIoU: %.2f" % result['outlier_mIoU'])
149 |
150 |
151 | @torch.no_grad()
152 | def validate(test_loader, model, criterion, metrics):
153 | model.eval()
154 | criterion.eval()
155 |
156 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
157 | header = 'Evaluation Inference: '
158 |
159 | print_freq = len(test_loader)
160 | name_to_vid = {name: i + 1 for i, name in enumerate(SEQS_CAFE)}
161 | file_path = path + '/pred_group_test_%s.txt' % args.split
162 |
163 | for i, (images, targets, infos) in enumerate(metric_logger.log_every(test_loader, print_freq, header)):
164 | images = images.cuda() # [B, T, 3, H, W]
165 | targets = [{k: v.cuda() for k, v in t.items()} for t in targets]
166 |
167 | boxes = torch.stack([t['boxes'] for t in targets])
168 | dummy_mask = torch.stack([t['actions'] == args.num_class + 1 for t in targets]).squeeze()
169 |
170 | # compute output
171 | outputs = model(images, boxes, dummy_mask)
172 |
173 | loss_dict = criterion(outputs, targets)
174 | weight_dict = criterion.weight_dict
175 |
176 | # reduce losses over all GPUs for logging purposes
177 | loss_dict_reduced = utils.reduce_dict(loss_dict)
178 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
179 | for k, v in loss_dict_reduced.items() if k in weight_dict}
180 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
181 | for k, v in loss_dict_reduced.items()}
182 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
183 | **loss_dict_reduced_scaled,
184 | **loss_dict_reduced_unscaled)
185 |
186 | metric_logger.update(group_class_error=loss_dict_reduced['group_class_error'])
187 |
188 | make_txt(boxes, infos, outputs, name_to_vid, file_path)
189 |
190 | # gather the stats from all processes
191 | metric_logger.synchronize_between_processes()
192 | print("Averaged stats:", metric_logger)
193 |
194 | detections = open(file_path, "r")
195 | result = metrics.evaluate(detections)
196 |
197 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, result
198 |
199 |
200 | def make_txt(boxes, infos, outputs, name_to_vid, file_path):
201 | for b in range(boxes.shape[0]):
202 | for t in range(boxes.shape[1]):
203 | image_w, image_h = args.image_width, args.image_height
204 |
205 | pred_group_actions = outputs['pred_activities'][b]
206 | pred_group_actions = F.softmax(pred_group_actions, dim=1)
207 | members = outputs['membership'][b]
208 |
209 | pred_membership = torch.argmax(members.transpose(0, 1), dim=1).detach().cpu()
210 | keep_membership = members.transpose(0, 1).max(-1).values > args.group_threshold
211 | pred_group_action = torch.argmax(pred_group_actions, dim=1).detach().cpu()
212 |
213 | for box_idx in range(boxes.shape[2]):
214 | x, y, w, h = boxes[b][t][box_idx]
215 | x1, y1, x2, y2 = (x - w / 2) * image_w, (y - h / 2) * image_h, (x + w / 2) * image_w, (
216 | y + h / 2) * image_h
217 |
218 | pred_group_id = pred_membership[box_idx]
219 | pred_group_action_idx = pred_group_action[pred_group_id]
220 | pred_group_action_prob = pred_group_actions[pred_group_id][pred_group_action_idx]
221 |
222 | if not (x1 == 0 and y1 == 0 and x2 == 0 and y2 == 0):
223 | if pred_group_action_idx != (pred_group_actions.shape[-1] - 1):
224 | if bool(keep_membership[box_idx]) is False:
225 | pred_group_id = -1
226 | pred_group_action_idx = args.num_class
227 |
228 | pred_list = [name_to_vid[infos[b]['vid']], infos[b]['sid'], infos[b]['fid'][t],
229 | int(x1), int(y1), int(x2), int(y2),
230 | int(pred_group_id), int(pred_group_action_idx) + 1,
231 | float(pred_group_action_prob)]
232 | str_to_be_added = [str(k) for k in pred_list]
233 | str_to_be_added = (" ".join(str_to_be_added))
234 |
235 | f = open(file_path, "a+")
236 | f.write(str_to_be_added + "\r\n")
237 | f.close()
238 |
239 |
240 | def collate_fn(batch):
241 | batch = list(zip(*batch))
242 | batch[0] = torch.stack([image for image in batch[0]])
243 | return tuple(batch)
244 |
245 |
246 | if __name__ == '__main__':
247 | main()
248 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.utils.data as data
5 |
6 | import os
7 | import math
8 | import sys
9 | import copy
10 | import time
11 | import random
12 | import numpy as np
13 | import argparse
14 |
15 | from models import build_model
16 | from util.utils import *
17 | import util.misc as utils
18 | import util.logger as loggers
19 | from dataloader.dataloader import read_dataset
20 | import evaluation.cafe_eval as evaluation
21 |
22 | parser = argparse.ArgumentParser(description='Group Activity Detection train code', add_help=False)
23 |
24 | # Dataset specification
25 | parser.add_argument('--dataset', default='cafe', type=str, help='dataset name')
26 | parser.add_argument('--val_mode', action='store_true')
27 | parser.add_argument('--split', default='place', type=str, help='dataset split. place or view')
28 | parser.add_argument('--data_path', default='../Dataset/', type=str, help='data path')
29 | parser.add_argument('--image_width', default=1280, type=int, help='Image width to resize')
30 | parser.add_argument('--image_height', default=720, type=int, help='Image height to resize')
31 | parser.add_argument('--random_sampling', action='store_true', help='random sampling strategy')
32 | parser.add_argument('--num_frame', default=5, type=int, help='number of frames for each clip')
33 | parser.add_argument('--num_class', default=6, type=int, help='number of activity classes')
34 |
35 | # Backbone parameters
36 | parser.add_argument('--backbone', default='resnet18', type=str, help='feature extraction backbone')
37 | parser.add_argument('--dilation', action='store_true', help='use dilation or not')
38 | parser.add_argument('--frozen_batch_norm', action='store_true', help='use frozen batch normalization')
39 | parser.add_argument('--hidden_dim', default=256, type=int, help='transformer channel dimension')
40 |
41 | # RoI Align parameters
42 | parser.add_argument('--num_boxes', default=14, type=int, help='maximum number of actors')
43 | parser.add_argument('--crop_size', default=5, type=int, help='roi align crop size')
44 |
45 | # Group Transformer
46 | parser.add_argument('--gar_nheads', default=4, type=int, help='number of heads')
47 | parser.add_argument('--gar_enc_layers', default=6, type=int, help='number of group transformer layers')
48 | parser.add_argument('--gar_ffn_dim', default=512, type=int, help='feed forward network dimension')
49 | parser.add_argument('--position_embedding', default='sine', type=str, help='various position encoding')
50 | parser.add_argument('--num_group_tokens', default=12, type=int, help='number of group tokens')
51 | parser.add_argument('--aux_loss', action='store_true')
52 | parser.add_argument('--group_threshold', default=0.5, type=float, help='post processing threshold')
53 | parser.add_argument('--distance_threshold', default=0.2, type=float, help='distance mask threshold')
54 |
55 | # Loss option
56 | parser.add_argument('--temperature', default=0.2, type=float, help='consistency loss temperature')
57 |
58 | # Loss coefficients (Individual)
59 | parser.add_argument('--ce_loss_coef', default=1, type=float)
60 | parser.add_argument('--eos_coef', default=1, type=float,
61 | help="Relative classification weight of the no-object class")
62 |
63 | # Loss coefficients (Group)
64 | parser.add_argument('--group_eos_coef', default=1, type=float)
65 | parser.add_argument('--group_ce_loss_coef', default=1, type=float)
66 | parser.add_argument('--group_code_loss_coef', default=5, type=float)
67 | parser.add_argument('--consistency_loss_coef', default=2, type=float)
68 |
69 | # Matcher (Group)
70 | parser.add_argument('--set_cost_group_class', default=1, type=float,
71 | help="Class coefficient in the matching cost")
72 | parser.add_argument('--set_cost_membership', default=1, type=float,
73 | help="Membership coefficient in the matching cost")
74 |
75 | # Training parameters
76 | parser.add_argument('--random_seed', default=1, type=int, help='random seed for reproduction')
77 | parser.add_argument('--epochs', default=30, type=int, help='Max epochs')
78 | parser.add_argument('--test_freq', default=1, type=int, help='print frequency')
79 | parser.add_argument('--batch', default=16, type=int, help='Batch size')
80 | parser.add_argument('--test_batch', default=16, type=int, help='Test batch size')
81 | parser.add_argument('--lr', default=1e-5, type=float, help='Initial learning rate')
82 | parser.add_argument('--max_lr', default=1e-4, type=float, help='Max learning rate')
83 | parser.add_argument('--lr_step', default=4, type=int, help='step size for learning rate scheduler')
84 | parser.add_argument('--lr_step_down', default=25, type=int, help='step down size (cyclic) for learning rate scheduler')
85 | parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay')
86 | parser.add_argument('--drop_rate', default=0.1, type=float, help='Dropout rate')
87 | parser.add_argument('--gradient_clipping', action='store_true', help='use gradient clipping')
88 | parser.add_argument('--max_norm', default=1.0, type=float, help='gradient clipping max norm')
89 |
90 | # GPU
91 | parser.add_argument('--device', default="0, 1", type=str, help='GPU device')
92 | parser.add_argument('--distributed', action='store_true')
93 |
94 | # Load model
95 | parser.add_argument('--load_model', action='store_true', help='load model')
96 | parser.add_argument('--model_path', default="", type=str, help='pretrained model path')
97 |
98 | # Visualization
99 | parser.add_argument('--result_path', default="./outputs/")
100 |
101 | # Evaluation
102 | parser.add_argument('--groundtruth', default='./evaluation/gt_tracks.txt', type=argparse.FileType("r"))
103 | parser.add_argument('--labelmap', default='./label_map/group_action_list.pbtxt', type=argparse.FileType("r"))
104 | parser.add_argument('--giou_thresh', default=1.0, type=float)
105 | parser.add_argument('--eval_type', default="gt_base", type=str, help='gt_based or detection_based')
106 |
107 | args = parser.parse_args()
108 | path = None
109 |
110 | SEQS_CAFE = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
111 |
112 | ACTIVITIES = ['Queueing', 'Ordering', 'Drinking', 'Working', 'Fighting', 'Selfie', 'Individual', 'No']
113 |
114 |
115 | def main():
116 | global args, path
117 |
118 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
119 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device
120 |
121 | time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
122 | exp_name = '[%s]_GAD_<%s>' % (args.dataset, time_str)
123 | save_path = './result/%s' % exp_name
124 |
125 | # set random seed
126 | random.seed(args.random_seed)
127 | np.random.seed(args.random_seed)
128 | torch.manual_seed(args.random_seed)
129 | torch.cuda.manual_seed(args.random_seed)
130 | torch.cuda.manual_seed_all(args.random_seed)
131 | torch.backends.cudnn.deterministic = True
132 | torch.backends.cudnn.benchmark = False
133 |
134 | train_set, test_set = read_dataset(args)
135 |
136 | # for variable length input
137 | if args.distributed:
138 | sampler_train = data.DistributedSampler(train_set, shuffle=True)
139 | sampler_test = data.DistributedSampler(test_set, shuffle=False)
140 | else:
141 | sampler_train = data.RandomSampler(train_set)
142 | sampler_test = data.RandomSampler(test_set)
143 |
144 | batch_sampler_train = data.BatchSampler(sampler_train, args.batch, drop_last=True)
145 |
146 | train_loader = data.DataLoader(train_set, batch_sampler=batch_sampler_train,
147 | collate_fn=collate_fn, num_workers=4, pin_memory=True)
148 | test_loader = data.DataLoader(test_set, args.test_batch, sampler=sampler_test, drop_last=False,
149 | collate_fn=collate_fn, num_workers=4, pin_memory=True)
150 |
151 | model, criterion = build_model(args)
152 | model = torch.nn.DataParallel(model).cuda()
153 |
154 | # get the number of model parameters
155 | parameters = 'Number of full model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))
156 | print_log(save_path, '--------------------Number of parameters--------------------')
157 | print_log(save_path, parameters)
158 |
159 | # define loss function and optimizer
160 | optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), eps=1e-8,
161 | weight_decay=args.weight_decay)
162 |
163 | scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, args.lr, args.max_lr, step_size_up=args.lr_step,
164 | step_size_down=args.lr_step_down, mode='triangular2',
165 | cycle_momentum=False)
166 |
167 | if args.load_model:
168 | checkpoint = torch.load(args.model_path)
169 | model.load_state_dict(checkpoint['state_dict'])
170 | scheduler.load_state_dict(checkpoint['scheduler'])
171 | optimizer.load_state_dict(checkpoint['optimizer'])
172 | start_epoch = checkpoint['epoch'] + 1
173 | else:
174 | start_epoch = 1
175 |
176 | path = args.result_path + exp_name
177 | if not os.path.exists(path):
178 | os.makedirs(path)
179 |
180 | metrics = evaluation.GAD_Evaluation(args)
181 |
182 | # training phase
183 | for epoch in range(start_epoch, args.epochs + 1):
184 | print_log(save_path, '----- %s at epoch #%d' % ("Train", epoch))
185 | train_log = train(train_loader, model, criterion, optimizer, epoch)
186 | print_log(save_path, 'Loss: %.4f' % (train_log['loss']))
187 | print_log(save_path, 'Group class error: %.2f' % (train_log['group_class_error']))
188 | print('Current learning rate is %f' % scheduler.get_last_lr()[0])
189 | scheduler.step()
190 |
191 | if epoch % args.test_freq == 0:
192 | print_log(save_path, '----- %s at epoch #%d' % ("Test", epoch))
193 | test_log, result = validate(test_loader, model, criterion, metrics, epoch)
194 | print_log(save_path, 'Loss: %.4f' % (test_log['loss']))
195 | print_log(save_path, 'Group class error: %.2f' % (test_log['group_class_error']))
196 | print_log(save_path, "group mAP at 1.0: %.2f" % result['group_mAP_1.0'])
197 | print_log(save_path, "group mAP at 0.5: %.2f" % result['group_mAP_0.5'])
198 | print_log(save_path, "outlier mIoU: %.2f" % result['outlier_mIoU'])
199 |
200 | state = {
201 | 'epoch': epoch,
202 | 'state_dict': model.state_dict(),
203 | 'optimizer': optimizer.state_dict(),
204 | 'scheduler': scheduler.state_dict(),
205 | }
206 | result_path = save_path + '/epoch%d.pth' % epoch
207 | torch.save(state, result_path)
208 |
209 |
210 | def train(train_loader, model, criterion, optimizer, epoch):
211 | model.train()
212 | criterion.train()
213 |
214 | # logger
215 | metric_logger = loggers.MetricLogger(mode="train", delimiter=" ")
216 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
217 | space_fmt = str(len(str(args.epochs)))
218 | header = 'Epoch [{start_epoch: >{fill}}/{end_epoch}]'.format(start_epoch=epoch, end_epoch=args.epochs,
219 | fill=space_fmt)
220 | print_freq = len(train_loader)
221 |
222 | for i, (images, targets, infos) in enumerate(metric_logger.log_every(train_loader, print_freq, header)):
223 | images = images.cuda() # [B, T, 3, H, W]
224 | targets = [{k: v.cuda() for k, v in t.items()} for t in targets]
225 |
226 | boxes = torch.stack([t['boxes'] for t in targets])
227 | dummy_mask = torch.stack([t['actions'] == args.num_class + 1 for t in targets]).squeeze()
228 |
229 | num_batch = images.shape[0]
230 | num_frame = images.shape[1]
231 |
232 | # compute output
233 | outputs = model(images, boxes, dummy_mask)
234 |
235 | loss_dict = criterion(outputs, targets, log=False)
236 | weight_dict = criterion.weight_dict
237 |
238 | loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
239 |
240 | # reduce losses over all GPUs for logging purposes
241 | loss_dict_reduced = utils.reduce_dict(loss_dict)
242 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
243 | for k, v in loss_dict_reduced.items()}
244 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
245 | for k, v in loss_dict_reduced.items() if k in weight_dict}
246 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
247 | loss_value = losses_reduced_scaled.item()
248 |
249 | if not math.isfinite(loss_value):
250 | print("Loss is {}, stopping training".format(loss_value))
251 | print(loss_dict_reduced)
252 | sys.exit(1)
253 |
254 | # compute gradient and do SGD step
255 | optimizer.zero_grad()
256 | loss.backward()
257 | if args.gradient_clipping:
258 | nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
259 | optimizer.step()
260 |
261 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
262 | metric_logger.update(group_class_error=loss_dict_reduced['group_class_error'])
263 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
264 |
265 | metric_logger.synchronize_between_processes()
266 | print("Averaged stats:", metric_logger)
267 |
268 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
269 |
270 |
271 | @torch.no_grad()
272 | def validate(test_loader, model, criterion, metrics, epoch):
273 | model.eval()
274 | criterion.eval()
275 |
276 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
277 | metric_logger.add_meter('group_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
278 | header = 'Evaluation Inference: '
279 |
280 | print_freq = len(test_loader)
281 | name_to_vid = {name: i + 1 for i, name in enumerate(SEQS_CAFE)}
282 | file_path = path + '/pred_group_epoch_%d.txt' % epoch
283 |
284 | for i, (images, targets, infos) in enumerate(metric_logger.log_every(test_loader, print_freq, header)):
285 | images = images.cuda() # [B, T, 3, H, W]
286 | targets = [{k: v.cuda() for k, v in t.items()} for t in targets]
287 |
288 | boxes = torch.stack([t['boxes'] for t in targets])
289 | dummy_mask = torch.stack([t['actions'] == args.num_class + 1 for t in targets]).squeeze()
290 |
291 | # compute output
292 | outputs = model(images, boxes, dummy_mask)
293 |
294 | loss_dict = criterion(outputs, targets)
295 | weight_dict = criterion.weight_dict
296 |
297 | # reduce losses over all GPUs for logging purposes
298 | loss_dict_reduced = utils.reduce_dict(loss_dict)
299 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
300 | for k, v in loss_dict_reduced.items() if k in weight_dict}
301 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
302 | for k, v in loss_dict_reduced.items()}
303 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
304 | **loss_dict_reduced_scaled,
305 | **loss_dict_reduced_unscaled)
306 |
307 | metric_logger.update(group_class_error=loss_dict_reduced['group_class_error'])
308 |
309 | make_txt(boxes, infos, outputs, name_to_vid, file_path)
310 |
311 | # gather the stats from all processes
312 | metric_logger.synchronize_between_processes()
313 | print("Averaged stats:", metric_logger)
314 |
315 | detections = open(file_path, "r")
316 | result = metrics.evaluate(detections)
317 |
318 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, result
319 |
320 |
321 | def make_txt(boxes, infos, outputs, name_to_vid, file_path):
322 | for b in range(boxes.shape[0]):
323 | for t in range(boxes.shape[1]):
324 | image_w, image_h = args.image_width, args.image_height
325 |
326 | pred_group_actions = outputs['pred_activities'][b]
327 | pred_group_actions = F.softmax(pred_group_actions, dim=1)
328 | members = outputs['membership'][b]
329 |
330 | pred_membership = torch.argmax(members.transpose(0, 1), dim=1).detach().cpu()
331 | keep_membership = members.transpose(0, 1).max(-1).values > args.group_threshold
332 | pred_group_action = torch.argmax(pred_group_actions, dim=1).detach().cpu()
333 |
334 | for box_idx in range(boxes.shape[2]):
335 | x, y, w, h = boxes[b][t][box_idx]
336 | x1, y1, x2, y2 = (x - w / 2) * image_w, (y - h / 2) * image_h, (x + w / 2) * image_w, (
337 | y + h / 2) * image_h
338 |
339 | pred_group_id = pred_membership[box_idx]
340 | pred_group_action_idx = pred_group_action[pred_group_id]
341 | pred_group_action_prob = pred_group_actions[pred_group_id][pred_group_action_idx]
342 |
343 | if not (x1 == 0 and y1 == 0 and x2 == 0 and y2 == 0):
344 | if pred_group_action_idx != (pred_group_actions.shape[-1] - 1):
345 | if bool(keep_membership[box_idx]) is False:
346 | pred_group_id = -1
347 | pred_group_action_idx = args.num_class
348 |
349 | pred_list = [name_to_vid[infos[b]['vid']], infos[b]['sid'], infos[b]['fid'][t],
350 | int(x1), int(y1), int(x2), int(y2),
351 | int(pred_group_id), int(pred_group_action_idx) + 1,
352 | float(pred_group_action_prob)]
353 | str_to_be_added = [str(k) for k in pred_list]
354 | str_to_be_added = (" ".join(str_to_be_added))
355 |
356 | f = open(file_path, "a+")
357 | f.write(str_to_be_added + "\r\n")
358 | f.close()
359 |
360 |
361 | def collate_fn(batch):
362 | batch = list(zip(*batch))
363 | batch[0] = torch.stack([image for image in batch[0]])
364 | return tuple(batch)
365 |
366 |
367 | if __name__ == '__main__':
368 | main()
369 |
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dk-kim/CAFE_codebase/caee21a8ccdd5faadecbc5a5034107084a7cd270/util/__init__.py
--------------------------------------------------------------------------------
/util/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 |
23 | # modified from torchvision to also return the union
24 | def box_iou(boxes1, boxes2):
25 | area1 = box_area(boxes1)
26 | area2 = box_area(boxes2)
27 |
28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30 |
31 | wh = (rb - lt).clamp(min=0) # [N,M,2]
32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33 |
34 | union = area1[:, None] + area2 - inter
35 |
36 | iou = inter / union
37 | return iou, union
38 |
39 |
40 | def generalized_box_iou(boxes1, boxes2):
41 | """
42 | Generalized IoU from https://giou.stanford.edu/
43 | The boxes should be in [x0, y0, x1, y1] format
44 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
45 | and M = len(boxes2)
46 | """
47 | # degenerate boxes gives inf / nan results
48 | # so do an early check
49 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
50 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
51 | iou, union = box_iou(boxes1, boxes2)
52 |
53 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
54 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
55 |
56 | wh = (rb - lt).clamp(min=0) # [N,M,2]
57 | area = wh[:, :, 0] * wh[:, :, 1]
58 |
59 | return iou - (area - union) / area
60 |
61 |
62 | def masks_to_boxes(masks):
63 | """Compute the bounding boxes around the provided masks
64 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
65 | Returns a [N, 4] tensors, with the boxes in xyxy format
66 | """
67 | if masks.numel() == 0:
68 | return torch.zeros((0, 4), device=masks.device)
69 |
70 | h, w = masks.shape[-2:]
71 |
72 | y = torch.arange(0, h, dtype=torch.float)
73 | x = torch.arange(0, w, dtype=torch.float)
74 | y, x = torch.meshgrid(y, x)
75 |
76 | x_mask = (masks * x.unsqueeze(0))
77 | x_max = x_mask.flatten(1).max(-1)[0]
78 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
79 |
80 | y_mask = (masks * y.unsqueeze(0))
81 | y_max = y_mask.flatten(1).max(-1)[0]
82 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
83 |
84 | return torch.stack([x_min, y_min, x_max, y_max], 1)
85 |
86 |
87 | def rescale_bboxes(out_bbox, size):
88 | img_h, img_w = size
89 | b = box_cxcywh_to_xyxy(out_bbox)
90 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.get_device())
91 | return b
92 |
93 |
94 | def rescale_pairs(out_pairs, size):
95 | img_h, img_w = size
96 | h_bbox = out_pairs[:, :4]
97 | o_bbox = out_pairs[:, 4:]
98 |
99 | h = box_cxcywh_to_xyxy(h_bbox)
100 | h = h * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(h_bbox.get_device())
101 |
102 | obj_mask = (o_bbox[:, 0] != -1)
103 | if obj_mask.sum() != 0:
104 | o = box_cxcywh_to_xyxy(o_bbox)
105 | o = o * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(o_bbox.get_device())
106 | o_bbox[obj_mask] = o[obj_mask]
107 | o = o_bbox
108 | p = torch.cat([h, o], dim=-1)
109 |
110 | return p
--------------------------------------------------------------------------------
/util/logger.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : hotr/util/logger.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | import torch
9 | import time
10 | import datetime
11 | import sys
12 | from time import sleep
13 | from collections import defaultdict
14 |
15 | from util.misc import SmoothedValue
16 |
17 | def print_params(model):
18 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
19 | print('\n[Logger] Number of params: ', n_parameters)
20 | return n_parameters
21 |
22 | def print_args(args):
23 | print('\n[Logger] DETR Arguments:')
24 | for k, v in vars(args).items():
25 | if k in [
26 | 'lr', 'lr_backbone', 'lr_drop',
27 | 'frozen_weights',
28 | 'backbone', 'dilation',
29 | 'position_embedding', 'enc_layers', 'dec_layers', 'num_queries',
30 | 'dataset_file']:
31 | print(f'\t{k}: {v}')
32 |
33 | if args.HOIDet:
34 | print('\n[Logger] DETR_HOI Arguments:')
35 | for k, v in vars(args).items():
36 | if k in [
37 | 'freeze_enc',
38 | 'query_flag',
39 | 'hoi_nheads',
40 | 'hoi_dim_feedforward',
41 | 'hoi_dec_layers',
42 | 'hoi_idx_loss_coef',
43 | 'hoi_act_loss_coef',
44 | 'hoi_eos_coef',
45 | 'object_threshold']:
46 | print(f'\t{k}: {v}')
47 |
48 | class MetricLogger(object):
49 | def __init__(self, mode="test", delimiter="\t"):
50 | self.meters = defaultdict(SmoothedValue)
51 | self.delimiter = delimiter
52 | self.mode = mode
53 |
54 | def update(self, **kwargs):
55 | for k, v in kwargs.items():
56 | if isinstance(v, torch.Tensor):
57 | v = v.item()
58 | assert isinstance(v, (float, int))
59 | self.meters[k].update(v)
60 |
61 | def __getattr__(self, attr):
62 | if attr in self.meters:
63 | return self.meters[attr]
64 | if attr in self.__dict__:
65 | return self.__dict__[attr]
66 | raise AttributeError("'{}' object has no attribute '{}'".format(
67 | type(self).__name__, attr))
68 |
69 | def __str__(self):
70 | loss_str = []
71 | for name, meter in self.meters.items():
72 | loss_str.append(
73 | "{}: {}".format(name, str(meter))
74 | )
75 | return self.delimiter.join(loss_str)
76 |
77 | def synchronize_between_processes(self):
78 | for meter in self.meters.values():
79 | meter.synchronize_between_processes()
80 |
81 | def add_meter(self, name, meter):
82 | self.meters[name] = meter
83 |
84 | def log_every(self, iterable, print_freq, header=None):
85 | i = 0
86 | if not header:
87 | header = ''
88 | start_time = time.time()
89 | end = time.time()
90 | iter_time = SmoothedValue(fmt='{avg:.4f}')
91 | data_time = SmoothedValue(fmt='{avg:.4f}')
92 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
93 | if torch.cuda.is_available():
94 | log_msg = self.delimiter.join([
95 | header,
96 | '[{0' + space_fmt + '}/{1}]',
97 | 'eta: {eta}',
98 | '{meters}',
99 | 'time: {time}',
100 | 'data: {data}',
101 | 'max mem: {memory:.0f}'
102 | ])
103 | else:
104 | log_msg = self.delimiter.join([
105 | header,
106 | '[{0' + space_fmt + '}/{1}]',
107 | 'eta: {eta}',
108 | '{meters}',
109 | 'time: {time}',
110 | 'data: {data}'
111 | ])
112 | MB = 1024.0 * 1024.0
113 | for obj in iterable:
114 | data_time.update(time.time() - end)
115 | yield obj
116 | iter_time.update(time.time() - end)
117 |
118 | if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1:
119 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
120 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
121 | if torch.cuda.is_available():
122 | print(log_msg.format(
123 | i+1, len(iterable), eta=eta_string,
124 | meters=str(self),
125 | time=str(iter_time), data=str(data_time),
126 | memory=torch.cuda.max_memory_allocated() / MB),
127 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
128 | else:
129 | print(log_msg.format(
130 | i+1, len(iterable), eta=eta_string,
131 | meters=str(self),
132 | time=str(iter_time), data=str(data_time)),
133 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
134 | else:
135 | log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]'])
136 | if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
137 | else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
138 |
139 | i += 1
140 | end = time.time()
141 | total_time = time.time() - start_time
142 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
143 | if self.mode=='test': print("")
144 | print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format(
145 | self.mode, total_time_str, total_time / len(iterable)))
146 |
--------------------------------------------------------------------------------
/util/misc.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR)
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | """
9 | Misc functions, including distributed helpers.
10 | Mostly copy-paste from torchvision references.
11 | """
12 | import os
13 | import subprocess
14 | from collections import deque
15 | import pickle
16 | from typing import Optional, List
17 |
18 | import torch
19 | import torch.distributed as dist
20 | from torch import Tensor
21 |
22 | # needed due to empty tensor bug in pytorch and torchvision 0.5
23 | import torchvision
24 |
25 |
26 | class SmoothedValue(object):
27 | """Track a series of values and provide access to smoothed values over a
28 | window or the global series average.
29 | """
30 |
31 | def __init__(self, window_size=20, fmt=None):
32 | if fmt is None:
33 | fmt = "{median:.4f} ({global_avg:.4f})"
34 | self.deque = deque(maxlen=window_size)
35 | self.total = 0.0
36 | self.count = 0
37 | self.fmt = fmt
38 |
39 | def update(self, value, n=1):
40 | self.deque.append(value)
41 | self.count += n
42 | self.total += value * n
43 |
44 | def synchronize_between_processes(self):
45 | """
46 | Warning: does not synchronize the deque!
47 | """
48 | if not is_dist_avail_and_initialized():
49 | return
50 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
51 | dist.barrier()
52 | dist.all_reduce(t)
53 | t = t.tolist()
54 | self.count = int(t[0])
55 | self.total = t[1]
56 |
57 | @property
58 | def median(self):
59 | d = torch.tensor(list(self.deque))
60 | return d.median().item()
61 |
62 | @property
63 | def avg(self):
64 | d = torch.tensor(list(self.deque), dtype=torch.float32)
65 | return d.mean().item()
66 |
67 | @property
68 | def global_avg(self):
69 | return self.total / self.count
70 |
71 | @property
72 | def max(self):
73 | return max(self.deque)
74 |
75 | @property
76 | def value(self):
77 | return self.deque[-1]
78 |
79 | def __str__(self):
80 | return self.fmt.format(
81 | median=self.median,
82 | avg=self.avg,
83 | global_avg=self.global_avg,
84 | max=self.max,
85 | value=self.value)
86 |
87 |
88 | def all_gather(data):
89 | """
90 | Run all_gather on arbitrary picklable data (not necessarily tensors)
91 | Args:
92 | data: any picklable object
93 | Returns:
94 | list[data]: list of data gathered from each rank
95 | """
96 | world_size = get_world_size()
97 | if world_size == 1:
98 | return [data]
99 |
100 | # serialized to a Tensor
101 | buffer = pickle.dumps(data)
102 | storage = torch.ByteStorage.from_buffer(buffer)
103 | tensor = torch.ByteTensor(storage).to("cuda")
104 |
105 | # obtain Tensor size of each rank
106 | local_size = torch.tensor([tensor.numel()], device="cuda")
107 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
108 | dist.all_gather(size_list, local_size)
109 | size_list = [int(size.item()) for size in size_list]
110 | max_size = max(size_list)
111 |
112 | # receiving Tensor from all ranks
113 | # we pad the tensor because torch all_gather does not support
114 | # gathering tensors of different shapes
115 | tensor_list = []
116 | for _ in size_list:
117 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
118 | if local_size != max_size:
119 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
120 | tensor = torch.cat((tensor, padding), dim=0)
121 | dist.all_gather(tensor_list, tensor)
122 |
123 | data_list = []
124 | for size, tensor in zip(size_list, tensor_list):
125 | buffer = tensor.cpu().numpy().tobytes()[:size]
126 | data_list.append(pickle.loads(buffer))
127 |
128 | return data_list
129 |
130 |
131 | def reduce_dict(input_dict, average=True):
132 | """
133 | Args:
134 | input_dict (dict): all the values will be reduced
135 | average (bool): whether to do average or sum
136 | Reduce the values in the dictionary from all processes so that all processes
137 | have the averaged results. Returns a dict with the same fields as
138 | input_dict, after reduction.
139 | """
140 | world_size = get_world_size()
141 | if world_size < 2:
142 | return input_dict
143 | with torch.no_grad():
144 | names = []
145 | values = []
146 | # sort the keys so that they are consistent across processes
147 | for k in sorted(input_dict.keys()):
148 | names.append(k)
149 | values.append(input_dict[k])
150 | values = torch.stack(values, dim=0)
151 | dist.all_reduce(values)
152 | if average:
153 | values /= world_size
154 | reduced_dict = {k: v for k, v in zip(names, values)}
155 | return reduced_dict
156 |
157 |
158 | def get_sha():
159 | cwd = os.path.dirname(os.path.abspath(__file__))
160 |
161 | def _run(command):
162 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
163 | sha = 'N/A'
164 | diff = "clean"
165 | branch = 'N/A'
166 | try:
167 | sha = _run(['git', 'rev-parse', 'HEAD'])
168 | subprocess.check_output(['git', 'diff'], cwd=cwd)
169 | diff = _run(['git', 'diff-index', 'HEAD'])
170 | diff = "has uncommited changes" if diff else "clean"
171 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
172 | except Exception:
173 | pass
174 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
175 | return message
176 |
177 |
178 | def collate_fn(batch):
179 | print(batch)
180 | batch = list(zip(*batch))
181 | batch[0] = nested_tensor_from_tensor_list(batch[0])
182 | return tuple(batch)
183 |
184 |
185 | def _max_by_axis(the_list):
186 | # type: (List[List[int]]) -> List[int]
187 | maxes = the_list[0]
188 | for sublist in the_list[1:]:
189 | for index, item in enumerate(sublist):
190 | maxes[index] = max(maxes[index], item)
191 | return maxes
192 |
193 |
194 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
195 | # TODO make this more general
196 | if tensor_list[0].ndim == 3:
197 | # TODO make it support different-sized images
198 | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
199 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
200 | batch_shape = [len(tensor_list)] + max_size
201 | b, c, h, w = batch_shape
202 | dtype = tensor_list[0].dtype
203 | device = tensor_list[0].device
204 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
205 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
206 | for img, pad_img, m in zip(tensor_list, tensor, mask):
207 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
208 | m[: img.shape[1], :img.shape[2]] = False
209 | else:
210 | raise ValueError('not supported')
211 | return NestedTensor(tensor, mask)
212 |
213 |
214 | class NestedTensor(object):
215 | def __init__(self, tensors, mask: Optional[Tensor]):
216 | self.tensors = tensors
217 | self.mask = mask
218 |
219 | def to(self, device):
220 | # type: (Device) -> NestedTensor # noqa
221 | cast_tensor = self.tensors.to(device)
222 | mask = self.mask
223 | if mask is not None:
224 | assert mask is not None
225 | cast_mask = mask.to(device)
226 | else:
227 | cast_mask = None
228 | return NestedTensor(cast_tensor, cast_mask)
229 |
230 | def decompose(self):
231 | return self.tensors, self.mask
232 |
233 | def __repr__(self):
234 | return str(self.tensors)
235 |
236 |
237 | def setup_for_distributed(is_master):
238 | """
239 | This function disables printing when not in master process
240 | """
241 | import builtins as __builtin__
242 | builtin_print = __builtin__.print
243 |
244 | def print(*args, **kwargs):
245 | force = kwargs.pop('force', False)
246 | if is_master or force:
247 | builtin_print(*args, **kwargs)
248 |
249 | __builtin__.print = print
250 |
251 |
252 | def is_dist_avail_and_initialized():
253 | if not dist.is_available():
254 | return False
255 | if not dist.is_initialized():
256 | return False
257 | return True
258 |
259 |
260 | def get_world_size():
261 | if not is_dist_avail_and_initialized():
262 | return 1
263 | return dist.get_world_size()
264 |
265 |
266 | def get_rank():
267 | if not is_dist_avail_and_initialized():
268 | return 0
269 | return dist.get_rank()
270 |
271 |
272 | def is_main_process():
273 | return get_rank() == 0
274 |
275 |
276 | def save_on_master(*args, **kwargs):
277 | if is_main_process():
278 | torch.save(*args, **kwargs)
279 |
280 |
281 | def init_distributed_mode(args):
282 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
283 | args.rank = int(os.environ["RANK"])
284 | args.world_size = int(os.environ['WORLD_SIZE'])
285 | args.gpu = int(os.environ['LOCAL_RANK'])
286 | elif 'SLURM_PROCID' in os.environ:
287 | args.rank = int(os.environ['SLURM_PROCID'])
288 | args.gpu = args.rank % torch.cuda.device_count()
289 | else:
290 | print('Not using distributed mode')
291 | args.distributed = False
292 | return
293 |
294 | args.distributed = True
295 |
296 | torch.cuda.set_device(args.gpu)
297 | args.dist_backend = 'nccl'
298 | print('| distributed init (rank {}): {}'.format(
299 | args.rank, args.dist_url), flush=True)
300 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
301 | world_size=args.world_size, rank=args.rank)
302 | torch.distributed.barrier()
303 | setup_for_distributed(args.rank == 0)
304 |
305 |
306 | @torch.no_grad()
307 | def accuracy(output, target, topk=(1,)):
308 | """Computes the precision@k for the specified values of k"""
309 | if target.numel() == 0:
310 | return [torch.zeros([], device=output.device)]
311 | maxk = max(topk)
312 | batch_size = target.size(0)
313 |
314 | _, pred = output.topk(maxk, 1, True, True)
315 | pred = pred.t()
316 | correct = pred.eq(target.view(1, -1).expand_as(pred))
317 |
318 | res = []
319 | for k in topk:
320 | correct_k = correct[:k].view(-1).float().sum(0)
321 | res.append(correct_k.mul_(100.0 / batch_size))
322 | return res
323 |
324 |
325 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
326 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
327 | """
328 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
329 | This will eventually be supported natively by PyTorch, and this
330 | class can go away.
331 | """
332 | if float(torchvision.__version__[:3]) < 0.7:
333 | if input.numel() > 0:
334 | return torch.nn.functional.interpolate(
335 | input, size, scale_factor, mode, align_corners
336 | )
337 |
338 | output_shape = _output_size(2, input, size, scale_factor)
339 | output_shape = list(input.shape[:-2]) + list(output_shape)
340 | return _new_empty_tensor(input, output_shape)
341 | else:
342 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
--------------------------------------------------------------------------------
/util/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | import os
4 |
5 |
6 | def print_log(result_path, *args):
7 | os.makedirs(result_path, exist_ok=True)
8 |
9 | print(*args)
10 | file_path = result_path + '/log.txt'
11 | if file_path is not None:
12 | with open(file_path, 'a') as f:
13 | print(*args, file=f)
14 |
--------------------------------------------------------------------------------