├── src ├── util │ ├── __init__.py │ ├── box_ops.py │ ├── logger.py │ └── misc.py ├── models │ ├── __init__.py │ ├── feed_forward.py │ ├── position_encoding.py │ ├── detr_matcher.py │ ├── backbone.py │ ├── hotr.py │ ├── post_process.py │ ├── detr.py │ └── hotr_matcher.py ├── engine │ ├── __init__.py │ ├── evaluator_coco.py │ ├── trainer.py │ ├── hico_det_evaluate.py │ ├── evaluator_vcoco.py │ ├── evaluator_hico.py │ └── arg_parser.py ├── data │ ├── datasets │ │ ├── __init__.py │ │ ├── coco.py │ │ ├── builtin_meta.py │ │ └── hico.py │ ├── evaluators │ │ ├── vcoco_eval.py │ │ ├── coco_eval.py │ │ └── hico_eval.py │ └── transforms │ │ └── transforms.py └── metrics │ ├── utils.py │ └── vcoco │ ├── ap_agent.py │ └── ap_role.py ├── imgs └── STIP.png ├── .run ├── STIP_vcoco_single_train.run.xml └── STIP_hicodet_single_train.run.xml ├── .gitignore ├── README.md └── STIP_main.py /src/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /imgs/STIP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zyong812/STIP/HEAD/imgs/STIP.png -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .detr import build 3 | 4 | def build_model(args): 5 | return build(args) -------------------------------------------------------------------------------- /src/models/feed_forward.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torch import nn 3 | 4 | class MLP(nn.Module): 5 | """ Very simple multi-layer perceptron (also called FFN)""" 6 | 7 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 8 | super().__init__() 9 | self.num_layers = num_layers 10 | h = [hidden_dim] * (num_layers - 1) 11 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 12 | 13 | def forward(self, x): 14 | for i, layer in enumerate(self.layers): 15 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 16 | return x -------------------------------------------------------------------------------- /src/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator_vcoco import vcoco_evaluate, vcoco_accumulate 2 | from .evaluator_hico import hico_evaluate 3 | 4 | def hoi_evaluator(args, model, criterion, postprocessors, data_loader, device, thr=0): 5 | if args.dataset_file == 'vcoco': 6 | return vcoco_evaluate(model, criterion, postprocessors, data_loader, device, args.output_dir, thr) 7 | elif args.dataset_file == 'hico-det': 8 | return hico_evaluate(model, postprocessors, data_loader, device, thr, args) 9 | else: raise NotImplementedError 10 | 11 | def hoi_accumulator(args, total_res, print_results=False, wandb=False): 12 | if args.dataset_file == 'vcoco': 13 | return vcoco_accumulate(total_res, args, print_results, wandb) 14 | else: raise NotImplementedError -------------------------------------------------------------------------------- /src/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch.utils.data 3 | import torchvision 4 | 5 | from src.data.datasets.coco import build as build_coco 6 | from src.data.datasets.vcoco import build as build_vcoco 7 | from src.data.datasets.hico import build as build_hico 8 | 9 | def get_coco_api_from_dataset(dataset): 10 | for _ in range(10): # what is this for? 11 | if isinstance(dataset, torch.utils.data.Subset): 12 | dataset = dataset.dataset 13 | if isinstance(dataset, torchvision.datasets.CocoDetection): 14 | return dataset.coco 15 | 16 | 17 | def build_dataset(image_set, args): 18 | if args.dataset_file == 'coco': 19 | return build_coco(image_set, args) 20 | elif args.dataset_file == 'vcoco': 21 | return build_vcoco(image_set, args) 22 | elif args.dataset_file == 'hico-det': 23 | return build_hico(image_set, args) 24 | raise ValueError(f'dataset {args.dataset_file} not supported') -------------------------------------------------------------------------------- /.run/STIP_vcoco_single_train.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 25 | -------------------------------------------------------------------------------- /.run/STIP_hicodet_single_train.run.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .idea 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | pip-wheel-metadata/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | wandb/ 132 | checkpoints/ 133 | res/ 134 | 135 | -------------------------------------------------------------------------------- /src/data/evaluators/vcoco_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) KakaoBrain, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | V-COCO evaluator that works in distributed mode. 4 | """ 5 | import os 6 | import numpy as np 7 | import torch 8 | 9 | from src.util.misc import all_gather 10 | from src.metrics.vcoco.ap_role import APRole 11 | from functools import partial 12 | 13 | def init_vcoco_evaluators(human_act_name, object_act_name): 14 | role_eval1 = APRole(act_name=object_act_name, scenario_flag=True, iou_threshold=0.5) 15 | role_eval2 = APRole(act_name=object_act_name, scenario_flag=False, iou_threshold=0.5) 16 | 17 | return role_eval1, role_eval2 18 | 19 | class VCocoEvaluator(object): 20 | def __init__(self, args): 21 | self.img_ids = [] 22 | self.eval_imgs = [] 23 | self.role_eval1, self.role_eval2 = init_vcoco_evaluators(args.human_actions, args.object_actions) 24 | self.num_human_act = args.num_human_act 25 | self.action_idx = args.valid_ids 26 | 27 | def update(self, outputs): 28 | img_ids = list(np.unique(list(outputs.keys()))) 29 | for img_num, img_id in enumerate(img_ids): 30 | print(f"Evaluating Score Matrix... : [{(img_num+1):>4}/{len(img_ids):<4}]" ,flush=True, end="\r") 31 | prediction = outputs[img_id]['prediction'] 32 | target = outputs[img_id]['target'] 33 | 34 | # score with prediction 35 | hbox, hcat, obox, ocat = list(map(lambda x: prediction[x], \ 36 | ['h_box', 'h_cat', 'o_box', 'o_cat'])) 37 | 38 | assert 'pair_score' in prediction 39 | score = prediction['pair_score'] 40 | 41 | hbox, hcat, obox, ocat, score =\ 42 | list(map(lambda x: x.cpu().numpy(), [hbox, hcat, obox, ocat, score])) 43 | 44 | # ground-truth 45 | gt_h_inds = (target['labels'] == 1) 46 | gt_h_box = target['boxes'][gt_h_inds, :4].cpu().numpy() 47 | gt_h_act = target['inst_actions'][gt_h_inds, :self.num_human_act].cpu().numpy() ## self.num_human_act=26, for filtering GT 48 | 49 | gt_p_box = target['pair_boxes'].cpu().numpy() 50 | gt_p_act = target['pair_actions'].cpu().numpy() 51 | 52 | score = score[self.action_idx, :, :] 53 | gt_p_act = gt_p_act[:, self.action_idx] 54 | 55 | self.role_eval1.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act) # gt_h_act: Nx26, gt_p_act: Nx25 56 | self.role_eval2.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act) 57 | self.img_ids.append(img_id) 58 | -------------------------------------------------------------------------------- /src/engine/evaluator_coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import src.util.misc as utils 4 | import src.util.logger as loggers 5 | from src.data.evaluators.coco_eval import CocoEvaluator 6 | 7 | @torch.no_grad() 8 | def coco_evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir): 9 | model.eval() 10 | criterion.eval() 11 | 12 | metric_logger = loggers.MetricLogger(delimiter=" ") 13 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 14 | header = 'Evaluation' 15 | 16 | iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys()) 17 | coco_evaluator = CocoEvaluator(base_ds, iou_types) 18 | print_freq = len(data_loader) 19 | # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75] 20 | 21 | print("\n>>> [MS-COCO Evaluation] <<<") 22 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 23 | samples = samples.to(device) 24 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 25 | 26 | outputs = model(samples) 27 | loss_dict = criterion(outputs, targets) 28 | weight_dict = criterion.weight_dict 29 | 30 | # reduce losses over all GPUs for logging purposes 31 | loss_dict_reduced = utils.reduce_dict(loss_dict) 32 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 33 | for k, v in loss_dict_reduced.items() if k in weight_dict} 34 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 35 | for k, v in loss_dict_reduced.items()} 36 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), 37 | **loss_dict_reduced_scaled, 38 | **loss_dict_reduced_unscaled) 39 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 40 | 41 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 42 | results = postprocessors['bbox'](outputs, orig_target_sizes) 43 | res = {target['image_id'].item(): output for target, output in zip(targets, results)} 44 | if coco_evaluator is not None: 45 | coco_evaluator.update(res) 46 | 47 | # gather the stats from all processes 48 | metric_logger.synchronize_between_processes() 49 | print("\n>>> [Averaged stats] <<<\n", metric_logger) 50 | if coco_evaluator is not None: 51 | coco_evaluator.synchronize_between_processes() 52 | 53 | # accumulate predictions from all images 54 | if coco_evaluator is not None: 55 | coco_evaluator.accumulate() 56 | coco_evaluator.summarize() 57 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 58 | if coco_evaluator is not None: 59 | if 'bbox' in postprocessors.keys(): 60 | stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() 61 | 62 | return stats, coco_evaluator -------------------------------------------------------------------------------- /src/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from src.util.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors 30 | mask = tensor_list.mask 31 | assert mask is not None 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 48 | return pos 49 | 50 | 51 | class PositionEmbeddingLearned(nn.Module): 52 | """ 53 | Absolute pos embedding, learned. 54 | """ 55 | def __init__(self, num_pos_feats=256): 56 | super().__init__() 57 | self.row_embed = nn.Embedding(50, num_pos_feats) 58 | self.col_embed = nn.Embedding(50, num_pos_feats) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | nn.init.uniform_(self.row_embed.weight) 63 | nn.init.uniform_(self.col_embed.weight) 64 | 65 | def forward(self, tensor_list: NestedTensor): 66 | x = tensor_list.tensors 67 | h, w = x.shape[-2:] 68 | i = torch.arange(w, device=x.device) 69 | j = torch.arange(h, device=x.device) 70 | x_emb = self.col_embed(i) 71 | y_emb = self.row_embed(j) 72 | pos = torch.cat([ 73 | x_emb.unsqueeze(0).repeat(h, 1, 1), 74 | y_emb.unsqueeze(1).repeat(1, w, 1), 75 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 76 | return pos 77 | 78 | 79 | def build_position_encoding(args): 80 | N_steps = args.hidden_dim // 2 81 | if args.position_embedding in ('v2', 'sine'): 82 | # TODO find a better way of exposing other arguments 83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 84 | elif args.position_embedding in ('v3', 'learned'): 85 | position_embedding = PositionEmbeddingLearned(N_steps) 86 | else: 87 | raise ValueError(f"not supported {args.position_embedding}") 88 | 89 | return position_embedding -------------------------------------------------------------------------------- /src/engine/trainer.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : engine/trainer.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | import math 9 | import torch 10 | import sys 11 | import src.util.misc as utils 12 | import src.util.logger as loggers 13 | from typing import Iterable 14 | import wandb 15 | from src.models.stip_utils import check_annotation, plot_cross_attention, plot_hoi_results 16 | 17 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 18 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 19 | device: torch.device, epoch: int, max_epoch: int, max_norm: float = 0, dataset_file: str = 'coco', log: bool = False): 20 | model.train() 21 | criterion.train() 22 | metric_logger = loggers.MetricLogger(mode="train", delimiter=" ") 23 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 24 | space_fmt = str(len(str(max_epoch))) 25 | header = 'Epoch [{start_epoch: >{fill}}/{end_epoch}]'.format(start_epoch=epoch+1, end_epoch=max_epoch, fill=space_fmt) 26 | print_freq = int(len(data_loader)/5) 27 | 28 | print(f"\n>>> Epoch #{(epoch+1)}") 29 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 30 | samples = samples.to(device) 31 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 32 | 33 | outputs = model(samples, targets) 34 | loss_dict = criterion(outputs, targets, log) 35 | weight_dict = criterion.weight_dict 36 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 37 | 38 | # reduce losses over all GPUs for logging purposes 39 | loss_dict_reduced = utils.reduce_dict(loss_dict) 40 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 41 | for k, v in loss_dict_reduced.items()} 42 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 43 | for k, v in loss_dict_reduced.items() if k in weight_dict} 44 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 45 | loss_value = losses_reduced_scaled.item() 46 | if utils.get_rank() == 0 and log: wandb.log(loss_dict_reduced_scaled) 47 | 48 | if not math.isfinite(loss_value): 49 | print("Loss is {}, stopping training".format(loss_value)) 50 | print(loss_dict_reduced) 51 | sys.exit(1) 52 | 53 | optimizer.zero_grad() 54 | losses.backward() 55 | if max_norm > 0: 56 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 57 | optimizer.step() 58 | 59 | # check_annotation(samples, targets, mode='train', rel_num=20) 60 | # plot_hoi_results(samples, outputs, targets, args=model.args) 61 | 62 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled) 63 | if "obj_class_error" in loss_dict: 64 | metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error']) 65 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 66 | 67 | # gather the stats from all processes 68 | metric_logger.synchronize_between_processes() 69 | print("Averaged stats:", metric_logger) 70 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 71 | -------------------------------------------------------------------------------- /src/metrics/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def count_parameters(model): 5 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 6 | 7 | 8 | def compute_overlap(a, b): 9 | if type(a) == torch.Tensor: 10 | if len(a.shape) == 2: 11 | area = (b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1) 12 | 13 | iw = torch.min(a[:, 2].unsqueeze(dim=1), b[:, 2]) - torch.max(a[:, 0].unsqueeze(dim=1), b[:, 0]) 14 | ih = torch.min(a[:, 3].unsqueeze(dim=1), b[:, 3]) - torch.max(a[:, 1].unsqueeze(dim=1), b[:, 1]) 15 | 16 | iw[iw<0] = 0 17 | ih[ih<0] = 0 18 | 19 | ua = torch.unsqueeze((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), dim=1) + area - iw * ih 20 | ua[ua < 1e-8] = 1e-8 21 | 22 | intersection = iw * ih 23 | 24 | return intersection / ua 25 | 26 | elif type(a) == np.ndarray: 27 | if len(a.shape) == 2: 28 | area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K) 29 | 30 | iw = np.minimum(np.expand_dims(a[:, 2], axis=1), np.expand_dims(b[:, 2], axis=0)) \ 31 | - np.maximum(np.expand_dims(a[:, 0], axis=1), np.expand_dims(b[:, 0], axis=0)) \ 32 | + 1 33 | ih = np.minimum(np.expand_dims(a[:, 3], axis=1), np.expand_dims(b[:, 3], axis=0)) \ 34 | - np.maximum(np.expand_dims(a[:, 1], axis=1), np.expand_dims(b[:, 1], axis=0)) \ 35 | + 1 36 | 37 | iw[iw<0] = 0 # (N, K) 38 | ih[ih<0] = 0 # (N, K) 39 | 40 | intersection = iw * ih 41 | 42 | ua = np.expand_dims((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), axis=1) + area - intersection 43 | ua[ua < 1e-8] = 1e-8 44 | 45 | return intersection / ua 46 | 47 | elif len(a.shape) == 1: 48 | area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K) 49 | 50 | iw = np.minimum(np.expand_dims([a[2]], axis=1), np.expand_dims(b[:, 2], axis=0)) \ 51 | - np.maximum(np.expand_dims([a[0]], axis=1), np.expand_dims(b[:, 0], axis=0)) 52 | ih = np.minimum(np.expand_dims([a[3]], axis=1), np.expand_dims(b[:, 3], axis=0)) \ 53 | - np.maximum(np.expand_dims([a[1]], axis=1), np.expand_dims(b[:, 1], axis=0)) 54 | 55 | iw[iw<0] = 0 # (N, K) 56 | ih[ih<0] = 0 # (N, K) 57 | 58 | ua = np.expand_dims([(a[2] - a[0] + 1) * (a[3] - a[1] + 1)], axis=1) + area - iw * ih 59 | ua[ua < 1e-8] = 1e-8 60 | 61 | intersection = iw * ih 62 | 63 | return intersection / ua 64 | 65 | 66 | def _compute_ap(recall, precision): 67 | """ Compute the average precision, given the recall and precision curves. 68 | Code originally from https://github.com/rbgirshick/py-faster-rcnn. 69 | # Arguments 70 | recall: The recall curve (list). 71 | precision: The precision curve (list). 72 | # Returns 73 | The average precision as computed in py-faster-rcnn. 74 | """ 75 | # correct AP calculation 76 | # first append sentinel values at the end 77 | mrec = np.concatenate(([0.], recall, [1.])) 78 | mpre = np.concatenate(([0.], precision, [0.])) 79 | 80 | # compute the precision envelope 81 | for i in range(mpre.size - 1, 0, -1): 82 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 83 | 84 | # to calculate area under PR curve, look for points 85 | # where X axis (recall) changes value 86 | i = np.where(mrec[1:] != mrec[:-1])[0] 87 | 88 | # and sum (\Delta recall) * prec 89 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 90 | return ap -------------------------------------------------------------------------------- /src/util/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | The boxes should be in [x0, y0, x1, y1] format 44 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 45 | and M = len(boxes2) 46 | """ 47 | # degenerate boxes gives inf / nan results 48 | # so do an early check 49 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 50 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 51 | iou, union = box_iou(boxes1, boxes2) 52 | 53 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 54 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 55 | 56 | wh = (rb - lt).clamp(min=0) # [N,M,2] 57 | area = wh[:, :, 0] * wh[:, :, 1] 58 | 59 | return iou - (area - union) / area 60 | 61 | 62 | def masks_to_boxes(masks): 63 | """Compute the bounding boxes around the provided masks 64 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 65 | Returns a [N, 4] tensors, with the boxes in xyxy format 66 | """ 67 | if masks.numel() == 0: 68 | return torch.zeros((0, 4), device=masks.device) 69 | 70 | h, w = masks.shape[-2:] 71 | 72 | y = torch.arange(0, h, dtype=torch.float) 73 | x = torch.arange(0, w, dtype=torch.float) 74 | y, x = torch.meshgrid(y, x) 75 | 76 | x_mask = (masks * x.unsqueeze(0)) 77 | x_max = x_mask.flatten(1).max(-1)[0] 78 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 79 | 80 | y_mask = (masks * y.unsqueeze(0)) 81 | y_max = y_mask.flatten(1).max(-1)[0] 82 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | return torch.stack([x_min, y_min, x_max, y_max], 1) 85 | 86 | 87 | def rescale_bboxes(out_bbox, size): 88 | img_h, img_w = size 89 | b = box_cxcywh_to_xyxy(out_bbox) 90 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.get_device()) 91 | return b 92 | 93 | 94 | def rescale_pairs(out_pairs, size): 95 | img_h, img_w = size 96 | h_bbox = out_pairs[:, :4] 97 | o_bbox = out_pairs[:, 4:] 98 | 99 | h = box_cxcywh_to_xyxy(h_bbox) 100 | h = h * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(h_bbox.get_device()) 101 | 102 | obj_mask = (o_bbox[:, 0] != -1) 103 | if obj_mask.sum() != 0: 104 | o = box_cxcywh_to_xyxy(o_bbox) 105 | o = o * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(o_bbox.get_device()) 106 | o_bbox[obj_mask] = o[obj_mask] 107 | o = o_bbox 108 | p = torch.cat([h, o], dim=-1) 109 | 110 | return p -------------------------------------------------------------------------------- /src/engine/hico_det_evaluate.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import copy 3 | import itertools 4 | 5 | import torch 6 | 7 | import src.util.misc as utils 8 | import src.util.logger as loggers 9 | from pycocotools.coco import COCO 10 | from pycocotools.cocoeval import COCOeval 11 | 12 | # evaluate detection on HICO-DET 13 | @torch.no_grad() 14 | def hico_det_evaluate(model, postprocessors, data_loader, device, args): 15 | hico_valid_obj_ids = torch.tensor(args.valid_obj_ids) # id -> coco id 16 | 17 | model.eval() 18 | 19 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ") 20 | header = 'Evaluation Inference (HICO-DET)' 21 | 22 | preds = [] 23 | gts = [] 24 | 25 | all_predictions = {} 26 | for samples, targets in metric_logger.log_every(data_loader, 50, header): 27 | samples = samples.to(device) 28 | targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets] 29 | 30 | outputs = model(samples) 31 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 32 | results = postprocessors['bbox'](outputs, orig_target_sizes, dataset='hico-det') 33 | 34 | preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results)))) 35 | # For avoiding a runtime error, the copy is used 36 | gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets))))) 37 | 38 | # gather the stats from all processes 39 | metric_logger.synchronize_between_processes() 40 | 41 | img_ids = [img_gts['id'] for img_gts in gts] 42 | _, indices = np.unique(img_ids, return_index=True) 43 | preds = {i: img_preds for i, img_preds in enumerate(preds) if i in indices} 44 | gts = {i: img_gts for i, img_gts in enumerate(gts) if i in indices} 45 | 46 | stats = do_detection_evaluation(preds, gts, hico_valid_obj_ids) 47 | return stats 48 | 49 | def do_detection_evaluation(predictions, groundtruths, hico_valid_obj_ids): 50 | # create a Coco-like object that we can use to evaluate detection! 51 | anns = [] 52 | for image_id, gt in groundtruths.items(): 53 | labels = gt['labels'].tolist() # map to coco like ids 54 | boxes = gt['boxes'].tolist() # xyxy 55 | for cls, box in zip(labels, boxes): 56 | anns.append({ 57 | 'area': (box[3] - box[1] + 1) * (box[2] - box[0] + 1), 58 | 'bbox': [box[0], box[1], box[2] - box[0] + 1, box[3] - box[1] + 1], # xywh 59 | 'category_id': cls, 60 | 'id': len(anns), 61 | 'image_id': image_id, 62 | 'iscrowd': 0, 63 | }) 64 | fauxcoco = COCO() 65 | fauxcoco.dataset = { 66 | 'info': {'description': 'use coco script for vg detection evaluation'}, 67 | 'images': [{'id': i} for i in range(len(groundtruths))], 68 | 'categories': [ 69 | {'supercategory': 'person', 'id': i, 'name': str(i)} for i in hico_valid_obj_ids.tolist() 70 | ], 71 | 'annotations': anns, 72 | } 73 | fauxcoco.createIndex() 74 | 75 | # format predictions to coco-like 76 | cocolike_predictions = [] 77 | for image_id, prediction in predictions.items(): 78 | box = torch.stack((prediction['boxes'][:,0], prediction['boxes'][:,1], prediction['boxes'][:,2]-prediction['boxes'][:,0]+1, prediction['boxes'][:,3]-prediction['boxes'][:,1]+1), dim=1).detach().cpu().numpy() # xywh 79 | label = prediction['labels'].cpu().numpy() # (#objs,) 80 | score = prediction['scores'].cpu().numpy() # (#objs,) 81 | 82 | image_id = np.asarray([image_id]*len(box)) 83 | cocolike_predictions.append( 84 | np.column_stack((image_id, box, score, label)) 85 | ) 86 | cocolike_predictions = np.concatenate(cocolike_predictions, 0) 87 | # evaluate via coco API 88 | res = fauxcoco.loadRes(cocolike_predictions) 89 | coco_eval = COCOeval(fauxcoco, res, 'bbox') 90 | coco_eval.params.imgIds = list(range(len(groundtruths))) 91 | coco_eval.evaluate() 92 | coco_eval.accumulate() 93 | coco_eval.summarize() 94 | mAp = coco_eval.stats[1] 95 | 96 | return mAp -------------------------------------------------------------------------------- /src/metrics/vcoco/ap_agent.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from src.metrics.utils import _compute_ap, compute_overlap 3 | import pdb 4 | 5 | class APAgent(object): 6 | def __init__(self, act_name, iou_threshold=0.5): 7 | self.act_name = act_name 8 | self.iou_threshold = iou_threshold 9 | 10 | self.fp = [np.zeros((0,))] * len(act_name) 11 | self.tp = [np.zeros((0,))] * len(act_name) 12 | self.score = [np.zeros((0,))] * len(act_name) 13 | self.num_ann = [0] * len(act_name) 14 | 15 | def add_data(self, box, act, cat, i_box, i_act): 16 | for label in range(len(self.act_name)): 17 | i_inds = (i_act[:, label] == 1) 18 | self.num_ann[label] += i_inds.sum() 19 | 20 | n_pred = box.shape[0] 21 | if n_pred == 0 : return 22 | 23 | ###################### 24 | valid_i_inds = (i_act[:, 0] != -1) # (n_i, ) # both in COCO & V-COCO 25 | 26 | overlaps = compute_overlap(box, i_box) # (n_pred, n_i) 27 | assigned_input = np.argmax(overlaps, axis=1) # (n_pred, ) 28 | v_inds = valid_i_inds[assigned_input] # (n_pred, ) 29 | 30 | n_valid = v_inds.sum() 31 | 32 | if n_valid == 0 : return 33 | valid_box = box[v_inds] 34 | valid_act = act[v_inds] 35 | valid_cat = cat[v_inds] 36 | 37 | ###################### 38 | s = valid_act * np.expand_dims(valid_cat, axis=1) # (n_v, #act) 39 | 40 | for label in range(len(self.act_name)): 41 | inds = np.argsort(s[:, label])[::-1] # (n_v, ) 42 | self.score[label] = np.append(self.score[label], s[inds, label]) 43 | 44 | correct_i_inds = (i_act[:, label] == 1) 45 | if correct_i_inds.sum() == 0: 46 | self.tp[label] = np.append(self.tp[label], np.array([0]*n_valid)) 47 | self.fp[label] = np.append(self.fp[label], np.array([1]*n_valid)) 48 | continue 49 | 50 | overlaps = compute_overlap(valid_box[inds], i_box) # (n_v, n_i) 51 | assigned_input = np.argmax(overlaps, axis=1) # (n_v, ) 52 | max_overlap = overlaps[range(n_valid), assigned_input] # (n_v, ) 53 | 54 | iou_inds = (max_overlap > self.iou_threshold) & correct_i_inds[assigned_input] # (n_v, ) 55 | 56 | i_nonzero = iou_inds.nonzero()[0] 57 | i_inds = assigned_input[i_nonzero] 58 | i_iou = np.unique(i_inds, return_index=True)[1] 59 | i_tp = i_nonzero[i_iou] 60 | 61 | t = np.zeros(n_valid, dtype=np.uint8) 62 | t[i_tp] = 1 63 | f = 1-t 64 | 65 | self.tp[label] = np.append(self.tp[label], t) 66 | self.fp[label] = np.append(self.fp[label], f) 67 | 68 | def evaluate(self): 69 | average_precisions = dict() 70 | for label in range(len(self.act_name)): 71 | if self.num_ann[label] == 0: 72 | average_precisions[label] = 0 73 | continue 74 | 75 | # sort by score 76 | indices = np.argsort(-self.score[label]) 77 | self.fp[label] = self.fp[label][indices] 78 | self.tp[label] = self.tp[label][indices] 79 | 80 | # compute false positives and true positives 81 | self.fp[label] = np.cumsum(self.fp[label]) 82 | self.tp[label] = np.cumsum(self.tp[label]) 83 | 84 | # compute recall and precision 85 | recall = self.tp[label] / self.num_ann[label] 86 | precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps) 87 | 88 | # compute average precision 89 | average_precisions[label] = _compute_ap(recall, precision) * 100 90 | 91 | print('\n================== AP (Agent) ===================') 92 | s, n = 0, 0 93 | 94 | for label in range(len(self.act_name)): 95 | label_name = "_".join(self.act_name[label].split("_")[1:]) 96 | print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label])) 97 | s += average_precisions[label] 98 | n += 1 99 | 100 | mAP = s/n 101 | print('| mAP(agent): {:0.2f}'.format(mAP)) 102 | print('----------------------------------------------------') 103 | 104 | return mAP -------------------------------------------------------------------------------- /src/models/detr_matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Modules to compute the matching cost and solve the corresponding LSAP. 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | 9 | from src.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 10 | 11 | 12 | class HungarianMatcher(nn.Module): 13 | """This class computes an assignment between the targets and the predictions of the network 14 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 15 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 16 | while the others are un-matched (and thus treated as non-objects). 17 | """ 18 | 19 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 20 | """Creates the matcher 21 | Params: 22 | cost_class: This is the relative weight of the classification error in the matching cost 23 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 24 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 25 | """ 26 | super().__init__() 27 | self.cost_class = cost_class 28 | self.cost_bbox = cost_bbox 29 | self.cost_giou = cost_giou 30 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 31 | 32 | @torch.no_grad() 33 | def forward(self, outputs, targets): 34 | """ Performs the matching 35 | Params: 36 | outputs: This is a dict that contains at least these entries: 37 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 38 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 39 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 40 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 41 | objects in the target) containing the class labels 42 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 43 | Returns: 44 | A list of size batch_size, containing tuples of (index_i, index_j) where: 45 | - index_i is the indices of the selected predictions (in order) 46 | - index_j is the indices of the corresponding selected targets (in order) 47 | For each batch element, it holds: 48 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 49 | """ 50 | bs, num_queries = outputs["pred_logits"].shape[:2] 51 | 52 | # We flatten to compute the cost matrices in a batch 53 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 54 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 55 | 56 | # Also concat the target labels and boxes 57 | tgt_ids = torch.cat([v["labels"] for v in targets]) 58 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 59 | 60 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 61 | # but approximate it in 1 - proba[target class]. 62 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 63 | cost_class = -out_prob[:, tgt_ids] 64 | 65 | # Compute the L1 cost between boxes 66 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 67 | 68 | # Compute the giou cost betwen boxes 69 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 70 | 71 | # Final cost matrix 72 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 73 | C = C.view(bs, num_queries, -1).cpu() 74 | 75 | sizes = [len(v["boxes"]) for v in targets] 76 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 77 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 78 | 79 | 80 | def build_matcher(args): 81 | return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) -------------------------------------------------------------------------------- /src/engine/evaluator_vcoco.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/engine/evaluator_vcoco.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | import os 9 | import torch 10 | import time 11 | import datetime 12 | 13 | import src.util.misc as utils 14 | import src.util.logger as loggers 15 | from src.data.evaluators.vcoco_eval import VCocoEvaluator 16 | from src.util.box_ops import rescale_bboxes, rescale_pairs 17 | from src.models.stip_utils import check_annotation, plot_cross_attention 18 | 19 | import wandb 20 | 21 | @torch.no_grad() 22 | def vcoco_evaluate(model, criterion, postprocessors, data_loader, device, output_dir, thr): 23 | model.eval() 24 | criterion.eval() 25 | 26 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ") 27 | header = 'Evaluation Inference (V-COCO)' 28 | 29 | print_freq = 1 # len(data_loader) 30 | res = {} 31 | hoi_recognition_time = [] 32 | 33 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 34 | samples = samples.to(device) 35 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets] 36 | 37 | # dec_selfattn_weights, dec_crossattn_weights = [], [] 38 | # hook = model.interaction_decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_crossattn_weights.append(output[1])) 39 | 40 | outputs = model(samples) 41 | loss_dict = criterion(outputs, targets) 42 | loss_dict_reduced = utils.reduce_dict(loss_dict) # ddp gathering 43 | 44 | loss_dict_reduced_scaled = {k: v * criterion.weight_dict[k] for k, v in loss_dict_reduced.items() if k in criterion.weight_dict} 45 | loss_value = sum(loss_dict_reduced_scaled.values()).item() 46 | 47 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 48 | results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='vcoco') 49 | targets = process_target(targets, orig_target_sizes) 50 | hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000) 51 | 52 | # check_annotation(samples, targets, mode='eval', rel_num=20, dataset='vcoco') 53 | # plot_cross_attention(samples, outputs, targets, dec_crossattn_weights, dataset='vcoco'); hook.remove() 54 | 55 | res.update( 56 | {target['image_id'].item():\ 57 | {'target': target, 'prediction': output} for target, output in zip(targets, results) 58 | } 59 | ) 60 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled) 61 | # if len(res) > 10: break 62 | 63 | print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms") 64 | metric_logger.synchronize_between_processes() 65 | print("Averaged validation stats:", metric_logger) 66 | 67 | start_time = time.time() 68 | gather_res = utils.all_gather(res) 69 | total_res = {} 70 | for dist_res in gather_res: 71 | total_res.update(dist_res) 72 | total_time = time.time() - start_time 73 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 74 | print(f"[stats] Distributed Gathering Time : {total_time_str}") 75 | 76 | return total_res 77 | 78 | def vcoco_accumulate(total_res, args, print_results, wandb_log): 79 | vcoco_evaluator = VCocoEvaluator(args) 80 | vcoco_evaluator.update(total_res) 81 | print(f"[stats] Score Matrix Generation completed!! ") 82 | 83 | scenario1 = vcoco_evaluator.role_eval1.evaluate(print_results) 84 | scenario2 = vcoco_evaluator.role_eval2.evaluate(print_results) 85 | 86 | if wandb_log: 87 | wandb.log({ 88 | 'scenario1': scenario1, 89 | 'scenario2': scenario2 90 | }) 91 | 92 | return scenario1, scenario2 93 | 94 | def process_target(targets, target_sizes): 95 | for idx, (target, target_size) in enumerate(zip(targets, target_sizes)): 96 | labels = target['labels'] 97 | valid_boxes_inds = (labels > 0) 98 | 99 | targets[idx]['boxes'] = rescale_bboxes(target['boxes'], target_size) # boxes 100 | targets[idx]['pair_boxes'] = rescale_pairs(target['pair_boxes'], target_size) # pairs 101 | 102 | return targets -------------------------------------------------------------------------------- /src/engine/evaluator_hico.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | from typing import Iterable 5 | import numpy as np 6 | import copy 7 | import itertools 8 | import matplotlib.pyplot as plt 9 | import torch 10 | 11 | import src.util.misc as utils 12 | import src.util.logger as loggers 13 | from src.data.evaluators.hico_eval import HICOEvaluator 14 | from src.models.stip_utils import check_annotation, plot_cross_attention, plot_hoi_results 15 | 16 | @torch.no_grad() 17 | def hico_evaluate(model, postprocessors, data_loader, device, thr, args): 18 | model.eval() 19 | 20 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ") 21 | header = 'Evaluation Inference (HICO-DET)' 22 | 23 | preds = [] 24 | gts = [] 25 | indices = [] 26 | hoi_recognition_time = [] 27 | 28 | for samples, targets in metric_logger.log_every(data_loader, 50, header): 29 | samples = samples.to(device) 30 | targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets] 31 | 32 | # # register hooks to obtain intermediate outputs 33 | # dec_selfattn_weights, dec_crossattn_weights = [], [] 34 | # if 'HOTR' in type(model).__name__: 35 | # hook_self = model.interaction_transformer.decoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: dec_selfattn_weights.append(output[1])) 36 | # hook_cross = model.interaction_transformer.decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_crossattn_weights.append(output[1])) 37 | # else: 38 | # hook_self = model.interaction_decoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: dec_selfattn_weights.append(output[1])) 39 | # hook_cross = model.interaction_decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_crossattn_weights.append(output[1])) 40 | 41 | outputs = model(samples) 42 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 43 | results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='hico-det') 44 | hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000) 45 | 46 | # # visualize 47 | # if targets[0]['id'] in [57]: # [47, 57, 81, 30, 46, 97]: # 30, 46, 97 48 | # # check_annotation(samples, targets, mode='eval', rel_num=20) 49 | # 50 | # # visualize cross-attentioa 51 | # if 'HOTR' in type(model).__name__: 52 | # outputs['pred_actions'] = outputs['pred_actions'][:, :, :args.num_actions] 53 | # outputs['pred_rel_pairs'] = [x.cpu() for x in torch.stack([outputs['pred_hidx'].argmax(-1), outputs['pred_oidx'].argmax(-1)], dim=-1)] 54 | # topk_qids, q_name_list = plot_hoi_results(samples, outputs, targets, args=args) 55 | # plot_cross_attention(samples, outputs, targets, dec_crossattn_weights, topk_qids=topk_qids) 56 | # print(f"image_id={targets[0]['id']}") 57 | # 58 | # # visualize self attention 59 | # print('visualize self-attention') 60 | # q_num = len(dec_selfattn_weights[0][0]) 61 | # plt.figure(figsize=(10,4)) 62 | # plt.imshow(dec_selfattn_weights[0][0].cpu().numpy(), vmin=0, vmax=0.4) 63 | # plt.xticks(np.arange(q_num), [f"{i}" for i in range(q_num)], rotation=90, fontsize=12) 64 | # plt.yticks(np.arange(q_num), [f"({q_name_list[i]})={i}" for i in range(q_num)], fontsize=12) 65 | # plt.gca().xaxis.set_ticks_position('top') 66 | # plt.grid(alpha=0.4, linestyle=':') 67 | # plt.show() 68 | # hook_self.remove(); hook_cross.remove() 69 | 70 | preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results)))) 71 | # For avoiding a runtime error, the copy is used 72 | gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets))))) 73 | 74 | 75 | print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms") 76 | 77 | # gather the stats from all processes 78 | metric_logger.synchronize_between_processes() 79 | 80 | img_ids = [img_gts['id'] for img_gts in gts] 81 | _, indices = np.unique(img_ids, return_index=True) 82 | preds = [img_preds for i, img_preds in enumerate(preds) if i in indices] 83 | gts = [img_gts for i, img_gts in enumerate(gts) if i in indices] 84 | 85 | evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets, 86 | data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat) 87 | 88 | stats = evaluator.evaluate() 89 | 90 | return stats 91 | 92 | -------------------------------------------------------------------------------- /src/models/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from src.util.misc import NestedTensor, is_main_process 15 | 16 | from .position_encoding import build_position_encoding 17 | 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 23 | without which any other models than torchvision.models.resnet[18,34,50,101] 24 | produce nans. 25 | """ 26 | 27 | def __init__(self, n): 28 | super(FrozenBatchNorm2d, self).__init__() 29 | self.register_buffer("weight", torch.ones(n)) 30 | self.register_buffer("bias", torch.zeros(n)) 31 | self.register_buffer("running_mean", torch.zeros(n)) 32 | self.register_buffer("running_var", torch.ones(n)) 33 | 34 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 35 | missing_keys, unexpected_keys, error_msgs): 36 | num_batches_tracked_key = prefix + 'num_batches_tracked' 37 | if num_batches_tracked_key in state_dict: 38 | del state_dict[num_batches_tracked_key] 39 | 40 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 41 | state_dict, prefix, local_metadata, strict, 42 | missing_keys, unexpected_keys, error_msgs) 43 | 44 | def forward(self, x): 45 | # move reshapes to the beginning 46 | # to make it fuser-friendly 47 | w = self.weight.reshape(1, -1, 1, 1) 48 | b = self.bias.reshape(1, -1, 1, 1) 49 | rv = self.running_var.reshape(1, -1, 1, 1) 50 | rm = self.running_mean.reshape(1, -1, 1, 1) 51 | eps = 1e-5 52 | scale = w * (rv + eps).rsqrt() 53 | bias = b - rm * scale 54 | return x * scale + bias 55 | 56 | 57 | class BackboneBase(nn.Module): 58 | 59 | def __init__(self, args, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 60 | super().__init__() 61 | for name, parameter in backbone.named_parameters(): 62 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 63 | parameter.requires_grad_(False) # fix other layers 64 | if return_interm_layers: 65 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 66 | else: 67 | if args.use_high_resolution_relation_feature_map: 68 | return_layers = {'layer3': "0", 'layer4': "1"} 69 | else: 70 | return_layers = {'layer4': "0"} 71 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 72 | self.num_channels = num_channels 73 | 74 | def forward(self, tensor_list: NestedTensor): 75 | xs = self.body(tensor_list.tensors) 76 | out: Dict[str, NestedTensor] = {} 77 | for name, x in xs.items(): 78 | m = tensor_list.mask 79 | assert m is not None 80 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 81 | out[name] = NestedTensor(x, mask) 82 | return out 83 | 84 | 85 | class Backbone(BackboneBase): 86 | """ResNet backbone with frozen BatchNorm.""" 87 | def __init__(self, args, name: str, 88 | train_backbone: bool, 89 | return_interm_layers: bool, 90 | dilation: bool): 91 | backbone = getattr(torchvision.models, name)( 92 | replace_stride_with_dilation=[False, False, dilation], 93 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 94 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 95 | super().__init__(args, backbone, train_backbone, num_channels, return_interm_layers) 96 | 97 | 98 | class Joiner(nn.Sequential): 99 | def __init__(self, backbone, position_embedding): 100 | super().__init__(backbone, position_embedding) 101 | 102 | def forward(self, tensor_list: NestedTensor): 103 | xs = self[0](tensor_list) 104 | out: List[NestedTensor] = [] 105 | pos = [] 106 | for name, x in xs.items(): 107 | out.append(x) 108 | # position encoding 109 | pos.append(self[1](x).to(x.tensors.dtype)) 110 | 111 | return out, pos 112 | 113 | 114 | def build_backbone(args): 115 | position_embedding = build_position_encoding(args) 116 | train_backbone = args.lr_backbone > 0 117 | return_interm_layers = False # args.masks 118 | backbone = Backbone(args, args.backbone, train_backbone, return_interm_layers, args.dilation) 119 | model = Joiner(backbone, position_embedding) 120 | model.num_channels = backbone.num_channels 121 | return model 122 | -------------------------------------------------------------------------------- /src/data/datasets/coco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | COCO dataset which returns image_id for evaluation. 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py 5 | """ 6 | from pathlib import Path 7 | 8 | import torch 9 | import torch.utils.data 10 | import torchvision 11 | from pycocotools import mask as coco_mask 12 | 13 | import src.data.transforms.transforms as T 14 | 15 | class CocoDetection(torchvision.datasets.CocoDetection): 16 | def __init__(self, img_folder, ann_file, transforms, return_masks): 17 | super(CocoDetection, self).__init__(img_folder, ann_file) 18 | self._transforms = transforms 19 | self.prepare = ConvertCocoPolysToMask(return_masks) 20 | 21 | def __getitem__(self, idx): 22 | img, target = super(CocoDetection, self).__getitem__(idx) 23 | image_id = self.ids[idx] 24 | target = {'image_id': image_id, 'annotations': target} 25 | img, target = self.prepare(img, target) 26 | if self._transforms is not None: 27 | img, target = self._transforms(img, target) 28 | return img, target 29 | 30 | 31 | def convert_coco_poly_to_mask(segmentations, height, width): 32 | masks = [] 33 | for polygons in segmentations: 34 | rles = coco_mask.frPyObjects(polygons, height, width) 35 | mask = coco_mask.decode(rles) 36 | if len(mask.shape) < 3: 37 | mask = mask[..., None] 38 | mask = torch.as_tensor(mask, dtype=torch.uint8) 39 | mask = mask.any(dim=2) 40 | masks.append(mask) 41 | if masks: 42 | masks = torch.stack(masks, dim=0) 43 | else: 44 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 45 | return masks 46 | 47 | 48 | class ConvertCocoPolysToMask(object): 49 | def __init__(self, return_masks=False): 50 | self.return_masks = return_masks 51 | 52 | def __call__(self, image, target): 53 | w, h = image.size 54 | 55 | image_id = target["image_id"] 56 | image_id = torch.tensor([image_id]) 57 | 58 | anno = target["annotations"] 59 | 60 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] 61 | 62 | boxes = [obj["bbox"] for obj in anno] 63 | # guard against no boxes via resizing 64 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 65 | boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2) 66 | boxes[:, 0::2].clamp_(min=0, max=w) 67 | boxes[:, 1::2].clamp_(min=0, max=h) 68 | 69 | classes = [obj["category_id"] for obj in anno] 70 | classes = torch.tensor(classes, dtype=torch.int64) 71 | 72 | if self.return_masks: 73 | segmentations = [obj["segmentation"] for obj in anno] 74 | masks = convert_coco_poly_to_mask(segmentations, h, w) 75 | 76 | keypoints = None 77 | if anno and "keypoints" in anno[0]: 78 | keypoints = [obj["keypoints"] for obj in anno] 79 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32) 80 | num_keypoints = keypoints.shape[0] 81 | if num_keypoints: 82 | keypoints = keypoints.view(num_keypoints, -1, 3) 83 | 84 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 85 | boxes = boxes[keep] 86 | classes = classes[keep] 87 | if self.return_masks: 88 | masks = masks[keep] 89 | if keypoints is not None: 90 | keypoints = keypoints[keep] 91 | 92 | target = {} 93 | target["boxes"] = boxes 94 | target["labels"] = classes 95 | if self.return_masks: 96 | target["masks"] = masks 97 | target["image_id"] = image_id 98 | if keypoints is not None: 99 | target["keypoints"] = keypoints 100 | 101 | # for conversion to coco api 102 | area = torch.tensor([obj["area"] for obj in anno]) 103 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) 104 | target["area"] = area[keep] 105 | target["iscrowd"] = iscrowd[keep] 106 | 107 | target["orig_size"] = torch.as_tensor([int(h), int(w)]) 108 | target["size"] = torch.as_tensor([int(h), int(w)]) 109 | 110 | return image, target 111 | 112 | 113 | def make_coco_transforms(image_set): 114 | 115 | normalize = T.Compose([ 116 | T.ToTensor(), 117 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 118 | ]) 119 | 120 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 121 | 122 | if image_set == 'train': 123 | return T.Compose([ 124 | T.RandomHorizontalFlip(), 125 | T.RandomSelect( 126 | T.RandomResize(scales, max_size=1333), 127 | T.Compose([ 128 | T.RandomResize([400, 500, 600]), 129 | T.RandomSizeCrop(384, 600), 130 | T.RandomResize(scales, max_size=1333), 131 | ]) 132 | ), 133 | normalize, 134 | ]) 135 | 136 | if image_set == 'val': 137 | return T.Compose([ 138 | T.RandomResize([800], max_size=1333), 139 | normalize, 140 | ]) 141 | 142 | raise ValueError(f'unknown {image_set}') 143 | 144 | 145 | def build(image_set, args): 146 | root = Path(args.data_path) 147 | assert root.exists(), f'provided COCO path {root} does not exist' 148 | mode = 'instances' 149 | PATHS = { 150 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'), 151 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'), 152 | } 153 | 154 | img_folder, ann_file = PATHS[image_set] 155 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks) 156 | return dataset -------------------------------------------------------------------------------- /src/util/logger.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/util/logger.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | import torch 9 | import time 10 | import datetime 11 | import sys 12 | from time import sleep 13 | from collections import defaultdict 14 | 15 | from src.util.misc import SmoothedValue 16 | 17 | def print_params(model): 18 | n_parameters = sum(p.numel() for p in model.parameters()) 19 | print('\n[Logger] Number of total params: ', n_parameters) 20 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 21 | print('\n[Logger] Number of trainable params: ', n_parameters) 22 | return n_parameters 23 | 24 | def print_args(args): 25 | print('\n[Logger] DETR Arguments:') 26 | for k, v in vars(args).items(): 27 | if k in [ 28 | 'lr', 'lr_backbone', 'lr_drop', 29 | 'frozen_weights', 30 | 'backbone', 'dilation', 31 | 'position_embedding', 'enc_layers', 'dec_layers', 'num_queries', 32 | 'dataset_file']: 33 | print(f'\t{k}: {v}') 34 | 35 | if args.HOIDet: 36 | print('\n[Logger] DETR_HOI Arguments:') 37 | for k, v in vars(args).items(): 38 | if k in [ 39 | 'freeze_enc', 40 | 'query_flag', 41 | 'hoi_nheads', 42 | 'hoi_dim_feedforward', 43 | 'hoi_dec_layers', 44 | 'hoi_idx_loss_coef', 45 | 'hoi_act_loss_coef', 46 | 'hoi_eos_coef', 47 | 'object_threshold']: 48 | print(f'\t{k}: {v}') 49 | 50 | class MetricLogger(object): 51 | def __init__(self, mode="test", delimiter="\t"): 52 | self.meters = defaultdict(SmoothedValue) 53 | self.delimiter = delimiter 54 | self.mode = mode 55 | 56 | def update(self, **kwargs): 57 | for k, v in kwargs.items(): 58 | if isinstance(v, torch.Tensor): 59 | v = v.item() 60 | assert isinstance(v, (float, int)) 61 | self.meters[k].update(v) 62 | 63 | def __getattr__(self, attr): 64 | if attr in self.meters: 65 | return self.meters[attr] 66 | if attr in self.__dict__: 67 | return self.__dict__[attr] 68 | raise AttributeError("'{}' object has no attribute '{}'".format( 69 | type(self).__name__, attr)) 70 | 71 | def __str__(self): 72 | loss_str = [] 73 | for name, meter in self.meters.items(): 74 | loss_str.append( 75 | "{}: {}".format(name, str(meter)) 76 | ) 77 | return self.delimiter.join(loss_str) 78 | 79 | def synchronize_between_processes(self): 80 | for meter in self.meters.values(): 81 | meter.synchronize_between_processes() 82 | 83 | def add_meter(self, name, meter): 84 | self.meters[name] = meter 85 | 86 | def log_every(self, iterable, print_freq, header=None): 87 | i = 0 88 | if not header: 89 | header = '' 90 | start_time = time.time() 91 | end = time.time() 92 | iter_time = SmoothedValue(fmt='{avg:.4f}') 93 | data_time = SmoothedValue(fmt='{avg:.4f}') 94 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 95 | if torch.cuda.is_available(): 96 | log_msg = self.delimiter.join([ 97 | header, 98 | '[{0' + space_fmt + '}/{1}]', 99 | 'eta: {eta}', 100 | '{meters}', 101 | 'time: {time}', 102 | 'data: {data}', 103 | 'max mem: {memory:.0f}' 104 | ]) 105 | else: 106 | log_msg = self.delimiter.join([ 107 | header, 108 | '[{0' + space_fmt + '}/{1}]', 109 | 'eta: {eta}', 110 | '{meters}', 111 | 'time: {time}', 112 | 'data: {data}' 113 | ]) 114 | MB = 1024.0 * 1024.0 115 | for obj in iterable: 116 | data_time.update(time.time() - end) 117 | yield obj 118 | iter_time.update(time.time() - end) 119 | 120 | if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1: 121 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 122 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 123 | if torch.cuda.is_available(): 124 | print(log_msg.format( 125 | i+1, len(iterable), eta=eta_string, 126 | meters=str(self), 127 | time=str(iter_time), data=str(data_time), 128 | memory=torch.cuda.max_memory_allocated() / MB), 129 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n")) 130 | else: 131 | print(log_msg.format( 132 | i+1, len(iterable), eta=eta_string, 133 | meters=str(self), 134 | time=str(iter_time), data=str(data_time)), 135 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n")) 136 | else: 137 | log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]']) 138 | if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r") 139 | else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r") 140 | 141 | i += 1 142 | end = time.time() 143 | total_time = time.time() - start_time 144 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 145 | if self.mode=='test': print("") 146 | print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format( 147 | self.mode, total_time_str, total_time / len(iterable))) 148 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [CVPR 2022] Exploring Structure-aware Transformer over Interaction Proposals for Human-object Interaction Detection 2 | 3 | * [PDF](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhang_Exploring_Structure-Aware_Transformer_Over_Interaction_Proposals_for_Human-Object_Interaction_Detection_CVPR_2022_paper.pdf) 4 | * [Supplementary](https://openaccess.thecvf.com/content/CVPR2022/supplemental/Zhang_Exploring_Structure-Aware_Transformer_CVPR_2022_supplemental.pdf) 5 | 6 | ## Paper introduction 7 | 8 | Recent high-performing Human-Object Interaction (HOI) detection techniques have been highly influenced by Transformer-based object detector (i.e., DETR). Nevertheless, most of them directly map parametric interaction queries into a set of HOI predictions through vanilla Transformer in a one-stage manner. This leaves rich inter- or intra-interaction structure under-exploited. In this work, we design a novel Transformer-style HOI detector, i.e., Structure-aware Transformer over Interaction Proposals (STIP), for HOI detection. Such design decomposes the process of HOI set prediction into two subsequent phases, i.e., an interaction proposal generation is first performed, and then followed by transforming the non-parametric interaction proposals into HOI predictions via a structure-aware Transformer. The structure-aware Transformer upgrades vanilla Transformer by encoding additionally the holistically semantic structure among interaction proposals as well as the locally spatial structure of human/object within each interaction proposal, so as to strengthen HOI predictions. Extensive experiments conducted on V-COCO and HICO-DET benchmarks have demonstrated the effectiveness of STIP, and superior results are reported when comparing with the state-of-the-art HOI detectors. 9 | 10 |

11 | 12 | ## Installation 13 | 14 | ### 1. Environmental Setup 15 | ```bash 16 | $ conda create -n STIP python=3.7 17 | $ conda install -c pytorch pytorch torchvision # PyTorch 1.7.1, torchvision 0.8.2, CUDA=11.0 18 | $ conda install cython scipy 19 | $ pip install pycocotools 20 | $ pip install opencv-python 21 | $ pip install wandb 22 | ``` 23 | 24 | ### 2. HOI dataset setup 25 | Our current version supports the experiments for 26 | - [V-COCO](https://github.com/s-gupta/v-coco) and [HICO-DET](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) dataset. Download the dataset under the pulled directory. 27 | - For HICO-DET, ~we use the [annotation files](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) provided by the PPDM authors. Download the [list of actions](https://drive.google.com/open?id=1EeHNHuYyJI-qqDk_-5nay7Mb07tzZLsl) as `list_action.txt` and place them under the unballed hico-det directory.~ If these download links are broken, please download them from [hico_20160224_det](https://cuhko365-my.sharepoint.com/:u:/g/personal/218019030_link_cuhk_edu_cn/EY3AbysMMUBJmQEt8zl4j9UBxHH6no2-g4o63EIt38GkgQ?e=oCEIsG). 28 | 29 | Below we present how you should place the files. 30 | ```bash 31 | # V-COCO setup 32 | $ git clone https://github.com/s-gupta/v-coco.git 33 | $ cd v-coco 34 | $ ln -s [:COCO_DIR] coco/images # COCO_DIR contains images of train2014 & val2014 35 | $ python script_pick_annotations.py [:COCO_DIR]/annotations 36 | 37 | # HICO-DET setup 38 | $ tar -zxvf hico_20160224_det.tar.gz # move the unballed folder under the pulled repository 39 | 40 | # dataset setup 41 | STIP 42 | │─ v-coco 43 | │ │─ data 44 | │ │ │─ instances_vcoco_all_2014.json 45 | │ │ : 46 | │ └─ coco 47 | │ │─ images 48 | │ │ │─ train2014 49 | │ │ │ │─ COCO_train2014_000000000009.jpg 50 | │ │ │ : 51 | │ │ └─ val2014 52 | │ │ │─ COCO_val2014_000000000042.jpg 53 | : : : 54 | │─ hico_20160224_det 55 | │ │─ list_action.txt 56 | │ │─ annotations 57 | │ │ │─ trainval_hico.json 58 | │ │ │─ test_hico.json 59 | │ │ └─ corre_hico.npy 60 | : : 61 | ``` 62 | 63 | If you wish to download the datasets on our own directory, simply change the 'data_path' argument to the directory you have downloaded the datasets. 64 | ```bash 65 | --data_path [:your_own_directory]/[v-coco/hico_20160224_det] 66 | ``` 67 | 68 | ### 3. Training/Testing on V-COCO 69 | 70 | ```shell 71 | python STIP_main.py --validate \ 72 | --num_hoi_queries 32 --batch_size 4 --lr 5e-5 --HOIDet --hoi_aux_loss --no_aux_loss \ 73 | --dataset_file vcoco --data_path v-coco --detr_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \ 74 | --output_dir checkpoints/vcoco --group_name STIP_debug --run_name vcoco_run1 75 | ``` 76 | 77 | * Add `--eval` option for evaluation 78 | 79 | ### 4. Training/Testing on HICO-DET 80 | 81 | Training with pretrained DETR detector on COCO. 82 | ```shell 83 | python STIP_main.py --validate \ 84 | --num_hoi_queries 32 --batch_size 4 --lr 5e-5 --HOIDet --hoi_aux_loss --no_aux_loss \ 85 | --dataset_file hico-det --data_path hico_20160224_det --detr_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \ 86 | --output_dir checkpoints/hico-det --group_name STIP_debug --run_name hicodet_run1 87 | ``` 88 | 89 | Jointly fine-tune object detector & HOI detector 90 | ```shell 91 | python STIP_main.py --validate \ 92 | --num_hoi_queries 32 --batch_size 2 --lr 1e-5 --HOIDet --hoi_aux_loss \ 93 | --dataset_file hico-det --data_path hico_20160224_det \ 94 | --output_dir checkpoints/hico-det --group_name STIP_debug --run_name hicodet_run1/jointly-tune \ 95 | --resume checkpoints/hico-det/STIP_debug/best.pth --train_detr 96 | ``` 97 | 98 | Pre-trained models: 99 | * [VCOCO](https://cuhko365-my.sharepoint.com/:u:/g/personal/218019030_link_cuhk_edu_cn/ETPIGJHooGVPuOXDBC3WTaoBQBGHXioj4hAbjllSTSOn6A?e=KBN0bq) 100 | * [HICO-DET (w/o jointly finetune)](https://cuhko365-my.sharepoint.com/:u:/g/personal/218019030_link_cuhk_edu_cn/EZyYzmW7vL9Bh7zKwT1JGFEBKbTkbs4vD9QW7kfs8EV_gQ?e=NjKBjX) 101 | 102 | ## License 103 | 104 | This repo is released under the [Apache License, Version 2.0](LICENSE). 105 | 106 | 107 | ## Acknowledgement 108 | 109 | This repo is based on [DETR](https://github.com/facebookresearch/detr), [HOTR](https://github.com/kakaobrain/HOTR). Thanks for their wonderful works. 110 | 111 | 112 | 113 | ## Citation 114 | 115 | If you find this code helpful for your research, please cite our paper. 116 | ``` 117 | @InProceedings{Zhang_2022_CVPR, 118 | author = {Zhang, Yong and Pan, Yingwei and Yao, Ting and Huang, Rui and Mei, Tao and Chen, Chang-Wen}, 119 | title = {Exploring Structure-Aware Transformer Over Interaction Proposals for Human-Object Interaction Detection}, 120 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 121 | month = {June}, 122 | year = {2022}, 123 | pages = {19548-19557} 124 | } 125 | ``` 126 | -------------------------------------------------------------------------------- /src/data/datasets/builtin_meta.py: -------------------------------------------------------------------------------- 1 | COCO_CATEGORIES = [ 2 | {"color": [], "isthing": 0, "id": 0, "name": "N/A"}, 3 | {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, 4 | {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, 5 | {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, 6 | {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, 7 | {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, 8 | {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, 9 | {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, 10 | {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, 11 | {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, 12 | {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, 13 | {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, 14 | {"color": [], "isthing": 0, "id": 12, "name": "N/A"}, 15 | {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, 16 | {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, 17 | {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, 18 | {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, 19 | {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, 20 | {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, 21 | {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, 22 | {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, 23 | {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, 24 | {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, 25 | {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, 26 | {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, 27 | {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, 28 | {"color": [], "isthing": 0, "id": 26, "name": "N/A"}, 29 | {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, 30 | {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, 31 | {"color": [], "isthing": 0, "id": 29, "name": "N/A"}, 32 | {"color": [], "isthing": 0, "id": 30, "name": "N/A"}, 33 | {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, 34 | {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, 35 | {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, 36 | {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, 37 | {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, 38 | {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, 39 | {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, 40 | {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, 41 | {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, 42 | {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, 43 | {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, 44 | {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, 45 | {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, 46 | {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, 47 | {"color": [], "isthing": 0, "id": 45, "name": "N/A"}, 48 | {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, 49 | {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, 50 | {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, 51 | {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, 52 | {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, 53 | {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, 54 | {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, 55 | {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, 56 | {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, 57 | {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, 58 | {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, 59 | {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, 60 | {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, 61 | {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, 62 | {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, 63 | {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, 64 | {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, 65 | {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, 66 | {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, 67 | {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, 68 | {"color": [], "isthing": 0, "id": 66, "name": "N/A"}, 69 | {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, 70 | {"color": [], "isthing": 0, "id": 68, "name": "N/A"}, 71 | {"color": [], "isthing": 0, "id": 69, "name": "N/A"}, 72 | {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, 73 | {"color": [], "isthing": 0, "id": 71, "name": "N/A"}, 74 | {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, 75 | {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, 76 | {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, 77 | {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, 78 | {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, 79 | {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, 80 | {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, 81 | {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, 82 | {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, 83 | {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, 84 | {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, 85 | {"color": [], "isthing": 0, "id": 83, "name": "N/A"}, 86 | {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, 87 | {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, 88 | {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, 89 | {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, 90 | {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, 91 | {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, 92 | {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, 93 | ] 94 | 95 | def _get_coco_instances_meta(): 96 | thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1] 97 | assert len(thing_ids) == 80, f"Length of thing ids : {len(thing_ids)}" 98 | 99 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 100 | thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1] 101 | thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1] 102 | 103 | coco_classes = [k["name"] for k in COCO_CATEGORIES] 104 | 105 | return { 106 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 107 | "thing_classes": thing_classes, 108 | "thing_colors": thing_colors, 109 | "coco_classes": coco_classes, 110 | } 111 | -------------------------------------------------------------------------------- /src/models/hotr.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/models/src.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import copy 9 | import time 10 | import datetime 11 | 12 | from src.util.misc import NestedTensor, nested_tensor_from_tensor_list 13 | from .feed_forward import MLP 14 | 15 | class HOTR(nn.Module): 16 | def __init__(self, detr, 17 | num_hoi_queries, 18 | num_actions, 19 | interaction_transformer, 20 | freeze_detr, 21 | share_enc, 22 | pretrained_dec, 23 | temperature, 24 | hoi_aux_loss, 25 | return_obj_class=None): 26 | super().__init__() 27 | 28 | # * Instance Transformer --------------- 29 | self.detr = detr 30 | if freeze_detr: 31 | # if this flag is given, freeze the object detection related parameters of DETR 32 | for p in self.parameters(): 33 | p.requires_grad_(False) 34 | hidden_dim = detr.transformer.d_model 35 | # -------------------------------------- 36 | 37 | # * Interaction Transformer ----------------------------------------- 38 | self.num_queries = num_hoi_queries 39 | self.query_embed = nn.Embedding(self.num_queries, hidden_dim) 40 | self.H_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3) 41 | self.O_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3) 42 | self.action_embed = nn.Linear(hidden_dim, num_actions+1) 43 | # -------------------------------------------------------------------- 44 | 45 | # * HICO-DET FFN heads --------------------------------------------- 46 | self.return_obj_class = (return_obj_class is not None) 47 | if return_obj_class: self._valid_obj_ids = return_obj_class + [return_obj_class[-1]+1] 48 | # ------------------------------------------------------------------ 49 | 50 | # * Transformer Options --------------------------------------------- 51 | self.interaction_transformer = interaction_transformer 52 | 53 | if share_enc: # share encoder 54 | self.interaction_transformer.encoder = detr.transformer.encoder 55 | 56 | if pretrained_dec: # free variables for interaction decoder 57 | self.interaction_transformer.decoder = copy.deepcopy(detr.transformer.decoder) 58 | for p in self.interaction_transformer.decoder.parameters(): 59 | p.requires_grad_(True) 60 | # --------------------------------------------------------------------- 61 | 62 | # * Loss Options ------------------- 63 | self.tau = temperature 64 | self.hoi_aux_loss = hoi_aux_loss 65 | # ---------------------------------- 66 | 67 | def forward(self, samples: NestedTensor, targets=None): 68 | if isinstance(samples, (list, torch.Tensor)): 69 | samples = nested_tensor_from_tensor_list(samples) 70 | 71 | # >>>>>>>>>>>> BACKBONE LAYERS <<<<<<<<<<<<<<< 72 | features, pos = self.detr.backbone(samples) 73 | bs = features[-1].tensors.shape[0] 74 | src, mask = features[-1].decompose() 75 | assert mask is not None 76 | # ---------------------------------------------- 77 | 78 | # >>>>>>>>>>>> OBJECT DETECTION LAYERS <<<<<<<<<< 79 | start_time = time.time() 80 | hs, _ = self.detr.transformer(self.detr.input_proj(src), mask, self.detr.query_embed.weight, pos[-1]) 81 | inst_repr = F.normalize(hs[-1], p=2, dim=2) # instance representations 82 | 83 | # Prediction Heads for Object Detection 84 | outputs_class = self.detr.class_embed(hs) 85 | outputs_coord = self.detr.bbox_embed(hs).sigmoid() 86 | object_detection_time = time.time() - start_time 87 | # ----------------------------------------------- 88 | 89 | # >>>>>>>>>>>> HOI DETECTION LAYERS <<<<<<<<<<<<<<< 90 | start_time = time.time() 91 | assert hasattr(self, 'interaction_transformer'), "Missing Interaction Transformer." 92 | interaction_hs = self.interaction_transformer(self.detr.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # interaction representations 93 | 94 | # [HO Pointers] 95 | H_Pointer_reprs = F.normalize(self.H_Pointer_embed(interaction_hs), p=2, dim=-1) 96 | O_Pointer_reprs = F.normalize(self.O_Pointer_embed(interaction_hs), p=2, dim=-1) 97 | outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr.transpose(1,2))) / self.tau for H_Pointer_repr in H_Pointer_reprs] 98 | outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr.transpose(1,2))) / self.tau for O_Pointer_repr in O_Pointer_reprs] 99 | 100 | # [Action Classification] 101 | outputs_action = self.action_embed(interaction_hs) 102 | # -------------------------------------------------- 103 | hoi_detection_time = time.time() - start_time 104 | hoi_recognition_time = max(hoi_detection_time - object_detection_time, 0) 105 | # ------------------------------------------------------------------- 106 | 107 | # [Target Classification] 108 | if self.return_obj_class: 109 | detr_logits = outputs_class[-1, ..., self._valid_obj_ids] 110 | o_indices = [output_oidx.max(-1)[-1] for output_oidx in outputs_oidx] 111 | obj_logit_stack = [torch.stack([detr_logits[batch_, o_idx, :] for batch_, o_idx in enumerate(o_indice)], 0) for o_indice in o_indices] 112 | outputs_obj_class = obj_logit_stack 113 | 114 | out = { 115 | "pred_logits": outputs_class[-1], 116 | "pred_boxes": outputs_coord[-1], 117 | "pred_hidx": outputs_hidx[-1], 118 | "pred_oidx": outputs_oidx[-1], 119 | "pred_actions": outputs_action[-1], 120 | "hoi_recognition_time": hoi_recognition_time, 121 | } 122 | 123 | if self.return_obj_class: out["pred_obj_logits"] = outputs_obj_class[-1] 124 | 125 | if self.hoi_aux_loss: # auxiliary loss 126 | out['hoi_aux_outputs'] = \ 127 | self._set_aux_loss_with_tgt(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_obj_class) \ 128 | if self.return_obj_class else \ 129 | self._set_aux_loss(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action) 130 | 131 | return out 132 | 133 | @torch.jit.unused 134 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action): 135 | return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e} 136 | for a, b, c, d, e in zip( 137 | outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), 138 | outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), 139 | outputs_hidx[:-1], 140 | outputs_oidx[:-1], 141 | outputs_action[:-1])] 142 | 143 | @torch.jit.unused 144 | def _set_aux_loss_with_tgt(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_tgt): 145 | return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e, 'pred_obj_logits': f} 146 | for a, b, c, d, e, f in zip( 147 | outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), 148 | outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)), 149 | outputs_hidx[:-1], 150 | outputs_oidx[:-1], 151 | outputs_action[:-1], 152 | outputs_tgt[:-1])] -------------------------------------------------------------------------------- /src/models/post_process.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/models/post_process.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | import time 6 | import copy 7 | import torch 8 | import torch.nn.functional as F 9 | from torch import nn 10 | from src.util import box_ops 11 | 12 | class PostProcess(nn.Module): 13 | """ This module converts the model's output into the format expected by the coco api""" 14 | def __init__(self, HOIDet): 15 | super().__init__() 16 | self.HOIDet = HOIDet 17 | 18 | @torch.no_grad() 19 | def forward(self, outputs, target_sizes, threshold=0, dataset='coco'): 20 | """ Perform the computation 21 | Parameters: 22 | outputs: raw outputs of the model 23 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 24 | For evaluation, this must be the original image size (before any data augmentation) 25 | For visualization, this should be the image size after data augment, but before padding 26 | """ 27 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] 28 | 29 | assert len(out_logits) == len(target_sizes) 30 | assert target_sizes.shape[1] == 2 31 | 32 | prob = F.softmax(out_logits, -1) 33 | scores, labels = prob[..., :-1].max(-1) 34 | 35 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) 36 | img_h, img_w = target_sizes.unbind(1) 37 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 38 | boxes = boxes * scale_fct[:, None, :] 39 | 40 | # Preidction Branch for HOI detection 41 | if self.HOIDet: 42 | if dataset == 'vcoco': 43 | """ Compute HOI triplet prediction score for V-COCO. 44 | Our scoring function follows the implementation details of UnionDet. 45 | """ 46 | out_time = outputs['hoi_recognition_time'] 47 | 48 | start_time = time.time() 49 | pair_actions = torch.sigmoid(outputs['pred_actions']) 50 | h_prob = F.softmax(outputs['pred_hidx'], -1) 51 | h_idx_score, h_indices = h_prob.max(-1) 52 | 53 | o_prob = F.softmax(outputs['pred_oidx'], -1) 54 | o_idx_score, o_indices = o_prob.max(-1) 55 | hoi_recognition_time = (time.time() - start_time) + out_time 56 | 57 | results = [] 58 | # iterate for batch size 59 | for batch_idx, (s, l, b) in enumerate(zip(scores, labels, boxes)): 60 | h_inds = (l == 1) & (s > threshold) 61 | o_inds = (s > threshold) 62 | 63 | h_box, h_cat = b[h_inds], s[h_inds] 64 | o_box, o_cat = b[o_inds], s[o_inds] 65 | 66 | # for scenario 1 in v-coco dataset 67 | o_inds = torch.cat((o_inds, torch.ones(1).type(torch.bool).to(o_inds.device))) # add an empty box 68 | o_box = torch.cat((o_box, torch.Tensor([0, 0, 0, 0]).unsqueeze(0).to(o_box.device))) 69 | 70 | result_dict = { 71 | 'h_box': h_box, 'h_cat': h_cat, 72 | 'o_box': o_box, 'o_cat': o_cat, 73 | 'scores': s, 'labels': l, 'boxes': b 74 | } 75 | 76 | h_inds_lst = (h_inds == True).nonzero(as_tuple=False).squeeze(-1) 77 | o_inds_lst = (o_inds == True).nonzero(as_tuple=False).squeeze(-1) 78 | 79 | K = boxes.shape[1] 80 | n_act = pair_actions[batch_idx][:, :-1].shape[-1] 81 | score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device) 82 | sorted_score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device) 83 | id_score = torch.zeros((K, K+1)).to(pair_actions[batch_idx].device) 84 | 85 | # Score function: strange HOTR scoring 86 | for hs, h_idx, os, o_idx, pair_action in zip(h_idx_score[batch_idx], h_indices[batch_idx], o_idx_score[batch_idx], o_indices[batch_idx], pair_actions[batch_idx]): 87 | matching_score = (1-pair_action[-1]) # no interaction score 88 | if h_idx == o_idx: o_idx = -1 # special case, head=tail 89 | if matching_score > id_score[h_idx, o_idx]: 90 | id_score[h_idx, o_idx] = matching_score 91 | sorted_score[:, h_idx, o_idx] = matching_score * pair_action[:-1] 92 | score[:, h_idx, o_idx] += matching_score * pair_action[:-1] 93 | 94 | score += sorted_score 95 | score = score[:, h_inds, :] 96 | score = score[:, :, o_inds] 97 | 98 | result_dict.update({ 99 | 'pair_score': score, 100 | 'hoi_recognition_time': hoi_recognition_time, 101 | }) 102 | 103 | results.append(result_dict) 104 | 105 | elif dataset == 'hico-det': 106 | """ Compute HOI triplet prediction score for HICO-DET. 107 | For HICO-DET, we follow the same scoring function but do not accumulate the results. 108 | """ 109 | out_time = outputs['hoi_recognition_time'] 110 | 111 | start_time = time.time() 112 | out_obj_logits, out_verb_logits = outputs['pred_obj_logits'], outputs['pred_actions'] 113 | out_verb_logits = outputs['pred_actions'] 114 | 115 | # actions 116 | matching_scores = (1-out_verb_logits.sigmoid()[..., -1:]) #* (1-out_verb_logits.sigmoid()[..., 57:58]) 117 | verb_scores = out_verb_logits.sigmoid()[..., :-1] * matching_scores 118 | 119 | # hbox, obox 120 | outputs_hrepr, outputs_orepr = outputs['pred_hidx'], outputs['pred_oidx'] 121 | obj_scores, obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1) 122 | 123 | h_prob = F.softmax(outputs_hrepr, -1) 124 | h_idx_score, h_indices = h_prob.max(-1) 125 | 126 | # targets 127 | o_prob = F.softmax(outputs_orepr, -1) 128 | o_idx_score, o_indices = o_prob.max(-1) 129 | hoi_recognition_time = (time.time() - start_time) + out_time 130 | 131 | # hidx, oidx 132 | sub_boxes, obj_boxes = [], [] 133 | for batch_id, (box, h_idx, o_idx) in enumerate(zip(boxes, h_indices, o_indices)): 134 | sub_boxes.append(box[h_idx, :]) 135 | obj_boxes.append(box[o_idx, :]) 136 | sub_boxes = torch.stack(sub_boxes, dim=0) 137 | obj_boxes = torch.stack(obj_boxes, dim=0) 138 | 139 | # accumulate results (iterate through interaction queries) 140 | results = [] 141 | for os, ol, vs, ms, sb, ob in zip(obj_scores, obj_labels, verb_scores, matching_scores, sub_boxes, obj_boxes): 142 | sl = torch.full_like(ol, 0) # self.subject_category_id = 0 in HICO-DET 143 | l = torch.cat((sl, ol)) 144 | b = torch.cat((sb, ob)) 145 | results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')}) 146 | vs = vs * os.unsqueeze(1) 147 | ids = torch.arange(b.shape[0]) 148 | res_dict = { 149 | 'verb_scores': vs.to('cpu'), 150 | 'sub_ids': ids[:ids.shape[0] // 2], 151 | 'obj_ids': ids[ids.shape[0] // 2:], 152 | 'hoi_recognition_time': hoi_recognition_time 153 | } 154 | results[-1].update(res_dict) 155 | else: 156 | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)] 157 | 158 | return results -------------------------------------------------------------------------------- /src/models/detr.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/models/detr.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | """ 9 | DETR & HOTR model and criterion classes. 10 | """ 11 | import torch 12 | import torch.nn.functional as F 13 | from torch import nn 14 | 15 | from src.util.misc import (NestedTensor, nested_tensor_from_tensor_list) 16 | 17 | from .backbone import build_backbone 18 | from .detr_matcher import build_matcher 19 | from .hotr_matcher import build_hoi_matcher 20 | from .transformer import build_transformer, build_hoi_transformer 21 | from .criterion import SetCriterion 22 | from .post_process import PostProcess 23 | from .feed_forward import MLP 24 | 25 | from .hotr import HOTR 26 | from .stip import STIP, STIPPostProcess, STIPCriterion 27 | 28 | class DETR(nn.Module): 29 | """ This is the DETR module that performs object detection """ 30 | def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False): 31 | """ Initializes the model. 32 | Parameters: 33 | backbone: torch module of the backbone to be used. See backbone.py 34 | transformer: torch module of the transformer architecture. See transformer.py 35 | num_classes: number of object classes 36 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects 37 | DETR can detect in a single image. For COCO, we recommend 100 queries. 38 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 39 | """ 40 | super().__init__() 41 | self.num_queries = num_queries 42 | self.transformer = transformer 43 | hidden_dim = transformer.d_model 44 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 45 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 46 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 47 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) 48 | self.backbone = backbone 49 | self.aux_loss = aux_loss 50 | 51 | def forward(self, samples: NestedTensor, targets=None): 52 | """ The forward expects a NestedTensor, which consists of: 53 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W] 54 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels 55 | It returns a dict with the following elements: 56 | - "pred_logits": the classification logits (including no-object) for all queries. 57 | Shape= [batch_size x num_queries x (num_classes + 1)] 58 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 59 | (center_x, center_y, height, width). These values are normalized in [0, 1], 60 | relative to the size of each individual image (disregarding possible padding). 61 | See PostProcess for information on how to retrieve the unnormalized bounding box. 62 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 63 | dictionnaries containing the two above keys for each decoder layer. 64 | """ 65 | if isinstance(samples, (list, torch.Tensor)): 66 | samples = nested_tensor_from_tensor_list(samples) 67 | features, pos = self.backbone(samples) 68 | 69 | src, mask = features[-1].decompose() 70 | assert mask is not None 71 | hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] 72 | 73 | outputs_class = self.class_embed(hs) 74 | outputs_coord = self.bbox_embed(hs).sigmoid() 75 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} 76 | if self.aux_loss: 77 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) 78 | 79 | return out 80 | 81 | @torch.jit.unused 82 | def _set_aux_loss(self, outputs_class, outputs_coord): 83 | # this is a workaround to make torchscript happy, as torchscript 84 | # doesn't support dictionary with non-homogeneous values, such 85 | # as a dict having both a Tensor and a list. 86 | return [{'pred_logits': a, 'pred_boxes': b} 87 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] 88 | 89 | 90 | def build(args): 91 | device = torch.device(args.device) 92 | 93 | backbone = build_backbone(args) 94 | 95 | transformer = build_transformer(args) 96 | 97 | model = DETR( 98 | backbone, 99 | transformer, 100 | num_classes=args.num_classes, 101 | num_queries=args.num_queries, 102 | aux_loss=args.aux_loss, 103 | ) 104 | 105 | matcher = build_matcher(args) 106 | weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef} 107 | weight_dict['loss_giou'] = args.giou_loss_coef 108 | 109 | # TODO this is a hack 110 | if args.aux_loss: 111 | aux_weight_dict = {} 112 | for i in range(args.dec_layers - 1): 113 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 114 | weight_dict.update(aux_weight_dict) 115 | 116 | losses = ['labels', 'boxes', 'cardinality'] if args.frozen_weights is None else [] 117 | if args.HOIDet and args.STIP_relation_head: 118 | model = STIP(args, detr=model, detr_matcher=matcher) 119 | criterion = STIPCriterion(args, matcher) 120 | postprocessors = {'hoi': STIPPostProcess(args, model)} 121 | elif args.HOIDet: 122 | hoi_matcher = build_hoi_matcher(args) 123 | hoi_losses = [] 124 | hoi_losses.append('pair_labels') 125 | hoi_losses.append('pair_actions') 126 | if args.dataset_file == 'hico-det': hoi_losses.append('pair_targets') 127 | 128 | hoi_weight_dict={} 129 | hoi_weight_dict['loss_hidx'] = args.hoi_idx_loss_coef 130 | hoi_weight_dict['loss_oidx'] = args.hoi_idx_loss_coef 131 | hoi_weight_dict['loss_act'] = args.hoi_act_loss_coef 132 | if args.dataset_file == 'hico-det': hoi_weight_dict['loss_tgt'] = args.hoi_tgt_loss_coef 133 | if args.hoi_aux_loss: 134 | hoi_aux_weight_dict = {} 135 | for i in range(args.hoi_dec_layers): 136 | hoi_aux_weight_dict.update({k + f'_{i}': v for k, v in hoi_weight_dict.items()}) 137 | hoi_weight_dict.update(hoi_aux_weight_dict) 138 | 139 | criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=hoi_weight_dict, 140 | eos_coef=args.eos_coef, losses=losses, num_actions=args.num_actions, 141 | HOI_losses=hoi_losses, HOI_matcher=hoi_matcher, args=args) 142 | 143 | interaction_transformer = build_hoi_transformer(args) # if (args.share_enc and args.pretrained_dec) else None 144 | 145 | kwargs = {} 146 | if args.dataset_file == 'hico-det': kwargs['return_obj_class'] = args.valid_obj_ids 147 | model = HOTR( 148 | detr=model, 149 | num_hoi_queries=args.num_hoi_queries, 150 | num_actions=args.num_actions, 151 | interaction_transformer=interaction_transformer, 152 | freeze_detr=(args.frozen_weights is not None), 153 | share_enc=args.share_enc, 154 | pretrained_dec=args.pretrained_dec, 155 | temperature=args.temperature, 156 | hoi_aux_loss=args.hoi_aux_loss, 157 | **kwargs # only return verb class for HICO-DET dataset 158 | ) 159 | postprocessors = {'hoi': PostProcess(args.HOIDet)} 160 | else: 161 | criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict, 162 | eos_coef=args.eos_coef, losses=losses, num_actions=args.num_actions, args=args) 163 | postprocessors = {'bbox': PostProcess(args.HOIDet)} 164 | criterion.to(device) 165 | 166 | return model, criterion, postprocessors -------------------------------------------------------------------------------- /src/engine/arg_parser.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : engine/arg_parser.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # Modified arguments are represented with * 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | import argparse 10 | 11 | def get_args_parser(): 12 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False) 13 | parser.add_argument('--lr', default=1e-4, type=float) 14 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 15 | parser.add_argument('--batch_size', default=2, type=int) 16 | parser.add_argument('--weight_decay', default=1e-4, type=float) 17 | parser.add_argument('--epochs', default=100, type=int) 18 | parser.add_argument('--lr_drop', default=80, type=int) 19 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 20 | help='gradient clipping max norm') 21 | 22 | # DETR Model parameters 23 | parser.add_argument('--frozen_weights', type=str, default=None, 24 | help="Path to the pretrained model. If set, only the mask head will be trained") 25 | # DETR Backbone 26 | parser.add_argument('--backbone', default='resnet50', type=str, 27 | help="Name of the convolutional backbone to use") 28 | parser.add_argument('--dilation', action='store_true', 29 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 30 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 31 | help="Type of positional embedding to use on top of the image features") 32 | 33 | # DETR Transformer (= Encoder, Instance Decoder) 34 | parser.add_argument('--enc_layers', default=6, type=int, 35 | help="Number of encoding layers in the transformer") 36 | parser.add_argument('--dec_layers', default=6, type=int, 37 | help="Number of decoding layers in the transformer") 38 | parser.add_argument('--dim_feedforward', default=2048, type=int, 39 | help="Intermediate size of the feedforward layers in the transformer blocks") 40 | parser.add_argument('--hidden_dim', default=256, type=int, 41 | help="Size of the embeddings (dimension of the transformer)") 42 | parser.add_argument('--dropout', default=0.1, type=float, 43 | help="Dropout applied in the transformer") 44 | parser.add_argument('--nheads', default=8, type=int, 45 | help="Number of attention heads inside the transformer's attentions") 46 | parser.add_argument('--num_queries', default=100, type=int, 47 | help="Number of query slots") 48 | parser.add_argument('--pre_norm', action='store_true') 49 | 50 | # Segmentation 51 | parser.add_argument('--masks', action='store_true', 52 | help="Train segmentation head if the flag is provided") 53 | 54 | # Loss Option 55 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 56 | help="Disables auxiliary decoding losses (loss at each layer)") 57 | 58 | # Loss coefficients (DETR) 59 | parser.add_argument('--mask_loss_coef', default=1, type=float) 60 | parser.add_argument('--dice_loss_coef', default=1, type=float) 61 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 62 | parser.add_argument('--giou_loss_coef', default=2, type=float) 63 | parser.add_argument('--eos_coef', default=0.1, type=float, 64 | help="Relative classification weight of the no-object class") 65 | 66 | # Matcher (DETR) 67 | parser.add_argument('--set_cost_class', default=1, type=float, 68 | help="Class coefficient in the matching cost") 69 | parser.add_argument('--set_cost_bbox', default=5, type=float, 70 | help="L1 box coefficient in the matching cost") 71 | parser.add_argument('--set_cost_giou', default=2, type=float, 72 | help="giou box coefficient in the matching cost") 73 | 74 | # * HOI Detection 75 | parser.add_argument('--HOIDet', action='store_true', 76 | help="Train HOI Detection head if the flag is provided") 77 | parser.add_argument('--share_enc', action='store_true', 78 | help="Share the Encoder in DETR for HOI Detection if the flag is provided") 79 | parser.add_argument('--pretrained_dec', action='store_true', 80 | help="Use Pre-trained Decoder in DETR for Interaction Decoder if the flag is provided") 81 | parser.add_argument('--hoi_enc_layers', default=1, type=int, 82 | help="Number of decoding layers in HOI transformer") 83 | parser.add_argument('--hoi_dec_layers', default=6, type=int, 84 | help="Number of decoding layers in HOI transformer") 85 | parser.add_argument('--hoi_nheads', default=8, type=int, 86 | help="Number of decoding layers in HOI transformer") 87 | parser.add_argument('--hoi_dim_feedforward', default=2048, type=int, 88 | help="Number of decoding layers in HOI transformer") 89 | # parser.add_argument('--hoi_mode', type=str, default=None, help='[inst | pair | all]') 90 | parser.add_argument('--num_hoi_queries', default=32, type=int, 91 | help="Number of Queries for Interaction Decoder") 92 | parser.add_argument('--hoi_aux_loss', action='store_true') 93 | 94 | 95 | # * HOTR Matcher 96 | parser.add_argument('--set_cost_idx', default=1, type=float, 97 | help="IDX coefficient in the matching cost") 98 | parser.add_argument('--set_cost_act', default=1, type=float, 99 | help="Action coefficient in the matching cost") 100 | parser.add_argument('--set_cost_tgt', default=1, type=float, 101 | help="Target coefficient in the matching cost") 102 | 103 | # * HOTR Loss coefficients 104 | parser.add_argument('--temperature', default=0.05, type=float, help="temperature") 105 | parser.add_argument('--hoi_idx_loss_coef', default=1, type=float) 106 | parser.add_argument('--hoi_act_loss_coef', default=1, type=float) 107 | parser.add_argument('--hoi_tgt_loss_coef', default=1, type=float) 108 | parser.add_argument('--hoi_eos_coef', default=0.1, type=float, help="Relative classification weight of the no-object class") 109 | 110 | # * dataset parameters 111 | parser.add_argument('--dataset_file', help='[coco | vcoco]') 112 | parser.add_argument('--data_path', type=str) 113 | parser.add_argument('--object_threshold', type=float, default=0, help='Threshold for object confidence') 114 | 115 | # machine parameters 116 | parser.add_argument('--output_dir', default='', 117 | help='path where to save, empty for no saving') 118 | parser.add_argument('--custom_path', default='', 119 | help="Data path for custom inference. Only required for custom_main.py") 120 | parser.add_argument('--device', default='cuda', 121 | help='device to use for training / testing') 122 | parser.add_argument('--seed', default=42, type=int) 123 | parser.add_argument('--resume', default='', help='resume from checkpoint') 124 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 125 | help='start epoch') 126 | parser.add_argument('--num_workers', default=4, type=int) 127 | 128 | # mode 129 | parser.add_argument('--eval', action='store_true', help="Only evaluate results if the flag is provided") 130 | parser.add_argument('--validate', action='store_true', help="Validate after every epoch") 131 | 132 | # distributed training parameters 133 | parser.add_argument('--world_size', default=1, type=int, 134 | help='number of distributed processes') 135 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 136 | 137 | # * WanDB 138 | parser.add_argument('--wandb', action='store_true') 139 | parser.add_argument('--project_name', default='HOTR') 140 | parser.add_argument('--group_name', default='KakaoBrain') 141 | parser.add_argument('--run_name', default='run_000001') 142 | 143 | # STIP 144 | parser.add_argument('--STIP_relation_head', action='store_true', default=False) 145 | parser.add_argument('--finetune_detr', action='store_true', default=False) 146 | parser.add_argument('--use_high_resolution_relation_feature_map', action='store_true', default=False) 147 | return parser 148 | -------------------------------------------------------------------------------- /src/metrics/vcoco/ap_role.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from src.metrics.utils import _compute_ap, compute_overlap 4 | 5 | class APRole(object): 6 | def __init__(self, act_name, scenario_flag=True, iou_threshold=0.5): 7 | self.act_name = act_name 8 | self.iou_threshold = iou_threshold 9 | 10 | self.scenario_flag = scenario_flag 11 | # scenario_1 : True 12 | # scenario_2 : False 13 | 14 | self.fp = [np.zeros((0,))] * len(act_name) 15 | self.tp = [np.zeros((0,))] * len(act_name) 16 | self.score = [np.zeros((0,))] * len(act_name) 17 | self.num_ann = [0] * len(act_name) 18 | 19 | def add_data(self, h_box, o_box, score, i_box, i_act, p_box, p_act): 20 | # i_box, i_act : to check if only in COCO 21 | for label in range(len(self.act_name)): 22 | p_inds = (p_act[:, label] == 1) 23 | self.num_ann[label] += p_inds.sum() 24 | 25 | if h_box.shape[0] == 0 : return # if no prediction, just return 26 | # COCO (O), V-COCO (X) __or__ collater, no ann in image => ignore 27 | 28 | valid_i_inds = (i_act[:, 0] != -1) # (n_i, ) 29 | overlaps = compute_overlap(h_box, i_box) # (n_h, n_i) 30 | assigned_input = np.argmax(overlaps, axis=1) # (n_h, ) 31 | v_inds = valid_i_inds[assigned_input] # (n_h, ) 32 | 33 | h_box = h_box[v_inds] 34 | score = score[:, v_inds, :] 35 | if h_box.shape[0] == 0 : return 36 | n_h = h_box.shape[0] 37 | 38 | valid_p_inds = (p_act[:, 0] != -1) | (p_box[:, 0] != -1) 39 | p_act = p_act[valid_p_inds] 40 | p_box = p_box[valid_p_inds] 41 | 42 | n_o = o_box.shape[0] 43 | if n_o == 0: 44 | # no prediction for object 45 | score = score.squeeze(axis=2) # (#act, n_h) 46 | 47 | for label in range(len(self.act_name)): 48 | h_inds = np.argsort(score[label])[::-1] # (n_h, ) 49 | self.score[label] = np.append(self.score[label], score[label, h_inds]) 50 | 51 | p_inds = (p_act[:, label] == 1) 52 | if p_inds.sum() == 0: 53 | self.tp[label] = np.append(self.tp[label], np.array([0]*n_h)) 54 | self.fp[label] = np.append(self.fp[label], np.array([1]*n_h)) 55 | continue 56 | 57 | h_overlaps = compute_overlap(h_box[h_inds], p_box[p_inds, :4]) # (n_h, n_p) 58 | assigned_p = np.argmax(h_overlaps, axis=1) # (n_h, ) 59 | h_max_overlap = h_overlaps[range(n_h), assigned_p] # (n_h, ) 60 | 61 | o_overlaps = compute_overlap(np.zeros((n_h, 4)), p_box[p_inds][assigned_p, 4:8]) 62 | o_overlaps = np.diag(o_overlaps) # (n_h, ) 63 | 64 | no_role_inds = (p_box[p_inds][assigned_p, 4] == -1) # (n_h, ) 65 | # human (o), action (o), no object in actual image 66 | 67 | h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, ) 68 | o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, ) 69 | 70 | # scenario1 is not considered (already no object) 71 | o_iou_inds[no_role_inds] = 1 72 | 73 | iou_inds = (h_iou_inds & o_iou_inds) 74 | p_nonzero = iou_inds.nonzero()[0] 75 | p_inds = assigned_p[p_nonzero] 76 | p_iou = np.unique(p_inds, return_index=True)[1] 77 | p_tp = p_nonzero[p_iou] 78 | 79 | t = np.zeros(n_h, dtype=np.uint8) 80 | t[p_tp] = 1 81 | f = 1-t 82 | 83 | self.tp[label] = np.append(self.tp[label], t) 84 | self.fp[label] = np.append(self.fp[label], f) 85 | 86 | else: 87 | s_obj_argmax = np.argmax(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h) 88 | s_obj_max = np.max(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h) 89 | 90 | h_overlaps = compute_overlap(h_box, p_box[:, :4]) # (n_h, n_p) 91 | for label in range(len(self.act_name)): 92 | h_inds = np.argsort(s_obj_max[label])[::-1] # (n_h, ) 93 | self.score[label] = np.append(self.score[label], s_obj_max[label, h_inds]) 94 | 95 | p_inds = (p_act[:, label] == 1) # (n_p, ) 96 | if p_inds.sum() == 0: ## no such relation, all considered as FP 97 | self.tp[label] = np.append(self.tp[label], np.array([0]*n_h)) 98 | self.fp[label] = np.append(self.fp[label], np.array([1]*n_h)) 99 | continue 100 | 101 | h_overlaps = compute_overlap(h_box[h_inds], p_box[:, :4]) # (n_h, n_p) # match for all hboxes 102 | h_max_overlap = np.max(h_overlaps, axis=1) # (n_h, ) # get the max overlap for hbox 103 | 104 | # for same human, multiple pairs exist. find the human box that has the same idx with max overlap hbox. 105 | h_max_temp = np.expand_dims(h_max_overlap, axis=1) 106 | h_over_thresh = (h_overlaps == h_max_temp) # (n_h, n_p) 107 | h_over_thresh = h_over_thresh & np.expand_dims(p_inds, axis=0) # (n_h, n_p) 108 | 109 | h_valid = h_over_thresh.sum(axis=1)>0 # (n_h, ) # at least one is True 110 | # h_valid -> if all is False, then argmax becomes 0. <- prevent 111 | assigned_p = np.argmax(h_over_thresh, axis=1) # (n_h, ) # p only for current act 112 | 113 | o_mapping_box = o_box[s_obj_argmax[label]][h_inds] # (n_h, ) # find where T is. 114 | p_mapping_box = p_box[assigned_p, 4:8] # (n_h, 4) 115 | 116 | o_overlaps = compute_overlap(o_mapping_box, p_mapping_box) # --: matching object 117 | o_overlaps = np.diag(o_overlaps) # (n_h, ) 118 | o_overlaps.setflags(write=1) 119 | if (~h_valid).sum() > 0: 120 | o_overlaps[~h_valid] = 0 # (n_h, ) 121 | 122 | no_role_inds = (p_box[assigned_p, 4] == -1) # (n_h, ) -- GT relation obj box is empty 123 | nan_box_inds = np.all(o_mapping_box == 0, axis=1) | np.all(np.isnan(o_mapping_box), axis=1) ## predicted relation obj box is 0 or nan 124 | no_role_inds = no_role_inds & h_valid 125 | nan_box_inds = nan_box_inds & h_valid 126 | 127 | h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, ) 128 | o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, ) 129 | 130 | if self.scenario_flag: ## scenario_1: for empty GT relation obj box, the predicted obj box should also be correct 131 | o_iou_inds[no_role_inds & nan_box_inds] = 1 132 | o_iou_inds[no_role_inds & ~nan_box_inds] = 0 133 | else: ## scenario_2: for empty GT relation obj box, ignore prediction for obj box 134 | o_iou_inds[no_role_inds] = 1 135 | 136 | iou_inds = (h_iou_inds & o_iou_inds) 137 | p_nonzero = iou_inds.nonzero()[0] 138 | p_inds = assigned_p[p_nonzero] 139 | p_iou = np.unique(p_inds, return_index=True)[1] ## remove duplicate, one GT match only once 140 | p_tp = p_nonzero[p_iou] 141 | 142 | t = np.zeros(n_h, dtype=np.uint8) 143 | t[p_tp] = 1 144 | f = 1-t 145 | 146 | self.tp[label] = np.append(self.tp[label], t) 147 | self.fp[label] = np.append(self.fp[label], f) 148 | # print('add data finished') 149 | 150 | def evaluate(self, print_log=False): 151 | average_precisions = dict() 152 | role_num = 1 if self.scenario_flag else 2 153 | for label in range(len(self.act_name)): 154 | 155 | # sort by score 156 | indices = np.argsort(-self.score[label]) 157 | self.fp[label] = self.fp[label][indices] ## compute AP over all predicted relations across the dataset by ranking 158 | self.tp[label] = self.tp[label][indices] 159 | 160 | 161 | if self.num_ann[label] == 0: 162 | average_precisions[label] = 0 163 | continue 164 | 165 | # compute false positives and true positives 166 | self.fp[label] = np.cumsum(self.fp[label]) 167 | self.tp[label] = np.cumsum(self.tp[label]) 168 | 169 | # compute recall and precision 170 | recall = self.tp[label] / self.num_ann[label] 171 | precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps) 172 | 173 | # compute average precision 174 | average_precisions[label] = _compute_ap(recall, precision) * 100 175 | 176 | if print_log: print(f'\n============= AP (Role scenario_{role_num}) ==============') 177 | s, n = 0, 0 178 | 179 | for label in range(len(self.act_name)): 180 | if 'point' in self.act_name[label]: 181 | continue 182 | label_name = "_".join(self.act_name[label].split("_")[1:]) 183 | if print_log: print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label])) 184 | if self.num_ann[label] != 0 : 185 | s += average_precisions[label] 186 | n += 1 187 | 188 | mAP = s/n 189 | if print_log: 190 | print('| mAP(role scenario_{:d}): {:0.2f}'.format(role_num, mAP)) 191 | print('----------------------------------------------------') 192 | 193 | return mAP -------------------------------------------------------------------------------- /src/data/evaluators/coco_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | COCO evaluator that works in distributed mode. 4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py 5 | The difference is that there is less copy-pasting from pycocotools 6 | in the end of the file, as python3 can suppress prints with contextlib 7 | """ 8 | import os 9 | import contextlib 10 | import copy 11 | import numpy as np 12 | import torch 13 | 14 | from pycocotools.cocoeval import COCOeval 15 | from pycocotools.coco import COCO 16 | import pycocotools.mask as mask_util 17 | 18 | from src.util.misc import all_gather 19 | 20 | 21 | class CocoEvaluator(object): 22 | def __init__(self, coco_gt, iou_types): 23 | assert isinstance(iou_types, (list, tuple)) 24 | coco_gt = copy.deepcopy(coco_gt) 25 | self.coco_gt = coco_gt 26 | 27 | self.iou_types = iou_types 28 | self.coco_eval = {} 29 | for iou_type in iou_types: 30 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) 31 | 32 | self.img_ids = [] 33 | self.eval_imgs = {k: [] for k in iou_types} 34 | 35 | def update(self, predictions): 36 | img_ids = list(np.unique(list(predictions.keys()))) 37 | self.img_ids.extend(img_ids) 38 | 39 | for iou_type in self.iou_types: 40 | results = self.prepare(predictions, iou_type) 41 | 42 | # suppress pycocotools prints 43 | with open(os.devnull, 'w') as devnull: 44 | with contextlib.redirect_stdout(devnull): 45 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO() 46 | coco_eval = self.coco_eval[iou_type] 47 | 48 | coco_eval.cocoDt = coco_dt 49 | coco_eval.params.imgIds = list(img_ids) 50 | img_ids, eval_imgs = evaluate(coco_eval) 51 | 52 | self.eval_imgs[iou_type].append(eval_imgs) 53 | 54 | def synchronize_between_processes(self): 55 | for iou_type in self.iou_types: 56 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) 57 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) 58 | 59 | def accumulate(self): 60 | for coco_eval in self.coco_eval.values(): 61 | coco_eval.accumulate() 62 | 63 | def summarize(self): 64 | for iou_type, coco_eval in self.coco_eval.items(): 65 | print("IoU metric: {}".format(iou_type)) 66 | coco_eval.summarize() 67 | 68 | def prepare(self, predictions, iou_type): 69 | if iou_type == "bbox": 70 | return self.prepare_for_coco_detection(predictions) 71 | elif iou_type == "segm": 72 | return self.prepare_for_coco_segmentation(predictions) 73 | elif iou_type == "keypoints": 74 | return self.prepare_for_coco_keypoint(predictions) 75 | else: 76 | raise ValueError("Unknown iou type {}".format(iou_type)) 77 | 78 | def prepare_for_coco_detection(self, predictions): 79 | coco_results = [] 80 | for original_id, prediction in predictions.items(): 81 | if len(prediction) == 0: 82 | continue 83 | 84 | boxes = prediction["boxes"] 85 | boxes = convert_to_xywh(boxes).tolist() 86 | scores = prediction["scores"].tolist() 87 | labels = prediction["labels"].tolist() 88 | 89 | coco_results.extend( 90 | [ 91 | { 92 | "image_id": original_id, 93 | "category_id": labels[k], 94 | "bbox": box, 95 | "score": scores[k], 96 | } 97 | for k, box in enumerate(boxes) 98 | ] 99 | ) 100 | return coco_results 101 | 102 | def prepare_for_coco_segmentation(self, predictions): 103 | coco_results = [] 104 | for original_id, prediction in predictions.items(): 105 | if len(prediction) == 0: 106 | continue 107 | 108 | scores = prediction["scores"] 109 | labels = prediction["labels"] 110 | masks = prediction["masks"] 111 | 112 | masks = masks > 0.5 113 | 114 | scores = prediction["scores"].tolist() 115 | labels = prediction["labels"].tolist() 116 | 117 | rles = [ 118 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] 119 | for mask in masks 120 | ] 121 | for rle in rles: 122 | rle["counts"] = rle["counts"].decode("utf-8") 123 | 124 | coco_results.extend( 125 | [ 126 | { 127 | "image_id": original_id, 128 | "category_id": labels[k], 129 | "segmentation": rle, 130 | "score": scores[k], 131 | } 132 | for k, rle in enumerate(rles) 133 | ] 134 | ) 135 | return coco_results 136 | 137 | def prepare_for_coco_keypoint(self, predictions): 138 | coco_results = [] 139 | for original_id, prediction in predictions.items(): 140 | if len(prediction) == 0: 141 | continue 142 | 143 | boxes = prediction["boxes"] 144 | boxes = convert_to_xywh(boxes).tolist() 145 | scores = prediction["scores"].tolist() 146 | labels = prediction["labels"].tolist() 147 | keypoints = prediction["keypoints"] 148 | keypoints = keypoints.flatten(start_dim=1).tolist() 149 | 150 | coco_results.extend( 151 | [ 152 | { 153 | "image_id": original_id, 154 | "category_id": labels[k], 155 | 'keypoints': keypoint, 156 | "score": scores[k], 157 | } 158 | for k, keypoint in enumerate(keypoints) 159 | ] 160 | ) 161 | return coco_results 162 | 163 | 164 | def convert_to_xywh(boxes): 165 | xmin, ymin, xmax, ymax = boxes.unbind(1) 166 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) 167 | 168 | 169 | def merge(img_ids, eval_imgs): 170 | all_img_ids = all_gather(img_ids) 171 | all_eval_imgs = all_gather(eval_imgs) 172 | 173 | merged_img_ids = [] 174 | for p in all_img_ids: 175 | merged_img_ids.extend(p) 176 | 177 | merged_eval_imgs = [] 178 | for p in all_eval_imgs: 179 | merged_eval_imgs.append(p) 180 | 181 | merged_img_ids = np.array(merged_img_ids) 182 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) 183 | 184 | # keep only unique (and in sorted order) images 185 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) 186 | merged_eval_imgs = merged_eval_imgs[..., idx] 187 | 188 | return merged_img_ids, merged_eval_imgs 189 | 190 | 191 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs): 192 | img_ids, eval_imgs = merge(img_ids, eval_imgs) 193 | img_ids = list(img_ids) 194 | eval_imgs = list(eval_imgs.flatten()) 195 | 196 | coco_eval.evalImgs = eval_imgs 197 | coco_eval.params.imgIds = img_ids 198 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params) 199 | 200 | 201 | ################################################################# 202 | # From pycocotools, just removed the prints and fixed 203 | # a Python3 bug about unicode not defined 204 | ################################################################# 205 | 206 | 207 | def evaluate(self): 208 | ''' 209 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs 210 | :return: None 211 | ''' 212 | # tic = time.time() 213 | # print('Running per image evaluation...') 214 | p = self.params 215 | # add backward compatibility if useSegm is specified in params 216 | if p.useSegm is not None: 217 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox' 218 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) 219 | # print('Evaluate annotation type *{}*'.format(p.iouType)) 220 | p.imgIds = list(np.unique(p.imgIds)) 221 | if p.useCats: 222 | p.catIds = list(np.unique(p.catIds)) 223 | p.maxDets = sorted(p.maxDets) 224 | self.params = p 225 | 226 | self._prepare() 227 | # loop through images, area range, max detection number 228 | catIds = p.catIds if p.useCats else [-1] 229 | 230 | if p.iouType == 'segm' or p.iouType == 'bbox': 231 | computeIoU = self.computeIoU 232 | elif p.iouType == 'keypoints': 233 | computeIoU = self.computeOks 234 | self.ious = { 235 | (imgId, catId): computeIoU(imgId, catId) 236 | for imgId in p.imgIds 237 | for catId in catIds} 238 | 239 | evaluateImg = self.evaluateImg 240 | maxDet = p.maxDets[-1] 241 | evalImgs = [ 242 | evaluateImg(imgId, catId, areaRng, maxDet) 243 | for catId in catIds 244 | for areaRng in p.areaRng 245 | for imgId in p.imgIds 246 | ] 247 | # this is NOT in the pycocotools code, but could be done outside 248 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) 249 | self._paramsEval = copy.deepcopy(self.params) 250 | # toc = time.time() 251 | # print('DONE (t={:0.2f}s).'.format(toc-tic)) 252 | return p.imgIds, evalImgs 253 | 254 | ################################################################# 255 | # end of straight copy from pycocotools, just removing the prints 256 | ################################################################# 257 | -------------------------------------------------------------------------------- /src/models/hotr_matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/models/hotr_matcher.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | 9 | from src.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 10 | 11 | import src.util.misc as utils 12 | import wandb 13 | 14 | class HungarianPairMatcher(nn.Module): 15 | def __init__(self, args): 16 | """Creates the matcher 17 | Params: 18 | cost_action: This is the relative weight of the multi-label action classification error in the matching cost 19 | cost_hbox: This is the relative weight of the classification error for human idx in the matching cost 20 | cost_obox: This is the relative weight of the classification error for object idx in the matching cost 21 | """ 22 | super().__init__() 23 | self.cost_action = args.set_cost_act 24 | self.cost_hbox = self.cost_obox = args.set_cost_idx 25 | self.cost_target = args.set_cost_tgt 26 | self.log_printer = args.wandb 27 | self.is_vcoco = (args.dataset_file == 'vcoco') 28 | self.is_hico = (args.dataset_file == 'hico-det') 29 | if self.is_vcoco: 30 | self.valid_ids = args.valid_ids 31 | self.invalid_ids = args.invalid_ids 32 | assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0" 33 | 34 | def reduce_redundant_gt_box(self, tgt_bbox, indices): 35 | """Filters redundant Ground-Truth Bounding Boxes 36 | Due to random crop augmentation, there exists cases where there exists 37 | multiple redundant labels for the exact same bounding box and object class. 38 | This function deals with the redundant labels for smoother HOTR training. 39 | """ 40 | tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True) 41 | 42 | k_idx, bbox_idx = indices 43 | triggered = False 44 | if (len(tgt_bbox) != len(tgt_bbox_unique)): 45 | map_dict = {k: v for k, v in enumerate(map_idx)} 46 | map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)} 47 | 48 | bbox_lst, k_lst = [], [] 49 | for bbox_id in bbox_idx: 50 | if map_dict[int(bbox_id)] not in bbox_lst: 51 | bbox_lst.append(map_dict[int(bbox_id)]) 52 | k_lst.append(map_bbox2kidx[int(bbox_id)]) 53 | bbox_idx = torch.tensor(bbox_lst) 54 | k_idx = torch.tensor(k_lst) 55 | tgt_bbox_res = tgt_bbox_unique 56 | else: 57 | tgt_bbox_res = tgt_bbox 58 | bbox_idx = bbox_idx.to(tgt_bbox.device) 59 | 60 | return tgt_bbox_res, k_idx, bbox_idx 61 | 62 | @torch.no_grad() 63 | def forward(self, outputs, targets, indices, log=False): 64 | assert "pred_actions" in outputs, "There is no action output for pair matching" 65 | num_obj_queries = outputs["pred_boxes"].shape[1] 66 | bs, num_queries = outputs["pred_actions"].shape[:2] 67 | detr_query_num = outputs["pred_logits"].shape[1] \ 68 | if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1 69 | 70 | return_list = [] 71 | if self.log_printer and log: 72 | log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []} 73 | if self.is_hico: log_dict['tgt_cost'] = [] 74 | 75 | for batch_idx in range(bs): 76 | tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4) 77 | tgt_cls = targets[batch_idx]["labels"] # (num_boxes) 78 | 79 | if self.is_vcoco: 80 | targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0 81 | keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0) 82 | targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx] 83 | targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx] 84 | targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx] 85 | 86 | tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8) 87 | tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29) 88 | tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes) 89 | 90 | tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4) 91 | tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4) 92 | elif self.is_hico: 93 | tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117) 94 | tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes) 95 | 96 | tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4) 97 | tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4) 98 | 99 | # find which gt boxes match the h, o boxes in the pair 100 | if self.is_vcoco: 101 | hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1) 102 | elif self.is_hico: 103 | hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1) 104 | obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1) 105 | obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1 106 | 107 | bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1) 108 | bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx]) 109 | bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0) 110 | 111 | cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1) 112 | cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1) 113 | 114 | # find which gt boxes matches which prediction in K 115 | h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes) 116 | o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes) 117 | 118 | tgt_hids, tgt_oids = [], [] 119 | # obtain ground truth indices for h 120 | if len(h_match_indices) != len(o_match_indices): 121 | import pdb; pdb.set_trace() 122 | 123 | for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices): 124 | hbox_idx, H_bbox_idx = h_match_idx 125 | obox_idx, O_bbox_idx = o_match_idx 126 | if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1 127 | O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear 128 | 129 | GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1) 130 | query_idx_for_H = k_idx[GT_idx_for_H] 131 | tgt_hids.append(query_idx_for_H) 132 | 133 | GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1) 134 | query_idx_for_O = k_idx[GT_idx_for_O] 135 | tgt_oids.append(query_idx_for_O) 136 | 137 | # check if empty 138 | if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1 139 | if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1 140 | 141 | tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0) 142 | flag = False 143 | if tgt_act.shape[0] == 0: 144 | tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device) 145 | targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device) 146 | if self.is_hico: 147 | pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1 148 | tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"]) 149 | targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device) 150 | tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0) 151 | 152 | # Concat target label 153 | tgt_hids = torch.cat(tgt_hids) 154 | tgt_oids = torch.cat(tgt_oids) 155 | 156 | out_hprob = outputs["pred_hidx"][batch_idx].softmax(-1) 157 | out_oprob = outputs["pred_oidx"][batch_idx].softmax(-1) 158 | out_act = outputs["pred_actions"][batch_idx].clone() 159 | if self.is_vcoco: out_act[..., self.invalid_ids] = 0 160 | if self.is_hico: 161 | out_tgt = outputs["pred_obj_logits"][batch_idx].softmax(-1) 162 | out_tgt[..., -1] = 0 # don't get cost for no-object 163 | tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1) 164 | 165 | cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1] 166 | cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1] 167 | 168 | cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum 169 | cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0) 170 | cost_action = cost_pos_act + cost_neg_act 171 | 172 | h_cost = self.cost_hbox * cost_hclass 173 | o_cost = self.cost_obox * cost_oclass 174 | act_cost = self.cost_action * cost_action 175 | 176 | C = h_cost + o_cost + act_cost 177 | if self.is_hico: 178 | cost_target = -out_tgt[:, tgt_tgt] 179 | tgt_cost = self.cost_target * cost_target 180 | C += tgt_cost 181 | C = C.view(num_queries, -1).cpu() 182 | 183 | return_list.append(linear_sum_assignment(C)) 184 | targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device) 185 | targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device) 186 | log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean() 187 | 188 | if self.log_printer and log: 189 | log_dict['h_cost'].append(h_cost.min(dim=0)[0].mean()) 190 | log_dict['o_cost'].append(o_cost.min(dim=0)[0].mean()) 191 | log_dict['act_cost'].append(act_cost.min(dim=0)[0].mean()) 192 | if self.is_hico: log_dict['tgt_cost'].append(tgt_cost.min(dim=0)[0].mean()) 193 | if self.log_printer and log: 194 | log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean() 195 | log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean() 196 | log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean() 197 | if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean() 198 | if utils.get_rank() == 0: wandb.log(log_dict) 199 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in return_list], targets 200 | 201 | def build_hoi_matcher(args): 202 | return HungarianPairMatcher(args) 203 | -------------------------------------------------------------------------------- /src/util/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/util/misc.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | """ 9 | Misc functions, including distributed helpers. 10 | Mostly copy-paste from torchvision references. 11 | """ 12 | import os 13 | import subprocess 14 | from collections import deque 15 | import pickle 16 | from typing import Optional, List 17 | 18 | import torch 19 | import torch.distributed as dist 20 | from torch import Tensor 21 | 22 | # needed due to empty tensor bug in pytorch and torchvision 0.5 23 | import torchvision 24 | if float(torchvision.__version__[:3]) < 0.7: 25 | from torchvision.ops import _new_empty_tensor 26 | from torchvision.ops.misc import _output_size 27 | 28 | 29 | class SmoothedValue(object): 30 | """Track a series of values and provide access to smoothed values over a 31 | window or the global series average. 32 | """ 33 | 34 | def __init__(self, window_size=20, fmt=None): 35 | if fmt is None: 36 | fmt = "{median:.4f} ({global_avg:.4f})" 37 | self.deque = deque(maxlen=window_size) 38 | self.total = 0.0 39 | self.count = 0 40 | self.fmt = fmt 41 | 42 | def update(self, value, n=1): 43 | self.deque.append(value) 44 | self.count += n 45 | self.total += value * n 46 | 47 | def synchronize_between_processes(self): 48 | """ 49 | Warning: does not synchronize the deque! 50 | """ 51 | if not is_dist_avail_and_initialized(): 52 | return 53 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 54 | dist.barrier() 55 | dist.all_reduce(t) 56 | t = t.tolist() 57 | self.count = int(t[0]) 58 | self.total = t[1] 59 | 60 | @property 61 | def median(self): 62 | d = torch.tensor(list(self.deque)) 63 | return d.median().item() 64 | 65 | @property 66 | def avg(self): 67 | d = torch.tensor(list(self.deque), dtype=torch.float32) 68 | return d.mean().item() 69 | 70 | @property 71 | def global_avg(self): 72 | return self.total / self.count 73 | 74 | @property 75 | def max(self): 76 | return max(self.deque) 77 | 78 | @property 79 | def value(self): 80 | return self.deque[-1] 81 | 82 | def __str__(self): 83 | return self.fmt.format( 84 | median=self.median, 85 | avg=self.avg, 86 | global_avg=self.global_avg, 87 | max=self.max, 88 | value=self.value) 89 | 90 | 91 | def all_gather(data): 92 | """ 93 | Run all_gather on arbitrary picklable data (not necessarily tensors) 94 | Args: 95 | data: any picklable object 96 | Returns: 97 | list[data]: list of data gathered from each rank 98 | """ 99 | world_size = get_world_size() 100 | if world_size == 1: 101 | return [data] 102 | 103 | # serialized to a Tensor 104 | buffer = pickle.dumps(data) 105 | storage = torch.ByteStorage.from_buffer(buffer) 106 | tensor = torch.ByteTensor(storage).to("cuda") 107 | 108 | # obtain Tensor size of each rank 109 | local_size = torch.tensor([tensor.numel()], device="cuda") 110 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 111 | dist.all_gather(size_list, local_size) 112 | size_list = [int(size.item()) for size in size_list] 113 | max_size = max(size_list) 114 | 115 | # receiving Tensor from all ranks 116 | # we pad the tensor because torch all_gather does not support 117 | # gathering tensors of different shapes 118 | tensor_list = [] 119 | for _ in size_list: 120 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 121 | if local_size != max_size: 122 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 123 | tensor = torch.cat((tensor, padding), dim=0) 124 | dist.all_gather(tensor_list, tensor) 125 | 126 | data_list = [] 127 | for size, tensor in zip(size_list, tensor_list): 128 | buffer = tensor.cpu().numpy().tobytes()[:size] 129 | data_list.append(pickle.loads(buffer)) 130 | 131 | return data_list 132 | 133 | 134 | def reduce_dict(input_dict, average=True): 135 | """ 136 | Args: 137 | input_dict (dict): all the values will be reduced 138 | average (bool): whether to do average or sum 139 | Reduce the values in the dictionary from all processes so that all processes 140 | have the averaged results. Returns a dict with the same fields as 141 | input_dict, after reduction. 142 | """ 143 | world_size = get_world_size() 144 | if world_size < 2: 145 | return input_dict 146 | with torch.no_grad(): 147 | names = [] 148 | values = [] 149 | # sort the keys so that they are consistent across processes 150 | for k in sorted(input_dict.keys()): 151 | names.append(k) 152 | values.append(input_dict[k]) 153 | values = torch.stack(values, dim=0) 154 | dist.all_reduce(values) 155 | if average: 156 | values /= world_size 157 | reduced_dict = {k: v for k, v in zip(names, values)} 158 | return reduced_dict 159 | 160 | 161 | def get_sha(): 162 | cwd = os.path.dirname(os.path.abspath(__file__)) 163 | 164 | def _run(command): 165 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 166 | sha = 'N/A' 167 | diff = "clean" 168 | branch = 'N/A' 169 | try: 170 | sha = _run(['git', 'rev-parse', 'HEAD']) 171 | subprocess.check_output(['git', 'diff'], cwd=cwd) 172 | diff = _run(['git', 'diff-index', 'HEAD']) 173 | diff = "has uncommited changes" if diff else "clean" 174 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 175 | except Exception: 176 | pass 177 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 178 | return message 179 | 180 | 181 | def collate_fn(batch): 182 | batch = list(zip(*batch)) 183 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 184 | return tuple(batch) 185 | 186 | 187 | def _max_by_axis(the_list): 188 | # type: (List[List[int]]) -> List[int] 189 | maxes = the_list[0] 190 | for sublist in the_list[1:]: 191 | for index, item in enumerate(sublist): 192 | maxes[index] = max(maxes[index], item) 193 | return maxes 194 | 195 | 196 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 197 | # TODO make this more general 198 | if tensor_list[0].ndim == 3: 199 | # TODO make it support different-sized images 200 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 201 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 202 | batch_shape = [len(tensor_list)] + max_size 203 | b, c, h, w = batch_shape 204 | dtype = tensor_list[0].dtype 205 | device = tensor_list[0].device 206 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 207 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 208 | for img, pad_img, m in zip(tensor_list, tensor, mask): 209 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 210 | m[: img.shape[1], :img.shape[2]] = False 211 | else: 212 | raise ValueError('not supported') 213 | return NestedTensor(tensor, mask) 214 | 215 | 216 | class NestedTensor(object): 217 | def __init__(self, tensors, mask: Optional[Tensor]): 218 | self.tensors = tensors 219 | self.mask = mask 220 | 221 | def to(self, device): 222 | # type: (Device) -> NestedTensor # noqa 223 | cast_tensor = self.tensors.to(device) 224 | mask = self.mask 225 | if mask is not None: 226 | assert mask is not None 227 | cast_mask = mask.to(device) 228 | else: 229 | cast_mask = None 230 | return NestedTensor(cast_tensor, cast_mask) 231 | 232 | def decompose(self): 233 | return self.tensors, self.mask 234 | 235 | def __repr__(self): 236 | return str(self.tensors) 237 | 238 | 239 | def setup_for_distributed(is_master): 240 | """ 241 | This function disables printing when not in master process 242 | """ 243 | import builtins as __builtin__ 244 | builtin_print = __builtin__.print 245 | 246 | def print(*args, **kwargs): 247 | force = kwargs.pop('force', False) 248 | if is_master or force: 249 | builtin_print(*args, **kwargs) 250 | 251 | __builtin__.print = print 252 | 253 | 254 | def is_dist_avail_and_initialized(): 255 | if not dist.is_available(): 256 | return False 257 | if not dist.is_initialized(): 258 | return False 259 | return True 260 | 261 | 262 | def get_world_size(): 263 | if not is_dist_avail_and_initialized(): 264 | return 1 265 | return dist.get_world_size() 266 | 267 | 268 | def get_rank(): 269 | if not is_dist_avail_and_initialized(): 270 | return 0 271 | return dist.get_rank() 272 | 273 | 274 | def is_main_process(): 275 | return get_rank() == 0 276 | 277 | 278 | def save_on_master(*args, **kwargs): 279 | if is_main_process(): 280 | torch.save(*args, **kwargs) 281 | 282 | 283 | def init_distributed_mode(args): 284 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 285 | args.rank = int(os.environ["RANK"]) 286 | args.world_size = int(os.environ['WORLD_SIZE']) 287 | args.gpu = int(os.environ['LOCAL_RANK']) 288 | elif 'SLURM_PROCID' in os.environ: 289 | args.rank = int(os.environ['SLURM_PROCID']) 290 | args.gpu = args.rank % torch.cuda.device_count() 291 | else: 292 | print('Not using distributed mode') 293 | args.distributed = False 294 | return 295 | 296 | args.distributed = True 297 | 298 | torch.cuda.set_device(args.gpu) 299 | args.dist_backend = 'nccl' 300 | print('| distributed init (rank {}): {}'.format( 301 | args.rank, args.dist_url), flush=True) 302 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 303 | world_size=args.world_size, rank=args.rank) 304 | torch.distributed.barrier() 305 | setup_for_distributed(args.rank == 0) 306 | 307 | 308 | @torch.no_grad() 309 | def accuracy(output, target, topk=(1,)): 310 | """Computes the precision@k for the specified values of k""" 311 | if target.numel() == 0: 312 | return [torch.zeros([], device=output.device)] 313 | maxk = max(topk) 314 | batch_size = target.size(0) 315 | 316 | _, pred = output.topk(maxk, 1, True, True) 317 | pred = pred.t() 318 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 319 | 320 | res = [] 321 | for k in topk: 322 | correct_k = correct[:k].view(-1).float().sum(0) 323 | res.append(correct_k.mul_(100.0 / batch_size)) 324 | return res 325 | 326 | 327 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 328 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 329 | """ 330 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 331 | This will eventually be supported natively by PyTorch, and this 332 | class can go away. 333 | """ 334 | if float(torchvision.__version__[:3]) < 0.7: 335 | if input.numel() > 0: 336 | return torch.nn.functional.interpolate( 337 | input, size, scale_factor, mode, align_corners 338 | ) 339 | 340 | output_shape = _output_size(2, input, size, scale_factor) 341 | output_shape = list(input.shape[:-2]) + list(output_shape) 342 | return _new_empty_tensor(input, output_shape) 343 | else: 344 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) -------------------------------------------------------------------------------- /src/data/evaluators/hico_eval.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/data/evaluators/hico_eval.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic) 6 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 7 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 8 | # ------------------------------------------------------------------------ 9 | import numpy as np 10 | from collections import defaultdict 11 | 12 | hico_valid_obj_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] 13 | 14 | class HICOEvaluator(): 15 | def __init__(self, preds, gts, rare_triplets, non_rare_triplets, correct_mat): 16 | self.overlap_iou = 0.5 17 | self.max_hois = 100 18 | self.mode = 'default' 19 | # self.mode = 'known_objects' 20 | 21 | self.rare_triplets = rare_triplets 22 | self.non_rare_triplets = non_rare_triplets 23 | 24 | self.fp = defaultdict(list) 25 | self.tp = defaultdict(list) 26 | self.score = defaultdict(list) 27 | self.sum_gts = defaultdict(lambda: 0) 28 | self.gt_triplets = [] 29 | 30 | self.preds = [] 31 | for img_preds in preds: 32 | img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items() if k != 'hoi_recognition_time'} 33 | bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_preds['boxes'], img_preds['labels'])] 34 | hoi_scores = img_preds['verb_scores'] 35 | verb_labels = np.tile(np.arange(hoi_scores.shape[1]), (hoi_scores.shape[0], 1)) 36 | subject_ids = np.tile(img_preds['sub_ids'], (hoi_scores.shape[1], 1)).T 37 | object_ids = np.tile(img_preds['obj_ids'], (hoi_scores.shape[1], 1)).T 38 | 39 | hoi_scores = hoi_scores.ravel() 40 | verb_labels = verb_labels.ravel() 41 | subject_ids = subject_ids.ravel() 42 | object_ids = object_ids.ravel() 43 | 44 | if len(subject_ids) > 0: 45 | object_labels = np.array([bboxes[object_id]['category_id'] for object_id in object_ids]) 46 | masks = correct_mat[verb_labels, object_labels] 47 | hoi_scores *= masks ## remove impossible verb-obj pairs 48 | 49 | hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for 50 | subject_id, object_id, category_id, score in zip(subject_ids, object_ids, verb_labels, hoi_scores)] 51 | hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 52 | hois = hois[:self.max_hois] 53 | else: 54 | hois = [] 55 | 56 | self.preds.append({ 57 | 'predictions': bboxes, 58 | 'hoi_prediction': hois 59 | }) 60 | 61 | self.gts = [] 62 | for img_gts in gts: 63 | img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id'} 64 | self.gts.append({ 65 | 'annotations': [{'bbox': bbox, 'category_id': hico_valid_obj_ids.index(label)} for bbox, label in zip(img_gts['boxes'], img_gts['labels'])], # map to valid obj ids 66 | 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in img_gts['hois']] 67 | }) 68 | for hoi in self.gts[-1]['hoi_annotation']: 69 | triplet = (self.gts[-1]['annotations'][hoi['subject_id']]['category_id'], 70 | self.gts[-1]['annotations'][hoi['object_id']]['category_id'], 71 | hoi['category_id']) 72 | 73 | if triplet not in self.gt_triplets: 74 | self.gt_triplets.append(triplet) 75 | 76 | self.sum_gts[triplet] += 1 77 | print('prepare for hico eval') 78 | 79 | def evaluate(self): 80 | for img_id, (img_preds, img_gts) in enumerate(zip(self.preds, self.gts)): 81 | print(f"Evaluating Score Matrix... : [{(img_id+1):>4}/{len(self.gts):<4}]" ,flush=True, end="\r") 82 | pred_bboxes = img_preds['predictions'] 83 | gt_bboxes = img_gts['annotations'] 84 | pred_hois = img_preds['hoi_prediction'] 85 | gt_hois = img_gts['hoi_annotation'] 86 | known_object_catids = [x['category_id'] for x in gt_bboxes] 87 | if len(gt_bboxes) != 0: 88 | bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes) ## match label and box 89 | self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps, known_object_catids=known_object_catids) 90 | else: 91 | if self.mode == 'default': # all predicted hois considered false positives 92 | for pred_hoi in pred_hois: 93 | triplet = [pred_bboxes[pred_hoi['subject_id']]['category_id'], 94 | pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id']] 95 | if triplet not in self.gt_triplets: 96 | continue 97 | self.tp[triplet].append(0) 98 | self.fp[triplet].append(1) 99 | self.score[triplet].append(pred_hoi['score']) 100 | print(f"[stats] Score Matrix Generation completed!! ") 101 | map = self.compute_map() 102 | return map 103 | 104 | def compute_map(self): 105 | ap = defaultdict(lambda: 0) 106 | rare_ap = defaultdict(lambda: 0) 107 | non_rare_ap = defaultdict(lambda: 0) 108 | max_recall = defaultdict(lambda: 0) 109 | rare_recall = defaultdict(lambda: 0) 110 | non_rare_recall = defaultdict(lambda: 0) 111 | for triplet in self.gt_triplets: 112 | # if triplet[-1] == 57: continue 113 | sum_gts = self.sum_gts[triplet] 114 | if sum_gts == 0: 115 | continue 116 | 117 | tp = np.array((self.tp[triplet])) 118 | fp = np.array((self.fp[triplet])) 119 | if len(tp) == 0: 120 | ap[triplet] = 0 121 | max_recall[triplet] = 0 122 | if triplet in self.rare_triplets: 123 | rare_ap[triplet] = 0 124 | rare_recall[triplet] = 0 125 | elif triplet in self.non_rare_triplets: 126 | non_rare_ap[triplet] = 0 127 | non_rare_recall[triplet] = 0 128 | else: 129 | print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet)) 130 | continue 131 | 132 | score = np.array(self.score[triplet]) 133 | sort_inds = np.argsort(-score) 134 | fp = fp[sort_inds] 135 | tp = tp[sort_inds] 136 | fp = np.cumsum(fp) 137 | tp = np.cumsum(tp) 138 | rec = tp / sum_gts 139 | prec = tp / (fp + tp) 140 | ap[triplet] = self.voc_ap(rec, prec) 141 | max_recall[triplet] = np.amax(rec) 142 | if triplet in self.rare_triplets: 143 | rare_ap[triplet] = ap[triplet] 144 | rare_recall[triplet] = max_recall[triplet] 145 | elif triplet in self.non_rare_triplets: 146 | non_rare_ap[triplet] = ap[triplet] 147 | non_rare_recall[triplet] = max_recall[triplet] 148 | else: 149 | print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet)) 150 | m_ap = np.mean(list(ap.values())) * 100 # percentage 151 | m_ap_rare = np.mean(list(rare_ap.values())) * 100 # percentage 152 | m_ap_non_rare = np.mean(list(non_rare_ap.values())) * 100 # percentage 153 | 154 | m_max_recall = np.mean(list(max_recall.values())) * 100 155 | m_r_rare = np.mean(list(rare_recall.values())) * 100 156 | m_r_non_rare = np.mean(list(non_rare_recall.values())) * 100 157 | 158 | print('--------------------') 159 | print('mAP: {} mAP rare: {} mAP non-rare: {}'.format(m_ap, m_ap_rare, m_ap_non_rare)) 160 | print('mR: {} mR rare: {} mR non-rare: {}'.format(m_max_recall, m_r_rare, m_r_non_rare)) 161 | print('--------------------') 162 | 163 | return {'mAP': m_ap, 'mAP rare': m_ap_rare, 'mAP non-rare': m_ap_non_rare, 'mean max recall': m_max_recall} 164 | 165 | def voc_ap(self, rec, prec): 166 | ap = 0. 167 | for t in np.arange(0., 1.1, 0.1): 168 | if np.sum(rec >= t) == 0: 169 | p = 0 170 | else: 171 | p = np.max(prec[rec >= t]) 172 | ap = ap + p / 11. 173 | return ap 174 | 175 | def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps, known_object_catids=None): 176 | pos_pred_ids = match_pairs.keys() 177 | vis_tag = np.zeros(len(gt_hois)) 178 | pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 179 | if len(pred_hois) != 0: 180 | for pred_hoi in pred_hois: 181 | is_match = 0 182 | if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and pred_hoi['object_id'] in pos_pred_ids: 183 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 184 | pred_obj_ids = match_pairs[pred_hoi['object_id']] 185 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 186 | pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']] 187 | pred_category_id = pred_hoi['category_id'] 188 | max_overlap = 0 189 | max_gt_hoi = 0 190 | for gt_hoi in gt_hois: 191 | if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids and pred_category_id == gt_hoi['category_id']: 192 | is_match = 1 193 | min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])], 194 | pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])]) 195 | if min_overlap_gt > max_overlap: 196 | max_overlap = min_overlap_gt 197 | max_gt_hoi = gt_hoi 198 | triplet = (pred_bboxes[pred_hoi['subject_id']]['category_id'], pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id']) 199 | if triplet not in self.gt_triplets: 200 | continue 201 | if self.mode == 'known_objects' and (triplet[1] not in known_object_catids): 202 | continue 203 | if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0: 204 | self.fp[triplet].append(0) 205 | self.tp[triplet].append(1) 206 | vis_tag[gt_hois.index(max_gt_hoi)] = 1 207 | else: 208 | self.fp[triplet].append(1) 209 | self.tp[triplet].append(0) 210 | self.score[triplet].append(pred_hoi['score']) 211 | 212 | def compute_iou_mat(self, bbox_list1, bbox_list2): 213 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) 214 | if len(bbox_list1) == 0 or len(bbox_list2) == 0: 215 | return {} 216 | for i, bbox1 in enumerate(bbox_list1): 217 | for j, bbox2 in enumerate(bbox_list2): 218 | iou_i = self.compute_IOU(bbox1, bbox2) 219 | iou_mat[i, j] = iou_i 220 | 221 | iou_mat_ov=iou_mat.copy() 222 | iou_mat[iou_mat>=self.overlap_iou] = 1 223 | iou_mat[iou_mat 0: 229 | for i, pred_id in enumerate(match_pairs[1]): 230 | if pred_id not in match_pairs_dict.keys(): 231 | match_pairs_dict[pred_id] = [] 232 | match_pair_overlaps[pred_id]=[] 233 | match_pairs_dict[pred_id].append(match_pairs[0][i]) 234 | match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id]) 235 | return match_pairs_dict, match_pair_overlaps 236 | 237 | def compute_IOU(self, bbox1, bbox2): 238 | if isinstance(bbox1['category_id'], str): 239 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) 240 | if isinstance(bbox2['category_id'], str): 241 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) 242 | if bbox1['category_id'] == bbox2['category_id']: 243 | rec1 = bbox1['bbox'] 244 | rec2 = bbox2['bbox'] 245 | # computing area of each rectangles 246 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1) 247 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1) 248 | 249 | # computing the sum_area 250 | sum_area = S_rec1 + S_rec2 251 | 252 | # find the each edge of intersect rectangle 253 | left_line = max(rec1[1], rec2[1]) 254 | right_line = min(rec1[3], rec2[3]) 255 | top_line = max(rec1[0], rec2[0]) 256 | bottom_line = min(rec1[2], rec2[2]) 257 | # judge if there is an intersect 258 | if left_line >= right_line or top_line >= bottom_line: 259 | return 0 260 | else: 261 | intersect = (right_line - left_line+1) * (bottom_line - top_line+1) 262 | return intersect / (sum_area - intersect) 263 | else: 264 | return 0 -------------------------------------------------------------------------------- /STIP_main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import random 5 | import time 6 | import multiprocessing 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import torch 11 | from torch.utils.data import DataLoader, DistributedSampler 12 | 13 | import src.data.datasets as datasets 14 | import src.util.misc as utils 15 | from src.engine.arg_parser import get_args_parser 16 | from src.data.datasets import build_dataset, get_coco_api_from_dataset 17 | from src.engine.trainer import train_one_epoch 18 | from src.engine import hoi_evaluator, hoi_accumulator 19 | from src.models import build_model 20 | import wandb 21 | from src.engine.evaluator_coco import coco_evaluate 22 | 23 | from src.util.logger import print_params, print_args 24 | from collections import OrderedDict 25 | 26 | def save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename): 27 | # save_ckpt: function for saving checkpoints 28 | output_dir = Path(args.output_dir) 29 | if args.output_dir: 30 | checkpoint_path = output_dir / f'{filename}.pth' 31 | utils.save_on_master({ 32 | 'model': model_without_ddp.state_dict(), 33 | 'optimizer': optimizer.state_dict(), 34 | 'lr_scheduler': lr_scheduler.state_dict(), 35 | 'epoch': epoch, 36 | 'args': args, 37 | }, checkpoint_path) 38 | 39 | def main(args): 40 | utils.init_distributed_mode(args) 41 | 42 | if not args.train_detr is not None: # pretrained DETR 43 | print("Freeze weights for detector") 44 | 45 | device = torch.device(args.device) 46 | 47 | # fix the seed for reproducibility 48 | seed = args.seed + utils.get_rank() 49 | torch.manual_seed(seed) 50 | np.random.seed(seed) 51 | random.seed(seed) 52 | 53 | # Data Setup 54 | dataset_train = build_dataset(image_set='train', args=args) 55 | dataset_val = build_dataset(image_set='val' if not args.eval else 'test', args=args) 56 | assert dataset_train.num_action() == dataset_val.num_action(), "Number of actions should be the same between splits" 57 | args.num_classes = dataset_train.num_category() 58 | args.num_actions = dataset_train.num_action() 59 | args.action_names = dataset_train.get_actions() 60 | if args.share_enc: args.hoi_enc_layers = args.enc_layers 61 | if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers 62 | if args.dataset_file == 'vcoco': 63 | # Save V-COCO dataset statistics 64 | args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0] 65 | args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1) 66 | args.human_actions = dataset_train.get_human_action() 67 | args.object_actions = dataset_train.get_object_action() 68 | args.num_human_act = dataset_train.num_human_act() 69 | elif args.dataset_file == 'hico-det': 70 | args.valid_obj_ids = dataset_train.get_valid_obj_ids() 71 | args.correct_mat = torch.tensor(dataset_val.correct_mat).to(device) 72 | print_args(args) 73 | 74 | if args.distributed: 75 | sampler_train = DistributedSampler(dataset_train, shuffle=True) 76 | sampler_val = DistributedSampler(dataset_val, shuffle=False) 77 | else: 78 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 79 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 80 | 81 | batch_sampler_train = torch.utils.data.BatchSampler( 82 | sampler_train, args.batch_size, drop_last=True) 83 | 84 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 85 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 86 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, 87 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) 88 | 89 | 90 | # Model Setup 91 | model, criterion, postprocessors = build_model(args) 92 | model.to(device) 93 | 94 | model_without_ddp = model 95 | if args.distributed: 96 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) 97 | model_without_ddp = model.module 98 | n_parameters = print_params(model) 99 | 100 | param_dicts = [ 101 | {"params": [p for n, p in model_without_ddp.named_parameters() if "detr" not in n and p.requires_grad]}, 102 | { 103 | "params": [p for n, p in model_without_ddp.named_parameters() if ("detr" in n and 'backbone' not in n) and p.requires_grad], 104 | "lr": args.lr * 0.1, 105 | }, 106 | { 107 | "params": [p for n, p in model_without_ddp.named_parameters() if ("detr" in n and 'backbone' in n) and p.requires_grad], 108 | "lr": args.lr * 0.01, 109 | }, 110 | ] 111 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay) 112 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 113 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args.reduce_lr_on_plateau_factor, patience=args.reduce_lr_on_plateau_patience, verbose=True) 114 | 115 | # Weight Setup 116 | if args.detr_weights is not None: 117 | print(f"Loading detr weights from args.detr_weights={args.detr_weights}") 118 | if args.detr_weights.startswith('https'): 119 | checkpoint = torch.hub.load_state_dict_from_url( 120 | args.detr_weights, map_location='cpu', check_hash=True) 121 | else: 122 | checkpoint = torch.load(args.detr_weights, map_location='cpu') 123 | 124 | if 'hico_ft_q16.pth' in args.detr_weights: # hack: for loading hico fine-tuned detr 125 | mapped_state_dict = OrderedDict() 126 | for k, v in checkpoint['model'].items(): 127 | if k.startswith('detr.'): 128 | mapped_state_dict[k.replace('detr.', '')] = v 129 | model_without_ddp.detr.load_state_dict(mapped_state_dict) 130 | else: 131 | model_without_ddp.detr.load_state_dict(checkpoint['model']) 132 | 133 | if args.resume: 134 | print(f"Loading model weights from args.resume={args.resume}") 135 | if args.resume.startswith('https'): 136 | checkpoint = torch.hub.load_state_dict_from_url( 137 | args.resume, map_location='cpu', check_hash=True) 138 | else: 139 | checkpoint = torch.load(args.resume, map_location='cpu') 140 | model_without_ddp.load_state_dict(checkpoint['model']) 141 | 142 | if args.eval: 143 | # test only mode 144 | if args.HOIDet: 145 | if args.dataset_file == 'vcoco': 146 | total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device) 147 | sc1, sc2 = hoi_accumulator(args, total_res, True, False) 148 | elif args.dataset_file == 'hico-det': 149 | test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device) 150 | print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f}') 151 | print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f}') 152 | print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f}') 153 | else: raise ValueError(f'dataset {args.dataset_file} is not supported.') 154 | return 155 | else: 156 | # check original detr code 157 | base_ds = get_coco_api_from_dataset(data_loader_val) 158 | test_stats, coco_evaluator = coco_evaluate(model, criterion, postprocessors, 159 | data_loader_val, base_ds, device, args.output_dir) 160 | if args.output_dir: 161 | utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, args.output_dir / "eval.pth") 162 | return 163 | 164 | # stats 165 | scenario1, scenario2 = 0, 0 166 | best_mAP, best_rare, best_non_rare = 0, 0, 0 167 | 168 | # add argparse 169 | if args.wandb and utils.get_rank() == 0: 170 | wandb.init( 171 | project=args.project_name, 172 | group=args.group_name, 173 | name=args.run_name, 174 | config=args 175 | ) 176 | wandb.watch(model) 177 | 178 | # Training starts here! 179 | start_time = time.time() 180 | for epoch in range(args.start_epoch, args.epochs): 181 | if args.distributed: 182 | sampler_train.set_epoch(epoch) 183 | train_stats = train_one_epoch( 184 | model, criterion, data_loader_train, optimizer, device, epoch, args.epochs, 185 | args.clip_max_norm, dataset_file=args.dataset_file, log=args.wandb) 186 | 187 | if isinstance(lr_scheduler, torch.optim.lr_scheduler.StepLR): lr_scheduler.step() 188 | 189 | # Validation 190 | if args.validate: 191 | print('-'*100) 192 | if args.dataset_file == 'vcoco': 193 | total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device) 194 | if utils.get_rank() == 0: 195 | sc1, sc2 = hoi_accumulator(args, total_res, False, args.wandb) 196 | if sc1 > scenario1: 197 | scenario1 = sc1 198 | scenario2 = sc2 199 | save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best') 200 | print(f'| Scenario #1 mAP : {sc1:.2f} ({scenario1:.2f})') 201 | print(f'| Scenario #2 mAP : {sc2:.2f} ({scenario2:.2f})') 202 | if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): lr_scheduler.step(sc1) 203 | elif args.dataset_file == 'hico-det': 204 | test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device) 205 | if utils.get_rank() == 0: 206 | if test_stats['mAP'] > best_mAP: 207 | best_mAP = test_stats['mAP'] 208 | best_rare = test_stats['mAP rare'] 209 | best_non_rare = test_stats['mAP non-rare'] 210 | save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best') 211 | print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f} ({best_mAP:.2f})') 212 | print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f} ({best_rare:.2f})') 213 | print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f} ({best_non_rare:.2f})') 214 | if args.wandb and utils.get_rank() == 0: 215 | wandb.log({ 216 | 'mAP': test_stats['mAP'] 217 | }) 218 | if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): lr_scheduler.step(test_stats['mAP']) 219 | print('-'*100) 220 | 221 | # if epoch%2==0: 222 | # save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename=f'checkpoint_{epoch}') 223 | 224 | total_time = time.time() - start_time 225 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 226 | print('Training time {}'.format(total_time_str)) 227 | if args.dataset_file == 'vcoco': 228 | print(f'| Scenario #1 mAP : {scenario1:.2f}') 229 | print(f'| Scenario #2 mAP : {scenario2:.2f}') 230 | elif args.dataset_file == 'hico-det': 231 | print(f'| mAP (full)\t\t: {best_mAP:.2f}') 232 | print(f'| mAP (rare)\t\t: {best_rare:.2f}') 233 | print(f'| mAP (non-rare)\t: {best_non_rare:.2f}') 234 | 235 | 236 | if __name__ == '__main__': 237 | parser = argparse.ArgumentParser( 238 | 'End-to-End Human Object Interaction training and evaluation script', 239 | parents=[get_args_parser()] 240 | ) 241 | # training 242 | parser.add_argument('--detr_weights', default=None, type=str) 243 | parser.add_argument('--train_detr', action='store_true', default=False) 244 | parser.add_argument('--finetune_detr_weight', default=0.1, type=float) 245 | parser.add_argument('--lr_detr', default=1e-5, type=float) 246 | parser.add_argument('--reduce_lr_on_plateau_patience', default=2, type=int) 247 | parser.add_argument('--reduce_lr_on_plateau_factor', default=0.1, type=float) 248 | 249 | # loss 250 | parser.add_argument('--proposal_focal_loss_alpha', default=0.75, type=float) # large alpha for high recall 251 | parser.add_argument('--action_focal_loss_alpha', default=0.5, type=float) 252 | parser.add_argument('--proposal_focal_loss_gamma', default=2, type=float) 253 | parser.add_argument('--action_focal_loss_gamma', default=2, type=float) 254 | parser.add_argument('--proposal_loss_coef', default=1, type=float) 255 | parser.add_argument('--action_loss_coef', default=1, type=float) 256 | 257 | # ablations 258 | parser.add_argument('--no_hard_mining_for_relation_discovery', dest='use_hard_mining_for_relation_discovery', action='store_false', default=True) 259 | parser.add_argument('--no_relation_dependency_encoding', dest='use_relation_dependency_encoding', action='store_false', default=True) 260 | parser.add_argument('--no_memory_layout_encoding', dest='use_memory_layout_encoding', action='store_false', default=True, help='layout encodings') 261 | parser.add_argument('--no_nms_on_detr', dest='apply_nms_on_detr', action='store_false', default=True) 262 | parser.add_argument('--no_tail_semantic_feature', dest='use_tail_semantic_feature', action='store_false', default=True) 263 | parser.add_argument('--no_spatial_feature', dest='use_spatial_feature', action='store_false', default=True) 264 | parser.add_argument('--no_interaction_decoder', action='store_true', default=False) 265 | 266 | # not sensitive or effective 267 | parser.add_argument('--use_memory_union_mask', action='store_true', default=False) 268 | parser.add_argument('--use_union_feature', action='store_true', default=False) 269 | parser.add_argument('--adaptive_relation_query_num', action='store_true', default=False) 270 | parser.add_argument('--use_relation_tgt_mask', action='store_true', default=False) 271 | parser.add_argument('--use_relation_tgt_mask_attend_topk', default=10, type=int) 272 | parser.add_argument('--use_prior_verb_label_mask', action='store_true', default=False) 273 | parser.add_argument('--relation_feature_map_from', default='backbone', help='backbone | detr_encoder') 274 | parser.add_argument('--use_query_fourier_encoding', action='store_true', default=False) 275 | 276 | args = parser.parse_args() 277 | args.STIP_relation_head = True 278 | 279 | if args.output_dir: 280 | args.output_dir += f"/{args.group_name}/{args.run_name}/" 281 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 282 | main(args) 283 | -------------------------------------------------------------------------------- /src/data/transforms/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Transforms and data augmentation for both image + bbox. 4 | """ 5 | import random 6 | 7 | import PIL 8 | import torch 9 | import torchvision.transforms as T 10 | import torchvision.transforms.functional as F 11 | 12 | from src.util.box_ops import box_xyxy_to_cxcywh 13 | from src.util.misc import interpolate 14 | 15 | 16 | def crop(image, target, region): 17 | cropped_image = F.crop(image, *region) 18 | 19 | target = target.copy() 20 | i, j, h, w = region 21 | 22 | # should we do something wrt the original size? 23 | target["size"] = torch.tensor([h, w]) 24 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 25 | 26 | fields = ["labels", "area", "iscrowd"] # add additional fields 27 | if "inst_actions" in target.keys(): 28 | fields.append("inst_actions") 29 | 30 | if "boxes" in target: 31 | boxes = target["boxes"] 32 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 33 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 34 | cropped_boxes = cropped_boxes.clamp(min=0) 35 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 36 | target["boxes"] = cropped_boxes.reshape(-1, 4) 37 | target["area"] = area 38 | fields.append("boxes") 39 | 40 | if "pair_boxes" in target or ("sub_boxes" in target and "obj_boxes" in target): 41 | if "pair_boxes" in target: 42 | pair_boxes = target["pair_boxes"] 43 | hboxes = pair_boxes[:, :4] 44 | oboxes = pair_boxes[:, 4:] 45 | if ("sub_boxes" in target and "obj_boxes" in target): 46 | hboxes = target["sub_boxes"] 47 | oboxes = target["obj_boxes"] 48 | 49 | cropped_hboxes = hboxes - torch.as_tensor([j, i, j, i]) 50 | cropped_hboxes = torch.min(cropped_hboxes.reshape(-1, 2, 2), max_size) 51 | cropped_hboxes = cropped_hboxes.clamp(min=0) 52 | hboxes = cropped_hboxes.reshape(-1, 4) 53 | 54 | obj_mask = (oboxes[:, 0] != -1) 55 | if obj_mask.sum() != 0: 56 | cropped_oboxes = oboxes[obj_mask] - torch.as_tensor([j, i, j, i]) 57 | cropped_oboxes = torch.min(cropped_oboxes.reshape(-1, 2, 2), max_size) 58 | cropped_oboxes = cropped_oboxes.clamp(min=0) 59 | oboxes[obj_mask] = cropped_oboxes.reshape(-1, 4) 60 | else: 61 | cropped_oboxes = oboxes 62 | 63 | cropped_pair_boxes = torch.cat([hboxes, oboxes], dim=-1) 64 | target["pair_boxes"] = cropped_pair_boxes 65 | pair_fields = ["pair_boxes", "pair_actions", "pair_targets"] 66 | 67 | if "masks" in target: 68 | # FIXME should we update the area here if there are no boxes[? 69 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 70 | fields.append("masks") 71 | 72 | # remove elements for which the boxes or masks that have zero area 73 | if "boxes" in target or "masks" in target: 74 | # favor boxes selection when defining which elements to keep 75 | # this is compatible with previous implementation 76 | if "boxes" in target: 77 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 78 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 79 | else: 80 | keep = target['masks'].flatten(1).any(1) 81 | 82 | for field in fields: 83 | if field in target: # added this because there is no 'iscrowd' field in v-coco dataset 84 | target[field] = target[field][keep] 85 | 86 | # remove elements that have redundant area 87 | if "boxes" in target and "labels" in target: 88 | cropped_boxes = target['boxes'] 89 | cropped_labels = target['labels'] 90 | 91 | cnr, keep_idx = [], [] 92 | for idx, (cropped_box, cropped_lbl) in enumerate(zip(cropped_boxes, cropped_labels)): 93 | if str((cropped_box, cropped_lbl)) not in cnr: 94 | cnr.append(str((cropped_box, cropped_lbl))) 95 | keep_idx.append(True) 96 | else: keep_idx.append(False) 97 | 98 | for field in fields: 99 | if field in target: 100 | target[field] = target[field][keep_idx] 101 | 102 | # remove elements for which pair boxes have zero area 103 | if "pair_boxes" in target: 104 | cropped_hboxes = target["pair_boxes"][:, :4].reshape(-1, 2, 2) 105 | cropped_oboxes = target["pair_boxes"][:, 4:].reshape(-1, 2, 2) 106 | keep_h = torch.all(cropped_hboxes[:, 1, :] > cropped_hboxes[:, 0, :], dim=1) 107 | keep_o = torch.all(cropped_oboxes[:, 1, :] > cropped_oboxes[:, 0, :], dim=1) 108 | not_empty_o = torch.all(target["pair_boxes"][:, 4:] >= 0, dim=1) 109 | discard_o = (~keep_o) & not_empty_o 110 | if (discard_o).sum() > 0: 111 | target["pair_boxes"][discard_o, 4:] = -1 112 | 113 | for pair_field in pair_fields: 114 | target[pair_field] = target[pair_field][keep_h] 115 | 116 | return cropped_image, target 117 | 118 | 119 | def hflip(image, target): 120 | flipped_image = F.hflip(image) 121 | 122 | w, h = image.size 123 | 124 | target = target.copy() 125 | if "boxes" in target: 126 | boxes = target["boxes"] 127 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 128 | target["boxes"] = boxes 129 | 130 | if "pair_boxes" in target: 131 | pair_boxes = target["pair_boxes"] 132 | hboxes = pair_boxes[:, :4] 133 | oboxes = pair_boxes[:, 4:] 134 | 135 | # human flip 136 | hboxes = hboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 137 | 138 | # object flip 139 | obj_mask = (oboxes[:, 0] != -1) 140 | if obj_mask.sum() != 0: 141 | o_tmp = oboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 142 | oboxes[obj_mask] = o_tmp[obj_mask] 143 | 144 | pair_boxes = torch.cat([hboxes, oboxes], dim=-1) 145 | target["pair_boxes"] = pair_boxes 146 | 147 | if "masks" in target: 148 | target['masks'] = target['masks'].flip(-1) 149 | 150 | return flipped_image, target 151 | 152 | 153 | def resize(image, target, size, max_size=None): 154 | # size can be min_size (scalar) or (w, h) tuple 155 | 156 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 157 | w, h = image_size 158 | if max_size is not None: 159 | min_original_size = float(min((w, h))) 160 | max_original_size = float(max((w, h))) 161 | if max_original_size / min_original_size * size > max_size: 162 | size = int(round(max_size * min_original_size / max_original_size)) 163 | 164 | if (w <= h and w == size) or (h <= w and h == size): 165 | return (h, w) 166 | 167 | if w < h: 168 | ow = size 169 | oh = int(size * h / w) 170 | else: 171 | oh = size 172 | ow = int(size * w / h) 173 | 174 | return (oh, ow) 175 | 176 | def get_size(image_size, size, max_size=None): 177 | if isinstance(size, (list, tuple)): 178 | return size[::-1] 179 | else: 180 | return get_size_with_aspect_ratio(image_size, size, max_size) 181 | 182 | size = get_size(image.size, size, max_size) 183 | rescaled_image = F.resize(image, size) 184 | 185 | if target is None: 186 | return rescaled_image, None 187 | 188 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 189 | ratio_width, ratio_height = ratios 190 | 191 | target = target.copy() 192 | if "boxes" in target: 193 | boxes = target["boxes"] 194 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 195 | target["boxes"] = scaled_boxes 196 | 197 | if "pair_boxes" in target: 198 | hboxes = target["pair_boxes"][:, :4] 199 | scaled_hboxes = hboxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 200 | hboxes = scaled_hboxes 201 | 202 | oboxes = target["pair_boxes"][:, 4:] 203 | obj_mask = (oboxes[:, 0] != -1) 204 | if obj_mask.sum() != 0: 205 | scaled_oboxes = oboxes[obj_mask] * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 206 | oboxes[obj_mask] = scaled_oboxes 207 | 208 | target["pair_boxes"] = torch.cat([hboxes, oboxes], dim=-1) 209 | 210 | if "area" in target: 211 | area = target["area"] 212 | scaled_area = area * (ratio_width * ratio_height) 213 | target["area"] = scaled_area 214 | 215 | h, w = size 216 | target["size"] = torch.tensor([h, w]) 217 | 218 | if "masks" in target: 219 | target['masks'] = interpolate( 220 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 221 | 222 | return rescaled_image, target 223 | 224 | 225 | def pad(image, target, padding): 226 | # assumes that we only pad on the bottom right corners 227 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 228 | if target is None: 229 | return padded_image, None 230 | target = target.copy() 231 | # should we do something wrt the original size? 232 | target["size"] = torch.tensor(padded_image[::-1]) 233 | if "masks" in target: 234 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 235 | return padded_image, target 236 | 237 | 238 | class RandomCrop(object): 239 | def __init__(self, size): 240 | self.size = size 241 | 242 | def __call__(self, img, target): 243 | region = T.RandomCrop.get_params(img, self.size) 244 | return crop(img, target, region) 245 | 246 | 247 | class RandomSizeCrop(object): 248 | def __init__(self, min_size: int, max_size: int): 249 | self.min_size = min_size 250 | self.max_size = max_size 251 | 252 | def __call__(self, img: PIL.Image.Image, target: dict): 253 | w = random.randint(self.min_size, min(img.width, self.max_size)) 254 | h = random.randint(self.min_size, min(img.height, self.max_size)) 255 | region = T.RandomCrop.get_params(img, [h, w]) 256 | return crop(img, target, region) 257 | 258 | 259 | class CenterCrop(object): 260 | def __init__(self, size): 261 | self.size = size 262 | 263 | def __call__(self, img, target): 264 | image_width, image_height = img.size 265 | crop_height, crop_width = self.size 266 | crop_top = int(round((image_height - crop_height) / 2.)) 267 | crop_left = int(round((image_width - crop_width) / 2.)) 268 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 269 | 270 | 271 | class RandomHorizontalFlip(object): 272 | def __init__(self, p=0.5): 273 | self.p = p 274 | 275 | def __call__(self, img, target): 276 | if random.random() < self.p: 277 | return hflip(img, target) 278 | return img, target 279 | 280 | 281 | class RandomResize(object): 282 | def __init__(self, sizes, max_size=None): 283 | assert isinstance(sizes, (list, tuple)) 284 | self.sizes = sizes 285 | self.max_size = max_size 286 | 287 | def __call__(self, img, target=None): 288 | size = random.choice(self.sizes) 289 | return resize(img, target, size, self.max_size) 290 | 291 | 292 | class RandomPad(object): 293 | def __init__(self, max_pad): 294 | self.max_pad = max_pad 295 | 296 | def __call__(self, img, target): 297 | pad_x = random.randint(0, self.max_pad) 298 | pad_y = random.randint(0, self.max_pad) 299 | return pad(img, target, (pad_x, pad_y)) 300 | 301 | 302 | class RandomSelect(object): 303 | """ 304 | Randomly selects between transforms1 and transforms2, 305 | with probability p for transforms1 and (1 - p) for transforms2 306 | """ 307 | def __init__(self, transforms1, transforms2, p=0.5): 308 | self.transforms1 = transforms1 309 | self.transforms2 = transforms2 310 | self.p = p 311 | 312 | def __call__(self, img, target): 313 | if random.random() < self.p: 314 | return self.transforms1(img, target) 315 | return self.transforms2(img, target) 316 | 317 | 318 | class ToTensor(object): 319 | def __call__(self, img, target): 320 | return F.to_tensor(img), target 321 | 322 | 323 | class RandomErasing(object): 324 | 325 | def __init__(self, *args, **kwargs): 326 | self.eraser = T.RandomErasing(*args, **kwargs) 327 | 328 | def __call__(self, img, target): 329 | return self.eraser(img), target 330 | 331 | 332 | class Normalize(object): 333 | def __init__(self, mean, std): 334 | self.mean = mean 335 | self.std = std 336 | 337 | def __call__(self, image, target=None): 338 | image = F.normalize(image, mean=self.mean, std=self.std) 339 | if target is None: 340 | return image, None 341 | target = target.copy() 342 | h, w = image.shape[-2:] 343 | if "boxes" in target: 344 | boxes = target["boxes"] 345 | boxes = box_xyxy_to_cxcywh(boxes) 346 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 347 | target["boxes"] = boxes 348 | 349 | if "pair_boxes" in target: 350 | hboxes = target["pair_boxes"][:, :4] 351 | hboxes = box_xyxy_to_cxcywh(hboxes) 352 | hboxes = hboxes / torch.tensor([w, h, w, h], dtype=torch.float32) 353 | 354 | oboxes = target["pair_boxes"][:, 4:] 355 | obj_mask = (oboxes[:, 0] != -1) 356 | if obj_mask.sum() != 0: 357 | oboxes[obj_mask] = box_xyxy_to_cxcywh(oboxes[obj_mask]) 358 | oboxes[obj_mask] = oboxes[obj_mask] / torch.tensor([w, h, w, h], dtype=torch.float32) 359 | 360 | pair_boxes = torch.cat([hboxes, oboxes], dim=-1) 361 | target["pair_boxes"] = pair_boxes 362 | 363 | return image, target 364 | 365 | class ColorJitter(object): 366 | def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0): 367 | self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue) 368 | 369 | def __call__(self, img, target): 370 | return self.color_jitter(img), target 371 | 372 | class Compose(object): 373 | def __init__(self, transforms): 374 | self.transforms = transforms 375 | 376 | def __call__(self, image, target): 377 | for t in self.transforms: 378 | image, target = t(image, target) 379 | return image, target 380 | 381 | def __repr__(self): 382 | format_string = self.__class__.__name__ + "(" 383 | for t in self.transforms: 384 | format_string += "\n" 385 | format_string += " {0}".format(t) 386 | format_string += "\n)" 387 | return format_string 388 | -------------------------------------------------------------------------------- /src/data/datasets/hico.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # HOTR official code : src/data/datasets/hico.py 3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic) 6 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 7 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 8 | # ------------------------------------------------------------------------ 9 | from pathlib import Path 10 | from PIL import Image 11 | import json 12 | from collections import defaultdict 13 | import numpy as np 14 | 15 | import torch 16 | import torch.utils.data 17 | import torchvision 18 | 19 | from src.data.datasets import builtin_meta 20 | import src.data.transforms.transforms as T 21 | 22 | class HICODetection(torch.utils.data.Dataset): 23 | def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries): 24 | self.img_set = img_set 25 | self.img_folder = img_folder 26 | with open(anno_file, 'r') as f: 27 | self.annotations = json.load(f) 28 | with open(action_list_file, 'r') as f: 29 | self.action_lines = f.readlines() 30 | self._transforms = transforms 31 | self.num_queries = num_queries 32 | self.get_metadata() 33 | 34 | if img_set == 'train': 35 | self.ids = [] 36 | for idx, img_anno in enumerate(self.annotations): 37 | for hoi in img_anno['hoi_annotation']: 38 | if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']): 39 | break 40 | else: 41 | self.ids.append(idx) 42 | else: 43 | self.ids = list(range(len(self.annotations))) 44 | # self.ids = self.ids[:1000] 45 | 46 | ############################################################################ 47 | # Number Method 48 | ############################################################################ 49 | def get_metadata(self): 50 | meta = builtin_meta._get_coco_instances_meta() 51 | self.COCO_CLASSES = meta['coco_classes'] 52 | self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()] 53 | self._valid_verb_ids, self._valid_verb_names = [], [] 54 | for action_line in self.action_lines[2:]: 55 | act_id, act_name = action_line.split() 56 | self._valid_verb_ids.append(int(act_id)) 57 | self._valid_verb_names.append(act_name) 58 | 59 | def get_valid_obj_ids(self): 60 | return self._valid_obj_ids 61 | 62 | def get_actions(self): 63 | return self._valid_verb_names 64 | 65 | def num_category(self): 66 | return len(self.COCO_CLASSES) 67 | 68 | def num_action(self): 69 | return len(self._valid_verb_ids) 70 | ############################################################################ 71 | 72 | def __len__(self): 73 | return len(self.ids) 74 | 75 | def __getitem__(self, idx): 76 | img_anno = self.annotations[self.ids[idx]] 77 | 78 | img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') 79 | w, h = img.size 80 | 81 | if self.img_set == 'train': 82 | img_anno = merge_box_annotations(img_anno) 83 | # img_anno = merge_box_annotations(img_anno) # for finetune detr 84 | 85 | # cut out the GTs that exceed the number of object queries 86 | if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: 87 | img_anno['annotations'] = img_anno['annotations'][:self.num_queries] 88 | 89 | boxes = [obj['bbox'] for obj in img_anno['annotations']] 90 | # guard against no boxes via resizing 91 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 92 | 93 | if self.img_set == 'train': 94 | # Add index for confirming which boxes are kept after image transformation 95 | classes = [(i, obj['category_id']) for i, obj in enumerate(img_anno['annotations'])] 96 | else: 97 | classes = [obj['category_id'] for obj in img_anno['annotations']] 98 | classes = torch.tensor(classes, dtype=torch.int64) 99 | 100 | target = {} 101 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 102 | target['size'] = torch.as_tensor([int(h), int(w)]) 103 | if self.img_set == 'train': 104 | boxes[:, 0::2].clamp_(min=0, max=w) 105 | boxes[:, 1::2].clamp_(min=0, max=h) 106 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 107 | boxes = boxes[keep] 108 | classes = classes[keep] 109 | 110 | target['boxes'] = boxes 111 | target['labels'] = classes 112 | target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) 113 | target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 114 | 115 | if self._transforms is not None: 116 | img, target = self._transforms(img, target) 117 | 118 | kept_box_indices = [label[0] for label in target['labels']] 119 | 120 | target['labels'] = target['labels'][:, 1] 121 | 122 | obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] 123 | sub_obj_pairs = [] 124 | for hoi in img_anno['hoi_annotation']: 125 | if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: 126 | continue 127 | sub_obj_pair = (hoi['subject_id'], hoi['object_id']) 128 | if sub_obj_pair in sub_obj_pairs: # multi label 129 | verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 130 | else: 131 | sub_obj_pairs.append(sub_obj_pair) 132 | obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) 133 | verb_label = [0 for _ in range(len(self._valid_verb_ids))] 134 | verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 135 | sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] 136 | obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] 137 | verb_labels.append(verb_label) 138 | sub_boxes.append(sub_box) 139 | obj_boxes.append(obj_box) 140 | if len(sub_obj_pairs) == 0: 141 | target['pair_targets'] = torch.zeros((0,), dtype=torch.int64) 142 | target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 143 | target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 144 | target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 145 | else: 146 | target['pair_targets'] = torch.stack(obj_labels) 147 | target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32) 148 | target['sub_boxes'] = torch.stack(sub_boxes) 149 | target['obj_boxes'] = torch.stack(obj_boxes) 150 | 151 | # relation map 152 | relation_map = torch.zeros((len(target['boxes']), len(target['boxes']), self.num_action())) 153 | for sub_obj_pair in sub_obj_pairs: 154 | kept_subj_id = kept_box_indices.index(sub_obj_pair[0]) 155 | kept_obj_id = kept_box_indices.index(sub_obj_pair[1]) 156 | relation_map[kept_subj_id, kept_obj_id] = torch.tensor(verb_labels[sub_obj_pairs.index(sub_obj_pair)]) 157 | target['relation_map'] = relation_map 158 | target['hois'] = relation_map.nonzero(as_tuple=False) 159 | else: 160 | target['boxes'] = boxes 161 | target['labels'] = classes 162 | target['id'] = idx 163 | 164 | if self._transforms is not None: 165 | img, _ = self._transforms(img, None) 166 | 167 | hois = [] 168 | for hoi in img_anno['hoi_annotation']: 169 | hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) 170 | target['hois'] = torch.as_tensor(hois, dtype=torch.int64) 171 | 172 | target['image_id'] = torch.tensor([idx]) 173 | return img, target 174 | 175 | def set_rare_hois(self, anno_file): 176 | with open(anno_file, 'r') as f: 177 | annotations = json.load(f) 178 | 179 | counts = defaultdict(lambda: 0) 180 | for img_anno in annotations: 181 | hois = img_anno['hoi_annotation'] 182 | bboxes = img_anno['annotations'] 183 | for hoi in hois: 184 | # mapped to valid obj ids for evaludation 185 | triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), 186 | self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), 187 | self._valid_verb_ids.index(hoi['category_id'])) 188 | counts[triplet] += 1 189 | self.rare_triplets = [] 190 | self.non_rare_triplets = [] 191 | for triplet, count in counts.items(): 192 | if count < 10: 193 | self.rare_triplets.append(triplet) 194 | else: 195 | self.non_rare_triplets.append(triplet) 196 | 197 | def load_correct_mat(self, path): 198 | self.correct_mat = np.load(path) 199 | 200 | 201 | # Add color jitter to coco transforms 202 | def make_hico_transforms(image_set): 203 | 204 | normalize = T.Compose([ 205 | T.ToTensor(), 206 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 207 | ]) 208 | 209 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 210 | 211 | if image_set == 'train': 212 | return T.Compose([ 213 | T.RandomHorizontalFlip(), 214 | T.ColorJitter(.4, .4, .4), 215 | T.RandomSelect( 216 | T.RandomResize(scales, max_size=1333), 217 | T.Compose([ 218 | T.RandomResize([400, 500, 600]), 219 | T.RandomSizeCrop(384, 600), 220 | T.RandomResize(scales, max_size=1333), 221 | ]) 222 | ), 223 | normalize, 224 | ]) 225 | 226 | if image_set == 'val': 227 | return T.Compose([ 228 | T.RandomResize([800], max_size=1333), 229 | normalize, 230 | ]) 231 | 232 | if image_set == 'test': 233 | return T.Compose([ 234 | T.RandomResize([800], max_size=1333), 235 | normalize, 236 | ]) 237 | 238 | raise ValueError(f'unknown {image_set}') 239 | 240 | 241 | def merge_box_annotations(org_image_annotation, overlap_iou_thres=0.7): 242 | merged_image_annotation = org_image_annotation.copy() 243 | 244 | # compute match 245 | bbox_list = org_image_annotation['annotations'] 246 | box_match = torch.zeros(len(bbox_list), len(bbox_list)).bool() 247 | for i, bbox1 in enumerate(bbox_list): 248 | for j, bbox2 in enumerate(bbox_list): 249 | box_match[i, j] = compute_box_match(bbox1, bbox2, overlap_iou_thres) 250 | 251 | box_groups = [] 252 | for i in range(len(box_match)): 253 | if box_match[i].any(): # box unassigned to group 254 | group_ids = box_match[i].nonzero(as_tuple=False).squeeze(1) 255 | box_groups.append(group_ids.tolist()) 256 | box_match[:, group_ids] = False 257 | assert sum([len(g) for g in box_groups]) == len(bbox_list) 258 | 259 | # merge to new anntations 260 | group_info, orgbox2group = [], {} 261 | for gid, org_box_ids in enumerate(box_groups): 262 | for orgid in org_box_ids: orgbox2group.update({orgid: gid}) 263 | # selected_box_id = np.random.choice(org_box_ids) 264 | # box_info = bbox_list[selected_box_id] 265 | box_info = { 266 | 'bbox': torch.tensor([bbox_list[id]['bbox'] for id in org_box_ids]).float().mean(dim=0).int().tolist(), 267 | 'category_id': bbox_list[org_box_ids[0]]['category_id'] 268 | } 269 | 270 | group_info.append(box_info) 271 | 272 | new_hois = [] 273 | for hoi in org_image_annotation['hoi_annotation']: 274 | if hoi['subject_id'] in orgbox2group and hoi['object_id'] in orgbox2group: 275 | new_hois.append({ 276 | 'subject_id': orgbox2group[hoi['subject_id']], 277 | 'object_id': orgbox2group[hoi['object_id']], 278 | 'category_id': hoi['category_id'] 279 | }) 280 | 281 | merged_image_annotation['annotations'] = group_info 282 | merged_image_annotation['hoi_annotation'] = new_hois 283 | return merged_image_annotation 284 | 285 | # iou > threshold and same category 286 | def compute_box_match(bbox1, bbox2, threshold): 287 | if isinstance(bbox1['category_id'], str): 288 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) 289 | if isinstance(bbox2['category_id'], str): 290 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) 291 | if bbox1['category_id'] != bbox2['category_id']: 292 | return False 293 | else: 294 | rec1 = bbox1['bbox'] 295 | rec2 = bbox2['bbox'] 296 | # computing area of each rectangles 297 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1) 298 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1) 299 | 300 | # computing the sum_area 301 | sum_area = S_rec1 + S_rec2 302 | 303 | # find the each edge of intersect rectangle 304 | left_line = max(rec1[1], rec2[1]) 305 | right_line = min(rec1[3], rec2[3]) 306 | top_line = max(rec1[0], rec2[0]) 307 | bottom_line = min(rec1[2], rec2[2]) 308 | 309 | # judge if there is an intersect 310 | intersect = max((right_line - left_line+1), 0) * max((bottom_line - top_line+1), 0) 311 | iou = intersect / (sum_area - intersect) 312 | if iou > threshold: 313 | return True 314 | else: 315 | return False 316 | 317 | def build(image_set, args): 318 | root = Path(args.data_path) 319 | assert root.exists(), f'provided HOI path {root} does not exist' 320 | PATHS = { 321 | 'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'), 322 | 'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'), 323 | 'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json') 324 | } 325 | CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy' 326 | action_list_file = root / 'list_action.txt' 327 | 328 | img_folder, anno_file = PATHS[image_set] 329 | dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set), 330 | num_queries=args.num_queries) 331 | if image_set == 'val' or image_set == 'test': 332 | dataset.set_rare_hois(PATHS['train'][1]) 333 | dataset.load_correct_mat(CORRECT_MAT_PATH) 334 | return dataset --------------------------------------------------------------------------------