├── dataloader ├── __init__.py ├── cafe.py └── dataloader.py ├── environment.yml ├── evaluation └── cafe_eval.py ├── label_map └── group_action_list.pbtxt ├── models ├── __init__.py ├── backbone.py ├── criterion.py ├── feed_forward.py ├── group_matcher.py ├── group_transformer.py ├── models.py └── position_encoding.py ├── readme.md ├── scripts ├── download_checkpoints.sh ├── download_datasets.sh ├── setup.sh ├── test_cafe_place.sh ├── test_cafe_view.sh ├── train_cafe_place.sh └── train_cafe_view.sh ├── test.py ├── train.py └── util ├── __init__.py ├── box_ops.py ├── logger.py ├── misc.py └── utils.py /dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dk-kim/CAFE_codebase/caee21a8ccdd5faadecbc5a5034107084a7cd270/dataloader/__init__.py -------------------------------------------------------------------------------- /dataloader/cafe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import torchvision.transforms as transforms 4 | 5 | import os 6 | import json 7 | import numpy as np 8 | import random 9 | from PIL import Image 10 | 11 | ACTIVITIES = ['Queueing', 'Ordering', 'Eating/Drinking', 'Working/Studying', 'Fighting', 'TakingSelfie'] 12 | 13 | 14 | # read annotation files 15 | def cafe_read_annotations(path, videos, num_class): 16 | labels = {} 17 | group_to_id = {name: i for i, name in enumerate(ACTIVITIES)} 18 | 19 | for vid in videos: 20 | video_path = os.path.join(path, vid) 21 | for cid in os.listdir(video_path): 22 | clip_path = os.path.join(video_path, cid) 23 | label_path = clip_path + '/ann.json' 24 | 25 | with open(label_path, 'r') as file: 26 | groups = {} 27 | boxes, actions, activities, members, membership = [], [], [], [], [] 28 | 29 | values = json.load(file) 30 | num_frames = values['framesCount'] 31 | frame_interval = values['framesEach'] 32 | actors = values['figures'] 33 | 34 | key_frame = actors[0]['shapes'][0]['frame'] 35 | 36 | for i, actor in enumerate(actors): 37 | actor_idx = actor['id'] 38 | group_name = actor['label'] 39 | 40 | box = actor['shapes'][0]['coordinates'] 41 | x1, y1 = box[0] 42 | x2, y2 = box[1] 43 | x_c, y_c = (x1 + x2) / 2, (y1 + y2) / 2 44 | w, h = x2 - x1, y2 - y1 45 | boxes.append([x_c, y_c, w, h]) 46 | 47 | if group_name != 'individual': 48 | group_idx = int(group_name[-1]) 49 | if actor['attributes'][0]['value'] != "": 50 | action = group_to_id[actor['attributes'][0]['value']['key']] 51 | 52 | if group_idx not in groups.keys(): 53 | groups[group_idx] = {'activity': action} 54 | 55 | if 'members' in groups[group_idx].keys(): 56 | groups[group_idx]['members'][i] = 1 57 | else: 58 | groups[group_idx]['members'] = torch.zeros(len(actors)) 59 | groups[group_idx]['members'][i] = 1 60 | else: 61 | if group_idx in groups.keys(): 62 | action = groups[group_idx]['activity'] 63 | else: 64 | action = -1 65 | else: 66 | action = num_class 67 | group_idx = 0 68 | 69 | actions.append(action) 70 | membership.append(group_idx) 71 | 72 | for i, action in enumerate(actions): 73 | if action == -1: 74 | group_idx = membership[i] 75 | 76 | if group_idx in groups.keys(): 77 | new_action = groups[group_idx]['activity'] 78 | actions[i] = new_action 79 | 80 | group_members = groups[group_idx]['members'] 81 | group_members[i] = 1 82 | else: 83 | membership[i] = 0 84 | actions[i] = num_class 85 | 86 | for group_id in sorted(groups): 87 | if group_id - 1 >= len(groups): 88 | new_id = len(groups) 89 | 90 | while new_id > 0: 91 | if new_id not in groups: 92 | groups[new_id] = groups[group_id] 93 | del groups[group_id] 94 | for i in range(len(membership)): 95 | if membership[i] == group_id: 96 | membership[i] = new_id 97 | group_id = new_id 98 | new_id -= 1 99 | 100 | for group_id in sorted(groups): 101 | activities.append(groups[group_id]['activity']) 102 | members.append(groups[group_id]['members']) 103 | 104 | actions = np.array(actions, dtype=np.int32) 105 | boxes = np.vstack(boxes) 106 | membership = np.array(membership, dtype=np.int32) - 1 107 | activities = np.array(activities, dtype=np.int32) 108 | 109 | actions = torch.from_numpy(actions).long() 110 | boxes = torch.from_numpy(boxes).float() 111 | membership = torch.from_numpy(membership).long() 112 | activities = torch.from_numpy(activities).long() 113 | 114 | if len(members) == 0: 115 | members = torch.tensor(members) 116 | else: 117 | members = torch.stack(members).float() 118 | 119 | annotations = { 120 | 'boxes': boxes, 121 | 'actions': actions, 122 | 'membership': membership, 123 | 'activities': activities, 124 | 'members': members, 125 | 'num_frames': num_frames, 126 | 'interval': frame_interval, 127 | 'key_frame': key_frame, 128 | } 129 | 130 | if len(annotations['activities']) != 0: 131 | labels[(int(vid), int(cid))] = annotations 132 | 133 | return labels 134 | 135 | 136 | def cafe_all_frames(labels): 137 | frames = [] 138 | 139 | for sid, anns in labels.items(): 140 | frames.append(sid) 141 | return frames 142 | 143 | 144 | class CafeDataset(data.Dataset): 145 | def __init__(self, frames, anns, tracks, image_path, args, is_training=True): 146 | super(CafeDataset, self).__init__() 147 | self.frames = frames 148 | self.anns = anns 149 | self.tracks = tracks 150 | self.image_path = image_path 151 | self.image_size = (args.image_width, args.image_height) 152 | self.num_boxes = args.num_boxes 153 | self.random_sampling = args.random_sampling 154 | self.num_frame = args.num_frame 155 | self.num_class = args.num_class 156 | self.is_training = is_training 157 | self.transform = transforms.Compose([ 158 | transforms.Resize((args.image_height, args.image_width)), 159 | transforms.ToTensor(), 160 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 161 | ]) 162 | 163 | def __getitem__(self, idx): 164 | if self.num_frame == 1: 165 | frames = self.select_key_frames(self.frames[idx]) 166 | else: 167 | frames = self.select_frames(self.frames[idx]) 168 | 169 | samples = self.load_samples(frames) 170 | 171 | return samples 172 | 173 | def __len__(self): 174 | return len(self.frames) 175 | 176 | def select_key_frames(self, frame): 177 | annotation = self.anns[frame] 178 | key_frame = annotation['key_frame'] 179 | 180 | return [(frame, int(key_frame))] 181 | 182 | def select_frames(self, frame): 183 | annotation = self.anns[frame] 184 | key_frame = annotation['key_frame'] 185 | total_frames = annotation['num_frames'] 186 | interval = annotation['interval'] 187 | 188 | if self.is_training: 189 | # random sampling 190 | if self.random_sampling: 191 | sample_frames = random.sample(range(total_frames), self.num_frame) 192 | sample_frames.sort() 193 | # segment-based sampling 194 | else: 195 | segment_duration = total_frames // self.num_frame 196 | sample_frames = np.multiply(list(range(self.num_frame)), segment_duration) + np.random.randint( 197 | segment_duration, size=self.num_frame) 198 | else: 199 | # random sampling 200 | if self.random_sampling: 201 | sample_frames = random.sample(range(total_frames), self.num_frame) 202 | sample_frames.sort() 203 | # segment-based sampling 204 | else: 205 | segment_duration = total_frames // self.num_frame 206 | sample_frames = np.multiply(list(range(self.num_frame)), segment_duration) + np.random.randint( 207 | segment_duration, size=self.num_frame) 208 | 209 | return [(frame, int(fid * annotation['interval'])) for fid in sample_frames] 210 | 211 | def load_samples(self, frames): 212 | images, boxes, gt_boxes, actions, activities, members, membership = [], [], [], [], [], [], [] 213 | targets = {} 214 | fids = [] 215 | 216 | for i, (frame, fid) in enumerate(frames): 217 | vid, cid = frame 218 | fids.append(fid) 219 | img = Image.open(self.image_path + '/%s/%s/images/frames_%d.jpg' % (vid, cid, fid)) 220 | image_w, image_h = img.width, img.height 221 | img = self.transform(img) 222 | images.append(img) 223 | 224 | num_boxes = self.anns[frame]['boxes'].shape[0] 225 | 226 | for box in self.anns[frame]['boxes']: 227 | x_c, y_c, w, h = box 228 | gt_boxes.append([x_c / image_w, y_c / image_h, w / image_w, h / image_h]) 229 | 230 | temp_boxes = np.ones((num_boxes, 4)) 231 | for j, track in enumerate(self.tracks[(vid, cid)][fid]): 232 | _id, x1, y1, x2, y2 = track 233 | 234 | if x1 < 0.0 and y2 < 0.0: 235 | x1, y1, x2, y2 = 0.0, 0.0, 1e-8, 1e-8 236 | 237 | x_c, y_c = (x1 + x2) / 2, (y1 + y2) / 2 238 | w, h = x2 - x1, y2 - y1 239 | 240 | if _id <= num_boxes: 241 | temp_boxes[int(_id - 1)] = np.array([x_c, y_c, w, h]) 242 | 243 | boxes.append(temp_boxes) 244 | actions = [self.anns[frame]['actions']] 245 | activities = [self.anns[frame]['activities']] 246 | members = [self.anns[frame]['members']] 247 | membership = [self.anns[frame]['membership']] 248 | 249 | if len(boxes[-1]) != self.num_boxes: 250 | boxes[-1] = np.vstack([boxes[-1], (self.num_boxes - len(boxes[-1])) * [[0.0, 0.0, 0.0, 0.0]]]) 251 | 252 | if len(actions[-1]) != self.num_boxes: 253 | actions[-1] = torch.cat((actions[-1], torch.tensor((self.num_boxes - len(actions[-1])) * [self.num_class + 1]))) 254 | 255 | if members[-1].shape[1] != self.num_boxes: 256 | members[-1] = torch.hstack( 257 | (members[-1], torch.zeros((members[-1].shape[0], self.num_boxes - members[-1].shape[1])))) 258 | 259 | if len(membership) != self.num_boxes: 260 | membership[-1] = torch.cat((membership[-1], torch.tensor((self.num_boxes - len(membership[-1])) * [-1]))) 261 | 262 | images = torch.stack(images) 263 | boxes = np.vstack(boxes).reshape([self.num_frame, -1, 4]) 264 | gt_boxes = np.vstack(gt_boxes).reshape([self.num_frame, -1, 4]) 265 | actions = torch.stack(actions) 266 | membership = torch.stack(membership) 267 | 268 | if len(activities) == 0: 269 | activities = torch.tensor(activities) 270 | members = torch.tensor(activities) 271 | else: 272 | activities = torch.stack(activities) 273 | members = torch.stack(members) 274 | 275 | boxes = torch.from_numpy(boxes).float() 276 | gt_boxes = torch.from_numpy(gt_boxes).float() 277 | 278 | targets['actions'] = actions 279 | targets['activities'] = activities 280 | targets['boxes'] = boxes 281 | targets['gt_boxes'] = gt_boxes 282 | targets['members'] = members 283 | targets['membership'] = membership 284 | 285 | infos = {'vid': vid, 'sid': cid, 'fid': fids, 'key_frame': self.anns[frame]['key_frame']} 286 | 287 | return images, targets, infos 288 | -------------------------------------------------------------------------------- /dataloader/dataloader.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from SSU (https://github.com/cvlab-epfl/social-scene-understanding) 3 | # Modified from ARG (https://github.com/wjchaoGit/Group-Activity-Recognition) 4 | # ------------------------------------------------------------------------ 5 | from .cafe import * 6 | 7 | import pickle 8 | 9 | TRAIN_CAFE_P = ['1', '2', '3', '4', '9', '10', '11', '12', '17', '18', '19', '20', '21', '22', '23', '24'] 10 | VAL_CAFE_P = ['13', '14', '15', '16'] 11 | TEST_CAFE_P = ['5', '6', '7', '8'] 12 | 13 | TRAIN_CAFE_V = ['1', '2', '5', '6', '9', '10', '13', '14', '17', '18', '21', '22'] 14 | VAL_CAFE_V = ['3', '7', '11', '15', '19', '23'] 15 | TEST_CAFE_V = ['4', '8', '12', '16', '20', '24'] 16 | 17 | 18 | def read_dataset(args): 19 | if args.dataset == 'cafe': 20 | data_path = args.data_path + 'cafe' 21 | 22 | # split-by-place setting 23 | if args.split == 'place': 24 | TRAIN_VIDEOS_CAFE = TRAIN_CAFE_P 25 | VAL_VIDEOS_CAFE = VAL_CAFE_P 26 | TEST_VIDEOS_CAFE = TEST_CAFE_P 27 | # split-by-view setting 28 | elif args.split == 'view': 29 | TRAIN_VIDEOS_CAFE = TRAIN_CAFE_V 30 | VAL_VIDEOS_CAFE = VAL_CAFE_V 31 | TEST_VIDEOS_CAFE = TEST_CAFE_V 32 | else: 33 | assert False 34 | 35 | if args.val_mode: 36 | train_data = cafe_read_annotations(data_path, TRAIN_VIDEOS_CAFE, args.num_class) 37 | train_frames = cafe_all_frames(train_data) 38 | 39 | test_data = cafe_read_annotations(data_path, VAL_VIDEOS_CAFE, args.num_class) 40 | test_frames = cafe_all_frames(test_data) 41 | else: 42 | train_data = cafe_read_annotations(data_path, TRAIN_VIDEOS_CAFE + VAL_VIDEOS_CAFE, args.num_class) 43 | train_frames = cafe_all_frames(train_data) 44 | 45 | test_data = cafe_read_annotations(data_path, TEST_VIDEOS_CAFE, args.num_class) 46 | test_frames = cafe_all_frames(test_data) 47 | 48 | # actor tracklets for all frames 49 | all_tracks = pickle.load(open(data_path + '/gt_tracks.pkl', 'rb')) 50 | 51 | train_set = CafeDataset(train_frames, train_data, all_tracks, data_path, args, is_training=True) 52 | test_set = CafeDataset(test_frames, test_data, all_tracks, data_path, args, is_training=False) 53 | else: 54 | assert False 55 | 56 | print("%d train samples and %d test samples" % (len(train_frames), len(test_frames))) 57 | 58 | return train_set, test_set 59 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: gad 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=4.5=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2021.10.26=h06a4308_2 10 | - certifi=2021.10.8=py38h06a4308_0 11 | - cudatoolkit=11.0.221=h6bb024c_0 12 | - freetype=2.11.0=h70c0345_0 13 | - giflib=5.2.1=h7b6447c_0 14 | - intel-openmp=2021.4.0=h06a4308_3561 15 | - joblib=1.1.0=pyhd3eb1b0_0 16 | - jpeg=9b=h024ee3a_2 17 | - lcms2=2.12=h3be6417_0 18 | - ld_impl_linux-64=2.35.1=h7274673_9 19 | - libffi=3.3=he6710b0_2 20 | - libgcc-ng=9.3.0=h5101ec6_17 21 | - libgfortran-ng=7.5.0=ha8ba4b0_17 22 | - libgfortran4=7.5.0=ha8ba4b0_17 23 | - libgomp=9.3.0=h5101ec6_17 24 | - libpng=1.6.37=hbc83047_0 25 | - libstdcxx-ng=9.3.0=hd4cf53a_17 26 | - libtiff=4.2.0=h85742a9_0 27 | - libuv=1.40.0=h7b6447c_0 28 | - libwebp=1.2.0=h89dd481_0 29 | - libwebp-base=1.2.0=h27cfd23_0 30 | - lz4-c=1.9.3=h295c915_1 31 | - mkl=2021.4.0=h06a4308_640 32 | - mkl-service=2.4.0=py38h7f8727e_0 33 | - mkl_fft=1.3.1=py38hd3c417c_0 34 | - mkl_random=1.2.2=py38h51133e4_0 35 | - ncurses=6.3=h7f8727e_2 36 | - ninja=1.10.2=py38hd09550d_3 37 | - numpy=1.21.2=py38h20f2e39_0 38 | - numpy-base=1.21.2=py38h79a1101_0 39 | - olefile=0.46=pyhd3eb1b0_0 40 | - openssl=1.1.1l=h7f8727e_0 41 | - pillow=8.4.0=py38h5aabda8_0 42 | - pip=21.2.4=py38h06a4308_0 43 | - python=3.8.5=h7579374_1 44 | - readline=8.1=h27cfd23_0 45 | - scikit-learn=1.0.1=py38h51133e4_0 46 | - scipy=1.7.1=py38h292c36d_2 47 | - setuptools=58.0.4=py38h06a4308_0 48 | - six=1.16.0=pyhd3eb1b0_0 49 | - sqlite=3.36.0=hc218d9a_0 50 | - threadpoolctl=2.2.0=pyh0d69192_0 51 | - tk=8.6.11=h1ccaba5_0 52 | - torchvision=0.8.2=py38_cu110 53 | - typing_extensions=3.10.0.2=pyh06a4308_0 54 | - wheel=0.37.0=pyhd3eb1b0_1 55 | - xz=5.2.5=h7b6447c_0 56 | - zlib=1.2.11=h7b6447c_3 57 | - zstd=1.4.9=haebb681_0 58 | prefix: /opt/conda/envs/gad 59 | -------------------------------------------------------------------------------- /evaluation/cafe_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from JRDB-ACT (https://github.com/JRDB-dataset/jrdb_toolkit/tree/main/Action%26Social_grouping_eval) 3 | # ------------------------------------------------------------------------ 4 | from collections import defaultdict, Counter 5 | import numpy as np 6 | import copy 7 | 8 | 9 | def make_image_key(v_id, c_id, f_id): 10 | """Returns a unique identifier for a video id & clip id & frame id""" 11 | return "%d,%d,%d" % (int(v_id), int(c_id), int(f_id)) 12 | 13 | def make_clip_key(image_key): 14 | """Returns a unique identifier for a video id & clip id""" 15 | v_id = image_key.split(',')[0] 16 | c_id = image_key.split(',')[1] 17 | return "%d,%d" % (int(v_id), int(c_id)) 18 | 19 | def read_text_file(text_file, eval_type, mode): 20 | """Loads boxes and class labels from a CSV file in the cafe format. 21 | 22 | Args: 23 | text_file: A file object. 24 | mode: 'gt' or 'pred' 25 | eval_type: 26 | 'gt_base': Eval type for trained model with ground turth actor tracklets as inputs. 27 | 'detect_base': Eval type for trained model with tracker actor tracklets as inputs. 28 | 29 | Returns: 30 | boxes: A dictionary mapping each unique image key (string) to a list of 31 | boxes, given as coordinates [y1, x1, y2, x2]. 32 | g_labels: A dictionary mapping each unique image key (string) to a list of 33 | integer group id labels, matching the corresponding box in 'boxes'. 34 | act_labels: A dictionary mapping each unique image key (string) to a list of 35 | integer group activity class lables, matching the corresponding box in `boxes`. 36 | a_scores: A dictionary mapping each unique image key (string) to a list of 37 | actor confidence score values lables, matching the corresponding box in `boxes`. 38 | g_scores: A dictionary mapping each unique image key (string) to a list of 39 | group confidence score values lables, matching the corresponding box in `boxes`. 40 | """ 41 | boxes = defaultdict(list) 42 | g_labels = defaultdict(list) 43 | act_labels = defaultdict(list) 44 | a_scores = defaultdict(list) 45 | g_scores = defaultdict(list) 46 | # reads each row in text file. 47 | with open(text_file.name) as r: 48 | for line in r.readlines(): 49 | row = line[:-1].split(' ') 50 | # makes image key. 51 | image_key = make_image_key(row[0], row[1], row[2]) 52 | # box coordinates. 53 | x1, y1, x2, y2 = [float(n) for n in row[3:7]] 54 | # actor confidence score. 55 | if eval_type == 'detect_base' and mode == 'pred': 56 | a_score = float(row[10]) 57 | else: 58 | a_score = 1.0 59 | # group confidence score. 60 | if mode == 'gt': 61 | g_score = None 62 | elif mode == 'pred': 63 | g_score = float(row[9]) 64 | # group identity document. 65 | group_id = int(row[7]) 66 | # group activity label. 67 | activity = int(row[8]) 68 | 69 | boxes[image_key].append([x1, y1, x2, y2]) 70 | g_labels[image_key].append(group_id) 71 | act_labels[image_key].append(activity) 72 | a_scores[image_key].append(a_score) 73 | g_scores[image_key].append(g_score) 74 | return boxes, g_labels, act_labels, a_scores, g_scores 75 | 76 | def actor_matching(pred_boxes, pred_a_scores, gt_boxes): 77 | """matches prediction tracklets and ground truth tracklets. 78 | 79 | Args: 80 | pred_boxes: A dictionary mapping each unique image key (string) to a list of 81 | prediction boxes, given as coordinates [y1, x1, y2, x2]. it has same permutation 82 | in image keys of each clip key. 83 | pred_a_scores: A dictionary mapping each unique image key (string) to a list of 84 | actor confidence score values lables, matching the corresponding box in `pred_boxes`. 85 | gt_boxes: A dictionary mapping each unique image key (string) to a list of 86 | ground truth boxes, given as coordinates [y1, x1, y2, x2]. it has same permutation 87 | in image keys of each clip key. 88 | 89 | Returns: 90 | matching_results: A dictionary mapping each unique clip key (string) to a list of 91 | id matching results between prediction tracklets and ground truth tracklets, given as 92 | {[prediction id: [ground truth id]]}. 93 | """ 94 | image_keys = pred_boxes.keys() 95 | frame_list = defaultdict(list) 96 | matching_results = defaultdict(list) 97 | for image_key in image_keys: 98 | clip_key = make_clip_key(image_key) 99 | frame_list[clip_key].append(image_key) 100 | clip_keys = frame_list.keys() 101 | for clip_key in clip_keys: 102 | matching = defaultdict(list) 103 | # IoU matrix 104 | iou_matrix = np.zeros((len(pred_boxes[frame_list[clip_key][0]]), len(gt_boxes[frame_list[clip_key][0]]))) 105 | # confidence score for prediction tracklets. 106 | confidence_mean = np.zeros(len(pred_boxes[frame_list[clip_key][0]])) 107 | # image numbers of clip. 108 | frame_len = len(frame_list[clip_key]) 109 | # puts sum of IoUs on the IoU matrix and sum of confidence score. 110 | for image_key in frame_list[clip_key]: 111 | for i, pred_box in enumerate(pred_boxes[image_key]): 112 | if pred_box[2] != 0 and pred_box[2] != -1: 113 | confidence_mean[i] += pred_a_scores[image_key][i] 114 | for j, gt_box in enumerate(gt_boxes[image_key]): 115 | if gt_box[2] != 0 and gt_box[2] != -1: 116 | iou_matrix[i,j] += IoU(pred_box, gt_box) 117 | # takes each mean of IoU and confidence score. 118 | confidence_mean = confidence_mean / frame_len 119 | iou_matrix = iou_matrix / frame_len 120 | # sorts by confidence score. 121 | sorted_scores = sorted(confidence_mean, reverse=True) 122 | 123 | # matching algorithm 124 | duplicated_score = 0 125 | for score in sorted_scores: 126 | if duplicated_score == score: 127 | continue 128 | else: 129 | for i in (np.where(confidence_mean == score)[0]): 130 | if max(iou_matrix[i]) > 0.5: 131 | j = np.where(iou_matrix[i] == max(iou_matrix[i]))[0][0] 132 | matching[i].append(j) 133 | iou_matrix[:,j] = 0 134 | 135 | # matching results 136 | matching_results[clip_key].append(matching) 137 | return matching_results 138 | 139 | 140 | def make_groups(boxes, g_labels, act_labels, g_scores): 141 | """combines boxes, activity, score to same group, same image 142 | 143 | Returns: 144 | groups_ids: A dictionary mapping each unique clip key (string) to a list of 145 | actor ids of each 'g_label'. 146 | groups_activity: A dictionary mapping each unique clip key (string) to a list of 147 | group activity class labels. 148 | groups_score: A dictionary mapping each unique clip key (string) to a list of 149 | group confidence score. 150 | """ 151 | image_keys = boxes.keys() 152 | groups_activity = defaultdict(list) 153 | groups_score = defaultdict(list) 154 | groups_ids = defaultdict(list) 155 | frame_list = defaultdict(list) 156 | # makes clip key. 157 | for image_key in image_keys: 158 | clip_key = make_clip_key(image_key) 159 | frame_list[clip_key].append(image_key) 160 | clip_keys = frame_list.keys() 161 | for clip_key in clip_keys: 162 | group_ids = defaultdict(list) 163 | group_activity = defaultdict(set) 164 | group_score = defaultdict(set) 165 | for i, (g_label, act_label, g_score) in enumerate( 166 | zip(g_labels[frame_list[clip_key][0]], act_labels[frame_list[clip_key][0]], 167 | g_scores[frame_list[clip_key][0]])): 168 | group_ids[g_label].append(i) 169 | group_activity[g_label].add(act_label) 170 | group_score[g_label].add(g_score) 171 | groups_ids[clip_key].append(group_ids) 172 | groups_activity[clip_key].append(group_activity) 173 | groups_score[clip_key].append(group_score) 174 | 175 | return groups_ids, groups_activity, groups_score 176 | 177 | 178 | def read_labelmap(labelmap_file): 179 | """Reads a labelmap without the dependency on protocol buffers. 180 | 181 | Args: 182 | labelmap_file: A file object containing a label map protocol buffer. 183 | 184 | Returns: 185 | labelmap: The label map in the form used by the object_detection_evaluation 186 | module - a list of {"id": integer, "name": classname } dicts. 187 | class_ids: A set containing all of the valid class id integers. 188 | """ 189 | labelmap = [] 190 | class_ids = set() 191 | name = "" 192 | class_id = "" 193 | for line in labelmap_file: 194 | if line.startswith(" name:"): 195 | name = line.split('"')[1] 196 | elif line.startswith(" id:") or line.startswith(" label_id:"): 197 | class_id = int(line.strip().split(" ")[-1]) 198 | labelmap.append({"id": class_id, "name": name}) 199 | class_ids.add(class_id) 200 | return labelmap, class_ids 201 | 202 | 203 | def IoU(box1, box2): 204 | """calculates IoU between two different boxes.""" 205 | # box = (x1, y1, x2, y2) 206 | box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1]) 207 | box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1]) 208 | 209 | # obtain x1, y1, x2, y2 of the intersection 210 | x1 = max(box1[0], box2[0]) 211 | y1 = max(box1[1], box2[1]) 212 | x2 = min(box1[2], box2[2]) 213 | y2 = min(box1[3], box2[3]) 214 | 215 | # compute the width and height of the intersection 216 | w = max(0, x2 - x1) 217 | h = max(0, y2 - y1) 218 | 219 | inter = w * h 220 | iou = inter / (box1_area + box2_area - inter + 1e-8) 221 | return iou 222 | 223 | 224 | def cal_group_IoU(pred_group, gt_group): 225 | """calculates group IoU between two different groups""" 226 | # Intersection 227 | Intersection = sum([1 for det_id in pred_group[2] if det_id in gt_group[2]]) 228 | 229 | # group IoU 230 | if Intersection != 0: 231 | group_IoU = Intersection / (len(pred_group[2]) + len(gt_group[2]) - Intersection) 232 | else: 233 | group_IoU = 0 234 | return group_IoU 235 | 236 | 237 | def calculateAveragePrecision(rec, prec): 238 | """calculates AP score of each activity class by all-point interploation method.""" 239 | mrec = [0] + [e for e in rec] + [1] 240 | mpre = [0] + [e for e in prec] + [0] 241 | 242 | for i in range(len(mpre) - 1, 0, -1): 243 | mpre[i - 1] = max(mpre[i - 1], mpre[i]) 244 | 245 | ii = [] 246 | 247 | for i in range(len(mrec) - 1): 248 | if mrec[1:][i] != mrec[0:-1][i]: 249 | ii.append(i + 1) 250 | 251 | ap = 0 252 | for i in ii: 253 | ap = ap + np.sum((mrec[i] - mrec[i - 1]) * mpre[i]) 254 | 255 | return [ap, mpre[0:len(mpre) - 1], mrec[0:len(mpre) - 1], ii] 256 | 257 | 258 | def outlier_metric(gt_groups_ids, gt_groups_activity, pred_groups_ids, pred_groups_activity, num_class): 259 | """calculates Outlier mIoU. 260 | 261 | Args: 262 | num_class: A number of activity classes. 263 | 264 | Returns: 265 | outlier_mIoU: Mean of outlier IoUs on each clip. 266 | """ 267 | clip_IoU = defaultdict(list) 268 | TP = defaultdict(list) 269 | clip_keys = pred_groups_ids.keys() 270 | c_pred_groups_activity = copy.deepcopy(pred_groups_activity) 271 | c_gt_groups_activity = copy.deepcopy(gt_groups_activity) 272 | # prediction groups on each class. defines group has members equals or more than two. 273 | pred_groups = [[clip_key, group_id, pred_groups_ids[clip_key][0][group_id]] for clip_key in clip_keys if 274 | clip_key in gt_groups_ids.keys() for group_id in pred_groups_ids[clip_key][0].keys() if 275 | c_pred_groups_activity[clip_key][0][group_id].pop() == (num_class + 1)] 276 | # ground truth groups on each class. 277 | gt_groups = [[clip_key, group_id, gt_groups_ids[clip_key][0][group_id]] for clip_key in clip_keys if 278 | clip_key in gt_groups_ids.keys() for group_id in gt_groups_ids[clip_key][0].keys() if 279 | c_gt_groups_activity[clip_key][0][group_id].pop() == (num_class + 1)] 280 | for clip_key in clip_keys: 281 | # escapes error that there are not exist pred_image_key on gt.txt. 282 | if clip_key in gt_groups_ids.keys(): 283 | # groups on same clip 284 | c_pred_groups = [pred_group for pred_group in pred_groups if pred_group[0] == clip_key] 285 | c_gt_groups = [gt_group for gt_group in gt_groups if gt_group[0] == clip_key] 286 | if len(c_pred_groups) != 0 and len(c_gt_groups) != 0: 287 | # outliers on prediction and ground truth. 288 | c_pred_ids = [pred_id for c_pred_group in c_pred_groups for pred_id in c_pred_group[2]] 289 | c_gt_ids = [gt_id for c_gt_group in c_gt_groups for gt_id in c_gt_group[2]] 290 | # number of True positive outliers. 291 | TP[clip_key] = sum([1 for pred_id in c_pred_ids if pred_id in c_gt_ids]) 292 | clip_IoU[clip_key] = TP[clip_key] / (len(c_pred_ids) + len(c_gt_ids) - TP[clip_key]) 293 | clip_IoU['total'].append(clip_IoU[clip_key]) 294 | elif len(c_pred_groups) != 0 or len(c_gt_groups) != 0: 295 | TP[clip_key] = 0 296 | clip_IoU[clip_key] = 0 297 | clip_IoU['total'].append(clip_IoU[clip_key]) 298 | # outlier mIoU. 299 | outlier_mIoU = np.array(clip_IoU['total']).mean() 300 | return outlier_mIoU * 100.0 301 | 302 | 303 | def group_mAP_eval(gt_groups_ids, gt_groups_activity, pred_groups_ids, pred_groups_activity, pred_groups_scores, 304 | categories, thresh): 305 | """calculates group mAP. 306 | 307 | Args: 308 | categories: A list of group activity classes, given as {name: ,id: }. 309 | thresh: A group IoU threshold for true positive prediction group condition. 310 | 311 | Returns: 312 | group_mAP: Mean of group APs on each activity class. 313 | group_APs; A list of each group AP on each activity class. 314 | """ 315 | clip_keys = pred_groups_ids.keys() 316 | # acc on each class. 317 | group_APs = np.zeros(len(categories)) 318 | for c, clas in enumerate(categories): 319 | # copy for set funtion to pop. 320 | c_pred_groups_activity = copy.deepcopy(pred_groups_activity) 321 | c_gt_groups_activity = copy.deepcopy(gt_groups_activity) 322 | # prediction groups on each class. 323 | pred_groups = [ 324 | [clip_key, group_id, pred_groups_ids[clip_key][0][group_id], pred_groups_scores[clip_key][0][group_id]] for 325 | clip_key in clip_keys if clip_key in gt_groups_ids.keys() for group_id in 326 | pred_groups_ids[clip_key][0].keys() if 327 | c_pred_groups_activity[clip_key][0][group_id].pop() == clas['id'] and len( 328 | pred_groups_ids[clip_key][0][group_id]) >= 2] 329 | # ground truth groups on each class. 330 | gt_groups = [[clip_key, group_id, gt_groups_ids[clip_key][0][group_id]] for clip_key in clip_keys if 331 | clip_key in gt_groups_ids.keys() for group_id in gt_groups_ids[clip_key][0].keys() if 332 | c_gt_groups_activity[clip_key][0][group_id].pop() == clas['id'] and len( 333 | gt_groups_ids[clip_key][0][group_id]) >= 2] 334 | 335 | # denominator of Recall. 336 | npos = len(gt_groups) 337 | 338 | # sorts det_groups in descending order for g_score. 339 | pred_groups = sorted(pred_groups, key=lambda conf: conf[3], reverse=True) 340 | 341 | TP = np.zeros(len(pred_groups)) 342 | FP = np.zeros(len(pred_groups)) 343 | 344 | det = Counter(gt_group[0] for gt_group in gt_groups) 345 | 346 | for key, val in det.items(): 347 | det[key] = np.zeros(val) 348 | 349 | # AP matching algorithm. 350 | for p, pred_group in enumerate(pred_groups): 351 | if pred_group[0] in gt_groups_ids.keys(): 352 | gt = [gt_group for gt_group in gt_groups if gt_group[0] == pred_group[0]] 353 | group_IoU_Max = 0 354 | for j, gt_group in enumerate(gt): 355 | group_IoU = cal_group_IoU(pred_group, gt_group) 356 | if group_IoU > group_IoU_Max: 357 | group_IoU_Max = group_IoU 358 | jmax = j 359 | # true positive prediction group condition. 360 | if group_IoU_Max >= thresh: 361 | if det[pred_group[0]][jmax] == 0: 362 | TP[p] = 1 363 | det[pred_group[0]][jmax] = 1 364 | else: 365 | FP[p] = 1 366 | else: 367 | FP[p] = 1 368 | 369 | acc_FP = np.cumsum(FP) 370 | acc_TP = np.cumsum(TP) 371 | # recall 372 | rec = acc_TP / npos 373 | # precision 374 | prec = np.divide(acc_TP, (acc_FP + acc_TP)) 375 | [ap, mpre, mrec, ii] = calculateAveragePrecision(rec, prec) 376 | # group AP on each group activity class 377 | group_APs[c] = ap * 100 378 | # group mAP 379 | group_mAP = group_APs.mean() 380 | return group_mAP, group_APs 381 | 382 | class GAD_Evaluation(): 383 | def __init__(self, args): 384 | super(GAD_Evaluation, self).__init__() 385 | self.eval_type = args.eval_type 386 | self.categories, self.class_whitelist = read_labelmap(args.labelmap) 387 | self.gt_boxes, self.gt_g_labels, self.gt_act_labels, _, self.gt_g_scores = read_text_file(args.groundtruth, self.eval_type, mode='gt') 388 | self.gt_groups_ids, self.gt_groups_activity, _ = make_groups( 389 | self.gt_boxes, self.gt_g_labels, self.gt_act_labels, self.gt_g_scores) 390 | 391 | 392 | def evaluate(self, detections): 393 | pred_boxes, pred_g_labels, pred_act_labels, pred_a_scores, pred_g_scores = read_text_file(detections, self.eval_type, mode='pred') 394 | pred_groups_ids, pred_groups_activity, pred_groups_scores = make_groups(pred_boxes, pred_g_labels, 395 | pred_act_labels, 396 | pred_g_scores) 397 | group_mAP, group_APs = group_mAP_eval(self.gt_groups_ids, self.gt_groups_activity, 398 | pred_groups_ids, pred_groups_activity, pred_groups_scores, 399 | self.categories, thresh=1.0) 400 | group_mAP_2, group_APs_2 = group_mAP_eval(self.gt_groups_ids, self.gt_groups_activity, 401 | pred_groups_ids, pred_groups_activity, pred_groups_scores, 402 | self.categories, thresh=0.5) 403 | outlier_mIoU = outlier_metric(self.gt_groups_ids, self.gt_groups_activity, 404 | pred_groups_ids, pred_groups_activity, 405 | len(self.categories)) 406 | result = { 407 | 'group_APs_1.0': group_APs, 408 | 'group_mAP_1.0': group_mAP, 409 | 'group_APs_0.5': group_APs_2, 410 | 'group_mAP_0.5': group_mAP_2, 411 | 'outlier_mIoU': outlier_mIoU, 412 | } 413 | return result 414 | -------------------------------------------------------------------------------- /label_map/group_action_list.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | name: "Queueing" 3 | id: 1 4 | } 5 | item { 6 | name: "Ordering" 7 | id: 2 8 | } 9 | item { 10 | name: "Eating/Drinking" 11 | id: 3 12 | } 13 | item { 14 | name: "Working/Studying" 15 | id: 4 16 | } 17 | item { 18 | name: "Fighting" 19 | id: 5 20 | } 21 | item { 22 | name: "TakingSelfie" 23 | id: 6 24 | } 25 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .group_matcher import build_group_matcher 2 | from .criterion import SetCriterion 3 | from .models import GADTR 4 | 5 | 6 | def build_model(args): 7 | model = GADTR(args) 8 | 9 | losses = ['labels', 'cardinality'] 10 | group_losses = ['group_labels', 'group_cardinality', 'group_code', 'group_consistency'] 11 | 12 | # Set loss coefficients 13 | weight_dict = {} 14 | weight_dict['loss_ce'] = args.ce_loss_coef 15 | weight_dict['loss_group_ce'] = args.group_ce_loss_coef 16 | weight_dict['loss_group_code'] = args.group_code_loss_coef 17 | weight_dict['loss_consistency'] = args.consistency_loss_coef 18 | 19 | # Group matching 20 | group_matcher = build_group_matcher(args) 21 | 22 | # Loss functions 23 | criterion = SetCriterion(args.num_class, weight_dict=weight_dict, eos_coef=args.eos_coef, 24 | losses=losses, group_losses=group_losses, group_matcher=group_matcher, args=args) 25 | 26 | criterion.cuda() 27 | 28 | return model, criterion 29 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torchvision 9 | from torchvision.models._utils import IntermediateLayerGetter 10 | 11 | from .position_encoding import build_position_encoding 12 | 13 | 14 | class FrozenBatchNorm2d(torch.nn.Module): 15 | """ 16 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 17 | 18 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 19 | without which any other models than torchvision.models.resnet[18,34,50,101] 20 | produce nans. 21 | """ 22 | 23 | def __init__(self, n): 24 | super(FrozenBatchNorm2d, self).__init__() 25 | self.register_buffer("weight", torch.ones(n)) 26 | self.register_buffer("bias", torch.zeros(n)) 27 | self.register_buffer("running_mean", torch.zeros(n)) 28 | self.register_buffer("running_var", torch.ones(n)) 29 | 30 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 31 | missing_keys, unexpected_keys, error_msgs): 32 | num_batches_tracked_key = prefix + 'num_batches_tracked' 33 | if num_batches_tracked_key in state_dict: 34 | del state_dict[num_batches_tracked_key] 35 | 36 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 37 | state_dict, prefix, local_metadata, strict, 38 | missing_keys, unexpected_keys, error_msgs) 39 | 40 | def forward(self, x): 41 | # move reshapes to the beginning 42 | # to make it fuser-friendly 43 | w = self.weight.reshape(1, -1, 1, 1) 44 | b = self.bias.reshape(1, -1, 1, 1) 45 | rv = self.running_var.reshape(1, -1, 1, 1) 46 | rm = self.running_mean.reshape(1, -1, 1, 1) 47 | eps = 1e-5 48 | scale = w * (rv + eps).rsqrt() 49 | bias = b - rm * scale 50 | return x * scale + bias 51 | 52 | 53 | class Backbone(nn.Module): 54 | def __init__(self, args): 55 | super(Backbone, self).__init__() 56 | 57 | if args.frozen_batch_norm: 58 | backbone = getattr(torchvision.models, args.backbone)( 59 | replace_stride_with_dilation=[False, False, args.dilation], 60 | pretrained=True, norm_layer=FrozenBatchNorm2d) 61 | else: 62 | backbone = getattr(torchvision.models, args.backbone)( 63 | replace_stride_with_dilation=[False, False, args.dilation], 64 | pretrained=True) 65 | 66 | self.num_frames = args.num_frame 67 | self.num_channels = 512 if args.backbone in ('resnet18', 'resnet34') else 2048 68 | 69 | self.body = IntermediateLayerGetter(backbone, return_layers={'layer4': "0"}) 70 | 71 | def forward(self, x): 72 | x = self.body(x)["0"] 73 | 74 | return x 75 | 76 | 77 | class Joiner(nn.Sequential): 78 | def __init__(self, backbone, position_embedding): 79 | super().__init__(backbone, position_embedding) 80 | 81 | def forward(self, x): 82 | bs, t, _, h, w = x.shape 83 | x = x.reshape(bs * t, 3, h, w) 84 | 85 | features = self[0](x) 86 | _, c, oh, ow = features.shape 87 | 88 | pos = self[1](features).to(x.dtype) 89 | 90 | return features, pos 91 | 92 | 93 | def build_backbone(args): 94 | pos_embed = build_position_encoding(args) 95 | backbone = Backbone(args) 96 | model = Joiner(backbone, pos_embed) 97 | model.num_channels = backbone.num_channels 98 | return model 99 | -------------------------------------------------------------------------------- /models/criterion.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR) 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | import torch 9 | import torch.nn.functional as F 10 | import copy 11 | import numpy as np 12 | 13 | from torch import nn 14 | 15 | from util import box_ops 16 | from util.misc import (accuracy, get_world_size, is_dist_avail_and_initialized) 17 | 18 | 19 | class SetCriterion(nn.Module): 20 | def __init__(self, num_classes, weight_dict, eos_coef, losses, group_losses=None, 21 | group_matcher=None, args=None): 22 | """ Create the criterion. 23 | Parameters: 24 | num_classes: number of object categories, omitting the special no-object category 25 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 26 | eos_coef: relative classification weight applied to the no-group activity class category 27 | losses: list of all the losses to be applied. See get_loss for list of available losses. 28 | group_losses: list of all the group losses to be applied. See get_group_loss for list of available group losses. 29 | group_matcher: module able to compute a matching between targets and predictions 30 | """ 31 | super().__init__() 32 | self.num_classes = num_classes 33 | self.weight_dict = weight_dict 34 | self.losses = losses 35 | self.eos_coef = eos_coef 36 | 37 | self.group_losses = group_losses 38 | self.group_matcher = group_matcher 39 | 40 | empty_weight = torch.ones(self.num_classes + 1) 41 | empty_weight[-1] = eos_coef 42 | self.register_buffer('empty_weight', empty_weight) 43 | 44 | empty_group_weight = torch.ones(self.num_classes + 1) 45 | empty_group_weight[-1] = args.group_eos_coef 46 | self.register_buffer('empty_group_weight', empty_group_weight) 47 | 48 | self.num_boxes = args.num_boxes 49 | 50 | # option 51 | self.temperature = args.temperature 52 | 53 | ####################################################################################################################### 54 | # * Individual Losses 55 | ####################################################################################################################### 56 | def loss_labels(self, outputs, targets, num_boxes, log=True): 57 | """Individual action classification loss (NLL)""" 58 | assert 'pred_actions' in outputs 59 | src_logits = outputs['pred_actions'] 60 | target_classes = torch.cat([v["actions"] for v in targets], dim=0) 61 | 62 | loss_ce = 0.0 63 | 64 | src_logits_log = None 65 | tgt_classes_log = None 66 | 67 | for batch_idx in range(src_logits.shape[0]): 68 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze() 69 | non_dummy_idx = dummy_idx.nonzero(as_tuple=True) 70 | src_logit = src_logits[batch_idx][non_dummy_idx].unsqueeze(0) 71 | target_class = target_classes[batch_idx][non_dummy_idx].unsqueeze(0) 72 | loss_ce += F.cross_entropy(src_logit.transpose(1, 2), target_class, self.empty_weight) 73 | 74 | if src_logits_log is None: 75 | src_logits_log = src_logit 76 | tgt_classes_log = target_class 77 | else: 78 | src_logits_log = torch.cat([src_logits_log.squeeze(), src_logit.squeeze()], dim=0) 79 | tgt_classes_log = torch.cat([tgt_classes_log.squeeze(), target_class.squeeze()], dim=0) 80 | 81 | loss_ce /= src_logits.shape[0] 82 | losses = {'loss_ce': loss_ce} 83 | 84 | if log: 85 | # TODO this should probably be a separate loss, not hacked in this one here 86 | losses['class_error'] = 100 - accuracy(src_logits_log, tgt_classes_log)[0] 87 | 88 | return losses 89 | 90 | @torch.no_grad() 91 | def loss_cardinality(self, outputs, targets, num_boxes): 92 | pred_logits = outputs['pred_actions'] 93 | device = pred_logits.device 94 | # tgt_lengths = torch.as_tensor([len(v["actions"]) for v in targets], device=device) 95 | tgt_lengths = torch.as_tensor([len(k) for v in targets for k in v["actions"]], device=device) 96 | # Count the number of predictions that are NOT "no-object" (which is the last class) 97 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) 98 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 99 | losses = {'cardinality_error': card_err} 100 | return losses 101 | 102 | ####################################################################################################################### 103 | # * Group Losses 104 | ####################################################################################################################### 105 | def loss_group_labels(self, outputs, targets, group_indices, log=True): 106 | """ Group activity classification loss (NLL)""" 107 | assert 'pred_activities' in outputs 108 | src_logits = outputs['pred_activities'] 109 | 110 | idx = self._get_src_permutation_idx(group_indices) 111 | flatten_targets = [u for t in targets for u in t["activities"]] 112 | target_classes_o = torch.cat([t[J] for t, (_, J) in zip(flatten_targets, group_indices)]) 113 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, dtype=torch.int64, 114 | device=src_logits.device) 115 | target_classes[idx] = target_classes_o 116 | 117 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_group_weight) 118 | losses = {'loss_group_ce': loss_ce} 119 | 120 | if log: 121 | # TODO this should probably be a separate loss, not hacked in this one here 122 | losses['group_class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] 123 | 124 | return losses 125 | 126 | @torch.no_grad() 127 | def loss_group_cardinality(self, outputs, targets, group_indices): 128 | pred_logits = outputs['pred_activities'] 129 | device = pred_logits.device 130 | tgt_lengths = torch.as_tensor([len(k) for v in targets for k in v["activities"]], device=device) 131 | # Count the number of predictions that are NOT "no-object" (which is the last class) 132 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) 133 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 134 | losses = {'group_cardinality_error': card_err} 135 | return losses 136 | 137 | def loss_group_code(self, outputs, targets, group_indices, log=True): 138 | """Membership loss""" 139 | sim = outputs['membership'] 140 | 141 | idx = self._get_src_permutation_idx(group_indices) 142 | 143 | # Binary cross entropy loss 144 | flatten_targets = [u for t in targets for u in t["members"]] 145 | 146 | # target_members_o = torch.cat([t[J] for t, (_, J) in zip(flatten_targets, group_indices)]).type(torch.FloatTensor).to(sim.device) 147 | target_members_o = torch.cat([t[J] for t, (_, J) in zip(flatten_targets, group_indices)]) 148 | target_members = torch.full(sim.shape, 0.0, dtype=torch.float, device=sim.device) 149 | target_members[idx] = target_members_o 150 | 151 | loss_membership = 0.0 152 | for batch_idx in range(sim.shape[0]): 153 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze() 154 | non_dummy_idx = dummy_idx.nonzero(as_tuple=True) 155 | sim_batch = sim[batch_idx].transpose(0, 1)[non_dummy_idx].transpose(0, 1).unsqueeze(0) 156 | 157 | target_members_batch = target_members[batch_idx].transpose(0, 1)[non_dummy_idx].transpose(0, 1).unsqueeze(0) 158 | 159 | loss_membership += F.binary_cross_entropy(sim_batch, target_members_batch) 160 | loss_membership /= sim.shape[0] 161 | 162 | losses = {'loss_group_code': loss_membership} 163 | return losses 164 | 165 | def loss_group_consistency(self, outputs, targets, group_indices): 166 | """Group consistency loss""" 167 | actor_embeds = outputs['actor_embeddings'] 168 | 169 | consistency_loss = 0.0 170 | 171 | for batch_idx in range(actor_embeds.shape[0]): 172 | membership = targets[batch_idx]["membership"][0] 173 | actor_embed = actor_embeds[batch_idx] # [n, f] 174 | 175 | cos = nn.CosineSimilarity(dim=-1) 176 | sim = cos(actor_embed.unsqueeze(1), actor_embed.unsqueeze(0)) / self.temperature 177 | 178 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze() 179 | non_dummy_idx = dummy_idx.nonzero(as_tuple=True) 180 | 181 | N = len(non_dummy_idx[0]) 182 | 183 | non_dummy_membership = membership[non_dummy_idx] 184 | 185 | group_count = 0 186 | 187 | for actor_idx in range(N): 188 | group_id = non_dummy_membership[actor_idx] 189 | 190 | if group_id != -1: 191 | positive_idx = (non_dummy_membership == group_id).nonzero(as_tuple=True) 192 | positive_idx = list(positive_idx[0]) 193 | positive_idx.remove(actor_idx) 194 | positive_idx = [tuple(positive_idx)] 195 | positive_samples = sim[actor_idx][positive_idx] 196 | 197 | negative_idx = (non_dummy_membership != group_id).nonzero(as_tuple=True) 198 | negative_samples = sim[actor_idx][negative_idx] 199 | 200 | nominator = torch.exp(positive_samples) 201 | denominator = torch.exp(torch.cat((positive_samples, negative_samples))) 202 | loss_partial = -torch.log(torch.sum(nominator) / torch.sum(denominator)) 203 | group_count += 1 204 | 205 | consistency_loss += loss_partial 206 | 207 | consistency_loss /= group_count 208 | 209 | consistency_loss /= actor_embeds.shape[0] 210 | losses = {'loss_consistency': consistency_loss} 211 | return losses 212 | 213 | def _get_src_permutation_idx(self, indices): 214 | # permute predictions following indices 215 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 216 | src_idx = torch.cat([src for (src, _) in indices]) 217 | return batch_idx, src_idx 218 | 219 | def _get_tgt_permutation_idx(self, indices): 220 | # permute targets following indices 221 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 222 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 223 | return batch_idx, tgt_idx 224 | 225 | # ***************************************************************************** 226 | # >>> DETR Losses 227 | def get_loss(self, loss, outputs, targets, num_boxes, **kwargs): 228 | loss_map = { 229 | 'labels': self.loss_labels, 230 | 'cardinality': self.loss_cardinality, 231 | } 232 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 233 | return loss_map[loss](outputs, targets, num_boxes, **kwargs) 234 | 235 | # >>> Group Losses 236 | def get_group_loss(self, loss, outputs, targets, group_indices, **kwargs): 237 | loss_map = { 238 | 'group_labels': self.loss_group_labels, 239 | 'group_cardinality': self.loss_group_cardinality, 240 | 'group_code': self.loss_group_code, 241 | 'group_consistency': self.loss_group_consistency, 242 | } 243 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 244 | return loss_map[loss](outputs, targets, group_indices, **kwargs) 245 | 246 | # ***************************************************************************** 247 | 248 | def forward(self, outputs, targets, log=True): 249 | """ This performs the loss computation. 250 | Parameters: 251 | outputs: dict of tensors, see the output specification of the model for the format 252 | targets: list of dicts, such that len(targets) == batch_size. 253 | The expected keys in each dict depends on the losses applied, see each loss' doc 254 | """ 255 | outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'} 256 | 257 | num_boxes = sum(len(u) for t in targets for u in t["actions"]) 258 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) 259 | if is_dist_avail_and_initialized(): 260 | torch.distributed.all_reduce(num_boxes) 261 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 262 | 263 | sim = outputs['membership'] 264 | bs, num_queries, num_clip_boxes = sim.shape 265 | 266 | for tgt in targets: 267 | tgt["dummy_idx"] = torch.ones_like(tgt["actions"], dtype=int) 268 | for box_idx in range(num_clip_boxes): 269 | if bool(tgt["actions"][0, box_idx] == self.num_classes + 1): 270 | tgt["dummy_idx"][0, box_idx] = 0 271 | 272 | input_targets = [copy.deepcopy(target) for target in targets] 273 | group_indices = self.group_matcher(outputs_without_aux, input_targets) 274 | 275 | # Compute all the requested losses 276 | losses = {} 277 | for loss in self.losses: 278 | losses.update(self.get_loss(loss, outputs, targets, num_boxes)) 279 | 280 | # Group activity detection losses 281 | for loss in self.group_losses: 282 | losses.update(self.get_group_loss(loss, outputs, targets, group_indices)) 283 | 284 | return losses 285 | -------------------------------------------------------------------------------- /models/feed_forward.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR) 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | import torch.nn.functional as F 6 | from torch import nn 7 | 8 | 9 | class MLP(nn.Module): 10 | """ Very simple multi-layer perceptron (also called FFN)""" 11 | 12 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 13 | super().__init__() 14 | self.num_layers = num_layers 15 | h = [hidden_dim] * (num_layers - 1) 16 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 17 | 18 | def forward(self, x): 19 | for i, layer in enumerate(self.layers): 20 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 21 | return x -------------------------------------------------------------------------------- /models/group_matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Modules to compute the matching cost and solve the corresponding LSAP. 7 | """ 8 | import torch 9 | from scipy.optimize import linear_sum_assignment 10 | from torch import nn 11 | import torch.nn.functional as F 12 | 13 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 14 | 15 | 16 | class HungarianMatcher(nn.Module): 17 | """This class computes an assignment between the targets and the predictions of the network 18 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 19 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 20 | while the others are un-matched (and thus treated as non-objects). 21 | """ 22 | 23 | def __init__(self, cost_class: float = 1, cost_code: float = 1): 24 | """Creates the matcher 25 | Params: 26 | cost_class: This is the relative weight of the classification error in the matching cost 27 | cost_code: This is the relative weight of the L2 error of the membership code in the matching cost 28 | """ 29 | super().__init__() 30 | self.cost_class = cost_class 31 | self.cost_code = cost_code 32 | assert cost_class != 0 or cost_code != 0, "all costs cant be 0" 33 | 34 | # membership cost 35 | def _get_cost_code(self, sim, targets_membership): 36 | cost_code = torch.cdist(sim, targets_membership, p=2) 37 | return cost_code 38 | 39 | @torch.no_grad() 40 | def forward(self, outputs, targets): 41 | bs, num_queries, num_boxes = outputs["membership"].shape 42 | 43 | # We flatten to compute the cost matrices in a batch 44 | out_prob = outputs["pred_activities"].flatten(0, 1).softmax(-1) # [bs * t * num_queries, num_classes] 45 | 46 | sim = outputs["membership"] 47 | 48 | if "dummy_idx" in targets[0].keys(): 49 | sim_dummy = [] 50 | for batch_idx in range(bs): 51 | dummy_idx = targets[batch_idx]["dummy_idx"].squeeze() 52 | sim_batch = sim[batch_idx] * dummy_idx 53 | sim_dummy.append(sim_batch.unsqueeze(0)) 54 | sim = torch.cat(sim_dummy, dim=0) 55 | 56 | sim = sim.flatten(0, 1) 57 | 58 | # Also concat the target labels and boxes 59 | tgt_ids = torch.cat([v["activities"] for v in targets], dim=1).reshape(-1) 60 | targets_membership = torch.cat([v["members"] for v in targets], dim=1).reshape(-1, num_boxes) 61 | 62 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 63 | # but approximate it in 1 - proba[target class]. 64 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 65 | cost_class = -out_prob[:, tgt_ids] 66 | # If cost code is similarity, it should be minus sign 67 | cost_code = self._get_cost_code(sim, targets_membership) 68 | 69 | # Final cost matrix 70 | C = self.cost_class * cost_class + self.cost_code * cost_code 71 | C = C.view(bs, num_queries, -1).cpu() 72 | 73 | sizes = [len(k) for v in targets for k in v["activities"]] 74 | 75 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 76 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 77 | 78 | 79 | def build_group_matcher(args): 80 | return HungarianMatcher(cost_class=args.set_cost_group_class, cost_code=args.set_cost_membership) 81 | -------------------------------------------------------------------------------- /models/group_transformer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR) 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | import copy 9 | from typing import Optional 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn, Tensor 14 | 15 | 16 | class Transformer(nn.Module): 17 | 18 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, num_actor=14, 19 | dim_feedforward=2048, dropout=0.1, activation="relu", return_intermediate_dec=False): 20 | super().__init__() 21 | 22 | decoder_layer = TransformerDecoderLayer(d_model, nhead, num_actor, dim_feedforward, dropout, activation) 23 | decoder_norm = nn.LayerNorm(d_model) 24 | self.decoder = TransformerDecoder(decoder_layer, num_encoder_layers, decoder_norm, 25 | return_intermediate=return_intermediate_dec) 26 | 27 | self._reset_parameters() 28 | self.d_model = d_model 29 | self.nhead = nhead 30 | 31 | def _reset_parameters(self): 32 | for p in self.parameters(): 33 | if p.dim() > 1: 34 | nn.init.xavier_uniform_(p) 35 | 36 | def forward(self, src, actor_mask, group_dummy_mask, group_embed, pos_embed, actor_embed): 37 | bs, t, c, h, w = src.shape 38 | _, _, n, _ = actor_embed.shape 39 | src = src.reshape(bs * t, c, h, w).flatten(2).permute(2, 0, 1) # [h x w, bs x t, c] 40 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 41 | group_embed = group_embed.reshape(-1, t, c) 42 | group_embed = group_embed.unsqueeze(1).repeat(1, bs, 1, 1).reshape(-1, bs * t, c) 43 | actor_embed = actor_embed.reshape(bs * t, -1, c).permute(1, 0, 2) # [n, bs x t, c] 44 | query_embed = torch.cat([actor_embed, group_embed], dim=0) # [n + k, bs x t, c] 45 | tgt = torch.zeros_like(query_embed) # [n + k, bs x t, c] 46 | if actor_mask is not None: 47 | actor_mask = actor_mask.unsqueeze(1).repeat(1, self.nhead, 1, 1).reshape(-1, n, n) 48 | hs, actor_att, feature_att = self.decoder(tgt, src, attn_mask=actor_mask, 49 | tgt_key_padding_mask=group_dummy_mask, 50 | pos=pos_embed, query_pos=query_embed) 51 | 52 | return hs.transpose(1, 2), actor_att, feature_att 53 | 54 | 55 | class TransformerDecoder(nn.Module): 56 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 57 | super().__init__() 58 | self.layers = _get_clones(decoder_layer, num_layers) 59 | self.num_layers = num_layers 60 | self.norm = norm 61 | self.return_intermediate = return_intermediate 62 | 63 | def forward(self, tgt, memory, 64 | tgt_mask: Optional[Tensor] = None, 65 | memory_mask: Optional[Tensor] = None, 66 | tgt_key_padding_mask: Optional[Tensor] = None, 67 | memory_key_padding_mask: Optional[Tensor] = None, 68 | attn_mask: Optional[Tensor] = None, 69 | pos: Optional[Tensor] = None, 70 | query_pos: Optional[Tensor] = None): 71 | output = tgt 72 | actor_att = None 73 | feature_att = None 74 | 75 | intermediate = [] 76 | intermediate_actor_att = [] 77 | intermediate_feature_att = [] 78 | 79 | for layer in self.layers: 80 | output, actor_att, feature_att = layer(output, memory, tgt_mask=tgt_mask, 81 | memory_mask=memory_mask, 82 | tgt_key_padding_mask=tgt_key_padding_mask, 83 | memory_key_padding_mask=memory_key_padding_mask, 84 | attn_mask=attn_mask, 85 | pos=pos, query_pos=query_pos) 86 | if self.return_intermediate: 87 | intermediate.append(self.norm(output)) 88 | intermediate_actor_att.append(actor_att) 89 | intermediate_feature_att.append(feature_att) 90 | 91 | if self.norm is not None: 92 | output = self.norm(output) 93 | if self.return_intermediate: 94 | intermediate.pop() 95 | intermediate.append(output) 96 | 97 | if self.return_intermediate: 98 | return torch.stack(intermediate), torch.stack(intermediate_actor_att), torch.stack(intermediate_feature_att) 99 | 100 | if actor_att is not None: 101 | actor_att = actor_att.unsqueeze(0) 102 | if feature_att is not None: 103 | feature_att = feature_att.unsqueeze(0) 104 | 105 | return output.unsqueeze(0), actor_att, feature_att 106 | 107 | 108 | class TransformerDecoderLayer(nn.Module): 109 | 110 | def __init__(self, d_model, nhead, num_actor, dim_feedforward=2048, dropout=0.1, activation="relu"): 111 | super().__init__() 112 | 113 | self.multihead_attn1 = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 114 | 115 | self.self_attn1 = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 116 | self.self_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 117 | self.multihead_attn2 = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 118 | 119 | # Implementation of Feedforward model 120 | self.linear1 = nn.Linear(d_model, dim_feedforward) 121 | self.dropout = nn.Dropout(dropout) 122 | self.linear2 = nn.Linear(dim_feedforward, d_model) 123 | 124 | self.norm1 = nn.LayerNorm(d_model) 125 | self.norm2 = nn.LayerNorm(d_model) 126 | self.norm3 = nn.LayerNorm(d_model) 127 | self.norm4 = nn.LayerNorm(d_model) 128 | self.dropout1 = nn.Dropout(dropout) 129 | self.dropout2 = nn.Dropout(dropout) 130 | self.dropout3 = nn.Dropout(dropout) 131 | self.dropout4 = nn.Dropout(dropout) 132 | 133 | self.activation = _get_activation_fn(activation) 134 | 135 | self.num_actor = num_actor 136 | 137 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 138 | return tensor if pos is None else tensor + pos 139 | 140 | def forward_post(self, tgt, memory, 141 | tgt_mask: Optional[Tensor] = None, 142 | memory_mask: Optional[Tensor] = None, 143 | tgt_key_padding_mask: Optional[Tensor] = None, 144 | memory_key_padding_mask: Optional[Tensor] = None, 145 | attn_mask=None, 146 | pos: Optional[Tensor] = None, 147 | query_pos: Optional[Tensor] = None): 148 | 149 | q = k = self.with_pos_embed(tgt, query_pos) 150 | 151 | actor_q = q[:self.num_actor, :, :] 152 | group_q = q[self.num_actor:, :, :] 153 | actor_k = k[:self.num_actor, :, :] 154 | group_k = k[self.num_actor:, :, :] 155 | 156 | tgt_actor = tgt[:self.num_actor, :, :] 157 | tgt_group = tgt[self.num_actor:, :, :] 158 | 159 | # actor-actor, group-group self-attention 160 | tgt2_actor, actor_att = self.self_attn1(actor_q, actor_k, value=tgt_actor, attn_mask=attn_mask) 161 | tgt2_group, _ = self.self_attn2(group_q, group_k, value=tgt_group) 162 | tgt2 = torch.cat([tgt2_actor, tgt2_group], dim=0) 163 | 164 | tgt = tgt + self.dropout1(tgt2) 165 | tgt = self.norm1(tgt) 166 | 167 | # actor-group attention 168 | tgt_actor = tgt[:self.num_actor, :, :] 169 | tgt_group = tgt[self.num_actor:, :, :] 170 | 171 | tgt2_group, group_att = self.multihead_attn1(query=tgt_group, key=tgt_actor, value=tgt_actor, 172 | key_padding_mask=tgt_key_padding_mask) 173 | 174 | tgt2 = torch.cat([tgt_actor, tgt2_group], dim=0) 175 | 176 | tgt = tgt + self.dropout2(tgt2) 177 | tgt = self.norm2(tgt) 178 | 179 | # actor-feature, group-feature cross-attention 180 | tgt2, feature_att = self.multihead_attn2(query=self.with_pos_embed(tgt, query_pos), 181 | key=self.with_pos_embed(memory, pos), value=memory) 182 | tgt = tgt + self.dropout3(tgt2) 183 | tgt = self.norm3(tgt) 184 | 185 | # FFN 186 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 187 | tgt = tgt + self.dropout4(tgt2) 188 | tgt = self.norm4(tgt) 189 | 190 | return tgt, group_att, feature_att 191 | 192 | def forward(self, tgt, memory, 193 | tgt_mask: Optional[Tensor] = None, 194 | memory_mask: Optional[Tensor] = None, 195 | tgt_key_padding_mask: Optional[Tensor] = None, 196 | memory_key_padding_mask: Optional[Tensor] = None, 197 | attn_mask: Optional[Tensor] = None, 198 | pos: Optional[Tensor] = None, 199 | query_pos: Optional[Tensor] = None): 200 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, tgt_key_padding_mask, 201 | memory_key_padding_mask, attn_mask, pos, query_pos) 202 | 203 | 204 | def _get_clones(module, N): 205 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 206 | 207 | 208 | def build_group_transformer(args): 209 | return Transformer( 210 | d_model=args.hidden_dim, 211 | dropout=args.drop_rate, 212 | nhead=args.gar_nheads, 213 | dim_feedforward=args.gar_ffn_dim, 214 | num_encoder_layers=args.gar_enc_layers, 215 | return_intermediate_dec=False, 216 | num_actor=args.num_boxes, 217 | ) 218 | 219 | 220 | def _get_activation_fn(activation): 221 | """Return an activation function given a string""" 222 | if activation == "relu": 223 | return F.relu 224 | if activation == "gelu": 225 | return F.gelu 226 | if activation == "glu": 227 | return F.glu 228 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 229 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from roi_align.roi_align import RoIAlign 6 | 7 | from .backbone import build_backbone 8 | from .group_transformer import build_group_transformer 9 | from .feed_forward import MLP 10 | 11 | 12 | class GADTR(nn.Module): 13 | def __init__(self, args): 14 | super(GADTR, self).__init__() 15 | 16 | self.dataset = args.dataset 17 | self.num_class = args.num_class 18 | self.num_frame = args.num_frame 19 | self.num_boxes = args.num_boxes 20 | 21 | self.hidden_dim = args.hidden_dim 22 | self.backbone = build_backbone(args) 23 | 24 | # RoI Align 25 | self.crop_size = args.crop_size 26 | self.roi_align = RoIAlign(crop_height=self.crop_size, crop_width=self.crop_size) 27 | self.fc_emb = nn.Linear(self.crop_size*self.crop_size*self.backbone.num_channels, self.hidden_dim) 28 | self.drop_emb = nn.Dropout(p=args.drop_rate) 29 | 30 | # Actor embedding 31 | self.input_proj = nn.Conv2d(self.backbone.num_channels, self.hidden_dim, kernel_size=1) 32 | self.box_pos_emb = MLP(4, self.hidden_dim, self.hidden_dim, 3) 33 | 34 | # Individual action classification head 35 | self.class_emb = nn.Linear(self.hidden_dim, self.num_class + 1) 36 | 37 | # Group Transformer 38 | self.group_transformer = build_group_transformer(args) 39 | self.num_group_tokens = args.num_group_tokens 40 | self.group_query_emb = nn.Embedding(self.num_group_tokens * self.num_frame, self.hidden_dim) 41 | 42 | # Group activity classfication head 43 | self.group_emb = nn.Linear(self.hidden_dim, self.num_class + 1) 44 | 45 | # Distance mask threshold 46 | self.distance_threshold = args.distance_threshold 47 | 48 | # Membership prediction heads 49 | self.actor_match_emb = nn.Linear(self.hidden_dim, self.hidden_dim) 50 | self.group_match_emb = nn.Linear(self.hidden_dim, self.hidden_dim) 51 | 52 | self.relu = F.relu 53 | 54 | for name, m in self.named_modules(): 55 | if 'backbone' not in name and 'group_transformer' not in name: 56 | if isinstance(m, nn.Linear): 57 | nn.init.kaiming_normal_(m.weight) 58 | if m.bias is not None: 59 | nn.init.zeros_(m.bias) 60 | 61 | def calculate_pairwise_distnace(self, boxes): 62 | bs = boxes.shape[0] 63 | 64 | rx = boxes.pow(2).sum(dim=2).reshape((bs, -1, 1)) 65 | ry = boxes.pow(2).sum(dim=2).reshape((bs, -1, 1)) 66 | 67 | dist = rx - 2.0 * boxes.matmul(boxes.transpose(1, 2)) + ry.transpose(1, 2) 68 | 69 | return torch.sqrt(dist) 70 | 71 | def forward(self, x, boxes, dummy_mask): 72 | """ 73 | :param x: [B, T, 3, H, W] 74 | :param boxes: [B, T, N, 4] 75 | :param dummy_mask: [B, N] 76 | :return: 77 | """ 78 | bs, t, _, h, w = x.shape 79 | n = boxes.shape[2] 80 | 81 | boxes = torch.reshape(boxes, (-1, 4)) # [b x t x n, 4] 82 | boxes_flat = boxes.clone().detach() 83 | boxes_idx = [i * torch.ones(n, dtype=torch.int) for i in range(bs * t)] 84 | boxes_idx = torch.stack(boxes_idx).to(device=boxes.device) 85 | boxes_idx_flat = torch.reshape(boxes_idx, (bs * t * n, )) # [b x t x n] 86 | 87 | features, pos = self.backbone(x) 88 | _, c, oh, ow = features.shape # [b x t, d, oh, ow] 89 | 90 | src = self.input_proj(features) 91 | src = torch.reshape(src, (bs, t, -1, oh, ow)) # [b, t, c, oh, ow] 92 | 93 | # calculate distance & distance mask 94 | boxes_center = boxes.clone().detach() 95 | boxes_center = torch.reshape(boxes_center[:, :2], (-1, n, 2)) 96 | boxes_distance = self.calculate_pairwise_distnace(boxes_center) 97 | 98 | distance_mask = (boxes_distance > self.distance_threshold) 99 | 100 | # ignore dummy boxes (padded boxes to match the number of actors) 101 | dummy_mask = dummy_mask.unsqueeze(1).repeat(1, t, 1).reshape(-1, n) 102 | actor_dummy_mask = (~dummy_mask.unsqueeze(2)).float() @ (~dummy_mask.unsqueeze(1)).float() 103 | dummy_diag = (dummy_mask.unsqueeze(2).float() @ dummy_mask.unsqueeze(1).float()).nonzero(as_tuple=True) 104 | actor_mask = ~(actor_dummy_mask.bool()) 105 | actor_mask[dummy_diag] = False 106 | actor_mask = distance_mask + actor_mask 107 | group_dummy_mask = dummy_mask 108 | 109 | boxes_flat[:, 0] = (boxes[:, 0] - boxes[:, 2] / 2) * ow 110 | boxes_flat[:, 1] = (boxes[:, 1] - boxes[:, 3] / 2) * oh 111 | boxes_flat[:, 2] = (boxes[:, 0] + boxes[:, 2] / 2) * ow 112 | boxes_flat[:, 3] = (boxes[:, 1] + boxes[:, 3] / 2) * oh 113 | 114 | boxes_flat.requires_grad = False 115 | boxes_idx_flat.requires_grad = False 116 | 117 | # extract actor features 118 | actor_features = self.roi_align(features, boxes_flat, boxes_idx_flat) 119 | actor_features = torch.reshape(actor_features, (bs * t * n, -1)) 120 | actor_features = self.fc_emb(actor_features) 121 | actor_features = F.relu(actor_features) 122 | actor_features = self.drop_emb(actor_features) 123 | actor_features = actor_features.reshape(bs, t, n, self.hidden_dim) 124 | 125 | # add positional information to box features 126 | box_pos_emb = self.box_pos_emb(boxes) 127 | box_pos_emb = torch.reshape(box_pos_emb, (bs, t, n, -1)) # [b, t, n, c] 128 | actor_features = actor_features + box_pos_emb 129 | 130 | # group transformer 131 | hs, actor_att, feature_att = self.group_transformer(src, actor_mask, group_dummy_mask, 132 | self.group_query_emb.weight, pos, actor_features) 133 | # [1, bs * t, n + k, f'], [1, bs * t, k, n], [1, bs * t, n + k, oh x ow] M: # group tokens, K: # boxes 134 | 135 | actor_hs = hs[0, :, :n] 136 | group_hs = hs[0, :, n:] 137 | 138 | actor_hs = actor_hs.reshape(bs, t, n, -1) 139 | actor_hs = actor_features + actor_hs 140 | 141 | # normalize 142 | inst_repr = F.normalize(actor_hs.reshape(bs, t, n, -1).mean(dim=1), p=2, dim=2) 143 | group_repr = F.normalize(group_hs.reshape(bs, t, self.num_group_tokens, -1).mean(dim=1), p=2, dim=2) 144 | 145 | # prediction heads 146 | outputs_class = self.class_emb(actor_hs) 147 | 148 | outputs_group_class = self.group_emb(group_hs) 149 | 150 | outputs_actor_emb = self.actor_match_emb(inst_repr) 151 | outputs_group_emb = self.group_match_emb(group_repr) 152 | 153 | membership = torch.bmm(outputs_group_emb, outputs_actor_emb.transpose(1, 2)) 154 | membership = F.softmax(membership, dim=1) 155 | 156 | out = { 157 | "pred_actions": outputs_class.reshape(bs, t, self.num_boxes, self.num_class + 1).mean(dim=1), 158 | "pred_activities": outputs_group_class.reshape(bs, t, self.num_group_tokens, self.num_class + 1).mean(dim=1), 159 | "membership": membership.reshape(bs, self.num_group_tokens, self.num_boxes), 160 | "actor_embeddings": F.normalize(actor_hs.reshape(bs, t, n, -1).mean(dim=1), p=2, dim=2), 161 | } 162 | 163 | return out 164 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | 6 | import math 7 | import torch 8 | from torch import nn 9 | 10 | 11 | class PositionEmbeddingSine(nn.Module): 12 | """ 13 | This is a more standard version of the position embedding, very similar to the one 14 | used by the Attention is all you need paper, generalized to work on images. 15 | """ 16 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 17 | super().__init__() 18 | self.num_pos_feats = num_pos_feats 19 | self.temperature = temperature 20 | self.normalize = normalize 21 | if scale is not None and normalize is False: 22 | raise ValueError("normalize should be True if scale is passed") 23 | if scale is None: 24 | scale = 2 * math.pi 25 | self.scale = scale 26 | 27 | def forward(self, x): 28 | bs, c, h, w = x.shape 29 | 30 | y_embed = torch.arange(1, h + 1, device=x.device).unsqueeze(0).unsqueeze(2) 31 | y_embed = y_embed.repeat(bs, 1, w) 32 | x_embed = torch.arange(1, w + 1, device=x.device).unsqueeze(0).unsqueeze(1) 33 | x_embed = x_embed.repeat(bs, h, 1) 34 | 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 48 | return pos 49 | 50 | 51 | class PositionEmbeddingLearned(nn.Module): 52 | """ 53 | Absolute pos embedding, learned. 54 | """ 55 | def __init__(self, num_pos_feats=256): 56 | super().__init__() 57 | self.row_embed = nn.Embedding(50, num_pos_feats) 58 | self.col_embed = nn.Embedding(50, num_pos_feats) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | nn.init.uniform_(self.row_embed.weight) 63 | nn.init.uniform_(self.col_embed.weight) 64 | 65 | def forward(self, x): 66 | h, w = x.shape[-2:] 67 | i = torch.arange(w, device=x.device) 68 | j = torch.arange(h, device=x.device) 69 | x_emb = self.col_embed(i) 70 | y_emb = self.row_embed(j) 71 | pos = torch.cat([ 72 | x_emb.unsqueeze(0).repeat(h, 1, 1), 73 | y_emb.unsqueeze(1).repeat(1, w, 1), 74 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 75 | return pos 76 | 77 | 78 | def build_position_encoding(args): 79 | N_steps = args.hidden_dim // 2 80 | if args.position_embedding in ('v2', 'sine'): 81 | # TODO find a better way of exposing other arguments 82 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 83 | elif args.position_embedding in ('v3', 'learned'): 84 | position_embedding = PositionEmbeddingLearned(N_steps) 85 | else: 86 | raise ValueError(f"not supported {args.position_embedding}") 87 | 88 | return position_embedding 89 | 90 | 91 | def build_index_encoding(args): 92 | N_steps = args.hidden_dim // 2 93 | if args.index_embedding in ('v2', 'sine'): 94 | # TODO find a better way of exposing other arguments 95 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 96 | elif args.index_embedding in ('v3', 'learned'): 97 | position_embedding = PositionEmbeddingLearned(N_steps) 98 | else: 99 | raise ValueError(f"not supported {args.position_embedding}") 100 | 101 | return position_embedding 102 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # Towards More Practical Group Activity Detection:
A New Benchmark and Model 2 | 3 | ### [Dongkeun Kim](https://dk-kim.github.io/), [Youngkil Song](https://www.linkedin.com/in/youngkil-song-8936792a3/), [Minsu Cho](https://cvlab.postech.ac.kr/~mcho/), [Suha Kwak](https://suhakwak.github.io/) 4 | 5 | ### [Project Page](http://dk-kim.github.io/CAFE) | [Paper](https://arxiv.org/abs/2312.02878) 6 | 7 | ## Overview 8 | This work introduces the new benchmark dataset, Café, and a new model for group activity detection (GAD). 9 | 10 | ## Requirements 11 | 12 | - Ubuntu 20.04 13 | - Python 3.8.5 14 | - CUDA 11.0 15 | - PyTorch 1.7.1 16 | 17 | ## Conda environment installation 18 | conda env create --file environment.yml 19 | 20 | conda activate gad 21 | 22 | pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 -f https://download.pytorch.org/whl/torch_stable.html 23 | 24 | ## Install additional package 25 | sh scripts/setup.sh 26 | 27 | ## Download datasets 28 | 29 | Download Café dataset from:
30 | 31 | https://cvlab.postech.ac.kr/research/CAFE/ 32 | 33 | ## Download trained weights 34 | 35 | sh scripts/download_checkpoints.sh 36 | 37 | or from:
38 | https://drive.google.com/file/d/1W_2gkzARCzSdK8Db4G4pkzN3GrJTYo8R/view?usp=drive_link 39 | 40 | ## Run test scripts 41 | 42 | - Café dataset (split by view) 43 | 44 | sh scripts/test_cafe_view.sh 45 | 46 | - Café dataset (split by place) 47 | 48 | 49 | sh scripts/test_cafe_place.sh 50 | 51 | ## Run train scripts 52 | 53 | - Café dataset (split by view) 54 | 55 | 56 | sh scripts/train_cafe_view.sh 57 | 58 | - Café dataset (split by place) 59 | 60 | 61 | sh scripts/train_cafe_place.sh 62 | 63 | 64 | ## File structure 65 | 66 | ├── Dataset/ 67 | │ └── cafe/ 68 | │ └── gt_tracks.pkl 69 | ├── dataloader/ 70 | ├── evaluation/ 71 | │ └── gt_tracks.txt 72 | ├── label_map/ 73 | ├── models/ 74 | ├── scripts/ 75 | └── util/ 76 | train.py 77 | test.py 78 | environment.yml 79 | README.md 80 | 81 | ## Citation 82 | If you find our work useful, please consider citing our paper: 83 | ```BibTeX 84 | @article{kim2023towards, 85 | title={Towards More Practical Group Activity Detection: A New Benchmark and Model}, 86 | author={Kim, Dongkeun and Song, Youngkil and Cho, Minsu and Kwak, Suha}, 87 | journal={arXiv preprint arXiv:2312.02878}, 88 | year={2023} 89 | } 90 | ``` 91 | 92 | ## Acknowledgement 93 | This work was supported by the NRF grant and the IITP grant funded by Ministry of Science and ICT, Korea (RS-2019-II191906, IITP-2020-0-00842, NRF-2021R1A2C3012728, RS-2022-II220264). 94 | -------------------------------------------------------------------------------- /scripts/download_checkpoints.sh: -------------------------------------------------------------------------------- 1 | wget --load-cookies ~/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies ~/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/file/d/1M0vmruhcU_SpYQL0SPqUcg-wKlfyOyWa/view?usp=drive_link' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1M0vmruhcU_SpYQL0SPqUcg-wKlfyOyWa" -O cafe_checkpoints.zip && rm -rf ~/cookies.txt 2 | unzip cafe_checkpoints.zip 3 | rm cafe_checkpoints.zip -------------------------------------------------------------------------------- /scripts/download_datasets.sh: -------------------------------------------------------------------------------- 1 | wget https://cvlab.postech.ac.kr/research/CAFE/Cafe_Dataset.zip 2 | unzip Dataset.zip 3 | rm Dataset.zip 4 | -------------------------------------------------------------------------------- /scripts/setup.sh: -------------------------------------------------------------------------------- 1 | wget https://github.com/longcw/RoIAlign.pytorch/archive/refs/heads/master.zip 2 | unzip master.zip 3 | rm master.zip 4 | cd RoIAlign.pytorch-master/ 5 | python setup.py install 6 | mv roi_align/ roi_align_torch/ 7 | cd ../ -------------------------------------------------------------------------------- /scripts/test_cafe_place.sh: -------------------------------------------------------------------------------- 1 | python test.py --data_path Dataset/ --split 'place' --model_path cafe_place.pth -------------------------------------------------------------------------------- /scripts/test_cafe_view.sh: -------------------------------------------------------------------------------- 1 | python test.py --data_path Dataset/ --split 'view' --model_path cafe_view.pth --random_seed 11 2 | -------------------------------------------------------------------------------- /scripts/train_cafe_place.sh: -------------------------------------------------------------------------------- 1 | python train.py --data_path Dataset/ --split 'place' -------------------------------------------------------------------------------- /scripts/train_cafe_view.sh: -------------------------------------------------------------------------------- 1 | python train.py --data_path Dataset/ --split 'view' -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data as data 5 | 6 | import os 7 | import math 8 | import sys 9 | import copy 10 | import time 11 | import random 12 | import numpy as np 13 | import argparse 14 | 15 | from models import build_model 16 | from util.utils import * 17 | import util.misc as utils 18 | import util.logger as loggers 19 | from dataloader.dataloader import read_dataset 20 | import evaluation.cafe_eval as evaluation 21 | 22 | parser = argparse.ArgumentParser(description='Group Activity Detection train code', add_help=False) 23 | 24 | # Dataset specification 25 | parser.add_argument('--dataset', default='cafe', type=str, help='dataset name') 26 | parser.add_argument('--val_mode', action='store_true') 27 | parser.add_argument('--split', default='place', type=str, help='dataset split. place or view') 28 | parser.add_argument('--data_path', default='../Dataset/', type=str, help='data path') 29 | parser.add_argument('--image_width', default=1280, type=int, help='Image width to resize') 30 | parser.add_argument('--image_height', default=720, type=int, help='Image height to resize') 31 | parser.add_argument('--random_sampling', action='store_true', help='random sampling strategy') 32 | parser.add_argument('--num_frame', default=5, type=int, help='number of frames for each clip') 33 | parser.add_argument('--num_class', default=6, type=int, help='number of activity classes') 34 | 35 | # Backbone parameters 36 | parser.add_argument('--backbone', default='resnet18', type=str, help='feature extraction backbone') 37 | parser.add_argument('--dilation', action='store_true', help='use dilation or not') 38 | parser.add_argument('--frozen_batch_norm', action='store_true', help='use frozen batch normalization') 39 | parser.add_argument('--hidden_dim', default=256, type=int, help='transformer channel dimension') 40 | 41 | # RoI Align parameters 42 | parser.add_argument('--num_boxes', default=14, type=int, help='maximum number of actors') 43 | parser.add_argument('--crop_size', default=5, type=int, help='roi align crop size') 44 | 45 | # Group Transformer 46 | parser.add_argument('--gar_nheads', default=4, type=int, help='number of heads') 47 | parser.add_argument('--gar_enc_layers', default=6, type=int, help='number of group transformer layers') 48 | parser.add_argument('--gar_ffn_dim', default=512, type=int, help='feed forward network dimension') 49 | parser.add_argument('--position_embedding', default='sine', type=str, help='various position encoding') 50 | parser.add_argument('--num_group_tokens', default=12, type=int, help='number of group tokens') 51 | parser.add_argument('--aux_loss', action='store_true') 52 | parser.add_argument('--group_threshold', default=0.5, type=float, help='post processing threshold') 53 | parser.add_argument('--distance_threshold', default=0.2, type=float, help='distance mask threshold') 54 | 55 | # Loss option 56 | parser.add_argument('--temperature', default=0.2, type=float, help='consistency loss temperature') 57 | 58 | # Loss coefficients (Individual) 59 | parser.add_argument('--ce_loss_coef', default=1, type=float) 60 | parser.add_argument('--eos_coef', default=1, type=float, 61 | help="Relative classification weight of the no-object class") 62 | 63 | # Loss coefficients (Group) 64 | parser.add_argument('--group_eos_coef', default=1, type=float) 65 | parser.add_argument('--group_ce_loss_coef', default=1, type=float) 66 | parser.add_argument('--group_code_loss_coef', default=5, type=float) 67 | parser.add_argument('--consistency_loss_coef', default=2, type=float) 68 | 69 | # Matcher (Group) 70 | parser.add_argument('--set_cost_group_class', default=1, type=float, 71 | help="Class coefficient in the matching cost") 72 | parser.add_argument('--set_cost_membership', default=1, type=float, 73 | help="Membership coefficient in the matching cost") 74 | 75 | # Training parameters 76 | parser.add_argument('--random_seed', default=1, type=int, help='random seed for reproduction') 77 | parser.add_argument('--batch', default=16, type=int, help='Batch size') 78 | parser.add_argument('--test_batch', default=16, type=int, help='Test batch size') 79 | parser.add_argument('--drop_rate', default=0.1, type=float, help='Dropout rate') 80 | # GPU 81 | parser.add_argument('--device', default="0, 1", type=str, help='GPU device') 82 | parser.add_argument('--distributed', action='store_true') 83 | 84 | # Load model 85 | parser.add_argument('--model_path', default="", type=str, help='pretrained model path') 86 | 87 | # Visualization 88 | parser.add_argument('--result_path', default="./outputs/") 89 | 90 | # Evaluation 91 | parser.add_argument('--groundtruth', default='./evaluation/gt_tracks.txt', type=argparse.FileType("r")) 92 | parser.add_argument('--labelmap', default='./label_map/group_action_list.pbtxt', type=argparse.FileType("r")) 93 | parser.add_argument('--giou_thresh', default=1.0, type=float) 94 | parser.add_argument('--eval_type', default="gt_base", type=str, help='gt_based or detection_based') 95 | 96 | args = parser.parse_args() 97 | path = None 98 | 99 | SEQS_CAFE = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] 100 | 101 | ACTIVITIES = ['Queueing', 'Ordering', 'Drinking', 'Working', 'Fighting', 'Selfie', 'Individual', 'No'] 102 | 103 | 104 | def main(): 105 | global args, path 106 | 107 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 108 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 109 | 110 | time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) 111 | exp_name = '[%s]_GAD_<%s>' % (args.dataset, time_str) 112 | 113 | path = args.result_path + exp_name 114 | if not os.path.exists(path): 115 | os.makedirs(path) 116 | 117 | # set random seed 118 | random.seed(args.random_seed) 119 | np.random.seed(args.random_seed) 120 | torch.manual_seed(args.random_seed) 121 | torch.cuda.manual_seed(args.random_seed) 122 | torch.cuda.manual_seed_all(args.random_seed) 123 | torch.backends.cudnn.deterministic = True 124 | torch.backends.cudnn.benchmark = False 125 | 126 | _, test_set = read_dataset(args) 127 | sampler_test = data.RandomSampler(test_set) 128 | 129 | test_loader = data.DataLoader(test_set, args.test_batch, sampler=sampler_test, drop_last=False, 130 | collate_fn=collate_fn, num_workers=4, pin_memory=True) 131 | 132 | model, criterion = build_model(args) 133 | model = torch.nn.DataParallel(model).cuda() 134 | 135 | pretrained_dict = torch.load(args.model_path)['state_dict'] 136 | new_state_dict = model.state_dict() 137 | for k, v in pretrained_dict.items(): 138 | if k in new_state_dict: 139 | new_state_dict.update({k:v}) 140 | 141 | model.load_state_dict(new_state_dict) 142 | 143 | metrics = evaluation.GAD_Evaluation(args) 144 | 145 | test_log, result = validate(test_loader, model, criterion, metrics) 146 | print("group mAP at 1.0: %.2f" % result['group_mAP_1.0']) 147 | print("group mAP at 0.5: %.2f" % result['group_mAP_0.5']) 148 | print("outlier mIoU: %.2f" % result['outlier_mIoU']) 149 | 150 | 151 | @torch.no_grad() 152 | def validate(test_loader, model, criterion, metrics): 153 | model.eval() 154 | criterion.eval() 155 | 156 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ") 157 | header = 'Evaluation Inference: ' 158 | 159 | print_freq = len(test_loader) 160 | name_to_vid = {name: i + 1 for i, name in enumerate(SEQS_CAFE)} 161 | file_path = path + '/pred_group_test_%s.txt' % args.split 162 | 163 | for i, (images, targets, infos) in enumerate(metric_logger.log_every(test_loader, print_freq, header)): 164 | images = images.cuda() # [B, T, 3, H, W] 165 | targets = [{k: v.cuda() for k, v in t.items()} for t in targets] 166 | 167 | boxes = torch.stack([t['boxes'] for t in targets]) 168 | dummy_mask = torch.stack([t['actions'] == args.num_class + 1 for t in targets]).squeeze() 169 | 170 | # compute output 171 | outputs = model(images, boxes, dummy_mask) 172 | 173 | loss_dict = criterion(outputs, targets) 174 | weight_dict = criterion.weight_dict 175 | 176 | # reduce losses over all GPUs for logging purposes 177 | loss_dict_reduced = utils.reduce_dict(loss_dict) 178 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 179 | for k, v in loss_dict_reduced.items() if k in weight_dict} 180 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 181 | for k, v in loss_dict_reduced.items()} 182 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), 183 | **loss_dict_reduced_scaled, 184 | **loss_dict_reduced_unscaled) 185 | 186 | metric_logger.update(group_class_error=loss_dict_reduced['group_class_error']) 187 | 188 | make_txt(boxes, infos, outputs, name_to_vid, file_path) 189 | 190 | # gather the stats from all processes 191 | metric_logger.synchronize_between_processes() 192 | print("Averaged stats:", metric_logger) 193 | 194 | detections = open(file_path, "r") 195 | result = metrics.evaluate(detections) 196 | 197 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, result 198 | 199 | 200 | def make_txt(boxes, infos, outputs, name_to_vid, file_path): 201 | for b in range(boxes.shape[0]): 202 | for t in range(boxes.shape[1]): 203 | image_w, image_h = args.image_width, args.image_height 204 | 205 | pred_group_actions = outputs['pred_activities'][b] 206 | pred_group_actions = F.softmax(pred_group_actions, dim=1) 207 | members = outputs['membership'][b] 208 | 209 | pred_membership = torch.argmax(members.transpose(0, 1), dim=1).detach().cpu() 210 | keep_membership = members.transpose(0, 1).max(-1).values > args.group_threshold 211 | pred_group_action = torch.argmax(pred_group_actions, dim=1).detach().cpu() 212 | 213 | for box_idx in range(boxes.shape[2]): 214 | x, y, w, h = boxes[b][t][box_idx] 215 | x1, y1, x2, y2 = (x - w / 2) * image_w, (y - h / 2) * image_h, (x + w / 2) * image_w, ( 216 | y + h / 2) * image_h 217 | 218 | pred_group_id = pred_membership[box_idx] 219 | pred_group_action_idx = pred_group_action[pred_group_id] 220 | pred_group_action_prob = pred_group_actions[pred_group_id][pred_group_action_idx] 221 | 222 | if not (x1 == 0 and y1 == 0 and x2 == 0 and y2 == 0): 223 | if pred_group_action_idx != (pred_group_actions.shape[-1] - 1): 224 | if bool(keep_membership[box_idx]) is False: 225 | pred_group_id = -1 226 | pred_group_action_idx = args.num_class 227 | 228 | pred_list = [name_to_vid[infos[b]['vid']], infos[b]['sid'], infos[b]['fid'][t], 229 | int(x1), int(y1), int(x2), int(y2), 230 | int(pred_group_id), int(pred_group_action_idx) + 1, 231 | float(pred_group_action_prob)] 232 | str_to_be_added = [str(k) for k in pred_list] 233 | str_to_be_added = (" ".join(str_to_be_added)) 234 | 235 | f = open(file_path, "a+") 236 | f.write(str_to_be_added + "\r\n") 237 | f.close() 238 | 239 | 240 | def collate_fn(batch): 241 | batch = list(zip(*batch)) 242 | batch[0] = torch.stack([image for image in batch[0]]) 243 | return tuple(batch) 244 | 245 | 246 | if __name__ == '__main__': 247 | main() 248 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.utils.data as data 5 | 6 | import os 7 | import math 8 | import sys 9 | import copy 10 | import time 11 | import random 12 | import numpy as np 13 | import argparse 14 | 15 | from models import build_model 16 | from util.utils import * 17 | import util.misc as utils 18 | import util.logger as loggers 19 | from dataloader.dataloader import read_dataset 20 | import evaluation.cafe_eval as evaluation 21 | 22 | parser = argparse.ArgumentParser(description='Group Activity Detection train code', add_help=False) 23 | 24 | # Dataset specification 25 | parser.add_argument('--dataset', default='cafe', type=str, help='dataset name') 26 | parser.add_argument('--val_mode', action='store_true') 27 | parser.add_argument('--split', default='place', type=str, help='dataset split. place or view') 28 | parser.add_argument('--data_path', default='../Dataset/', type=str, help='data path') 29 | parser.add_argument('--image_width', default=1280, type=int, help='Image width to resize') 30 | parser.add_argument('--image_height', default=720, type=int, help='Image height to resize') 31 | parser.add_argument('--random_sampling', action='store_true', help='random sampling strategy') 32 | parser.add_argument('--num_frame', default=5, type=int, help='number of frames for each clip') 33 | parser.add_argument('--num_class', default=6, type=int, help='number of activity classes') 34 | 35 | # Backbone parameters 36 | parser.add_argument('--backbone', default='resnet18', type=str, help='feature extraction backbone') 37 | parser.add_argument('--dilation', action='store_true', help='use dilation or not') 38 | parser.add_argument('--frozen_batch_norm', action='store_true', help='use frozen batch normalization') 39 | parser.add_argument('--hidden_dim', default=256, type=int, help='transformer channel dimension') 40 | 41 | # RoI Align parameters 42 | parser.add_argument('--num_boxes', default=14, type=int, help='maximum number of actors') 43 | parser.add_argument('--crop_size', default=5, type=int, help='roi align crop size') 44 | 45 | # Group Transformer 46 | parser.add_argument('--gar_nheads', default=4, type=int, help='number of heads') 47 | parser.add_argument('--gar_enc_layers', default=6, type=int, help='number of group transformer layers') 48 | parser.add_argument('--gar_ffn_dim', default=512, type=int, help='feed forward network dimension') 49 | parser.add_argument('--position_embedding', default='sine', type=str, help='various position encoding') 50 | parser.add_argument('--num_group_tokens', default=12, type=int, help='number of group tokens') 51 | parser.add_argument('--aux_loss', action='store_true') 52 | parser.add_argument('--group_threshold', default=0.5, type=float, help='post processing threshold') 53 | parser.add_argument('--distance_threshold', default=0.2, type=float, help='distance mask threshold') 54 | 55 | # Loss option 56 | parser.add_argument('--temperature', default=0.2, type=float, help='consistency loss temperature') 57 | 58 | # Loss coefficients (Individual) 59 | parser.add_argument('--ce_loss_coef', default=1, type=float) 60 | parser.add_argument('--eos_coef', default=1, type=float, 61 | help="Relative classification weight of the no-object class") 62 | 63 | # Loss coefficients (Group) 64 | parser.add_argument('--group_eos_coef', default=1, type=float) 65 | parser.add_argument('--group_ce_loss_coef', default=1, type=float) 66 | parser.add_argument('--group_code_loss_coef', default=5, type=float) 67 | parser.add_argument('--consistency_loss_coef', default=2, type=float) 68 | 69 | # Matcher (Group) 70 | parser.add_argument('--set_cost_group_class', default=1, type=float, 71 | help="Class coefficient in the matching cost") 72 | parser.add_argument('--set_cost_membership', default=1, type=float, 73 | help="Membership coefficient in the matching cost") 74 | 75 | # Training parameters 76 | parser.add_argument('--random_seed', default=1, type=int, help='random seed for reproduction') 77 | parser.add_argument('--epochs', default=30, type=int, help='Max epochs') 78 | parser.add_argument('--test_freq', default=1, type=int, help='print frequency') 79 | parser.add_argument('--batch', default=16, type=int, help='Batch size') 80 | parser.add_argument('--test_batch', default=16, type=int, help='Test batch size') 81 | parser.add_argument('--lr', default=1e-5, type=float, help='Initial learning rate') 82 | parser.add_argument('--max_lr', default=1e-4, type=float, help='Max learning rate') 83 | parser.add_argument('--lr_step', default=4, type=int, help='step size for learning rate scheduler') 84 | parser.add_argument('--lr_step_down', default=25, type=int, help='step down size (cyclic) for learning rate scheduler') 85 | parser.add_argument('--weight_decay', default=1e-4, type=float, help='weight decay') 86 | parser.add_argument('--drop_rate', default=0.1, type=float, help='Dropout rate') 87 | parser.add_argument('--gradient_clipping', action='store_true', help='use gradient clipping') 88 | parser.add_argument('--max_norm', default=1.0, type=float, help='gradient clipping max norm') 89 | 90 | # GPU 91 | parser.add_argument('--device', default="0, 1", type=str, help='GPU device') 92 | parser.add_argument('--distributed', action='store_true') 93 | 94 | # Load model 95 | parser.add_argument('--load_model', action='store_true', help='load model') 96 | parser.add_argument('--model_path', default="", type=str, help='pretrained model path') 97 | 98 | # Visualization 99 | parser.add_argument('--result_path', default="./outputs/") 100 | 101 | # Evaluation 102 | parser.add_argument('--groundtruth', default='./evaluation/gt_tracks.txt', type=argparse.FileType("r")) 103 | parser.add_argument('--labelmap', default='./label_map/group_action_list.pbtxt', type=argparse.FileType("r")) 104 | parser.add_argument('--giou_thresh', default=1.0, type=float) 105 | parser.add_argument('--eval_type', default="gt_base", type=str, help='gt_based or detection_based') 106 | 107 | args = parser.parse_args() 108 | path = None 109 | 110 | SEQS_CAFE = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24] 111 | 112 | ACTIVITIES = ['Queueing', 'Ordering', 'Drinking', 'Working', 'Fighting', 'Selfie', 'Individual', 'No'] 113 | 114 | 115 | def main(): 116 | global args, path 117 | 118 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 119 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 120 | 121 | time_str = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) 122 | exp_name = '[%s]_GAD_<%s>' % (args.dataset, time_str) 123 | save_path = './result/%s' % exp_name 124 | 125 | # set random seed 126 | random.seed(args.random_seed) 127 | np.random.seed(args.random_seed) 128 | torch.manual_seed(args.random_seed) 129 | torch.cuda.manual_seed(args.random_seed) 130 | torch.cuda.manual_seed_all(args.random_seed) 131 | torch.backends.cudnn.deterministic = True 132 | torch.backends.cudnn.benchmark = False 133 | 134 | train_set, test_set = read_dataset(args) 135 | 136 | # for variable length input 137 | if args.distributed: 138 | sampler_train = data.DistributedSampler(train_set, shuffle=True) 139 | sampler_test = data.DistributedSampler(test_set, shuffle=False) 140 | else: 141 | sampler_train = data.RandomSampler(train_set) 142 | sampler_test = data.RandomSampler(test_set) 143 | 144 | batch_sampler_train = data.BatchSampler(sampler_train, args.batch, drop_last=True) 145 | 146 | train_loader = data.DataLoader(train_set, batch_sampler=batch_sampler_train, 147 | collate_fn=collate_fn, num_workers=4, pin_memory=True) 148 | test_loader = data.DataLoader(test_set, args.test_batch, sampler=sampler_test, drop_last=False, 149 | collate_fn=collate_fn, num_workers=4, pin_memory=True) 150 | 151 | model, criterion = build_model(args) 152 | model = torch.nn.DataParallel(model).cuda() 153 | 154 | # get the number of model parameters 155 | parameters = 'Number of full model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])) 156 | print_log(save_path, '--------------------Number of parameters--------------------') 157 | print_log(save_path, parameters) 158 | 159 | # define loss function and optimizer 160 | optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.999), eps=1e-8, 161 | weight_decay=args.weight_decay) 162 | 163 | scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, args.lr, args.max_lr, step_size_up=args.lr_step, 164 | step_size_down=args.lr_step_down, mode='triangular2', 165 | cycle_momentum=False) 166 | 167 | if args.load_model: 168 | checkpoint = torch.load(args.model_path) 169 | model.load_state_dict(checkpoint['state_dict']) 170 | scheduler.load_state_dict(checkpoint['scheduler']) 171 | optimizer.load_state_dict(checkpoint['optimizer']) 172 | start_epoch = checkpoint['epoch'] + 1 173 | else: 174 | start_epoch = 1 175 | 176 | path = args.result_path + exp_name 177 | if not os.path.exists(path): 178 | os.makedirs(path) 179 | 180 | metrics = evaluation.GAD_Evaluation(args) 181 | 182 | # training phase 183 | for epoch in range(start_epoch, args.epochs + 1): 184 | print_log(save_path, '----- %s at epoch #%d' % ("Train", epoch)) 185 | train_log = train(train_loader, model, criterion, optimizer, epoch) 186 | print_log(save_path, 'Loss: %.4f' % (train_log['loss'])) 187 | print_log(save_path, 'Group class error: %.2f' % (train_log['group_class_error'])) 188 | print('Current learning rate is %f' % scheduler.get_last_lr()[0]) 189 | scheduler.step() 190 | 191 | if epoch % args.test_freq == 0: 192 | print_log(save_path, '----- %s at epoch #%d' % ("Test", epoch)) 193 | test_log, result = validate(test_loader, model, criterion, metrics, epoch) 194 | print_log(save_path, 'Loss: %.4f' % (test_log['loss'])) 195 | print_log(save_path, 'Group class error: %.2f' % (test_log['group_class_error'])) 196 | print_log(save_path, "group mAP at 1.0: %.2f" % result['group_mAP_1.0']) 197 | print_log(save_path, "group mAP at 0.5: %.2f" % result['group_mAP_0.5']) 198 | print_log(save_path, "outlier mIoU: %.2f" % result['outlier_mIoU']) 199 | 200 | state = { 201 | 'epoch': epoch, 202 | 'state_dict': model.state_dict(), 203 | 'optimizer': optimizer.state_dict(), 204 | 'scheduler': scheduler.state_dict(), 205 | } 206 | result_path = save_path + '/epoch%d.pth' % epoch 207 | torch.save(state, result_path) 208 | 209 | 210 | def train(train_loader, model, criterion, optimizer, epoch): 211 | model.train() 212 | criterion.train() 213 | 214 | # logger 215 | metric_logger = loggers.MetricLogger(mode="train", delimiter=" ") 216 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 217 | space_fmt = str(len(str(args.epochs))) 218 | header = 'Epoch [{start_epoch: >{fill}}/{end_epoch}]'.format(start_epoch=epoch, end_epoch=args.epochs, 219 | fill=space_fmt) 220 | print_freq = len(train_loader) 221 | 222 | for i, (images, targets, infos) in enumerate(metric_logger.log_every(train_loader, print_freq, header)): 223 | images = images.cuda() # [B, T, 3, H, W] 224 | targets = [{k: v.cuda() for k, v in t.items()} for t in targets] 225 | 226 | boxes = torch.stack([t['boxes'] for t in targets]) 227 | dummy_mask = torch.stack([t['actions'] == args.num_class + 1 for t in targets]).squeeze() 228 | 229 | num_batch = images.shape[0] 230 | num_frame = images.shape[1] 231 | 232 | # compute output 233 | outputs = model(images, boxes, dummy_mask) 234 | 235 | loss_dict = criterion(outputs, targets, log=False) 236 | weight_dict = criterion.weight_dict 237 | 238 | loss = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 239 | 240 | # reduce losses over all GPUs for logging purposes 241 | loss_dict_reduced = utils.reduce_dict(loss_dict) 242 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 243 | for k, v in loss_dict_reduced.items()} 244 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 245 | for k, v in loss_dict_reduced.items() if k in weight_dict} 246 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 247 | loss_value = losses_reduced_scaled.item() 248 | 249 | if not math.isfinite(loss_value): 250 | print("Loss is {}, stopping training".format(loss_value)) 251 | print(loss_dict_reduced) 252 | sys.exit(1) 253 | 254 | # compute gradient and do SGD step 255 | optimizer.zero_grad() 256 | loss.backward() 257 | if args.gradient_clipping: 258 | nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) 259 | optimizer.step() 260 | 261 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 262 | metric_logger.update(group_class_error=loss_dict_reduced['group_class_error']) 263 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 264 | 265 | metric_logger.synchronize_between_processes() 266 | print("Averaged stats:", metric_logger) 267 | 268 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 269 | 270 | 271 | @torch.no_grad() 272 | def validate(test_loader, model, criterion, metrics, epoch): 273 | model.eval() 274 | criterion.eval() 275 | 276 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ") 277 | metric_logger.add_meter('group_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 278 | header = 'Evaluation Inference: ' 279 | 280 | print_freq = len(test_loader) 281 | name_to_vid = {name: i + 1 for i, name in enumerate(SEQS_CAFE)} 282 | file_path = path + '/pred_group_epoch_%d.txt' % epoch 283 | 284 | for i, (images, targets, infos) in enumerate(metric_logger.log_every(test_loader, print_freq, header)): 285 | images = images.cuda() # [B, T, 3, H, W] 286 | targets = [{k: v.cuda() for k, v in t.items()} for t in targets] 287 | 288 | boxes = torch.stack([t['boxes'] for t in targets]) 289 | dummy_mask = torch.stack([t['actions'] == args.num_class + 1 for t in targets]).squeeze() 290 | 291 | # compute output 292 | outputs = model(images, boxes, dummy_mask) 293 | 294 | loss_dict = criterion(outputs, targets) 295 | weight_dict = criterion.weight_dict 296 | 297 | # reduce losses over all GPUs for logging purposes 298 | loss_dict_reduced = utils.reduce_dict(loss_dict) 299 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 300 | for k, v in loss_dict_reduced.items() if k in weight_dict} 301 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 302 | for k, v in loss_dict_reduced.items()} 303 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), 304 | **loss_dict_reduced_scaled, 305 | **loss_dict_reduced_unscaled) 306 | 307 | metric_logger.update(group_class_error=loss_dict_reduced['group_class_error']) 308 | 309 | make_txt(boxes, infos, outputs, name_to_vid, file_path) 310 | 311 | # gather the stats from all processes 312 | metric_logger.synchronize_between_processes() 313 | print("Averaged stats:", metric_logger) 314 | 315 | detections = open(file_path, "r") 316 | result = metrics.evaluate(detections) 317 | 318 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}, result 319 | 320 | 321 | def make_txt(boxes, infos, outputs, name_to_vid, file_path): 322 | for b in range(boxes.shape[0]): 323 | for t in range(boxes.shape[1]): 324 | image_w, image_h = args.image_width, args.image_height 325 | 326 | pred_group_actions = outputs['pred_activities'][b] 327 | pred_group_actions = F.softmax(pred_group_actions, dim=1) 328 | members = outputs['membership'][b] 329 | 330 | pred_membership = torch.argmax(members.transpose(0, 1), dim=1).detach().cpu() 331 | keep_membership = members.transpose(0, 1).max(-1).values > args.group_threshold 332 | pred_group_action = torch.argmax(pred_group_actions, dim=1).detach().cpu() 333 | 334 | for box_idx in range(boxes.shape[2]): 335 | x, y, w, h = boxes[b][t][box_idx] 336 | x1, y1, x2, y2 = (x - w / 2) * image_w, (y - h / 2) * image_h, (x + w / 2) * image_w, ( 337 | y + h / 2) * image_h 338 | 339 | pred_group_id = pred_membership[box_idx] 340 | pred_group_action_idx = pred_group_action[pred_group_id] 341 | pred_group_action_prob = pred_group_actions[pred_group_id][pred_group_action_idx] 342 | 343 | if not (x1 == 0 and y1 == 0 and x2 == 0 and y2 == 0): 344 | if pred_group_action_idx != (pred_group_actions.shape[-1] - 1): 345 | if bool(keep_membership[box_idx]) is False: 346 | pred_group_id = -1 347 | pred_group_action_idx = args.num_class 348 | 349 | pred_list = [name_to_vid[infos[b]['vid']], infos[b]['sid'], infos[b]['fid'][t], 350 | int(x1), int(y1), int(x2), int(y2), 351 | int(pred_group_id), int(pred_group_action_idx) + 1, 352 | float(pred_group_action_prob)] 353 | str_to_be_added = [str(k) for k in pred_list] 354 | str_to_be_added = (" ".join(str_to_be_added)) 355 | 356 | f = open(file_path, "a+") 357 | f.write(str_to_be_added + "\r\n") 358 | f.close() 359 | 360 | 361 | def collate_fn(batch): 362 | batch = list(zip(*batch)) 363 | batch[0] = torch.stack([image for image in batch[0]]) 364 | return tuple(batch) 365 | 366 | 367 | if __name__ == '__main__': 368 | main() 369 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dk-kim/CAFE_codebase/caee21a8ccdd5faadecbc5a5034107084a7cd270/util/__init__.py -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | The boxes should be in [x0, y0, x1, y1] format 44 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 45 | and M = len(boxes2) 46 | """ 47 | # degenerate boxes gives inf / nan results 48 | # so do an early check 49 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 50 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 51 | iou, union = box_iou(boxes1, boxes2) 52 | 53 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 54 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 55 | 56 | wh = (rb - lt).clamp(min=0) # [N,M,2] 57 | area = wh[:, :, 0] * wh[:, :, 1] 58 | 59 | return iou - (area - union) / area 60 | 61 | 62 | def masks_to_boxes(masks): 63 | """Compute the bounding boxes around the provided masks 64 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 65 | Returns a [N, 4] tensors, with the boxes in xyxy format 66 | """ 67 | if masks.numel() == 0: 68 | return torch.zeros((0, 4), device=masks.device) 69 | 70 | h, w = masks.shape[-2:] 71 | 72 | y = torch.arange(0, h, dtype=torch.float) 73 | x = torch.arange(0, w, dtype=torch.float) 74 | y, x = torch.meshgrid(y, x) 75 | 76 | x_mask = (masks * x.unsqueeze(0)) 77 | x_max = x_mask.flatten(1).max(-1)[0] 78 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 79 | 80 | y_mask = (masks * y.unsqueeze(0)) 81 | y_max = y_mask.flatten(1).max(-1)[0] 82 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | return torch.stack([x_min, y_min, x_max, y_max], 1) 85 | 86 | 87 | def rescale_bboxes(out_bbox, size): 88 | img_h, img_w = size 89 | b = box_cxcywh_to_xyxy(out_bbox) 90 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.get_device()) 91 | return b 92 | 93 | 94 | def rescale_pairs(out_pairs, size): 95 | img_h, img_w = size 96 | h_bbox = out_pairs[:, :4] 97 | o_bbox = out_pairs[:, 4:] 98 | 99 | h = box_cxcywh_to_xyxy(h_bbox) 100 | h = h * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(h_bbox.get_device()) 101 | 102 | obj_mask = (o_bbox[:, 0] != -1) 103 | if obj_mask.sum() != 0: 104 | o = box_cxcywh_to_xyxy(o_bbox) 105 | o = o * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(o_bbox.get_device()) 106 | o_bbox[obj_mask] = o[obj_mask] 107 | o = o_bbox 108 | p = torch.cat([h, o], dim=-1) 109 | 110 | return p -------------------------------------------------------------------------------- /util/logger.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : hotr/util/logger.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | import torch 9 | import time 10 | import datetime 11 | import sys 12 | from time import sleep 13 | from collections import defaultdict 14 | 15 | from util.misc import SmoothedValue 16 | 17 | def print_params(model): 18 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 19 | print('\n[Logger] Number of params: ', n_parameters) 20 | return n_parameters 21 | 22 | def print_args(args): 23 | print('\n[Logger] DETR Arguments:') 24 | for k, v in vars(args).items(): 25 | if k in [ 26 | 'lr', 'lr_backbone', 'lr_drop', 27 | 'frozen_weights', 28 | 'backbone', 'dilation', 29 | 'position_embedding', 'enc_layers', 'dec_layers', 'num_queries', 30 | 'dataset_file']: 31 | print(f'\t{k}: {v}') 32 | 33 | if args.HOIDet: 34 | print('\n[Logger] DETR_HOI Arguments:') 35 | for k, v in vars(args).items(): 36 | if k in [ 37 | 'freeze_enc', 38 | 'query_flag', 39 | 'hoi_nheads', 40 | 'hoi_dim_feedforward', 41 | 'hoi_dec_layers', 42 | 'hoi_idx_loss_coef', 43 | 'hoi_act_loss_coef', 44 | 'hoi_eos_coef', 45 | 'object_threshold']: 46 | print(f'\t{k}: {v}') 47 | 48 | class MetricLogger(object): 49 | def __init__(self, mode="test", delimiter="\t"): 50 | self.meters = defaultdict(SmoothedValue) 51 | self.delimiter = delimiter 52 | self.mode = mode 53 | 54 | def update(self, **kwargs): 55 | for k, v in kwargs.items(): 56 | if isinstance(v, torch.Tensor): 57 | v = v.item() 58 | assert isinstance(v, (float, int)) 59 | self.meters[k].update(v) 60 | 61 | def __getattr__(self, attr): 62 | if attr in self.meters: 63 | return self.meters[attr] 64 | if attr in self.__dict__: 65 | return self.__dict__[attr] 66 | raise AttributeError("'{}' object has no attribute '{}'".format( 67 | type(self).__name__, attr)) 68 | 69 | def __str__(self): 70 | loss_str = [] 71 | for name, meter in self.meters.items(): 72 | loss_str.append( 73 | "{}: {}".format(name, str(meter)) 74 | ) 75 | return self.delimiter.join(loss_str) 76 | 77 | def synchronize_between_processes(self): 78 | for meter in self.meters.values(): 79 | meter.synchronize_between_processes() 80 | 81 | def add_meter(self, name, meter): 82 | self.meters[name] = meter 83 | 84 | def log_every(self, iterable, print_freq, header=None): 85 | i = 0 86 | if not header: 87 | header = '' 88 | start_time = time.time() 89 | end = time.time() 90 | iter_time = SmoothedValue(fmt='{avg:.4f}') 91 | data_time = SmoothedValue(fmt='{avg:.4f}') 92 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 93 | if torch.cuda.is_available(): 94 | log_msg = self.delimiter.join([ 95 | header, 96 | '[{0' + space_fmt + '}/{1}]', 97 | 'eta: {eta}', 98 | '{meters}', 99 | 'time: {time}', 100 | 'data: {data}', 101 | 'max mem: {memory:.0f}' 102 | ]) 103 | else: 104 | log_msg = self.delimiter.join([ 105 | header, 106 | '[{0' + space_fmt + '}/{1}]', 107 | 'eta: {eta}', 108 | '{meters}', 109 | 'time: {time}', 110 | 'data: {data}' 111 | ]) 112 | MB = 1024.0 * 1024.0 113 | for obj in iterable: 114 | data_time.update(time.time() - end) 115 | yield obj 116 | iter_time.update(time.time() - end) 117 | 118 | if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1: 119 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 120 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 121 | if torch.cuda.is_available(): 122 | print(log_msg.format( 123 | i+1, len(iterable), eta=eta_string, 124 | meters=str(self), 125 | time=str(iter_time), data=str(data_time), 126 | memory=torch.cuda.max_memory_allocated() / MB), 127 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n")) 128 | else: 129 | print(log_msg.format( 130 | i+1, len(iterable), eta=eta_string, 131 | meters=str(self), 132 | time=str(iter_time), data=str(data_time)), 133 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n")) 134 | else: 135 | log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]']) 136 | if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r") 137 | else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r") 138 | 139 | i += 1 140 | end = time.time() 141 | total_time = time.time() - start_time 142 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 143 | if self.mode=='test': print("") 144 | print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format( 145 | self.mode, total_time_str, total_time / len(iterable))) 146 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from HOTR (https://github.com/kakaobrain/HOTR) 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | """ 9 | Misc functions, including distributed helpers. 10 | Mostly copy-paste from torchvision references. 11 | """ 12 | import os 13 | import subprocess 14 | from collections import deque 15 | import pickle 16 | from typing import Optional, List 17 | 18 | import torch 19 | import torch.distributed as dist 20 | from torch import Tensor 21 | 22 | # needed due to empty tensor bug in pytorch and torchvision 0.5 23 | import torchvision 24 | 25 | 26 | class SmoothedValue(object): 27 | """Track a series of values and provide access to smoothed values over a 28 | window or the global series average. 29 | """ 30 | 31 | def __init__(self, window_size=20, fmt=None): 32 | if fmt is None: 33 | fmt = "{median:.4f} ({global_avg:.4f})" 34 | self.deque = deque(maxlen=window_size) 35 | self.total = 0.0 36 | self.count = 0 37 | self.fmt = fmt 38 | 39 | def update(self, value, n=1): 40 | self.deque.append(value) 41 | self.count += n 42 | self.total += value * n 43 | 44 | def synchronize_between_processes(self): 45 | """ 46 | Warning: does not synchronize the deque! 47 | """ 48 | if not is_dist_avail_and_initialized(): 49 | return 50 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 51 | dist.barrier() 52 | dist.all_reduce(t) 53 | t = t.tolist() 54 | self.count = int(t[0]) 55 | self.total = t[1] 56 | 57 | @property 58 | def median(self): 59 | d = torch.tensor(list(self.deque)) 60 | return d.median().item() 61 | 62 | @property 63 | def avg(self): 64 | d = torch.tensor(list(self.deque), dtype=torch.float32) 65 | return d.mean().item() 66 | 67 | @property 68 | def global_avg(self): 69 | return self.total / self.count 70 | 71 | @property 72 | def max(self): 73 | return max(self.deque) 74 | 75 | @property 76 | def value(self): 77 | return self.deque[-1] 78 | 79 | def __str__(self): 80 | return self.fmt.format( 81 | median=self.median, 82 | avg=self.avg, 83 | global_avg=self.global_avg, 84 | max=self.max, 85 | value=self.value) 86 | 87 | 88 | def all_gather(data): 89 | """ 90 | Run all_gather on arbitrary picklable data (not necessarily tensors) 91 | Args: 92 | data: any picklable object 93 | Returns: 94 | list[data]: list of data gathered from each rank 95 | """ 96 | world_size = get_world_size() 97 | if world_size == 1: 98 | return [data] 99 | 100 | # serialized to a Tensor 101 | buffer = pickle.dumps(data) 102 | storage = torch.ByteStorage.from_buffer(buffer) 103 | tensor = torch.ByteTensor(storage).to("cuda") 104 | 105 | # obtain Tensor size of each rank 106 | local_size = torch.tensor([tensor.numel()], device="cuda") 107 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 108 | dist.all_gather(size_list, local_size) 109 | size_list = [int(size.item()) for size in size_list] 110 | max_size = max(size_list) 111 | 112 | # receiving Tensor from all ranks 113 | # we pad the tensor because torch all_gather does not support 114 | # gathering tensors of different shapes 115 | tensor_list = [] 116 | for _ in size_list: 117 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 118 | if local_size != max_size: 119 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 120 | tensor = torch.cat((tensor, padding), dim=0) 121 | dist.all_gather(tensor_list, tensor) 122 | 123 | data_list = [] 124 | for size, tensor in zip(size_list, tensor_list): 125 | buffer = tensor.cpu().numpy().tobytes()[:size] 126 | data_list.append(pickle.loads(buffer)) 127 | 128 | return data_list 129 | 130 | 131 | def reduce_dict(input_dict, average=True): 132 | """ 133 | Args: 134 | input_dict (dict): all the values will be reduced 135 | average (bool): whether to do average or sum 136 | Reduce the values in the dictionary from all processes so that all processes 137 | have the averaged results. Returns a dict with the same fields as 138 | input_dict, after reduction. 139 | """ 140 | world_size = get_world_size() 141 | if world_size < 2: 142 | return input_dict 143 | with torch.no_grad(): 144 | names = [] 145 | values = [] 146 | # sort the keys so that they are consistent across processes 147 | for k in sorted(input_dict.keys()): 148 | names.append(k) 149 | values.append(input_dict[k]) 150 | values = torch.stack(values, dim=0) 151 | dist.all_reduce(values) 152 | if average: 153 | values /= world_size 154 | reduced_dict = {k: v for k, v in zip(names, values)} 155 | return reduced_dict 156 | 157 | 158 | def get_sha(): 159 | cwd = os.path.dirname(os.path.abspath(__file__)) 160 | 161 | def _run(command): 162 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 163 | sha = 'N/A' 164 | diff = "clean" 165 | branch = 'N/A' 166 | try: 167 | sha = _run(['git', 'rev-parse', 'HEAD']) 168 | subprocess.check_output(['git', 'diff'], cwd=cwd) 169 | diff = _run(['git', 'diff-index', 'HEAD']) 170 | diff = "has uncommited changes" if diff else "clean" 171 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 172 | except Exception: 173 | pass 174 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 175 | return message 176 | 177 | 178 | def collate_fn(batch): 179 | print(batch) 180 | batch = list(zip(*batch)) 181 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 182 | return tuple(batch) 183 | 184 | 185 | def _max_by_axis(the_list): 186 | # type: (List[List[int]]) -> List[int] 187 | maxes = the_list[0] 188 | for sublist in the_list[1:]: 189 | for index, item in enumerate(sublist): 190 | maxes[index] = max(maxes[index], item) 191 | return maxes 192 | 193 | 194 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 195 | # TODO make this more general 196 | if tensor_list[0].ndim == 3: 197 | # TODO make it support different-sized images 198 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 199 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 200 | batch_shape = [len(tensor_list)] + max_size 201 | b, c, h, w = batch_shape 202 | dtype = tensor_list[0].dtype 203 | device = tensor_list[0].device 204 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 205 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 206 | for img, pad_img, m in zip(tensor_list, tensor, mask): 207 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 208 | m[: img.shape[1], :img.shape[2]] = False 209 | else: 210 | raise ValueError('not supported') 211 | return NestedTensor(tensor, mask) 212 | 213 | 214 | class NestedTensor(object): 215 | def __init__(self, tensors, mask: Optional[Tensor]): 216 | self.tensors = tensors 217 | self.mask = mask 218 | 219 | def to(self, device): 220 | # type: (Device) -> NestedTensor # noqa 221 | cast_tensor = self.tensors.to(device) 222 | mask = self.mask 223 | if mask is not None: 224 | assert mask is not None 225 | cast_mask = mask.to(device) 226 | else: 227 | cast_mask = None 228 | return NestedTensor(cast_tensor, cast_mask) 229 | 230 | def decompose(self): 231 | return self.tensors, self.mask 232 | 233 | def __repr__(self): 234 | return str(self.tensors) 235 | 236 | 237 | def setup_for_distributed(is_master): 238 | """ 239 | This function disables printing when not in master process 240 | """ 241 | import builtins as __builtin__ 242 | builtin_print = __builtin__.print 243 | 244 | def print(*args, **kwargs): 245 | force = kwargs.pop('force', False) 246 | if is_master or force: 247 | builtin_print(*args, **kwargs) 248 | 249 | __builtin__.print = print 250 | 251 | 252 | def is_dist_avail_and_initialized(): 253 | if not dist.is_available(): 254 | return False 255 | if not dist.is_initialized(): 256 | return False 257 | return True 258 | 259 | 260 | def get_world_size(): 261 | if not is_dist_avail_and_initialized(): 262 | return 1 263 | return dist.get_world_size() 264 | 265 | 266 | def get_rank(): 267 | if not is_dist_avail_and_initialized(): 268 | return 0 269 | return dist.get_rank() 270 | 271 | 272 | def is_main_process(): 273 | return get_rank() == 0 274 | 275 | 276 | def save_on_master(*args, **kwargs): 277 | if is_main_process(): 278 | torch.save(*args, **kwargs) 279 | 280 | 281 | def init_distributed_mode(args): 282 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 283 | args.rank = int(os.environ["RANK"]) 284 | args.world_size = int(os.environ['WORLD_SIZE']) 285 | args.gpu = int(os.environ['LOCAL_RANK']) 286 | elif 'SLURM_PROCID' in os.environ: 287 | args.rank = int(os.environ['SLURM_PROCID']) 288 | args.gpu = args.rank % torch.cuda.device_count() 289 | else: 290 | print('Not using distributed mode') 291 | args.distributed = False 292 | return 293 | 294 | args.distributed = True 295 | 296 | torch.cuda.set_device(args.gpu) 297 | args.dist_backend = 'nccl' 298 | print('| distributed init (rank {}): {}'.format( 299 | args.rank, args.dist_url), flush=True) 300 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 301 | world_size=args.world_size, rank=args.rank) 302 | torch.distributed.barrier() 303 | setup_for_distributed(args.rank == 0) 304 | 305 | 306 | @torch.no_grad() 307 | def accuracy(output, target, topk=(1,)): 308 | """Computes the precision@k for the specified values of k""" 309 | if target.numel() == 0: 310 | return [torch.zeros([], device=output.device)] 311 | maxk = max(topk) 312 | batch_size = target.size(0) 313 | 314 | _, pred = output.topk(maxk, 1, True, True) 315 | pred = pred.t() 316 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 317 | 318 | res = [] 319 | for k in topk: 320 | correct_k = correct[:k].view(-1).float().sum(0) 321 | res.append(correct_k.mul_(100.0 / batch_size)) 322 | return res 323 | 324 | 325 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 326 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 327 | """ 328 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 329 | This will eventually be supported natively by PyTorch, and this 330 | class can go away. 331 | """ 332 | if float(torchvision.__version__[:3]) < 0.7: 333 | if input.numel() > 0: 334 | return torch.nn.functional.interpolate( 335 | input, size, scale_factor, mode, align_corners 336 | ) 337 | 338 | output_shape = _output_size(2, input, size, scale_factor) 339 | output_shape = list(input.shape[:-2]) + list(output_shape) 340 | return _new_empty_tensor(input, output_shape) 341 | else: 342 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) -------------------------------------------------------------------------------- /util/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | 5 | 6 | def print_log(result_path, *args): 7 | os.makedirs(result_path, exist_ok=True) 8 | 9 | print(*args) 10 | file_path = result_path + '/log.txt' 11 | if file_path is not None: 12 | with open(file_path, 'a') as f: 13 | print(*args, file=f) 14 | --------------------------------------------------------------------------------