├── src
├── util
│ ├── __init__.py
│ ├── box_ops.py
│ ├── logger.py
│ └── misc.py
├── models
│ ├── __init__.py
│ ├── feed_forward.py
│ ├── position_encoding.py
│ ├── detr_matcher.py
│ ├── backbone.py
│ ├── hotr.py
│ ├── post_process.py
│ ├── detr.py
│ └── hotr_matcher.py
├── engine
│ ├── __init__.py
│ ├── evaluator_coco.py
│ ├── trainer.py
│ ├── hico_det_evaluate.py
│ ├── evaluator_vcoco.py
│ ├── evaluator_hico.py
│ └── arg_parser.py
├── data
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── coco.py
│ │ ├── builtin_meta.py
│ │ └── hico.py
│ ├── evaluators
│ │ ├── vcoco_eval.py
│ │ ├── coco_eval.py
│ │ └── hico_eval.py
│ └── transforms
│ │ └── transforms.py
└── metrics
│ ├── utils.py
│ └── vcoco
│ ├── ap_agent.py
│ └── ap_role.py
├── imgs
└── STIP.png
├── .run
├── STIP_vcoco_single_train.run.xml
└── STIP_hicodet_single_train.run.xml
├── .gitignore
├── README.md
└── STIP_main.py
/src/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/imgs/STIP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zyong812/STIP/HEAD/imgs/STIP.png
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | from .detr import build
3 |
4 | def build_model(args):
5 | return build(args)
--------------------------------------------------------------------------------
/src/models/feed_forward.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | from torch import nn
3 |
4 | class MLP(nn.Module):
5 | """ Very simple multi-layer perceptron (also called FFN)"""
6 |
7 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
8 | super().__init__()
9 | self.num_layers = num_layers
10 | h = [hidden_dim] * (num_layers - 1)
11 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
12 |
13 | def forward(self, x):
14 | for i, layer in enumerate(self.layers):
15 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
16 | return x
--------------------------------------------------------------------------------
/src/engine/__init__.py:
--------------------------------------------------------------------------------
1 | from .evaluator_vcoco import vcoco_evaluate, vcoco_accumulate
2 | from .evaluator_hico import hico_evaluate
3 |
4 | def hoi_evaluator(args, model, criterion, postprocessors, data_loader, device, thr=0):
5 | if args.dataset_file == 'vcoco':
6 | return vcoco_evaluate(model, criterion, postprocessors, data_loader, device, args.output_dir, thr)
7 | elif args.dataset_file == 'hico-det':
8 | return hico_evaluate(model, postprocessors, data_loader, device, thr, args)
9 | else: raise NotImplementedError
10 |
11 | def hoi_accumulator(args, total_res, print_results=False, wandb=False):
12 | if args.dataset_file == 'vcoco':
13 | return vcoco_accumulate(total_res, args, print_results, wandb)
14 | else: raise NotImplementedError
--------------------------------------------------------------------------------
/src/data/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | import torch.utils.data
3 | import torchvision
4 |
5 | from src.data.datasets.coco import build as build_coco
6 | from src.data.datasets.vcoco import build as build_vcoco
7 | from src.data.datasets.hico import build as build_hico
8 |
9 | def get_coco_api_from_dataset(dataset):
10 | for _ in range(10): # what is this for?
11 | if isinstance(dataset, torch.utils.data.Subset):
12 | dataset = dataset.dataset
13 | if isinstance(dataset, torchvision.datasets.CocoDetection):
14 | return dataset.coco
15 |
16 |
17 | def build_dataset(image_set, args):
18 | if args.dataset_file == 'coco':
19 | return build_coco(image_set, args)
20 | elif args.dataset_file == 'vcoco':
21 | return build_vcoco(image_set, args)
22 | elif args.dataset_file == 'hico-det':
23 | return build_hico(image_set, args)
24 | raise ValueError(f'dataset {args.dataset_file} not supported')
--------------------------------------------------------------------------------
/.run/STIP_vcoco_single_train.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/.run/STIP_hicodet_single_train.run.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | .idea
6 |
7 | # C extensions
8 | *.so
9 |
10 | # Distribution / packaging
11 | .Python
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | pip-wheel-metadata/
25 | share/python-wheels/
26 | *.egg-info/
27 | .installed.cfg
28 | *.egg
29 | MANIFEST
30 |
31 | # PyInstaller
32 | # Usually these files are written by a python script from a template
33 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
34 | *.manifest
35 | *.spec
36 |
37 | # Installer logs
38 | pip-log.txt
39 | pip-delete-this-directory.txt
40 |
41 | # Unit test / coverage reports
42 | htmlcov/
43 | .tox/
44 | .nox/
45 | .coverage
46 | .coverage.*
47 | .cache
48 | nosetests.xml
49 | coverage.xml
50 | *.cover
51 | *.py,cover
52 | .hypothesis/
53 | .pytest_cache/
54 |
55 | # Translations
56 | *.mo
57 | *.pot
58 |
59 | # Django stuff:
60 | *.log
61 | local_settings.py
62 | db.sqlite3
63 | db.sqlite3-journal
64 |
65 | # Flask stuff:
66 | instance/
67 | .webassets-cache
68 |
69 | # Scrapy stuff:
70 | .scrapy
71 |
72 | # Sphinx documentation
73 | docs/_build/
74 |
75 | # PyBuilder
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # pipenv
89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
92 | # install all needed dependencies.
93 | #Pipfile.lock
94 |
95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
96 | __pypackages__/
97 |
98 | # Celery stuff
99 | celerybeat-schedule
100 | celerybeat.pid
101 |
102 | # SageMath parsed files
103 | *.sage.py
104 |
105 | # Environments
106 | .env
107 | .venv
108 | env/
109 | venv/
110 | ENV/
111 | env.bak/
112 | venv.bak/
113 |
114 | # Spyder project settings
115 | .spyderproject
116 | .spyproject
117 |
118 | # Rope project settings
119 | .ropeproject
120 |
121 | # mkdocs documentation
122 | /site
123 |
124 | # mypy
125 | .mypy_cache/
126 | .dmypy.json
127 | dmypy.json
128 |
129 | # Pyre type checker
130 | .pyre/
131 | wandb/
132 | checkpoints/
133 | res/
134 |
135 |
--------------------------------------------------------------------------------
/src/data/evaluators/vcoco_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) KakaoBrain, Inc. and its affiliates. All Rights Reserved
2 | """
3 | V-COCO evaluator that works in distributed mode.
4 | """
5 | import os
6 | import numpy as np
7 | import torch
8 |
9 | from src.util.misc import all_gather
10 | from src.metrics.vcoco.ap_role import APRole
11 | from functools import partial
12 |
13 | def init_vcoco_evaluators(human_act_name, object_act_name):
14 | role_eval1 = APRole(act_name=object_act_name, scenario_flag=True, iou_threshold=0.5)
15 | role_eval2 = APRole(act_name=object_act_name, scenario_flag=False, iou_threshold=0.5)
16 |
17 | return role_eval1, role_eval2
18 |
19 | class VCocoEvaluator(object):
20 | def __init__(self, args):
21 | self.img_ids = []
22 | self.eval_imgs = []
23 | self.role_eval1, self.role_eval2 = init_vcoco_evaluators(args.human_actions, args.object_actions)
24 | self.num_human_act = args.num_human_act
25 | self.action_idx = args.valid_ids
26 |
27 | def update(self, outputs):
28 | img_ids = list(np.unique(list(outputs.keys())))
29 | for img_num, img_id in enumerate(img_ids):
30 | print(f"Evaluating Score Matrix... : [{(img_num+1):>4}/{len(img_ids):<4}]" ,flush=True, end="\r")
31 | prediction = outputs[img_id]['prediction']
32 | target = outputs[img_id]['target']
33 |
34 | # score with prediction
35 | hbox, hcat, obox, ocat = list(map(lambda x: prediction[x], \
36 | ['h_box', 'h_cat', 'o_box', 'o_cat']))
37 |
38 | assert 'pair_score' in prediction
39 | score = prediction['pair_score']
40 |
41 | hbox, hcat, obox, ocat, score =\
42 | list(map(lambda x: x.cpu().numpy(), [hbox, hcat, obox, ocat, score]))
43 |
44 | # ground-truth
45 | gt_h_inds = (target['labels'] == 1)
46 | gt_h_box = target['boxes'][gt_h_inds, :4].cpu().numpy()
47 | gt_h_act = target['inst_actions'][gt_h_inds, :self.num_human_act].cpu().numpy() ## self.num_human_act=26, for filtering GT
48 |
49 | gt_p_box = target['pair_boxes'].cpu().numpy()
50 | gt_p_act = target['pair_actions'].cpu().numpy()
51 |
52 | score = score[self.action_idx, :, :]
53 | gt_p_act = gt_p_act[:, self.action_idx]
54 |
55 | self.role_eval1.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act) # gt_h_act: Nx26, gt_p_act: Nx25
56 | self.role_eval2.add_data(hbox, obox, score, gt_h_box, gt_h_act, gt_p_box, gt_p_act)
57 | self.img_ids.append(img_id)
58 |
--------------------------------------------------------------------------------
/src/engine/evaluator_coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import src.util.misc as utils
4 | import src.util.logger as loggers
5 | from src.data.evaluators.coco_eval import CocoEvaluator
6 |
7 | @torch.no_grad()
8 | def coco_evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir):
9 | model.eval()
10 | criterion.eval()
11 |
12 | metric_logger = loggers.MetricLogger(delimiter=" ")
13 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
14 | header = 'Evaluation'
15 |
16 | iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
17 | coco_evaluator = CocoEvaluator(base_ds, iou_types)
18 | print_freq = len(data_loader)
19 | # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
20 |
21 | print("\n>>> [MS-COCO Evaluation] <<<")
22 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
23 | samples = samples.to(device)
24 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
25 |
26 | outputs = model(samples)
27 | loss_dict = criterion(outputs, targets)
28 | weight_dict = criterion.weight_dict
29 |
30 | # reduce losses over all GPUs for logging purposes
31 | loss_dict_reduced = utils.reduce_dict(loss_dict)
32 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
33 | for k, v in loss_dict_reduced.items() if k in weight_dict}
34 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
35 | for k, v in loss_dict_reduced.items()}
36 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
37 | **loss_dict_reduced_scaled,
38 | **loss_dict_reduced_unscaled)
39 | metric_logger.update(class_error=loss_dict_reduced['class_error'])
40 |
41 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
42 | results = postprocessors['bbox'](outputs, orig_target_sizes)
43 | res = {target['image_id'].item(): output for target, output in zip(targets, results)}
44 | if coco_evaluator is not None:
45 | coco_evaluator.update(res)
46 |
47 | # gather the stats from all processes
48 | metric_logger.synchronize_between_processes()
49 | print("\n>>> [Averaged stats] <<<\n", metric_logger)
50 | if coco_evaluator is not None:
51 | coco_evaluator.synchronize_between_processes()
52 |
53 | # accumulate predictions from all images
54 | if coco_evaluator is not None:
55 | coco_evaluator.accumulate()
56 | coco_evaluator.summarize()
57 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
58 | if coco_evaluator is not None:
59 | if 'bbox' in postprocessors.keys():
60 | stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
61 |
62 | return stats, coco_evaluator
--------------------------------------------------------------------------------
/src/models/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Various positional encodings for the transformer.
4 | """
5 | import math
6 | import torch
7 | from torch import nn
8 |
9 | from src.util.misc import NestedTensor
10 |
11 |
12 | class PositionEmbeddingSine(nn.Module):
13 | """
14 | This is a more standard version of the position embedding, very similar to the one
15 | used by the Attention is all you need paper, generalized to work on images.
16 | """
17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
18 | super().__init__()
19 | self.num_pos_feats = num_pos_feats
20 | self.temperature = temperature
21 | self.normalize = normalize
22 | if scale is not None and normalize is False:
23 | raise ValueError("normalize should be True if scale is passed")
24 | if scale is None:
25 | scale = 2 * math.pi
26 | self.scale = scale
27 |
28 | def forward(self, tensor_list: NestedTensor):
29 | x = tensor_list.tensors
30 | mask = tensor_list.mask
31 | assert mask is not None
32 | not_mask = ~mask
33 | y_embed = not_mask.cumsum(1, dtype=torch.float32)
34 | x_embed = not_mask.cumsum(2, dtype=torch.float32)
35 | if self.normalize:
36 | eps = 1e-6
37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
39 |
40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
42 |
43 | pos_x = x_embed[:, :, :, None] / dim_t
44 | pos_y = y_embed[:, :, :, None] / dim_t
45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
48 | return pos
49 |
50 |
51 | class PositionEmbeddingLearned(nn.Module):
52 | """
53 | Absolute pos embedding, learned.
54 | """
55 | def __init__(self, num_pos_feats=256):
56 | super().__init__()
57 | self.row_embed = nn.Embedding(50, num_pos_feats)
58 | self.col_embed = nn.Embedding(50, num_pos_feats)
59 | self.reset_parameters()
60 |
61 | def reset_parameters(self):
62 | nn.init.uniform_(self.row_embed.weight)
63 | nn.init.uniform_(self.col_embed.weight)
64 |
65 | def forward(self, tensor_list: NestedTensor):
66 | x = tensor_list.tensors
67 | h, w = x.shape[-2:]
68 | i = torch.arange(w, device=x.device)
69 | j = torch.arange(h, device=x.device)
70 | x_emb = self.col_embed(i)
71 | y_emb = self.row_embed(j)
72 | pos = torch.cat([
73 | x_emb.unsqueeze(0).repeat(h, 1, 1),
74 | y_emb.unsqueeze(1).repeat(1, w, 1),
75 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
76 | return pos
77 |
78 |
79 | def build_position_encoding(args):
80 | N_steps = args.hidden_dim // 2
81 | if args.position_embedding in ('v2', 'sine'):
82 | # TODO find a better way of exposing other arguments
83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
84 | elif args.position_embedding in ('v3', 'learned'):
85 | position_embedding = PositionEmbeddingLearned(N_steps)
86 | else:
87 | raise ValueError(f"not supported {args.position_embedding}")
88 |
89 | return position_embedding
--------------------------------------------------------------------------------
/src/engine/trainer.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : engine/trainer.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | import math
9 | import torch
10 | import sys
11 | import src.util.misc as utils
12 | import src.util.logger as loggers
13 | from typing import Iterable
14 | import wandb
15 | from src.models.stip_utils import check_annotation, plot_cross_attention, plot_hoi_results
16 |
17 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
18 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
19 | device: torch.device, epoch: int, max_epoch: int, max_norm: float = 0, dataset_file: str = 'coco', log: bool = False):
20 | model.train()
21 | criterion.train()
22 | metric_logger = loggers.MetricLogger(mode="train", delimiter=" ")
23 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
24 | space_fmt = str(len(str(max_epoch)))
25 | header = 'Epoch [{start_epoch: >{fill}}/{end_epoch}]'.format(start_epoch=epoch+1, end_epoch=max_epoch, fill=space_fmt)
26 | print_freq = int(len(data_loader)/5)
27 |
28 | print(f"\n>>> Epoch #{(epoch+1)}")
29 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
30 | samples = samples.to(device)
31 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
32 |
33 | outputs = model(samples, targets)
34 | loss_dict = criterion(outputs, targets, log)
35 | weight_dict = criterion.weight_dict
36 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
37 |
38 | # reduce losses over all GPUs for logging purposes
39 | loss_dict_reduced = utils.reduce_dict(loss_dict)
40 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v
41 | for k, v in loss_dict_reduced.items()}
42 | loss_dict_reduced_scaled = {k: v * weight_dict[k]
43 | for k, v in loss_dict_reduced.items() if k in weight_dict}
44 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
45 | loss_value = losses_reduced_scaled.item()
46 | if utils.get_rank() == 0 and log: wandb.log(loss_dict_reduced_scaled)
47 |
48 | if not math.isfinite(loss_value):
49 | print("Loss is {}, stopping training".format(loss_value))
50 | print(loss_dict_reduced)
51 | sys.exit(1)
52 |
53 | optimizer.zero_grad()
54 | losses.backward()
55 | if max_norm > 0:
56 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
57 | optimizer.step()
58 |
59 | # check_annotation(samples, targets, mode='train', rel_num=20)
60 | # plot_hoi_results(samples, outputs, targets, args=model.args)
61 |
62 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
63 | if "obj_class_error" in loss_dict:
64 | metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error'])
65 | metric_logger.update(lr=optimizer.param_groups[0]["lr"])
66 |
67 | # gather the stats from all processes
68 | metric_logger.synchronize_between_processes()
69 | print("Averaged stats:", metric_logger)
70 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
71 |
--------------------------------------------------------------------------------
/src/metrics/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 | def count_parameters(model):
5 | return sum(p.numel() for p in model.parameters() if p.requires_grad)
6 |
7 |
8 | def compute_overlap(a, b):
9 | if type(a) == torch.Tensor:
10 | if len(a.shape) == 2:
11 | area = (b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1)
12 |
13 | iw = torch.min(a[:, 2].unsqueeze(dim=1), b[:, 2]) - torch.max(a[:, 0].unsqueeze(dim=1), b[:, 0])
14 | ih = torch.min(a[:, 3].unsqueeze(dim=1), b[:, 3]) - torch.max(a[:, 1].unsqueeze(dim=1), b[:, 1])
15 |
16 | iw[iw<0] = 0
17 | ih[ih<0] = 0
18 |
19 | ua = torch.unsqueeze((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), dim=1) + area - iw * ih
20 | ua[ua < 1e-8] = 1e-8
21 |
22 | intersection = iw * ih
23 |
24 | return intersection / ua
25 |
26 | elif type(a) == np.ndarray:
27 | if len(a.shape) == 2:
28 | area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K)
29 |
30 | iw = np.minimum(np.expand_dims(a[:, 2], axis=1), np.expand_dims(b[:, 2], axis=0)) \
31 | - np.maximum(np.expand_dims(a[:, 0], axis=1), np.expand_dims(b[:, 0], axis=0)) \
32 | + 1
33 | ih = np.minimum(np.expand_dims(a[:, 3], axis=1), np.expand_dims(b[:, 3], axis=0)) \
34 | - np.maximum(np.expand_dims(a[:, 1], axis=1), np.expand_dims(b[:, 1], axis=0)) \
35 | + 1
36 |
37 | iw[iw<0] = 0 # (N, K)
38 | ih[ih<0] = 0 # (N, K)
39 |
40 | intersection = iw * ih
41 |
42 | ua = np.expand_dims((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), axis=1) + area - intersection
43 | ua[ua < 1e-8] = 1e-8
44 |
45 | return intersection / ua
46 |
47 | elif len(a.shape) == 1:
48 | area = np.expand_dims((b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1), axis=0) #(1, K)
49 |
50 | iw = np.minimum(np.expand_dims([a[2]], axis=1), np.expand_dims(b[:, 2], axis=0)) \
51 | - np.maximum(np.expand_dims([a[0]], axis=1), np.expand_dims(b[:, 0], axis=0))
52 | ih = np.minimum(np.expand_dims([a[3]], axis=1), np.expand_dims(b[:, 3], axis=0)) \
53 | - np.maximum(np.expand_dims([a[1]], axis=1), np.expand_dims(b[:, 1], axis=0))
54 |
55 | iw[iw<0] = 0 # (N, K)
56 | ih[ih<0] = 0 # (N, K)
57 |
58 | ua = np.expand_dims([(a[2] - a[0] + 1) * (a[3] - a[1] + 1)], axis=1) + area - iw * ih
59 | ua[ua < 1e-8] = 1e-8
60 |
61 | intersection = iw * ih
62 |
63 | return intersection / ua
64 |
65 |
66 | def _compute_ap(recall, precision):
67 | """ Compute the average precision, given the recall and precision curves.
68 | Code originally from https://github.com/rbgirshick/py-faster-rcnn.
69 | # Arguments
70 | recall: The recall curve (list).
71 | precision: The precision curve (list).
72 | # Returns
73 | The average precision as computed in py-faster-rcnn.
74 | """
75 | # correct AP calculation
76 | # first append sentinel values at the end
77 | mrec = np.concatenate(([0.], recall, [1.]))
78 | mpre = np.concatenate(([0.], precision, [0.]))
79 |
80 | # compute the precision envelope
81 | for i in range(mpre.size - 1, 0, -1):
82 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
83 |
84 | # to calculate area under PR curve, look for points
85 | # where X axis (recall) changes value
86 | i = np.where(mrec[1:] != mrec[:-1])[0]
87 |
88 | # and sum (\Delta recall) * prec
89 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
90 | return ap
--------------------------------------------------------------------------------
/src/util/box_ops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Utilities for bounding box manipulation and GIoU.
4 | """
5 | import torch
6 | from torchvision.ops.boxes import box_area
7 |
8 |
9 | def box_cxcywh_to_xyxy(x):
10 | x_c, y_c, w, h = x.unbind(-1)
11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
12 | (x_c + 0.5 * w), (y_c + 0.5 * h)]
13 | return torch.stack(b, dim=-1)
14 |
15 |
16 | def box_xyxy_to_cxcywh(x):
17 | x0, y0, x1, y1 = x.unbind(-1)
18 | b = [(x0 + x1) / 2, (y0 + y1) / 2,
19 | (x1 - x0), (y1 - y0)]
20 | return torch.stack(b, dim=-1)
21 |
22 |
23 | # modified from torchvision to also return the union
24 | def box_iou(boxes1, boxes2):
25 | area1 = box_area(boxes1)
26 | area2 = box_area(boxes2)
27 |
28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2]
29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2]
30 |
31 | wh = (rb - lt).clamp(min=0) # [N,M,2]
32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
33 |
34 | union = area1[:, None] + area2 - inter
35 |
36 | iou = inter / union
37 | return iou, union
38 |
39 |
40 | def generalized_box_iou(boxes1, boxes2):
41 | """
42 | Generalized IoU from https://giou.stanford.edu/
43 | The boxes should be in [x0, y0, x1, y1] format
44 | Returns a [N, M] pairwise matrix, where N = len(boxes1)
45 | and M = len(boxes2)
46 | """
47 | # degenerate boxes gives inf / nan results
48 | # so do an early check
49 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
50 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
51 | iou, union = box_iou(boxes1, boxes2)
52 |
53 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
54 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
55 |
56 | wh = (rb - lt).clamp(min=0) # [N,M,2]
57 | area = wh[:, :, 0] * wh[:, :, 1]
58 |
59 | return iou - (area - union) / area
60 |
61 |
62 | def masks_to_boxes(masks):
63 | """Compute the bounding boxes around the provided masks
64 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions.
65 | Returns a [N, 4] tensors, with the boxes in xyxy format
66 | """
67 | if masks.numel() == 0:
68 | return torch.zeros((0, 4), device=masks.device)
69 |
70 | h, w = masks.shape[-2:]
71 |
72 | y = torch.arange(0, h, dtype=torch.float)
73 | x = torch.arange(0, w, dtype=torch.float)
74 | y, x = torch.meshgrid(y, x)
75 |
76 | x_mask = (masks * x.unsqueeze(0))
77 | x_max = x_mask.flatten(1).max(-1)[0]
78 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
79 |
80 | y_mask = (masks * y.unsqueeze(0))
81 | y_max = y_mask.flatten(1).max(-1)[0]
82 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0]
83 |
84 | return torch.stack([x_min, y_min, x_max, y_max], 1)
85 |
86 |
87 | def rescale_bboxes(out_bbox, size):
88 | img_h, img_w = size
89 | b = box_cxcywh_to_xyxy(out_bbox)
90 | b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(out_bbox.get_device())
91 | return b
92 |
93 |
94 | def rescale_pairs(out_pairs, size):
95 | img_h, img_w = size
96 | h_bbox = out_pairs[:, :4]
97 | o_bbox = out_pairs[:, 4:]
98 |
99 | h = box_cxcywh_to_xyxy(h_bbox)
100 | h = h * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(h_bbox.get_device())
101 |
102 | obj_mask = (o_bbox[:, 0] != -1)
103 | if obj_mask.sum() != 0:
104 | o = box_cxcywh_to_xyxy(o_bbox)
105 | o = o * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32).to(o_bbox.get_device())
106 | o_bbox[obj_mask] = o[obj_mask]
107 | o = o_bbox
108 | p = torch.cat([h, o], dim=-1)
109 |
110 | return p
--------------------------------------------------------------------------------
/src/engine/hico_det_evaluate.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import copy
3 | import itertools
4 |
5 | import torch
6 |
7 | import src.util.misc as utils
8 | import src.util.logger as loggers
9 | from pycocotools.coco import COCO
10 | from pycocotools.cocoeval import COCOeval
11 |
12 | # evaluate detection on HICO-DET
13 | @torch.no_grad()
14 | def hico_det_evaluate(model, postprocessors, data_loader, device, args):
15 | hico_valid_obj_ids = torch.tensor(args.valid_obj_ids) # id -> coco id
16 |
17 | model.eval()
18 |
19 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
20 | header = 'Evaluation Inference (HICO-DET)'
21 |
22 | preds = []
23 | gts = []
24 |
25 | all_predictions = {}
26 | for samples, targets in metric_logger.log_every(data_loader, 50, header):
27 | samples = samples.to(device)
28 | targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets]
29 |
30 | outputs = model(samples)
31 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
32 | results = postprocessors['bbox'](outputs, orig_target_sizes, dataset='hico-det')
33 |
34 | preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
35 | # For avoiding a runtime error, the copy is used
36 | gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
37 |
38 | # gather the stats from all processes
39 | metric_logger.synchronize_between_processes()
40 |
41 | img_ids = [img_gts['id'] for img_gts in gts]
42 | _, indices = np.unique(img_ids, return_index=True)
43 | preds = {i: img_preds for i, img_preds in enumerate(preds) if i in indices}
44 | gts = {i: img_gts for i, img_gts in enumerate(gts) if i in indices}
45 |
46 | stats = do_detection_evaluation(preds, gts, hico_valid_obj_ids)
47 | return stats
48 |
49 | def do_detection_evaluation(predictions, groundtruths, hico_valid_obj_ids):
50 | # create a Coco-like object that we can use to evaluate detection!
51 | anns = []
52 | for image_id, gt in groundtruths.items():
53 | labels = gt['labels'].tolist() # map to coco like ids
54 | boxes = gt['boxes'].tolist() # xyxy
55 | for cls, box in zip(labels, boxes):
56 | anns.append({
57 | 'area': (box[3] - box[1] + 1) * (box[2] - box[0] + 1),
58 | 'bbox': [box[0], box[1], box[2] - box[0] + 1, box[3] - box[1] + 1], # xywh
59 | 'category_id': cls,
60 | 'id': len(anns),
61 | 'image_id': image_id,
62 | 'iscrowd': 0,
63 | })
64 | fauxcoco = COCO()
65 | fauxcoco.dataset = {
66 | 'info': {'description': 'use coco script for vg detection evaluation'},
67 | 'images': [{'id': i} for i in range(len(groundtruths))],
68 | 'categories': [
69 | {'supercategory': 'person', 'id': i, 'name': str(i)} for i in hico_valid_obj_ids.tolist()
70 | ],
71 | 'annotations': anns,
72 | }
73 | fauxcoco.createIndex()
74 |
75 | # format predictions to coco-like
76 | cocolike_predictions = []
77 | for image_id, prediction in predictions.items():
78 | box = torch.stack((prediction['boxes'][:,0], prediction['boxes'][:,1], prediction['boxes'][:,2]-prediction['boxes'][:,0]+1, prediction['boxes'][:,3]-prediction['boxes'][:,1]+1), dim=1).detach().cpu().numpy() # xywh
79 | label = prediction['labels'].cpu().numpy() # (#objs,)
80 | score = prediction['scores'].cpu().numpy() # (#objs,)
81 |
82 | image_id = np.asarray([image_id]*len(box))
83 | cocolike_predictions.append(
84 | np.column_stack((image_id, box, score, label))
85 | )
86 | cocolike_predictions = np.concatenate(cocolike_predictions, 0)
87 | # evaluate via coco API
88 | res = fauxcoco.loadRes(cocolike_predictions)
89 | coco_eval = COCOeval(fauxcoco, res, 'bbox')
90 | coco_eval.params.imgIds = list(range(len(groundtruths)))
91 | coco_eval.evaluate()
92 | coco_eval.accumulate()
93 | coco_eval.summarize()
94 | mAp = coco_eval.stats[1]
95 |
96 | return mAp
--------------------------------------------------------------------------------
/src/metrics/vcoco/ap_agent.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from src.metrics.utils import _compute_ap, compute_overlap
3 | import pdb
4 |
5 | class APAgent(object):
6 | def __init__(self, act_name, iou_threshold=0.5):
7 | self.act_name = act_name
8 | self.iou_threshold = iou_threshold
9 |
10 | self.fp = [np.zeros((0,))] * len(act_name)
11 | self.tp = [np.zeros((0,))] * len(act_name)
12 | self.score = [np.zeros((0,))] * len(act_name)
13 | self.num_ann = [0] * len(act_name)
14 |
15 | def add_data(self, box, act, cat, i_box, i_act):
16 | for label in range(len(self.act_name)):
17 | i_inds = (i_act[:, label] == 1)
18 | self.num_ann[label] += i_inds.sum()
19 |
20 | n_pred = box.shape[0]
21 | if n_pred == 0 : return
22 |
23 | ######################
24 | valid_i_inds = (i_act[:, 0] != -1) # (n_i, ) # both in COCO & V-COCO
25 |
26 | overlaps = compute_overlap(box, i_box) # (n_pred, n_i)
27 | assigned_input = np.argmax(overlaps, axis=1) # (n_pred, )
28 | v_inds = valid_i_inds[assigned_input] # (n_pred, )
29 |
30 | n_valid = v_inds.sum()
31 |
32 | if n_valid == 0 : return
33 | valid_box = box[v_inds]
34 | valid_act = act[v_inds]
35 | valid_cat = cat[v_inds]
36 |
37 | ######################
38 | s = valid_act * np.expand_dims(valid_cat, axis=1) # (n_v, #act)
39 |
40 | for label in range(len(self.act_name)):
41 | inds = np.argsort(s[:, label])[::-1] # (n_v, )
42 | self.score[label] = np.append(self.score[label], s[inds, label])
43 |
44 | correct_i_inds = (i_act[:, label] == 1)
45 | if correct_i_inds.sum() == 0:
46 | self.tp[label] = np.append(self.tp[label], np.array([0]*n_valid))
47 | self.fp[label] = np.append(self.fp[label], np.array([1]*n_valid))
48 | continue
49 |
50 | overlaps = compute_overlap(valid_box[inds], i_box) # (n_v, n_i)
51 | assigned_input = np.argmax(overlaps, axis=1) # (n_v, )
52 | max_overlap = overlaps[range(n_valid), assigned_input] # (n_v, )
53 |
54 | iou_inds = (max_overlap > self.iou_threshold) & correct_i_inds[assigned_input] # (n_v, )
55 |
56 | i_nonzero = iou_inds.nonzero()[0]
57 | i_inds = assigned_input[i_nonzero]
58 | i_iou = np.unique(i_inds, return_index=True)[1]
59 | i_tp = i_nonzero[i_iou]
60 |
61 | t = np.zeros(n_valid, dtype=np.uint8)
62 | t[i_tp] = 1
63 | f = 1-t
64 |
65 | self.tp[label] = np.append(self.tp[label], t)
66 | self.fp[label] = np.append(self.fp[label], f)
67 |
68 | def evaluate(self):
69 | average_precisions = dict()
70 | for label in range(len(self.act_name)):
71 | if self.num_ann[label] == 0:
72 | average_precisions[label] = 0
73 | continue
74 |
75 | # sort by score
76 | indices = np.argsort(-self.score[label])
77 | self.fp[label] = self.fp[label][indices]
78 | self.tp[label] = self.tp[label][indices]
79 |
80 | # compute false positives and true positives
81 | self.fp[label] = np.cumsum(self.fp[label])
82 | self.tp[label] = np.cumsum(self.tp[label])
83 |
84 | # compute recall and precision
85 | recall = self.tp[label] / self.num_ann[label]
86 | precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps)
87 |
88 | # compute average precision
89 | average_precisions[label] = _compute_ap(recall, precision) * 100
90 |
91 | print('\n================== AP (Agent) ===================')
92 | s, n = 0, 0
93 |
94 | for label in range(len(self.act_name)):
95 | label_name = "_".join(self.act_name[label].split("_")[1:])
96 | print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label]))
97 | s += average_precisions[label]
98 | n += 1
99 |
100 | mAP = s/n
101 | print('| mAP(agent): {:0.2f}'.format(mAP))
102 | print('----------------------------------------------------')
103 |
104 | return mAP
--------------------------------------------------------------------------------
/src/models/detr_matcher.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Modules to compute the matching cost and solve the corresponding LSAP.
4 | """
5 | import torch
6 | from scipy.optimize import linear_sum_assignment
7 | from torch import nn
8 |
9 | from src.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
10 |
11 |
12 | class HungarianMatcher(nn.Module):
13 | """This class computes an assignment between the targets and the predictions of the network
14 | For efficiency reasons, the targets don't include the no_object. Because of this, in general,
15 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
16 | while the others are un-matched (and thus treated as non-objects).
17 | """
18 |
19 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
20 | """Creates the matcher
21 | Params:
22 | cost_class: This is the relative weight of the classification error in the matching cost
23 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
24 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
25 | """
26 | super().__init__()
27 | self.cost_class = cost_class
28 | self.cost_bbox = cost_bbox
29 | self.cost_giou = cost_giou
30 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
31 |
32 | @torch.no_grad()
33 | def forward(self, outputs, targets):
34 | """ Performs the matching
35 | Params:
36 | outputs: This is a dict that contains at least these entries:
37 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
38 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
39 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
40 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
41 | objects in the target) containing the class labels
42 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
43 | Returns:
44 | A list of size batch_size, containing tuples of (index_i, index_j) where:
45 | - index_i is the indices of the selected predictions (in order)
46 | - index_j is the indices of the corresponding selected targets (in order)
47 | For each batch element, it holds:
48 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
49 | """
50 | bs, num_queries = outputs["pred_logits"].shape[:2]
51 |
52 | # We flatten to compute the cost matrices in a batch
53 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes]
54 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4]
55 |
56 | # Also concat the target labels and boxes
57 | tgt_ids = torch.cat([v["labels"] for v in targets])
58 | tgt_bbox = torch.cat([v["boxes"] for v in targets])
59 |
60 | # Compute the classification cost. Contrary to the loss, we don't use the NLL,
61 | # but approximate it in 1 - proba[target class].
62 | # The 1 is a constant that doesn't change the matching, it can be ommitted.
63 | cost_class = -out_prob[:, tgt_ids]
64 |
65 | # Compute the L1 cost between boxes
66 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
67 |
68 | # Compute the giou cost betwen boxes
69 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
70 |
71 | # Final cost matrix
72 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
73 | C = C.view(bs, num_queries, -1).cpu()
74 |
75 | sizes = [len(v["boxes"]) for v in targets]
76 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
77 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
78 |
79 |
80 | def build_matcher(args):
81 | return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou)
--------------------------------------------------------------------------------
/src/engine/evaluator_vcoco.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/engine/evaluator_vcoco.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | import os
9 | import torch
10 | import time
11 | import datetime
12 |
13 | import src.util.misc as utils
14 | import src.util.logger as loggers
15 | from src.data.evaluators.vcoco_eval import VCocoEvaluator
16 | from src.util.box_ops import rescale_bboxes, rescale_pairs
17 | from src.models.stip_utils import check_annotation, plot_cross_attention
18 |
19 | import wandb
20 |
21 | @torch.no_grad()
22 | def vcoco_evaluate(model, criterion, postprocessors, data_loader, device, output_dir, thr):
23 | model.eval()
24 | criterion.eval()
25 |
26 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
27 | header = 'Evaluation Inference (V-COCO)'
28 |
29 | print_freq = 1 # len(data_loader)
30 | res = {}
31 | hoi_recognition_time = []
32 |
33 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
34 | samples = samples.to(device)
35 | targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
36 |
37 | # dec_selfattn_weights, dec_crossattn_weights = [], []
38 | # hook = model.interaction_decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_crossattn_weights.append(output[1]))
39 |
40 | outputs = model(samples)
41 | loss_dict = criterion(outputs, targets)
42 | loss_dict_reduced = utils.reduce_dict(loss_dict) # ddp gathering
43 |
44 | loss_dict_reduced_scaled = {k: v * criterion.weight_dict[k] for k, v in loss_dict_reduced.items() if k in criterion.weight_dict}
45 | loss_value = sum(loss_dict_reduced_scaled.values()).item()
46 |
47 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
48 | results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='vcoco')
49 | targets = process_target(targets, orig_target_sizes)
50 | hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
51 |
52 | # check_annotation(samples, targets, mode='eval', rel_num=20, dataset='vcoco')
53 | # plot_cross_attention(samples, outputs, targets, dec_crossattn_weights, dataset='vcoco'); hook.remove()
54 |
55 | res.update(
56 | {target['image_id'].item():\
57 | {'target': target, 'prediction': output} for target, output in zip(targets, results)
58 | }
59 | )
60 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled)
61 | # if len(res) > 10: break
62 |
63 | print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
64 | metric_logger.synchronize_between_processes()
65 | print("Averaged validation stats:", metric_logger)
66 |
67 | start_time = time.time()
68 | gather_res = utils.all_gather(res)
69 | total_res = {}
70 | for dist_res in gather_res:
71 | total_res.update(dist_res)
72 | total_time = time.time() - start_time
73 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
74 | print(f"[stats] Distributed Gathering Time : {total_time_str}")
75 |
76 | return total_res
77 |
78 | def vcoco_accumulate(total_res, args, print_results, wandb_log):
79 | vcoco_evaluator = VCocoEvaluator(args)
80 | vcoco_evaluator.update(total_res)
81 | print(f"[stats] Score Matrix Generation completed!! ")
82 |
83 | scenario1 = vcoco_evaluator.role_eval1.evaluate(print_results)
84 | scenario2 = vcoco_evaluator.role_eval2.evaluate(print_results)
85 |
86 | if wandb_log:
87 | wandb.log({
88 | 'scenario1': scenario1,
89 | 'scenario2': scenario2
90 | })
91 |
92 | return scenario1, scenario2
93 |
94 | def process_target(targets, target_sizes):
95 | for idx, (target, target_size) in enumerate(zip(targets, target_sizes)):
96 | labels = target['labels']
97 | valid_boxes_inds = (labels > 0)
98 |
99 | targets[idx]['boxes'] = rescale_bboxes(target['boxes'], target_size) # boxes
100 | targets[idx]['pair_boxes'] = rescale_pairs(target['pair_boxes'], target_size) # pairs
101 |
102 | return targets
--------------------------------------------------------------------------------
/src/engine/evaluator_hico.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import sys
4 | from typing import Iterable
5 | import numpy as np
6 | import copy
7 | import itertools
8 | import matplotlib.pyplot as plt
9 | import torch
10 |
11 | import src.util.misc as utils
12 | import src.util.logger as loggers
13 | from src.data.evaluators.hico_eval import HICOEvaluator
14 | from src.models.stip_utils import check_annotation, plot_cross_attention, plot_hoi_results
15 |
16 | @torch.no_grad()
17 | def hico_evaluate(model, postprocessors, data_loader, device, thr, args):
18 | model.eval()
19 |
20 | metric_logger = loggers.MetricLogger(mode="test", delimiter=" ")
21 | header = 'Evaluation Inference (HICO-DET)'
22 |
23 | preds = []
24 | gts = []
25 | indices = []
26 | hoi_recognition_time = []
27 |
28 | for samples, targets in metric_logger.log_every(data_loader, 50, header):
29 | samples = samples.to(device)
30 | targets = [{k: (v.to(device) if k != 'id' else v) for k, v in t.items()} for t in targets]
31 |
32 | # # register hooks to obtain intermediate outputs
33 | # dec_selfattn_weights, dec_crossattn_weights = [], []
34 | # if 'HOTR' in type(model).__name__:
35 | # hook_self = model.interaction_transformer.decoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: dec_selfattn_weights.append(output[1]))
36 | # hook_cross = model.interaction_transformer.decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_crossattn_weights.append(output[1]))
37 | # else:
38 | # hook_self = model.interaction_decoder.layers[-1].self_attn.register_forward_hook(lambda self, input, output: dec_selfattn_weights.append(output[1]))
39 | # hook_cross = model.interaction_decoder.layers[-1].multihead_attn.register_forward_hook(lambda self, input, output: dec_crossattn_weights.append(output[1]))
40 |
41 | outputs = model(samples)
42 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
43 | results = postprocessors['hoi'](outputs, orig_target_sizes, threshold=thr, dataset='hico-det')
44 | hoi_recognition_time.append(results[0]['hoi_recognition_time'] * 1000)
45 |
46 | # # visualize
47 | # if targets[0]['id'] in [57]: # [47, 57, 81, 30, 46, 97]: # 30, 46, 97
48 | # # check_annotation(samples, targets, mode='eval', rel_num=20)
49 | #
50 | # # visualize cross-attentioa
51 | # if 'HOTR' in type(model).__name__:
52 | # outputs['pred_actions'] = outputs['pred_actions'][:, :, :args.num_actions]
53 | # outputs['pred_rel_pairs'] = [x.cpu() for x in torch.stack([outputs['pred_hidx'].argmax(-1), outputs['pred_oidx'].argmax(-1)], dim=-1)]
54 | # topk_qids, q_name_list = plot_hoi_results(samples, outputs, targets, args=args)
55 | # plot_cross_attention(samples, outputs, targets, dec_crossattn_weights, topk_qids=topk_qids)
56 | # print(f"image_id={targets[0]['id']}")
57 | #
58 | # # visualize self attention
59 | # print('visualize self-attention')
60 | # q_num = len(dec_selfattn_weights[0][0])
61 | # plt.figure(figsize=(10,4))
62 | # plt.imshow(dec_selfattn_weights[0][0].cpu().numpy(), vmin=0, vmax=0.4)
63 | # plt.xticks(np.arange(q_num), [f"{i}" for i in range(q_num)], rotation=90, fontsize=12)
64 | # plt.yticks(np.arange(q_num), [f"({q_name_list[i]})={i}" for i in range(q_num)], fontsize=12)
65 | # plt.gca().xaxis.set_ticks_position('top')
66 | # plt.grid(alpha=0.4, linestyle=':')
67 | # plt.show()
68 | # hook_self.remove(); hook_cross.remove()
69 |
70 | preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results))))
71 | # For avoiding a runtime error, the copy is used
72 | gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets)))))
73 |
74 |
75 | print(f"[stats] HOI Recognition Time (avg) : {sum(hoi_recognition_time)/len(hoi_recognition_time):.4f} ms")
76 |
77 | # gather the stats from all processes
78 | metric_logger.synchronize_between_processes()
79 |
80 | img_ids = [img_gts['id'] for img_gts in gts]
81 | _, indices = np.unique(img_ids, return_index=True)
82 | preds = [img_preds for i, img_preds in enumerate(preds) if i in indices]
83 | gts = [img_gts for i, img_gts in enumerate(gts) if i in indices]
84 |
85 | evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets,
86 | data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat)
87 |
88 | stats = evaluator.evaluate()
89 |
90 | return stats
91 |
92 |
--------------------------------------------------------------------------------
/src/models/backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Backbone modules.
4 | """
5 | from collections import OrderedDict
6 |
7 | import torch
8 | import torch.nn.functional as F
9 | import torchvision
10 | from torch import nn
11 | from torchvision.models._utils import IntermediateLayerGetter
12 | from typing import Dict, List
13 |
14 | from src.util.misc import NestedTensor, is_main_process
15 |
16 | from .position_encoding import build_position_encoding
17 |
18 |
19 | class FrozenBatchNorm2d(torch.nn.Module):
20 | """
21 | BatchNorm2d where the batch statistics and the affine parameters are fixed.
22 | Copy-paste from torchvision.misc.ops with added eps before rqsrt,
23 | without which any other models than torchvision.models.resnet[18,34,50,101]
24 | produce nans.
25 | """
26 |
27 | def __init__(self, n):
28 | super(FrozenBatchNorm2d, self).__init__()
29 | self.register_buffer("weight", torch.ones(n))
30 | self.register_buffer("bias", torch.zeros(n))
31 | self.register_buffer("running_mean", torch.zeros(n))
32 | self.register_buffer("running_var", torch.ones(n))
33 |
34 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
35 | missing_keys, unexpected_keys, error_msgs):
36 | num_batches_tracked_key = prefix + 'num_batches_tracked'
37 | if num_batches_tracked_key in state_dict:
38 | del state_dict[num_batches_tracked_key]
39 |
40 | super(FrozenBatchNorm2d, self)._load_from_state_dict(
41 | state_dict, prefix, local_metadata, strict,
42 | missing_keys, unexpected_keys, error_msgs)
43 |
44 | def forward(self, x):
45 | # move reshapes to the beginning
46 | # to make it fuser-friendly
47 | w = self.weight.reshape(1, -1, 1, 1)
48 | b = self.bias.reshape(1, -1, 1, 1)
49 | rv = self.running_var.reshape(1, -1, 1, 1)
50 | rm = self.running_mean.reshape(1, -1, 1, 1)
51 | eps = 1e-5
52 | scale = w * (rv + eps).rsqrt()
53 | bias = b - rm * scale
54 | return x * scale + bias
55 |
56 |
57 | class BackboneBase(nn.Module):
58 |
59 | def __init__(self, args, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool):
60 | super().__init__()
61 | for name, parameter in backbone.named_parameters():
62 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
63 | parameter.requires_grad_(False) # fix other layers
64 | if return_interm_layers:
65 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"}
66 | else:
67 | if args.use_high_resolution_relation_feature_map:
68 | return_layers = {'layer3': "0", 'layer4': "1"}
69 | else:
70 | return_layers = {'layer4': "0"}
71 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
72 | self.num_channels = num_channels
73 |
74 | def forward(self, tensor_list: NestedTensor):
75 | xs = self.body(tensor_list.tensors)
76 | out: Dict[str, NestedTensor] = {}
77 | for name, x in xs.items():
78 | m = tensor_list.mask
79 | assert m is not None
80 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0]
81 | out[name] = NestedTensor(x, mask)
82 | return out
83 |
84 |
85 | class Backbone(BackboneBase):
86 | """ResNet backbone with frozen BatchNorm."""
87 | def __init__(self, args, name: str,
88 | train_backbone: bool,
89 | return_interm_layers: bool,
90 | dilation: bool):
91 | backbone = getattr(torchvision.models, name)(
92 | replace_stride_with_dilation=[False, False, dilation],
93 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d)
94 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048
95 | super().__init__(args, backbone, train_backbone, num_channels, return_interm_layers)
96 |
97 |
98 | class Joiner(nn.Sequential):
99 | def __init__(self, backbone, position_embedding):
100 | super().__init__(backbone, position_embedding)
101 |
102 | def forward(self, tensor_list: NestedTensor):
103 | xs = self[0](tensor_list)
104 | out: List[NestedTensor] = []
105 | pos = []
106 | for name, x in xs.items():
107 | out.append(x)
108 | # position encoding
109 | pos.append(self[1](x).to(x.tensors.dtype))
110 |
111 | return out, pos
112 |
113 |
114 | def build_backbone(args):
115 | position_embedding = build_position_encoding(args)
116 | train_backbone = args.lr_backbone > 0
117 | return_interm_layers = False # args.masks
118 | backbone = Backbone(args, args.backbone, train_backbone, return_interm_layers, args.dilation)
119 | model = Joiner(backbone, position_embedding)
120 | model.num_channels = backbone.num_channels
121 | return model
122 |
--------------------------------------------------------------------------------
/src/data/datasets/coco.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | COCO dataset which returns image_id for evaluation.
4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py
5 | """
6 | from pathlib import Path
7 |
8 | import torch
9 | import torch.utils.data
10 | import torchvision
11 | from pycocotools import mask as coco_mask
12 |
13 | import src.data.transforms.transforms as T
14 |
15 | class CocoDetection(torchvision.datasets.CocoDetection):
16 | def __init__(self, img_folder, ann_file, transforms, return_masks):
17 | super(CocoDetection, self).__init__(img_folder, ann_file)
18 | self._transforms = transforms
19 | self.prepare = ConvertCocoPolysToMask(return_masks)
20 |
21 | def __getitem__(self, idx):
22 | img, target = super(CocoDetection, self).__getitem__(idx)
23 | image_id = self.ids[idx]
24 | target = {'image_id': image_id, 'annotations': target}
25 | img, target = self.prepare(img, target)
26 | if self._transforms is not None:
27 | img, target = self._transforms(img, target)
28 | return img, target
29 |
30 |
31 | def convert_coco_poly_to_mask(segmentations, height, width):
32 | masks = []
33 | for polygons in segmentations:
34 | rles = coco_mask.frPyObjects(polygons, height, width)
35 | mask = coco_mask.decode(rles)
36 | if len(mask.shape) < 3:
37 | mask = mask[..., None]
38 | mask = torch.as_tensor(mask, dtype=torch.uint8)
39 | mask = mask.any(dim=2)
40 | masks.append(mask)
41 | if masks:
42 | masks = torch.stack(masks, dim=0)
43 | else:
44 | masks = torch.zeros((0, height, width), dtype=torch.uint8)
45 | return masks
46 |
47 |
48 | class ConvertCocoPolysToMask(object):
49 | def __init__(self, return_masks=False):
50 | self.return_masks = return_masks
51 |
52 | def __call__(self, image, target):
53 | w, h = image.size
54 |
55 | image_id = target["image_id"]
56 | image_id = torch.tensor([image_id])
57 |
58 | anno = target["annotations"]
59 |
60 | anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0]
61 |
62 | boxes = [obj["bbox"] for obj in anno]
63 | # guard against no boxes via resizing
64 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
65 | boxes[:, 2:] += boxes[:, :2] # (x1, y1, w, h) -> (x1, y1, x2, y2)
66 | boxes[:, 0::2].clamp_(min=0, max=w)
67 | boxes[:, 1::2].clamp_(min=0, max=h)
68 |
69 | classes = [obj["category_id"] for obj in anno]
70 | classes = torch.tensor(classes, dtype=torch.int64)
71 |
72 | if self.return_masks:
73 | segmentations = [obj["segmentation"] for obj in anno]
74 | masks = convert_coco_poly_to_mask(segmentations, h, w)
75 |
76 | keypoints = None
77 | if anno and "keypoints" in anno[0]:
78 | keypoints = [obj["keypoints"] for obj in anno]
79 | keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
80 | num_keypoints = keypoints.shape[0]
81 | if num_keypoints:
82 | keypoints = keypoints.view(num_keypoints, -1, 3)
83 |
84 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
85 | boxes = boxes[keep]
86 | classes = classes[keep]
87 | if self.return_masks:
88 | masks = masks[keep]
89 | if keypoints is not None:
90 | keypoints = keypoints[keep]
91 |
92 | target = {}
93 | target["boxes"] = boxes
94 | target["labels"] = classes
95 | if self.return_masks:
96 | target["masks"] = masks
97 | target["image_id"] = image_id
98 | if keypoints is not None:
99 | target["keypoints"] = keypoints
100 |
101 | # for conversion to coco api
102 | area = torch.tensor([obj["area"] for obj in anno])
103 | iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno])
104 | target["area"] = area[keep]
105 | target["iscrowd"] = iscrowd[keep]
106 |
107 | target["orig_size"] = torch.as_tensor([int(h), int(w)])
108 | target["size"] = torch.as_tensor([int(h), int(w)])
109 |
110 | return image, target
111 |
112 |
113 | def make_coco_transforms(image_set):
114 |
115 | normalize = T.Compose([
116 | T.ToTensor(),
117 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
118 | ])
119 |
120 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
121 |
122 | if image_set == 'train':
123 | return T.Compose([
124 | T.RandomHorizontalFlip(),
125 | T.RandomSelect(
126 | T.RandomResize(scales, max_size=1333),
127 | T.Compose([
128 | T.RandomResize([400, 500, 600]),
129 | T.RandomSizeCrop(384, 600),
130 | T.RandomResize(scales, max_size=1333),
131 | ])
132 | ),
133 | normalize,
134 | ])
135 |
136 | if image_set == 'val':
137 | return T.Compose([
138 | T.RandomResize([800], max_size=1333),
139 | normalize,
140 | ])
141 |
142 | raise ValueError(f'unknown {image_set}')
143 |
144 |
145 | def build(image_set, args):
146 | root = Path(args.data_path)
147 | assert root.exists(), f'provided COCO path {root} does not exist'
148 | mode = 'instances'
149 | PATHS = {
150 | "train": (root / "train2017", root / "annotations" / f'{mode}_train2017.json'),
151 | "val": (root / "val2017", root / "annotations" / f'{mode}_val2017.json'),
152 | }
153 |
154 | img_folder, ann_file = PATHS[image_set]
155 | dataset = CocoDetection(img_folder, ann_file, transforms=make_coco_transforms(image_set), return_masks=args.masks)
156 | return dataset
--------------------------------------------------------------------------------
/src/util/logger.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/util/logger.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | import torch
9 | import time
10 | import datetime
11 | import sys
12 | from time import sleep
13 | from collections import defaultdict
14 |
15 | from src.util.misc import SmoothedValue
16 |
17 | def print_params(model):
18 | n_parameters = sum(p.numel() for p in model.parameters())
19 | print('\n[Logger] Number of total params: ', n_parameters)
20 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
21 | print('\n[Logger] Number of trainable params: ', n_parameters)
22 | return n_parameters
23 |
24 | def print_args(args):
25 | print('\n[Logger] DETR Arguments:')
26 | for k, v in vars(args).items():
27 | if k in [
28 | 'lr', 'lr_backbone', 'lr_drop',
29 | 'frozen_weights',
30 | 'backbone', 'dilation',
31 | 'position_embedding', 'enc_layers', 'dec_layers', 'num_queries',
32 | 'dataset_file']:
33 | print(f'\t{k}: {v}')
34 |
35 | if args.HOIDet:
36 | print('\n[Logger] DETR_HOI Arguments:')
37 | for k, v in vars(args).items():
38 | if k in [
39 | 'freeze_enc',
40 | 'query_flag',
41 | 'hoi_nheads',
42 | 'hoi_dim_feedforward',
43 | 'hoi_dec_layers',
44 | 'hoi_idx_loss_coef',
45 | 'hoi_act_loss_coef',
46 | 'hoi_eos_coef',
47 | 'object_threshold']:
48 | print(f'\t{k}: {v}')
49 |
50 | class MetricLogger(object):
51 | def __init__(self, mode="test", delimiter="\t"):
52 | self.meters = defaultdict(SmoothedValue)
53 | self.delimiter = delimiter
54 | self.mode = mode
55 |
56 | def update(self, **kwargs):
57 | for k, v in kwargs.items():
58 | if isinstance(v, torch.Tensor):
59 | v = v.item()
60 | assert isinstance(v, (float, int))
61 | self.meters[k].update(v)
62 |
63 | def __getattr__(self, attr):
64 | if attr in self.meters:
65 | return self.meters[attr]
66 | if attr in self.__dict__:
67 | return self.__dict__[attr]
68 | raise AttributeError("'{}' object has no attribute '{}'".format(
69 | type(self).__name__, attr))
70 |
71 | def __str__(self):
72 | loss_str = []
73 | for name, meter in self.meters.items():
74 | loss_str.append(
75 | "{}: {}".format(name, str(meter))
76 | )
77 | return self.delimiter.join(loss_str)
78 |
79 | def synchronize_between_processes(self):
80 | for meter in self.meters.values():
81 | meter.synchronize_between_processes()
82 |
83 | def add_meter(self, name, meter):
84 | self.meters[name] = meter
85 |
86 | def log_every(self, iterable, print_freq, header=None):
87 | i = 0
88 | if not header:
89 | header = ''
90 | start_time = time.time()
91 | end = time.time()
92 | iter_time = SmoothedValue(fmt='{avg:.4f}')
93 | data_time = SmoothedValue(fmt='{avg:.4f}')
94 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
95 | if torch.cuda.is_available():
96 | log_msg = self.delimiter.join([
97 | header,
98 | '[{0' + space_fmt + '}/{1}]',
99 | 'eta: {eta}',
100 | '{meters}',
101 | 'time: {time}',
102 | 'data: {data}',
103 | 'max mem: {memory:.0f}'
104 | ])
105 | else:
106 | log_msg = self.delimiter.join([
107 | header,
108 | '[{0' + space_fmt + '}/{1}]',
109 | 'eta: {eta}',
110 | '{meters}',
111 | 'time: {time}',
112 | 'data: {data}'
113 | ])
114 | MB = 1024.0 * 1024.0
115 | for obj in iterable:
116 | data_time.update(time.time() - end)
117 | yield obj
118 | iter_time.update(time.time() - end)
119 |
120 | if (i % print_freq == 0 and i !=0) or i == len(iterable) - 1:
121 | eta_seconds = iter_time.global_avg * (len(iterable) - i)
122 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
123 | if torch.cuda.is_available():
124 | print(log_msg.format(
125 | i+1, len(iterable), eta=eta_string,
126 | meters=str(self),
127 | time=str(iter_time), data=str(data_time),
128 | memory=torch.cuda.max_memory_allocated() / MB),
129 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
130 | else:
131 | print(log_msg.format(
132 | i+1, len(iterable), eta=eta_string,
133 | meters=str(self),
134 | time=str(iter_time), data=str(data_time)),
135 | flush=(self.mode=='test'), end=("\r" if self.mode=='test' else "\n"))
136 | else:
137 | log_interval = self.delimiter.join([header, '[{0' + space_fmt + '}/{1}]'])
138 | if torch.cuda.is_available(): print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
139 | else: print(log_interval.format(i+1, len(iterable)), flush=True, end="\r")
140 |
141 | i += 1
142 | end = time.time()
143 | total_time = time.time() - start_time
144 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
145 | if self.mode=='test': print("")
146 | print('[stats] Total Time ({}) : {} ({:.4f} s / it)'.format(
147 | self.mode, total_time_str, total_time / len(iterable)))
148 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## [CVPR 2022] Exploring Structure-aware Transformer over Interaction Proposals for Human-object Interaction Detection
2 |
3 | * [PDF](https://openaccess.thecvf.com/content/CVPR2022/papers/Zhang_Exploring_Structure-Aware_Transformer_Over_Interaction_Proposals_for_Human-Object_Interaction_Detection_CVPR_2022_paper.pdf)
4 | * [Supplementary](https://openaccess.thecvf.com/content/CVPR2022/supplemental/Zhang_Exploring_Structure-Aware_Transformer_CVPR_2022_supplemental.pdf)
5 |
6 | ## Paper introduction
7 |
8 | Recent high-performing Human-Object Interaction (HOI) detection techniques have been highly influenced by Transformer-based object detector (i.e., DETR). Nevertheless, most of them directly map parametric interaction queries into a set of HOI predictions through vanilla Transformer in a one-stage manner. This leaves rich inter- or intra-interaction structure under-exploited. In this work, we design a novel Transformer-style HOI detector, i.e., Structure-aware Transformer over Interaction Proposals (STIP), for HOI detection. Such design decomposes the process of HOI set prediction into two subsequent phases, i.e., an interaction proposal generation is first performed, and then followed by transforming the non-parametric interaction proposals into HOI predictions via a structure-aware Transformer. The structure-aware Transformer upgrades vanilla Transformer by encoding additionally the holistically semantic structure among interaction proposals as well as the locally spatial structure of human/object within each interaction proposal, so as to strengthen HOI predictions. Extensive experiments conducted on V-COCO and HICO-DET benchmarks have demonstrated the effectiveness of STIP, and superior results are reported when comparing with the state-of-the-art HOI detectors.
9 |
10 |

11 |
12 | ## Installation
13 |
14 | ### 1. Environmental Setup
15 | ```bash
16 | $ conda create -n STIP python=3.7
17 | $ conda install -c pytorch pytorch torchvision # PyTorch 1.7.1, torchvision 0.8.2, CUDA=11.0
18 | $ conda install cython scipy
19 | $ pip install pycocotools
20 | $ pip install opencv-python
21 | $ pip install wandb
22 | ```
23 |
24 | ### 2. HOI dataset setup
25 | Our current version supports the experiments for
26 | - [V-COCO](https://github.com/s-gupta/v-coco) and [HICO-DET](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) dataset. Download the dataset under the pulled directory.
27 | - For HICO-DET, ~we use the [annotation files](https://drive.google.com/file/d/1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk/view) provided by the PPDM authors. Download the [list of actions](https://drive.google.com/open?id=1EeHNHuYyJI-qqDk_-5nay7Mb07tzZLsl) as `list_action.txt` and place them under the unballed hico-det directory.~ If these download links are broken, please download them from [hico_20160224_det](https://cuhko365-my.sharepoint.com/:u:/g/personal/218019030_link_cuhk_edu_cn/EY3AbysMMUBJmQEt8zl4j9UBxHH6no2-g4o63EIt38GkgQ?e=oCEIsG).
28 |
29 | Below we present how you should place the files.
30 | ```bash
31 | # V-COCO setup
32 | $ git clone https://github.com/s-gupta/v-coco.git
33 | $ cd v-coco
34 | $ ln -s [:COCO_DIR] coco/images # COCO_DIR contains images of train2014 & val2014
35 | $ python script_pick_annotations.py [:COCO_DIR]/annotations
36 |
37 | # HICO-DET setup
38 | $ tar -zxvf hico_20160224_det.tar.gz # move the unballed folder under the pulled repository
39 |
40 | # dataset setup
41 | STIP
42 | │─ v-coco
43 | │ │─ data
44 | │ │ │─ instances_vcoco_all_2014.json
45 | │ │ :
46 | │ └─ coco
47 | │ │─ images
48 | │ │ │─ train2014
49 | │ │ │ │─ COCO_train2014_000000000009.jpg
50 | │ │ │ :
51 | │ │ └─ val2014
52 | │ │ │─ COCO_val2014_000000000042.jpg
53 | : : :
54 | │─ hico_20160224_det
55 | │ │─ list_action.txt
56 | │ │─ annotations
57 | │ │ │─ trainval_hico.json
58 | │ │ │─ test_hico.json
59 | │ │ └─ corre_hico.npy
60 | : :
61 | ```
62 |
63 | If you wish to download the datasets on our own directory, simply change the 'data_path' argument to the directory you have downloaded the datasets.
64 | ```bash
65 | --data_path [:your_own_directory]/[v-coco/hico_20160224_det]
66 | ```
67 |
68 | ### 3. Training/Testing on V-COCO
69 |
70 | ```shell
71 | python STIP_main.py --validate \
72 | --num_hoi_queries 32 --batch_size 4 --lr 5e-5 --HOIDet --hoi_aux_loss --no_aux_loss \
73 | --dataset_file vcoco --data_path v-coco --detr_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \
74 | --output_dir checkpoints/vcoco --group_name STIP_debug --run_name vcoco_run1
75 | ```
76 |
77 | * Add `--eval` option for evaluation
78 |
79 | ### 4. Training/Testing on HICO-DET
80 |
81 | Training with pretrained DETR detector on COCO.
82 | ```shell
83 | python STIP_main.py --validate \
84 | --num_hoi_queries 32 --batch_size 4 --lr 5e-5 --HOIDet --hoi_aux_loss --no_aux_loss \
85 | --dataset_file hico-det --data_path hico_20160224_det --detr_weights https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth \
86 | --output_dir checkpoints/hico-det --group_name STIP_debug --run_name hicodet_run1
87 | ```
88 |
89 | Jointly fine-tune object detector & HOI detector
90 | ```shell
91 | python STIP_main.py --validate \
92 | --num_hoi_queries 32 --batch_size 2 --lr 1e-5 --HOIDet --hoi_aux_loss \
93 | --dataset_file hico-det --data_path hico_20160224_det \
94 | --output_dir checkpoints/hico-det --group_name STIP_debug --run_name hicodet_run1/jointly-tune \
95 | --resume checkpoints/hico-det/STIP_debug/best.pth --train_detr
96 | ```
97 |
98 | Pre-trained models:
99 | * [VCOCO](https://cuhko365-my.sharepoint.com/:u:/g/personal/218019030_link_cuhk_edu_cn/ETPIGJHooGVPuOXDBC3WTaoBQBGHXioj4hAbjllSTSOn6A?e=KBN0bq)
100 | * [HICO-DET (w/o jointly finetune)](https://cuhko365-my.sharepoint.com/:u:/g/personal/218019030_link_cuhk_edu_cn/EZyYzmW7vL9Bh7zKwT1JGFEBKbTkbs4vD9QW7kfs8EV_gQ?e=NjKBjX)
101 |
102 | ## License
103 |
104 | This repo is released under the [Apache License, Version 2.0](LICENSE).
105 |
106 |
107 | ## Acknowledgement
108 |
109 | This repo is based on [DETR](https://github.com/facebookresearch/detr), [HOTR](https://github.com/kakaobrain/HOTR). Thanks for their wonderful works.
110 |
111 |
112 |
113 | ## Citation
114 |
115 | If you find this code helpful for your research, please cite our paper.
116 | ```
117 | @InProceedings{Zhang_2022_CVPR,
118 | author = {Zhang, Yong and Pan, Yingwei and Yao, Ting and Huang, Rui and Mei, Tao and Chen, Chang-Wen},
119 | title = {Exploring Structure-Aware Transformer Over Interaction Proposals for Human-Object Interaction Detection},
120 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
121 | month = {June},
122 | year = {2022},
123 | pages = {19548-19557}
124 | }
125 | ```
126 |
--------------------------------------------------------------------------------
/src/data/datasets/builtin_meta.py:
--------------------------------------------------------------------------------
1 | COCO_CATEGORIES = [
2 | {"color": [], "isthing": 0, "id": 0, "name": "N/A"},
3 | {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"},
4 | {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"},
5 | {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"},
6 | {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"},
7 | {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"},
8 | {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"},
9 | {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"},
10 | {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"},
11 | {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"},
12 | {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"},
13 | {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"},
14 | {"color": [], "isthing": 0, "id": 12, "name": "N/A"},
15 | {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"},
16 | {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"},
17 | {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"},
18 | {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"},
19 | {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"},
20 | {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"},
21 | {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"},
22 | {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"},
23 | {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"},
24 | {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"},
25 | {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"},
26 | {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"},
27 | {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"},
28 | {"color": [], "isthing": 0, "id": 26, "name": "N/A"},
29 | {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"},
30 | {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"},
31 | {"color": [], "isthing": 0, "id": 29, "name": "N/A"},
32 | {"color": [], "isthing": 0, "id": 30, "name": "N/A"},
33 | {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"},
34 | {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"},
35 | {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"},
36 | {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"},
37 | {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"},
38 | {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"},
39 | {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"},
40 | {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"},
41 | {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"},
42 | {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"},
43 | {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"},
44 | {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"},
45 | {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"},
46 | {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"},
47 | {"color": [], "isthing": 0, "id": 45, "name": "N/A"},
48 | {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"},
49 | {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"},
50 | {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"},
51 | {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"},
52 | {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"},
53 | {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"},
54 | {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"},
55 | {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"},
56 | {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"},
57 | {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"},
58 | {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"},
59 | {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"},
60 | {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"},
61 | {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"},
62 | {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"},
63 | {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"},
64 | {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"},
65 | {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"},
66 | {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"},
67 | {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"},
68 | {"color": [], "isthing": 0, "id": 66, "name": "N/A"},
69 | {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"},
70 | {"color": [], "isthing": 0, "id": 68, "name": "N/A"},
71 | {"color": [], "isthing": 0, "id": 69, "name": "N/A"},
72 | {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"},
73 | {"color": [], "isthing": 0, "id": 71, "name": "N/A"},
74 | {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"},
75 | {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"},
76 | {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"},
77 | {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"},
78 | {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"},
79 | {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"},
80 | {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"},
81 | {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"},
82 | {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"},
83 | {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"},
84 | {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"},
85 | {"color": [], "isthing": 0, "id": 83, "name": "N/A"},
86 | {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"},
87 | {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"},
88 | {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"},
89 | {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"},
90 | {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"},
91 | {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"},
92 | {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"},
93 | ]
94 |
95 | def _get_coco_instances_meta():
96 | thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1]
97 | assert len(thing_ids) == 80, f"Length of thing ids : {len(thing_ids)}"
98 |
99 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)}
100 | thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1]
101 | thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1]
102 |
103 | coco_classes = [k["name"] for k in COCO_CATEGORIES]
104 |
105 | return {
106 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id,
107 | "thing_classes": thing_classes,
108 | "thing_colors": thing_colors,
109 | "coco_classes": coco_classes,
110 | }
111 |
--------------------------------------------------------------------------------
/src/models/hotr.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/models/src.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | import copy
9 | import time
10 | import datetime
11 |
12 | from src.util.misc import NestedTensor, nested_tensor_from_tensor_list
13 | from .feed_forward import MLP
14 |
15 | class HOTR(nn.Module):
16 | def __init__(self, detr,
17 | num_hoi_queries,
18 | num_actions,
19 | interaction_transformer,
20 | freeze_detr,
21 | share_enc,
22 | pretrained_dec,
23 | temperature,
24 | hoi_aux_loss,
25 | return_obj_class=None):
26 | super().__init__()
27 |
28 | # * Instance Transformer ---------------
29 | self.detr = detr
30 | if freeze_detr:
31 | # if this flag is given, freeze the object detection related parameters of DETR
32 | for p in self.parameters():
33 | p.requires_grad_(False)
34 | hidden_dim = detr.transformer.d_model
35 | # --------------------------------------
36 |
37 | # * Interaction Transformer -----------------------------------------
38 | self.num_queries = num_hoi_queries
39 | self.query_embed = nn.Embedding(self.num_queries, hidden_dim)
40 | self.H_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
41 | self.O_Pointer_embed = MLP(hidden_dim, hidden_dim, hidden_dim, 3)
42 | self.action_embed = nn.Linear(hidden_dim, num_actions+1)
43 | # --------------------------------------------------------------------
44 |
45 | # * HICO-DET FFN heads ---------------------------------------------
46 | self.return_obj_class = (return_obj_class is not None)
47 | if return_obj_class: self._valid_obj_ids = return_obj_class + [return_obj_class[-1]+1]
48 | # ------------------------------------------------------------------
49 |
50 | # * Transformer Options ---------------------------------------------
51 | self.interaction_transformer = interaction_transformer
52 |
53 | if share_enc: # share encoder
54 | self.interaction_transformer.encoder = detr.transformer.encoder
55 |
56 | if pretrained_dec: # free variables for interaction decoder
57 | self.interaction_transformer.decoder = copy.deepcopy(detr.transformer.decoder)
58 | for p in self.interaction_transformer.decoder.parameters():
59 | p.requires_grad_(True)
60 | # ---------------------------------------------------------------------
61 |
62 | # * Loss Options -------------------
63 | self.tau = temperature
64 | self.hoi_aux_loss = hoi_aux_loss
65 | # ----------------------------------
66 |
67 | def forward(self, samples: NestedTensor, targets=None):
68 | if isinstance(samples, (list, torch.Tensor)):
69 | samples = nested_tensor_from_tensor_list(samples)
70 |
71 | # >>>>>>>>>>>> BACKBONE LAYERS <<<<<<<<<<<<<<<
72 | features, pos = self.detr.backbone(samples)
73 | bs = features[-1].tensors.shape[0]
74 | src, mask = features[-1].decompose()
75 | assert mask is not None
76 | # ----------------------------------------------
77 |
78 | # >>>>>>>>>>>> OBJECT DETECTION LAYERS <<<<<<<<<<
79 | start_time = time.time()
80 | hs, _ = self.detr.transformer(self.detr.input_proj(src), mask, self.detr.query_embed.weight, pos[-1])
81 | inst_repr = F.normalize(hs[-1], p=2, dim=2) # instance representations
82 |
83 | # Prediction Heads for Object Detection
84 | outputs_class = self.detr.class_embed(hs)
85 | outputs_coord = self.detr.bbox_embed(hs).sigmoid()
86 | object_detection_time = time.time() - start_time
87 | # -----------------------------------------------
88 |
89 | # >>>>>>>>>>>> HOI DETECTION LAYERS <<<<<<<<<<<<<<<
90 | start_time = time.time()
91 | assert hasattr(self, 'interaction_transformer'), "Missing Interaction Transformer."
92 | interaction_hs = self.interaction_transformer(self.detr.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] # interaction representations
93 |
94 | # [HO Pointers]
95 | H_Pointer_reprs = F.normalize(self.H_Pointer_embed(interaction_hs), p=2, dim=-1)
96 | O_Pointer_reprs = F.normalize(self.O_Pointer_embed(interaction_hs), p=2, dim=-1)
97 | outputs_hidx = [(torch.bmm(H_Pointer_repr, inst_repr.transpose(1,2))) / self.tau for H_Pointer_repr in H_Pointer_reprs]
98 | outputs_oidx = [(torch.bmm(O_Pointer_repr, inst_repr.transpose(1,2))) / self.tau for O_Pointer_repr in O_Pointer_reprs]
99 |
100 | # [Action Classification]
101 | outputs_action = self.action_embed(interaction_hs)
102 | # --------------------------------------------------
103 | hoi_detection_time = time.time() - start_time
104 | hoi_recognition_time = max(hoi_detection_time - object_detection_time, 0)
105 | # -------------------------------------------------------------------
106 |
107 | # [Target Classification]
108 | if self.return_obj_class:
109 | detr_logits = outputs_class[-1, ..., self._valid_obj_ids]
110 | o_indices = [output_oidx.max(-1)[-1] for output_oidx in outputs_oidx]
111 | obj_logit_stack = [torch.stack([detr_logits[batch_, o_idx, :] for batch_, o_idx in enumerate(o_indice)], 0) for o_indice in o_indices]
112 | outputs_obj_class = obj_logit_stack
113 |
114 | out = {
115 | "pred_logits": outputs_class[-1],
116 | "pred_boxes": outputs_coord[-1],
117 | "pred_hidx": outputs_hidx[-1],
118 | "pred_oidx": outputs_oidx[-1],
119 | "pred_actions": outputs_action[-1],
120 | "hoi_recognition_time": hoi_recognition_time,
121 | }
122 |
123 | if self.return_obj_class: out["pred_obj_logits"] = outputs_obj_class[-1]
124 |
125 | if self.hoi_aux_loss: # auxiliary loss
126 | out['hoi_aux_outputs'] = \
127 | self._set_aux_loss_with_tgt(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_obj_class) \
128 | if self.return_obj_class else \
129 | self._set_aux_loss(outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action)
130 |
131 | return out
132 |
133 | @torch.jit.unused
134 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action):
135 | return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e}
136 | for a, b, c, d, e in zip(
137 | outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
138 | outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
139 | outputs_hidx[:-1],
140 | outputs_oidx[:-1],
141 | outputs_action[:-1])]
142 |
143 | @torch.jit.unused
144 | def _set_aux_loss_with_tgt(self, outputs_class, outputs_coord, outputs_hidx, outputs_oidx, outputs_action, outputs_tgt):
145 | return [{'pred_logits': a, 'pred_boxes': b, 'pred_hidx': c, 'pred_oidx': d, 'pred_actions': e, 'pred_obj_logits': f}
146 | for a, b, c, d, e, f in zip(
147 | outputs_class[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
148 | outputs_coord[-1:].repeat((outputs_action.shape[0], 1, 1, 1)),
149 | outputs_hidx[:-1],
150 | outputs_oidx[:-1],
151 | outputs_action[:-1],
152 | outputs_tgt[:-1])]
--------------------------------------------------------------------------------
/src/models/post_process.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/models/post_process.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | import time
6 | import copy
7 | import torch
8 | import torch.nn.functional as F
9 | from torch import nn
10 | from src.util import box_ops
11 |
12 | class PostProcess(nn.Module):
13 | """ This module converts the model's output into the format expected by the coco api"""
14 | def __init__(self, HOIDet):
15 | super().__init__()
16 | self.HOIDet = HOIDet
17 |
18 | @torch.no_grad()
19 | def forward(self, outputs, target_sizes, threshold=0, dataset='coco'):
20 | """ Perform the computation
21 | Parameters:
22 | outputs: raw outputs of the model
23 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch
24 | For evaluation, this must be the original image size (before any data augmentation)
25 | For visualization, this should be the image size after data augment, but before padding
26 | """
27 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']
28 |
29 | assert len(out_logits) == len(target_sizes)
30 | assert target_sizes.shape[1] == 2
31 |
32 | prob = F.softmax(out_logits, -1)
33 | scores, labels = prob[..., :-1].max(-1)
34 |
35 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox)
36 | img_h, img_w = target_sizes.unbind(1)
37 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
38 | boxes = boxes * scale_fct[:, None, :]
39 |
40 | # Preidction Branch for HOI detection
41 | if self.HOIDet:
42 | if dataset == 'vcoco':
43 | """ Compute HOI triplet prediction score for V-COCO.
44 | Our scoring function follows the implementation details of UnionDet.
45 | """
46 | out_time = outputs['hoi_recognition_time']
47 |
48 | start_time = time.time()
49 | pair_actions = torch.sigmoid(outputs['pred_actions'])
50 | h_prob = F.softmax(outputs['pred_hidx'], -1)
51 | h_idx_score, h_indices = h_prob.max(-1)
52 |
53 | o_prob = F.softmax(outputs['pred_oidx'], -1)
54 | o_idx_score, o_indices = o_prob.max(-1)
55 | hoi_recognition_time = (time.time() - start_time) + out_time
56 |
57 | results = []
58 | # iterate for batch size
59 | for batch_idx, (s, l, b) in enumerate(zip(scores, labels, boxes)):
60 | h_inds = (l == 1) & (s > threshold)
61 | o_inds = (s > threshold)
62 |
63 | h_box, h_cat = b[h_inds], s[h_inds]
64 | o_box, o_cat = b[o_inds], s[o_inds]
65 |
66 | # for scenario 1 in v-coco dataset
67 | o_inds = torch.cat((o_inds, torch.ones(1).type(torch.bool).to(o_inds.device))) # add an empty box
68 | o_box = torch.cat((o_box, torch.Tensor([0, 0, 0, 0]).unsqueeze(0).to(o_box.device)))
69 |
70 | result_dict = {
71 | 'h_box': h_box, 'h_cat': h_cat,
72 | 'o_box': o_box, 'o_cat': o_cat,
73 | 'scores': s, 'labels': l, 'boxes': b
74 | }
75 |
76 | h_inds_lst = (h_inds == True).nonzero(as_tuple=False).squeeze(-1)
77 | o_inds_lst = (o_inds == True).nonzero(as_tuple=False).squeeze(-1)
78 |
79 | K = boxes.shape[1]
80 | n_act = pair_actions[batch_idx][:, :-1].shape[-1]
81 | score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
82 | sorted_score = torch.zeros((n_act, K, K+1)).to(pair_actions[batch_idx].device)
83 | id_score = torch.zeros((K, K+1)).to(pair_actions[batch_idx].device)
84 |
85 | # Score function: strange HOTR scoring
86 | for hs, h_idx, os, o_idx, pair_action in zip(h_idx_score[batch_idx], h_indices[batch_idx], o_idx_score[batch_idx], o_indices[batch_idx], pair_actions[batch_idx]):
87 | matching_score = (1-pair_action[-1]) # no interaction score
88 | if h_idx == o_idx: o_idx = -1 # special case, head=tail
89 | if matching_score > id_score[h_idx, o_idx]:
90 | id_score[h_idx, o_idx] = matching_score
91 | sorted_score[:, h_idx, o_idx] = matching_score * pair_action[:-1]
92 | score[:, h_idx, o_idx] += matching_score * pair_action[:-1]
93 |
94 | score += sorted_score
95 | score = score[:, h_inds, :]
96 | score = score[:, :, o_inds]
97 |
98 | result_dict.update({
99 | 'pair_score': score,
100 | 'hoi_recognition_time': hoi_recognition_time,
101 | })
102 |
103 | results.append(result_dict)
104 |
105 | elif dataset == 'hico-det':
106 | """ Compute HOI triplet prediction score for HICO-DET.
107 | For HICO-DET, we follow the same scoring function but do not accumulate the results.
108 | """
109 | out_time = outputs['hoi_recognition_time']
110 |
111 | start_time = time.time()
112 | out_obj_logits, out_verb_logits = outputs['pred_obj_logits'], outputs['pred_actions']
113 | out_verb_logits = outputs['pred_actions']
114 |
115 | # actions
116 | matching_scores = (1-out_verb_logits.sigmoid()[..., -1:]) #* (1-out_verb_logits.sigmoid()[..., 57:58])
117 | verb_scores = out_verb_logits.sigmoid()[..., :-1] * matching_scores
118 |
119 | # hbox, obox
120 | outputs_hrepr, outputs_orepr = outputs['pred_hidx'], outputs['pred_oidx']
121 | obj_scores, obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1)
122 |
123 | h_prob = F.softmax(outputs_hrepr, -1)
124 | h_idx_score, h_indices = h_prob.max(-1)
125 |
126 | # targets
127 | o_prob = F.softmax(outputs_orepr, -1)
128 | o_idx_score, o_indices = o_prob.max(-1)
129 | hoi_recognition_time = (time.time() - start_time) + out_time
130 |
131 | # hidx, oidx
132 | sub_boxes, obj_boxes = [], []
133 | for batch_id, (box, h_idx, o_idx) in enumerate(zip(boxes, h_indices, o_indices)):
134 | sub_boxes.append(box[h_idx, :])
135 | obj_boxes.append(box[o_idx, :])
136 | sub_boxes = torch.stack(sub_boxes, dim=0)
137 | obj_boxes = torch.stack(obj_boxes, dim=0)
138 |
139 | # accumulate results (iterate through interaction queries)
140 | results = []
141 | for os, ol, vs, ms, sb, ob in zip(obj_scores, obj_labels, verb_scores, matching_scores, sub_boxes, obj_boxes):
142 | sl = torch.full_like(ol, 0) # self.subject_category_id = 0 in HICO-DET
143 | l = torch.cat((sl, ol))
144 | b = torch.cat((sb, ob))
145 | results.append({'labels': l.to('cpu'), 'boxes': b.to('cpu')})
146 | vs = vs * os.unsqueeze(1)
147 | ids = torch.arange(b.shape[0])
148 | res_dict = {
149 | 'verb_scores': vs.to('cpu'),
150 | 'sub_ids': ids[:ids.shape[0] // 2],
151 | 'obj_ids': ids[ids.shape[0] // 2:],
152 | 'hoi_recognition_time': hoi_recognition_time
153 | }
154 | results[-1].update(res_dict)
155 | else:
156 | results = [{'scores': s, 'labels': l, 'boxes': b} for s, l, b in zip(scores, labels, boxes)]
157 |
158 | return results
--------------------------------------------------------------------------------
/src/models/detr.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/models/detr.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | """
9 | DETR & HOTR model and criterion classes.
10 | """
11 | import torch
12 | import torch.nn.functional as F
13 | from torch import nn
14 |
15 | from src.util.misc import (NestedTensor, nested_tensor_from_tensor_list)
16 |
17 | from .backbone import build_backbone
18 | from .detr_matcher import build_matcher
19 | from .hotr_matcher import build_hoi_matcher
20 | from .transformer import build_transformer, build_hoi_transformer
21 | from .criterion import SetCriterion
22 | from .post_process import PostProcess
23 | from .feed_forward import MLP
24 |
25 | from .hotr import HOTR
26 | from .stip import STIP, STIPPostProcess, STIPCriterion
27 |
28 | class DETR(nn.Module):
29 | """ This is the DETR module that performs object detection """
30 | def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):
31 | """ Initializes the model.
32 | Parameters:
33 | backbone: torch module of the backbone to be used. See backbone.py
34 | transformer: torch module of the transformer architecture. See transformer.py
35 | num_classes: number of object classes
36 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects
37 | DETR can detect in a single image. For COCO, we recommend 100 queries.
38 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used.
39 | """
40 | super().__init__()
41 | self.num_queries = num_queries
42 | self.transformer = transformer
43 | hidden_dim = transformer.d_model
44 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1)
45 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)
46 | self.query_embed = nn.Embedding(num_queries, hidden_dim)
47 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)
48 | self.backbone = backbone
49 | self.aux_loss = aux_loss
50 |
51 | def forward(self, samples: NestedTensor, targets=None):
52 | """ The forward expects a NestedTensor, which consists of:
53 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
54 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
55 | It returns a dict with the following elements:
56 | - "pred_logits": the classification logits (including no-object) for all queries.
57 | Shape= [batch_size x num_queries x (num_classes + 1)]
58 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as
59 | (center_x, center_y, height, width). These values are normalized in [0, 1],
60 | relative to the size of each individual image (disregarding possible padding).
61 | See PostProcess for information on how to retrieve the unnormalized bounding box.
62 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of
63 | dictionnaries containing the two above keys for each decoder layer.
64 | """
65 | if isinstance(samples, (list, torch.Tensor)):
66 | samples = nested_tensor_from_tensor_list(samples)
67 | features, pos = self.backbone(samples)
68 |
69 | src, mask = features[-1].decompose()
70 | assert mask is not None
71 | hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0]
72 |
73 | outputs_class = self.class_embed(hs)
74 | outputs_coord = self.bbox_embed(hs).sigmoid()
75 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
76 | if self.aux_loss:
77 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)
78 |
79 | return out
80 |
81 | @torch.jit.unused
82 | def _set_aux_loss(self, outputs_class, outputs_coord):
83 | # this is a workaround to make torchscript happy, as torchscript
84 | # doesn't support dictionary with non-homogeneous values, such
85 | # as a dict having both a Tensor and a list.
86 | return [{'pred_logits': a, 'pred_boxes': b}
87 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
88 |
89 |
90 | def build(args):
91 | device = torch.device(args.device)
92 |
93 | backbone = build_backbone(args)
94 |
95 | transformer = build_transformer(args)
96 |
97 | model = DETR(
98 | backbone,
99 | transformer,
100 | num_classes=args.num_classes,
101 | num_queries=args.num_queries,
102 | aux_loss=args.aux_loss,
103 | )
104 |
105 | matcher = build_matcher(args)
106 | weight_dict = {'loss_ce': 1, 'loss_bbox': args.bbox_loss_coef}
107 | weight_dict['loss_giou'] = args.giou_loss_coef
108 |
109 | # TODO this is a hack
110 | if args.aux_loss:
111 | aux_weight_dict = {}
112 | for i in range(args.dec_layers - 1):
113 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
114 | weight_dict.update(aux_weight_dict)
115 |
116 | losses = ['labels', 'boxes', 'cardinality'] if args.frozen_weights is None else []
117 | if args.HOIDet and args.STIP_relation_head:
118 | model = STIP(args, detr=model, detr_matcher=matcher)
119 | criterion = STIPCriterion(args, matcher)
120 | postprocessors = {'hoi': STIPPostProcess(args, model)}
121 | elif args.HOIDet:
122 | hoi_matcher = build_hoi_matcher(args)
123 | hoi_losses = []
124 | hoi_losses.append('pair_labels')
125 | hoi_losses.append('pair_actions')
126 | if args.dataset_file == 'hico-det': hoi_losses.append('pair_targets')
127 |
128 | hoi_weight_dict={}
129 | hoi_weight_dict['loss_hidx'] = args.hoi_idx_loss_coef
130 | hoi_weight_dict['loss_oidx'] = args.hoi_idx_loss_coef
131 | hoi_weight_dict['loss_act'] = args.hoi_act_loss_coef
132 | if args.dataset_file == 'hico-det': hoi_weight_dict['loss_tgt'] = args.hoi_tgt_loss_coef
133 | if args.hoi_aux_loss:
134 | hoi_aux_weight_dict = {}
135 | for i in range(args.hoi_dec_layers):
136 | hoi_aux_weight_dict.update({k + f'_{i}': v for k, v in hoi_weight_dict.items()})
137 | hoi_weight_dict.update(hoi_aux_weight_dict)
138 |
139 | criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=hoi_weight_dict,
140 | eos_coef=args.eos_coef, losses=losses, num_actions=args.num_actions,
141 | HOI_losses=hoi_losses, HOI_matcher=hoi_matcher, args=args)
142 |
143 | interaction_transformer = build_hoi_transformer(args) # if (args.share_enc and args.pretrained_dec) else None
144 |
145 | kwargs = {}
146 | if args.dataset_file == 'hico-det': kwargs['return_obj_class'] = args.valid_obj_ids
147 | model = HOTR(
148 | detr=model,
149 | num_hoi_queries=args.num_hoi_queries,
150 | num_actions=args.num_actions,
151 | interaction_transformer=interaction_transformer,
152 | freeze_detr=(args.frozen_weights is not None),
153 | share_enc=args.share_enc,
154 | pretrained_dec=args.pretrained_dec,
155 | temperature=args.temperature,
156 | hoi_aux_loss=args.hoi_aux_loss,
157 | **kwargs # only return verb class for HICO-DET dataset
158 | )
159 | postprocessors = {'hoi': PostProcess(args.HOIDet)}
160 | else:
161 | criterion = SetCriterion(args.num_classes, matcher=matcher, weight_dict=weight_dict,
162 | eos_coef=args.eos_coef, losses=losses, num_actions=args.num_actions, args=args)
163 | postprocessors = {'bbox': PostProcess(args.HOIDet)}
164 | criterion.to(device)
165 |
166 | return model, criterion, postprocessors
--------------------------------------------------------------------------------
/src/engine/arg_parser.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : engine/arg_parser.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # Modified arguments are represented with *
5 | # ------------------------------------------------------------------------
6 | # Modified from DETR (https://github.com/facebookresearch/detr)
7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
8 | # ------------------------------------------------------------------------
9 | import argparse
10 |
11 | def get_args_parser():
12 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
13 | parser.add_argument('--lr', default=1e-4, type=float)
14 | parser.add_argument('--lr_backbone', default=1e-5, type=float)
15 | parser.add_argument('--batch_size', default=2, type=int)
16 | parser.add_argument('--weight_decay', default=1e-4, type=float)
17 | parser.add_argument('--epochs', default=100, type=int)
18 | parser.add_argument('--lr_drop', default=80, type=int)
19 | parser.add_argument('--clip_max_norm', default=0.1, type=float,
20 | help='gradient clipping max norm')
21 |
22 | # DETR Model parameters
23 | parser.add_argument('--frozen_weights', type=str, default=None,
24 | help="Path to the pretrained model. If set, only the mask head will be trained")
25 | # DETR Backbone
26 | parser.add_argument('--backbone', default='resnet50', type=str,
27 | help="Name of the convolutional backbone to use")
28 | parser.add_argument('--dilation', action='store_true',
29 | help="If true, we replace stride with dilation in the last convolutional block (DC5)")
30 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
31 | help="Type of positional embedding to use on top of the image features")
32 |
33 | # DETR Transformer (= Encoder, Instance Decoder)
34 | parser.add_argument('--enc_layers', default=6, type=int,
35 | help="Number of encoding layers in the transformer")
36 | parser.add_argument('--dec_layers', default=6, type=int,
37 | help="Number of decoding layers in the transformer")
38 | parser.add_argument('--dim_feedforward', default=2048, type=int,
39 | help="Intermediate size of the feedforward layers in the transformer blocks")
40 | parser.add_argument('--hidden_dim', default=256, type=int,
41 | help="Size of the embeddings (dimension of the transformer)")
42 | parser.add_argument('--dropout', default=0.1, type=float,
43 | help="Dropout applied in the transformer")
44 | parser.add_argument('--nheads', default=8, type=int,
45 | help="Number of attention heads inside the transformer's attentions")
46 | parser.add_argument('--num_queries', default=100, type=int,
47 | help="Number of query slots")
48 | parser.add_argument('--pre_norm', action='store_true')
49 |
50 | # Segmentation
51 | parser.add_argument('--masks', action='store_true',
52 | help="Train segmentation head if the flag is provided")
53 |
54 | # Loss Option
55 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
56 | help="Disables auxiliary decoding losses (loss at each layer)")
57 |
58 | # Loss coefficients (DETR)
59 | parser.add_argument('--mask_loss_coef', default=1, type=float)
60 | parser.add_argument('--dice_loss_coef', default=1, type=float)
61 | parser.add_argument('--bbox_loss_coef', default=5, type=float)
62 | parser.add_argument('--giou_loss_coef', default=2, type=float)
63 | parser.add_argument('--eos_coef', default=0.1, type=float,
64 | help="Relative classification weight of the no-object class")
65 |
66 | # Matcher (DETR)
67 | parser.add_argument('--set_cost_class', default=1, type=float,
68 | help="Class coefficient in the matching cost")
69 | parser.add_argument('--set_cost_bbox', default=5, type=float,
70 | help="L1 box coefficient in the matching cost")
71 | parser.add_argument('--set_cost_giou', default=2, type=float,
72 | help="giou box coefficient in the matching cost")
73 |
74 | # * HOI Detection
75 | parser.add_argument('--HOIDet', action='store_true',
76 | help="Train HOI Detection head if the flag is provided")
77 | parser.add_argument('--share_enc', action='store_true',
78 | help="Share the Encoder in DETR for HOI Detection if the flag is provided")
79 | parser.add_argument('--pretrained_dec', action='store_true',
80 | help="Use Pre-trained Decoder in DETR for Interaction Decoder if the flag is provided")
81 | parser.add_argument('--hoi_enc_layers', default=1, type=int,
82 | help="Number of decoding layers in HOI transformer")
83 | parser.add_argument('--hoi_dec_layers', default=6, type=int,
84 | help="Number of decoding layers in HOI transformer")
85 | parser.add_argument('--hoi_nheads', default=8, type=int,
86 | help="Number of decoding layers in HOI transformer")
87 | parser.add_argument('--hoi_dim_feedforward', default=2048, type=int,
88 | help="Number of decoding layers in HOI transformer")
89 | # parser.add_argument('--hoi_mode', type=str, default=None, help='[inst | pair | all]')
90 | parser.add_argument('--num_hoi_queries', default=32, type=int,
91 | help="Number of Queries for Interaction Decoder")
92 | parser.add_argument('--hoi_aux_loss', action='store_true')
93 |
94 |
95 | # * HOTR Matcher
96 | parser.add_argument('--set_cost_idx', default=1, type=float,
97 | help="IDX coefficient in the matching cost")
98 | parser.add_argument('--set_cost_act', default=1, type=float,
99 | help="Action coefficient in the matching cost")
100 | parser.add_argument('--set_cost_tgt', default=1, type=float,
101 | help="Target coefficient in the matching cost")
102 |
103 | # * HOTR Loss coefficients
104 | parser.add_argument('--temperature', default=0.05, type=float, help="temperature")
105 | parser.add_argument('--hoi_idx_loss_coef', default=1, type=float)
106 | parser.add_argument('--hoi_act_loss_coef', default=1, type=float)
107 | parser.add_argument('--hoi_tgt_loss_coef', default=1, type=float)
108 | parser.add_argument('--hoi_eos_coef', default=0.1, type=float, help="Relative classification weight of the no-object class")
109 |
110 | # * dataset parameters
111 | parser.add_argument('--dataset_file', help='[coco | vcoco]')
112 | parser.add_argument('--data_path', type=str)
113 | parser.add_argument('--object_threshold', type=float, default=0, help='Threshold for object confidence')
114 |
115 | # machine parameters
116 | parser.add_argument('--output_dir', default='',
117 | help='path where to save, empty for no saving')
118 | parser.add_argument('--custom_path', default='',
119 | help="Data path for custom inference. Only required for custom_main.py")
120 | parser.add_argument('--device', default='cuda',
121 | help='device to use for training / testing')
122 | parser.add_argument('--seed', default=42, type=int)
123 | parser.add_argument('--resume', default='', help='resume from checkpoint')
124 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
125 | help='start epoch')
126 | parser.add_argument('--num_workers', default=4, type=int)
127 |
128 | # mode
129 | parser.add_argument('--eval', action='store_true', help="Only evaluate results if the flag is provided")
130 | parser.add_argument('--validate', action='store_true', help="Validate after every epoch")
131 |
132 | # distributed training parameters
133 | parser.add_argument('--world_size', default=1, type=int,
134 | help='number of distributed processes')
135 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
136 |
137 | # * WanDB
138 | parser.add_argument('--wandb', action='store_true')
139 | parser.add_argument('--project_name', default='HOTR')
140 | parser.add_argument('--group_name', default='KakaoBrain')
141 | parser.add_argument('--run_name', default='run_000001')
142 |
143 | # STIP
144 | parser.add_argument('--STIP_relation_head', action='store_true', default=False)
145 | parser.add_argument('--finetune_detr', action='store_true', default=False)
146 | parser.add_argument('--use_high_resolution_relation_feature_map', action='store_true', default=False)
147 | return parser
148 |
--------------------------------------------------------------------------------
/src/metrics/vcoco/ap_role.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from src.metrics.utils import _compute_ap, compute_overlap
4 |
5 | class APRole(object):
6 | def __init__(self, act_name, scenario_flag=True, iou_threshold=0.5):
7 | self.act_name = act_name
8 | self.iou_threshold = iou_threshold
9 |
10 | self.scenario_flag = scenario_flag
11 | # scenario_1 : True
12 | # scenario_2 : False
13 |
14 | self.fp = [np.zeros((0,))] * len(act_name)
15 | self.tp = [np.zeros((0,))] * len(act_name)
16 | self.score = [np.zeros((0,))] * len(act_name)
17 | self.num_ann = [0] * len(act_name)
18 |
19 | def add_data(self, h_box, o_box, score, i_box, i_act, p_box, p_act):
20 | # i_box, i_act : to check if only in COCO
21 | for label in range(len(self.act_name)):
22 | p_inds = (p_act[:, label] == 1)
23 | self.num_ann[label] += p_inds.sum()
24 |
25 | if h_box.shape[0] == 0 : return # if no prediction, just return
26 | # COCO (O), V-COCO (X) __or__ collater, no ann in image => ignore
27 |
28 | valid_i_inds = (i_act[:, 0] != -1) # (n_i, )
29 | overlaps = compute_overlap(h_box, i_box) # (n_h, n_i)
30 | assigned_input = np.argmax(overlaps, axis=1) # (n_h, )
31 | v_inds = valid_i_inds[assigned_input] # (n_h, )
32 |
33 | h_box = h_box[v_inds]
34 | score = score[:, v_inds, :]
35 | if h_box.shape[0] == 0 : return
36 | n_h = h_box.shape[0]
37 |
38 | valid_p_inds = (p_act[:, 0] != -1) | (p_box[:, 0] != -1)
39 | p_act = p_act[valid_p_inds]
40 | p_box = p_box[valid_p_inds]
41 |
42 | n_o = o_box.shape[0]
43 | if n_o == 0:
44 | # no prediction for object
45 | score = score.squeeze(axis=2) # (#act, n_h)
46 |
47 | for label in range(len(self.act_name)):
48 | h_inds = np.argsort(score[label])[::-1] # (n_h, )
49 | self.score[label] = np.append(self.score[label], score[label, h_inds])
50 |
51 | p_inds = (p_act[:, label] == 1)
52 | if p_inds.sum() == 0:
53 | self.tp[label] = np.append(self.tp[label], np.array([0]*n_h))
54 | self.fp[label] = np.append(self.fp[label], np.array([1]*n_h))
55 | continue
56 |
57 | h_overlaps = compute_overlap(h_box[h_inds], p_box[p_inds, :4]) # (n_h, n_p)
58 | assigned_p = np.argmax(h_overlaps, axis=1) # (n_h, )
59 | h_max_overlap = h_overlaps[range(n_h), assigned_p] # (n_h, )
60 |
61 | o_overlaps = compute_overlap(np.zeros((n_h, 4)), p_box[p_inds][assigned_p, 4:8])
62 | o_overlaps = np.diag(o_overlaps) # (n_h, )
63 |
64 | no_role_inds = (p_box[p_inds][assigned_p, 4] == -1) # (n_h, )
65 | # human (o), action (o), no object in actual image
66 |
67 | h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, )
68 | o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, )
69 |
70 | # scenario1 is not considered (already no object)
71 | o_iou_inds[no_role_inds] = 1
72 |
73 | iou_inds = (h_iou_inds & o_iou_inds)
74 | p_nonzero = iou_inds.nonzero()[0]
75 | p_inds = assigned_p[p_nonzero]
76 | p_iou = np.unique(p_inds, return_index=True)[1]
77 | p_tp = p_nonzero[p_iou]
78 |
79 | t = np.zeros(n_h, dtype=np.uint8)
80 | t[p_tp] = 1
81 | f = 1-t
82 |
83 | self.tp[label] = np.append(self.tp[label], t)
84 | self.fp[label] = np.append(self.fp[label], f)
85 |
86 | else:
87 | s_obj_argmax = np.argmax(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h)
88 | s_obj_max = np.max(score.reshape(-1, n_o), axis=1).reshape(-1, n_h) # (#act, n_h)
89 |
90 | h_overlaps = compute_overlap(h_box, p_box[:, :4]) # (n_h, n_p)
91 | for label in range(len(self.act_name)):
92 | h_inds = np.argsort(s_obj_max[label])[::-1] # (n_h, )
93 | self.score[label] = np.append(self.score[label], s_obj_max[label, h_inds])
94 |
95 | p_inds = (p_act[:, label] == 1) # (n_p, )
96 | if p_inds.sum() == 0: ## no such relation, all considered as FP
97 | self.tp[label] = np.append(self.tp[label], np.array([0]*n_h))
98 | self.fp[label] = np.append(self.fp[label], np.array([1]*n_h))
99 | continue
100 |
101 | h_overlaps = compute_overlap(h_box[h_inds], p_box[:, :4]) # (n_h, n_p) # match for all hboxes
102 | h_max_overlap = np.max(h_overlaps, axis=1) # (n_h, ) # get the max overlap for hbox
103 |
104 | # for same human, multiple pairs exist. find the human box that has the same idx with max overlap hbox.
105 | h_max_temp = np.expand_dims(h_max_overlap, axis=1)
106 | h_over_thresh = (h_overlaps == h_max_temp) # (n_h, n_p)
107 | h_over_thresh = h_over_thresh & np.expand_dims(p_inds, axis=0) # (n_h, n_p)
108 |
109 | h_valid = h_over_thresh.sum(axis=1)>0 # (n_h, ) # at least one is True
110 | # h_valid -> if all is False, then argmax becomes 0. <- prevent
111 | assigned_p = np.argmax(h_over_thresh, axis=1) # (n_h, ) # p only for current act
112 |
113 | o_mapping_box = o_box[s_obj_argmax[label]][h_inds] # (n_h, ) # find where T is.
114 | p_mapping_box = p_box[assigned_p, 4:8] # (n_h, 4)
115 |
116 | o_overlaps = compute_overlap(o_mapping_box, p_mapping_box) # --: matching object
117 | o_overlaps = np.diag(o_overlaps) # (n_h, )
118 | o_overlaps.setflags(write=1)
119 | if (~h_valid).sum() > 0:
120 | o_overlaps[~h_valid] = 0 # (n_h, )
121 |
122 | no_role_inds = (p_box[assigned_p, 4] == -1) # (n_h, ) -- GT relation obj box is empty
123 | nan_box_inds = np.all(o_mapping_box == 0, axis=1) | np.all(np.isnan(o_mapping_box), axis=1) ## predicted relation obj box is 0 or nan
124 | no_role_inds = no_role_inds & h_valid
125 | nan_box_inds = nan_box_inds & h_valid
126 |
127 | h_iou_inds = (h_max_overlap > self.iou_threshold) # (n_h, )
128 | o_iou_inds = (o_overlaps > self.iou_threshold) # (n_h, )
129 |
130 | if self.scenario_flag: ## scenario_1: for empty GT relation obj box, the predicted obj box should also be correct
131 | o_iou_inds[no_role_inds & nan_box_inds] = 1
132 | o_iou_inds[no_role_inds & ~nan_box_inds] = 0
133 | else: ## scenario_2: for empty GT relation obj box, ignore prediction for obj box
134 | o_iou_inds[no_role_inds] = 1
135 |
136 | iou_inds = (h_iou_inds & o_iou_inds)
137 | p_nonzero = iou_inds.nonzero()[0]
138 | p_inds = assigned_p[p_nonzero]
139 | p_iou = np.unique(p_inds, return_index=True)[1] ## remove duplicate, one GT match only once
140 | p_tp = p_nonzero[p_iou]
141 |
142 | t = np.zeros(n_h, dtype=np.uint8)
143 | t[p_tp] = 1
144 | f = 1-t
145 |
146 | self.tp[label] = np.append(self.tp[label], t)
147 | self.fp[label] = np.append(self.fp[label], f)
148 | # print('add data finished')
149 |
150 | def evaluate(self, print_log=False):
151 | average_precisions = dict()
152 | role_num = 1 if self.scenario_flag else 2
153 | for label in range(len(self.act_name)):
154 |
155 | # sort by score
156 | indices = np.argsort(-self.score[label])
157 | self.fp[label] = self.fp[label][indices] ## compute AP over all predicted relations across the dataset by ranking
158 | self.tp[label] = self.tp[label][indices]
159 |
160 |
161 | if self.num_ann[label] == 0:
162 | average_precisions[label] = 0
163 | continue
164 |
165 | # compute false positives and true positives
166 | self.fp[label] = np.cumsum(self.fp[label])
167 | self.tp[label] = np.cumsum(self.tp[label])
168 |
169 | # compute recall and precision
170 | recall = self.tp[label] / self.num_ann[label]
171 | precision = self.tp[label] / np.maximum(self.tp[label] + self.fp[label], np.finfo(np.float64).eps)
172 |
173 | # compute average precision
174 | average_precisions[label] = _compute_ap(recall, precision) * 100
175 |
176 | if print_log: print(f'\n============= AP (Role scenario_{role_num}) ==============')
177 | s, n = 0, 0
178 |
179 | for label in range(len(self.act_name)):
180 | if 'point' in self.act_name[label]:
181 | continue
182 | label_name = "_".join(self.act_name[label].split("_")[1:])
183 | if print_log: print('{: >23}: AP = {:0.2f} (#pos = {:d})'.format(label_name, average_precisions[label], self.num_ann[label]))
184 | if self.num_ann[label] != 0 :
185 | s += average_precisions[label]
186 | n += 1
187 |
188 | mAP = s/n
189 | if print_log:
190 | print('| mAP(role scenario_{:d}): {:0.2f}'.format(role_num, mAP))
191 | print('----------------------------------------------------')
192 |
193 | return mAP
--------------------------------------------------------------------------------
/src/data/evaluators/coco_eval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | COCO evaluator that works in distributed mode.
4 | Mostly copy-paste from https://github.com/pytorch/vision/blob/edfd5a7/references/detection/coco_eval.py
5 | The difference is that there is less copy-pasting from pycocotools
6 | in the end of the file, as python3 can suppress prints with contextlib
7 | """
8 | import os
9 | import contextlib
10 | import copy
11 | import numpy as np
12 | import torch
13 |
14 | from pycocotools.cocoeval import COCOeval
15 | from pycocotools.coco import COCO
16 | import pycocotools.mask as mask_util
17 |
18 | from src.util.misc import all_gather
19 |
20 |
21 | class CocoEvaluator(object):
22 | def __init__(self, coco_gt, iou_types):
23 | assert isinstance(iou_types, (list, tuple))
24 | coco_gt = copy.deepcopy(coco_gt)
25 | self.coco_gt = coco_gt
26 |
27 | self.iou_types = iou_types
28 | self.coco_eval = {}
29 | for iou_type in iou_types:
30 | self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
31 |
32 | self.img_ids = []
33 | self.eval_imgs = {k: [] for k in iou_types}
34 |
35 | def update(self, predictions):
36 | img_ids = list(np.unique(list(predictions.keys())))
37 | self.img_ids.extend(img_ids)
38 |
39 | for iou_type in self.iou_types:
40 | results = self.prepare(predictions, iou_type)
41 |
42 | # suppress pycocotools prints
43 | with open(os.devnull, 'w') as devnull:
44 | with contextlib.redirect_stdout(devnull):
45 | coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
46 | coco_eval = self.coco_eval[iou_type]
47 |
48 | coco_eval.cocoDt = coco_dt
49 | coco_eval.params.imgIds = list(img_ids)
50 | img_ids, eval_imgs = evaluate(coco_eval)
51 |
52 | self.eval_imgs[iou_type].append(eval_imgs)
53 |
54 | def synchronize_between_processes(self):
55 | for iou_type in self.iou_types:
56 | self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
57 | create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
58 |
59 | def accumulate(self):
60 | for coco_eval in self.coco_eval.values():
61 | coco_eval.accumulate()
62 |
63 | def summarize(self):
64 | for iou_type, coco_eval in self.coco_eval.items():
65 | print("IoU metric: {}".format(iou_type))
66 | coco_eval.summarize()
67 |
68 | def prepare(self, predictions, iou_type):
69 | if iou_type == "bbox":
70 | return self.prepare_for_coco_detection(predictions)
71 | elif iou_type == "segm":
72 | return self.prepare_for_coco_segmentation(predictions)
73 | elif iou_type == "keypoints":
74 | return self.prepare_for_coco_keypoint(predictions)
75 | else:
76 | raise ValueError("Unknown iou type {}".format(iou_type))
77 |
78 | def prepare_for_coco_detection(self, predictions):
79 | coco_results = []
80 | for original_id, prediction in predictions.items():
81 | if len(prediction) == 0:
82 | continue
83 |
84 | boxes = prediction["boxes"]
85 | boxes = convert_to_xywh(boxes).tolist()
86 | scores = prediction["scores"].tolist()
87 | labels = prediction["labels"].tolist()
88 |
89 | coco_results.extend(
90 | [
91 | {
92 | "image_id": original_id,
93 | "category_id": labels[k],
94 | "bbox": box,
95 | "score": scores[k],
96 | }
97 | for k, box in enumerate(boxes)
98 | ]
99 | )
100 | return coco_results
101 |
102 | def prepare_for_coco_segmentation(self, predictions):
103 | coco_results = []
104 | for original_id, prediction in predictions.items():
105 | if len(prediction) == 0:
106 | continue
107 |
108 | scores = prediction["scores"]
109 | labels = prediction["labels"]
110 | masks = prediction["masks"]
111 |
112 | masks = masks > 0.5
113 |
114 | scores = prediction["scores"].tolist()
115 | labels = prediction["labels"].tolist()
116 |
117 | rles = [
118 | mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0]
119 | for mask in masks
120 | ]
121 | for rle in rles:
122 | rle["counts"] = rle["counts"].decode("utf-8")
123 |
124 | coco_results.extend(
125 | [
126 | {
127 | "image_id": original_id,
128 | "category_id": labels[k],
129 | "segmentation": rle,
130 | "score": scores[k],
131 | }
132 | for k, rle in enumerate(rles)
133 | ]
134 | )
135 | return coco_results
136 |
137 | def prepare_for_coco_keypoint(self, predictions):
138 | coco_results = []
139 | for original_id, prediction in predictions.items():
140 | if len(prediction) == 0:
141 | continue
142 |
143 | boxes = prediction["boxes"]
144 | boxes = convert_to_xywh(boxes).tolist()
145 | scores = prediction["scores"].tolist()
146 | labels = prediction["labels"].tolist()
147 | keypoints = prediction["keypoints"]
148 | keypoints = keypoints.flatten(start_dim=1).tolist()
149 |
150 | coco_results.extend(
151 | [
152 | {
153 | "image_id": original_id,
154 | "category_id": labels[k],
155 | 'keypoints': keypoint,
156 | "score": scores[k],
157 | }
158 | for k, keypoint in enumerate(keypoints)
159 | ]
160 | )
161 | return coco_results
162 |
163 |
164 | def convert_to_xywh(boxes):
165 | xmin, ymin, xmax, ymax = boxes.unbind(1)
166 | return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
167 |
168 |
169 | def merge(img_ids, eval_imgs):
170 | all_img_ids = all_gather(img_ids)
171 | all_eval_imgs = all_gather(eval_imgs)
172 |
173 | merged_img_ids = []
174 | for p in all_img_ids:
175 | merged_img_ids.extend(p)
176 |
177 | merged_eval_imgs = []
178 | for p in all_eval_imgs:
179 | merged_eval_imgs.append(p)
180 |
181 | merged_img_ids = np.array(merged_img_ids)
182 | merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
183 |
184 | # keep only unique (and in sorted order) images
185 | merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
186 | merged_eval_imgs = merged_eval_imgs[..., idx]
187 |
188 | return merged_img_ids, merged_eval_imgs
189 |
190 |
191 | def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
192 | img_ids, eval_imgs = merge(img_ids, eval_imgs)
193 | img_ids = list(img_ids)
194 | eval_imgs = list(eval_imgs.flatten())
195 |
196 | coco_eval.evalImgs = eval_imgs
197 | coco_eval.params.imgIds = img_ids
198 | coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
199 |
200 |
201 | #################################################################
202 | # From pycocotools, just removed the prints and fixed
203 | # a Python3 bug about unicode not defined
204 | #################################################################
205 |
206 |
207 | def evaluate(self):
208 | '''
209 | Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
210 | :return: None
211 | '''
212 | # tic = time.time()
213 | # print('Running per image evaluation...')
214 | p = self.params
215 | # add backward compatibility if useSegm is specified in params
216 | if p.useSegm is not None:
217 | p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
218 | print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
219 | # print('Evaluate annotation type *{}*'.format(p.iouType))
220 | p.imgIds = list(np.unique(p.imgIds))
221 | if p.useCats:
222 | p.catIds = list(np.unique(p.catIds))
223 | p.maxDets = sorted(p.maxDets)
224 | self.params = p
225 |
226 | self._prepare()
227 | # loop through images, area range, max detection number
228 | catIds = p.catIds if p.useCats else [-1]
229 |
230 | if p.iouType == 'segm' or p.iouType == 'bbox':
231 | computeIoU = self.computeIoU
232 | elif p.iouType == 'keypoints':
233 | computeIoU = self.computeOks
234 | self.ious = {
235 | (imgId, catId): computeIoU(imgId, catId)
236 | for imgId in p.imgIds
237 | for catId in catIds}
238 |
239 | evaluateImg = self.evaluateImg
240 | maxDet = p.maxDets[-1]
241 | evalImgs = [
242 | evaluateImg(imgId, catId, areaRng, maxDet)
243 | for catId in catIds
244 | for areaRng in p.areaRng
245 | for imgId in p.imgIds
246 | ]
247 | # this is NOT in the pycocotools code, but could be done outside
248 | evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds))
249 | self._paramsEval = copy.deepcopy(self.params)
250 | # toc = time.time()
251 | # print('DONE (t={:0.2f}s).'.format(toc-tic))
252 | return p.imgIds, evalImgs
253 |
254 | #################################################################
255 | # end of straight copy from pycocotools, just removing the prints
256 | #################################################################
257 |
--------------------------------------------------------------------------------
/src/models/hotr_matcher.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/models/hotr_matcher.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | import torch
6 | from scipy.optimize import linear_sum_assignment
7 | from torch import nn
8 |
9 | from src.util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
10 |
11 | import src.util.misc as utils
12 | import wandb
13 |
14 | class HungarianPairMatcher(nn.Module):
15 | def __init__(self, args):
16 | """Creates the matcher
17 | Params:
18 | cost_action: This is the relative weight of the multi-label action classification error in the matching cost
19 | cost_hbox: This is the relative weight of the classification error for human idx in the matching cost
20 | cost_obox: This is the relative weight of the classification error for object idx in the matching cost
21 | """
22 | super().__init__()
23 | self.cost_action = args.set_cost_act
24 | self.cost_hbox = self.cost_obox = args.set_cost_idx
25 | self.cost_target = args.set_cost_tgt
26 | self.log_printer = args.wandb
27 | self.is_vcoco = (args.dataset_file == 'vcoco')
28 | self.is_hico = (args.dataset_file == 'hico-det')
29 | if self.is_vcoco:
30 | self.valid_ids = args.valid_ids
31 | self.invalid_ids = args.invalid_ids
32 | assert self.cost_action != 0 or self.cost_hbox != 0 or self.cost_obox != 0, "all costs cant be 0"
33 |
34 | def reduce_redundant_gt_box(self, tgt_bbox, indices):
35 | """Filters redundant Ground-Truth Bounding Boxes
36 | Due to random crop augmentation, there exists cases where there exists
37 | multiple redundant labels for the exact same bounding box and object class.
38 | This function deals with the redundant labels for smoother HOTR training.
39 | """
40 | tgt_bbox_unique, map_idx, idx_cnt = torch.unique(tgt_bbox, dim=0, return_inverse=True, return_counts=True)
41 |
42 | k_idx, bbox_idx = indices
43 | triggered = False
44 | if (len(tgt_bbox) != len(tgt_bbox_unique)):
45 | map_dict = {k: v for k, v in enumerate(map_idx)}
46 | map_bbox2kidx = {int(bbox_id): k_id for bbox_id, k_id in zip(bbox_idx, k_idx)}
47 |
48 | bbox_lst, k_lst = [], []
49 | for bbox_id in bbox_idx:
50 | if map_dict[int(bbox_id)] not in bbox_lst:
51 | bbox_lst.append(map_dict[int(bbox_id)])
52 | k_lst.append(map_bbox2kidx[int(bbox_id)])
53 | bbox_idx = torch.tensor(bbox_lst)
54 | k_idx = torch.tensor(k_lst)
55 | tgt_bbox_res = tgt_bbox_unique
56 | else:
57 | tgt_bbox_res = tgt_bbox
58 | bbox_idx = bbox_idx.to(tgt_bbox.device)
59 |
60 | return tgt_bbox_res, k_idx, bbox_idx
61 |
62 | @torch.no_grad()
63 | def forward(self, outputs, targets, indices, log=False):
64 | assert "pred_actions" in outputs, "There is no action output for pair matching"
65 | num_obj_queries = outputs["pred_boxes"].shape[1]
66 | bs, num_queries = outputs["pred_actions"].shape[:2]
67 | detr_query_num = outputs["pred_logits"].shape[1] \
68 | if (outputs["pred_oidx"].shape[-1] == (outputs["pred_logits"].shape[1] + 1)) else -1
69 |
70 | return_list = []
71 | if self.log_printer and log:
72 | log_dict = {'h_cost': [], 'o_cost': [], 'act_cost': []}
73 | if self.is_hico: log_dict['tgt_cost'] = []
74 |
75 | for batch_idx in range(bs):
76 | tgt_bbox = targets[batch_idx]["boxes"] # (num_boxes, 4)
77 | tgt_cls = targets[batch_idx]["labels"] # (num_boxes)
78 |
79 | if self.is_vcoco:
80 | targets[batch_idx]["pair_actions"][:, self.invalid_ids] = 0
81 | keep_idx = (targets[batch_idx]["pair_actions"].sum(dim=-1) != 0)
82 | targets[batch_idx]["pair_boxes"] = targets[batch_idx]["pair_boxes"][keep_idx]
83 | targets[batch_idx]["pair_actions"] = targets[batch_idx]["pair_actions"][keep_idx]
84 | targets[batch_idx]["pair_targets"] = targets[batch_idx]["pair_targets"][keep_idx]
85 |
86 | tgt_pbox = targets[batch_idx]["pair_boxes"] # (num_pair_boxes, 8)
87 | tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 29)
88 | tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
89 |
90 | tgt_hbox = tgt_pbox[:, :4] # (num_pair_boxes, 4)
91 | tgt_obox = tgt_pbox[:, 4:] # (num_pair_boxes, 4)
92 | elif self.is_hico:
93 | tgt_act = targets[batch_idx]["pair_actions"] # (num_pair_boxes, 117)
94 | tgt_tgt = targets[batch_idx]["pair_targets"] # (num_pair_boxes)
95 |
96 | tgt_hbox = targets[batch_idx]["sub_boxes"] # (num_pair_boxes, 4)
97 | tgt_obox = targets[batch_idx]["obj_boxes"] # (num_pair_boxes, 4)
98 |
99 | # find which gt boxes match the h, o boxes in the pair
100 | if self.is_vcoco:
101 | hbox_with_cls = torch.cat([tgt_hbox, torch.ones((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
102 | elif self.is_hico:
103 | hbox_with_cls = torch.cat([tgt_hbox, torch.zeros((tgt_hbox.shape[0], 1)).to(tgt_hbox.device)], dim=1)
104 | obox_with_cls = torch.cat([tgt_obox, tgt_tgt.unsqueeze(-1)], dim=1)
105 | obox_with_cls[obox_with_cls[:, :4].sum(dim=1) == -4, -1] = -1 # turn the class of occluded objects to -1
106 |
107 | bbox_with_cls = torch.cat([tgt_bbox, tgt_cls.unsqueeze(-1)], dim=1)
108 | bbox_with_cls, k_idx, bbox_idx = self.reduce_redundant_gt_box(bbox_with_cls, indices[batch_idx])
109 | bbox_with_cls = torch.cat((bbox_with_cls, torch.as_tensor([-1.]*5).unsqueeze(0).to(tgt_cls.device)), dim=0)
110 |
111 | cost_hbox = torch.cdist(hbox_with_cls, bbox_with_cls, p=1)
112 | cost_obox = torch.cdist(obox_with_cls, bbox_with_cls, p=1)
113 |
114 | # find which gt boxes matches which prediction in K
115 | h_match_indices = torch.nonzero(cost_hbox == 0, as_tuple=False) # (num_hbox, num_boxes)
116 | o_match_indices = torch.nonzero(cost_obox == 0, as_tuple=False) # (num_obox, num_boxes)
117 |
118 | tgt_hids, tgt_oids = [], []
119 | # obtain ground truth indices for h
120 | if len(h_match_indices) != len(o_match_indices):
121 | import pdb; pdb.set_trace()
122 |
123 | for h_match_idx, o_match_idx in zip(h_match_indices, o_match_indices):
124 | hbox_idx, H_bbox_idx = h_match_idx
125 | obox_idx, O_bbox_idx = o_match_idx
126 | if O_bbox_idx == (len(bbox_with_cls)-1): # if the object class is -1
127 | O_bbox_idx = H_bbox_idx # happens in V-COCO, the target object may not appear
128 |
129 | GT_idx_for_H = (bbox_idx == H_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
130 | query_idx_for_H = k_idx[GT_idx_for_H]
131 | tgt_hids.append(query_idx_for_H)
132 |
133 | GT_idx_for_O = (bbox_idx == O_bbox_idx).nonzero(as_tuple=False).squeeze(-1)
134 | query_idx_for_O = k_idx[GT_idx_for_O]
135 | tgt_oids.append(query_idx_for_O)
136 |
137 | # check if empty
138 | if len(tgt_hids) == 0: tgt_hids.append(torch.as_tensor([-1])) # we later ignore the label -1
139 | if len(tgt_oids) == 0: tgt_oids.append(torch.as_tensor([-1])) # we later ignore the label -1
140 |
141 | tgt_sum = (tgt_act.sum(dim=-1)).unsqueeze(0)
142 | flag = False
143 | if tgt_act.shape[0] == 0:
144 | tgt_act = torch.zeros((1, tgt_act.shape[1])).to(targets[batch_idx]["pair_actions"].device)
145 | targets[batch_idx]["pair_actions"] = torch.zeros((1, targets[batch_idx]["pair_actions"].shape[1])).to(targets[batch_idx]["pair_actions"].device)
146 | if self.is_hico:
147 | pad_tgt = -1 # outputs["pred_obj_logits"].shape[-1]-1
148 | tgt_tgt = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"])
149 | targets[batch_idx]["pair_targets"] = torch.tensor([pad_tgt]).to(targets[batch_idx]["pair_targets"].device)
150 | tgt_sum = (tgt_act.sum(dim=-1) + 1).unsqueeze(0)
151 |
152 | # Concat target label
153 | tgt_hids = torch.cat(tgt_hids)
154 | tgt_oids = torch.cat(tgt_oids)
155 |
156 | out_hprob = outputs["pred_hidx"][batch_idx].softmax(-1)
157 | out_oprob = outputs["pred_oidx"][batch_idx].softmax(-1)
158 | out_act = outputs["pred_actions"][batch_idx].clone()
159 | if self.is_vcoco: out_act[..., self.invalid_ids] = 0
160 | if self.is_hico:
161 | out_tgt = outputs["pred_obj_logits"][batch_idx].softmax(-1)
162 | out_tgt[..., -1] = 0 # don't get cost for no-object
163 | tgt_act = torch.cat([tgt_act, torch.zeros(tgt_act.shape[0]).unsqueeze(-1).to(tgt_act.device)], dim=-1)
164 |
165 | cost_hclass = -out_hprob[:, tgt_hids] # [batch_size * num_queries, detr.num_queries+1]
166 | cost_oclass = -out_oprob[:, tgt_oids] # [batch_size * num_queries, detr.num_queries+1]
167 |
168 | cost_pos_act = (-torch.matmul(out_act, tgt_act.t().float())) / tgt_sum
169 | cost_neg_act = (torch.matmul(out_act, (~tgt_act.bool()).type(torch.int64).t().float())) / (~tgt_act.bool()).type(torch.int64).sum(dim=-1).unsqueeze(0)
170 | cost_action = cost_pos_act + cost_neg_act
171 |
172 | h_cost = self.cost_hbox * cost_hclass
173 | o_cost = self.cost_obox * cost_oclass
174 | act_cost = self.cost_action * cost_action
175 |
176 | C = h_cost + o_cost + act_cost
177 | if self.is_hico:
178 | cost_target = -out_tgt[:, tgt_tgt]
179 | tgt_cost = self.cost_target * cost_target
180 | C += tgt_cost
181 | C = C.view(num_queries, -1).cpu()
182 |
183 | return_list.append(linear_sum_assignment(C))
184 | targets[batch_idx]["h_labels"] = tgt_hids.to(tgt_hbox.device)
185 | targets[batch_idx]["o_labels"] = tgt_oids.to(tgt_obox.device)
186 | log_act_cost = torch.zeros([1]).to(tgt_act.device) if tgt_act.shape[0] == 0 else act_cost.min(dim=0)[0].mean()
187 |
188 | if self.log_printer and log:
189 | log_dict['h_cost'].append(h_cost.min(dim=0)[0].mean())
190 | log_dict['o_cost'].append(o_cost.min(dim=0)[0].mean())
191 | log_dict['act_cost'].append(act_cost.min(dim=0)[0].mean())
192 | if self.is_hico: log_dict['tgt_cost'].append(tgt_cost.min(dim=0)[0].mean())
193 | if self.log_printer and log:
194 | log_dict['h_cost'] = torch.stack(log_dict['h_cost']).mean()
195 | log_dict['o_cost'] = torch.stack(log_dict['o_cost']).mean()
196 | log_dict['act_cost'] = torch.stack(log_dict['act_cost']).mean()
197 | if self.is_hico: log_dict['tgt_cost'] = torch.stack(log_dict['tgt_cost']).mean()
198 | if utils.get_rank() == 0: wandb.log(log_dict)
199 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in return_list], targets
200 |
201 | def build_hoi_matcher(args):
202 | return HungarianPairMatcher(args)
203 |
--------------------------------------------------------------------------------
/src/util/misc.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/util/misc.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from DETR (https://github.com/facebookresearch/detr)
6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
7 | # ------------------------------------------------------------------------
8 | """
9 | Misc functions, including distributed helpers.
10 | Mostly copy-paste from torchvision references.
11 | """
12 | import os
13 | import subprocess
14 | from collections import deque
15 | import pickle
16 | from typing import Optional, List
17 |
18 | import torch
19 | import torch.distributed as dist
20 | from torch import Tensor
21 |
22 | # needed due to empty tensor bug in pytorch and torchvision 0.5
23 | import torchvision
24 | if float(torchvision.__version__[:3]) < 0.7:
25 | from torchvision.ops import _new_empty_tensor
26 | from torchvision.ops.misc import _output_size
27 |
28 |
29 | class SmoothedValue(object):
30 | """Track a series of values and provide access to smoothed values over a
31 | window or the global series average.
32 | """
33 |
34 | def __init__(self, window_size=20, fmt=None):
35 | if fmt is None:
36 | fmt = "{median:.4f} ({global_avg:.4f})"
37 | self.deque = deque(maxlen=window_size)
38 | self.total = 0.0
39 | self.count = 0
40 | self.fmt = fmt
41 |
42 | def update(self, value, n=1):
43 | self.deque.append(value)
44 | self.count += n
45 | self.total += value * n
46 |
47 | def synchronize_between_processes(self):
48 | """
49 | Warning: does not synchronize the deque!
50 | """
51 | if not is_dist_avail_and_initialized():
52 | return
53 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
54 | dist.barrier()
55 | dist.all_reduce(t)
56 | t = t.tolist()
57 | self.count = int(t[0])
58 | self.total = t[1]
59 |
60 | @property
61 | def median(self):
62 | d = torch.tensor(list(self.deque))
63 | return d.median().item()
64 |
65 | @property
66 | def avg(self):
67 | d = torch.tensor(list(self.deque), dtype=torch.float32)
68 | return d.mean().item()
69 |
70 | @property
71 | def global_avg(self):
72 | return self.total / self.count
73 |
74 | @property
75 | def max(self):
76 | return max(self.deque)
77 |
78 | @property
79 | def value(self):
80 | return self.deque[-1]
81 |
82 | def __str__(self):
83 | return self.fmt.format(
84 | median=self.median,
85 | avg=self.avg,
86 | global_avg=self.global_avg,
87 | max=self.max,
88 | value=self.value)
89 |
90 |
91 | def all_gather(data):
92 | """
93 | Run all_gather on arbitrary picklable data (not necessarily tensors)
94 | Args:
95 | data: any picklable object
96 | Returns:
97 | list[data]: list of data gathered from each rank
98 | """
99 | world_size = get_world_size()
100 | if world_size == 1:
101 | return [data]
102 |
103 | # serialized to a Tensor
104 | buffer = pickle.dumps(data)
105 | storage = torch.ByteStorage.from_buffer(buffer)
106 | tensor = torch.ByteTensor(storage).to("cuda")
107 |
108 | # obtain Tensor size of each rank
109 | local_size = torch.tensor([tensor.numel()], device="cuda")
110 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
111 | dist.all_gather(size_list, local_size)
112 | size_list = [int(size.item()) for size in size_list]
113 | max_size = max(size_list)
114 |
115 | # receiving Tensor from all ranks
116 | # we pad the tensor because torch all_gather does not support
117 | # gathering tensors of different shapes
118 | tensor_list = []
119 | for _ in size_list:
120 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
121 | if local_size != max_size:
122 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
123 | tensor = torch.cat((tensor, padding), dim=0)
124 | dist.all_gather(tensor_list, tensor)
125 |
126 | data_list = []
127 | for size, tensor in zip(size_list, tensor_list):
128 | buffer = tensor.cpu().numpy().tobytes()[:size]
129 | data_list.append(pickle.loads(buffer))
130 |
131 | return data_list
132 |
133 |
134 | def reduce_dict(input_dict, average=True):
135 | """
136 | Args:
137 | input_dict (dict): all the values will be reduced
138 | average (bool): whether to do average or sum
139 | Reduce the values in the dictionary from all processes so that all processes
140 | have the averaged results. Returns a dict with the same fields as
141 | input_dict, after reduction.
142 | """
143 | world_size = get_world_size()
144 | if world_size < 2:
145 | return input_dict
146 | with torch.no_grad():
147 | names = []
148 | values = []
149 | # sort the keys so that they are consistent across processes
150 | for k in sorted(input_dict.keys()):
151 | names.append(k)
152 | values.append(input_dict[k])
153 | values = torch.stack(values, dim=0)
154 | dist.all_reduce(values)
155 | if average:
156 | values /= world_size
157 | reduced_dict = {k: v for k, v in zip(names, values)}
158 | return reduced_dict
159 |
160 |
161 | def get_sha():
162 | cwd = os.path.dirname(os.path.abspath(__file__))
163 |
164 | def _run(command):
165 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
166 | sha = 'N/A'
167 | diff = "clean"
168 | branch = 'N/A'
169 | try:
170 | sha = _run(['git', 'rev-parse', 'HEAD'])
171 | subprocess.check_output(['git', 'diff'], cwd=cwd)
172 | diff = _run(['git', 'diff-index', 'HEAD'])
173 | diff = "has uncommited changes" if diff else "clean"
174 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
175 | except Exception:
176 | pass
177 | message = f"sha: {sha}, status: {diff}, branch: {branch}"
178 | return message
179 |
180 |
181 | def collate_fn(batch):
182 | batch = list(zip(*batch))
183 | batch[0] = nested_tensor_from_tensor_list(batch[0])
184 | return tuple(batch)
185 |
186 |
187 | def _max_by_axis(the_list):
188 | # type: (List[List[int]]) -> List[int]
189 | maxes = the_list[0]
190 | for sublist in the_list[1:]:
191 | for index, item in enumerate(sublist):
192 | maxes[index] = max(maxes[index], item)
193 | return maxes
194 |
195 |
196 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
197 | # TODO make this more general
198 | if tensor_list[0].ndim == 3:
199 | # TODO make it support different-sized images
200 | max_size = _max_by_axis([list(img.shape) for img in tensor_list])
201 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
202 | batch_shape = [len(tensor_list)] + max_size
203 | b, c, h, w = batch_shape
204 | dtype = tensor_list[0].dtype
205 | device = tensor_list[0].device
206 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
207 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
208 | for img, pad_img, m in zip(tensor_list, tensor, mask):
209 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
210 | m[: img.shape[1], :img.shape[2]] = False
211 | else:
212 | raise ValueError('not supported')
213 | return NestedTensor(tensor, mask)
214 |
215 |
216 | class NestedTensor(object):
217 | def __init__(self, tensors, mask: Optional[Tensor]):
218 | self.tensors = tensors
219 | self.mask = mask
220 |
221 | def to(self, device):
222 | # type: (Device) -> NestedTensor # noqa
223 | cast_tensor = self.tensors.to(device)
224 | mask = self.mask
225 | if mask is not None:
226 | assert mask is not None
227 | cast_mask = mask.to(device)
228 | else:
229 | cast_mask = None
230 | return NestedTensor(cast_tensor, cast_mask)
231 |
232 | def decompose(self):
233 | return self.tensors, self.mask
234 |
235 | def __repr__(self):
236 | return str(self.tensors)
237 |
238 |
239 | def setup_for_distributed(is_master):
240 | """
241 | This function disables printing when not in master process
242 | """
243 | import builtins as __builtin__
244 | builtin_print = __builtin__.print
245 |
246 | def print(*args, **kwargs):
247 | force = kwargs.pop('force', False)
248 | if is_master or force:
249 | builtin_print(*args, **kwargs)
250 |
251 | __builtin__.print = print
252 |
253 |
254 | def is_dist_avail_and_initialized():
255 | if not dist.is_available():
256 | return False
257 | if not dist.is_initialized():
258 | return False
259 | return True
260 |
261 |
262 | def get_world_size():
263 | if not is_dist_avail_and_initialized():
264 | return 1
265 | return dist.get_world_size()
266 |
267 |
268 | def get_rank():
269 | if not is_dist_avail_and_initialized():
270 | return 0
271 | return dist.get_rank()
272 |
273 |
274 | def is_main_process():
275 | return get_rank() == 0
276 |
277 |
278 | def save_on_master(*args, **kwargs):
279 | if is_main_process():
280 | torch.save(*args, **kwargs)
281 |
282 |
283 | def init_distributed_mode(args):
284 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
285 | args.rank = int(os.environ["RANK"])
286 | args.world_size = int(os.environ['WORLD_SIZE'])
287 | args.gpu = int(os.environ['LOCAL_RANK'])
288 | elif 'SLURM_PROCID' in os.environ:
289 | args.rank = int(os.environ['SLURM_PROCID'])
290 | args.gpu = args.rank % torch.cuda.device_count()
291 | else:
292 | print('Not using distributed mode')
293 | args.distributed = False
294 | return
295 |
296 | args.distributed = True
297 |
298 | torch.cuda.set_device(args.gpu)
299 | args.dist_backend = 'nccl'
300 | print('| distributed init (rank {}): {}'.format(
301 | args.rank, args.dist_url), flush=True)
302 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
303 | world_size=args.world_size, rank=args.rank)
304 | torch.distributed.barrier()
305 | setup_for_distributed(args.rank == 0)
306 |
307 |
308 | @torch.no_grad()
309 | def accuracy(output, target, topk=(1,)):
310 | """Computes the precision@k for the specified values of k"""
311 | if target.numel() == 0:
312 | return [torch.zeros([], device=output.device)]
313 | maxk = max(topk)
314 | batch_size = target.size(0)
315 |
316 | _, pred = output.topk(maxk, 1, True, True)
317 | pred = pred.t()
318 | correct = pred.eq(target.view(1, -1).expand_as(pred))
319 |
320 | res = []
321 | for k in topk:
322 | correct_k = correct[:k].view(-1).float().sum(0)
323 | res.append(correct_k.mul_(100.0 / batch_size))
324 | return res
325 |
326 |
327 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
328 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
329 | """
330 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
331 | This will eventually be supported natively by PyTorch, and this
332 | class can go away.
333 | """
334 | if float(torchvision.__version__[:3]) < 0.7:
335 | if input.numel() > 0:
336 | return torch.nn.functional.interpolate(
337 | input, size, scale_factor, mode, align_corners
338 | )
339 |
340 | output_shape = _output_size(2, input, size, scale_factor)
341 | output_shape = list(input.shape[:-2]) + list(output_shape)
342 | return _new_empty_tensor(input, output_shape)
343 | else:
344 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
--------------------------------------------------------------------------------
/src/data/evaluators/hico_eval.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/data/evaluators/hico_eval.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
6 | # Copyright (c) Hitachi, Ltd. All Rights Reserved.
7 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
8 | # ------------------------------------------------------------------------
9 | import numpy as np
10 | from collections import defaultdict
11 |
12 | hico_valid_obj_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
13 |
14 | class HICOEvaluator():
15 | def __init__(self, preds, gts, rare_triplets, non_rare_triplets, correct_mat):
16 | self.overlap_iou = 0.5
17 | self.max_hois = 100
18 | self.mode = 'default'
19 | # self.mode = 'known_objects'
20 |
21 | self.rare_triplets = rare_triplets
22 | self.non_rare_triplets = non_rare_triplets
23 |
24 | self.fp = defaultdict(list)
25 | self.tp = defaultdict(list)
26 | self.score = defaultdict(list)
27 | self.sum_gts = defaultdict(lambda: 0)
28 | self.gt_triplets = []
29 |
30 | self.preds = []
31 | for img_preds in preds:
32 | img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items() if k != 'hoi_recognition_time'}
33 | bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_preds['boxes'], img_preds['labels'])]
34 | hoi_scores = img_preds['verb_scores']
35 | verb_labels = np.tile(np.arange(hoi_scores.shape[1]), (hoi_scores.shape[0], 1))
36 | subject_ids = np.tile(img_preds['sub_ids'], (hoi_scores.shape[1], 1)).T
37 | object_ids = np.tile(img_preds['obj_ids'], (hoi_scores.shape[1], 1)).T
38 |
39 | hoi_scores = hoi_scores.ravel()
40 | verb_labels = verb_labels.ravel()
41 | subject_ids = subject_ids.ravel()
42 | object_ids = object_ids.ravel()
43 |
44 | if len(subject_ids) > 0:
45 | object_labels = np.array([bboxes[object_id]['category_id'] for object_id in object_ids])
46 | masks = correct_mat[verb_labels, object_labels]
47 | hoi_scores *= masks ## remove impossible verb-obj pairs
48 |
49 | hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for
50 | subject_id, object_id, category_id, score in zip(subject_ids, object_ids, verb_labels, hoi_scores)]
51 | hois.sort(key=lambda k: (k.get('score', 0)), reverse=True)
52 | hois = hois[:self.max_hois]
53 | else:
54 | hois = []
55 |
56 | self.preds.append({
57 | 'predictions': bboxes,
58 | 'hoi_prediction': hois
59 | })
60 |
61 | self.gts = []
62 | for img_gts in gts:
63 | img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id'}
64 | self.gts.append({
65 | 'annotations': [{'bbox': bbox, 'category_id': hico_valid_obj_ids.index(label)} for bbox, label in zip(img_gts['boxes'], img_gts['labels'])], # map to valid obj ids
66 | 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in img_gts['hois']]
67 | })
68 | for hoi in self.gts[-1]['hoi_annotation']:
69 | triplet = (self.gts[-1]['annotations'][hoi['subject_id']]['category_id'],
70 | self.gts[-1]['annotations'][hoi['object_id']]['category_id'],
71 | hoi['category_id'])
72 |
73 | if triplet not in self.gt_triplets:
74 | self.gt_triplets.append(triplet)
75 |
76 | self.sum_gts[triplet] += 1
77 | print('prepare for hico eval')
78 |
79 | def evaluate(self):
80 | for img_id, (img_preds, img_gts) in enumerate(zip(self.preds, self.gts)):
81 | print(f"Evaluating Score Matrix... : [{(img_id+1):>4}/{len(self.gts):<4}]" ,flush=True, end="\r")
82 | pred_bboxes = img_preds['predictions']
83 | gt_bboxes = img_gts['annotations']
84 | pred_hois = img_preds['hoi_prediction']
85 | gt_hois = img_gts['hoi_annotation']
86 | known_object_catids = [x['category_id'] for x in gt_bboxes]
87 | if len(gt_bboxes) != 0:
88 | bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes) ## match label and box
89 | self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps, known_object_catids=known_object_catids)
90 | else:
91 | if self.mode == 'default': # all predicted hois considered false positives
92 | for pred_hoi in pred_hois:
93 | triplet = [pred_bboxes[pred_hoi['subject_id']]['category_id'],
94 | pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id']]
95 | if triplet not in self.gt_triplets:
96 | continue
97 | self.tp[triplet].append(0)
98 | self.fp[triplet].append(1)
99 | self.score[triplet].append(pred_hoi['score'])
100 | print(f"[stats] Score Matrix Generation completed!! ")
101 | map = self.compute_map()
102 | return map
103 |
104 | def compute_map(self):
105 | ap = defaultdict(lambda: 0)
106 | rare_ap = defaultdict(lambda: 0)
107 | non_rare_ap = defaultdict(lambda: 0)
108 | max_recall = defaultdict(lambda: 0)
109 | rare_recall = defaultdict(lambda: 0)
110 | non_rare_recall = defaultdict(lambda: 0)
111 | for triplet in self.gt_triplets:
112 | # if triplet[-1] == 57: continue
113 | sum_gts = self.sum_gts[triplet]
114 | if sum_gts == 0:
115 | continue
116 |
117 | tp = np.array((self.tp[triplet]))
118 | fp = np.array((self.fp[triplet]))
119 | if len(tp) == 0:
120 | ap[triplet] = 0
121 | max_recall[triplet] = 0
122 | if triplet in self.rare_triplets:
123 | rare_ap[triplet] = 0
124 | rare_recall[triplet] = 0
125 | elif triplet in self.non_rare_triplets:
126 | non_rare_ap[triplet] = 0
127 | non_rare_recall[triplet] = 0
128 | else:
129 | print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet))
130 | continue
131 |
132 | score = np.array(self.score[triplet])
133 | sort_inds = np.argsort(-score)
134 | fp = fp[sort_inds]
135 | tp = tp[sort_inds]
136 | fp = np.cumsum(fp)
137 | tp = np.cumsum(tp)
138 | rec = tp / sum_gts
139 | prec = tp / (fp + tp)
140 | ap[triplet] = self.voc_ap(rec, prec)
141 | max_recall[triplet] = np.amax(rec)
142 | if triplet in self.rare_triplets:
143 | rare_ap[triplet] = ap[triplet]
144 | rare_recall[triplet] = max_recall[triplet]
145 | elif triplet in self.non_rare_triplets:
146 | non_rare_ap[triplet] = ap[triplet]
147 | non_rare_recall[triplet] = max_recall[triplet]
148 | else:
149 | print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet))
150 | m_ap = np.mean(list(ap.values())) * 100 # percentage
151 | m_ap_rare = np.mean(list(rare_ap.values())) * 100 # percentage
152 | m_ap_non_rare = np.mean(list(non_rare_ap.values())) * 100 # percentage
153 |
154 | m_max_recall = np.mean(list(max_recall.values())) * 100
155 | m_r_rare = np.mean(list(rare_recall.values())) * 100
156 | m_r_non_rare = np.mean(list(non_rare_recall.values())) * 100
157 |
158 | print('--------------------')
159 | print('mAP: {} mAP rare: {} mAP non-rare: {}'.format(m_ap, m_ap_rare, m_ap_non_rare))
160 | print('mR: {} mR rare: {} mR non-rare: {}'.format(m_max_recall, m_r_rare, m_r_non_rare))
161 | print('--------------------')
162 |
163 | return {'mAP': m_ap, 'mAP rare': m_ap_rare, 'mAP non-rare': m_ap_non_rare, 'mean max recall': m_max_recall}
164 |
165 | def voc_ap(self, rec, prec):
166 | ap = 0.
167 | for t in np.arange(0., 1.1, 0.1):
168 | if np.sum(rec >= t) == 0:
169 | p = 0
170 | else:
171 | p = np.max(prec[rec >= t])
172 | ap = ap + p / 11.
173 | return ap
174 |
175 | def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps, known_object_catids=None):
176 | pos_pred_ids = match_pairs.keys()
177 | vis_tag = np.zeros(len(gt_hois))
178 | pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True)
179 | if len(pred_hois) != 0:
180 | for pred_hoi in pred_hois:
181 | is_match = 0
182 | if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and pred_hoi['object_id'] in pos_pred_ids:
183 | pred_sub_ids = match_pairs[pred_hoi['subject_id']]
184 | pred_obj_ids = match_pairs[pred_hoi['object_id']]
185 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']]
186 | pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']]
187 | pred_category_id = pred_hoi['category_id']
188 | max_overlap = 0
189 | max_gt_hoi = 0
190 | for gt_hoi in gt_hois:
191 | if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids and pred_category_id == gt_hoi['category_id']:
192 | is_match = 1
193 | min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])],
194 | pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])])
195 | if min_overlap_gt > max_overlap:
196 | max_overlap = min_overlap_gt
197 | max_gt_hoi = gt_hoi
198 | triplet = (pred_bboxes[pred_hoi['subject_id']]['category_id'], pred_bboxes[pred_hoi['object_id']]['category_id'], pred_hoi['category_id'])
199 | if triplet not in self.gt_triplets:
200 | continue
201 | if self.mode == 'known_objects' and (triplet[1] not in known_object_catids):
202 | continue
203 | if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0:
204 | self.fp[triplet].append(0)
205 | self.tp[triplet].append(1)
206 | vis_tag[gt_hois.index(max_gt_hoi)] = 1
207 | else:
208 | self.fp[triplet].append(1)
209 | self.tp[triplet].append(0)
210 | self.score[triplet].append(pred_hoi['score'])
211 |
212 | def compute_iou_mat(self, bbox_list1, bbox_list2):
213 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2)))
214 | if len(bbox_list1) == 0 or len(bbox_list2) == 0:
215 | return {}
216 | for i, bbox1 in enumerate(bbox_list1):
217 | for j, bbox2 in enumerate(bbox_list2):
218 | iou_i = self.compute_IOU(bbox1, bbox2)
219 | iou_mat[i, j] = iou_i
220 |
221 | iou_mat_ov=iou_mat.copy()
222 | iou_mat[iou_mat>=self.overlap_iou] = 1
223 | iou_mat[iou_mat 0:
229 | for i, pred_id in enumerate(match_pairs[1]):
230 | if pred_id not in match_pairs_dict.keys():
231 | match_pairs_dict[pred_id] = []
232 | match_pair_overlaps[pred_id]=[]
233 | match_pairs_dict[pred_id].append(match_pairs[0][i])
234 | match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id])
235 | return match_pairs_dict, match_pair_overlaps
236 |
237 | def compute_IOU(self, bbox1, bbox2):
238 | if isinstance(bbox1['category_id'], str):
239 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', ''))
240 | if isinstance(bbox2['category_id'], str):
241 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', ''))
242 | if bbox1['category_id'] == bbox2['category_id']:
243 | rec1 = bbox1['bbox']
244 | rec2 = bbox2['bbox']
245 | # computing area of each rectangles
246 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1)
247 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1)
248 |
249 | # computing the sum_area
250 | sum_area = S_rec1 + S_rec2
251 |
252 | # find the each edge of intersect rectangle
253 | left_line = max(rec1[1], rec2[1])
254 | right_line = min(rec1[3], rec2[3])
255 | top_line = max(rec1[0], rec2[0])
256 | bottom_line = min(rec1[2], rec2[2])
257 | # judge if there is an intersect
258 | if left_line >= right_line or top_line >= bottom_line:
259 | return 0
260 | else:
261 | intersect = (right_line - left_line+1) * (bottom_line - top_line+1)
262 | return intersect / (sum_area - intersect)
263 | else:
264 | return 0
--------------------------------------------------------------------------------
/STIP_main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import random
5 | import time
6 | import multiprocessing
7 | from pathlib import Path
8 |
9 | import numpy as np
10 | import torch
11 | from torch.utils.data import DataLoader, DistributedSampler
12 |
13 | import src.data.datasets as datasets
14 | import src.util.misc as utils
15 | from src.engine.arg_parser import get_args_parser
16 | from src.data.datasets import build_dataset, get_coco_api_from_dataset
17 | from src.engine.trainer import train_one_epoch
18 | from src.engine import hoi_evaluator, hoi_accumulator
19 | from src.models import build_model
20 | import wandb
21 | from src.engine.evaluator_coco import coco_evaluate
22 |
23 | from src.util.logger import print_params, print_args
24 | from collections import OrderedDict
25 |
26 | def save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename):
27 | # save_ckpt: function for saving checkpoints
28 | output_dir = Path(args.output_dir)
29 | if args.output_dir:
30 | checkpoint_path = output_dir / f'{filename}.pth'
31 | utils.save_on_master({
32 | 'model': model_without_ddp.state_dict(),
33 | 'optimizer': optimizer.state_dict(),
34 | 'lr_scheduler': lr_scheduler.state_dict(),
35 | 'epoch': epoch,
36 | 'args': args,
37 | }, checkpoint_path)
38 |
39 | def main(args):
40 | utils.init_distributed_mode(args)
41 |
42 | if not args.train_detr is not None: # pretrained DETR
43 | print("Freeze weights for detector")
44 |
45 | device = torch.device(args.device)
46 |
47 | # fix the seed for reproducibility
48 | seed = args.seed + utils.get_rank()
49 | torch.manual_seed(seed)
50 | np.random.seed(seed)
51 | random.seed(seed)
52 |
53 | # Data Setup
54 | dataset_train = build_dataset(image_set='train', args=args)
55 | dataset_val = build_dataset(image_set='val' if not args.eval else 'test', args=args)
56 | assert dataset_train.num_action() == dataset_val.num_action(), "Number of actions should be the same between splits"
57 | args.num_classes = dataset_train.num_category()
58 | args.num_actions = dataset_train.num_action()
59 | args.action_names = dataset_train.get_actions()
60 | if args.share_enc: args.hoi_enc_layers = args.enc_layers
61 | if args.pretrained_dec: args.hoi_dec_layers = args.dec_layers
62 | if args.dataset_file == 'vcoco':
63 | # Save V-COCO dataset statistics
64 | args.valid_ids = np.array(dataset_train.get_object_label_idx()).nonzero()[0]
65 | args.invalid_ids = np.argwhere(np.array(dataset_train.get_object_label_idx()) == 0).squeeze(1)
66 | args.human_actions = dataset_train.get_human_action()
67 | args.object_actions = dataset_train.get_object_action()
68 | args.num_human_act = dataset_train.num_human_act()
69 | elif args.dataset_file == 'hico-det':
70 | args.valid_obj_ids = dataset_train.get_valid_obj_ids()
71 | args.correct_mat = torch.tensor(dataset_val.correct_mat).to(device)
72 | print_args(args)
73 |
74 | if args.distributed:
75 | sampler_train = DistributedSampler(dataset_train, shuffle=True)
76 | sampler_val = DistributedSampler(dataset_val, shuffle=False)
77 | else:
78 | sampler_train = torch.utils.data.RandomSampler(dataset_train)
79 | sampler_val = torch.utils.data.SequentialSampler(dataset_val)
80 |
81 | batch_sampler_train = torch.utils.data.BatchSampler(
82 | sampler_train, args.batch_size, drop_last=True)
83 |
84 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
85 | collate_fn=utils.collate_fn, num_workers=args.num_workers)
86 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
87 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers)
88 |
89 |
90 | # Model Setup
91 | model, criterion, postprocessors = build_model(args)
92 | model.to(device)
93 |
94 | model_without_ddp = model
95 | if args.distributed:
96 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
97 | model_without_ddp = model.module
98 | n_parameters = print_params(model)
99 |
100 | param_dicts = [
101 | {"params": [p for n, p in model_without_ddp.named_parameters() if "detr" not in n and p.requires_grad]},
102 | {
103 | "params": [p for n, p in model_without_ddp.named_parameters() if ("detr" in n and 'backbone' not in n) and p.requires_grad],
104 | "lr": args.lr * 0.1,
105 | },
106 | {
107 | "params": [p for n, p in model_without_ddp.named_parameters() if ("detr" in n and 'backbone' in n) and p.requires_grad],
108 | "lr": args.lr * 0.01,
109 | },
110 | ]
111 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, weight_decay=args.weight_decay)
112 | # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
113 | lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=args.reduce_lr_on_plateau_factor, patience=args.reduce_lr_on_plateau_patience, verbose=True)
114 |
115 | # Weight Setup
116 | if args.detr_weights is not None:
117 | print(f"Loading detr weights from args.detr_weights={args.detr_weights}")
118 | if args.detr_weights.startswith('https'):
119 | checkpoint = torch.hub.load_state_dict_from_url(
120 | args.detr_weights, map_location='cpu', check_hash=True)
121 | else:
122 | checkpoint = torch.load(args.detr_weights, map_location='cpu')
123 |
124 | if 'hico_ft_q16.pth' in args.detr_weights: # hack: for loading hico fine-tuned detr
125 | mapped_state_dict = OrderedDict()
126 | for k, v in checkpoint['model'].items():
127 | if k.startswith('detr.'):
128 | mapped_state_dict[k.replace('detr.', '')] = v
129 | model_without_ddp.detr.load_state_dict(mapped_state_dict)
130 | else:
131 | model_without_ddp.detr.load_state_dict(checkpoint['model'])
132 |
133 | if args.resume:
134 | print(f"Loading model weights from args.resume={args.resume}")
135 | if args.resume.startswith('https'):
136 | checkpoint = torch.hub.load_state_dict_from_url(
137 | args.resume, map_location='cpu', check_hash=True)
138 | else:
139 | checkpoint = torch.load(args.resume, map_location='cpu')
140 | model_without_ddp.load_state_dict(checkpoint['model'])
141 |
142 | if args.eval:
143 | # test only mode
144 | if args.HOIDet:
145 | if args.dataset_file == 'vcoco':
146 | total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device)
147 | sc1, sc2 = hoi_accumulator(args, total_res, True, False)
148 | elif args.dataset_file == 'hico-det':
149 | test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device)
150 | print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f}')
151 | print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f}')
152 | print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f}')
153 | else: raise ValueError(f'dataset {args.dataset_file} is not supported.')
154 | return
155 | else:
156 | # check original detr code
157 | base_ds = get_coco_api_from_dataset(data_loader_val)
158 | test_stats, coco_evaluator = coco_evaluate(model, criterion, postprocessors,
159 | data_loader_val, base_ds, device, args.output_dir)
160 | if args.output_dir:
161 | utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, args.output_dir / "eval.pth")
162 | return
163 |
164 | # stats
165 | scenario1, scenario2 = 0, 0
166 | best_mAP, best_rare, best_non_rare = 0, 0, 0
167 |
168 | # add argparse
169 | if args.wandb and utils.get_rank() == 0:
170 | wandb.init(
171 | project=args.project_name,
172 | group=args.group_name,
173 | name=args.run_name,
174 | config=args
175 | )
176 | wandb.watch(model)
177 |
178 | # Training starts here!
179 | start_time = time.time()
180 | for epoch in range(args.start_epoch, args.epochs):
181 | if args.distributed:
182 | sampler_train.set_epoch(epoch)
183 | train_stats = train_one_epoch(
184 | model, criterion, data_loader_train, optimizer, device, epoch, args.epochs,
185 | args.clip_max_norm, dataset_file=args.dataset_file, log=args.wandb)
186 |
187 | if isinstance(lr_scheduler, torch.optim.lr_scheduler.StepLR): lr_scheduler.step()
188 |
189 | # Validation
190 | if args.validate:
191 | print('-'*100)
192 | if args.dataset_file == 'vcoco':
193 | total_res = hoi_evaluator(args, model, criterion, postprocessors, data_loader_val, device)
194 | if utils.get_rank() == 0:
195 | sc1, sc2 = hoi_accumulator(args, total_res, False, args.wandb)
196 | if sc1 > scenario1:
197 | scenario1 = sc1
198 | scenario2 = sc2
199 | save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best')
200 | print(f'| Scenario #1 mAP : {sc1:.2f} ({scenario1:.2f})')
201 | print(f'| Scenario #2 mAP : {sc2:.2f} ({scenario2:.2f})')
202 | if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): lr_scheduler.step(sc1)
203 | elif args.dataset_file == 'hico-det':
204 | test_stats = hoi_evaluator(args, model, None, postprocessors, data_loader_val, device)
205 | if utils.get_rank() == 0:
206 | if test_stats['mAP'] > best_mAP:
207 | best_mAP = test_stats['mAP']
208 | best_rare = test_stats['mAP rare']
209 | best_non_rare = test_stats['mAP non-rare']
210 | save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename='best')
211 | print(f'| mAP (full)\t\t: {test_stats["mAP"]:.2f} ({best_mAP:.2f})')
212 | print(f'| mAP (rare)\t\t: {test_stats["mAP rare"]:.2f} ({best_rare:.2f})')
213 | print(f'| mAP (non-rare)\t: {test_stats["mAP non-rare"]:.2f} ({best_non_rare:.2f})')
214 | if args.wandb and utils.get_rank() == 0:
215 | wandb.log({
216 | 'mAP': test_stats['mAP']
217 | })
218 | if isinstance(lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): lr_scheduler.step(test_stats['mAP'])
219 | print('-'*100)
220 |
221 | # if epoch%2==0:
222 | # save_ckpt(args, model_without_ddp, optimizer, lr_scheduler, epoch, filename=f'checkpoint_{epoch}')
223 |
224 | total_time = time.time() - start_time
225 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
226 | print('Training time {}'.format(total_time_str))
227 | if args.dataset_file == 'vcoco':
228 | print(f'| Scenario #1 mAP : {scenario1:.2f}')
229 | print(f'| Scenario #2 mAP : {scenario2:.2f}')
230 | elif args.dataset_file == 'hico-det':
231 | print(f'| mAP (full)\t\t: {best_mAP:.2f}')
232 | print(f'| mAP (rare)\t\t: {best_rare:.2f}')
233 | print(f'| mAP (non-rare)\t: {best_non_rare:.2f}')
234 |
235 |
236 | if __name__ == '__main__':
237 | parser = argparse.ArgumentParser(
238 | 'End-to-End Human Object Interaction training and evaluation script',
239 | parents=[get_args_parser()]
240 | )
241 | # training
242 | parser.add_argument('--detr_weights', default=None, type=str)
243 | parser.add_argument('--train_detr', action='store_true', default=False)
244 | parser.add_argument('--finetune_detr_weight', default=0.1, type=float)
245 | parser.add_argument('--lr_detr', default=1e-5, type=float)
246 | parser.add_argument('--reduce_lr_on_plateau_patience', default=2, type=int)
247 | parser.add_argument('--reduce_lr_on_plateau_factor', default=0.1, type=float)
248 |
249 | # loss
250 | parser.add_argument('--proposal_focal_loss_alpha', default=0.75, type=float) # large alpha for high recall
251 | parser.add_argument('--action_focal_loss_alpha', default=0.5, type=float)
252 | parser.add_argument('--proposal_focal_loss_gamma', default=2, type=float)
253 | parser.add_argument('--action_focal_loss_gamma', default=2, type=float)
254 | parser.add_argument('--proposal_loss_coef', default=1, type=float)
255 | parser.add_argument('--action_loss_coef', default=1, type=float)
256 |
257 | # ablations
258 | parser.add_argument('--no_hard_mining_for_relation_discovery', dest='use_hard_mining_for_relation_discovery', action='store_false', default=True)
259 | parser.add_argument('--no_relation_dependency_encoding', dest='use_relation_dependency_encoding', action='store_false', default=True)
260 | parser.add_argument('--no_memory_layout_encoding', dest='use_memory_layout_encoding', action='store_false', default=True, help='layout encodings')
261 | parser.add_argument('--no_nms_on_detr', dest='apply_nms_on_detr', action='store_false', default=True)
262 | parser.add_argument('--no_tail_semantic_feature', dest='use_tail_semantic_feature', action='store_false', default=True)
263 | parser.add_argument('--no_spatial_feature', dest='use_spatial_feature', action='store_false', default=True)
264 | parser.add_argument('--no_interaction_decoder', action='store_true', default=False)
265 |
266 | # not sensitive or effective
267 | parser.add_argument('--use_memory_union_mask', action='store_true', default=False)
268 | parser.add_argument('--use_union_feature', action='store_true', default=False)
269 | parser.add_argument('--adaptive_relation_query_num', action='store_true', default=False)
270 | parser.add_argument('--use_relation_tgt_mask', action='store_true', default=False)
271 | parser.add_argument('--use_relation_tgt_mask_attend_topk', default=10, type=int)
272 | parser.add_argument('--use_prior_verb_label_mask', action='store_true', default=False)
273 | parser.add_argument('--relation_feature_map_from', default='backbone', help='backbone | detr_encoder')
274 | parser.add_argument('--use_query_fourier_encoding', action='store_true', default=False)
275 |
276 | args = parser.parse_args()
277 | args.STIP_relation_head = True
278 |
279 | if args.output_dir:
280 | args.output_dir += f"/{args.group_name}/{args.run_name}/"
281 | Path(args.output_dir).mkdir(parents=True, exist_ok=True)
282 | main(args)
283 |
--------------------------------------------------------------------------------
/src/data/transforms/transforms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 | """
3 | Transforms and data augmentation for both image + bbox.
4 | """
5 | import random
6 |
7 | import PIL
8 | import torch
9 | import torchvision.transforms as T
10 | import torchvision.transforms.functional as F
11 |
12 | from src.util.box_ops import box_xyxy_to_cxcywh
13 | from src.util.misc import interpolate
14 |
15 |
16 | def crop(image, target, region):
17 | cropped_image = F.crop(image, *region)
18 |
19 | target = target.copy()
20 | i, j, h, w = region
21 |
22 | # should we do something wrt the original size?
23 | target["size"] = torch.tensor([h, w])
24 | max_size = torch.as_tensor([w, h], dtype=torch.float32)
25 |
26 | fields = ["labels", "area", "iscrowd"] # add additional fields
27 | if "inst_actions" in target.keys():
28 | fields.append("inst_actions")
29 |
30 | if "boxes" in target:
31 | boxes = target["boxes"]
32 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i])
33 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size)
34 | cropped_boxes = cropped_boxes.clamp(min=0)
35 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1)
36 | target["boxes"] = cropped_boxes.reshape(-1, 4)
37 | target["area"] = area
38 | fields.append("boxes")
39 |
40 | if "pair_boxes" in target or ("sub_boxes" in target and "obj_boxes" in target):
41 | if "pair_boxes" in target:
42 | pair_boxes = target["pair_boxes"]
43 | hboxes = pair_boxes[:, :4]
44 | oboxes = pair_boxes[:, 4:]
45 | if ("sub_boxes" in target and "obj_boxes" in target):
46 | hboxes = target["sub_boxes"]
47 | oboxes = target["obj_boxes"]
48 |
49 | cropped_hboxes = hboxes - torch.as_tensor([j, i, j, i])
50 | cropped_hboxes = torch.min(cropped_hboxes.reshape(-1, 2, 2), max_size)
51 | cropped_hboxes = cropped_hboxes.clamp(min=0)
52 | hboxes = cropped_hboxes.reshape(-1, 4)
53 |
54 | obj_mask = (oboxes[:, 0] != -1)
55 | if obj_mask.sum() != 0:
56 | cropped_oboxes = oboxes[obj_mask] - torch.as_tensor([j, i, j, i])
57 | cropped_oboxes = torch.min(cropped_oboxes.reshape(-1, 2, 2), max_size)
58 | cropped_oboxes = cropped_oboxes.clamp(min=0)
59 | oboxes[obj_mask] = cropped_oboxes.reshape(-1, 4)
60 | else:
61 | cropped_oboxes = oboxes
62 |
63 | cropped_pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
64 | target["pair_boxes"] = cropped_pair_boxes
65 | pair_fields = ["pair_boxes", "pair_actions", "pair_targets"]
66 |
67 | if "masks" in target:
68 | # FIXME should we update the area here if there are no boxes[?
69 | target['masks'] = target['masks'][:, i:i + h, j:j + w]
70 | fields.append("masks")
71 |
72 | # remove elements for which the boxes or masks that have zero area
73 | if "boxes" in target or "masks" in target:
74 | # favor boxes selection when defining which elements to keep
75 | # this is compatible with previous implementation
76 | if "boxes" in target:
77 | cropped_boxes = target['boxes'].reshape(-1, 2, 2)
78 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1)
79 | else:
80 | keep = target['masks'].flatten(1).any(1)
81 |
82 | for field in fields:
83 | if field in target: # added this because there is no 'iscrowd' field in v-coco dataset
84 | target[field] = target[field][keep]
85 |
86 | # remove elements that have redundant area
87 | if "boxes" in target and "labels" in target:
88 | cropped_boxes = target['boxes']
89 | cropped_labels = target['labels']
90 |
91 | cnr, keep_idx = [], []
92 | for idx, (cropped_box, cropped_lbl) in enumerate(zip(cropped_boxes, cropped_labels)):
93 | if str((cropped_box, cropped_lbl)) not in cnr:
94 | cnr.append(str((cropped_box, cropped_lbl)))
95 | keep_idx.append(True)
96 | else: keep_idx.append(False)
97 |
98 | for field in fields:
99 | if field in target:
100 | target[field] = target[field][keep_idx]
101 |
102 | # remove elements for which pair boxes have zero area
103 | if "pair_boxes" in target:
104 | cropped_hboxes = target["pair_boxes"][:, :4].reshape(-1, 2, 2)
105 | cropped_oboxes = target["pair_boxes"][:, 4:].reshape(-1, 2, 2)
106 | keep_h = torch.all(cropped_hboxes[:, 1, :] > cropped_hboxes[:, 0, :], dim=1)
107 | keep_o = torch.all(cropped_oboxes[:, 1, :] > cropped_oboxes[:, 0, :], dim=1)
108 | not_empty_o = torch.all(target["pair_boxes"][:, 4:] >= 0, dim=1)
109 | discard_o = (~keep_o) & not_empty_o
110 | if (discard_o).sum() > 0:
111 | target["pair_boxes"][discard_o, 4:] = -1
112 |
113 | for pair_field in pair_fields:
114 | target[pair_field] = target[pair_field][keep_h]
115 |
116 | return cropped_image, target
117 |
118 |
119 | def hflip(image, target):
120 | flipped_image = F.hflip(image)
121 |
122 | w, h = image.size
123 |
124 | target = target.copy()
125 | if "boxes" in target:
126 | boxes = target["boxes"]
127 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
128 | target["boxes"] = boxes
129 |
130 | if "pair_boxes" in target:
131 | pair_boxes = target["pair_boxes"]
132 | hboxes = pair_boxes[:, :4]
133 | oboxes = pair_boxes[:, 4:]
134 |
135 | # human flip
136 | hboxes = hboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
137 |
138 | # object flip
139 | obj_mask = (oboxes[:, 0] != -1)
140 | if obj_mask.sum() != 0:
141 | o_tmp = oboxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0])
142 | oboxes[obj_mask] = o_tmp[obj_mask]
143 |
144 | pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
145 | target["pair_boxes"] = pair_boxes
146 |
147 | if "masks" in target:
148 | target['masks'] = target['masks'].flip(-1)
149 |
150 | return flipped_image, target
151 |
152 |
153 | def resize(image, target, size, max_size=None):
154 | # size can be min_size (scalar) or (w, h) tuple
155 |
156 | def get_size_with_aspect_ratio(image_size, size, max_size=None):
157 | w, h = image_size
158 | if max_size is not None:
159 | min_original_size = float(min((w, h)))
160 | max_original_size = float(max((w, h)))
161 | if max_original_size / min_original_size * size > max_size:
162 | size = int(round(max_size * min_original_size / max_original_size))
163 |
164 | if (w <= h and w == size) or (h <= w and h == size):
165 | return (h, w)
166 |
167 | if w < h:
168 | ow = size
169 | oh = int(size * h / w)
170 | else:
171 | oh = size
172 | ow = int(size * w / h)
173 |
174 | return (oh, ow)
175 |
176 | def get_size(image_size, size, max_size=None):
177 | if isinstance(size, (list, tuple)):
178 | return size[::-1]
179 | else:
180 | return get_size_with_aspect_ratio(image_size, size, max_size)
181 |
182 | size = get_size(image.size, size, max_size)
183 | rescaled_image = F.resize(image, size)
184 |
185 | if target is None:
186 | return rescaled_image, None
187 |
188 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size))
189 | ratio_width, ratio_height = ratios
190 |
191 | target = target.copy()
192 | if "boxes" in target:
193 | boxes = target["boxes"]
194 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
195 | target["boxes"] = scaled_boxes
196 |
197 | if "pair_boxes" in target:
198 | hboxes = target["pair_boxes"][:, :4]
199 | scaled_hboxes = hboxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
200 | hboxes = scaled_hboxes
201 |
202 | oboxes = target["pair_boxes"][:, 4:]
203 | obj_mask = (oboxes[:, 0] != -1)
204 | if obj_mask.sum() != 0:
205 | scaled_oboxes = oboxes[obj_mask] * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height])
206 | oboxes[obj_mask] = scaled_oboxes
207 |
208 | target["pair_boxes"] = torch.cat([hboxes, oboxes], dim=-1)
209 |
210 | if "area" in target:
211 | area = target["area"]
212 | scaled_area = area * (ratio_width * ratio_height)
213 | target["area"] = scaled_area
214 |
215 | h, w = size
216 | target["size"] = torch.tensor([h, w])
217 |
218 | if "masks" in target:
219 | target['masks'] = interpolate(
220 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5
221 |
222 | return rescaled_image, target
223 |
224 |
225 | def pad(image, target, padding):
226 | # assumes that we only pad on the bottom right corners
227 | padded_image = F.pad(image, (0, 0, padding[0], padding[1]))
228 | if target is None:
229 | return padded_image, None
230 | target = target.copy()
231 | # should we do something wrt the original size?
232 | target["size"] = torch.tensor(padded_image[::-1])
233 | if "masks" in target:
234 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
235 | return padded_image, target
236 |
237 |
238 | class RandomCrop(object):
239 | def __init__(self, size):
240 | self.size = size
241 |
242 | def __call__(self, img, target):
243 | region = T.RandomCrop.get_params(img, self.size)
244 | return crop(img, target, region)
245 |
246 |
247 | class RandomSizeCrop(object):
248 | def __init__(self, min_size: int, max_size: int):
249 | self.min_size = min_size
250 | self.max_size = max_size
251 |
252 | def __call__(self, img: PIL.Image.Image, target: dict):
253 | w = random.randint(self.min_size, min(img.width, self.max_size))
254 | h = random.randint(self.min_size, min(img.height, self.max_size))
255 | region = T.RandomCrop.get_params(img, [h, w])
256 | return crop(img, target, region)
257 |
258 |
259 | class CenterCrop(object):
260 | def __init__(self, size):
261 | self.size = size
262 |
263 | def __call__(self, img, target):
264 | image_width, image_height = img.size
265 | crop_height, crop_width = self.size
266 | crop_top = int(round((image_height - crop_height) / 2.))
267 | crop_left = int(round((image_width - crop_width) / 2.))
268 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width))
269 |
270 |
271 | class RandomHorizontalFlip(object):
272 | def __init__(self, p=0.5):
273 | self.p = p
274 |
275 | def __call__(self, img, target):
276 | if random.random() < self.p:
277 | return hflip(img, target)
278 | return img, target
279 |
280 |
281 | class RandomResize(object):
282 | def __init__(self, sizes, max_size=None):
283 | assert isinstance(sizes, (list, tuple))
284 | self.sizes = sizes
285 | self.max_size = max_size
286 |
287 | def __call__(self, img, target=None):
288 | size = random.choice(self.sizes)
289 | return resize(img, target, size, self.max_size)
290 |
291 |
292 | class RandomPad(object):
293 | def __init__(self, max_pad):
294 | self.max_pad = max_pad
295 |
296 | def __call__(self, img, target):
297 | pad_x = random.randint(0, self.max_pad)
298 | pad_y = random.randint(0, self.max_pad)
299 | return pad(img, target, (pad_x, pad_y))
300 |
301 |
302 | class RandomSelect(object):
303 | """
304 | Randomly selects between transforms1 and transforms2,
305 | with probability p for transforms1 and (1 - p) for transforms2
306 | """
307 | def __init__(self, transforms1, transforms2, p=0.5):
308 | self.transforms1 = transforms1
309 | self.transforms2 = transforms2
310 | self.p = p
311 |
312 | def __call__(self, img, target):
313 | if random.random() < self.p:
314 | return self.transforms1(img, target)
315 | return self.transforms2(img, target)
316 |
317 |
318 | class ToTensor(object):
319 | def __call__(self, img, target):
320 | return F.to_tensor(img), target
321 |
322 |
323 | class RandomErasing(object):
324 |
325 | def __init__(self, *args, **kwargs):
326 | self.eraser = T.RandomErasing(*args, **kwargs)
327 |
328 | def __call__(self, img, target):
329 | return self.eraser(img), target
330 |
331 |
332 | class Normalize(object):
333 | def __init__(self, mean, std):
334 | self.mean = mean
335 | self.std = std
336 |
337 | def __call__(self, image, target=None):
338 | image = F.normalize(image, mean=self.mean, std=self.std)
339 | if target is None:
340 | return image, None
341 | target = target.copy()
342 | h, w = image.shape[-2:]
343 | if "boxes" in target:
344 | boxes = target["boxes"]
345 | boxes = box_xyxy_to_cxcywh(boxes)
346 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
347 | target["boxes"] = boxes
348 |
349 | if "pair_boxes" in target:
350 | hboxes = target["pair_boxes"][:, :4]
351 | hboxes = box_xyxy_to_cxcywh(hboxes)
352 | hboxes = hboxes / torch.tensor([w, h, w, h], dtype=torch.float32)
353 |
354 | oboxes = target["pair_boxes"][:, 4:]
355 | obj_mask = (oboxes[:, 0] != -1)
356 | if obj_mask.sum() != 0:
357 | oboxes[obj_mask] = box_xyxy_to_cxcywh(oboxes[obj_mask])
358 | oboxes[obj_mask] = oboxes[obj_mask] / torch.tensor([w, h, w, h], dtype=torch.float32)
359 |
360 | pair_boxes = torch.cat([hboxes, oboxes], dim=-1)
361 | target["pair_boxes"] = pair_boxes
362 |
363 | return image, target
364 |
365 | class ColorJitter(object):
366 | def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0):
367 | self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue)
368 |
369 | def __call__(self, img, target):
370 | return self.color_jitter(img), target
371 |
372 | class Compose(object):
373 | def __init__(self, transforms):
374 | self.transforms = transforms
375 |
376 | def __call__(self, image, target):
377 | for t in self.transforms:
378 | image, target = t(image, target)
379 | return image, target
380 |
381 | def __repr__(self):
382 | format_string = self.__class__.__name__ + "("
383 | for t in self.transforms:
384 | format_string += "\n"
385 | format_string += " {0}".format(t)
386 | format_string += "\n)"
387 | return format_string
388 |
--------------------------------------------------------------------------------
/src/data/datasets/hico.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # HOTR official code : src/data/datasets/hico.py
3 | # Copyright (c) Kakao Brain, Inc. and its affiliates. All Rights Reserved
4 | # ------------------------------------------------------------------------
5 | # Modified from QPIC (https://github.com/hitachi-rd-cv/qpic)
6 | # Copyright (c) Hitachi, Ltd. All Rights Reserved.
7 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
8 | # ------------------------------------------------------------------------
9 | from pathlib import Path
10 | from PIL import Image
11 | import json
12 | from collections import defaultdict
13 | import numpy as np
14 |
15 | import torch
16 | import torch.utils.data
17 | import torchvision
18 |
19 | from src.data.datasets import builtin_meta
20 | import src.data.transforms.transforms as T
21 |
22 | class HICODetection(torch.utils.data.Dataset):
23 | def __init__(self, img_set, img_folder, anno_file, action_list_file, transforms, num_queries):
24 | self.img_set = img_set
25 | self.img_folder = img_folder
26 | with open(anno_file, 'r') as f:
27 | self.annotations = json.load(f)
28 | with open(action_list_file, 'r') as f:
29 | self.action_lines = f.readlines()
30 | self._transforms = transforms
31 | self.num_queries = num_queries
32 | self.get_metadata()
33 |
34 | if img_set == 'train':
35 | self.ids = []
36 | for idx, img_anno in enumerate(self.annotations):
37 | for hoi in img_anno['hoi_annotation']:
38 | if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len(img_anno['annotations']):
39 | break
40 | else:
41 | self.ids.append(idx)
42 | else:
43 | self.ids = list(range(len(self.annotations)))
44 | # self.ids = self.ids[:1000]
45 |
46 | ############################################################################
47 | # Number Method
48 | ############################################################################
49 | def get_metadata(self):
50 | meta = builtin_meta._get_coco_instances_meta()
51 | self.COCO_CLASSES = meta['coco_classes']
52 | self._valid_obj_ids = [id for id in meta['thing_dataset_id_to_contiguous_id'].keys()]
53 | self._valid_verb_ids, self._valid_verb_names = [], []
54 | for action_line in self.action_lines[2:]:
55 | act_id, act_name = action_line.split()
56 | self._valid_verb_ids.append(int(act_id))
57 | self._valid_verb_names.append(act_name)
58 |
59 | def get_valid_obj_ids(self):
60 | return self._valid_obj_ids
61 |
62 | def get_actions(self):
63 | return self._valid_verb_names
64 |
65 | def num_category(self):
66 | return len(self.COCO_CLASSES)
67 |
68 | def num_action(self):
69 | return len(self._valid_verb_ids)
70 | ############################################################################
71 |
72 | def __len__(self):
73 | return len(self.ids)
74 |
75 | def __getitem__(self, idx):
76 | img_anno = self.annotations[self.ids[idx]]
77 |
78 | img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB')
79 | w, h = img.size
80 |
81 | if self.img_set == 'train':
82 | img_anno = merge_box_annotations(img_anno)
83 | # img_anno = merge_box_annotations(img_anno) # for finetune detr
84 |
85 | # cut out the GTs that exceed the number of object queries
86 | if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries:
87 | img_anno['annotations'] = img_anno['annotations'][:self.num_queries]
88 |
89 | boxes = [obj['bbox'] for obj in img_anno['annotations']]
90 | # guard against no boxes via resizing
91 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
92 |
93 | if self.img_set == 'train':
94 | # Add index for confirming which boxes are kept after image transformation
95 | classes = [(i, obj['category_id']) for i, obj in enumerate(img_anno['annotations'])]
96 | else:
97 | classes = [obj['category_id'] for obj in img_anno['annotations']]
98 | classes = torch.tensor(classes, dtype=torch.int64)
99 |
100 | target = {}
101 | target['orig_size'] = torch.as_tensor([int(h), int(w)])
102 | target['size'] = torch.as_tensor([int(h), int(w)])
103 | if self.img_set == 'train':
104 | boxes[:, 0::2].clamp_(min=0, max=w)
105 | boxes[:, 1::2].clamp_(min=0, max=h)
106 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
107 | boxes = boxes[keep]
108 | classes = classes[keep]
109 |
110 | target['boxes'] = boxes
111 | target['labels'] = classes
112 | target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])])
113 | target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
114 |
115 | if self._transforms is not None:
116 | img, target = self._transforms(img, target)
117 |
118 | kept_box_indices = [label[0] for label in target['labels']]
119 |
120 | target['labels'] = target['labels'][:, 1]
121 |
122 | obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], []
123 | sub_obj_pairs = []
124 | for hoi in img_anno['hoi_annotation']:
125 | if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices:
126 | continue
127 | sub_obj_pair = (hoi['subject_id'], hoi['object_id'])
128 | if sub_obj_pair in sub_obj_pairs: # multi label
129 | verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1
130 | else:
131 | sub_obj_pairs.append(sub_obj_pair)
132 | obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])])
133 | verb_label = [0 for _ in range(len(self._valid_verb_ids))]
134 | verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1
135 | sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])]
136 | obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])]
137 | verb_labels.append(verb_label)
138 | sub_boxes.append(sub_box)
139 | obj_boxes.append(obj_box)
140 | if len(sub_obj_pairs) == 0:
141 | target['pair_targets'] = torch.zeros((0,), dtype=torch.int64)
142 | target['pair_actions'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32)
143 | target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
144 | target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32)
145 | else:
146 | target['pair_targets'] = torch.stack(obj_labels)
147 | target['pair_actions'] = torch.as_tensor(verb_labels, dtype=torch.float32)
148 | target['sub_boxes'] = torch.stack(sub_boxes)
149 | target['obj_boxes'] = torch.stack(obj_boxes)
150 |
151 | # relation map
152 | relation_map = torch.zeros((len(target['boxes']), len(target['boxes']), self.num_action()))
153 | for sub_obj_pair in sub_obj_pairs:
154 | kept_subj_id = kept_box_indices.index(sub_obj_pair[0])
155 | kept_obj_id = kept_box_indices.index(sub_obj_pair[1])
156 | relation_map[kept_subj_id, kept_obj_id] = torch.tensor(verb_labels[sub_obj_pairs.index(sub_obj_pair)])
157 | target['relation_map'] = relation_map
158 | target['hois'] = relation_map.nonzero(as_tuple=False)
159 | else:
160 | target['boxes'] = boxes
161 | target['labels'] = classes
162 | target['id'] = idx
163 |
164 | if self._transforms is not None:
165 | img, _ = self._transforms(img, None)
166 |
167 | hois = []
168 | for hoi in img_anno['hoi_annotation']:
169 | hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id'])))
170 | target['hois'] = torch.as_tensor(hois, dtype=torch.int64)
171 |
172 | target['image_id'] = torch.tensor([idx])
173 | return img, target
174 |
175 | def set_rare_hois(self, anno_file):
176 | with open(anno_file, 'r') as f:
177 | annotations = json.load(f)
178 |
179 | counts = defaultdict(lambda: 0)
180 | for img_anno in annotations:
181 | hois = img_anno['hoi_annotation']
182 | bboxes = img_anno['annotations']
183 | for hoi in hois:
184 | # mapped to valid obj ids for evaludation
185 | triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']),
186 | self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']),
187 | self._valid_verb_ids.index(hoi['category_id']))
188 | counts[triplet] += 1
189 | self.rare_triplets = []
190 | self.non_rare_triplets = []
191 | for triplet, count in counts.items():
192 | if count < 10:
193 | self.rare_triplets.append(triplet)
194 | else:
195 | self.non_rare_triplets.append(triplet)
196 |
197 | def load_correct_mat(self, path):
198 | self.correct_mat = np.load(path)
199 |
200 |
201 | # Add color jitter to coco transforms
202 | def make_hico_transforms(image_set):
203 |
204 | normalize = T.Compose([
205 | T.ToTensor(),
206 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
207 | ])
208 |
209 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
210 |
211 | if image_set == 'train':
212 | return T.Compose([
213 | T.RandomHorizontalFlip(),
214 | T.ColorJitter(.4, .4, .4),
215 | T.RandomSelect(
216 | T.RandomResize(scales, max_size=1333),
217 | T.Compose([
218 | T.RandomResize([400, 500, 600]),
219 | T.RandomSizeCrop(384, 600),
220 | T.RandomResize(scales, max_size=1333),
221 | ])
222 | ),
223 | normalize,
224 | ])
225 |
226 | if image_set == 'val':
227 | return T.Compose([
228 | T.RandomResize([800], max_size=1333),
229 | normalize,
230 | ])
231 |
232 | if image_set == 'test':
233 | return T.Compose([
234 | T.RandomResize([800], max_size=1333),
235 | normalize,
236 | ])
237 |
238 | raise ValueError(f'unknown {image_set}')
239 |
240 |
241 | def merge_box_annotations(org_image_annotation, overlap_iou_thres=0.7):
242 | merged_image_annotation = org_image_annotation.copy()
243 |
244 | # compute match
245 | bbox_list = org_image_annotation['annotations']
246 | box_match = torch.zeros(len(bbox_list), len(bbox_list)).bool()
247 | for i, bbox1 in enumerate(bbox_list):
248 | for j, bbox2 in enumerate(bbox_list):
249 | box_match[i, j] = compute_box_match(bbox1, bbox2, overlap_iou_thres)
250 |
251 | box_groups = []
252 | for i in range(len(box_match)):
253 | if box_match[i].any(): # box unassigned to group
254 | group_ids = box_match[i].nonzero(as_tuple=False).squeeze(1)
255 | box_groups.append(group_ids.tolist())
256 | box_match[:, group_ids] = False
257 | assert sum([len(g) for g in box_groups]) == len(bbox_list)
258 |
259 | # merge to new anntations
260 | group_info, orgbox2group = [], {}
261 | for gid, org_box_ids in enumerate(box_groups):
262 | for orgid in org_box_ids: orgbox2group.update({orgid: gid})
263 | # selected_box_id = np.random.choice(org_box_ids)
264 | # box_info = bbox_list[selected_box_id]
265 | box_info = {
266 | 'bbox': torch.tensor([bbox_list[id]['bbox'] for id in org_box_ids]).float().mean(dim=0).int().tolist(),
267 | 'category_id': bbox_list[org_box_ids[0]]['category_id']
268 | }
269 |
270 | group_info.append(box_info)
271 |
272 | new_hois = []
273 | for hoi in org_image_annotation['hoi_annotation']:
274 | if hoi['subject_id'] in orgbox2group and hoi['object_id'] in orgbox2group:
275 | new_hois.append({
276 | 'subject_id': orgbox2group[hoi['subject_id']],
277 | 'object_id': orgbox2group[hoi['object_id']],
278 | 'category_id': hoi['category_id']
279 | })
280 |
281 | merged_image_annotation['annotations'] = group_info
282 | merged_image_annotation['hoi_annotation'] = new_hois
283 | return merged_image_annotation
284 |
285 | # iou > threshold and same category
286 | def compute_box_match(bbox1, bbox2, threshold):
287 | if isinstance(bbox1['category_id'], str):
288 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', ''))
289 | if isinstance(bbox2['category_id'], str):
290 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', ''))
291 | if bbox1['category_id'] != bbox2['category_id']:
292 | return False
293 | else:
294 | rec1 = bbox1['bbox']
295 | rec2 = bbox2['bbox']
296 | # computing area of each rectangles
297 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1)
298 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1)
299 |
300 | # computing the sum_area
301 | sum_area = S_rec1 + S_rec2
302 |
303 | # find the each edge of intersect rectangle
304 | left_line = max(rec1[1], rec2[1])
305 | right_line = min(rec1[3], rec2[3])
306 | top_line = max(rec1[0], rec2[0])
307 | bottom_line = min(rec1[2], rec2[2])
308 |
309 | # judge if there is an intersect
310 | intersect = max((right_line - left_line+1), 0) * max((bottom_line - top_line+1), 0)
311 | iou = intersect / (sum_area - intersect)
312 | if iou > threshold:
313 | return True
314 | else:
315 | return False
316 |
317 | def build(image_set, args):
318 | root = Path(args.data_path)
319 | assert root.exists(), f'provided HOI path {root} does not exist'
320 | PATHS = {
321 | 'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json'),
322 | 'val': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json'),
323 | 'test': (root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json')
324 | }
325 | CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy'
326 | action_list_file = root / 'list_action.txt'
327 |
328 | img_folder, anno_file = PATHS[image_set]
329 | dataset = HICODetection(image_set, img_folder, anno_file, action_list_file, transforms=make_hico_transforms(image_set),
330 | num_queries=args.num_queries)
331 | if image_set == 'val' or image_set == 'test':
332 | dataset.set_rare_hois(PATHS['train'][1])
333 | dataset.load_correct_mat(CORRECT_MAT_PATH)
334 | return dataset
--------------------------------------------------------------------------------