├── tools ├── __init__.py ├── convert_parameters.py ├── convert_vcoco_annotations.py └── covert_annot_for_official_eval.py ├── configs ├── __init__.py ├── .DS_Store ├── vcoco_l.sh ├── vcoco_m.sh ├── vcoco_s.sh ├── hico_s.sh ├── hico_l.sh ├── hico_m.sh ├── hico_s_zs_uo.sh ├── hico_s_zs_uv.sh ├── hico_s_zs_nf_uc.sh └── hico_s_zs_rf_uc.sh ├── datasets ├── .DS_Store ├── __init__.py ├── vcoco.py ├── transforms.py ├── vcoco_eval.py ├── hico.py ├── hico_eval_triplet.py └── vcoco_text_label.py ├── paper_images ├── .DS_Store └── intro.png ├── models ├── __init__.py ├── position_encoding.py ├── matcher.py ├── backbone.py └── gen.py ├── requirements.txt ├── util ├── __init__.py ├── topk.py ├── box_ops.py ├── plot_utils.py └── misc.py ├── LICENSE ├── engine.py ├── README.md ├── main.py └── generate_vcoco_official.py /tools/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YueLiao/gen-vlkt/HEAD/configs/.DS_Store -------------------------------------------------------------------------------- /datasets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YueLiao/gen-vlkt/HEAD/datasets/.DS_Store -------------------------------------------------------------------------------- /paper_images/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YueLiao/gen-vlkt/HEAD/paper_images/.DS_Store -------------------------------------------------------------------------------- /paper_images/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YueLiao/gen-vlkt/HEAD/paper_images/intro.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .gen_vlkt import build 2 | 3 | 4 | def build_model(args): 5 | return build(args) 6 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cython 2 | pycocotools 3 | torch==1.7.1 4 | torchvision==0.8.2 5 | scipy==1.3.1 6 | opencv-python 7 | ftfy 8 | regex 9 | tqdm -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data 2 | import torchvision 3 | 4 | from .hico import build as build_hico 5 | from .vcoco import build as build_vcoco 6 | 7 | def build_dataset(image_set, args): 8 | if args.dataset_file == 'hico': 9 | return build_hico(image_set, args) 10 | if args.dataset_file == 'vcoco': 11 | return build_vcoco(image_set, args) 12 | raise ValueError(f'dataset {args.dataset_file} not supported') 13 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | 9 | -------------------------------------------------------------------------------- /configs/vcoco_l.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/vcoco_gen_vlkt_l_r101_dec_6layers 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r101-pre-2branch-vcoco.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file vcoco \ 13 | --hoi_path data/v-coco \ 14 | --num_obj_classes 81 \ 15 | --num_verb_classes 29 \ 16 | --backbone resnet101 \ 17 | --num_queries 64 \ 18 | --dec_layers 6 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 27 | -------------------------------------------------------------------------------- /configs/vcoco_m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/vcoco_gen_vlkt_m_r101_dec_3layers 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r101-pre-2branch-vcoco.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file vcoco \ 13 | --hoi_path data/v-coco \ 14 | --num_obj_classes 81 \ 15 | --num_verb_classes 29 \ 16 | --backbone resnet101 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 27 | -------------------------------------------------------------------------------- /configs/vcoco_s.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/vcoco_gen_vlkt_s_r50_dec_3layers 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r50-pre-2branch-vcoco.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file vcoco \ 13 | --hoi_path data/v-coco \ 14 | --num_obj_classes 81 \ 15 | --num_verb_classes 29 \ 16 | --backbone resnet50 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 27 | -------------------------------------------------------------------------------- /configs/hico_s.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/hico_gen_vlkt_s_r50_dec_3layers 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file hico \ 13 | --hoi_path data/hico_20160224_det \ 14 | --num_obj_classes 80 \ 15 | --num_verb_classes 117 \ 16 | --backbone resnet50 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 27 | -------------------------------------------------------------------------------- /configs/hico_l.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/hico_gen_vlkt_l_r101_dec_6layers 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r101-pre-2branch-hico.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file hico \ 13 | --hoi_path data/hico_20160224_det \ 14 | --num_obj_classes 80 \ 15 | --num_verb_classes 117 \ 16 | --backbone resnet101 \ 17 | --num_queries 64 \ 18 | --dec_layers 6 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 27 | -------------------------------------------------------------------------------- /configs/hico_m.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/hico_gen_vlkt_m_r101_dec_3layers 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r101-pre-2branch-hico.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file hico \ 13 | --hoi_path data/hico_20160224_det \ 14 | --num_obj_classes 80 \ 15 | --num_verb_classes 117 \ 16 | --backbone resnet101 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 27 | -------------------------------------------------------------------------------- /util/topk.py: -------------------------------------------------------------------------------- 1 | def sift(li, low, higt): 2 | tmp = li[low] 3 | i = low 4 | j = 2 * i + 1 5 | while j <= higt: # 情况2:i已经是最后一层 6 | if j + 1 <= higt and li[j + 1] < li[j]: # 右孩子存在并且小于左孩子 7 | j += 1 8 | if tmp > li[j]: 9 | li[i] = li[j] 10 | i = j 11 | j = 2 * i + 1 12 | else: 13 | break # 情况1:j位置比tmp小 14 | li[i] = tmp 15 | 16 | def top_k(li, k): 17 | heap = li[0:k] 18 | # 建堆 19 | for i in range(k // 2 - 1, -1, -1): 20 | sift(heap, i, k - 1) 21 | for i in range(k, len(li)): 22 | if li[i] > heap[0]: 23 | heap[0] = li[i] 24 | sift(heap, 0, k - 1) 25 | # 挨个输出 26 | for i in range(k - 1, -1, -1): 27 | heap[0], heap[i] = heap[i], heap[0] 28 | sift(heap, 0, i - 1) 29 | 30 | return heap 31 | 32 | -------------------------------------------------------------------------------- /configs/hico_s_zs_uo.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/hico_gen_vlkt_s_r50_dec_3layers_zero_shot_uo 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file hico \ 13 | --hoi_path data/hico_20160224_det \ 14 | --num_obj_classes 80 \ 15 | --num_verb_classes 117 \ 16 | --backbone resnet50 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 \ 27 | --zero_shot_type unseen_object \ 28 | --fix_clip \ 29 | --del_unseen 30 | -------------------------------------------------------------------------------- /configs/hico_s_zs_uv.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/hico_gen_vlkt_s_r50_dec_3layers_zero_shot_uv 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r50-pre-2branch-hico.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file hico \ 13 | --hoi_path data/hico_20160224_det \ 14 | --num_obj_classes 80 \ 15 | --num_verb_classes 117 \ 16 | --backbone resnet50 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 \ 27 | --zero_shot_type unseen_verb \ 28 | --fix_clip \ 29 | --del_unseen 30 | -------------------------------------------------------------------------------- /configs/hico_s_zs_nf_uc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/hico_gen_vlkt_s_r50_dec_3layers_zero_shot_nf_uc 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r50-pre-2stage-hico.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file hico \ 13 | --hoi_path data/hico_20160224_det \ 14 | --num_obj_classes 80 \ 15 | --num_verb_classes 117 \ 16 | --backbone resnet50 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 \ 27 | --zero_shot_type non_rare_first \ 28 | --fix_clip \ 29 | --del_unseen 30 | -------------------------------------------------------------------------------- /configs/hico_s_zs_rf_uc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | EXP_DIR=exps/hico_gen_vlkt_s_r50_dec_3layers_zero_shot_rf_uc 5 | 6 | python -m torch.distributed.launch \ 7 | --nproc_per_node=8 \ 8 | --use_env \ 9 | main.py \ 10 | --pretrained params/detr-r50-pre-2stage-hico.pth \ 11 | --output_dir ${EXP_DIR} \ 12 | --dataset_file hico \ 13 | --hoi_path data/hico_20160224_det \ 14 | --num_obj_classes 80 \ 15 | --num_verb_classes 117 \ 16 | --backbone resnet50 \ 17 | --num_queries 64 \ 18 | --dec_layers 3 \ 19 | --epochs 90 \ 20 | --lr_drop 60 \ 21 | --use_nms_filter \ 22 | --ft_clip_with_small_lr \ 23 | --with_clip_label \ 24 | --with_obj_clip_label \ 25 | --with_mimic \ 26 | --mimic_loss_coef 20 \ 27 | --zero_shot_type rare_first \ 28 | --fix_clip \ 29 | --del_unseen 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Yue Liao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /util/box_ops.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | """ 9 | Utilities for bounding box manipulation and GIoU. 10 | """ 11 | import torch 12 | from torchvision.ops.boxes import box_area 13 | 14 | 15 | def box_cxcywh_to_xyxy(x): 16 | x_c, y_c, w, h = x.unbind(-1) 17 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 18 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | def box_xyxy_to_cxcywh(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 25 | (x1 - x0), (y1 - y0)] 26 | return torch.stack(b, dim=-1) 27 | 28 | 29 | # modified from torchvision to also return the union 30 | def box_iou(boxes1, boxes2): 31 | area1 = box_area(boxes1) 32 | area2 = box_area(boxes2) 33 | 34 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 35 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 36 | 37 | wh = (rb - lt).clamp(min=0) # [N,M,2] 38 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 39 | 40 | union = area1[:, None] + area2 - inter 41 | 42 | iou = inter / union 43 | return iou, union 44 | 45 | 46 | def generalized_box_iou(boxes1, boxes2): 47 | """ 48 | Generalized IoU from https://giou.stanford.edu/ 49 | 50 | The boxes should be in [x0, y0, x1, y1] format 51 | 52 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 53 | and M = len(boxes2) 54 | """ 55 | # degenerate boxes gives inf / nan results 56 | # so do an early check 57 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 58 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 59 | iou, union = box_iou(boxes1, boxes2) 60 | 61 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 62 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 63 | 64 | wh = (rb - lt).clamp(min=0) # [N,M,2] 65 | area = wh[:, :, 0] * wh[:, :, 1] 66 | 67 | return iou - (area - union) / area 68 | 69 | 70 | def masks_to_boxes(masks): 71 | """Compute the bounding boxes around the provided masks 72 | 73 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 74 | 75 | Returns a [N, 4] tensors, with the boxes in xyxy format 76 | """ 77 | if masks.numel() == 0: 78 | return torch.zeros((0, 4), device=masks.device) 79 | 80 | h, w = masks.shape[-2:] 81 | 82 | y = torch.arange(0, h, dtype=torch.float) 83 | x = torch.arange(0, w, dtype=torch.float) 84 | y, x = torch.meshgrid(y, x) 85 | 86 | x_mask = (masks * x.unsqueeze(0)) 87 | x_max = x_mask.flatten(1).max(-1)[0] 88 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 89 | 90 | y_mask = (masks * y.unsqueeze(0)) 91 | y_max = y_mask.flatten(1).max(-1)[0] 92 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 93 | 94 | return torch.stack([x_min, y_min, x_max, y_max], 1) 95 | -------------------------------------------------------------------------------- /tools/convert_parameters.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser() 9 | 10 | parser.add_argument( 11 | '--load_path', type=str, required=True, 12 | ) 13 | parser.add_argument( 14 | '--save_path', type=str, required=True, 15 | ) 16 | parser.add_argument( 17 | '--dataset', type=str, default='hico', 18 | ) 19 | parser.add_argument( 20 | '--num_queries', type=int, default=100, 21 | ) 22 | 23 | args = parser.parse_args() 24 | 25 | return args 26 | 27 | 28 | def main(args): 29 | ps = torch.load(args.load_path) 30 | 31 | obj_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 32 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 33 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 34 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 35 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 36 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 37 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 38 | 82, 84, 85, 86, 87, 88, 89, 90] 39 | 40 | # For no pair 41 | obj_ids.append(91) 42 | 43 | for k in list(ps['model'].keys()): 44 | print(k) 45 | if len(k.split('.')) > 1 and k.split('.')[1] == 'decoder': 46 | ps['model'][k.replace('decoder', 'instance_decoder')] = ps['model'][k].clone() 47 | ps['model'][k.replace('decoder', 'interaction_decoder')] = ps['model'][k].clone() 48 | del ps['model'][k] 49 | 50 | ps['model']['hum_bbox_embed.layers.0.weight'] = ps['model']['bbox_embed.layers.0.weight'].clone() 51 | ps['model']['hum_bbox_embed.layers.0.bias'] = ps['model']['bbox_embed.layers.0.bias'].clone() 52 | ps['model']['hum_bbox_embed.layers.1.weight'] = ps['model']['bbox_embed.layers.1.weight'].clone() 53 | ps['model']['hum_bbox_embed.layers.1.bias'] = ps['model']['bbox_embed.layers.1.bias'].clone() 54 | ps['model']['hum_bbox_embed.layers.2.weight'] = ps['model']['bbox_embed.layers.2.weight'].clone() 55 | ps['model']['hum_bbox_embed.layers.2.bias'] = ps['model']['bbox_embed.layers.2.bias'].clone() 56 | 57 | ps['model']['obj_bbox_embed.layers.0.weight'] = ps['model']['bbox_embed.layers.0.weight'].clone() 58 | ps['model']['obj_bbox_embed.layers.0.bias'] = ps['model']['bbox_embed.layers.0.bias'].clone() 59 | ps['model']['obj_bbox_embed.layers.1.weight'] = ps['model']['bbox_embed.layers.1.weight'].clone() 60 | ps['model']['obj_bbox_embed.layers.1.bias'] = ps['model']['bbox_embed.layers.1.bias'].clone() 61 | ps['model']['obj_bbox_embed.layers.2.weight'] = ps['model']['bbox_embed.layers.2.weight'].clone() 62 | ps['model']['obj_bbox_embed.layers.2.bias'] = ps['model']['bbox_embed.layers.2.bias'].clone() 63 | 64 | ps['model']['obj_class_embed.weight'] = ps['model']['class_embed.weight'].clone()[obj_ids] 65 | ps['model']['obj_class_embed.bias'] = ps['model']['class_embed.bias'].clone()[obj_ids] 66 | 67 | ps['model']['query_embed.weight'] = ps['model']['query_embed.weight'].clone()[:args.num_queries] 68 | 69 | if args.dataset == 'vcoco': 70 | l = nn.Linear(ps['model']['obj_class_embed.weight'].shape[1], 1) 71 | l.to(ps['model']['obj_class_embed.weight'].device) 72 | ps['model']['obj_class_embed.weight'] = torch.cat(( 73 | ps['model']['obj_class_embed.weight'][:-1], l.weight, ps['model']['obj_class_embed.weight'][[-1]])) 74 | ps['model']['obj_class_embed.bias'] = torch.cat( 75 | (ps['model']['obj_class_embed.bias'][:-1], l.bias, ps['model']['obj_class_embed.bias'][[-1]])) 76 | 77 | torch.save(ps, args.save_path) 78 | 79 | 80 | if __name__ == '__main__': 81 | args = get_args() 82 | main(args) 83 | -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Various positional encodings for the transformer. 7 | """ 8 | import math 9 | import torch 10 | from torch import nn 11 | 12 | from util.misc import NestedTensor 13 | 14 | 15 | class PositionEmbeddingSine(nn.Module): 16 | """ 17 | This is a more standard version of the position embedding, very similar to the one 18 | used by the Attention is all you need paper, generalized to work on images. 19 | """ 20 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 21 | super().__init__() 22 | self.num_pos_feats = num_pos_feats 23 | self.temperature = temperature 24 | self.normalize = normalize 25 | if scale is not None and normalize is False: 26 | raise ValueError("normalize should be True if scale is passed") 27 | if scale is None: 28 | scale = 2 * math.pi 29 | self.scale = scale 30 | 31 | def forward(self, tensor_list: NestedTensor): 32 | x = tensor_list.tensors 33 | mask = tensor_list.mask 34 | assert mask is not None 35 | not_mask = ~mask 36 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 37 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 38 | if self.normalize: 39 | eps = 1e-6 40 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 41 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 42 | 43 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 44 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 45 | 46 | pos_x = x_embed[:, :, :, None] / dim_t 47 | pos_y = y_embed[:, :, :, None] / dim_t 48 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 49 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 50 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 51 | return pos 52 | 53 | 54 | class PositionEmbeddingLearned(nn.Module): 55 | """ 56 | Absolute pos embedding, learned. 57 | """ 58 | def __init__(self, num_pos_feats=256): 59 | super().__init__() 60 | self.row_embed = nn.Embedding(50, num_pos_feats) 61 | self.col_embed = nn.Embedding(50, num_pos_feats) 62 | self.reset_parameters() 63 | 64 | def reset_parameters(self): 65 | nn.init.uniform_(self.row_embed.weight) 66 | nn.init.uniform_(self.col_embed.weight) 67 | 68 | def forward(self, tensor_list: NestedTensor): 69 | x = tensor_list.tensors 70 | h, w = x.shape[-2:] 71 | i = torch.arange(w, device=x.device) 72 | j = torch.arange(h, device=x.device) 73 | x_emb = self.col_embed(i) 74 | y_emb = self.row_embed(j) 75 | pos = torch.cat([ 76 | x_emb.unsqueeze(0).repeat(h, 1, 1), 77 | y_emb.unsqueeze(1).repeat(1, w, 1), 78 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 79 | return pos 80 | 81 | 82 | def build_position_encoding(args): 83 | N_steps = args.hidden_dim // 2 84 | if args.position_embedding in ('v2', 'sine'): 85 | # TODO find a better way of exposing other arguments 86 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 87 | elif args.position_embedding in ('v3', 'learned'): 88 | position_embedding = PositionEmbeddingLearned(N_steps) 89 | else: 90 | raise ValueError(f"not supported {args.position_embedding}") 91 | 92 | return position_embedding 93 | -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import linear_sum_assignment 3 | from torch import nn 4 | 5 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 6 | 7 | class HungarianMatcherHOI(nn.Module): 8 | 9 | def __init__(self, cost_obj_class: float = 1, cost_verb_class: float = 1, cost_bbox: float = 1, 10 | cost_giou: float = 1, cost_hoi_class: float = 1): 11 | super().__init__() 12 | self.cost_obj_class = cost_obj_class 13 | self.cost_verb_class = cost_verb_class 14 | self.cost_hoi_class = cost_hoi_class 15 | self.cost_bbox = cost_bbox 16 | self.cost_giou = cost_giou 17 | assert cost_obj_class != 0 or cost_verb_class != 0 or cost_bbox != 0 or cost_giou != 0, 'all costs cant be 0' 18 | 19 | @torch.no_grad() 20 | def forward(self, outputs, targets): 21 | bs, num_queries = outputs['pred_sub_boxes'].shape[:2] 22 | if 'pred_hoi_logits' in outputs.keys(): 23 | out_hoi_prob = outputs['pred_hoi_logits'].flatten(0, 1).sigmoid() 24 | tgt_hoi_labels = torch.cat([v['hoi_labels'] for v in targets]) 25 | tgt_hoi_labels_permute = tgt_hoi_labels.permute(1, 0) 26 | cost_hoi_class = -(out_hoi_prob.matmul(tgt_hoi_labels_permute) / \ 27 | (tgt_hoi_labels_permute.sum(dim=0, keepdim=True) + 1e-4) + \ 28 | (1 - out_hoi_prob).matmul(1 - tgt_hoi_labels_permute) / \ 29 | ((1 - tgt_hoi_labels_permute).sum(dim=0, keepdim=True) + 1e-4)) / 2 30 | cost_hoi_class = self.cost_hoi_class * cost_hoi_class 31 | else: 32 | 33 | out_verb_prob = outputs['pred_verb_logits'].flatten(0, 1).sigmoid() 34 | tgt_verb_labels = torch.cat([v['verb_labels'] for v in targets]) 35 | tgt_verb_labels_permute = tgt_verb_labels.permute(1, 0) 36 | cost_verb_class = -(out_verb_prob.matmul(tgt_verb_labels_permute) / \ 37 | (tgt_verb_labels_permute.sum(dim=0, keepdim=True) + 1e-4) + \ 38 | (1 - out_verb_prob).matmul(1 - tgt_verb_labels_permute) / \ 39 | ((1 - tgt_verb_labels_permute).sum(dim=0, keepdim=True) + 1e-4)) / 2 40 | 41 | cost_hoi_class = self.cost_verb_class * cost_verb_class 42 | tgt_obj_labels = torch.cat([v['obj_labels'] for v in targets]) 43 | out_obj_prob = outputs['pred_obj_logits'].flatten(0, 1).softmax(-1) 44 | cost_obj_class = -out_obj_prob[:, tgt_obj_labels] 45 | out_sub_bbox = outputs['pred_sub_boxes'].flatten(0, 1) 46 | out_obj_bbox = outputs['pred_obj_boxes'].flatten(0, 1) 47 | 48 | tgt_sub_boxes = torch.cat([v['sub_boxes'] for v in targets]) 49 | tgt_obj_boxes = torch.cat([v['obj_boxes'] for v in targets]) 50 | 51 | cost_sub_bbox = torch.cdist(out_sub_bbox, tgt_sub_boxes, p=1) 52 | cost_obj_bbox = torch.cdist(out_obj_bbox, tgt_obj_boxes, p=1) * (tgt_obj_boxes != 0).any(dim=1).unsqueeze(0) 53 | if cost_sub_bbox.shape[1] == 0: 54 | cost_bbox = cost_sub_bbox 55 | else: 56 | cost_bbox = torch.stack((cost_sub_bbox, cost_obj_bbox)).max(dim=0)[0] 57 | 58 | cost_sub_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_sub_bbox), box_cxcywh_to_xyxy(tgt_sub_boxes)) 59 | cost_obj_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_obj_bbox), box_cxcywh_to_xyxy(tgt_obj_boxes)) + \ 60 | cost_sub_giou * (tgt_obj_boxes == 0).all(dim=1).unsqueeze(0) 61 | if cost_sub_giou.shape[1] == 0: 62 | cost_giou = cost_sub_giou 63 | else: 64 | cost_giou = torch.stack((cost_sub_giou, cost_obj_giou)).max(dim=0)[0] 65 | 66 | C = self.cost_hoi_class * cost_hoi_class + self.cost_bbox * cost_bbox + \ 67 | self.cost_giou * cost_giou + self.cost_obj_class * cost_obj_class 68 | 69 | C = C.view(bs, num_queries, -1).cpu() 70 | 71 | sizes = [len(v['sub_boxes']) for v in targets] 72 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 73 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 74 | 75 | 76 | def build_matcher(args): 77 | return HungarianMatcherHOI(cost_obj_class=args.set_cost_obj_class, cost_verb_class=args.set_cost_verb_class, 78 | cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou, 79 | cost_hoi_class=args.set_cost_hoi) 80 | -------------------------------------------------------------------------------- /util/plot_utils.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | """ 9 | Plotting utilities to visualize training logs. 10 | """ 11 | import torch 12 | import pandas as pd 13 | import seaborn as sns 14 | import matplotlib.pyplot as plt 15 | 16 | from pathlib import Path, PurePath 17 | 18 | 19 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 20 | ''' 21 | Function to plot specific fields from training log(s). Plots both training and test results. 22 | 23 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 24 | - fields = which results to plot from each log file - plots both training and test for each field. 25 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 26 | - log_name = optional, name of log file if different than default 'log.txt'. 27 | 28 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 29 | - solid lines are training results, dashed lines are test results. 30 | 31 | ''' 32 | func_name = "plot_utils.py::plot_logs" 33 | 34 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 35 | # convert single Path to list to avoid 'not iterable' error 36 | 37 | if not isinstance(logs, list): 38 | if isinstance(logs, PurePath): 39 | logs = [logs] 40 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 41 | else: 42 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 43 | Expect list[Path] or single Path obj, received {type(logs)}") 44 | 45 | # verify valid dir(s) and that every item in list is Path object 46 | for i, dir in enumerate(logs): 47 | if not isinstance(dir, PurePath): 48 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 49 | if dir.exists(): 50 | continue 51 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 52 | 53 | # load log file(s) and plot 54 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 55 | 56 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 57 | 58 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 59 | for j, field in enumerate(fields): 60 | if field == 'mAP': 61 | coco_eval = pd.DataFrame(pd.np.stack(df.test_coco_eval.dropna().values)[:, 1]).ewm(com=ewm_col).mean() 62 | axs[j].plot(coco_eval, c=color) 63 | else: 64 | df.interpolate().ewm(com=ewm_col).mean().plot( 65 | y=[f'train_{field}', f'test_{field}'], 66 | ax=axs[j], 67 | color=[color] * 2, 68 | style=['-', '--'] 69 | ) 70 | for ax, field in zip(axs, fields): 71 | ax.legend([Path(p).name for p in logs]) 72 | ax.set_title(field) 73 | 74 | 75 | def plot_precision_recall(files, naming_scheme='iter'): 76 | if naming_scheme == 'exp_id': 77 | # name becomes exp_id 78 | names = [f.parts[-3] for f in files] 79 | elif naming_scheme == 'iter': 80 | names = [f.stem for f in files] 81 | else: 82 | raise ValueError(f'not supported {naming_scheme}') 83 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 84 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 85 | data = torch.load(f) 86 | # precision is n_iou, n_points, n_cat, n_area, max_det 87 | precision = data['precision'] 88 | recall = data['params'].recThrs 89 | scores = data['scores'] 90 | # take precision for all classes, all areas and 100 detections 91 | precision = precision[0, :, :, 0, -1].mean(1) 92 | scores = scores[0, :, :, 0, -1].mean(1) 93 | prec = precision.mean() 94 | rec = data['recall'][0, :, 0, -1].mean() 95 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 96 | f'score={scores.mean():0.3f}, ' + 97 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 98 | ) 99 | axs[0].plot(recall, precision, c=color) 100 | axs[1].plot(recall, scores, c=color) 101 | 102 | axs[0].set_title('Precision / Recall') 103 | axs[0].legend(names) 104 | axs[1].set_title('Scores / Recall') 105 | axs[1].legend(names) 106 | return fig, axs 107 | -------------------------------------------------------------------------------- /models/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Backbone modules. 7 | """ 8 | from collections import OrderedDict 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | import torchvision 13 | from torch import nn 14 | from torchvision.models._utils import IntermediateLayerGetter 15 | from typing import Dict, List 16 | 17 | from util.misc import NestedTensor, is_main_process 18 | 19 | from .position_encoding import build_position_encoding 20 | 21 | 22 | class FrozenBatchNorm2d(torch.nn.Module): 23 | """ 24 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 25 | 26 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 27 | without which any other models than torchvision.models.resnet[18,34,50,101] 28 | produce nans. 29 | """ 30 | 31 | def __init__(self, n): 32 | super(FrozenBatchNorm2d, self).__init__() 33 | self.register_buffer("weight", torch.ones(n)) 34 | self.register_buffer("bias", torch.zeros(n)) 35 | self.register_buffer("running_mean", torch.zeros(n)) 36 | self.register_buffer("running_var", torch.ones(n)) 37 | 38 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 39 | missing_keys, unexpected_keys, error_msgs): 40 | num_batches_tracked_key = prefix + 'num_batches_tracked' 41 | if num_batches_tracked_key in state_dict: 42 | del state_dict[num_batches_tracked_key] 43 | 44 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 45 | state_dict, prefix, local_metadata, strict, 46 | missing_keys, unexpected_keys, error_msgs) 47 | 48 | def forward(self, x): 49 | # move reshapes to the beginning 50 | # to make it fuser-friendly 51 | w = self.weight.reshape(1, -1, 1, 1) 52 | b = self.bias.reshape(1, -1, 1, 1) 53 | rv = self.running_var.reshape(1, -1, 1, 1) 54 | rm = self.running_mean.reshape(1, -1, 1, 1) 55 | eps = 1e-5 56 | scale = w * (rv + eps).rsqrt() 57 | bias = b - rm * scale 58 | return x * scale + bias 59 | 60 | 61 | class BackboneBase(nn.Module): 62 | 63 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 64 | super().__init__() 65 | for name, parameter in backbone.named_parameters(): 66 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 67 | parameter.requires_grad_(False) 68 | if return_interm_layers: 69 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 70 | else: 71 | return_layers = {'layer4': "0"} 72 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 73 | self.num_channels = num_channels 74 | 75 | def forward(self, tensor_list: NestedTensor): 76 | xs = self.body(tensor_list.tensors) 77 | out: Dict[str, NestedTensor] = {} 78 | for name, x in xs.items(): 79 | m = tensor_list.mask 80 | assert m is not None 81 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 82 | out[name] = NestedTensor(x, mask) 83 | return out 84 | 85 | 86 | class Backbone(BackboneBase): 87 | """ResNet backbone with frozen BatchNorm.""" 88 | def __init__(self, name: str, 89 | train_backbone: bool, 90 | return_interm_layers: bool, 91 | dilation: bool): 92 | backbone = getattr(torchvision.models, name)( 93 | replace_stride_with_dilation=[False, False, dilation], 94 | pretrained=False, norm_layer=FrozenBatchNorm2d) 95 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 96 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 97 | 98 | 99 | class Joiner(nn.Sequential): 100 | def __init__(self, backbone, position_embedding): 101 | super().__init__(backbone, position_embedding) 102 | 103 | def forward(self, tensor_list: NestedTensor): 104 | xs = self[0](tensor_list) 105 | out: List[NestedTensor] = [] 106 | pos = [] 107 | for name, x in xs.items(): 108 | out.append(x) 109 | # position encoding 110 | pos.append(self[1](x).to(x.tensors.dtype)) 111 | 112 | return out, pos 113 | 114 | 115 | def build_backbone(args): 116 | position_embedding = build_position_encoding(args) 117 | train_backbone = args.lr_backbone > 0 118 | return_interm_layers = args.masks 119 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 120 | model = Joiner(backbone, position_embedding) 121 | model.num_channels = backbone.num_channels 122 | return model 123 | -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import sys 4 | from typing import Iterable 5 | import numpy as np 6 | import copy 7 | import itertools 8 | 9 | import torch 10 | 11 | import util.misc as utils 12 | from datasets.hico_eval_triplet import HICOEvaluator 13 | from datasets.vcoco_eval import VCOCOEvaluator 14 | 15 | 16 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 17 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 18 | device: torch.device, epoch: int, max_norm: float = 0): 19 | model.train() 20 | criterion.train() 21 | metric_logger = utils.MetricLogger(delimiter=" ") 22 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 23 | if hasattr(criterion, 'loss_labels'): 24 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 25 | elif hasattr(criterion, 'loss_hoi_labels'): 26 | metric_logger.add_meter('hoi_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 27 | else: 28 | metric_logger.add_meter('obj_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 29 | header = 'Epoch: [{}]'.format(epoch) 30 | print_freq = 10 31 | 32 | for samples, targets in metric_logger.log_every(data_loader, print_freq, header): 33 | samples = samples.to(device) 34 | targets = [{k: v.to(device) for k, v in t.items() if k != 'filename'} for t in targets] 35 | outputs = model(samples) 36 | # print(targets) 37 | loss_dict = criterion(outputs, targets) 38 | weight_dict = criterion.weight_dict 39 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 40 | 41 | # reduce losses over all GPUs for logging purposes 42 | loss_dict_reduced = utils.reduce_dict(loss_dict) 43 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 44 | for k, v in loss_dict_reduced.items()} 45 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 46 | for k, v in loss_dict_reduced.items() if k in weight_dict} 47 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 48 | 49 | loss_value = losses_reduced_scaled.item() 50 | # print(loss_value) 51 | # sys.exit() 52 | 53 | if not math.isfinite(loss_value): 54 | print("Loss is {}, stopping training".format(loss_value)) 55 | print(loss_dict_reduced) 56 | sys.exit(1) 57 | 58 | optimizer.zero_grad() 59 | losses.backward() 60 | if max_norm > 0: 61 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 62 | optimizer.step() 63 | 64 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 65 | if hasattr(criterion, 'loss_labels'): 66 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 67 | elif hasattr(criterion, 'loss_hoi_labels'): 68 | metric_logger.update(hoi_class_error=loss_dict_reduced['hoi_class_error']) 69 | else: 70 | metric_logger.update(obj_class_error=loss_dict_reduced['obj_class_error']) 71 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 72 | 73 | # gather the stats from all processes 74 | metric_logger.synchronize_between_processes() 75 | print("Averaged stats:", metric_logger) 76 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 77 | 78 | 79 | @torch.no_grad() 80 | def evaluate_hoi(dataset_file, model, postprocessors, data_loader, 81 | subject_category_id, device, args): 82 | model.eval() 83 | 84 | metric_logger = utils.MetricLogger(delimiter=" ") 85 | header = 'Test:' 86 | 87 | preds = [] 88 | gts = [] 89 | indices = [] 90 | counter = 0 91 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 92 | samples = samples.to(device) 93 | # print(targets) 94 | outputs = model(samples, is_training=False) 95 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 96 | results = postprocessors['hoi'](outputs, orig_target_sizes) 97 | 98 | preds.extend(list(itertools.chain.from_iterable(utils.all_gather(results)))) 99 | # For avoiding a runtime error, the copy is used 100 | gts.extend(list(itertools.chain.from_iterable(utils.all_gather(copy.deepcopy(targets))))) 101 | 102 | # counter += 1 103 | 104 | 105 | # gather the stats from all processes 106 | metric_logger.synchronize_between_processes() 107 | 108 | img_ids = [img_gts['id'] for img_gts in gts] 109 | _, indices = np.unique(img_ids, return_index=True) 110 | preds = [img_preds for i, img_preds in enumerate(preds) if i in indices] 111 | gts = [img_gts for i, img_gts in enumerate(gts) if i in indices] 112 | 113 | if dataset_file == 'hico': 114 | evaluator = HICOEvaluator(preds, gts, data_loader.dataset.rare_triplets, 115 | data_loader.dataset.non_rare_triplets, data_loader.dataset.correct_mat, args=args) 116 | elif dataset_file == 'vcoco': 117 | evaluator = VCOCOEvaluator(preds, gts, data_loader.dataset.correct_mat, use_nms_filter=args.use_nms_filter) 118 | 119 | stats = evaluator.evaluate() 120 | 121 | return stats 122 | -------------------------------------------------------------------------------- /tools/convert_vcoco_annotations.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from collections import defaultdict 4 | import json 5 | import pickle 6 | import os 7 | 8 | import vsrl_utils as vu 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument( 15 | '--load_path', type=str, required=True, 16 | ) 17 | parser.add_argument( 18 | '--prior_path', type=str, required=True, 19 | ) 20 | parser.add_argument( 21 | '--save_path', type=str, required=True, 22 | ) 23 | 24 | args = parser.parse_args() 25 | 26 | return args 27 | 28 | 29 | def set_hoi(box_annotations, hoi_annotations, verb_classes): 30 | no_object_id = -1 31 | 32 | hoia_annotations = defaultdict(lambda: { 33 | 'annotations': [], 34 | 'hoi_annotation': [] 35 | }) 36 | 37 | for action_annotation in hoi_annotations: 38 | for label, img_id, role_ids in zip(action_annotation['label'][:, 0], 39 | action_annotation['image_id'][:, 0], 40 | action_annotation['role_object_id']): 41 | hoia_annotations[img_id]['file_name'] = box_annotations[img_id]['file_name'] 42 | hoia_annotations[img_id]['annotations'] = box_annotations[img_id]['annotations'] 43 | 44 | if label == 0: 45 | continue 46 | 47 | subject_id = box_annotations[img_id]['annotation_ids'].index(role_ids[0]) 48 | 49 | if len(role_ids) == 1: 50 | hoia_annotations[img_id]['hoi_annotation'].append( 51 | {'subject_id': subject_id, 'object_id': no_object_id, 52 | 'category_id': verb_classes.index(action_annotation['action_name'])}) 53 | continue 54 | 55 | for role_name, role_id in zip(action_annotation['role_name'][1:], role_ids[1:]): 56 | if role_id == 0: 57 | object_id = no_object_id 58 | else: 59 | object_id = box_annotations[img_id]['annotation_ids'].index(role_id) 60 | 61 | hoia_annotations[img_id]['hoi_annotation'].append( 62 | {'subject_id': subject_id, 'object_id': object_id, 63 | 'category_id': verb_classes.index('{}_{}'.format(action_annotation['action_name'], role_name))}) 64 | 65 | hoia_annotations = [v for v in hoia_annotations.values()] 66 | 67 | return hoia_annotations 68 | 69 | 70 | def main(args): 71 | vsgnet_verbs_classes = { 72 | 'carry_obj': 0, 73 | 'catch_obj': 1, 74 | 'cut_instr':2, 75 | 'cut_obj': 3, 76 | 'drink_instr': 4, 77 | 'eat_instr':5, 78 | 'eat_obj': 6, 79 | 'hit_instr':7, 80 | 'hit_obj': 8, 81 | 'hold_obj': 9, 82 | 'jump_instr': 10, 83 | 'kick_obj': 11, 84 | 'lay_instr': 12, 85 | 'look_obj': 13, 86 | 'point_instr': 14, 87 | 'read_obj': 15, 88 | 'ride_instr': 16, 89 | 'run': 17, 90 | 'sit_instr': 18, 91 | 'skateboard_instr': 19, 92 | 'ski_instr': 20, 93 | 'smile': 21, 94 | 'snowboard_instr': 22, 95 | 'stand': 23, 96 | 'surf_instr': 24, 97 | 'talk_on_phone_instr': 25, 98 | 'throw_obj': 26, 99 | 'walk': 27, 100 | 'work_on_computer_instr': 28 101 | } 102 | 103 | box_annotations = defaultdict(lambda: { 104 | 'annotations': [], 105 | 'annotation_ids': [] 106 | }) 107 | 108 | coco = vu.load_coco(args.load_path) 109 | 110 | img_ids = coco.getImgIds() 111 | img_infos = coco.loadImgs(img_ids) 112 | 113 | for img_info in img_infos: 114 | box_annotations[img_info['id']]['file_name'] = img_info['file_name'] 115 | 116 | annotation_ids = coco.getAnnIds(imgIds=img_ids) 117 | annotations = coco.loadAnns(annotation_ids) 118 | for annotation in annotations: 119 | img_id = annotation['image_id'] 120 | category_id = annotation['category_id'] 121 | box = np.array(annotation['bbox']) 122 | box[2:] += box[:2] 123 | 124 | box_annotations[img_id]['annotations'].append({'category_id': category_id, 'bbox': box.tolist()}) 125 | box_annotations[img_id]['annotation_ids'].append(annotation['id']) 126 | 127 | hoi_trainval = vu.load_vcoco('vcoco_trainval') 128 | hoi_test = vu.load_vcoco('vcoco_test') 129 | 130 | action_classes = [x['action_name'] for x in hoi_trainval] 131 | verb_classes = [] 132 | for action in hoi_trainval: 133 | if len(action['role_name']) == 1: 134 | verb_classes.append(action['action_name']) 135 | else: 136 | verb_classes += ['{}_{}'.format(action['action_name'], r) for r in action['role_name'][1:]] 137 | 138 | print('Verb class') 139 | for i, verb_class in enumerate(verb_classes): 140 | print('{:02d}: {}'.format(i, verb_class)) 141 | 142 | hoia_trainval_annotations = set_hoi(box_annotations, hoi_trainval, verb_classes) 143 | hoia_test_annotations = set_hoi(box_annotations, hoi_test, verb_classes) 144 | 145 | print('#Training images: {}, #Test images: {}'.format(len(hoia_trainval_annotations), len(hoia_test_annotations))) 146 | 147 | with open(os.path.join(args.save_path, 'trainval_vcoco.json'), 'w') as f: 148 | json.dump(hoia_trainval_annotations, f) 149 | 150 | with open(os.path.join(args.save_path, 'test_vcoco.json'), 'w') as f: 151 | json.dump(hoia_test_annotations, f) 152 | 153 | with open(args.prior_path, 'rb') as f: 154 | prior = pickle.load(f) 155 | 156 | prior = [prior[k] for k in sorted(prior.keys())] 157 | prior = np.concatenate(prior).T 158 | prior = prior[[vsgnet_verbs_classes[verb_class] for verb_class in verb_classes]] 159 | np.save(os.path.join(args.save_path, 'corre_vcoco.npy'), prior) 160 | 161 | 162 | if __name__ == '__main__': 163 | args = get_args() 164 | main(args) 165 | -------------------------------------------------------------------------------- /tools/covert_annot_for_official_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import scipy.io as sio 4 | import os 5 | 6 | 7 | def Format_Pred(pred_file): 8 | orig_file = json.load(open(pred_file, 'r')) 9 | if isinstance(orig_file, str): 10 | orig_file = eval(orig_file)['preds'] 11 | out_pred = {} 12 | for annot in orig_file: 13 | annot_bbox = annot['predictions'] 14 | annot_hoi = annot['hoi_prediction'] 15 | img_id = int((annot['filename'].split('.')[0]).split('_')[-1]) 16 | for hoi in annot_hoi: 17 | sub_bbox = annot_bbox[hoi['subject_id']] 18 | obj_bbox = annot_bbox[hoi['object_id']] 19 | hoi_cls = int(hoi['category_id'] + 1) 20 | score = hoi['score'] 21 | this_out = {'img_id': img_id, 'human_box': sub_bbox['bbox'], 'object_box': obj_bbox['bbox'], 'score': score} 22 | if hoi_cls not in out_pred.keys(): 23 | out_pred[hoi_cls] = [] 24 | out_pred[hoi_cls].append(this_out) 25 | return out_pred 26 | 27 | 28 | def save_HICO(HICO, HICO_dir, classid, begin, finish): 29 | all_boxes = [] 30 | for i in range(begin, finish + 1): 31 | total = [] 32 | score = [] 33 | if i in HICO.keys(): 34 | for element in HICO[i]: 35 | temp = [] 36 | temp.append(element['human_box']) # Human box 37 | temp.append(element['object_box']) # Object box 38 | temp.append(element['img_id']) # image id 39 | temp.append(int(i - begin)) # action id (0-599) 40 | temp.append(element['score'] * 1000) 41 | total.append(temp) 42 | score.append(element['score'] * 1000) 43 | 44 | idx = np.argsort(score, axis=0)[::-1] 45 | for i_idx in range(min(len(idx), 19999)): 46 | all_boxes.append(total[idx[i_idx]]) 47 | else: 48 | print(i) 49 | savefile = HICO_dir + 'detections_' + str(classid).zfill(2) + '.mat' 50 | sio.savemat(savefile, {'all_boxes': all_boxes}) 51 | 52 | def Generate_HICO_detection(output_file, HICO_dir): 53 | if not os.path.exists(HICO_dir): 54 | os.makedirs(HICO_dir) 55 | 56 | # Remove previous results 57 | filelist = [f for f in os.listdir(HICO_dir)] 58 | for f in filelist: 59 | os.remove(os.path.join(HICO_dir, f)) 60 | 61 | HICO = Format_Pred(output_file) 62 | 63 | save_HICO(HICO, HICO_dir, 1, 161, 170) # 1 person 64 | save_HICO(HICO, HICO_dir, 2, 11, 24) # 2 bicycle 65 | save_HICO(HICO, HICO_dir, 3, 66, 76) # 3 car 66 | save_HICO(HICO, HICO_dir, 4, 147, 160) # 4 motorcycle 67 | save_HICO(HICO, HICO_dir, 5, 1, 10) # 5 airplane 68 | save_HICO(HICO, HICO_dir, 6, 55, 65) # 6 bus 69 | save_HICO(HICO, HICO_dir, 7, 187, 194) # 7 train 70 | save_HICO(HICO, HICO_dir, 8, 568, 576) # 8 truck 71 | save_HICO(HICO, HICO_dir, 9, 32, 46) # 9 boat 72 | save_HICO(HICO, HICO_dir, 10, 563, 567) # 10 traffic light 73 | save_HICO(HICO, HICO_dir, 11, 326, 330) # 11 fire_hydrant 74 | save_HICO(HICO, HICO_dir, 12, 503, 506) # 12 stop_sign 75 | save_HICO(HICO, HICO_dir, 13, 415, 418) # 13 parking_meter 76 | save_HICO(HICO, HICO_dir, 14, 244, 247) # 14 bench 77 | save_HICO(HICO, HICO_dir, 15, 25, 31) # 15 bird 78 | save_HICO(HICO, HICO_dir, 16, 77, 86) # 16 cat 79 | save_HICO(HICO, HICO_dir, 17, 112, 129) # 17 dog 80 | save_HICO(HICO, HICO_dir, 18, 130, 146) # 18 horse 81 | save_HICO(HICO, HICO_dir, 19, 175, 186) # 19 sheep 82 | save_HICO(HICO, HICO_dir, 20, 97, 107) # 20 cow 83 | save_HICO(HICO, HICO_dir, 21, 314, 325) # 21 elephant 84 | save_HICO(HICO, HICO_dir, 22, 236, 239) # 22 bear 85 | save_HICO(HICO, HICO_dir, 23, 596, 600) # 23 zebra 86 | save_HICO(HICO, HICO_dir, 24, 343, 348) # 24 giraffe 87 | save_HICO(HICO, HICO_dir, 25, 209, 214) # 25 backpack 88 | save_HICO(HICO, HICO_dir, 26, 577, 584) # 26 umbrella 89 | save_HICO(HICO, HICO_dir, 27, 353, 356) # 27 handbag 90 | save_HICO(HICO, HICO_dir, 28, 539, 546) # 28 tie 91 | save_HICO(HICO, HICO_dir, 29, 507, 516) # 29 suitcase 92 | save_HICO(HICO, HICO_dir, 30, 337, 342) # 30 Frisbee 93 | save_HICO(HICO, HICO_dir, 31, 464, 474) # 31 skis 94 | save_HICO(HICO, HICO_dir, 32, 475, 483) # 32 snowboard 95 | save_HICO(HICO, HICO_dir, 33, 489, 502) # 33 sports_ball 96 | save_HICO(HICO, HICO_dir, 34, 369, 376) # 34 kite 97 | save_HICO(HICO, HICO_dir, 35, 225, 232) # 35 baseball_bat 98 | save_HICO(HICO, HICO_dir, 36, 233, 235) # 36 baseball_glove 99 | save_HICO(HICO, HICO_dir, 37, 454, 463) # 37 skateboard 100 | save_HICO(HICO, HICO_dir, 38, 517, 528) # 38 surfboard 101 | save_HICO(HICO, HICO_dir, 39, 534, 538) # 39 tennis_racket 102 | save_HICO(HICO, HICO_dir, 40, 47, 54) # 40 bottle 103 | save_HICO(HICO, HICO_dir, 41, 589, 595) # 41 wine_glass 104 | save_HICO(HICO, HICO_dir, 42, 296, 305) # 42 cup 105 | save_HICO(HICO, HICO_dir, 43, 331, 336) # 43 fork 106 | save_HICO(HICO, HICO_dir, 44, 377, 383) # 44 knife 107 | save_HICO(HICO, HICO_dir, 45, 484, 488) # 45 spoon 108 | save_HICO(HICO, HICO_dir, 46, 253, 257) # 46 bowl 109 | save_HICO(HICO, HICO_dir, 47, 215, 224) # 47 banana 110 | save_HICO(HICO, HICO_dir, 48, 199, 208) # 48 apple 111 | save_HICO(HICO, HICO_dir, 49, 439, 445) # 49 sandwich 112 | save_HICO(HICO, HICO_dir, 50, 398, 407) # 50 orange 113 | save_HICO(HICO, HICO_dir, 51, 258, 264) # 51 broccoli 114 | save_HICO(HICO, HICO_dir, 52, 274, 283) # 52 carrot 115 | save_HICO(HICO, HICO_dir, 53, 357, 363) # 53 hot_dog 116 | save_HICO(HICO, HICO_dir, 54, 419, 429) # 54 pizza 117 | save_HICO(HICO, HICO_dir, 55, 306, 313) # 55 donut 118 | save_HICO(HICO, HICO_dir, 56, 265, 273) # 56 cake 119 | save_HICO(HICO, HICO_dir, 57, 87, 92) # 57 chair 120 | save_HICO(HICO, HICO_dir, 58, 93, 96) # 58 couch 121 | save_HICO(HICO, HICO_dir, 59, 171, 174) # 59 potted_plant 122 | save_HICO(HICO, HICO_dir, 60, 240, 243) # 60 bed 123 | save_HICO(HICO, HICO_dir, 61, 108, 111) # 61 dining_table 124 | save_HICO(HICO, HICO_dir, 62, 551, 558) # 62 toilet 125 | save_HICO(HICO, HICO_dir, 63, 195, 198) # 63 TV 126 | save_HICO(HICO, HICO_dir, 64, 384, 389) # 64 laptop 127 | save_HICO(HICO, HICO_dir, 65, 394, 397) # 65 mouse 128 | save_HICO(HICO, HICO_dir, 66, 435, 438) # 66 remote 129 | save_HICO(HICO, HICO_dir, 67, 364, 368) # 67 keyboard 130 | save_HICO(HICO, HICO_dir, 68, 284, 290) # 68 cell_phone 131 | save_HICO(HICO, HICO_dir, 69, 390, 393) # 69 microwave 132 | save_HICO(HICO, HICO_dir, 70, 408, 414) # 70 oven 133 | save_HICO(HICO, HICO_dir, 71, 547, 550) # 71 toaster 134 | save_HICO(HICO, HICO_dir, 72, 450, 453) # 72 sink 135 | save_HICO(HICO, HICO_dir, 73, 430, 434) # 73 refrigerator 136 | save_HICO(HICO, HICO_dir, 74, 248, 252) # 74 book 137 | save_HICO(HICO, HICO_dir, 75, 291, 295) # 75 clock 138 | save_HICO(HICO, HICO_dir, 76, 585, 588) # 76 vase 139 | save_HICO(HICO, HICO_dir, 77, 446, 449) # 77 scissors 140 | save_HICO(HICO, HICO_dir, 78, 529, 533) # 78 teddy_bear 141 | save_HICO(HICO, HICO_dir, 79, 349, 352) # 79 hair_drier 142 | save_HICO(HICO, HICO_dir, 80, 559, 562) # 80 toothbrush 143 | 144 | 145 | if __name__ == '__main__': 146 | Generate_HICO_detection('./hico_gen.json', './ppdm_results/') 147 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GEN-VLKT 2 | Code for our CVPR 2022 paper "[GEN-VLKT: Simplify Association and Enhance Interaction Understanding for HOI Detection](https://arxiv.org/pdf/2203.13954)". 3 | 4 | Contributed by [Yue Liao*](https://liaoyue.net/), Aixi Zhang*, Miao Lu, Yongliang Wang, Xiaobo Li and [Si Liu](http://colalab.org/people). 5 | 6 | ![](paper_images/intro.png) 7 | 8 | ## Installation 9 | Installl the dependencies. 10 | ``` 11 | pip install -r requirements.txt 12 | ``` 13 | Clone and build CLIP. 14 | ``` 15 | git clone https://github.com/openai/CLIP.git && cd CLIP && python setup.py develop && cd .. 16 | ``` 17 | ## Data preparation 18 | 19 | ### HICO-DET 20 | HICO-DET dataset can be downloaded [here](https://drive.google.com/file/d/1dUByzVzM6z1Oq4gENa1-t0FLhr0UtDaS/view). After finishing downloading, unpack the tarball (`hico_20160224_det.tar.gz`) to the `data` directory. 21 | 22 | Instead of using the original annotations files, we use the annotation files provided by the PPDM authors. The annotation files can be downloaded from [here](https://drive.google.com/open?id=1WI-gsNLS-t0Kh8TVki1wXqc3y2Ow1f2R). The downloaded annotation files have to be placed as follows. 23 | ``` 24 | data 25 | └─ hico_20160224_det 26 | |─ annotations 27 | | |─ trainval_hico.json 28 | | |─ test_hico.json 29 | | └─ corre_hico.npy 30 | : 31 | ``` 32 | 33 | ### V-COCO 34 | First clone the repository of V-COCO from [here](https://github.com/s-gupta/v-coco), and then follow the instruction to generate the file `instances_vcoco_all_2014.json`. Next, download the prior file `prior.pickle` from [here](https://drive.google.com/drive/folders/10uuzvMUCVVv95-xAZg5KS94QXm7QXZW4). Place the files and make directories as follows. 35 | ``` 36 | GEN-VLKT 37 | |─ data 38 | │ └─ v-coco 39 | | |─ data 40 | | | |─ instances_vcoco_all_2014.json 41 | | | : 42 | | |─ prior.pickle 43 | | |─ images 44 | | | |─ train2014 45 | | | | |─ COCO_train2014_000000000009.jpg 46 | | | | : 47 | | | └─ val2014 48 | | | |─ COCO_val2014_000000000042.jpg 49 | | | : 50 | | |─ annotations 51 | : : 52 | ``` 53 | For our implementation, the annotation file have to be converted to the HOIA format. The conversion can be conducted as follows. 54 | ``` 55 | PYTHONPATH=data/v-coco \ 56 | python convert_vcoco_annotations.py \ 57 | --load_path data/v-coco/data \ 58 | --prior_path data/v-coco/prior.pickle \ 59 | --save_path data/v-coco/annotations 60 | ``` 61 | Note that only Python2 can be used for this conversion because `vsrl_utils.py` in the v-coco repository shows a error with Python3. 62 | 63 | V-COCO annotations with the HOIA format, `corre_vcoco.npy`, `test_vcoco.json`, and `trainval_vcoco.json` will be generated to `annotations` directory. 64 | 65 | 66 | 67 | ## Pre-trained model 68 | Download the pretrained model of DETR detector for [ResNet50](https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth), and put it to the `params` directory. 69 | ``` 70 | python ./tools/convert_parameters.py \ 71 | --load_path params/detr-r50-e632da11.pth \ 72 | --save_path params/detr-r50-pre-2branch-hico.pth \ 73 | --num_queries 64 74 | 75 | python ./tools/convert_parameters.py \ 76 | --load_path params/detr-r50-e632da11.pth \ 77 | --save_path params/detr-r50-pre-2branch-vcoco.pth \ 78 | --dataset vcoco \ 79 | --num_queries 64 80 | ``` 81 | 82 | ## Training 83 | After the preparation, you can start training with the following commands. 84 | ### HICO-DET 85 | ``` 86 | sh ./config/hico_s.sh 87 | ``` 88 | 89 | ### V-COCO 90 | ``` 91 | sh ./configs/vcoco_s.sh 92 | ``` 93 | ### Zero-shot 94 | ``` 95 | sh ./configs/hico_s_zs_nf_uc.sh 96 | ``` 97 | 98 | ## Evaluation 99 | 100 | ### HICO-DET 101 | You can conduct the evaluation with trained parameters for HICO-DET as follows. 102 | ``` 103 | python -m torch.distributed.launch \ 104 | --nproc_per_node=8 \ 105 | --use_env \ 106 | main.py \ 107 | --pretrained pretrained/hico_gen_vlkt_s.pth \ 108 | --dataset_file hico \ 109 | --hoi_path data/hico_20160224_det \ 110 | --num_obj_classes 80 \ 111 | --num_verb_classes 117 \ 112 | --backbone resnet50 \ 113 | --num_queries 64 \ 114 | --dec_layers 3 \ 115 | --eval \ 116 | --with_clip_label \ 117 | --with_obj_clip_label \ 118 | --use_nms_filter 119 | ``` 120 | 121 | For the official evaluation (reported in paper), you need to covert the prediction file to a official prediction format following [this file](./tools/covert_annot_for_official_eval.py), and then follow [PPDM](https://github.com/YueLiao/PPDM) evaluation steps. 122 | ### V-COCO 123 | Firstly, you need the add the following main function to the vsrl_eval.py in data/v-coco. 124 | ``` 125 | if __name__ == '__main__': 126 | import sys 127 | 128 | vsrl_annot_file = 'data/vcoco/vcoco_test.json' 129 | coco_file = 'data/instances_vcoco_all_2014.json' 130 | split_file = 'data/splits/vcoco_test.ids' 131 | 132 | vcocoeval = VCOCOeval(vsrl_annot_file, coco_file, split_file) 133 | 134 | det_file = sys.argv[1] 135 | vcocoeval._do_eval(det_file, ovr_thresh=0.5) 136 | ``` 137 | 138 | Next, for the official evaluation of V-COCO, a pickle file of detection results have to be generated. You can generate the file with the following command. and then evaluate it as follows. 139 | ``` 140 | python generate_vcoco_official.py \ 141 | --param_path pretrained/VCOCO_GEN_VLKT_S.pth \ 142 | --save_path vcoco.pickle \ 143 | --hoi_path data/v-coco \ 144 | --num_queries 64 \ 145 | --dec_layers 3 \ 146 | --use_nms_filter \ 147 | --with_clip_label \ 148 | --with_obj_clip_label 149 | 150 | cd data/v-coco 151 | python vsrl_eval.py vcoco.pickle 152 | 153 | ``` 154 | 155 | ### Zero-shot 156 | ``` 157 | python -m torch.distributed.launch \ 158 | --nproc_per_node=8 \ 159 | --use_env \ 160 | main.py \ 161 | --pretrained pretrained/hico_gen_vlkt_s.pth \ 162 | --dataset_file hico \ 163 | --hoi_path data/hico_20160224_det \ 164 | --num_obj_classes 80 \ 165 | --num_verb_classes 117 \ 166 | --backbone resnet50 \ 167 | --num_queries 64 \ 168 | --dec_layers 3 \ 169 | --eval \ 170 | --with_clip_label \ 171 | --with_obj_clip_label \ 172 | --use_nms_filter \ 173 | --zero_shot_type rare_first \ 174 | --del_unseen 175 | ``` 176 | 177 | ## Regular HOI Detection Results 178 | 179 | ### HICO-DET 180 | | | Full (D) |Rare (D)|Non-rare (D)|Full(KO)|Rare (KO)|Non-rare (KO)|Download| Conifg| 181 | |:-------------------|:--------:| :---: | :---: | :---: |:-------:|:-----------:| :---: | :---: | 182 | | GEN-VLKT-S (R50) | 33.75 | 29.25 |35.10 | 36.78| 32.75 | 37.99 | [model](https://drive.google.com/file/d/1dcxY41-fBZ1J_Rh_41VolliCwLa7qgk1/view?usp=sharing) | [config](./configs/hico_s.sh)| 183 | | GEN-VLKT-M* (R101) | 34.63 | 30.04| 36.01| 37.97| 33.72 | 39.24 | [model](https://drive.google.com/file/d/1rAS0gEOx2-L3qeprYal4oLgatQgSPtJJ/view?usp=sharing) | [config](./configs/hico_m.sh)| 184 | | GEN-VLKT-L (R101) | 34.95 | 31.18| 36.08 | 38.22| 34.36 | 39.37 | [model](https://drive.google.com/file/d/1wTSrpCZujg6kqHbikRrGbUafStggFXjh/view?usp=sharing) |[config](./configs/hico_l.sh) | 185 | 186 | D: Default, KO: Known object, *: The original model is lost and the provided checkpoint performance is slightly different from the paper reported. 187 | 188 | 189 | ### V-COCO 190 | | | Scenario 1 | Scenario 2 | Download | Config | 191 | | :--- | :---: | :---: | :---: | :---: | 192 | |GEN-VLKT-S (R50)| 62.41| 64.46 | [model](https://drive.google.com/file/d/1bPlr1_jhRabcG9N4B8NN4vUr63q4Go-Y/view?usp=sharing) |[config](./configs/vcoco_s.sh) | 193 | |GEN-VLKT-M (R101)| 63.28| 65.58 | [model](https://drive.google.com/file/d/1q9KrHLfaDA6TGu5obqxrCxGlxq5iauAR/view?usp=sharing) |[config](./configs/vcoco_m.sh) | 194 | |GEN-VLKT-L (R101)| 63.58 |65.93 | [model](https://drive.google.com/file/d/1y_AYdF_BewWTZfDPEiSOy72e-63x_rDC/view?usp=sharing) |[config](./configs/vcoco_l.sh) | 195 | 196 | ## Zero-shot HOI Detection Results 197 | | |Type |Unseen| Seen| Full|Download| Conifg| 198 | | :--- | :---: | :---: | :---: | :---: | :---: | :---: | 199 | | GEN-VLKT-S|RF-UC |21.36 |32.91 |30.56| [model](https://drive.google.com/file/d/1h4NKPNfbH9ixED6X6-4oH5mGkeOJmW0M/view?usp=sharing)|[config](./configs/hico_s_zs_rf_uc.sh)| 200 | | GEN-VLKT-S|NF-UC |25.05| 23.38| 23.71| [model](https://drive.google.com/file/d/1J1UdauMnzc1cM-OOqMrwXpnJwBGj5pe6/view?usp=sharing)|[config](./configs/hico_s_zs_nf_uc.sh)| 201 | | GEN-VLKT-S|UO |10.51| 28.92| 25.63| [model](https://drive.google.com/file/d/19nEAr1IIeTryYFeVA6SY1pmmEpvw3eUD/view?usp=sharing)|[config](./configs/hico_s_zs_uo.sh)| 202 | | GEN-VLKT-S|UV|20.96| 30.23| 28.74| [model](https://drive.google.com/file/d/1lJbsoIgeluYFcBC_pnx5FdT3fOUpMwQl/view?usp=sharing)|[config](./configs/hico_s_zs_uv.sh)| 203 | ## Citation 204 | Please consider citing our paper if it helps your research. 205 | ``` 206 | @article{liao2022gen, 207 | title={GEN-VLKT: Simplify Association and Enhance Interaction Understanding for HOI Detection}, 208 | author={Liao, Yue and Zhang, Aixi and Lu, Miao and Wang, Yongliang and Li, Xiaobo and Liu, Si}, 209 | journal={arXiv preprint arXiv:2203.13954}, 210 | year={2022} 211 | } 212 | ``` 213 | 214 | ## License 215 | GEN-VLKT is released under the MIT license. See [LICENSE](./LICENSE) for additional details. 216 | 217 | ## Acknowledge 218 | Some of the codes are built upon [PPDM](https://github.com/YueLiao/PPDM), [DETR](https://github.com/facebookresearch/detr), [QPIC](https://github.com/hitachi-rd-cv/qpic) and [CDN](https://github.com/YueLiao/CDN). Thanks them for their great works! 219 | 220 | -------------------------------------------------------------------------------- /datasets/vcoco.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from PIL import Image 3 | import json 4 | import numpy as np 5 | 6 | import torch 7 | import torch.utils.data 8 | import torchvision 9 | 10 | import datasets.transforms as T 11 | import clip 12 | from .vcoco_text_label import * 13 | 14 | 15 | class VCOCO(torch.utils.data.Dataset): 16 | 17 | def __init__(self, img_set, img_folder, anno_file, transforms, num_queries, args): 18 | self.img_set = img_set 19 | self.img_folder = img_folder 20 | with open(anno_file, 'r') as f: 21 | self.annotations = json.load(f) 22 | self._transforms = transforms 23 | 24 | self.num_queries = num_queries 25 | 26 | self._valid_obj_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 27 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 28 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 29 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 30 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 31 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 32 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 33 | 82, 84, 85, 86, 87, 88, 89, 90) 34 | self._valid_verb_ids = range(29) 35 | 36 | device = "cuda" if torch.cuda.is_available() else "cpu" 37 | _, self.clip_preprocess = clip.load(args.clip_model, device) 38 | 39 | self.text_label_ids = list(vcoco_hoi_text_label.keys()) 40 | 41 | def __len__(self): 42 | return len(self.annotations) 43 | 44 | def __getitem__(self, idx): 45 | img_anno = self.annotations[idx] 46 | 47 | img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') 48 | w, h = img.size 49 | 50 | if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: 51 | img_anno['annotations'] = img_anno['annotations'][:self.num_queries] 52 | 53 | boxes = [obj['bbox'] for obj in img_anno['annotations']] 54 | # guard against no boxes via resizing 55 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 56 | 57 | if self.img_set == 'train': 58 | # Add index for confirming which boxes are kept after image transformation 59 | classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in enumerate(img_anno['annotations'])] 60 | else: 61 | classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] 62 | classes = torch.tensor(classes, dtype=torch.int64) 63 | 64 | target = {} 65 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 66 | target['size'] = torch.as_tensor([int(h), int(w)]) 67 | if self.img_set == 'train': 68 | boxes[:, 0::2].clamp_(min=0, max=w) 69 | boxes[:, 1::2].clamp_(min=0, max=h) 70 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 71 | boxes = boxes[keep] 72 | classes = classes[keep] 73 | 74 | target['boxes'] = boxes 75 | target['labels'] = classes 76 | target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) 77 | target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 78 | 79 | if self._transforms is not None: 80 | img_0, target_0 = self._transforms[0](img, target) 81 | img, target = self._transforms[1](img_0, target_0) 82 | clip_inputs = self.clip_preprocess(img_0) 83 | target['clip_inputs'] = clip_inputs 84 | kept_box_indices = [label[0] for label in target['labels']] 85 | 86 | target['labels'] = target['labels'][:, 1] 87 | 88 | obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] 89 | sub_obj_pairs = [] 90 | hoi_labels = [] 91 | for hoi in img_anno['hoi_annotation']: 92 | if hoi['subject_id'] not in kept_box_indices or \ 93 | (hoi['object_id'] != -1 and hoi['object_id'] not in kept_box_indices): 94 | continue 95 | 96 | #if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: 97 | # continue 98 | 99 | if hoi['object_id'] == -1: 100 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 80) 101 | else: 102 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 103 | target['labels'][kept_box_indices.index(hoi['object_id'])]) 104 | 105 | if verb_obj_pair not in self.text_label_ids: 106 | continue 107 | 108 | sub_obj_pair = (hoi['subject_id'], hoi['object_id']) 109 | if sub_obj_pair in sub_obj_pairs: 110 | verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 111 | hoi_labels[sub_obj_pairs.index(sub_obj_pair)][self.text_label_ids.index(verb_obj_pair)] = 1 112 | else: 113 | sub_obj_pairs.append(sub_obj_pair) 114 | if hoi['object_id'] == -1: 115 | obj_labels.append(torch.tensor(len(self._valid_obj_ids))) 116 | else: 117 | obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) 118 | verb_label = [0 for _ in range(len(self._valid_verb_ids))] 119 | verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 120 | hoi_label = [0] * len(self.text_label_ids) 121 | hoi_label[self.text_label_ids.index(verb_obj_pair)] = 1 122 | sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] 123 | if hoi['object_id'] == -1: 124 | obj_box = torch.zeros((4,), dtype=torch.float32) 125 | else: 126 | obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] 127 | verb_labels.append(verb_label) 128 | hoi_labels.append(hoi_label) 129 | sub_boxes.append(sub_box) 130 | obj_boxes.append(obj_box) 131 | 132 | target['filename'] = img_anno['file_name'] 133 | if len(sub_obj_pairs) == 0: 134 | target['obj_labels'] = torch.zeros((0,), dtype=torch.int64) 135 | target['verb_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 136 | #target['hoi_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 137 | target['hoi_labels'] = torch.zeros((0, len(self.text_label_ids)), dtype=torch.float32) 138 | target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 139 | target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 140 | else: 141 | target['obj_labels'] = torch.stack(obj_labels) 142 | target['verb_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 143 | #target['hoi_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 144 | target['hoi_labels'] = torch.as_tensor(hoi_labels, dtype=torch.float32) 145 | target['sub_boxes'] = torch.stack(sub_boxes) 146 | target['obj_boxes'] = torch.stack(obj_boxes) 147 | else: 148 | target['filename'] = img_anno['file_name'] 149 | target['boxes'] = boxes 150 | target['labels'] = classes 151 | target['id'] = idx 152 | target['img_id'] = int(img_anno['file_name'].rstrip('.jpg').split('_')[2]) 153 | 154 | if self._transforms is not None: 155 | img, _ = self._transforms(img, None) 156 | 157 | hois = [] 158 | for hoi in img_anno['hoi_annotation']: 159 | hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) 160 | target['hois'] = torch.as_tensor(hois, dtype=torch.int64) 161 | 162 | return img, target 163 | 164 | def load_correct_mat(self, path): 165 | self.correct_mat = np.load(path) 166 | 167 | 168 | # Add color jitter to coco transforms 169 | def make_vcoco_transforms(image_set): 170 | 171 | normalize = T.Compose([ 172 | T.ToTensor(), 173 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 174 | ]) 175 | 176 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 177 | 178 | if image_set == 'train': 179 | return [T.Compose([ 180 | T.RandomHorizontalFlip(), 181 | T.ColorJitter(.4, .4, .4), 182 | T.RandomSelect( 183 | T.RandomResize(scales, max_size=1333), 184 | T.Compose([ 185 | T.RandomResize([400, 500, 600]), 186 | T.RandomSizeCrop(384, 600), 187 | T.RandomResize(scales, max_size=1333), 188 | ]))] 189 | ), 190 | normalize 191 | ] 192 | 193 | if image_set == 'val': 194 | return T.Compose([ 195 | T.RandomResize([800], max_size=1333), 196 | normalize, 197 | ]) 198 | 199 | raise ValueError(f'unknown {image_set}') 200 | 201 | 202 | def build(image_set, args): 203 | root = Path(args.hoi_path) 204 | assert root.exists(), f'provided HOI path {root} does not exist' 205 | PATHS = { 206 | 'train': (root / 'images' / 'train2014', root / 'annotations' / 'trainval_vcoco.json'), 207 | 'val': (root / 'images' / 'val2014', root / 'annotations' / 'test_vcoco.json') 208 | } 209 | CORRECT_MAT_PATH = root / 'annotations' / 'corre_vcoco.npy' 210 | 211 | img_folder, anno_file = PATHS[image_set] 212 | dataset = VCOCO(image_set, img_folder, anno_file, transforms=make_vcoco_transforms(image_set), 213 | num_queries=args.num_queries, args=args) 214 | if image_set == 'val': 215 | dataset.load_correct_mat(CORRECT_MAT_PATH) 216 | return dataset 217 | -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Modified from DETR (https://github.com/facebookresearch/detr) 3 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 4 | # ------------------------------------------------------------------------ 5 | """ 6 | Transforms and data augmentation for both image + bbox. 7 | """ 8 | import random 9 | 10 | import PIL 11 | import torch 12 | import torchvision.transforms as T 13 | import torchvision.transforms.functional as F 14 | 15 | from util.box_ops import box_xyxy_to_cxcywh 16 | from util.misc import interpolate 17 | 18 | 19 | def crop(image, target, region): 20 | cropped_image = F.crop(image, *region) 21 | 22 | target = target.copy() 23 | i, j, h, w = region 24 | 25 | # should we do something wrt the original size? 26 | target["size"] = torch.tensor([h, w]) 27 | 28 | fields = ["labels", "area", "iscrowd"] 29 | 30 | if "boxes" in target: 31 | boxes = target["boxes"] 32 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 33 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 34 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 35 | cropped_boxes = cropped_boxes.clamp(min=0) 36 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 37 | target["boxes"] = cropped_boxes.reshape(-1, 4) 38 | target["area"] = area 39 | fields.append("boxes") 40 | 41 | if "masks" in target: 42 | # FIXME should we update the area here if there are no boxes? 43 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 44 | fields.append("masks") 45 | 46 | # remove elements for which the boxes or masks that have zero area 47 | if "boxes" in target or "masks" in target: 48 | # favor boxes selection when defining which elements to keep 49 | # this is compatible with previous implementation 50 | if "boxes" in target: 51 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 52 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 53 | else: 54 | keep = target['masks'].flatten(1).any(1) 55 | 56 | for field in fields: 57 | target[field] = target[field][keep] 58 | 59 | return cropped_image, target 60 | 61 | 62 | def hflip(image, target): 63 | flipped_image = F.hflip(image) 64 | 65 | w, h = image.size 66 | 67 | target = target.copy() 68 | if "boxes" in target: 69 | boxes = target["boxes"] 70 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 71 | target["boxes"] = boxes 72 | 73 | if "masks" in target: 74 | target['masks'] = target['masks'].flip(-1) 75 | 76 | return flipped_image, target 77 | 78 | 79 | def resize(image, target, size, max_size=None): 80 | # size can be min_size (scalar) or (w, h) tuple 81 | 82 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 83 | w, h = image_size 84 | if max_size is not None: 85 | min_original_size = float(min((w, h))) 86 | max_original_size = float(max((w, h))) 87 | if max_original_size / min_original_size * size > max_size: 88 | size = int(round(max_size * min_original_size / max_original_size)) 89 | 90 | if (w <= h and w == size) or (h <= w and h == size): 91 | return (h, w) 92 | 93 | if w < h: 94 | ow = size 95 | oh = int(size * h / w) 96 | else: 97 | oh = size 98 | ow = int(size * w / h) 99 | 100 | return (oh, ow) 101 | 102 | def get_size(image_size, size, max_size=None): 103 | if isinstance(size, (list, tuple)): 104 | return size[::-1] 105 | else: 106 | return get_size_with_aspect_ratio(image_size, size, max_size) 107 | 108 | size = get_size(image.size, size, max_size) 109 | rescaled_image = F.resize(image, size) 110 | 111 | if target is None: 112 | return rescaled_image, None 113 | 114 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 115 | ratio_width, ratio_height = ratios 116 | 117 | target = target.copy() 118 | if "boxes" in target: 119 | boxes = target["boxes"] 120 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 121 | target["boxes"] = scaled_boxes 122 | 123 | if "area" in target: 124 | area = target["area"] 125 | scaled_area = area * (ratio_width * ratio_height) 126 | target["area"] = scaled_area 127 | 128 | h, w = size 129 | target["size"] = torch.tensor([h, w]) 130 | 131 | if "masks" in target: 132 | target['masks'] = interpolate( 133 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 134 | 135 | return rescaled_image, target 136 | 137 | 138 | def pad(image, target, padding): 139 | # assumes that we only pad on the bottom right corners 140 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 141 | if target is None: 142 | return padded_image, None 143 | target = target.copy() 144 | # should we do something wrt the original size? 145 | target["size"] = torch.tensor(padded_image[::-1]) 146 | if "masks" in target: 147 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 148 | return padded_image, target 149 | 150 | 151 | class RandomCrop(object): 152 | def __init__(self, size): 153 | self.size = size 154 | 155 | def __call__(self, img, target): 156 | region = T.RandomCrop.get_params(img, self.size) 157 | return crop(img, target, region) 158 | 159 | 160 | class RandomSizeCrop(object): 161 | def __init__(self, min_size: int, max_size: int): 162 | self.min_size = min_size 163 | self.max_size = max_size 164 | 165 | def __call__(self, img: PIL.Image.Image, target: dict): 166 | w = random.randint(self.min_size, min(img.width, self.max_size)) 167 | h = random.randint(self.min_size, min(img.height, self.max_size)) 168 | region = T.RandomCrop.get_params(img, [h, w]) 169 | return crop(img, target, region) 170 | 171 | 172 | class CenterCrop(object): 173 | def __init__(self, size): 174 | self.size = size 175 | 176 | def __call__(self, img, target): 177 | image_width, image_height = img.size 178 | crop_height, crop_width = self.size 179 | crop_top = int(round((image_height - crop_height) / 2.)) 180 | crop_left = int(round((image_width - crop_width) / 2.)) 181 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 182 | 183 | 184 | class RandomHorizontalFlip(object): 185 | def __init__(self, p=0.5): 186 | self.p = p 187 | 188 | def __call__(self, img, target): 189 | if random.random() < self.p: 190 | return hflip(img, target) 191 | return img, target 192 | 193 | 194 | class RandomResize(object): 195 | def __init__(self, sizes, max_size=None): 196 | assert isinstance(sizes, (list, tuple)) 197 | self.sizes = sizes 198 | self.max_size = max_size 199 | 200 | def __call__(self, img, target=None): 201 | size = random.choice(self.sizes) 202 | return resize(img, target, size, self.max_size) 203 | 204 | 205 | class RandomPad(object): 206 | def __init__(self, max_pad): 207 | self.max_pad = max_pad 208 | 209 | def __call__(self, img, target): 210 | pad_x = random.randint(0, self.max_pad) 211 | pad_y = random.randint(0, self.max_pad) 212 | return pad(img, target, (pad_x, pad_y)) 213 | 214 | 215 | class RandomSelect(object): 216 | """ 217 | Randomly selects between transforms1 and transforms2, 218 | with probability p for transforms1 and (1 - p) for transforms2 219 | """ 220 | def __init__(self, transforms1, transforms2, p=0.5): 221 | self.transforms1 = transforms1 222 | self.transforms2 = transforms2 223 | self.p = p 224 | 225 | def __call__(self, img, target): 226 | if random.random() < self.p: 227 | return self.transforms1(img, target) 228 | return self.transforms2(img, target) 229 | 230 | 231 | class ToTensor(object): 232 | def __call__(self, img, target): 233 | return F.to_tensor(img), target 234 | 235 | 236 | class RandomErasing(object): 237 | 238 | def __init__(self, *args, **kwargs): 239 | self.eraser = T.RandomErasing(*args, **kwargs) 240 | 241 | def __call__(self, img, target): 242 | return self.eraser(img), target 243 | 244 | 245 | class Normalize(object): 246 | def __init__(self, mean, std): 247 | self.mean = mean 248 | self.std = std 249 | 250 | def __call__(self, image, target=None): 251 | image = F.normalize(image, mean=self.mean, std=self.std) 252 | if target is None: 253 | return image, None 254 | target = target.copy() 255 | h, w = image.shape[-2:] 256 | if "boxes" in target: 257 | boxes = target["boxes"] 258 | boxes = box_xyxy_to_cxcywh(boxes) 259 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 260 | target["boxes"] = boxes 261 | return image, target 262 | 263 | 264 | class Compose(object): 265 | def __init__(self, transforms): 266 | self.transforms = transforms 267 | 268 | def __call__(self, image, target): 269 | for t in self.transforms: 270 | image, target = t(image, target) 271 | return image, target 272 | 273 | def __repr__(self): 274 | format_string = self.__class__.__name__ + "(" 275 | for t in self.transforms: 276 | format_string += "\n" 277 | format_string += " {0}".format(t) 278 | format_string += "\n)" 279 | return format_string 280 | 281 | class ColorJitter(object): 282 | def __init__(self, brightness=0, contrast=0, saturatio=0, hue=0): 283 | self.color_jitter = T.ColorJitter(brightness, contrast, saturatio, hue) 284 | 285 | def __call__(self, img, target): 286 | return self.color_jitter(img), target 287 | -------------------------------------------------------------------------------- /datasets/vcoco_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from collections import defaultdict 3 | import os, cv2, json 4 | from .vcoco_text_label import * 5 | from util.topk import top_k 6 | 7 | class VCOCOEvaluator(): 8 | 9 | def __init__(self, preds, gts, correct_mat, use_nms_filter=False): 10 | self.overlap_iou = 0.5 11 | self.max_hois = 100 12 | 13 | self.fp = defaultdict(list) 14 | self.tp = defaultdict(list) 15 | self.score = defaultdict(list) 16 | self.sum_gts = defaultdict(lambda: 0) 17 | 18 | self.verb_classes = ['hold_obj', 'stand', 'sit_instr', 'ride_instr', 'walk', 'look_obj', 'hit_instr', 'hit_obj', 19 | 'eat_obj', 'eat_instr', 'jump_instr', 'lay_instr', 'talk_on_phone_instr', 'carry_obj', 20 | 'throw_obj', 'catch_obj', 'cut_instr', 'cut_obj', 'run', 'work_on_computer_instr', 21 | 'ski_instr', 'surf_instr', 'skateboard_instr', 'smile', 'drink_instr', 'kick_obj', 22 | 'point_instr', 'read_obj', 'snowboard_instr'] 23 | self.thesis_map_indices = [0, 2, 3, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 24, 25, 27, 28] 24 | 25 | self.preds = [] 26 | self.hoi_obj_list = [] 27 | self.verb_hoi_dict = defaultdict(list) 28 | self.vcoco_triplet_labels = list(vcoco_hoi_text_label.keys()) 29 | for index, hoi_pair in enumerate(self.vcoco_triplet_labels): 30 | self.hoi_obj_list.append(hoi_pair[1]) 31 | self.verb_hoi_dict[hoi_pair[0]].append(index) 32 | 33 | self.score_mode = 1 34 | for img_preds in preds: 35 | img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items()} 36 | bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_preds['boxes'], img_preds['labels'])] 37 | if self.score_mode == 0: 38 | obj_scores = img_preds['obj_scores'] 39 | hoi_scores = img_preds['hoi_scores'] * obj_scores[:, self.hoi_obj_list] 40 | elif self.score_mode == 1: 41 | obj_scores = img_preds['obj_scores'] * img_preds['obj_scores'] 42 | hoi_scores = img_preds['hoi_scores'] + obj_scores[:, self.hoi_obj_list] 43 | else: 44 | raise 45 | 46 | verb_scores = np.zeros((hoi_scores.shape[0], len(self.verb_hoi_dict)))# 64 x 29 47 | for i in range(hoi_scores.shape[0]): 48 | for k,v in self.verb_hoi_dict.items(): 49 | #verb_scores[i][k] = np.sum(hoi_scores[i, v]) 50 | verb_scores[i][k] = np.max(hoi_scores[i, v]) 51 | 52 | verb_labels = np.tile(np.arange(verb_scores.shape[1]), (verb_scores.shape[0], 1)) 53 | subject_ids = np.tile(img_preds['sub_ids'], (verb_scores.shape[1], 1)).T 54 | object_ids = np.tile(img_preds['obj_ids'], (verb_scores.shape[1], 1)).T 55 | 56 | verb_scores = verb_scores.ravel() 57 | verb_labels = verb_labels.ravel() 58 | subject_ids = subject_ids.ravel() 59 | object_ids = object_ids.ravel() 60 | 61 | if len(subject_ids) > 0: 62 | object_labels = np.array([bboxes[object_id]['category_id'] for object_id in object_ids]) 63 | correct_mat = np.concatenate((correct_mat, np.ones((correct_mat.shape[0], 1))), axis=1) 64 | masks = correct_mat[verb_labels, object_labels] 65 | verb_scores *= masks 66 | 67 | hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for 68 | subject_id, object_id, category_id, score in zip(subject_ids, object_ids, verb_labels, verb_scores)] 69 | hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 70 | hois = hois[:self.max_hois] 71 | else: 72 | hois = [] 73 | 74 | 75 | self.preds.append({ 76 | 'predictions': bboxes, 77 | 'hoi_prediction': hois 78 | }) 79 | 80 | self.gts = [] 81 | for img_gts in gts: 82 | img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id' and k != 'img_id' and k != 'filename'} 83 | self.gts.append({ 84 | 'annotations': [{'bbox': bbox, 'category_id': label} for bbox, label in zip(img_gts['boxes'], img_gts['labels'])], 85 | 'hoi_annotation': [{'subject_id': hoi[0], 'object_id': hoi[1], 'category_id': hoi[2]} for hoi in img_gts['hois']] 86 | }) 87 | for hoi in self.gts[-1]['hoi_annotation']: 88 | self.sum_gts[hoi['category_id']] += 1 89 | 90 | def evaluate(self): 91 | for img_preds, img_gts in zip(self.preds, self.gts): 92 | pred_bboxes = img_preds['predictions'] 93 | gt_bboxes = img_gts['annotations'] 94 | pred_hois = img_preds['hoi_prediction'] 95 | gt_hois = img_gts['hoi_annotation'] 96 | if len(gt_bboxes) != 0: 97 | bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes) 98 | self.compute_fptp(pred_hois, gt_hois, bbox_pairs, pred_bboxes, bbox_overlaps) 99 | else: 100 | for pred_hoi in pred_hois: 101 | self.tp[pred_hoi['category_id']].append(0) 102 | self.fp[pred_hoi['category_id']].append(1) 103 | self.score[pred_hoi['category_id']].append(pred_hoi['score']) 104 | map = self.compute_map() 105 | return map 106 | 107 | def compute_map(self): 108 | print('------------------------------------------------------------') 109 | ap = defaultdict(lambda: 0) 110 | aps = {} 111 | for category_id in sorted(list(self.sum_gts.keys())): 112 | sum_gts = self.sum_gts[category_id] 113 | if sum_gts == 0: 114 | continue 115 | 116 | tp = np.array((self.tp[category_id])) 117 | fp = np.array((self.fp[category_id])) 118 | if len(tp) == 0: 119 | ap[category_id] = 0 120 | else: 121 | score = np.array(self.score[category_id]) 122 | sort_inds = np.argsort(-score) 123 | fp = fp[sort_inds] 124 | tp = tp[sort_inds] 125 | fp = np.cumsum(fp) 126 | tp = np.cumsum(tp) 127 | rec = tp / sum_gts 128 | prec = tp / (fp + tp) 129 | ap[category_id] = self.voc_ap(rec, prec) 130 | print('{:>23s}: #GTs = {:>04d}, AP = {:>.4f}'.format(self.verb_classes[category_id], sum_gts, ap[category_id])) 131 | aps['AP_{}'.format(self.verb_classes[category_id])] = ap[category_id] 132 | 133 | m_ap_all = np.mean(list(ap.values())) 134 | m_ap_thesis = np.mean([ap[category_id] for category_id in self.thesis_map_indices]) 135 | 136 | print('------------------------------------------------------------') 137 | print('mAP all: {:.4f} mAP thesis: {:.4f}'.format(m_ap_all, m_ap_thesis)) 138 | print('------------------------------------------------------------') 139 | 140 | aps.update({'mAP_all': m_ap_all, 'mAP_thesis': m_ap_thesis}) 141 | 142 | return aps 143 | 144 | def voc_ap(self, rec, prec): 145 | ap = 0. 146 | for t in np.arange(0., 1.1, 0.1): 147 | if np.sum(rec >= t) == 0: 148 | p = 0 149 | else: 150 | p = np.max(prec[rec >= t]) 151 | ap = ap + p / 11. 152 | return ap 153 | 154 | def compute_fptp(self, pred_hois, gt_hois, match_pairs, pred_bboxes, bbox_overlaps): 155 | pos_pred_ids = match_pairs.keys() 156 | vis_tag = np.zeros(len(gt_hois)) 157 | pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 158 | if len(pred_hois) != 0: 159 | for pred_hoi in pred_hois: 160 | is_match = 0 161 | max_overlap = 0 162 | max_gt_hoi = 0 163 | for gt_hoi in gt_hois: 164 | if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and \ 165 | gt_hoi['object_id'] == -1: 166 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 167 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 168 | pred_category_id = pred_hoi['category_id'] 169 | if gt_hoi['subject_id'] in pred_sub_ids and pred_category_id == gt_hoi['category_id']: 170 | is_match = 1 171 | min_overlap_gt = pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])] 172 | if min_overlap_gt > max_overlap: 173 | max_overlap = min_overlap_gt 174 | max_gt_hoi = gt_hoi 175 | elif len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and \ 176 | pred_hoi['object_id'] in pos_pred_ids: 177 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 178 | pred_obj_ids = match_pairs[pred_hoi['object_id']] 179 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 180 | pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']] 181 | pred_category_id = pred_hoi['category_id'] 182 | if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids and \ 183 | pred_category_id == gt_hoi['category_id']: 184 | is_match = 1 185 | min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])], 186 | pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])]) 187 | if min_overlap_gt > max_overlap: 188 | max_overlap = min_overlap_gt 189 | max_gt_hoi = gt_hoi 190 | if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0: 191 | self.fp[pred_hoi['category_id']].append(0) 192 | self.tp[pred_hoi['category_id']].append(1) 193 | vis_tag[gt_hois.index(max_gt_hoi)] = 1 194 | else: 195 | self.fp[pred_hoi['category_id']].append(1) 196 | self.tp[pred_hoi['category_id']].append(0) 197 | self.score[pred_hoi['category_id']].append(pred_hoi['score']) 198 | 199 | def compute_iou_mat(self, bbox_list1, bbox_list2): 200 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) 201 | if len(bbox_list1) == 0 or len(bbox_list2) == 0: 202 | return {} 203 | for i, bbox1 in enumerate(bbox_list1): 204 | for j, bbox2 in enumerate(bbox_list2): 205 | iou_i = self.compute_IOU(bbox1, bbox2) 206 | iou_mat[i, j] = iou_i 207 | 208 | iou_mat_ov=iou_mat.copy() 209 | iou_mat[iou_mat>=self.overlap_iou] = 1 210 | iou_mat[iou_mat 0: 216 | for i, pred_id in enumerate(match_pairs[1]): 217 | if pred_id not in match_pairs_dict.keys(): 218 | match_pairs_dict[pred_id] = [] 219 | match_pair_overlaps[pred_id]=[] 220 | match_pairs_dict[pred_id].append(match_pairs[0][i]) 221 | match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id]) 222 | return match_pairs_dict, match_pair_overlaps 223 | 224 | def compute_IOU(self, bbox1, bbox2): 225 | if isinstance(bbox1['category_id'], str): 226 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) 227 | if isinstance(bbox2['category_id'], str): 228 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) 229 | if bbox1['category_id'] == bbox2['category_id']: 230 | rec1 = bbox1['bbox'] 231 | rec2 = bbox2['bbox'] 232 | # computing area of each rectangles 233 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1) 234 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1) 235 | 236 | # computing the sum_area 237 | sum_area = S_rec1 + S_rec2 238 | 239 | # find the each edge of intersect rectangle 240 | left_line = max(rec1[1], rec2[1]) 241 | right_line = min(rec1[3], rec2[3]) 242 | top_line = max(rec1[0], rec2[0]) 243 | bottom_line = min(rec1[2], rec2[2]) 244 | # judge if there is an intersect 245 | if left_line >= right_line or top_line >= bottom_line: 246 | return 0 247 | else: 248 | intersect = (right_line - left_line+1) * (bottom_line - top_line+1) 249 | return intersect / (sum_area - intersect) 250 | else: 251 | return 0 252 | -------------------------------------------------------------------------------- /datasets/hico.py: -------------------------------------------------------------------------------- 1 | """ 2 | HICO detection dataset. 3 | """ 4 | from pathlib import Path 5 | 6 | import torchvision.transforms 7 | from PIL import Image 8 | import json 9 | from collections import defaultdict 10 | import numpy as np 11 | 12 | import torch 13 | import torch.utils.data 14 | import clip 15 | 16 | import datasets.transforms as T 17 | from .hico_text_label import hico_text_label, hico_unseen_index 18 | 19 | 20 | 21 | class HICODetection(torch.utils.data.Dataset): 22 | def __init__(self, img_set, img_folder, anno_file, clip_feats_folder, transforms, num_queries, args): 23 | self.img_set = img_set 24 | self.img_folder = img_folder 25 | self.clip_feates_folder = clip_feats_folder 26 | with open(anno_file, 'r') as f: 27 | self.annotations = json.load(f) 28 | self._transforms = transforms 29 | 30 | self.num_queries = num_queries 31 | 32 | self.unseen_index = hico_unseen_index.get(args.zero_shot_type, []) 33 | self._valid_obj_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 34 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 35 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 36 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 37 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 38 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 39 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 40 | 82, 84, 85, 86, 87, 88, 89, 90) 41 | self._valid_verb_ids = list(range(1, 118)) 42 | 43 | self.text_label_dict = hico_text_label 44 | self.text_label_ids = list(self.text_label_dict.keys()) 45 | if img_set == 'train' and len(self.unseen_index) != 0 and args.del_unseen: 46 | tmp = [] 47 | for idx, k in enumerate(self.text_label_ids): 48 | if idx in self.unseen_index: 49 | continue 50 | else: 51 | tmp.append(k) 52 | self.text_label_ids = tmp 53 | 54 | if img_set == 'train': 55 | self.ids = [] 56 | for idx, img_anno in enumerate(self.annotations): 57 | new_img_anno = [] 58 | skip_pair = [] 59 | for hoi in img_anno['hoi_annotation']: 60 | if hoi['hoi_category_id'] - 1 in self.unseen_index: 61 | skip_pair.append((hoi['subject_id'], hoi['object_id'])) 62 | for hoi in img_anno['hoi_annotation']: 63 | if hoi['subject_id'] >= len(img_anno['annotations']) or hoi['object_id'] >= len( 64 | img_anno['annotations']): 65 | new_img_anno = [] 66 | break 67 | if (hoi['subject_id'], hoi['object_id']) not in skip_pair: 68 | new_img_anno.append(hoi) 69 | if len(new_img_anno) > 0: 70 | self.ids.append(idx) 71 | img_anno['hoi_annotation'] = new_img_anno 72 | else: 73 | self.ids = list(range(len(self.annotations))) 74 | print("{} contains {} images".format(img_set, len(self.ids))) 75 | 76 | device = "cuda" if torch.cuda.is_available() else "cpu" 77 | _, self.clip_preprocess = clip.load(args.clip_model, device) 78 | 79 | def __len__(self): 80 | return len(self.ids) 81 | 82 | def __getitem__(self, idx): 83 | img_anno = self.annotations[self.ids[idx]] 84 | 85 | img = Image.open(self.img_folder / img_anno['file_name']).convert('RGB') 86 | w, h = img.size 87 | 88 | if self.img_set == 'train' and len(img_anno['annotations']) > self.num_queries: 89 | img_anno['annotations'] = img_anno['annotations'][:self.num_queries] 90 | 91 | boxes = [obj['bbox'] for obj in img_anno['annotations']] 92 | # guard against no boxes via resizing 93 | boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) 94 | 95 | if self.img_set == 'train': 96 | # Add index for confirming which boxes are kept after image transformation 97 | classes = [(i, self._valid_obj_ids.index(obj['category_id'])) for i, obj in 98 | enumerate(img_anno['annotations'])] 99 | else: 100 | classes = [self._valid_obj_ids.index(obj['category_id']) for obj in img_anno['annotations']] 101 | classes = torch.tensor(classes, dtype=torch.int64) 102 | 103 | target = {} 104 | target['orig_size'] = torch.as_tensor([int(h), int(w)]) 105 | target['size'] = torch.as_tensor([int(h), int(w)]) 106 | if self.img_set == 'train': 107 | boxes[:, 0::2].clamp_(min=0, max=w) 108 | boxes[:, 1::2].clamp_(min=0, max=h) 109 | keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) 110 | boxes = boxes[keep] 111 | classes = classes[keep] 112 | 113 | target['boxes'] = boxes 114 | target['labels'] = classes 115 | target['iscrowd'] = torch.tensor([0 for _ in range(boxes.shape[0])]) 116 | target['area'] = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) 117 | 118 | if self._transforms is not None: 119 | img_0, target_0 = self._transforms[0](img, target) 120 | img, target = self._transforms[1](img_0, target_0) 121 | clip_inputs = self.clip_preprocess(img_0) 122 | target['clip_inputs'] = clip_inputs 123 | kept_box_indices = [label[0] for label in target['labels']] 124 | 125 | target['labels'] = target['labels'][:, 1] 126 | 127 | obj_labels, verb_labels, sub_boxes, obj_boxes = [], [], [], [] 128 | sub_obj_pairs = [] 129 | hoi_labels = [] 130 | for hoi in img_anno['hoi_annotation']: 131 | # print('hoi: ', hoi) 132 | if hoi['subject_id'] not in kept_box_indices or hoi['object_id'] not in kept_box_indices: 133 | continue 134 | verb_obj_pair = (self._valid_verb_ids.index(hoi['category_id']), 135 | target['labels'][kept_box_indices.index(hoi['object_id'])]) 136 | if verb_obj_pair not in self.text_label_ids: 137 | continue 138 | 139 | sub_obj_pair = (hoi['subject_id'], hoi['object_id']) 140 | if sub_obj_pair in sub_obj_pairs: 141 | verb_labels[sub_obj_pairs.index(sub_obj_pair)][self._valid_verb_ids.index(hoi['category_id'])] = 1 142 | hoi_labels[sub_obj_pairs.index(sub_obj_pair)][self.text_label_ids.index(verb_obj_pair)] = 1 143 | else: 144 | sub_obj_pairs.append(sub_obj_pair) 145 | obj_labels.append(target['labels'][kept_box_indices.index(hoi['object_id'])]) 146 | verb_label = [0 for _ in range(len(self._valid_verb_ids))] 147 | hoi_label = [0] * len(self.text_label_ids) 148 | hoi_label[self.text_label_ids.index(verb_obj_pair)] = 1 149 | verb_label[self._valid_verb_ids.index(hoi['category_id'])] = 1 150 | sub_box = target['boxes'][kept_box_indices.index(hoi['subject_id'])] 151 | obj_box = target['boxes'][kept_box_indices.index(hoi['object_id'])] 152 | verb_labels.append(verb_label) 153 | hoi_labels.append(hoi_label) 154 | sub_boxes.append(sub_box) 155 | obj_boxes.append(obj_box) 156 | 157 | target['filename'] = img_anno['file_name'] 158 | # print('sub_obj_pairs: ', sub_obj_pairs) 159 | if len(sub_obj_pairs) == 0: 160 | target['obj_labels'] = torch.zeros((0,), dtype=torch.int64) 161 | target['verb_labels'] = torch.zeros((0, len(self._valid_verb_ids)), dtype=torch.float32) 162 | target['hoi_labels'] = torch.zeros((0, len(self.text_label_ids)), dtype=torch.float32) 163 | target['sub_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 164 | target['obj_boxes'] = torch.zeros((0, 4), dtype=torch.float32) 165 | else: 166 | target['obj_labels'] = torch.stack(obj_labels) 167 | target['verb_labels'] = torch.as_tensor(verb_labels, dtype=torch.float32) 168 | target['hoi_labels'] = torch.as_tensor(hoi_labels, dtype=torch.float32) 169 | target['sub_boxes'] = torch.stack(sub_boxes) 170 | target['obj_boxes'] = torch.stack(obj_boxes) 171 | else: 172 | target['filename'] = img_anno['file_name'] 173 | target['boxes'] = boxes 174 | target['labels'] = classes 175 | target['id'] = idx 176 | 177 | if self._transforms is not None: 178 | img, _ = self._transforms(img, None) 179 | 180 | hois = [] 181 | for hoi in img_anno['hoi_annotation']: 182 | hois.append((hoi['subject_id'], hoi['object_id'], self._valid_verb_ids.index(hoi['category_id']))) 183 | target['hois'] = torch.as_tensor(hois, dtype=torch.int64) 184 | 185 | 186 | return img, target 187 | 188 | def set_rare_hois(self, anno_file): 189 | with open(anno_file, 'r') as f: 190 | annotations = json.load(f) 191 | 192 | if len(self.unseen_index) == 0: 193 | # no unseen categoruy, use rare to evaluate 194 | counts = defaultdict(lambda: 0) 195 | for img_anno in annotations: 196 | hois = img_anno['hoi_annotation'] 197 | bboxes = img_anno['annotations'] 198 | for hoi in hois: 199 | triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), 200 | self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), 201 | self._valid_verb_ids.index(hoi['category_id'])) 202 | counts[triplet] += 1 203 | self.rare_triplets = [] 204 | self.non_rare_triplets = [] 205 | for triplet, count in counts.items(): 206 | if count < 10: 207 | self.rare_triplets.append(triplet) 208 | else: 209 | self.non_rare_triplets.append(triplet) 210 | print("rare:{}, non-rare:{}".format(len(self.rare_triplets), len(self.non_rare_triplets))) 211 | else: 212 | self.rare_triplets = [] 213 | self.non_rare_triplets = [] 214 | for img_anno in annotations: 215 | hois = img_anno['hoi_annotation'] 216 | bboxes = img_anno['annotations'] 217 | for hoi in hois: 218 | triplet = (self._valid_obj_ids.index(bboxes[hoi['subject_id']]['category_id']), 219 | self._valid_obj_ids.index(bboxes[hoi['object_id']]['category_id']), 220 | self._valid_verb_ids.index(hoi['category_id'])) 221 | if hoi['hoi_category_id'] - 1 in self.unseen_index: 222 | self.rare_triplets.append(triplet) 223 | else: 224 | self.non_rare_triplets.append(triplet) 225 | print("unseen:{}, seen:{}".format(len(self.rare_triplets), len(self.non_rare_triplets))) 226 | 227 | def load_correct_mat(self, path): 228 | self.correct_mat = np.load(path) 229 | 230 | 231 | # Add color jitter to coco transforms 232 | def make_hico_transforms(image_set): 233 | normalize = T.Compose([ 234 | T.ToTensor(), 235 | T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 236 | ]) 237 | 238 | scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 239 | 240 | if image_set == 'train': 241 | return [T.Compose([ 242 | T.RandomHorizontalFlip(), 243 | T.ColorJitter(.4, .4, .4), 244 | T.RandomSelect( 245 | T.RandomResize(scales, max_size=1333), 246 | T.Compose([ 247 | T.RandomResize([400, 500, 600]), 248 | T.RandomSizeCrop(384, 600), 249 | T.RandomResize(scales, max_size=1333), 250 | ]))] 251 | ), 252 | normalize 253 | ] 254 | 255 | if image_set == 'val': 256 | return T.Compose([ 257 | T.RandomResize([800], max_size=1333), 258 | normalize, 259 | ]) 260 | 261 | raise ValueError(f'unknown {image_set}') 262 | 263 | 264 | def build(image_set, args): 265 | root = Path(args.hoi_path) 266 | assert root.exists(), f'provided HOI path {root} does not exist' 267 | PATHS = { 268 | 'train': (root / 'images' / 'train2015', root / 'annotations' / 'trainval_hico.json', 269 | root / 'clip_feats_pool' / 'train2015'), 270 | 'val': ( 271 | root / 'images' / 'test2015', root / 'annotations' / 'test_hico.json', 272 | root / 'clip_feats_pool' / 'test2015') 273 | } 274 | CORRECT_MAT_PATH = root / 'annotations' / 'corre_hico.npy' 275 | 276 | img_folder, anno_file, clip_feats_folder = PATHS[image_set] 277 | dataset = HICODetection(image_set, img_folder, anno_file, clip_feats_folder, 278 | transforms=make_hico_transforms(image_set), 279 | num_queries=args.num_queries, args=args) 280 | if image_set == 'val': 281 | dataset.set_rare_hois(PATHS['train'][1]) 282 | dataset.load_correct_mat(CORRECT_MAT_PATH) 283 | return dataset 284 | -------------------------------------------------------------------------------- /models/gen.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | 8 | 9 | class GEN(nn.Module): 10 | 11 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 12 | num_dec_layers=3, dim_feedforward=2048, dropout=0.1, 13 | activation="relu", normalize_before=False, 14 | return_intermediate_dec=False): 15 | super().__init__() 16 | 17 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 18 | dropout, activation, normalize_before) 19 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 20 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 21 | 22 | instance_decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 23 | dropout, activation, normalize_before) 24 | instance_decoder_norm = nn.LayerNorm(d_model) 25 | self.instance_decoder = TransformerDecoder(instance_decoder_layer, 26 | num_dec_layers, 27 | instance_decoder_norm, 28 | return_intermediate=return_intermediate_dec) 29 | 30 | interaction_decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 31 | dropout, activation, normalize_before) 32 | interaction_decoder_norm = nn.LayerNorm(d_model) 33 | self.interaction_decoder = TransformerDecoder(interaction_decoder_layer, 34 | num_dec_layers, 35 | interaction_decoder_norm, 36 | return_intermediate=return_intermediate_dec) 37 | 38 | self._reset_parameters() 39 | 40 | self.d_model = d_model 41 | self.nhead = nhead 42 | 43 | def _reset_parameters(self): 44 | for p in self.parameters(): 45 | if p.dim() > 1: 46 | nn.init.xavier_uniform_(p) 47 | 48 | def forward(self, src, mask, query_embed_h, query_embed_o, pos_guided_embed, pos_embed): 49 | # flatten NxCxHxW to HWxNxC 50 | bs, c, h, w = src.shape 51 | src = src.flatten(2).permute(2, 0, 1) 52 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 53 | num_queries = query_embed_h.shape[0] 54 | 55 | query_embed_o = query_embed_o + pos_guided_embed 56 | query_embed_h = query_embed_h + pos_guided_embed 57 | query_embed_o = query_embed_o.unsqueeze(1).repeat(1, bs, 1) 58 | query_embed_h = query_embed_h.unsqueeze(1).repeat(1, bs, 1) 59 | ins_query_embed = torch.cat((query_embed_h, query_embed_o), dim=0) 60 | 61 | mask = mask.flatten(1) 62 | ins_tgt = torch.zeros_like(ins_query_embed) 63 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 64 | 65 | ins_hs = self.instance_decoder(ins_tgt, memory, memory_key_padding_mask=mask, 66 | pos=pos_embed, query_pos=ins_query_embed) 67 | ins_hs = ins_hs.transpose(1, 2) 68 | h_hs = ins_hs[:, :, :num_queries, :] 69 | o_hs = ins_hs[:, :, num_queries:, :] 70 | 71 | # add 72 | ins_guided_embed = (h_hs + o_hs) / 2.0 73 | ins_guided_embed = ins_guided_embed.permute(0, 2, 1, 3) 74 | 75 | inter_tgt = torch.zeros_like(ins_guided_embed[0]) 76 | inter_hs = self.interaction_decoder(inter_tgt, memory, memory_key_padding_mask=mask, 77 | pos=pos_embed, query_pos=ins_guided_embed) 78 | inter_hs = inter_hs.transpose(1, 2) 79 | 80 | return h_hs, o_hs, inter_hs, memory.permute(1, 2, 0).view(bs, c, h, w) 81 | 82 | 83 | class TransformerEncoder(nn.Module): 84 | 85 | def __init__(self, encoder_layer, num_layers, norm=None): 86 | super().__init__() 87 | self.layers = _get_clones(encoder_layer, num_layers) 88 | self.num_layers = num_layers 89 | self.norm = norm 90 | 91 | def forward(self, src, 92 | mask: Optional[Tensor] = None, 93 | src_key_padding_mask: Optional[Tensor] = None, 94 | pos: Optional[Tensor] = None): 95 | output = src 96 | 97 | for layer in self.layers: 98 | output = layer(output, src_mask=mask, 99 | src_key_padding_mask=src_key_padding_mask, pos=pos) 100 | 101 | if self.norm is not None: 102 | output = self.norm(output) 103 | 104 | return output 105 | 106 | 107 | class TransformerDecoder(nn.Module): 108 | 109 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 110 | super().__init__() 111 | self.layers = _get_clones(decoder_layer, num_layers) 112 | self.num_layers = num_layers 113 | self.norm = norm 114 | self.return_intermediate = return_intermediate 115 | 116 | def forward(self, tgt, memory, 117 | tgt_mask: Optional[Tensor] = None, 118 | memory_mask: Optional[Tensor] = None, 119 | tgt_key_padding_mask: Optional[Tensor] = None, 120 | memory_key_padding_mask: Optional[Tensor] = None, 121 | pos: Optional[Tensor] = None, 122 | query_pos: Optional[Tensor] = None): 123 | output = tgt 124 | 125 | intermediate = [] 126 | 127 | for i, layer in enumerate(self.layers): 128 | if len(query_pos.shape) == 4: 129 | this_query_pos = query_pos[i] 130 | else: 131 | this_query_pos = query_pos 132 | output = layer(output, memory, tgt_mask=tgt_mask, 133 | memory_mask=memory_mask, 134 | tgt_key_padding_mask=tgt_key_padding_mask, 135 | memory_key_padding_mask=memory_key_padding_mask, 136 | pos=pos, query_pos=this_query_pos) 137 | if self.return_intermediate: 138 | intermediate.append(self.norm(output)) 139 | 140 | if self.norm is not None: 141 | output = self.norm(output) 142 | if self.return_intermediate: 143 | intermediate.pop() 144 | intermediate.append(output) 145 | 146 | if self.return_intermediate: 147 | return torch.stack(intermediate) 148 | 149 | return output 150 | 151 | 152 | class TransformerEncoderLayer(nn.Module): 153 | 154 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 155 | activation="relu", normalize_before=False): 156 | super().__init__() 157 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 158 | # Implementation of Feedforward model 159 | self.linear1 = nn.Linear(d_model, dim_feedforward) 160 | self.dropout = nn.Dropout(dropout) 161 | self.linear2 = nn.Linear(dim_feedforward, d_model) 162 | 163 | self.norm1 = nn.LayerNorm(d_model) 164 | self.norm2 = nn.LayerNorm(d_model) 165 | self.dropout1 = nn.Dropout(dropout) 166 | self.dropout2 = nn.Dropout(dropout) 167 | 168 | self.activation = _get_activation_fn(activation) 169 | self.normalize_before = normalize_before 170 | 171 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 172 | return tensor if pos is None else tensor + pos 173 | 174 | def forward_post(self, 175 | src, 176 | src_mask: Optional[Tensor] = None, 177 | src_key_padding_mask: Optional[Tensor] = None, 178 | pos: Optional[Tensor] = None): 179 | q = k = self.with_pos_embed(src, pos) 180 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 181 | key_padding_mask=src_key_padding_mask)[0] 182 | src = src + self.dropout1(src2) 183 | src = self.norm1(src) 184 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 185 | src = src + self.dropout2(src2) 186 | src = self.norm2(src) 187 | return src 188 | 189 | def forward_pre(self, src, 190 | src_mask: Optional[Tensor] = None, 191 | src_key_padding_mask: Optional[Tensor] = None, 192 | pos: Optional[Tensor] = None): 193 | src2 = self.norm1(src) 194 | q = k = self.with_pos_embed(src2, pos) 195 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 196 | key_padding_mask=src_key_padding_mask)[0] 197 | src = src + self.dropout1(src2) 198 | src2 = self.norm2(src) 199 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 200 | src = src + self.dropout2(src2) 201 | return src 202 | 203 | def forward(self, src, 204 | src_mask: Optional[Tensor] = None, 205 | src_key_padding_mask: Optional[Tensor] = None, 206 | pos: Optional[Tensor] = None): 207 | if self.normalize_before: 208 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 209 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 210 | 211 | 212 | class TransformerDecoderLayer(nn.Module): 213 | 214 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 215 | activation="relu", normalize_before=False): 216 | super().__init__() 217 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 218 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 219 | # Implementation of Feedforward model 220 | self.linear1 = nn.Linear(d_model, dim_feedforward) 221 | self.dropout = nn.Dropout(dropout) 222 | self.linear2 = nn.Linear(dim_feedforward, d_model) 223 | 224 | self.norm1 = nn.LayerNorm(d_model) 225 | self.norm2 = nn.LayerNorm(d_model) 226 | self.norm3 = nn.LayerNorm(d_model) 227 | self.dropout1 = nn.Dropout(dropout) 228 | self.dropout2 = nn.Dropout(dropout) 229 | self.dropout3 = nn.Dropout(dropout) 230 | 231 | self.activation = _get_activation_fn(activation) 232 | self.normalize_before = normalize_before 233 | 234 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 235 | return tensor if pos is None else tensor + pos 236 | 237 | def forward_post(self, tgt, memory, 238 | tgt_mask: Optional[Tensor] = None, 239 | memory_mask: Optional[Tensor] = None, 240 | tgt_key_padding_mask: Optional[Tensor] = None, 241 | memory_key_padding_mask: Optional[Tensor] = None, 242 | pos: Optional[Tensor] = None, 243 | query_pos: Optional[Tensor] = None): 244 | q = k = self.with_pos_embed(tgt, query_pos) 245 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 246 | key_padding_mask=tgt_key_padding_mask)[0] 247 | tgt = tgt + self.dropout1(tgt2) 248 | tgt = self.norm1(tgt) 249 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 250 | key=self.with_pos_embed(memory, pos), 251 | value=memory, attn_mask=memory_mask, 252 | key_padding_mask=memory_key_padding_mask)[0] 253 | tgt = tgt + self.dropout2(tgt2) 254 | tgt = self.norm2(tgt) 255 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 256 | tgt = tgt + self.dropout3(tgt2) 257 | tgt = self.norm3(tgt) 258 | return tgt 259 | 260 | def forward_pre(self, tgt, memory, 261 | tgt_mask: Optional[Tensor] = None, 262 | memory_mask: Optional[Tensor] = None, 263 | tgt_key_padding_mask: Optional[Tensor] = None, 264 | memory_key_padding_mask: Optional[Tensor] = None, 265 | pos: Optional[Tensor] = None, 266 | query_pos: Optional[Tensor] = None): 267 | tgt2 = self.norm1(tgt) 268 | q = k = self.with_pos_embed(tgt2, query_pos) 269 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 270 | key_padding_mask=tgt_key_padding_mask)[0] 271 | tgt = tgt + self.dropout1(tgt2) 272 | tgt2 = self.norm2(tgt) 273 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 274 | key=self.with_pos_embed(memory, pos), 275 | value=memory, attn_mask=memory_mask, 276 | key_padding_mask=memory_key_padding_mask)[0] 277 | tgt = tgt + self.dropout2(tgt2) 278 | tgt2 = self.norm3(tgt) 279 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 280 | tgt = tgt + self.dropout3(tgt2) 281 | return tgt 282 | 283 | def forward(self, tgt, memory, 284 | tgt_mask: Optional[Tensor] = None, 285 | memory_mask: Optional[Tensor] = None, 286 | tgt_key_padding_mask: Optional[Tensor] = None, 287 | memory_key_padding_mask: Optional[Tensor] = None, 288 | pos: Optional[Tensor] = None, 289 | query_pos: Optional[Tensor] = None): 290 | if self.normalize_before: 291 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 292 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 293 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 294 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 295 | 296 | 297 | def _get_clones(module, N): 298 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 299 | 300 | 301 | def build_gen(args): 302 | return GEN( 303 | d_model=args.hidden_dim, 304 | dropout=args.dropout, 305 | nhead=args.nheads, 306 | dim_feedforward=args.dim_feedforward, 307 | num_encoder_layers=args.enc_layers, 308 | num_dec_layers=args.dec_layers, 309 | normalize_before=args.pre_norm, 310 | return_intermediate_dec=True, 311 | ) 312 | 313 | 314 | def _get_activation_fn(activation): 315 | """Return an activation function given a string""" 316 | if activation == "relu": 317 | return F.relu 318 | if activation == "gelu": 319 | return F.gelu 320 | if activation == "glu": 321 | return F.glu 322 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 323 | -------------------------------------------------------------------------------- /util/misc.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | # Modified from DETR (https://github.com/facebookresearch/detr) 6 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 7 | # ------------------------------------------------------------------------ 8 | """ 9 | Misc functions, including distributed helpers. 10 | 11 | Mostly copy-paste from torchvision references. 12 | """ 13 | import os 14 | import subprocess 15 | import time 16 | from collections import defaultdict, deque 17 | import datetime 18 | import pickle 19 | from typing import Optional, List 20 | 21 | import torch 22 | import torch.distributed as dist 23 | from torch import Tensor 24 | 25 | # needed due to empty tensor bug in pytorch and torchvision 0.5 26 | import torchvision 27 | if float(torchvision.__version__[:3]) < 0.7: 28 | from torchvision.ops import _new_empty_tensor 29 | from torchvision.ops.misc import _output_size 30 | 31 | 32 | class SmoothedValue(object): 33 | """Track a series of values and provide access to smoothed values over a 34 | window or the global series average. 35 | """ 36 | 37 | def __init__(self, window_size=20, fmt=None): 38 | if fmt is None: 39 | fmt = "{median:.4f} ({global_avg:.4f})" 40 | self.deque = deque(maxlen=window_size) 41 | self.total = 0.0 42 | self.count = 0 43 | self.fmt = fmt 44 | 45 | def update(self, value, n=1): 46 | self.deque.append(value) 47 | self.count += n 48 | self.total += value * n 49 | 50 | def synchronize_between_processes(self): 51 | """ 52 | Warning: does not synchronize the deque! 53 | """ 54 | if not is_dist_avail_and_initialized(): 55 | return 56 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 57 | dist.barrier() 58 | dist.all_reduce(t) 59 | t = t.tolist() 60 | self.count = int(t[0]) 61 | self.total = t[1] 62 | 63 | @property 64 | def median(self): 65 | d = torch.tensor(list(self.deque)) 66 | return d.median().item() 67 | 68 | @property 69 | def avg(self): 70 | d = torch.tensor(list(self.deque), dtype=torch.float32) 71 | return d.mean().item() 72 | 73 | @property 74 | def global_avg(self): 75 | return self.total / self.count 76 | 77 | @property 78 | def max(self): 79 | return max(self.deque) 80 | 81 | @property 82 | def value(self): 83 | return self.deque[-1] 84 | 85 | def __str__(self): 86 | return self.fmt.format( 87 | median=self.median, 88 | avg=self.avg, 89 | global_avg=self.global_avg, 90 | max=self.max, 91 | value=self.value) 92 | 93 | 94 | def all_gather(data): 95 | """ 96 | Run all_gather on arbitrary picklable data (not necessarily tensors) 97 | Args: 98 | data: any picklable object 99 | Returns: 100 | list[data]: list of data gathered from each rank 101 | """ 102 | world_size = get_world_size() 103 | if world_size == 1: 104 | return [data] 105 | 106 | # serialized to a Tensor 107 | buffer = pickle.dumps(data) 108 | storage = torch.ByteStorage.from_buffer(buffer) 109 | tensor = torch.ByteTensor(storage).to("cuda") 110 | 111 | # obtain Tensor size of each rank 112 | local_size = torch.tensor([tensor.numel()], device="cuda") 113 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 114 | dist.all_gather(size_list, local_size) 115 | size_list = [int(size.item()) for size in size_list] 116 | max_size = max(size_list) 117 | 118 | # receiving Tensor from all ranks 119 | # we pad the tensor because torch all_gather does not support 120 | # gathering tensors of different shapes 121 | tensor_list = [] 122 | for _ in size_list: 123 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 124 | if local_size != max_size: 125 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 126 | tensor = torch.cat((tensor, padding), dim=0) 127 | dist.all_gather(tensor_list, tensor) 128 | 129 | data_list = [] 130 | for size, tensor in zip(size_list, tensor_list): 131 | buffer = tensor.cpu().numpy().tobytes()[:size] 132 | data_list.append(pickle.loads(buffer)) 133 | 134 | return data_list 135 | 136 | 137 | def reduce_dict(input_dict, average=True): 138 | """ 139 | Args: 140 | input_dict (dict): all the values will be reduced 141 | average (bool): whether to do average or sum 142 | Reduce the values in the dictionary from all processes so that all processes 143 | have the averaged results. Returns a dict with the same fields as 144 | input_dict, after reduction. 145 | """ 146 | world_size = get_world_size() 147 | if world_size < 2: 148 | return input_dict 149 | with torch.no_grad(): 150 | names = [] 151 | values = [] 152 | # sort the keys so that they are consistent across processes 153 | for k in sorted(input_dict.keys()): 154 | names.append(k) 155 | values.append(input_dict[k]) 156 | values = torch.stack(values, dim=0) 157 | dist.all_reduce(values) 158 | if average: 159 | values /= world_size 160 | reduced_dict = {k: v for k, v in zip(names, values)} 161 | return reduced_dict 162 | 163 | 164 | class MetricLogger(object): 165 | def __init__(self, delimiter="\t"): 166 | self.meters = defaultdict(SmoothedValue) 167 | self.delimiter = delimiter 168 | 169 | def update(self, **kwargs): 170 | for k, v in kwargs.items(): 171 | if isinstance(v, torch.Tensor): 172 | v = v.item() 173 | assert isinstance(v, (float, int)) 174 | self.meters[k].update(v) 175 | 176 | def __getattr__(self, attr): 177 | if attr in self.meters: 178 | return self.meters[attr] 179 | if attr in self.__dict__: 180 | return self.__dict__[attr] 181 | raise AttributeError("'{}' object has no attribute '{}'".format( 182 | type(self).__name__, attr)) 183 | 184 | def __str__(self): 185 | loss_str = [] 186 | for name, meter in self.meters.items(): 187 | loss_str.append( 188 | "{}: {}".format(name, str(meter)) 189 | ) 190 | return self.delimiter.join(loss_str) 191 | 192 | def synchronize_between_processes(self): 193 | for meter in self.meters.values(): 194 | meter.synchronize_between_processes() 195 | 196 | def add_meter(self, name, meter): 197 | self.meters[name] = meter 198 | 199 | def log_every(self, iterable, print_freq, header=None): 200 | i = 0 201 | if not header: 202 | header = '' 203 | start_time = time.time() 204 | end = time.time() 205 | iter_time = SmoothedValue(fmt='{avg:.4f}') 206 | data_time = SmoothedValue(fmt='{avg:.4f}') 207 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 208 | if torch.cuda.is_available(): 209 | log_msg = self.delimiter.join([ 210 | header, 211 | '[{0' + space_fmt + '}/{1}]', 212 | 'eta: {eta}', 213 | '{meters}', 214 | 'time: {time}', 215 | 'data: {data}', 216 | 'max mem: {memory:.0f}' 217 | ]) 218 | else: 219 | log_msg = self.delimiter.join([ 220 | header, 221 | '[{0' + space_fmt + '}/{1}]', 222 | 'eta: {eta}', 223 | '{meters}', 224 | 'time: {time}', 225 | 'data: {data}' 226 | ]) 227 | MB = 1024.0 * 1024.0 228 | for obj in iterable: 229 | data_time.update(time.time() - end) 230 | yield obj 231 | iter_time.update(time.time() - end) 232 | if i % print_freq == 0 or i == len(iterable) - 1: 233 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 234 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 235 | if torch.cuda.is_available(): 236 | print(log_msg.format( 237 | i, len(iterable), eta=eta_string, 238 | meters=str(self), 239 | time=str(iter_time), data=str(data_time), 240 | memory=torch.cuda.max_memory_allocated() / MB)) 241 | else: 242 | print(log_msg.format( 243 | i, len(iterable), eta=eta_string, 244 | meters=str(self), 245 | time=str(iter_time), data=str(data_time))) 246 | i += 1 247 | end = time.time() 248 | total_time = time.time() - start_time 249 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 250 | print('{} Total time: {} ({:.4f} s / it)'.format( 251 | header, total_time_str, total_time / len(iterable))) 252 | 253 | 254 | def get_sha(): 255 | cwd = os.path.dirname(os.path.abspath(__file__)) 256 | 257 | def _run(command): 258 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 259 | sha = 'N/A' 260 | diff = "clean" 261 | branch = 'N/A' 262 | try: 263 | sha = _run(['git', 'rev-parse', 'HEAD']) 264 | subprocess.check_output(['git', 'diff'], cwd=cwd) 265 | diff = _run(['git', 'diff-index', 'HEAD']) 266 | diff = "has uncommited changes" if diff else "clean" 267 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 268 | except Exception: 269 | pass 270 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 271 | return message 272 | 273 | 274 | def collate_fn(batch): 275 | batch = list(zip(*batch)) 276 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 277 | return tuple(batch) 278 | 279 | 280 | def _max_by_axis(the_list): 281 | # type: (List[List[int]]) -> List[int] 282 | maxes = the_list[0] 283 | for sublist in the_list[1:]: 284 | for index, item in enumerate(sublist): 285 | maxes[index] = max(maxes[index], item) 286 | return maxes 287 | 288 | 289 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 290 | # TODO make this more general 291 | if tensor_list[0].ndim == 3: 292 | # TODO make it support different-sized images 293 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 294 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 295 | batch_shape = [len(tensor_list)] + max_size 296 | b, c, h, w = batch_shape 297 | dtype = tensor_list[0].dtype 298 | device = tensor_list[0].device 299 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 300 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 301 | for img, pad_img, m in zip(tensor_list, tensor, mask): 302 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 303 | m[: img.shape[1], :img.shape[2]] = False 304 | else: 305 | raise ValueError('not supported') 306 | return NestedTensor(tensor, mask) 307 | 308 | 309 | class NestedTensor(object): 310 | def __init__(self, tensors, mask: Optional[Tensor]): 311 | self.tensors = tensors 312 | self.mask = mask 313 | 314 | def to(self, device): 315 | # type: (Device) -> NestedTensor # noqa 316 | cast_tensor = self.tensors.to(device) 317 | mask = self.mask 318 | if mask is not None: 319 | assert mask is not None 320 | cast_mask = mask.to(device) 321 | else: 322 | cast_mask = None 323 | return NestedTensor(cast_tensor, cast_mask) 324 | 325 | def decompose(self): 326 | return self.tensors, self.mask 327 | 328 | def __repr__(self): 329 | return str(self.tensors) 330 | 331 | 332 | def setup_for_distributed(is_master): 333 | """ 334 | This function disables printing when not in master process 335 | """ 336 | import builtins as __builtin__ 337 | builtin_print = __builtin__.print 338 | 339 | def print(*args, **kwargs): 340 | force = kwargs.pop('force', False) 341 | if is_master or force: 342 | builtin_print(*args, **kwargs) 343 | 344 | __builtin__.print = print 345 | 346 | 347 | def is_dist_avail_and_initialized(): 348 | if not dist.is_available(): 349 | return False 350 | if not dist.is_initialized(): 351 | return False 352 | return True 353 | 354 | 355 | def get_world_size(): 356 | if not is_dist_avail_and_initialized(): 357 | return 1 358 | return dist.get_world_size() 359 | 360 | 361 | def get_rank(): 362 | if not is_dist_avail_and_initialized(): 363 | return 0 364 | return dist.get_rank() 365 | 366 | 367 | def is_main_process(): 368 | return get_rank() == 0 369 | 370 | 371 | def save_on_master(*args, **kwargs): 372 | if is_main_process(): 373 | torch.save(*args, **kwargs) 374 | 375 | 376 | def init_distributed_mode(args): 377 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 378 | args.rank = int(os.environ["RANK"]) 379 | args.world_size = int(os.environ['WORLD_SIZE']) 380 | #args.gpu = int(os.environ['LOCAL_RANK']) 381 | args.gpu = args.rank % torch.cuda.device_count() 382 | elif 'SLURM_PROCID' in os.environ: 383 | args.rank = int(os.environ['SLURM_PROCID']) 384 | args.gpu = args.rank % torch.cuda.device_count() 385 | else: 386 | print('Not using distributed mode') 387 | args.distributed = False 388 | return 389 | 390 | args.distributed = True 391 | 392 | torch.cuda.set_device(args.gpu) 393 | args.dist_backend = 'nccl' 394 | print('| distributed init (rank {}): {}'.format( 395 | args.rank, args.dist_url), flush=True) 396 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 397 | world_size=args.world_size, rank=args.rank) 398 | torch.distributed.barrier() 399 | setup_for_distributed(args.rank == 0) 400 | 401 | 402 | @torch.no_grad() 403 | def accuracy(output, target, topk=(1,)): 404 | """Computes the precision@k for the specified values of k""" 405 | if target.numel() == 0: 406 | return [torch.zeros([], device=output.device)] 407 | maxk = max(topk) 408 | batch_size = target.size(0) 409 | 410 | _, pred = output.topk(maxk, 1, True, True) 411 | pred = pred.t() 412 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 413 | 414 | res = [] 415 | for k in topk: 416 | correct_k = correct[:k].view(-1).float().sum(0) 417 | res.append(correct_k.mul_(100.0 / batch_size)) 418 | return res 419 | 420 | 421 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 422 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 423 | """ 424 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 425 | This will eventually be supported natively by PyTorch, and this 426 | class can go away. 427 | """ 428 | if float(torchvision.__version__[:3]) < 0.7: 429 | if input.numel() > 0: 430 | return torch.nn.functional.interpolate( 431 | input, size, scale_factor, mode, align_corners 432 | ) 433 | 434 | output_shape = _output_size(2, input, size, scale_factor) 435 | output_shape = list(input.shape[:-2]) + list(output_shape) 436 | return _new_empty_tensor(input, output_shape) 437 | else: 438 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 439 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import random 5 | import time 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | from torch.utils.data import DataLoader, DistributedSampler 11 | 12 | import datasets 13 | import util.misc as utils 14 | from datasets import build_dataset 15 | from engine import train_one_epoch, evaluate_hoi 16 | from models import build_model 17 | import os 18 | 19 | 20 | def get_args_parser(): 21 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False) 22 | parser.add_argument('--lr', default=1e-4, type=float) 23 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 24 | parser.add_argument('--lr_clip', default=1e-5, type=float) 25 | parser.add_argument('--batch_size', default=2, type=int) 26 | parser.add_argument('--weight_decay', default=1e-4, type=float) 27 | parser.add_argument('--epochs', default=150, type=int) 28 | parser.add_argument('--lr_drop', default=100, type=int) 29 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 30 | help='gradient clipping max norm') 31 | 32 | # Model parameters 33 | parser.add_argument('--frozen_weights', type=str, default=None, 34 | help="Path to the pretrained model. If set, only the mask head will be trained") 35 | # * Backbone 36 | parser.add_argument('--backbone', default='resnet50', type=str, 37 | help="Name of the convolutional backbone to use") 38 | parser.add_argument('--dilation', action='store_true', 39 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 40 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 41 | help="Type of positional embedding to use on top of the image features") 42 | 43 | # * Transformer 44 | parser.add_argument('--enc_layers', default=6, type=int, 45 | help="Number of encoding layers in the transformer") 46 | parser.add_argument('--dec_layers', default=3, type=int, 47 | help="Number of stage1 decoding layers in the transformer") 48 | parser.add_argument('--dim_feedforward', default=2048, type=int, 49 | help="Intermediate size of the feedforward layers in the transformer blocks") 50 | parser.add_argument('--hidden_dim', default=256, type=int, 51 | help="Size of the embeddings (dimension of the transformer)") 52 | parser.add_argument('--dropout', default=0.1, type=float, 53 | help="Dropout applied in the transformer") 54 | parser.add_argument('--nheads', default=8, type=int, 55 | help="Number of attention heads inside the transformer's attentions") 56 | parser.add_argument('--num_queries', default=100, type=int, 57 | help="Number of query slots") 58 | parser.add_argument('--pre_norm', action='store_true') 59 | 60 | # * Segmentation 61 | parser.add_argument('--masks', action='store_true', 62 | help="Train segmentation head if the flag is provided") 63 | 64 | # HOI 65 | parser.add_argument('--hoi', action='store_true', 66 | help="Train for HOI if the flag is provided") 67 | parser.add_argument('--num_obj_classes', type=int, default=80, 68 | help="Number of object classes") 69 | parser.add_argument('--num_verb_classes', type=int, default=117, 70 | help="Number of verb classes") 71 | parser.add_argument('--pretrained', type=str, default='', 72 | help='Pretrained model path') 73 | parser.add_argument('--subject_category_id', default=0, type=int) 74 | parser.add_argument('--verb_loss_type', type=str, default='focal', 75 | help='Loss type for the verb classification') 76 | 77 | # Loss 78 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 79 | help="Disables auxiliary decoding losses (loss at each layer)") 80 | parser.add_argument('--with_mimic', action='store_true', 81 | help="Use clip feature mimic") 82 | # * Matcher 83 | parser.add_argument('--set_cost_class', default=1, type=float, 84 | help="Class coefficient in the matching cost") 85 | parser.add_argument('--set_cost_bbox', default=2.5, type=float, 86 | help="L1 box coefficient in the matching cost") 87 | parser.add_argument('--set_cost_giou', default=1, type=float, 88 | help="giou box coefficient in the matching cost") 89 | parser.add_argument('--set_cost_obj_class', default=1, type=float, 90 | help="Object class coefficient in the matching cost") 91 | parser.add_argument('--set_cost_verb_class', default=1, type=float, 92 | help="Verb class coefficient in the matching cost") 93 | parser.add_argument('--set_cost_hoi', default=1, type=float, 94 | help="Hoi class coefficient") 95 | 96 | # * Loss coefficients 97 | parser.add_argument('--mask_loss_coef', default=1, type=float) 98 | parser.add_argument('--dice_loss_coef', default=1, type=float) 99 | parser.add_argument('--bbox_loss_coef', default=2.5, type=float) 100 | parser.add_argument('--giou_loss_coef', default=1, type=float) 101 | parser.add_argument('--obj_loss_coef', default=1, type=float) 102 | parser.add_argument('--verb_loss_coef', default=2, type=float) 103 | parser.add_argument('--hoi_loss_coef', default=2, type=float) 104 | parser.add_argument('--mimic_loss_coef', default=20, type=float) 105 | parser.add_argument('--alpha', default=0.5, type=float, help='focal loss alpha') 106 | parser.add_argument('--eos_coef', default=0.1, type=float, 107 | help="Relative classification weight of the no-object class") 108 | 109 | # dataset parameters 110 | parser.add_argument('--dataset_file', default='coco') 111 | parser.add_argument('--coco_path', type=str) 112 | parser.add_argument('--coco_panoptic_path', type=str) 113 | parser.add_argument('--remove_difficult', action='store_true') 114 | parser.add_argument('--hoi_path', type=str) 115 | 116 | parser.add_argument('--output_dir', default='', 117 | help='path where to save, empty for no saving') 118 | parser.add_argument('--device', default='cuda', 119 | help='device to use for training / testing') 120 | parser.add_argument('--seed', default=42, type=int) 121 | parser.add_argument('--resume', default='', help='resume from checkpoint') 122 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', 123 | help='start epoch') 124 | parser.add_argument('--eval', action='store_true') 125 | parser.add_argument('--num_workers', default=2, type=int) 126 | 127 | # distributed training parameters 128 | parser.add_argument('--world_size', default=1, type=int, 129 | help='number of distributed processes') 130 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 131 | 132 | # hoi eval parameters 133 | parser.add_argument('--use_nms_filter', action='store_true', help='Use pair nms filter, default not use') 134 | parser.add_argument('--thres_nms', default=0.7, type=float) 135 | parser.add_argument('--nms_alpha', default=1, type=float) 136 | parser.add_argument('--nms_beta', default=0.5, type=float) 137 | parser.add_argument('--json_file', default='results.json', type=str) 138 | 139 | # clip 140 | parser.add_argument('--ft_clip_with_small_lr', action='store_true', 141 | help='Use smaller learning rate to finetune clip weights') 142 | parser.add_argument('--with_clip_label', action='store_true', help='Use clip to classify HOI') 143 | parser.add_argument('--early_stop_mimic', action='store_true', help='stop mimic after step') 144 | parser.add_argument('--with_obj_clip_label', action='store_true', help='Use clip to classify object') 145 | parser.add_argument('--clip_model', default='ViT-B/32', 146 | help='clip pretrained model path') 147 | parser.add_argument('--fix_clip', action='store_true', help='') 148 | parser.add_argument('--clip_embed_dim', default=512, type=int) 149 | 150 | # zero shot type 151 | parser.add_argument('--zero_shot_type', default='default', 152 | help='default, rare_first, non_rare_first, unseen_object, unseen_verb') 153 | parser.add_argument('--del_unseen', action='store_true', help='') 154 | 155 | return parser 156 | 157 | 158 | def main(args): 159 | utils.init_distributed_mode(args) 160 | print("git:\n {}\n".format(utils.get_sha())) 161 | 162 | if args.frozen_weights is not None: 163 | assert args.masks, "Frozen training is meant for segmentation only" 164 | print(args) 165 | 166 | device = torch.device(args.device) 167 | 168 | # fix the seed for reproducibility 169 | seed = args.seed + utils.get_rank() 170 | torch.manual_seed(seed) 171 | torch.cuda.manual_seed_all(seed) 172 | np.random.seed(seed) 173 | random.seed(seed) 174 | torch.backends.cudnn.deterministic = True 175 | 176 | model, criterion, postprocessors = build_model(args) 177 | model.to(device) 178 | print('****************') 179 | print(model) 180 | print('****************') 181 | 182 | model_without_ddp = model 183 | if args.distributed: 184 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 185 | model_without_ddp = model.module 186 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 187 | print('number of params:', n_parameters) 188 | 189 | for name, p in model.named_parameters(): 190 | if 'eval_visual_projection' in name: 191 | p.requires_grad = False 192 | 193 | if args.fix_clip: 194 | for name, p in model.named_parameters(): 195 | if 'obj_visual_projection' in name or 'visual_projection' in name: 196 | p.requires_grad = False 197 | 198 | if args.ft_clip_with_small_lr: 199 | if args.with_obj_clip_label and args.with_clip_label: 200 | param_dicts = [ 201 | {"params": [p for n, p in model_without_ddp.named_parameters() if 202 | "backbone" not in n and 'visual_projection' not in n and 'obj_visual_projection' not in n and p.requires_grad]}, 203 | { 204 | "params": [p for n, p in model_without_ddp.named_parameters() if 205 | "backbone" in n and p.requires_grad], 206 | "lr": args.lr_backbone, 207 | }, 208 | { 209 | "params": [p for n, p in model_without_ddp.named_parameters() if 210 | ('visual_projection' in n or 'obj_visual_projection' in n) and p.requires_grad], 211 | "lr": args.lr_clip, 212 | }, 213 | ] 214 | elif args.with_clip_label: 215 | param_dicts = [ 216 | {"params": [p for n, p in model_without_ddp.named_parameters() if 217 | "backbone" not in n and 'visual_projection' not in n and p.requires_grad]}, 218 | { 219 | "params": [p for n, p in model_without_ddp.named_parameters() if 220 | "backbone" in n and p.requires_grad], 221 | "lr": args.lr_backbone, 222 | }, 223 | { 224 | "params": [p for n, p in model_without_ddp.named_parameters() if 225 | 'visual_projection' in n and p.requires_grad], 226 | "lr": args.lr_clip, 227 | }, 228 | ] 229 | elif args.with_obj_clip_label: 230 | param_dicts = [ 231 | {"params": [p for n, p in model_without_ddp.named_parameters() if 232 | "backbone" not in n and 'obj_visual_projection' not in n and p.requires_grad]}, 233 | { 234 | "params": [p for n, p in model_without_ddp.named_parameters() if 235 | "backbone" in n and p.requires_grad], 236 | "lr": args.lr_backbone, 237 | }, 238 | { 239 | "params": [p for n, p in model_without_ddp.named_parameters() if 240 | 'obj_visual_projection' in n and p.requires_grad], 241 | "lr": args.lr_clip, 242 | }, 243 | ] 244 | else: 245 | raise 246 | 247 | else: 248 | param_dicts = [ 249 | {"params": [p for n, p in model_without_ddp.named_parameters() if "backbone" not in n and p.requires_grad]}, 250 | { 251 | "params": [p for n, p in model_without_ddp.named_parameters() if "backbone" in n and p.requires_grad], 252 | "lr": args.lr_backbone, 253 | }, 254 | ] 255 | 256 | optimizer = torch.optim.AdamW(param_dicts, lr=args.lr, 257 | weight_decay=args.weight_decay) 258 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 259 | 260 | dataset_train = build_dataset(image_set='train', args=args) 261 | dataset_val = build_dataset(image_set='val', args=args) 262 | 263 | if args.distributed: 264 | sampler_train = DistributedSampler(dataset_train) 265 | sampler_val = DistributedSampler(dataset_val, shuffle=False) 266 | else: 267 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 268 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 269 | 270 | batch_sampler_train = torch.utils.data.BatchSampler( 271 | sampler_train, args.batch_size, drop_last=True) 272 | 273 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 274 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 275 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, 276 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) 277 | 278 | if args.frozen_weights is not None: 279 | checkpoint = torch.load(args.frozen_weights, map_location='cpu') 280 | model_without_ddp.detr.load_state_dict(checkpoint['model']) 281 | 282 | output_dir = Path(args.output_dir) 283 | if args.resume: 284 | if args.resume.startswith('https'): 285 | checkpoint = torch.hub.load_state_dict_from_url( 286 | args.resume, map_location='cpu', check_hash=True) 287 | else: 288 | checkpoint = torch.load(args.resume, map_location='cpu') 289 | model_without_ddp.load_state_dict(checkpoint['model']) 290 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 291 | optimizer.load_state_dict(checkpoint['optimizer']) 292 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 293 | args.start_epoch = checkpoint['epoch'] + 1 294 | elif args.pretrained: 295 | checkpoint = torch.load(args.pretrained, map_location='cpu') 296 | if args.eval: 297 | model_without_ddp.load_state_dict(checkpoint['model']) 298 | else: 299 | model_without_ddp.load_state_dict(checkpoint['model'], strict=False) 300 | 301 | if args.eval: 302 | test_stats = evaluate_hoi(args.dataset_file, model, postprocessors, data_loader_val, 303 | args.subject_category_id, device, args) 304 | return 305 | 306 | print("Start training") 307 | start_time = time.time() 308 | best_performance = 0 309 | for epoch in range(args.start_epoch, args.epochs): 310 | if args.distributed: 311 | sampler_train.set_epoch(epoch) 312 | 313 | train_stats = train_one_epoch( 314 | model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm) 315 | lr_scheduler.step() 316 | if epoch == args.epochs - 1: 317 | checkpoint_path = os.path.join(output_dir, 'checkpoint_last.pth') 318 | utils.save_on_master({ 319 | 'model': model_without_ddp.state_dict(), 320 | 'optimizer': optimizer.state_dict(), 321 | 'lr_scheduler': lr_scheduler.state_dict(), 322 | 'epoch': epoch, 323 | 'args': args, 324 | }, checkpoint_path) 325 | 326 | if epoch < args.lr_drop and epoch % 5 != 0: ## eval every 5 epoch before lr_drop 327 | continue 328 | elif epoch >= args.lr_drop and epoch % 2 == 0: ## eval every 2 epoch after lr_drop 329 | continue 330 | 331 | test_stats = evaluate_hoi(args.dataset_file, model, postprocessors, data_loader_val, 332 | args.subject_category_id, device, args) 333 | if args.dataset_file == 'hico': 334 | performance = test_stats['mAP'] 335 | elif args.dataset_file == 'vcoco': 336 | performance = test_stats['mAP_all'] 337 | elif args.dataset_file == 'hoia': 338 | performance = test_stats['mAP'] 339 | 340 | if performance > best_performance: 341 | checkpoint_path = os.path.join(output_dir, 'checkpoint_best.pth') 342 | utils.save_on_master({ 343 | 'model': model_without_ddp.state_dict(), 344 | 'optimizer': optimizer.state_dict(), 345 | 'lr_scheduler': lr_scheduler.state_dict(), 346 | 'epoch': epoch, 347 | 'args': args, 348 | }, checkpoint_path) 349 | 350 | best_performance = performance 351 | 352 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 353 | **{f'test_{k}': v for k, v in test_stats.items()}, 354 | 'epoch': epoch, 355 | 'n_parameters': n_parameters} 356 | 357 | if args.output_dir and utils.is_main_process(): 358 | with (output_dir / "log.txt").open("a") as f: 359 | f.write(json.dumps(log_stats) + "\n") 360 | 361 | total_time = time.time() - start_time 362 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 363 | print('Training time {}'.format(total_time_str)) 364 | 365 | 366 | if __name__ == '__main__': 367 | parser = argparse.ArgumentParser('GEN VLKT training and evaluation script', parents=[get_args_parser()]) 368 | args = parser.parse_args() 369 | if args.output_dir: 370 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 371 | main(args) 372 | -------------------------------------------------------------------------------- /datasets/hico_eval_triplet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) Hitachi, Ltd. All Rights Reserved. 3 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 4 | # ------------------------------------------------------------------------ 5 | import numpy as np 6 | from collections import defaultdict 7 | import os, cv2, json 8 | from .hico_text_label import hico_text_label 9 | from util.topk import top_k 10 | 11 | class HICOEvaluator(): 12 | def __init__(self, preds, gts, rare_triplets, non_rare_triplets, correct_mat, args): 13 | self.overlap_iou = 0.5 14 | self.max_hois = 100 15 | 16 | self.zero_shot_type = args.zero_shot_type 17 | 18 | self.use_nms_filter = args.use_nms_filter 19 | self.thres_nms = args.thres_nms 20 | self.nms_alpha = args.nms_alpha 21 | self.nms_beta = args.nms_beta 22 | 23 | self.use_score_thres = False 24 | self.thres_score = 1e-5 25 | 26 | self.use_soft_nms = False 27 | self.soft_nms_sigma = 0.5 28 | self.soft_nms_thres_score = 1e-11 29 | 30 | self.rare_triplets = rare_triplets 31 | self.non_rare_triplets = non_rare_triplets 32 | 33 | self.fp = defaultdict(list) 34 | self.tp = defaultdict(list) 35 | self.score = defaultdict(list) 36 | self.sum_gts = defaultdict(lambda: 0) 37 | self.gt_triplets = [] 38 | 39 | self.preds = [] 40 | self.hico_triplet_labels = list(hico_text_label.keys()) 41 | self.hoi_obj_list = [] 42 | for hoi_pair in self.hico_triplet_labels: 43 | self.hoi_obj_list.append(hoi_pair[1]) 44 | 45 | for index, img_preds in enumerate(preds): 46 | img_preds = {k: v.to('cpu').numpy() for k, v in img_preds.items()} 47 | bboxes = [{'bbox': list(bbox)} for bbox in img_preds['boxes']] 48 | obj_scores = img_preds['obj_scores'] * img_preds['obj_scores'] 49 | hoi_scores = img_preds['hoi_scores'] + obj_scores[:, self.hoi_obj_list] 50 | 51 | hoi_labels = np.tile(np.arange(hoi_scores.shape[1]), (hoi_scores.shape[0], 1)) 52 | subject_ids = np.tile(img_preds['sub_ids'], (hoi_scores.shape[1], 1)).T 53 | object_ids = np.tile(img_preds['obj_ids'], (hoi_scores.shape[1], 1)).T 54 | 55 | hoi_scores = hoi_scores.ravel() 56 | hoi_labels = hoi_labels.ravel() 57 | subject_ids = subject_ids.ravel() 58 | object_ids = object_ids.ravel() 59 | 60 | topk_hoi_scores = top_k(list(hoi_scores), self.max_hois) 61 | topk_indexes = np.array([np.where(hoi_scores == score)[0][0] for score in topk_hoi_scores]) 62 | 63 | if len(subject_ids) > 0: 64 | hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} 65 | for 66 | subject_id, object_id, category_id, score in 67 | zip(subject_ids[topk_indexes], object_ids[topk_indexes], hoi_labels[topk_indexes], topk_hoi_scores)] 68 | hois = hois[:self.max_hois] 69 | else: 70 | hois = [] 71 | 72 | filename = gts[index]['filename'] 73 | self.preds.append({ 74 | 'filename': filename, 75 | 'predictions': bboxes, 76 | 'hoi_prediction': hois 77 | }) 78 | 79 | if self.use_nms_filter: 80 | print('eval use_nms_filter ...') 81 | self.preds = self.triplet_nms_filter(self.preds) 82 | 83 | 84 | self.gts = [] 85 | 86 | for i, img_gts in enumerate(gts): 87 | filename = img_gts['filename'] 88 | img_gts = {k: v.to('cpu').numpy() for k, v in img_gts.items() if k != 'id' and k != 'filename'} 89 | bbox_anns = [{'bbox': list(bbox), 'category_id': label} for bbox, label in 90 | zip(img_gts['boxes'], img_gts['labels'])] 91 | hoi_anns = [{'subject_id': hoi[0], 'object_id': hoi[1], 92 | 'category_id': self.hico_triplet_labels.index((hoi[2], bbox_anns[hoi[1]]['category_id']))} 93 | for hoi in img_gts['hois']] 94 | self.gts.append({ 95 | 'filename': filename, 96 | 'annotations': bbox_anns, 97 | 'hoi_annotation': hoi_anns 98 | }) 99 | for hoi in self.gts[-1]['hoi_annotation']: 100 | triplet = hoi['category_id'] 101 | 102 | if triplet not in self.gt_triplets: 103 | self.gt_triplets.append(triplet) 104 | 105 | self.sum_gts[triplet] += 1 106 | 107 | with open(args.json_file, 'w') as f: 108 | f.write(json.dumps(str({'preds': self.preds, 'gts': self.gts}))) 109 | 110 | print(len(self.preds)) 111 | print(len(self.gts)) 112 | 113 | def evaluate(self): 114 | for img_preds, img_gts in zip(self.preds, self.gts): 115 | pred_bboxes = img_preds['predictions'] 116 | if len(pred_bboxes) == 0: continue 117 | 118 | gt_bboxes = img_gts['annotations'] 119 | pred_hois = img_preds['hoi_prediction'] 120 | gt_hois = img_gts['hoi_annotation'] 121 | if len(gt_bboxes) != 0: 122 | bbox_pairs, bbox_overlaps = self.compute_iou_mat(gt_bboxes, pred_bboxes) 123 | self.compute_fptp(pred_hois, gt_hois, bbox_pairs, bbox_overlaps) 124 | else: 125 | for pred_hoi in pred_hois: 126 | triplet = pred_hoi['category_id'] 127 | if triplet not in self.gt_triplets: 128 | continue 129 | self.tp[triplet].append(0) 130 | self.fp[triplet].append(1) 131 | self.score[triplet].append(pred_hoi['score']) 132 | map = self.compute_map() 133 | return map 134 | 135 | def compute_map(self): 136 | ap = defaultdict(lambda: 0) 137 | rare_ap = defaultdict(lambda: 0) 138 | non_rare_ap = defaultdict(lambda: 0) 139 | max_recall = defaultdict(lambda: 0) 140 | for triplet in self.gt_triplets: 141 | sum_gts = self.sum_gts[triplet] 142 | orignal_triplet = self.hico_triplet_labels[triplet] 143 | orignal_triplet = (0, orignal_triplet[1], orignal_triplet[0]) 144 | if sum_gts == 0: 145 | continue 146 | 147 | tp = np.array((self.tp[triplet])) 148 | fp = np.array((self.fp[triplet])) 149 | if len(tp) == 0: 150 | ap[triplet] = 0 151 | max_recall[triplet] = 0 152 | if orignal_triplet in self.rare_triplets: 153 | rare_ap[triplet] = 0 154 | elif orignal_triplet in self.non_rare_triplets: 155 | non_rare_ap[triplet] = 0 156 | else: 157 | print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet)) 158 | continue 159 | 160 | score = np.array(self.score[triplet]) 161 | sort_inds = np.argsort(-score) 162 | fp = fp[sort_inds] 163 | tp = tp[sort_inds] 164 | fp = np.cumsum(fp) 165 | tp = np.cumsum(tp) 166 | rec = tp / sum_gts 167 | prec = tp / (fp + tp) 168 | # ap[triplet] = self.cal_prec(rec, prec) 169 | ap[triplet] = self.voc_ap(rec, prec) 170 | max_recall[triplet] = np.amax(rec) 171 | if orignal_triplet in self.rare_triplets: 172 | rare_ap[triplet] = ap[triplet] 173 | elif orignal_triplet in self.non_rare_triplets: 174 | non_rare_ap[triplet] = ap[triplet] 175 | else: 176 | print('Warning: triplet {} is neither in rare triplets nor in non-rare triplets'.format(triplet)) 177 | m_ap = np.mean(list(ap.values())) 178 | m_ap_rare = np.mean(list(rare_ap.values())) 179 | m_ap_non_rare = np.mean(list(non_rare_ap.values())) 180 | m_max_recall = np.mean(list(max_recall.values())) 181 | 182 | print('--------------------') 183 | if self.zero_shot_type == "default": 184 | print('mAP full: {} mAP rare: {} mAP non-rare: {} mean max recall: {}'.format(m_ap, m_ap_rare, m_ap_non_rare, 185 | m_max_recall)) 186 | return_dict = {'mAP': m_ap, 'mAP rare': m_ap_rare, 'mAP non-rare': m_ap_non_rare, 'mean max recall': m_max_recall} 187 | 188 | elif self.zero_shot_type == "unseen_object": 189 | print('mAP full: {} mAP unseen-obj: {} mAP seen-obj: {} mean max recall: {}'.format(m_ap, m_ap_rare, m_ap_non_rare, 190 | m_max_recall)) 191 | return_dict = {'mAP': m_ap, 'mAP unseen-obj': m_ap_rare, 'mAP seen-obj': m_ap_non_rare, 'mean max recall': m_max_recall} 192 | 193 | else: 194 | print('mAP full: {} mAP unseen: {} mAP seen: {} mean max recall: {}'.format(m_ap, m_ap_rare, m_ap_non_rare, 195 | m_max_recall)) 196 | return_dict = {'mAP': m_ap, 'mAP unseen': m_ap_rare, 'mAP seen': m_ap_non_rare, 'mean max recall': m_max_recall} 197 | 198 | print('--------------------') 199 | 200 | return return_dict 201 | 202 | def cal_prec(self, rec, prec, t=0.8): 203 | if np.sum(rec >= t) == 0: 204 | p = 0 205 | else: 206 | p = np.max(prec[rec >= t]) 207 | return p 208 | 209 | def voc_ap(self, rec, prec): 210 | ap = 0. 211 | for t in np.arange(0., 1.1, 0.1): 212 | if np.sum(rec >= t) == 0: 213 | p = 0 214 | else: 215 | p = np.max(prec[rec >= t]) 216 | ap = ap + p / 11. 217 | return ap 218 | 219 | def compute_fptp(self, pred_hois, gt_hois, match_pairs, bbox_overlaps): 220 | pos_pred_ids = match_pairs.keys() 221 | vis_tag = np.zeros(len(gt_hois)) 222 | pred_hois.sort(key=lambda k: (k.get('score', 0)), reverse=True) 223 | if len(pred_hois) != 0: 224 | for pred_hoi in pred_hois: 225 | is_match = 0 226 | if len(match_pairs) != 0 and pred_hoi['subject_id'] in pos_pred_ids and pred_hoi[ 227 | 'object_id'] in pos_pred_ids: 228 | pred_sub_ids = match_pairs[pred_hoi['subject_id']] 229 | pred_obj_ids = match_pairs[pred_hoi['object_id']] 230 | pred_sub_overlaps = bbox_overlaps[pred_hoi['subject_id']] 231 | pred_obj_overlaps = bbox_overlaps[pred_hoi['object_id']] 232 | pred_category_id = pred_hoi['category_id'] 233 | max_overlap = 0 234 | max_gt_hoi = 0 235 | for gt_hoi in gt_hois: 236 | if gt_hoi['subject_id'] in pred_sub_ids and gt_hoi['object_id'] in pred_obj_ids \ 237 | and pred_category_id == gt_hoi['category_id']: 238 | is_match = 1 239 | min_overlap_gt = min(pred_sub_overlaps[pred_sub_ids.index(gt_hoi['subject_id'])], 240 | pred_obj_overlaps[pred_obj_ids.index(gt_hoi['object_id'])]) 241 | if min_overlap_gt > max_overlap: 242 | max_overlap = min_overlap_gt 243 | max_gt_hoi = gt_hoi 244 | triplet = pred_hoi['category_id'] 245 | if triplet not in self.gt_triplets: 246 | continue 247 | if is_match == 1 and vis_tag[gt_hois.index(max_gt_hoi)] == 0: 248 | self.fp[triplet].append(0) 249 | self.tp[triplet].append(1) 250 | vis_tag[gt_hois.index(max_gt_hoi)] = 1 251 | else: 252 | self.fp[triplet].append(1) 253 | self.tp[triplet].append(0) 254 | self.score[triplet].append(pred_hoi['score']) 255 | 256 | def compute_iou_mat(self, bbox_list1, bbox_list2): 257 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) 258 | if len(bbox_list1) == 0 or len(bbox_list2) == 0: 259 | return {} 260 | for i, bbox1 in enumerate(bbox_list1): 261 | for j, bbox2 in enumerate(bbox_list2): 262 | iou_i = self.compute_IOU(bbox1, bbox2) 263 | iou_mat[i, j] = iou_i 264 | 265 | iou_mat_ov = iou_mat.copy() 266 | iou_mat[iou_mat >= self.overlap_iou] = 1 267 | iou_mat[iou_mat < self.overlap_iou] = 0 268 | 269 | match_pairs = np.nonzero(iou_mat) 270 | match_pairs_dict = {} 271 | match_pair_overlaps = {} 272 | if iou_mat.max() > 0: 273 | for i, pred_id in enumerate(match_pairs[1]): 274 | if pred_id not in match_pairs_dict.keys(): 275 | match_pairs_dict[pred_id] = [] 276 | match_pair_overlaps[pred_id] = [] 277 | match_pairs_dict[pred_id].append(match_pairs[0][i]) 278 | match_pair_overlaps[pred_id].append(iou_mat_ov[match_pairs[0][i], pred_id]) 279 | return match_pairs_dict, match_pair_overlaps 280 | 281 | def compute_IOU(self, bbox1, bbox2): 282 | rec1 = bbox1['bbox'] 283 | rec2 = bbox2['bbox'] 284 | # computing area of each rectangles 285 | S_rec1 = (rec1[2] - rec1[0] + 1) * (rec1[3] - rec1[1] + 1) 286 | S_rec2 = (rec2[2] - rec2[0] + 1) * (rec2[3] - rec2[1] + 1) 287 | 288 | # computing the sum_area 289 | sum_area = S_rec1 + S_rec2 290 | 291 | # find the each edge of intersect rectangle 292 | left_line = max(rec1[1], rec2[1]) 293 | right_line = min(rec1[3], rec2[3]) 294 | top_line = max(rec1[0], rec2[0]) 295 | bottom_line = min(rec1[2], rec2[2]) 296 | # judge if there is an intersect 297 | if left_line >= right_line or top_line >= bottom_line: 298 | return 0 299 | else: 300 | intersect = (right_line - left_line + 1) * (bottom_line - top_line + 1) 301 | return intersect / (sum_area - intersect) 302 | 303 | def triplet_nms_filter(self, preds): 304 | preds_filtered = [] 305 | for img_preds in preds: 306 | pred_bboxes = img_preds['predictions'] 307 | pred_hois = img_preds['hoi_prediction'] 308 | all_triplets = {} 309 | for index, pred_hoi in enumerate(pred_hois): 310 | triplet = pred_hoi['category_id'] 311 | 312 | if triplet not in all_triplets: 313 | all_triplets[triplet] = {'subs': [], 'objs': [], 'scores': [], 'indexes': []} 314 | all_triplets[triplet]['subs'].append(pred_bboxes[pred_hoi['subject_id']]['bbox']) 315 | all_triplets[triplet]['objs'].append(pred_bboxes[pred_hoi['object_id']]['bbox']) 316 | all_triplets[triplet]['scores'].append(pred_hoi['score']) 317 | all_triplets[triplet]['indexes'].append(index) 318 | 319 | all_keep_inds = [] 320 | for triplet, values in all_triplets.items(): 321 | subs, objs, scores = values['subs'], values['objs'], values['scores'] 322 | if self.use_soft_nms: 323 | keep_inds = self.pairwise_soft_nms(np.array(subs), np.array(objs), np.array(scores)) 324 | else: 325 | keep_inds = self.pairwise_nms(np.array(subs), np.array(objs), np.array(scores)) 326 | 327 | if self.use_score_thres: 328 | sorted_scores = np.array(scores)[keep_inds] 329 | keep_inds = np.array(keep_inds)[sorted_scores > self.thres_score] 330 | 331 | keep_inds = list(np.array(values['indexes'])[keep_inds]) 332 | all_keep_inds.extend(keep_inds) 333 | 334 | preds_filtered.append({ 335 | 'filename': img_preds['filename'], 336 | 'predictions': pred_bboxes, 337 | 'hoi_prediction': list(np.array(img_preds['hoi_prediction'])[all_keep_inds]) 338 | }) 339 | 340 | return preds_filtered 341 | 342 | def pairwise_nms(self, subs, objs, scores): 343 | sx1, sy1, sx2, sy2 = subs[:, 0], subs[:, 1], subs[:, 2], subs[:, 3] 344 | ox1, oy1, ox2, oy2 = objs[:, 0], objs[:, 1], objs[:, 2], objs[:, 3] 345 | 346 | sub_areas = (sx2 - sx1 + 1) * (sy2 - sy1 + 1) 347 | obj_areas = (ox2 - ox1 + 1) * (oy2 - oy1 + 1) 348 | 349 | order = scores.argsort()[::-1] 350 | 351 | keep_inds = [] 352 | while order.size > 0: 353 | i = order[0] 354 | keep_inds.append(i) 355 | 356 | sxx1 = np.maximum(sx1[i], sx1[order[1:]]) 357 | syy1 = np.maximum(sy1[i], sy1[order[1:]]) 358 | sxx2 = np.minimum(sx2[i], sx2[order[1:]]) 359 | syy2 = np.minimum(sy2[i], sy2[order[1:]]) 360 | 361 | sw = np.maximum(0.0, sxx2 - sxx1 + 1) 362 | sh = np.maximum(0.0, syy2 - syy1 + 1) 363 | sub_inter = sw * sh 364 | sub_union = sub_areas[i] + sub_areas[order[1:]] - sub_inter 365 | 366 | oxx1 = np.maximum(ox1[i], ox1[order[1:]]) 367 | oyy1 = np.maximum(oy1[i], oy1[order[1:]]) 368 | oxx2 = np.minimum(ox2[i], ox2[order[1:]]) 369 | oyy2 = np.minimum(oy2[i], oy2[order[1:]]) 370 | 371 | ow = np.maximum(0.0, oxx2 - oxx1 + 1) 372 | oh = np.maximum(0.0, oyy2 - oyy1 + 1) 373 | obj_inter = ow * oh 374 | obj_union = obj_areas[i] + obj_areas[order[1:]] - obj_inter 375 | 376 | ovr = np.power(sub_inter / sub_union, self.nms_alpha) * np.power(obj_inter / obj_union, self.nms_beta) 377 | inds = np.where(ovr <= self.thres_nms)[0] 378 | 379 | order = order[inds + 1] 380 | return keep_inds 381 | 382 | def pairwise_soft_nms(self, subs, objs, scores): 383 | assert subs.shape[0] == objs.shape[0] 384 | N = subs.shape[0] 385 | 386 | sx1, sy1, sx2, sy2 = subs[:, 0], subs[:, 1], subs[:, 2], subs[:, 3] 387 | ox1, oy1, ox2, oy2 = objs[:, 0], objs[:, 1], objs[:, 2], objs[:, 3] 388 | 389 | sub_areas = (sx2 - sx1 + 1) * (sy2 - sy1 + 1) 390 | obj_areas = (ox2 - ox1 + 1) * (oy2 - oy1 + 1) 391 | 392 | for i in range(N): 393 | tscore = scores[i] 394 | pos = i + 1 395 | if i != N - 1: 396 | maxpos = np.argmax(scores[pos:]) 397 | maxscore = scores[pos:][maxpos] 398 | 399 | if tscore < maxscore: 400 | subs[i], subs[maxpos.item() + i + 1] = subs[maxpos.item() + i + 1].copy(), subs[i].copy() 401 | sub_areas[i], sub_areas[maxpos + i + 1] = sub_areas[maxpos + i + 1].copy(), sub_areas[i].copy() 402 | 403 | objs[i], objs[maxpos.item() + i + 1] = objs[maxpos.item() + i + 1].copy(), objs[i].copy() 404 | obj_areas[i], obj_areas[maxpos + i + 1] = obj_areas[maxpos + i + 1].copy(), obj_areas[i].copy() 405 | 406 | scores[i], scores[maxpos.item() + i + 1] = scores[maxpos.item() + i + 1].copy(), scores[i].copy() 407 | 408 | # IoU calculate 409 | sxx1 = np.maximum(subs[i, 0], subs[pos:, 0]) 410 | syy1 = np.maximum(subs[i, 1], subs[pos:, 1]) 411 | sxx2 = np.minimum(subs[i, 2], subs[pos:, 2]) 412 | syy2 = np.minimum(subs[i, 3], subs[pos:, 3]) 413 | 414 | sw = np.maximum(0.0, sxx2 - sxx1 + 1) 415 | sh = np.maximum(0.0, syy2 - syy1 + 1) 416 | sub_inter = sw * sh 417 | sub_union = sub_areas[i] + sub_areas[pos:] - sub_inter 418 | sub_ovr = sub_inter / sub_union 419 | 420 | oxx1 = np.maximum(objs[i, 0], objs[pos:, 0]) 421 | oyy1 = np.maximum(objs[i, 1], objs[pos:, 1]) 422 | oxx2 = np.minimum(objs[i, 2], objs[pos:, 2]) 423 | oyy2 = np.minimum(objs[i, 3], objs[pos:, 3]) 424 | 425 | ow = np.maximum(0.0, oxx2 - oxx1 + 1) 426 | oh = np.maximum(0.0, oyy2 - oyy1 + 1) 427 | obj_inter = ow * oh 428 | obj_union = obj_areas[i] + obj_areas[pos:] - obj_inter 429 | obj_ovr = obj_inter / obj_union 430 | 431 | # Gaussian decay 432 | ## mode 1 433 | # weight = np.exp(-(sub_ovr * obj_ovr) / self.soft_nms_sigma) 434 | 435 | ## mode 2 436 | weight = np.exp(-sub_ovr / self.soft_nms_sigma) * np.exp(-obj_ovr / self.soft_nms_sigma) 437 | 438 | scores[pos:] = weight * scores[pos:] 439 | 440 | # select the boxes and keep the corresponding indexes 441 | keep_inds = np.where(scores > self.soft_nms_thres_score)[0] 442 | 443 | return keep_inds 444 | 445 | def clip_preds_boxes(self, preds): 446 | preds_filtered = [] 447 | for img_preds in preds: 448 | filename = img_preds['filename'] 449 | 450 | input_file = os.path.join('data/hico_20160224_det/images/test2015/', filename) 451 | img = cv2.imread(input_file) 452 | h, w, c = img.shape 453 | 454 | pred_bboxes = img_preds['predictions'] 455 | for pred_bbox in pred_bboxes: 456 | pred_bbox['bbox'] = self.bbox_clip(pred_bbox['bbox'], (h, w)) 457 | 458 | preds_filtered.append(img_preds) 459 | 460 | return preds_filtered 461 | 462 | def bbox_clip(self, box, size): 463 | x1, y1, x2, y2 = box 464 | h, w = size 465 | x1 = max(0, x1) 466 | y1 = max(0, y1) 467 | x2 = min(x2, w) 468 | y2 = min(y2, h) 469 | return [x1, y1, x2, y2] 470 | -------------------------------------------------------------------------------- /generate_vcoco_official.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import numpy as np 4 | import copy 5 | import pickle 6 | from collections import defaultdict 7 | 8 | import torch 9 | from torch import nn 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | 13 | from datasets.vcoco import build as build_dataset 14 | from models.backbone import build_backbone 15 | from models.gen import build_gen 16 | import util.misc as utils 17 | from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 18 | from util.misc import (NestedTensor, nested_tensor_from_tensor_list, 19 | accuracy, get_world_size, interpolate, 20 | is_dist_avail_and_initialized) 21 | from datasets.vcoco_text_label import vcoco_hoi_text_label, vcoco_obj_text_label 22 | import clip 23 | 24 | 25 | class GEN_VLKT(nn.Module): 26 | 27 | def __init__(self, backbone, transformer, num_obj_classes, num_verb_classes, num_queries, args=None): 28 | super().__init__() 29 | self.num_queries = num_queries 30 | self.transformer = transformer 31 | hidden_dim = transformer.d_model 32 | self.query_embed_h = nn.Embedding(num_queries, hidden_dim) 33 | self.query_embed_o = nn.Embedding(num_queries, hidden_dim) 34 | self.pos_guided_embedd = nn.Embedding(num_queries, hidden_dim) 35 | self.hum_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 36 | self.obj_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 37 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) 38 | self.backbone = backbone 39 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 40 | self.obj_logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 41 | self.args = args 42 | 43 | hoi_clip_label, obj_clip_label, v_linear_proj_weight = self.generate_vcoco_text_label() # 29x512, 81x512 44 | if args.with_clip_label: 45 | self.hoi_class_fc = nn.Sequential( 46 | nn.Linear(hidden_dim, 512), 47 | nn.LayerNorm(512), 48 | ) 49 | self.visual_projection = nn.Linear(512, 263) 50 | self.visual_projection.weight.data = hoi_clip_label / hoi_clip_label.norm(dim=-1, keepdim=True) 51 | 52 | if args.with_obj_clip_label: 53 | self.obj_class_fc = nn.Sequential( 54 | nn.Linear(hidden_dim, 512), 55 | nn.LayerNorm(512), 56 | ) 57 | self.obj_visual_projection = nn.Linear(512, num_obj_classes + 1) 58 | self.obj_visual_projection.weight.data = obj_clip_label / obj_clip_label.norm(dim=-1, keepdim=True) 59 | 60 | self.hidden_dim = hidden_dim 61 | self.reset_parameters() 62 | 63 | def reset_parameters(self): 64 | nn.init.uniform_(self.pos_guided_embedd.weight) 65 | 66 | def generate_vcoco_text_label(self): 67 | device = "cuda" if torch.cuda.is_available() else "cpu" 68 | 69 | hoi_text_inputs = torch.cat([clip.tokenize(vcoco_hoi_text_label[id]) for id in vcoco_hoi_text_label.keys()]) 70 | obj_text_inputs = torch.cat([clip.tokenize(obj_text[1]) for obj_text in vcoco_obj_text_label]) 71 | 72 | clip_model, preprocess = clip.load(self.args.clip_model, device=device) 73 | with torch.no_grad(): 74 | hoi_text_embedding = clip_model.encode_text(hoi_text_inputs.to(device)) 75 | obj_text_embedding = clip_model.encode_text(obj_text_inputs.to(device)) 76 | v_linear_proj_weight = clip_model.visual.proj.detach() 77 | del clip_model 78 | return hoi_text_embedding.float(), obj_text_embedding.float(), v_linear_proj_weight.float() 79 | 80 | def forward(self, samples: NestedTensor): 81 | if not isinstance(samples, NestedTensor): 82 | samples = nested_tensor_from_tensor_list(samples) 83 | features, pos = self.backbone(samples) 84 | 85 | src, mask = features[-1].decompose() 86 | assert mask is not None 87 | h_hs, o_hs, inter_hs = self.transformer(self.input_proj(src), mask, self.query_embed_h.weight, 88 | self.query_embed_o.weight, 89 | self.pos_guided_embedd.weight, pos[-1])[:3] 90 | 91 | outputs_sub_coord = self.hum_bbox_embed(h_hs).sigmoid() 92 | outputs_obj_coord = self.obj_bbox_embed(o_hs).sigmoid() 93 | 94 | if self.args.with_obj_clip_label: 95 | obj_logit_scale = self.obj_logit_scale.exp() 96 | o_hs = self.obj_class_fc(o_hs) 97 | o_hs = o_hs / o_hs.norm(dim=-1, keepdim=True) 98 | outputs_obj_class = obj_logit_scale * self.obj_visual_projection(o_hs) 99 | 100 | if self.args.with_clip_label: 101 | logit_scale = self.logit_scale.exp() 102 | inter_hs = self.hoi_class_fc(inter_hs) 103 | outputs_inter_hs = inter_hs.clone() 104 | inter_hs = inter_hs / inter_hs.norm(dim=-1, keepdim=True) 105 | outputs_hoi_class = logit_scale * self.visual_projection(inter_hs) 106 | 107 | out = {'pred_hoi_logits': outputs_hoi_class[-1], 'pred_obj_logits': outputs_obj_class[-1], 108 | 'pred_sub_boxes': outputs_sub_coord[-1], 'pred_obj_boxes': outputs_obj_coord[-1], 109 | 'semantic_memory': outputs_inter_hs[-1]} 110 | 111 | return out 112 | 113 | 114 | class MLP(nn.Module): 115 | """ Very simple multi-layer perceptron (also called FFN)""" 116 | 117 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 118 | super().__init__() 119 | self.num_layers = num_layers 120 | h = [hidden_dim] * (num_layers - 1) 121 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 122 | 123 | def forward(self, x): 124 | for i, layer in enumerate(self.layers): 125 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 126 | return x 127 | 128 | 129 | class PostProcessHOI(nn.Module): 130 | 131 | def __init__(self, num_queries, subject_category_id, correct_mat, args): 132 | super().__init__() 133 | self.max_hois = 100 134 | 135 | self.num_queries = num_queries 136 | self.subject_category_id = subject_category_id 137 | 138 | correct_mat = np.concatenate((correct_mat, np.ones((correct_mat.shape[0], 1))), axis=1) 139 | self.register_buffer('correct_mat', torch.from_numpy(correct_mat)) 140 | 141 | self.use_nms_filter = args.use_nms_filter 142 | self.thres_nms = args.thres_nms 143 | self.nms_alpha = args.nms_alpha 144 | self.nms_beta = args.nms_beta 145 | print('using use_nms_filter: ', self.use_nms_filter) 146 | 147 | self.hoi_obj_list = [] 148 | self.verb_hoi_dict = defaultdict(list) 149 | self.vcoco_triplet_labels = list(vcoco_hoi_text_label.keys()) 150 | for index, hoi_pair in enumerate(self.vcoco_triplet_labels): 151 | self.hoi_obj_list.append(hoi_pair[1]) 152 | self.verb_hoi_dict[hoi_pair[0]].append(index) 153 | 154 | @torch.no_grad() 155 | def forward(self, outputs, target_sizes): 156 | out_obj_logits = outputs['pred_obj_logits'] 157 | out_hoi_logits = outputs['pred_hoi_logits'] 158 | out_sub_boxes = outputs['pred_sub_boxes'] 159 | out_obj_boxes = outputs['pred_obj_boxes'] 160 | 161 | assert len(out_obj_logits) == len(target_sizes) 162 | assert target_sizes.shape[1] == 2 163 | 164 | hoi_scores = out_hoi_logits.sigmoid() 165 | obj_scores = out_obj_logits.sigmoid() 166 | obj_labels = F.softmax(out_obj_logits, -1)[..., :-1].max(-1)[1] 167 | 168 | img_h, img_w = target_sizes.unbind(1) 169 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(hoi_scores.device) 170 | sub_boxes = box_cxcywh_to_xyxy(out_sub_boxes) 171 | sub_boxes = sub_boxes * scale_fct[:, None, :] 172 | obj_boxes = box_cxcywh_to_xyxy(out_obj_boxes) 173 | obj_boxes = obj_boxes * scale_fct[:, None, :] 174 | 175 | results = [] 176 | for index in range(len(hoi_scores)): 177 | hs, os, ol, sb, ob = hoi_scores[index], obj_scores[index], obj_labels[index], sub_boxes[index], obj_boxes[ 178 | index] 179 | sl = torch.full_like(ol, self.subject_category_id) 180 | l = torch.cat((sl, ol)) 181 | b = torch.cat((sb, ob)) 182 | bboxes = [{'bbox': bbox, 'category_id': label} for bbox, label in 183 | zip(b.to('cpu').numpy(), l.to('cpu').numpy())] 184 | 185 | hs = hs.to('cpu').numpy() 186 | os = os.to('cpu').numpy() 187 | os = os * os 188 | hs = hs + os[:, self.hoi_obj_list] 189 | verb_scores = np.zeros((hs.shape[0], len(self.verb_hoi_dict))) 190 | for i in range(hs.shape[0]): 191 | for k, v in self.verb_hoi_dict.items(): 192 | verb_scores[i][k] = np.max(hs[i, v]) 193 | 194 | verb_labels = np.tile(np.arange(verb_scores.shape[1]), (verb_scores.shape[0], 1)) 195 | 196 | ids = torch.arange(b.shape[0]) 197 | 198 | hois = [{'subject_id': subject_id, 'object_id': object_id, 'category_id': category_id, 'score': score} for 199 | subject_id, object_id, category_id, score in zip(ids[:ids.shape[0] // 2].to('cpu').numpy(), 200 | ids[ids.shape[0] // 2:].to('cpu').numpy(), 201 | verb_labels, verb_scores)] 202 | 203 | current_result = {'predictions': bboxes, 'hoi_prediction': hois} 204 | 205 | if self.use_nms_filter: 206 | current_result = self.triplet_nms_filter(current_result) 207 | 208 | results.append(current_result) 209 | 210 | return results 211 | 212 | def triplet_nms_filter(self, preds): 213 | pred_bboxes = preds['predictions'] 214 | pred_hois = preds['hoi_prediction'] 215 | all_triplets = {} 216 | for index, pred_hoi in enumerate(pred_hois): 217 | triplet = str(pred_bboxes[pred_hoi['subject_id']]['category_id']) + '_' + \ 218 | str(pred_bboxes[pred_hoi['object_id']]['category_id']) + '_' + str(pred_hoi['category_id']) 219 | 220 | if triplet not in all_triplets: 221 | all_triplets[triplet] = {'subs': [], 'objs': [], 'scores': [], 'indexes': []} 222 | all_triplets[triplet]['subs'].append(pred_bboxes[pred_hoi['subject_id']]['bbox']) 223 | all_triplets[triplet]['objs'].append(pred_bboxes[pred_hoi['object_id']]['bbox']) 224 | all_triplets[triplet]['scores'].append(pred_hoi['score']) 225 | all_triplets[triplet]['indexes'].append(index) 226 | 227 | all_keep_inds = [] 228 | for triplet, values in all_triplets.items(): 229 | subs, objs, scores = values['subs'], values['objs'], values['scores'] 230 | keep_inds = self.pairwise_nms(np.array(subs), np.array(objs), np.array(scores)) 231 | 232 | keep_inds = list(np.array(values['indexes'])[keep_inds]) 233 | all_keep_inds.extend(keep_inds) 234 | 235 | preds_filtered = { 236 | 'predictions': pred_bboxes, 237 | 'hoi_prediction': list(np.array(preds['hoi_prediction'])[all_keep_inds]) 238 | } 239 | 240 | return preds_filtered 241 | 242 | def pairwise_nms(self, subs, objs, scores): 243 | sx1, sy1, sx2, sy2 = subs[:, 0], subs[:, 1], subs[:, 2], subs[:, 3] 244 | ox1, oy1, ox2, oy2 = objs[:, 0], objs[:, 1], objs[:, 2], objs[:, 3] 245 | 246 | sub_areas = (sx2 - sx1 + 1) * (sy2 - sy1 + 1) 247 | obj_areas = (ox2 - ox1 + 1) * (oy2 - oy1 + 1) 248 | 249 | max_scores = np.max(scores, axis=1) 250 | order = max_scores.argsort()[::-1] 251 | 252 | keep_inds = [] 253 | while order.size > 0: 254 | i = order[0] 255 | keep_inds.append(i) 256 | 257 | sxx1 = np.maximum(sx1[i], sx1[order[1:]]) 258 | syy1 = np.maximum(sy1[i], sy1[order[1:]]) 259 | sxx2 = np.minimum(sx2[i], sx2[order[1:]]) 260 | syy2 = np.minimum(sy2[i], sy2[order[1:]]) 261 | 262 | sw = np.maximum(0.0, sxx2 - sxx1 + 1) 263 | sh = np.maximum(0.0, syy2 - syy1 + 1) 264 | sub_inter = sw * sh 265 | sub_union = sub_areas[i] + sub_areas[order[1:]] - sub_inter 266 | 267 | oxx1 = np.maximum(ox1[i], ox1[order[1:]]) 268 | oyy1 = np.maximum(oy1[i], oy1[order[1:]]) 269 | oxx2 = np.minimum(ox2[i], ox2[order[1:]]) 270 | oyy2 = np.minimum(oy2[i], oy2[order[1:]]) 271 | 272 | ow = np.maximum(0.0, oxx2 - oxx1 + 1) 273 | oh = np.maximum(0.0, oyy2 - oyy1 + 1) 274 | obj_inter = ow * oh 275 | obj_union = obj_areas[i] + obj_areas[order[1:]] - obj_inter 276 | 277 | ovr = np.power(sub_inter / sub_union, self.nms_alpha) * np.power(obj_inter / obj_union, self.nms_beta) 278 | inds = np.where(ovr <= self.thres_nms)[0] 279 | 280 | order = order[inds + 1] 281 | return keep_inds 282 | 283 | 284 | def get_args_parser(): 285 | parser = argparse.ArgumentParser('Set transformer detector', add_help=False) 286 | parser.add_argument('--batch_size', default=2, type=int) 287 | 288 | # * Backbone 289 | parser.add_argument('--backbone', default='resnet50', type=str, 290 | help="Name of the convolutional backbone to use") 291 | parser.add_argument('--dilation', action='store_true', 292 | help="If true, we replace stride with dilation in the last convolutional block (DC5)") 293 | parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'), 294 | help="Type of positional embedding to use on top of the image features") 295 | 296 | # * Transformer 297 | parser.add_argument('--enc_layers', default=6, type=int, 298 | help="Number of encoding layers in the transformer") 299 | parser.add_argument('--dec_layers', default=3, type=int, 300 | help="Number of stage1 decoding layers in the transformer") 301 | parser.add_argument('--dim_feedforward', default=2048, type=int, 302 | help="Intermediate size of the feedforward layers in the transformer blocks") 303 | parser.add_argument('--hidden_dim', default=256, type=int, 304 | help="Size of the embeddings (dimension of the transformer)") 305 | parser.add_argument('--dropout', default=0.1, type=float, 306 | help="Dropout applied in the transformer") 307 | parser.add_argument('--nheads', default=8, type=int, 308 | help="Number of attention heads inside the transformer's attentions") 309 | parser.add_argument('--num_queries', default=100, type=int, 310 | help="Number of query slots") 311 | parser.add_argument('--pre_norm', action='store_true') 312 | 313 | # * HOI 314 | parser.add_argument('--subject_category_id', default=0, type=int) 315 | parser.add_argument('--missing_category_id', default=80, type=int) 316 | 317 | parser.add_argument('--hoi_path', type=str) 318 | parser.add_argument('--param_path', type=str, required=True) 319 | parser.add_argument('--save_path', type=str, required=True) 320 | 321 | parser.add_argument('--device', default='cuda', 322 | help='device to use for training / testing') 323 | parser.add_argument('--num_workers', default=2, type=int) 324 | 325 | # * PNMS 326 | parser.add_argument('--use_nms_filter', action='store_true', help='Use pair nms filter, default not use') 327 | parser.add_argument('--thres_nms', default=0.7, type=float) 328 | parser.add_argument('--nms_alpha', default=1, type=float) 329 | parser.add_argument('--nms_beta', default=0.5, type=float) 330 | 331 | # clip 332 | parser.add_argument('--ft_clip_with_small_lr', action='store_true', 333 | help='Use smaller learning rate to finetune clip weights') 334 | parser.add_argument('--with_clip_label', action='store_true', help='Use clip to classify HOI') 335 | parser.add_argument('--early_stop_mimic', action='store_true', help='stop mimic after step') 336 | parser.add_argument('--with_obj_clip_label', action='store_true', help='Use clip to classify object') 337 | parser.add_argument('--clip_model', default='/data-nas2/liaoyue/HICO-Det/ViT-B-32.pt', 338 | help='clip pretrained model path') 339 | 340 | return parser 341 | 342 | 343 | def main(args): 344 | print("git:\n {}\n".format(utils.get_sha())) 345 | 346 | print(args) 347 | 348 | valid_obj_ids = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 349 | 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 350 | 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 351 | 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 352 | 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 353 | 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 354 | 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 355 | 82, 84, 85, 86, 87, 88, 89, 90) 356 | 357 | verb_classes = ['hold_obj', 'stand', 'sit_instr', 'ride_instr', 'walk', 'look_obj', 'hit_instr', 'hit_obj', 358 | 'eat_obj', 'eat_instr', 'jump_instr', 'lay_instr', 'talk_on_phone_instr', 'carry_obj', 359 | 'throw_obj', 'catch_obj', 'cut_instr', 'cut_obj', 'run', 'work_on_computer_instr', 360 | 'ski_instr', 'surf_instr', 'skateboard_instr', 'smile', 'drink_instr', 'kick_obj', 361 | 'point_instr', 'read_obj', 'snowboard_instr'] 362 | 363 | device = torch.device(args.device) 364 | 365 | dataset_val = build_dataset(image_set='val', args=args) 366 | 367 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 368 | 369 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, 370 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) 371 | 372 | args.lr_backbone = 0 373 | args.masks = False 374 | backbone = build_backbone(args) 375 | gen = build_gen(args) 376 | model = GEN_VLKT(backbone, gen, len(valid_obj_ids) + 1, len(verb_classes), 377 | args.num_queries, args) 378 | post_processor = PostProcessHOI(args.num_queries, args.subject_category_id, dataset_val.correct_mat, args) 379 | model.to(device) 380 | post_processor.to(device) 381 | 382 | checkpoint = torch.load(args.param_path, map_location='cpu') 383 | model.load_state_dict(checkpoint['model']) 384 | 385 | detections = generate(model, post_processor, data_loader_val, device, verb_classes, args.missing_category_id) 386 | 387 | with open(args.save_path, 'wb') as f: 388 | pickle.dump(detections, f, protocol=2) 389 | 390 | 391 | @torch.no_grad() 392 | def generate(model, post_processor, data_loader, device, verb_classes, missing_category_id): 393 | model.eval() 394 | 395 | metric_logger = utils.MetricLogger(delimiter=" ") 396 | header = 'Generate:' 397 | 398 | detections = [] 399 | for samples, targets in metric_logger.log_every(data_loader, 10, header): 400 | samples = samples.to(device) 401 | 402 | outputs = model(samples) 403 | orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) 404 | results = post_processor(outputs, orig_target_sizes) 405 | 406 | for img_results, img_targets in zip(results, targets): 407 | for hoi in img_results['hoi_prediction']: 408 | detection = { 409 | 'image_id': img_targets['img_id'], 410 | 'person_box': img_results['predictions'][hoi['subject_id']]['bbox'].tolist() 411 | } 412 | if img_results['predictions'][hoi['object_id']]['category_id'] == missing_category_id: 413 | object_box = [np.nan, np.nan, np.nan, np.nan] 414 | else: 415 | object_box = img_results['predictions'][hoi['object_id']]['bbox'].tolist() 416 | cut_agent = 0 417 | hit_agent = 0 418 | eat_agent = 0 419 | for idx, score in zip(hoi['category_id'], hoi['score']): 420 | verb_class = verb_classes[idx] 421 | score = score.item() 422 | if len(verb_class.split('_')) == 1: 423 | detection['{}_agent'.format(verb_class)] = score 424 | elif 'cut_' in verb_class: 425 | detection[verb_class] = object_box + [score] 426 | cut_agent = score if score > cut_agent else cut_agent 427 | elif 'hit_' in verb_class: 428 | detection[verb_class] = object_box + [score] 429 | hit_agent = score if score > hit_agent else hit_agent 430 | elif 'eat_' in verb_class: 431 | detection[verb_class] = object_box + [score] 432 | eat_agent = score if score > eat_agent else eat_agent 433 | else: 434 | detection[verb_class] = object_box + [score] 435 | detection['{}_agent'.format( 436 | verb_class.replace('_obj', '').replace('_instr', ''))] = score 437 | detection['cut_agent'] = cut_agent 438 | detection['hit_agent'] = hit_agent 439 | detection['eat_agent'] = eat_agent 440 | detections.append(detection) 441 | 442 | return detections 443 | 444 | 445 | if __name__ == '__main__': 446 | parser = argparse.ArgumentParser(parents=[get_args_parser()]) 447 | args = parser.parse_args() 448 | main(args) 449 | -------------------------------------------------------------------------------- /datasets/vcoco_text_label.py: -------------------------------------------------------------------------------- 1 | vcoco_obj_text_label = [(0, 'a photo of a person and a person'), (1, 'a photo of a person and a bicycle'), 2 | (2, 'a photo of a person and a car'), (3, 'a photo of a person and a motorcycle'), 3 | (4, 'a photo of a person and an airplane'), (5, 'a photo of a person and a bus'), 4 | (6, 'a photo of a person and a train'), (7, 'a photo of a person and a truck'), 5 | (8, 'a photo of a person and a boat'), (9, 'a photo of a person and a traffic light'), 6 | (10, 'a photo of a person and a fire hydrant'), (11, 'a photo of a person and a stop sign'), 7 | (12, 'a photo of a person and a parking meter'), (13, 'a photo of a person and a bench'), 8 | (14, 'a photo of a person and a bird'), (15, 'a photo of a person and a cat'), 9 | (16, 'a photo of a person and a dog'), (17, 'a photo of a person and a horse'), 10 | (18, 'a photo of a person and a sheep'), (19, 'a photo of a person and a cow'), 11 | (20, 'a photo of a person and an elephant'), (21, 'a photo of a person and a bear'), 12 | (22, 'a photo of a person and a zebra'), (23, 'a photo of a person and a giraffe'), 13 | (24, 'a photo of a person and a backpack'), (25, 'a photo of a person and a umbrella'), 14 | (26, 'a photo of a person and a handbag'), (27, 'a photo of a person and a tie'), 15 | (28, 'a photo of a person and a suitcase'), (29, 'a photo of a person and a frisbee'), 16 | (30, 'a photo of a person and a skis'), (31, 'a photo of a person and a snowboard'), 17 | (32, 'a photo of a person and a sports ball'), (33, 'a photo of a person and a kite'), 18 | (34, 'a photo of a person and a baseball bat'), 19 | (35, 'a photo of a person and a baseball glove'), 20 | (36, 'a photo of a person and a skateboard'), (37, 'a photo of a person and a surfboard'), 21 | (38, 'a photo of a person and a tennis racket'), (39, 'a photo of a person and a bottle'), 22 | (40, 'a photo of a person and a wine glass'), (41, 'a photo of a person and a cup'), 23 | (42, 'a photo of a person and a fork'), (43, 'a photo of a person and a knife'), 24 | (44, 'a photo of a person and a spoon'), (45, 'a photo of a person and a bowl'), 25 | (46, 'a photo of a person and a banana'), (47, 'a photo of a person and an apple'), 26 | (48, 'a photo of a person and a sandwich'), (49, 'a photo of a person and an orange'), 27 | (50, 'a photo of a person and a broccoli'), (51, 'a photo of a person and a carrot'), 28 | (52, 'a photo of a person and a hot dog'), (53, 'a photo of a person and a pizza'), 29 | (54, 'a photo of a person and a donut'), (55, 'a photo of a person and a cake'), 30 | (56, 'a photo of a person and a chair'), (57, 'a photo of a person and a couch'), 31 | (58, 'a photo of a person and a potted plant'), (59, 'a photo of a person and a bed'), 32 | (60, 'a photo of a person and a dining table'), (61, 'a photo of a person and a toilet'), 33 | (62, 'a photo of a person and a tv'), (63, 'a photo of a person and a laptop'), 34 | (64, 'a photo of a person and a mouse'), (65, 'a photo of a person and a remote'), 35 | (66, 'a photo of a person and a keyboard'), (67, 'a photo of a person and a cell phone'), 36 | (68, 'a photo of a person and a microwave'), (69, 'a photo of a person and an oven'), 37 | (70, 'a photo of a person and a toaster'), (71, 'a photo of a person and a sink'), 38 | (72, 'a photo of a person and a refrigerator'), (73, 'a photo of a person and a book'), 39 | (74, 'a photo of a person and a clock'), (75, 'a photo of a person and a vase'), 40 | (76, 'a photo of a person and a scissors'), (77, 'a photo of a person and a teddy bear'), 41 | (78, 'a photo of a person and a hair drier'), (79, 'a photo of a person and a toothbrush'), 42 | (80, 'a photo of a person only'), (81, 'a photo of nothing')] 43 | 44 | vcoco_hoi_text_label = {(0, 41): 'a photo of a person holding a cup', 45 | (16, 80): 'a photo of a person cutting with something', 46 | (17, 53): 'a photo of a person cutting a pizza', 47 | (0, 53): 'a photo of a person holding a pizza', (2, 80): 'a photo of a person sitting', 48 | (8, 53): 'a photo of a person eating a pizza', 49 | (9, 80): 'a photo of a person eating with something', 50 | (23, 80): 'a photo of a person smiling', (21, 37): 'a photo of a person surfing a surfboard', 51 | (0, 73): 'a photo of a person holding a book', 52 | (2, 13): 'a photo of a person sitting a bench', 53 | (5, 73): 'a photo of a person looking at a book', 54 | (27, 73): 'a photo of a person reading a book', (1, 80): 'a photo of a person standing', 55 | (22, 36): 'a photo of a person skateboarding a skateboard', 56 | (20, 30): 'a photo of a person skiing a skis', (0, 80): 'a photo of a person holding', 57 | (8, 80): 'a photo of a person eating', (2, 56): 'a photo of a person sitting a chair', 58 | (5, 63): 'a photo of a person looking at a laptop', 59 | (19, 63): 'a photo of a person working on computer a laptop', 60 | (0, 40): 'a photo of a person holding a wine glass', 61 | (24, 40): 'a photo of a person drinking a wine glass', 62 | (5, 31): 'a photo of a person looking at a snowboard', 63 | (28, 31): 'a photo of a person snowboarding a snowboard', 64 | (0, 76): 'a photo of a person holding a scissors', 65 | (5, 80): 'a photo of a person looking at something', 66 | (5, 76): 'a photo of a person looking at a scissors', 67 | (16, 76): 'a photo of a person cutting with a scissors', 68 | (17, 80): 'a photo of a person cutting', 69 | (5, 37): 'a photo of a person looking at a surfboard', 70 | (2, 17): 'a photo of a person sitting a horse', 71 | (3, 17): 'a photo of a person riding a horse', (4, 80): 'a photo of a person walking', 72 | (5, 29): 'a photo of a person looking at a frisbee', (10, 80): 'a photo of a person jumping', 73 | (14, 29): 'a photo of a person throwing a frisbee', (18, 80): 'a photo of a person running', 74 | (5, 53): 'a photo of a person looking at a pizza', 75 | (0, 48): 'a photo of a person holding a sandwich', 76 | (8, 48): 'a photo of a person eating a sandwich', 77 | (0, 67): 'a photo of a person holding a cell phone', 78 | (19, 80): 'a photo of a person working on computer', 79 | (0, 24): 'a photo of a person holding a backpack', 80 | (13, 24): 'a photo of a person carrying a backpack', (11, 80): 'a photo of a person laying', 81 | (11, 57): 'a photo of a person laying a couch', 82 | (0, 17): 'a photo of a person holding a horse', (0, 15): 'a photo of a person holding a cat', 83 | (11, 59): 'a photo of a person laying a bed', 84 | (15, 29): 'a photo of a person catching a frisbee', (3, 80): 'a photo of a person riding', 85 | (12, 67): 'a photo of a person talking on phone a cell phone', 86 | (0, 31): 'a photo of a person holding a snowboard', 87 | (10, 31): 'a photo of a person jumping a snowboard', 88 | (5, 36): 'a photo of a person looking at a skateboard', 89 | (10, 36): 'a photo of a person jumping a skateboard', 90 | (0, 79): 'a photo of a person holding a toothbrush', (27, 80): 'a photo of a person reading', 91 | (0, 39): 'a photo of a person holding a bottle', 92 | (24, 39): 'a photo of a person drinking a bottle', 93 | (2, 59): 'a photo of a person sitting a bed', 94 | (5, 48): 'a photo of a person looking at a sandwich', 95 | (0, 30): 'a photo of a person holding a skis', 96 | (0, 38): 'a photo of a person holding a tennis racket', 97 | (5, 32): 'a photo of a person looking at a sports ball', 98 | (6, 38): 'a photo of a person hitting with a tennis racket', 99 | (7, 32): 'a photo of a person hitting a sports ball', 100 | (5, 0): 'a photo of a person looking at a person', 101 | (5, 17): 'a photo of a person looking at a horse', 102 | (0, 47): 'a photo of a person holding an apple', 103 | (5, 18): 'a photo of a person looking at a sheep', 104 | (8, 47): 'a photo of a person eating an apple', 105 | (25, 32): 'a photo of a person kicking a sports ball', 106 | (0, 44): 'a photo of a person holding a spoon', 107 | (5, 55): 'a photo of a person looking at a cake', 108 | (8, 55): 'a photo of a person eating a cake', 109 | (9, 44): 'a photo of a person eating with a spoon', 110 | (0, 63): 'a photo of a person holding a laptop', 111 | (6, 80): 'a photo of a person hitting with something', 112 | (2, 3): 'a photo of a person sitting a motorcycle', 113 | (3, 3): 'a photo of a person riding a motorcycle', 114 | (0, 43): 'a photo of a person holding a knife', 115 | (5, 43): 'a photo of a person looking at a knife', 116 | (16, 43): 'a photo of a person cutting with a knife', 117 | (17, 55): 'a photo of a person cutting a cake', (7, 80): 'a photo of a person hitting', 118 | (0, 34): 'a photo of a person holding a baseball bat', 119 | (6, 34): 'a photo of a person hitting with a baseball bat', 120 | (15, 80): 'a photo of a person catching', (2, 57): 'a photo of a person sitting a couch', 121 | (0, 77): 'a photo of a person holding a teddy bear', 122 | (13, 49): 'a photo of a person carrying an orange', 123 | (0, 42): 'a photo of a person holding a fork', 124 | (9, 42): 'a photo of a person eating with a fork', 125 | (5, 62): 'a photo of a person looking at a tv', 126 | (0, 28): 'a photo of a person holding a suitcase', 127 | (13, 28): 'a photo of a person carrying a suitcase', 128 | (2, 20): 'a photo of a person sitting an elephant', 129 | (3, 20): 'a photo of a person riding an elephant', 130 | (5, 15): 'a photo of a person looking at a cat', 131 | (0, 56): 'a photo of a person holding a chair', 132 | (5, 60): 'a photo of a person looking at a dining table', 133 | (24, 41): 'a photo of a person drinking a cup', (14, 80): 'a photo of a person throwing', 134 | (13, 26): 'a photo of a person carrying a handbag', 135 | (5, 16): 'a photo of a person looking at a dog', 136 | (0, 46): 'a photo of a person holding a banana', 137 | (13, 46): 'a photo of a person carrying a banana', 138 | (5, 28): 'a photo of a person looking at a suitcase', 139 | (9, 43): 'a photo of a person eating with a knife', 140 | (0, 37): 'a photo of a person holding a surfboard', 141 | (13, 37): 'a photo of a person carrying a surfboard', 142 | (8, 54): 'a photo of a person eating a donut', 143 | (0, 0): 'a photo of a person holding a person', 144 | (0, 35): 'a photo of a person holding a baseball glove', 145 | (0, 65): 'a photo of a person holding a remote', 146 | (0, 54): 'a photo of a person holding a donut', 147 | (0, 26): 'a photo of a person holding a handbag', (13, 80): 'a photo of a person carrying', 148 | (13, 0): 'a photo of a person carrying a person', 149 | (0, 32): 'a photo of a person holding a sports ball', 150 | (14, 32): 'a photo of a person throwing a sports ball', 151 | (5, 54): 'a photo of a person looking at a donut', 152 | (0, 1): 'a photo of a person holding a bicycle', 153 | (2, 1): 'a photo of a person sitting a bicycle', 154 | (3, 1): 'a photo of a person riding a bicycle', 155 | (5, 1): 'a photo of a person looking at a bicycle', (25, 80): 'a photo of a person kicking', 156 | (5, 67): 'a photo of a person looking at a cell phone', 157 | (5, 6): 'a photo of a person looking at a train', 158 | (0, 29): 'a photo of a person holding a frisbee', 159 | (0, 36): 'a photo of a person holding a skateboard', 160 | (3, 7): 'a photo of a person riding a truck', 161 | (26, 63): 'a photo of a person pointing a laptop', 162 | (0, 3): 'a photo of a person holding a motorcycle', 163 | (13, 30): 'a photo of a person carrying a skis', 164 | (0, 25): 'a photo of a person holding a umbrella', 165 | (5, 45): 'a photo of a person looking at a bowl', 166 | (17, 51): 'a photo of a person cutting a carrot', 167 | (0, 52): 'a photo of a person holding a hot dog', 168 | (8, 52): 'a photo of a person eating a hot dog', 169 | (0, 33): 'a photo of a person holding a kite', 170 | (5, 13): 'a photo of a person looking at a bench', 171 | (12, 80): 'a photo of a person talking on phone', 172 | (22, 80): 'a photo of a person skateboarding', 173 | (5, 35): 'a photo of a person looking at a baseball glove', 174 | (15, 32): 'a photo of a person catching a sports ball', 175 | (26, 80): 'a photo of a person pointing', 176 | (13, 25): 'a photo of a person carrying a umbrella', 177 | (5, 40): 'a photo of a person looking at a wine glass', 178 | (10, 37): 'a photo of a person jumping a surfboard', 179 | (5, 33): 'a photo of a person looking at a kite', 180 | (13, 33): 'a photo of a person carrying a kite', 181 | (3, 6): 'a photo of a person riding a train', 182 | (5, 44): 'a photo of a person looking at a spoon', 183 | (0, 20): 'a photo of a person holding an elephant', (21, 80): 'a photo of a person surfing', 184 | (5, 20): 'a photo of a person looking at an elephant', 185 | (3, 8): 'a photo of a person riding a boat', 186 | (5, 23): 'a photo of a person looking at a giraffe', 187 | (13, 67): 'a photo of a person carrying a cell phone', 188 | (11, 56): 'a photo of a person laying a chair', 189 | (5, 19): 'a photo of a person looking at a cow', 190 | (5, 42): 'a photo of a person looking at a fork', 191 | (0, 55): 'a photo of a person holding a cake', 192 | (13, 32): 'a photo of a person carrying a sports ball', 193 | (5, 30): 'a photo of a person looking at a skis', 194 | (13, 36): 'a photo of a person carrying a skateboard', 195 | (26, 67): 'a photo of a person pointing a cell phone', 196 | (5, 52): 'a photo of a person looking at a hot dog', 197 | (8, 46): 'a photo of a person eating a banana', (20, 80): 'a photo of a person skiing', 198 | (28, 80): 'a photo of a person snowboarding', (0, 14): 'a photo of a person holding a bird', 199 | (11, 60): 'a photo of a person laying a dining table', 200 | (0, 16): 'a photo of a person holding a dog', 201 | (0, 72): 'a photo of a person holding a refrigerator', 202 | (5, 72): 'a photo of a person looking at a refrigerator', 203 | (5, 7): 'a photo of a person looking at a truck', 204 | (5, 41): 'a photo of a person looking at a cup', 205 | (2, 61): 'a photo of a person sitting a toilet', (24, 80): 'a photo of a person drinking', 206 | (0, 27): 'a photo of a person holding a tie', 207 | (5, 27): 'a photo of a person looking at a tie', 208 | (17, 27): 'a photo of a person cutting a tie', 209 | (5, 10): 'a photo of a person looking at a fire hydrant', 210 | (26, 10): 'a photo of a person pointing a fire hydrant', 211 | (11, 13): 'a photo of a person laying a bench', 212 | (17, 18): 'a photo of a person cutting a sheep', 213 | (0, 64): 'a photo of a person holding a mouse', 214 | (5, 64): 'a photo of a person looking at a mouse', 215 | (5, 66): 'a photo of a person looking at a keyboard', 216 | (16, 42): 'a photo of a person cutting with a fork', 217 | (17, 0): 'a photo of a person cutting a person', 218 | (5, 5): 'a photo of a person looking at a bus', (3, 2): 'a photo of a person riding a car', 219 | (10, 30): 'a photo of a person jumping a skis', 220 | (5, 4): 'a photo of a person looking at an airplane', 221 | (5, 46): 'a photo of a person looking at a banana', 222 | (2, 28): 'a photo of a person sitting a suitcase', 223 | (13, 29): 'a photo of a person carrying a frisbee', 224 | (5, 26): 'a photo of a person looking at a handbag', 225 | (8, 50): 'a photo of a person eating a broccoli', 226 | (17, 46): 'a photo of a person cutting a banana', 227 | (0, 18): 'a photo of a person holding a sheep', 228 | (17, 48): 'a photo of a person cutting a sandwich', 229 | (26, 0): 'a photo of a person pointing a person', 230 | (5, 3): 'a photo of a person looking at a motorcycle', 231 | (5, 24): 'a photo of a person looking at a backpack', 232 | (0, 45): 'a photo of a person holding a bowl', 233 | (26, 27): 'a photo of a person pointing a tie', 234 | (0, 49): 'a photo of a person holding an orange', 235 | (8, 49): 'a photo of a person eating an orange', 236 | (5, 34): 'a photo of a person looking at a baseball bat', 237 | (13, 31): 'a photo of a person carrying a snowboard', 238 | (17, 54): 'a photo of a person cutting a donut', 239 | (5, 38): 'a photo of a person looking at a tennis racket', 240 | (8, 51): 'a photo of a person eating a carrot', 241 | (17, 47): 'a photo of a person cutting an apple', 242 | (13, 40): 'a photo of a person carrying a wine glass', 243 | (26, 48): 'a photo of a person pointing a sandwich', 244 | (26, 62): 'a photo of a person pointing a tv', 245 | (13, 74): 'a photo of a person carrying a clock', 246 | (5, 61): 'a photo of a person looking at a toilet', 247 | (26, 19): 'a photo of a person pointing a cow', 248 | (5, 65): 'a photo of a person looking at a remote', 249 | (26, 18): 'a photo of a person pointing a sheep', 250 | (0, 50): 'a photo of a person holding a broccoli', 251 | (0, 13): 'a photo of a person holding a bench', 252 | (26, 33): 'a photo of a person pointing a kite', 253 | (0, 7): 'a photo of a person holding a truck', 254 | (13, 41): 'a photo of a person carrying a cup', 255 | (24, 45): 'a photo of a person drinking a bowl', 256 | (13, 38): 'a photo of a person carrying a tennis racket', 257 | (13, 39): 'a photo of a person carrying a bottle', 258 | (5, 47): 'a photo of a person looking at an apple', 259 | (5, 56): 'a photo of a person looking at a chair', 260 | (2, 24): 'a photo of a person sitting a backpack', 261 | (26, 60): 'a photo of a person pointing a dining table', 262 | (0, 78): 'a photo of a person holding a hair drier', 263 | (5, 39): 'a photo of a person looking at a bottle', 264 | (26, 55): 'a photo of a person pointing a cake', 265 | (26, 66): 'a photo of a person pointing a keyboard', 266 | (26, 72): 'a photo of a person pointing a refrigerator', 267 | (5, 74): 'a photo of a person looking at a clock', 268 | (0, 8): 'a photo of a person holding a boat', (17, 45): 'a photo of a person cutting a bowl', 269 | (26, 23): 'a photo of a person pointing a giraffe', 270 | (5, 25): 'a photo of a person looking at a umbrella', 271 | (0, 66): 'a photo of a person holding a keyboard', 272 | (2, 26): 'a photo of a person sitting a handbag', 273 | (26, 52): 'a photo of a person pointing a hot dog', 274 | (2, 60): 'a photo of a person sitting a dining table', 275 | (13, 77): 'a photo of a person carrying a teddy bear', 276 | (0, 51): 'a photo of a person holding a carrot', 277 | (13, 34): 'a photo of a person carrying a baseball bat', 278 | (5, 2): 'a photo of a person looking at a car', (3, 5): 'a photo of a person riding a bus', 279 | (17, 50): 'a photo of a person cutting a broccoli', 280 | (5, 14): 'a photo of a person looking at a bird', 281 | (13, 73): 'a photo of a person carrying a book', 282 | (5, 50): 'a photo of a person looking at a broccoli'} 283 | --------------------------------------------------------------------------------