├── .gitignore ├── LICENSE ├── README.md ├── arguments.py ├── baseline_disjoint_detector_and_clip.py ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── data └── README.md ├── datasets ├── __init__.py ├── augmentation.py ├── hico.py ├── hico_categories.py ├── hico_evaluator.py ├── swig.py ├── swig_evaluator.py ├── swig_v1_categories.py └── transforms.py ├── engine.py ├── figures └── THID_arch.png ├── main.py ├── models ├── __init__.py ├── criterion.py ├── matcher.py ├── model.py ├── position_encoding.py └── transformer.py └── utils ├── __init__.py ├── box_ops.py ├── misc.py ├── sampler.py ├── scheduler.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # compilation and distribution 2 | .vscode/ 3 | __pycache__/ 4 | _ext 5 | *.py[cod] 6 | *.pyc 7 | *.so 8 | build/ 9 | dist/ 10 | wheels/ 11 | *.egg-info/ 12 | *.egg/ 13 | 14 | # outputs 15 | logs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Suchen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Transferable Human-object Interaction Detector (THID) 2 | 3 | ## Overview 4 | 5 | THID is an end-to-end transformer-based human-object interaction (HOI) detector. [[Paper]](https://cse.buffalo.edu/~jsyuan/papers/2022/CVPR2022_4126.pdf) 6 | 7 | ![THID](./figures/THID_arch.png) 8 | 9 | - **Motivation**: It is difficult to construct a data collection including all possible combinations of human actions and interacting objects due to the combinatorial nature of human-object interactions (HOI). In this work, we aim to develop a transferable HOI detector for the wide range of unseen interactions. 10 | - **Components**: (1) We treat independent HOI labels as the natural language supervision of interactions and embed them into a joint visual-and-text space to capture their correlations. (2) Our visual encoder is instantiated as a Vision Transformer with new learnable HOI tokens and a sequence parser to generate HOI predictions with bounding boxes. (3) It distills and leverages the transferable knowledge from the pretrained CLIP model to perform the zero-shot interaction detection. 11 | 12 | ## Preparation 13 | 14 | ### Installation 15 | 16 | Our code is built upon [CLIP](https://github.com/openai/CLIP). This repo requires to install [PyTorch](https://pytorch.org/get-started/locally/) and torchvision, as well as small additional dependencies. 17 | 18 | ```bash 19 | conda install pytorch torchvision cudatoolkit=11.3 -c pytorch 20 | pip install ftfy regex tqdm numpy Pillow matplotlib 21 | ``` 22 | 23 | ### Dataset 24 | 25 | The experiments are mainly conducted on **HICO-DET** and **SWIG-HOI** dataset. We follow [this repo](https://github.com/YueLiao/PPDM) to prepare the HICO-DET dataset. And we follow [this repo](https://github.com/scwangdyd/large_vocabulary_hoi_detection) to prepare the SWIG-HOI dataset. 26 | 27 | #### HICO-DET 28 | 29 | HICO-DET dataset can be downloaded [here](https://drive.google.com/open?id=1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk). After finishing downloading, unpack the tarball (`hico_20160224_det.tar.gz`) to the `data` directory. We use the annotation files provided by the [PPDM](https://github.com/YueLiao/PPDM) authors. We re-organize the annotation files with additional meta info, e.g., image width and height. The annotation files can be downloaded from [here](https://drive.google.com/open?id=1lqmevkw8fjDuTqsOOgzg07Kf6lXhK2rg). The downloaded files have to be placed as follows. Otherwise, please replace the default path to your custom locations in [datasets/hico.py](./datasets/hico.py). 30 | 31 | ``` plain 32 | |─ data 33 | │ └─ hico_20160224_det 34 | | |- images 35 | | | |─ test2015 36 | | | |─ train2015 37 | | |─ annotations 38 | | | |─ trainval_hico_ann.json 39 | | | |─ test_hico_ann.json 40 | : : 41 | ``` 42 | 43 | #### SWIG-DET 44 | 45 | SWIG-DET dataset can be downloaded [here](https://swig-data-weights.s3.us-east-2.amazonaws.com/images_512.zip). After finishing downloading, unpack the `images_512.zip` to the `data` directory. The annotation files can be downloaded from [here](https://drive.google.com/open?id=1GxNP99J0KP6Pwfekij_M1Z0moHziX8QN). The downloaded files to be placed as follows. Otherwise, please replace the default path to your custom locations in [datasets/swig.py](./datasets/swig.py). 46 | 47 | ``` plain 48 | |─ data 49 | │ └─ swig_hoi 50 | | |- images_512 51 | | |─ annotations 52 | | | |─ swig_train_1000.json 53 | | | |- swig_val_1000.json 54 | | | |─ swig_trainval_1000.json 55 | | | |- swig_test_1000.json 56 | : : 57 | ``` 58 | 59 | ## Training 60 | 61 | Run this command to train the model in HICO-DET dataset 62 | 63 | ``` bash 64 | python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py \ 65 | --batch_size 8 \ 66 | --output_dir [path to save checkpoint] \ 67 | --epochs 100 \ 68 | --lr 1e-4 --min-lr 1e-7 \ 69 | --hoi_token_length 50 \ 70 | --enable_dec \ 71 | --dataset_file hico 72 | ``` 73 | 74 | Run this command to train the model in SWIG-HOI dataset 75 | 76 | ``` bash 77 | python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py \ 78 | --batch_size 8 \ 79 | --output_dir [path to save checkpoint] \ 80 | --epochs 100 \ 81 | --lr 1e-4 --min-lr 1e-7 \ 82 | --hoi_token_length 50 \ 83 | --enable_dec \ 84 | --dataset_file swig 85 | ``` 86 | 87 | ## Inference 88 | 89 | Run this command to evaluate the model on HICO-DET dataset 90 | 91 | ``` bash 92 | python main.py --eval \ 93 | --batch_size 1 \ 94 | --output_dir [path to save results] \ 95 | --hoi_token_length 50 \ 96 | --enable_dec \ 97 | --pretrained [path to the pretrained model] \ 98 | --eval_size 256 [or 224 448 ...] \ 99 | --test_score_thresh 1e-4 \ 100 | --dataset_file hico 101 | ``` 102 | 103 | Run this command to evaluate the model on SWIG-HOI dataset 104 | 105 | ``` bash 106 | python main.py --eval \ 107 | --batch_size 8 \ 108 | --output_dir [path to save results] \ 109 | --hoi_token_length 10 \ 110 | --enable_dec \ 111 | --pretrained [path to the pretrained model] \ 112 | --eval_size 256 [or 224 448 ...] \ 113 | --test_score_thresh 1e-4 \ 114 | --dataset_file swig 115 | ``` 116 | 117 | ## Models 118 | 119 | | Model | dataset | HOI Tokens | AP seen | AP unseen | Log | Checkpoint | 120 | | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | 121 | | `THID-HICO` | HICO-DET | 50 | 25.30 | 17.57 | [Log](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_hico_token50_epoch100_log.txt) | [params](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_hico_token50_epoch100.pth)| 122 | | `THID-HICO` | HICO-DET | 10 | 23.72 | 16.45 | [Log](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_hico_token10_epoch100_log.txt) | [params](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_hico_token10_epoch100.pth)| 123 | 124 | | Model | dataset | HOI Tokens | AP non-rare | AP rare | AP unseen | Log | Checkpoint | 125 | | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | :-----: | 126 | | `THID-SWIG` | SWIG-HOI | 20 | 19.49 | 14.13 | 10.49 | [Log](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_swig_token20_epoch100_log.txt) | [params](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_swig_token20_epoch100.pth)| 127 | | `THID-SWIG` | SWIG-HOI | 10 | 18.30 | 13.99 | 11.14 | [Log](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_swig_token10_epoch50_log.txt) | [params](https://github.com/scwangdyd/promting_hoi/releases/download/v0.2/thid_swig_token10_epoch50.pth)| 128 | 129 | ## Citing 130 | 131 | Please consider citing our paper if it helps your research. 132 | 133 | ``` 134 | @inproceedings{wang_cvpr2022, 135 | author = {Wang, Suchen and Duan, Yueqi and Ding, Henghui and Tan, Yap-Peng and Yap, Kim-Hui and Yuan, Junsong}, 136 | title = {Learning Transferable Human-Object Interaction Detectors with Natural Language Supervision}, 137 | booktitle = {CVPR}, 138 | year = {2022}, 139 | } 140 | ``` -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Additionally modified by Suchen for HOI detector 3 | import argparse 4 | 5 | 6 | def get_args_parser(): 7 | parser = argparse.ArgumentParser('Set Human-Object Interaction Detector', add_help=False) 8 | parser.add_argument('--lr', default=1e-4, type=float) 9 | parser.add_argument('--lr_backbone', default=1e-5, type=float) 10 | parser.add_argument('--batch_size', default=8, type=int) 11 | parser.add_argument('--weight_decay', default=1e-4, type=float) 12 | parser.add_argument('--epochs', default=150, type=int) 13 | parser.add_argument('--lr_drop', default=120, type=int) 14 | parser.add_argument('--clip_max_norm', default=0.1, type=float, 15 | help='gradient clipping max norm') 16 | 17 | # Model Setting 18 | parser.add_argument('--clip_model', default="ViT-B/16", type=str, 19 | help="Name of pretrained CLIP model") 20 | # parser.add_argument('--frozen_weights', type=str, default=None,) 21 | # * Vision 22 | parser.add_argument('--embed_dim', default=512, type=int, 23 | help="Size of the embeddings (dimension of the transformer)") 24 | parser.add_argument('--image_resolution', default=224, type=int, 25 | help="input image resolution to the vision transformer") 26 | parser.add_argument('--vision_layers', default=12, type=int, 27 | help="number of layers in vision transformer") 28 | parser.add_argument('--vision_width', default=768, type=int, 29 | help="feature channels in vision transformer") 30 | parser.add_argument('--vision_patch_size', default=16, type=int, 31 | help="patch size: the input image is divided into multiple patches") 32 | parser.add_argument('--hoi_token_length', default=5, type=int, 33 | help="number of [HOI] tokens added to transformer's input") 34 | # * Text 35 | parser.add_argument('--context_length', default=77, type=int, 36 | help="Maximum length of the text description") 37 | parser.add_argument('--vocab_size', default=49408, type=int, 38 | help="Vocabulary size pre-trained with text encoder") 39 | parser.add_argument('--transformer_width', default=512, type=int, 40 | help="feature channels in text tranformer") 41 | parser.add_argument('--transformer_heads', default=8, type=int, 42 | help="number of multi-attention heads in text transformer") 43 | parser.add_argument('--transformer_layers', default=12, type=int, 44 | help="number of layers in text transformer") 45 | parser.add_argument('--prefix_length', default=8, type=int, 46 | help="number of [PREFIX] tokens at the beginning of sentences") 47 | parser.add_argument('--conjun_length', default=2, type=int, 48 | help="number of [CONJUN] tokens between actions and objects") 49 | # * Bounding box head 50 | parser.add_argument('--enable_dec', action='store_true', help='enable decoders') 51 | parser.add_argument('--dec_heads', default=8, type=int, 52 | help="Number of multi-head attention") 53 | parser.add_argument('--dec_layers', default=4, type=int, 54 | help="Number of layers in the bounding box head") 55 | # Loss 56 | parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false', 57 | help="Disables auxiliary decoding losses (loss at each layer)") 58 | # * Matcher 59 | parser.add_argument('--set_cost_class', default=5, type=float, 60 | help="class coefficient in the matching cost") 61 | parser.add_argument('--set_cost_bbox', default=5, type=float, 62 | help="L1 box coefficient in the matching cost") 63 | parser.add_argument('--set_cost_giou', default=2, type=float, 64 | help="giou box coefficient in the matching cost") 65 | parser.add_argument('--set_cost_conf', default=10, type=float, 66 | help="box confidence score coefficient in the matching cost") 67 | # * Loss coefficients 68 | parser.add_argument('--class_loss_coef', default=5, type=float) 69 | parser.add_argument('--bbox_loss_coef', default=5, type=float) 70 | parser.add_argument('--giou_loss_coef', default=2, type=float) 71 | parser.add_argument('--conf_loss_coef', default=10, type=float) 72 | parser.add_argument('--eos_coef', default=0.1, type=float, 73 | help="relative classification weight of the no-object class") 74 | # * Learning rate schedule parameters 75 | parser.add_argument('--sched', default='warmupcos', type=str, metavar='SCHEDULER', 76 | help='LR scheduler (default: "step", options:"step", "warmupcos"') 77 | parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', 78 | help='learning rate noise on/off epoch percentages') 79 | parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', 80 | help='learning rate noise limit percent (default: 0.67)') 81 | parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', 82 | help='learning rate noise std-dev (default: 1.0)') 83 | parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR', 84 | help='warmup learning rate (default: 1e-6)') 85 | parser.add_argument('--min-lr', type=float, default=1e-7, metavar='LR', 86 | help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') 87 | parser.add_argument('--warmup-epochs', type=int, default=0, metavar='N', 88 | help='epochs to warmup LR, if scheduler supports') 89 | parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', 90 | help='LR decay rate (default: 0.1)') 91 | # Dataset parameters 92 | parser.add_argument('--dataset_file', default='swig', choices=['hico', 'swig']) 93 | parser.add_argument('--repeat_factor_sampling', default=False, type=lambda x: (str(x).lower() == 'true'), 94 | help='apply repeat factor sampling to increase the rate at which tail categories are observed') 95 | parser.add_argument('--zero_shot_exp', default=True, type=lambda x: (str(x).lower() == 'true'), 96 | help='[specific for hico], treat 120 rare interactions as zero shot') 97 | parser.add_argument('--ignore_non_interaction', default=True, type=lambda x: (str(x).lower() == 'true'), 98 | help='[specific for hico], ignore category') 99 | # Inference 100 | parser.add_argument('--test_score_thresh', default=0.01, type=float, 101 | help="threshold to filter out HOI predictions") 102 | parser.add_argument('--eval_size', default=448, type=int, help="image resolution for evaluation") 103 | parser.add_argument('--vis_outputs', action='store_true', help='visualize the model outputs') 104 | parser.add_argument('--vis_dir', default='', help='path where to save visualization results') 105 | # Training setup 106 | parser.add_argument('--eval', action='store_true') 107 | parser.add_argument('--seed', default=22, type=int) 108 | parser.add_argument('--resume', default='', help='resume from checkpoint') 109 | parser.add_argument('--pretrained', default='', help='path to checkpoint') 110 | parser.add_argument('--start_epoch', default=0, type=int, metavar='N', help='start epoch') 111 | # * Log and Device 112 | parser.add_argument('--output_dir', default='', 113 | help='path where to save, empty for no saving') 114 | parser.add_argument('--device', default='cuda', 115 | help='device to use for training / testing') 116 | # * Distributed training parameters 117 | parser.add_argument('--world_size', default=1, type=int, 118 | help='number of distributed processes') 119 | parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training') 120 | parser.add_argument('--local_rank', help='url used to set up distributed training') 121 | parser.add_argument('--num_workers', default=2, type=int) 122 | return parser -------------------------------------------------------------------------------- /baseline_disjoint_detector_and_clip.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script implements the baseline method: Disjoint object detector + Pretrained CLIP model. 3 | 4 | Basic idea: we can learn an off-the-shelf object detector to first produce 5 | the bounding boxes for all humans and objects. Then we build human-object pairs. 6 | For each pair, we crop their union region and send it to the pretrained CLIP model. 7 | 8 | This script assumes that the boxes have been computed (should be given as the input). 9 | """ 10 | import argparse 11 | import os 12 | import json 13 | import pickle 14 | import clip 15 | import torch 16 | import numpy as np 17 | from tqdm import tqdm 18 | from PIL import Image 19 | from utils.hico_evaluator import hico_evaluation, prepare_hico_gts 20 | from utils.hico_categories import ( 21 | HICO_INTERACTIONS, 22 | HICO_OBJECTS, 23 | VERB_MAPPER, 24 | ZERO_SHOT_INTERACTION_IDS, 25 | NON_INTERACTION_IDS 26 | ) 27 | from utils.swig_evaluator import swig_evaluation, prepare_swig_gts 28 | from utils.swig_v1_categories import SWIG_ACTIONS, SWIG_CATEGORIES, SWIG_INTERACTIONS 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("--exp", default="SWIG", type=str, choices=["HICO", "SWIG"], 34 | help="Experiments on which dataset") 35 | parser.add_argument("--precomputed-boxes", type=str, 36 | #default="/raid1/suchen/repo/baselines_hoi/DRG/Data/test_HICO_finetuned_v3.pkl", 37 | default="/raid1/suchen/repo/promting_hoi/data/precomputed/swig_hoi/swig_dev_JSL_boxes.pkl", 38 | help="path to the precomputed boxes.") 39 | parser.add_argument("--dataset-annos", type=str, 40 | # default="/raid1/suchen/repo/promting_hoi/data/HICO-DET/test_hico.json", 41 | default="/raid1/suchen/repo/promting_hoi/data/swig_hoi/swig_dev_1000.json", 42 | help="path to the dataset annotations.") 43 | return parser.parse_args() 44 | 45 | 46 | def load_precomputed_boxes(args): 47 | """ Load precomputed boxes from the given file (default in .pkl). """ 48 | if args.exp == "HICO": 49 | img_dir = "/raid1/suchen/dataset/hico_20160224_det/images/test2015" 50 | 51 | with open(args.precomputed_boxes, "rb") as f: 52 | boxes = pickle.load(f) 53 | with open(args.dataset_annos, "r") as f: 54 | img_annos = json.load(f) 55 | 56 | id_to_filename = {} 57 | for img_dict in img_annos: 58 | img_id = int(img_dict["file_name"].split("_")[-1].split(".")[0]) 59 | img_filename = os.path.join(img_dir, img_dict["file_name"]) 60 | id_to_filename[img_id] = img_filename 61 | 62 | boxes_dict = {} 63 | for img_id, box in boxes.items(): 64 | if img_id not in id_to_filename: 65 | continue 66 | img_filename = id_to_filename[img_id] 67 | boxes_dict[img_filename] = box 68 | 69 | elif args.exp == "SWIG": 70 | img_dir = "/raid1/suchen/dataset/swig/images_512/" 71 | 72 | with open(args.precomputed_boxes, "rb") as f: 73 | boxes = pickle.load(f) 74 | 75 | boxes_dict = {} 76 | for img_name, dets in boxes.items(): 77 | img_filename = os.path.join(img_dir, img_name) 78 | boxes_dict[img_filename] = dets 79 | 80 | return boxes_dict 81 | 82 | 83 | def build_ho_pairs(args, boxes): 84 | """ Pair every human and object box, and return the union region. """ 85 | if args.exp == "HICO": 86 | person_boxes = [] 87 | object_boxes = [] 88 | for box_data in boxes: 89 | box_dict = {"box": box_data[2], "score": box_data[-1], "category_id": box_data[-2]} 90 | score = box_dict["score"] 91 | if score < 0.2: 92 | continue 93 | if box_data[1] == "Human": 94 | person_boxes.append(box_dict) 95 | else: 96 | object_boxes.append(box_dict) 97 | 98 | ho_pairs = [] 99 | for person_dict in person_boxes: 100 | for object_dict in object_boxes: 101 | person_box = person_dict["box"] 102 | object_box = object_dict["box"] 103 | ul = [min(person_box[0], object_box[0]), min(person_box[1], object_box[1])] 104 | br = [max(person_box[2], object_box[2]), max(person_box[3], object_box[3])] 105 | ho_pairs.append({ 106 | "person_box": person_box, 107 | "object_box": object_box, 108 | "union_box": ul + br, 109 | "person_score": person_dict["score"], 110 | "object_score": object_dict["score"], 111 | "object_category": object_dict["category_id"] - 1 # start from 1 112 | }) 113 | 114 | elif args.exp == "SWIG": 115 | person_boxes = [] 116 | object_boxes = [] 117 | for box_data in boxes: 118 | box_dict = {"box": box_data[2:], "score": box_data[1], "category_id": box_data[0]} 119 | score = box_data[1] 120 | if score < 0.01: 121 | continue 122 | if box_data[0] == 0: 123 | person_boxes.append(box_dict) 124 | else: 125 | object_boxes.append(box_dict) 126 | 127 | ho_pairs = [] 128 | for person_dict in person_boxes: 129 | for object_dict in object_boxes: 130 | person_box = person_dict["box"] 131 | object_box = object_dict["box"] 132 | ul = [min(person_box[0], object_box[0]), min(person_box[1], object_box[1])] 133 | br = [max(person_box[2], object_box[2]), max(person_box[3], object_box[3])] 134 | ho_pairs.append({ 135 | "person_box": person_box, 136 | "object_box": object_box, 137 | "union_box": ul + br, 138 | "person_score": person_dict["score"], 139 | "object_score": object_dict["score"], 140 | "object_category": object_dict["category_id"] 141 | }) 142 | 143 | return ho_pairs 144 | 145 | 146 | def prepare_text_inputs(args, model): 147 | """ Encode the classes using pre-trained CLIP text encoder. """ 148 | device = "cuda" if torch.cuda.is_available() else "cpu" 149 | if args.exp == "HICO": 150 | text_inputs = [] 151 | indices_mapper = {} 152 | for i, hoi in enumerate(HICO_INTERACTIONS): 153 | act = hoi["action"] 154 | if act == "no_interaction": 155 | continue 156 | act = act.split("_") 157 | act[0] = VERB_MAPPER[act[0]] 158 | act = " ".join(act) 159 | obj = hoi["object"] 160 | s = f"a photo of people {act} {obj}." 161 | indices_mapper[len(text_inputs)] = i 162 | text_inputs.append(s) 163 | 164 | elif args.exp == "SWIG": 165 | text_inputs = [] 166 | indices_mapper = {} 167 | text_freq = {} 168 | for i, hoi in enumerate(SWIG_INTERACTIONS): 169 | if hoi["evaluation"] == 0: continue 170 | action_id = hoi["action_id"] 171 | object_id = hoi["object_id"] 172 | 173 | act = SWIG_ACTIONS[action_id]["name"] 174 | obj = SWIG_CATEGORIES[object_id]["name"] 175 | act_def = SWIG_ACTIONS[action_id]["def"] 176 | obj_def = SWIG_CATEGORIES[object_id]["def"] 177 | obj_gloss = SWIG_CATEGORIES[object_id]["gloss"] 178 | obj_gloss = [obj] + [x for x in obj_gloss if x != obj] 179 | if len(obj_gloss) > 1: 180 | obj_gloss = " or ".join(obj_gloss) 181 | else: 182 | obj_gloss = obj_gloss[0] 183 | # s = f"A photo of a person {act} with object {obj}. The object {obj} means {obj_def}." 184 | # s = f"a photo of a person {act} with object {obj}" 185 | # s = f"A photo of a person {act} with {obj}. The {act} means to {act_def}." 186 | s = f"A photo of a person {act} with {obj_gloss}. The {act} means to {act_def}." 187 | indices_mapper[len(text_inputs)] = i 188 | text_freq[s] = hoi["frequency"] 189 | text_inputs.append(s) 190 | 191 | text_tokens = torch.cat([clip.tokenize(s) for s in text_inputs]).to(device) 192 | with torch.no_grad(): 193 | text_features = model.encode_text(text_tokens) 194 | text_features /= text_features.norm(dim=-1, keepdim=True) 195 | return text_features, text_inputs, indices_mapper 196 | 197 | 198 | def predict(args, model, preprocess, text_features, text_inputs, indices_mapper, img_filename, ho_pairs): 199 | """ Inference using pretrained CLIP model. """ 200 | device = "cuda" if torch.cuda.is_available() else "cpu" 201 | 202 | image = Image.open(img_filename) 203 | 204 | predictions = [] 205 | for ho_dict in ho_pairs: 206 | union_box = ho_dict["union_box"] 207 | cropped_image = image.crop(tuple(union_box)) 208 | 209 | image_input = preprocess(cropped_image).unsqueeze(0).to(device) 210 | 211 | # Calculate features 212 | with torch.no_grad(): 213 | image_features = model.encode_image(image_input) 214 | image_features /= image_features.norm(dim=-1, keepdim=True) 215 | 216 | # Filter out text inputs 217 | if args.exp == "HICO": 218 | obj_cat = ho_dict["object_category"] 219 | obj_name = HICO_OBJECTS[obj_cat]["name"] 220 | kept_indices = [] 221 | for i, text in enumerate(text_inputs): 222 | if obj_name in text: 223 | kept_indices.append(i) 224 | kept_indices = torch.tensor(kept_indices).to(device) 225 | kept_text_features = text_features[kept_indices] 226 | elif args.exp == "SWIG": 227 | obj_cat = ho_dict["object_category"] 228 | obj_name = SWIG_CATEGORIES[obj_cat]["name"] 229 | kept_indices = [] 230 | for i, text in enumerate(text_inputs): 231 | if obj_name in text: 232 | kept_indices.append(i) 233 | if len(kept_indices) == 0: 234 | continue 235 | kept_indices = torch.tensor(kept_indices).to(device) 236 | kept_text_features = text_features[kept_indices] 237 | 238 | similarity = (100.0 * image_features @ kept_text_features.T).softmax(dim=-1) 239 | 240 | if args.exp == "HICO": 241 | 242 | values, indices = similarity[0].topk(min(3, len(similarity[0]))) 243 | preds_per_pair = [] 244 | for score, idx in zip(values, kept_indices[indices]): 245 | preds_per_pair.append([ 246 | indices_mapper[int(idx)], 247 | ho_dict["person_box"], 248 | ho_dict["object_box"], 249 | float(score) * ho_dict["person_score"] * ho_dict["object_score"] 250 | ]) 251 | elif args.exp == "SWIG": 252 | 253 | preds_per_pair = [] 254 | for score, idx in zip(similarity[0], kept_indices): 255 | preds_per_pair.append([ 256 | indices_mapper[int(idx)], 257 | ho_dict["person_box"], 258 | ho_dict["object_box"], 259 | float(score) * ho_dict["person_score"] * ho_dict["object_score"] 260 | ]) 261 | predictions.extend(preds_per_pair) 262 | 263 | return predictions 264 | 265 | 266 | def evaluate(args): 267 | 268 | if args.exp == "HICO": 269 | # Load detections 270 | with open("./baselines/disjoint_detector_clip_dets.pkl", "rb") as f: 271 | dets = pickle.load(f) 272 | predictions = {} 273 | for img_key, dets_per_img in dets.items(): 274 | img_id = int(img_key.split("_")[-1].split(".")[0]) 275 | predictions[img_id] = dets_per_img 276 | 277 | # Load and prepare ground truth 278 | gts = prepare_hico_gts(args.dataset_annos) 279 | 280 | hico_ap, hico_rec = hico_evaluation(predictions, gts) 281 | 282 | zero_inters = ZERO_SHOT_INTERACTION_IDS 283 | zero_inters = np.asarray(zero_inters) 284 | seen_inters = np.setdiff1d(np.arange(600), zero_inters) 285 | zs_mAP = np.mean(hico_ap[zero_inters]) 286 | sn_mAP = np.mean(hico_ap[seen_inters]) 287 | print("zero-shot mAP: {:.2f}".format(zs_mAP * 100.)) 288 | print("seen mAP: {:.2f}".format(sn_mAP * 100.)) 289 | print("full mAP: {:.2f}".format(np.mean(hico_ap) * 100.)) 290 | 291 | 292 | no_inters = NON_INTERACTION_IDS 293 | zero_inters = np.setdiff1d(zero_inters, no_inters) 294 | seen_inters = np.setdiff1d(seen_inters, no_inters) 295 | full_inters = np.setdiff1d(np.arange(600), no_inters) 296 | zs_mAP = np.mean(hico_ap[zero_inters]) 297 | sn_mAP = np.mean(hico_ap[seen_inters]) 298 | fl_mAP = np.mean(hico_ap[full_inters]) 299 | print("zero-shot mAP: {:.2f}".format(zs_mAP * 100.)) 300 | print("seen mAP: {:.2f}".format(sn_mAP * 100.)) 301 | print("full mAP: {:.2f}".format(fl_mAP * 100.)) 302 | 303 | elif args.exp == "SWIG": 304 | 305 | # Load and prepare ground truth 306 | gts, filename_to_id_mapper = prepare_swig_gts(args.dataset_annos) 307 | 308 | # Load detections 309 | with open(f"./outputs/{args.exp}/disjoint_detector_clip_dets.pkl", "rb") as f: 310 | dets = pickle.load(f) 311 | predictions = {} 312 | for img_key, dets_per_img in dets.items(): 313 | img_id = filename_to_id_mapper[img_key] 314 | predictions[img_id] = dets_per_img 315 | 316 | # Evaluation 317 | swig_ap, swig_rec = swig_evaluation(predictions, gts) 318 | 319 | eval_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["evaluation"] == 1]) 320 | zero_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 0 and x["evaluation"] == 1]) 321 | rare_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 1 and x["evaluation"] == 1]) 322 | nonrare_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 2 and x["evaluation"] == 1]) 323 | 324 | full_mAP = np.mean(swig_ap[eval_hois]) 325 | zero_mAP = np.mean(swig_ap[zero_hois]) 326 | rare_mAP = np.mean(swig_ap[rare_hois]) 327 | nonrare_mAP = np.mean(swig_ap[nonrare_hois]) 328 | print("zero-shot mAP: {:.2f}".format(zero_mAP * 100.)) 329 | print("rare mAP: {:.2f}".format(rare_mAP * 100.)) 330 | print("nonrare mAP: {:.2f}".format(nonrare_mAP * 100.)) 331 | print("full mAP: {:.2f}".format(full_mAP * 100.)) 332 | 333 | 334 | def main(args): 335 | 336 | # Load the model 337 | device = "cuda" if torch.cuda.is_available() else "cpu" 338 | model, preprocess = clip.load('ViT-B/32', device) 339 | 340 | # Load dataset 341 | boxes_dict = load_precomputed_boxes(args) 342 | 343 | # Prepare text inputs 344 | text_features, text_inputs, indices_mapper = prepare_text_inputs(args, model) 345 | 346 | predictions = {} 347 | for img_key, boxes in tqdm(boxes_dict.items()): 348 | ho_pairs = build_ho_pairs(args, boxes) 349 | preds = predict(args, model, preprocess, text_features, text_inputs, 350 | indices_mapper, img_key, ho_pairs) 351 | predictions[os.path.basename(img_key)] = preds 352 | 353 | with open(f"./outputs/{args.exp}/disjoint_detector_clip_dets.pkl", "wb") as f: 354 | pickle.dump(predictions, f) 355 | 356 | 357 | if __name__ == "__main__": 358 | args = parse_args() 359 | main(args) 360 | evaluate(args) -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scwangdyd/promting_hoi/29938ccbcb7c8206873a984628a132064c769270/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _convert_image_to_rgb(image): 72 | return image.convert("RGB") 73 | 74 | 75 | def _transform(n_px): 76 | return Compose([ 77 | Resize(n_px, interpolation=BICUBIC), 78 | CenterCrop(n_px), 79 | _convert_image_to_rgb, 80 | ToTensor(), 81 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 82 | ]) 83 | 84 | 85 | def available_models() -> List[str]: 86 | """Returns the names of available CLIP models""" 87 | return list(_MODELS.keys()) 88 | 89 | 90 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None): 91 | """Load a CLIP model 92 | 93 | Parameters 94 | ---------- 95 | name : str 96 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 97 | 98 | device : Union[str, torch.device] 99 | The device to put the loaded model 100 | 101 | jit : bool 102 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 103 | 104 | download_root: str 105 | path to download the model files; by default, it uses "~/.cache/clip" 106 | 107 | Returns 108 | ------- 109 | model : torch.nn.Module 110 | The CLIP model 111 | 112 | preprocess : Callable[[PIL.Image], torch.Tensor] 113 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | try: 123 | # loading JIT archive 124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 125 | state_dict = None 126 | except RuntimeError: 127 | # loading saved state dict 128 | if jit: 129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 130 | jit = False 131 | state_dict = torch.load(model_path, map_location="cpu") 132 | 133 | if not jit: 134 | model = build_model(state_dict or model.state_dict()).to(device) 135 | if str(device) == "cpu": 136 | model.float() 137 | return model, _transform(model.visual.input_resolution) 138 | 139 | # patch the device names 140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 142 | 143 | def patch_device(module): 144 | try: 145 | graphs = [module.graph] if hasattr(module, "graph") else [] 146 | except RuntimeError: 147 | graphs = [] 148 | 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("prim::Constant"): 154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 155 | node.copyAttributes(device_node) 156 | 157 | model.apply(patch_device) 158 | patch_device(model.encode_image) 159 | patch_device(model.encode_text) 160 | 161 | # patch dtype to float32 on CPU 162 | if str(device) == "cpu": 163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 165 | float_node = float_input.node() 166 | 167 | def patch_float(module): 168 | try: 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | except RuntimeError: 171 | graphs = [] 172 | 173 | if hasattr(module, "forward1"): 174 | graphs.append(module.forward1.graph) 175 | 176 | for graph in graphs: 177 | for node in graph.findAllNodes("aten::to"): 178 | inputs = list(node.inputs()) 179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 180 | if inputs[i].node()["value"] == 5: 181 | inputs[i].node().copyAttributes(float_node) 182 | 183 | model.apply(patch_float) 184 | patch_float(model.encode_image) 185 | patch_float(model.encode_text) 186 | 187 | model.float() 188 | 189 | return model, _transform(model.input_resolution.item()) 190 | 191 | 192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 193 | """ 194 | Returns the tokenized representation of given input string(s) 195 | 196 | Parameters 197 | ---------- 198 | texts : Union[str, List[str]] 199 | An input string or a list of input strings to tokenize 200 | 201 | context_length : int 202 | The context length to use; all CLIP models use 77 as the context length 203 | 204 | truncate: bool 205 | Whether to truncate the text in case its encoding is longer than the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 210 | """ 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | 214 | sot_token = _tokenizer.encoder["<|startoftext|>"] 215 | eot_token = _tokenizer.encoder["<|endoftext|>"] 216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | if truncate: 222 | tokens = tokens[:context_length] 223 | tokens[-1] = eot_token 224 | else: 225 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 226 | result[i, :len(tokens)] = torch.tensor(tokens) 227 | 228 | return result 229 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | ### Dataset 2 | 3 | The experiments are mainly conducted on **HICO-DET** and **SWIG-HOI** dataset. We follow [this repo](https://github.com/YueLiao/PPDM) to prepare the HICO-DET dataset. And we follow [this repo](https://github.com/scwangdyd/large_vocabulary_hoi_detection) to prepare the SWIG-HOI dataset. 4 | 5 | #### HICO-DET 6 | 7 | HICO-DET dataset can be downloaded [here](https://drive.google.com/open?id=1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk). After finishing downloading, unpack the tarball (`hico_20160224_det.tar.gz`) to the `data` directory. We use the annotation files provided by the PPDM authors. We re-organize the annotation with additional meta info, e.g., image width and height. The annotation files can be downloaded from [here](https://drive.google.com/open?id=1lqmevkw8fjDuTqsOOgzg07Kf6lXhK2rg). The downloaded annotation files have to be placed as follows. 8 | 9 | ``` plain 10 | |─ data 11 | │ └─ hico_20160224_det 12 | | |- images 13 | | | |─ test2015 14 | | | |─ train2015 15 | | |─ annotations 16 | | | |─ trainval_hico_ann.json 17 | | | |─ test_hico_ann.json 18 | : : 19 | ``` 20 | 21 | #### SWIG-DET 22 | 23 | SWIG-DET dataset can be downloaded [here](https://swig-data-weights.s3.us-east-2.amazonaws.com/images_512.zip). After finishing downloading, unpack the `images_512.zip` to the `data` directory. The annotation files can be downloaded from [here](https://drive.google.com/open?id=1GxNP99J0KP6Pwfekij_M1Z0moHziX8QN). The downloaded files to be placed as follows. 24 | 25 | ``` plain 26 | |─ data 27 | │ └─ swig_hoi 28 | | |- images_512 29 | | |─ annotations 30 | | | |─ swig_train_1000.json 31 | | | |- swig_val_1000.json 32 | | | |─ swig_trainval_1000.json 33 | | | |- swig_test_1000.json 34 | : : 35 | ``` -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .swig import build as build_swig 2 | from .hico import build as build_hico 3 | 4 | from .swig_evaluator import SWiGEvaluator 5 | from .hico_evaluator import HICOEvaluator 6 | 7 | 8 | def build_dataset(image_set, args): 9 | if args.dataset_file == 'swig': 10 | return build_swig(image_set, args) 11 | if args.dataset_file == 'hico': 12 | return build_hico(image_set, args) 13 | raise ValueError(f'dataset {args.dataset_file} not supported') 14 | 15 | 16 | def build_evaluator(args): 17 | if args.dataset_file == "swig": 18 | from .swig import SWIG_VAL_ANNO 19 | evaluator = SWiGEvaluator(SWIG_VAL_ANNO, args.output_dir) 20 | elif args.dataset_file == "hico": 21 | from .hico import HICO_VAL_ANNO 22 | evaluator = HICOEvaluator(HICO_VAL_ANNO, args.output_dir) 23 | else: 24 | raise NotImplementedError 25 | 26 | return evaluator -------------------------------------------------------------------------------- /datasets/augmentation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | import inspect 5 | import numpy as np 6 | import pprint 7 | from typing import Any, List, Optional, Tuple, Union 8 | from fvcore.transforms.transform import Transform, TransformList 9 | 10 | """ 11 | See "Data Augmentation" tutorial for an overview of the system. 12 | """ 13 | 14 | 15 | __all__ = [ 16 | "Augmentation", 17 | "AugmentationList", 18 | "AugInput", 19 | "TransformGen", 20 | "apply_transform_gens", 21 | "StandardAugInput", 22 | "apply_augmentations", 23 | ] 24 | 25 | 26 | def _check_img_dtype(img): 27 | assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format( 28 | type(img) 29 | ) 30 | assert not isinstance(img.dtype, np.integer) or ( 31 | img.dtype == np.uint8 32 | ), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format( 33 | img.dtype 34 | ) 35 | assert img.ndim in [2, 3], img.ndim 36 | 37 | 38 | def _get_aug_input_args(aug, aug_input) -> List[Any]: 39 | """ 40 | Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``. 41 | """ 42 | if aug.input_args is None: 43 | # Decide what attributes are needed automatically 44 | prms = list(inspect.signature(aug.get_transform).parameters.items()) 45 | # The default behavior is: if there is one parameter, then its "image" 46 | # (work automatically for majority of use cases, and also avoid BC breaking), 47 | # Otherwise, use the argument names. 48 | if len(prms) == 1: 49 | names = ("image",) 50 | else: 51 | names = [] 52 | for name, prm in prms: 53 | if prm.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD): 54 | raise TypeError( 55 | f""" \ 56 | The default implementation of `{type(aug)}.__call__` does not allow \ 57 | `{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \ 58 | If arguments are unknown, reimplement `__call__` instead. \ 59 | """ 60 | ) 61 | names.append(name) 62 | aug.input_args = tuple(names) 63 | 64 | args = [] 65 | for f in aug.input_args: 66 | try: 67 | args.append(getattr(aug_input, f)) 68 | except AttributeError as e: 69 | raise AttributeError( 70 | f"{type(aug)}.get_transform needs input attribute '{f}', " 71 | f"but it is not an attribute of {type(aug_input)}!" 72 | ) from e 73 | return args 74 | 75 | 76 | class Augmentation: 77 | """ 78 | Augmentation defines (often random) policies/strategies to generate :class:`Transform` 79 | from data. It is often used for pre-processing of input data. 80 | 81 | A "policy" that generates a :class:`Transform` may, in the most general case, 82 | need arbitrary information from input data in order to determine what transforms 83 | to apply. Therefore, each :class:`Augmentation` instance defines the arguments 84 | needed by its :meth:`get_transform` method. When called with the positional arguments, 85 | the :meth:`get_transform` method executes the policy. 86 | 87 | Note that :class:`Augmentation` defines the policies to create a :class:`Transform`, 88 | but not how to execute the actual transform operations to those data. 89 | Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform. 90 | 91 | The returned `Transform` object is meant to describe deterministic transformation, which means 92 | it can be re-applied on associated data, e.g. the geometry of an image and its segmentation 93 | masks need to be transformed together. 94 | (If such re-application is not needed, then determinism is not a crucial requirement.) 95 | """ 96 | 97 | input_args: Optional[Tuple[str]] = None 98 | """ 99 | Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``. 100 | By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only 101 | contain "image". As long as the argument name convention is followed, there is no need for 102 | users to touch this attribute. 103 | """ 104 | 105 | def _init(self, params=None): 106 | if params: 107 | for k, v in params.items(): 108 | if k != "self" and not k.startswith("_"): 109 | setattr(self, k, v) 110 | 111 | def get_transform(self, *args) -> Transform: 112 | """ 113 | Execute the policy based on input data, and decide what transform to apply to inputs. 114 | 115 | Args: 116 | args: Any fixed-length positional arguments. By default, the name of the arguments 117 | should exist in the :class:`AugInput` to be used. 118 | 119 | Returns: 120 | Transform: Returns the deterministic transform to apply to the input. 121 | 122 | Examples: 123 | :: 124 | class MyAug: 125 | # if a policy needs to know both image and semantic segmentation 126 | def get_transform(image, sem_seg) -> T.Transform: 127 | pass 128 | tfm: Transform = MyAug().get_transform(image, sem_seg) 129 | new_image = tfm.apply_image(image) 130 | 131 | Notes: 132 | Users can freely use arbitrary new argument names in custom 133 | :meth:`get_transform` method, as long as they are available in the 134 | input data. In detectron2 we use the following convention: 135 | 136 | * image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or 137 | floating point in range [0, 1] or [0, 255]. 138 | * boxes: (N,4) ndarray of float32. It represents the instance bounding boxes 139 | of N instances. Each is in XYXY format in unit of absolute coordinates. 140 | * sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel. 141 | 142 | We do not specify convention for other types and do not include builtin 143 | :class:`Augmentation` that uses other types in detectron2. 144 | """ 145 | raise NotImplementedError 146 | 147 | def __call__(self, aug_input) -> Transform: 148 | """ 149 | Augment the given `aug_input` **in-place**, and return the transform that's used. 150 | 151 | This method will be called to apply the augmentation. In most augmentation, it 152 | is enough to use the default implementation, which calls :meth:`get_transform` 153 | using the inputs. But a subclass can overwrite it to have more complicated logic. 154 | 155 | Args: 156 | aug_input (AugInput): an object that has attributes needed by this augmentation 157 | (defined by ``self.get_transform``). Its ``transform`` method will be called 158 | to in-place transform it. 159 | 160 | Returns: 161 | Transform: the transform that is applied on the input. 162 | """ 163 | args = _get_aug_input_args(self, aug_input) 164 | tfm = self.get_transform(*args) 165 | assert isinstance(tfm, (Transform, TransformList)), ( 166 | f"{type(self)}.get_transform must return an instance of Transform! " 167 | "Got {type(tfm)} instead." 168 | ) 169 | aug_input.transform(tfm) 170 | return tfm 171 | 172 | def _rand_range(self, low=1.0, high=None, size=None): 173 | """ 174 | Uniform float random number between low and high. 175 | """ 176 | if high is None: 177 | low, high = 0, low 178 | if size is None: 179 | size = [] 180 | return np.random.uniform(low, high, size) 181 | 182 | def __repr__(self): 183 | """ 184 | Produce something like: 185 | "MyAugmentation(field1={self.field1}, field2={self.field2})" 186 | """ 187 | try: 188 | sig = inspect.signature(self.__init__) 189 | classname = type(self).__name__ 190 | argstr = [] 191 | for name, param in sig.parameters.items(): 192 | assert ( 193 | param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD 194 | ), "The default __repr__ doesn't support *args or **kwargs" 195 | assert hasattr(self, name), ( 196 | "Attribute {} not found! " 197 | "Default __repr__ only works if attributes match the constructor.".format(name) 198 | ) 199 | attr = getattr(self, name) 200 | default = param.default 201 | if default is attr: 202 | continue 203 | attr_str = pprint.pformat(attr) 204 | if "\n" in attr_str: 205 | # don't show it if pformat decides to use >1 lines 206 | attr_str = "..." 207 | argstr.append("{}={}".format(name, attr_str)) 208 | return "{}({})".format(classname, ", ".join(argstr)) 209 | except AssertionError: 210 | return super().__repr__() 211 | 212 | __str__ = __repr__ 213 | 214 | 215 | def _transform_to_aug(tfm_or_aug): 216 | """ 217 | Wrap Transform into Augmentation. 218 | Private, used internally to implement augmentations. 219 | """ 220 | assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug 221 | if isinstance(tfm_or_aug, Augmentation): 222 | return tfm_or_aug 223 | else: 224 | 225 | class _TransformToAug(Augmentation): 226 | def __init__(self, tfm: Transform): 227 | self.tfm = tfm 228 | 229 | def get_transform(self, *args): 230 | return self.tfm 231 | 232 | def __repr__(self): 233 | return repr(self.tfm) 234 | 235 | __str__ = __repr__ 236 | 237 | return _TransformToAug(tfm_or_aug) 238 | 239 | 240 | class AugmentationList(Augmentation): 241 | """ 242 | Apply a sequence of augmentations. 243 | 244 | It has ``__call__`` method to apply the augmentations. 245 | 246 | Note that :meth:`get_transform` method is impossible (will throw error if called) 247 | for :class:`AugmentationList`, because in order to apply a sequence of augmentations, 248 | the kth augmentation must be applied first, to provide inputs needed by the (k+1)th 249 | augmentation. 250 | """ 251 | 252 | def __init__(self, augs): 253 | """ 254 | Args: 255 | augs (list[Augmentation or Transform]): 256 | """ 257 | super().__init__() 258 | self.augs = [_transform_to_aug(x) for x in augs] 259 | 260 | def __call__(self, aug_input) -> Transform: 261 | tfms = [] 262 | for x in self.augs: 263 | tfm = x(aug_input) 264 | tfms.append(tfm) 265 | return TransformList(tfms) 266 | 267 | def __repr__(self): 268 | msgs = [str(x) for x in self.augs] 269 | return "AugmentationList[{}]".format(", ".join(msgs)) 270 | 271 | __str__ = __repr__ 272 | 273 | 274 | class AugInput: 275 | """ 276 | Input that can be used with :meth:`Augmentation.__call__`. 277 | This is a standard implementation for the majority of use cases. 278 | This class provides the standard attributes **"image", "boxes", "sem_seg"** 279 | defined in :meth:`__init__` and they may be needed by different augmentations. 280 | Most augmentation policies do not need attributes beyond these three. 281 | 282 | After applying augmentations to these attributes (using :meth:`AugInput.transform`), 283 | the returned transforms can then be used to transform other data structures that users have. 284 | 285 | Examples: 286 | :: 287 | input = AugInput(image, boxes=boxes) 288 | tfms = augmentation(input) 289 | transformed_image = input.image 290 | transformed_boxes = input.boxes 291 | transformed_other_data = tfms.apply_other(other_data) 292 | 293 | An extended project that works with new data types may implement augmentation policies 294 | that need other inputs. An algorithm may need to transform inputs in a way different 295 | from the standard approach defined in this class. In those rare situations, users can 296 | implement a class similar to this class, that satify the following condition: 297 | 298 | * The input must provide access to these data in the form of attribute access 299 | (``getattr``). For example, if an :class:`Augmentation` to be applied needs "image" 300 | and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg". 301 | * The input must have a ``transform(tfm: Transform) -> None`` method which 302 | in-place transforms all its attributes. 303 | """ 304 | 305 | # TODO maybe should support more builtin data types here 306 | def __init__( 307 | self, 308 | image: np.ndarray, 309 | *, 310 | boxes: Optional[np.ndarray] = None, 311 | sem_seg: Optional[np.ndarray] = None, 312 | ): 313 | """ 314 | Args: 315 | image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or 316 | floating point in range [0, 1] or [0, 255]. The meaning of C is up 317 | to users. 318 | boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode 319 | sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element 320 | is an integer label of pixel. 321 | """ 322 | _check_img_dtype(image) 323 | self.image = image 324 | self.boxes = boxes 325 | self.sem_seg = sem_seg 326 | 327 | def transform(self, tfm: Transform) -> None: 328 | """ 329 | In-place transform all attributes of this class. 330 | 331 | By "in-place", it means after calling this method, accessing an attribute such 332 | as ``self.image`` will return transformed data. 333 | """ 334 | self.image = tfm.apply_image(self.image) 335 | if self.boxes is not None: 336 | self.boxes = tfm.apply_box(self.boxes) 337 | if self.sem_seg is not None: 338 | self.sem_seg = tfm.apply_segmentation(self.sem_seg) 339 | 340 | def apply_augmentations( 341 | self, augmentations: List[Union[Augmentation, Transform]] 342 | ) -> TransformList: 343 | """ 344 | Equivalent of ``AugmentationList(augmentations)(self)`` 345 | """ 346 | return AugmentationList(augmentations)(self) 347 | 348 | 349 | def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs): 350 | """ 351 | Use ``T.AugmentationList(augmentations)(inputs)`` instead. 352 | """ 353 | if isinstance(inputs, np.ndarray): 354 | # handle the common case of image-only Augmentation, also for backward compatibility 355 | image_only = True 356 | inputs = AugInput(inputs) 357 | else: 358 | image_only = False 359 | tfms = inputs.apply_augmentations(augmentations) 360 | return inputs.image if image_only else inputs, tfms 361 | 362 | 363 | apply_transform_gens = apply_augmentations 364 | """ 365 | Alias for backward-compatibility. 366 | """ 367 | 368 | TransformGen = Augmentation 369 | """ 370 | Alias for Augmentation, since it is something that generates :class:`Transform`s 371 | """ 372 | 373 | StandardAugInput = AugInput 374 | """ 375 | Alias for compatibility. It's not worth the complexity to have two classes. 376 | """ -------------------------------------------------------------------------------- /datasets/hico.py: -------------------------------------------------------------------------------- 1 | """ 2 | HICO-DET dataset utils 3 | """ 4 | import os 5 | import json 6 | import collections 7 | import torch 8 | import torch.utils.data 9 | from torchvision.datasets import CocoDetection 10 | import datasets.transforms as T 11 | from PIL import Image 12 | from .hico_categories import HICO_INTERACTIONS, HICO_ACTIONS, HICO_OBJECTS, ZERO_SHOT_INTERACTION_IDS, NON_INTERACTION_IDS 13 | from utils.sampler import repeat_factors_from_category_frequency, get_dataset_indices 14 | 15 | 16 | # NOTE: Replace the path to your file 17 | HICO_TRAIN_ROOT = "./data/hico_20160224_det/images/train2015" 18 | HICO_TRAIN_ANNO = "./data/hico_20160224_det/annotations/trainval_hico_ann.json" 19 | HICO_VAL_ROOT = "./data/hico_20160224_det/images/test2015" 20 | HICO_VAL_ANNO = "./data/hico_20160224_det/annotations/test_hico_ann.json" 21 | 22 | 23 | class HICO(CocoDetection): 24 | def __init__( 25 | self, 26 | img_folder, 27 | ann_file, 28 | transforms, 29 | image_set, 30 | zero_shot_exp, 31 | repeat_factor_sampling, 32 | ignore_non_interaction 33 | ): 34 | """ 35 | Args: 36 | json_file (str): full path to the json file in HOI instances annotation format. 37 | image_root (str or path-like): the directory where the images in this json file exists. 38 | transforms (class): composition of image transforms. 39 | image_set (str): 'train', 'val', or 'test'. 40 | repeat_factor_sampling (bool): resampling training data to increase the rate of tail 41 | categories to be observed by oversampling the images that contain them. 42 | zero_shot_exp (bool): if true, see the last 120 rare HOI categories as zero-shot, 43 | excluding them from the training data. For the selected rare HOI categories, please 44 | refer to `: ZERO_SHOT_INTERACTION_IDS`. 45 | ignore_non_interaction (bool): Ignore non-interaction categories, since they tend to 46 | confuse the models with the meaning of true interactions. 47 | """ 48 | self.root = img_folder 49 | self.transforms = transforms 50 | # Text description of human-object interactions 51 | dataset_texts, text_mapper = prepare_dataset_text() 52 | self.dataset_texts = dataset_texts 53 | self.text_mapper = text_mapper # text to contiguous ids for evaluation 54 | object_to_related_hois, action_to_related_hois = prepare_related_hois() 55 | self.object_to_related_hois = object_to_related_hois 56 | self.action_to_related_hois = action_to_related_hois 57 | # Load dataset 58 | repeat_factor_sampling = repeat_factor_sampling and image_set == "train" 59 | zero_shot_exp = zero_shot_exp and image_set == "train" 60 | self.dataset_dicts = load_hico_json( 61 | json_file=ann_file, 62 | image_root=img_folder, 63 | zero_shot_exp=zero_shot_exp, 64 | repeat_factor_sampling=repeat_factor_sampling, 65 | ignore_non_interaction=ignore_non_interaction) 66 | 67 | def __getitem__(self, idx: int): 68 | 69 | filename = self.dataset_dicts[idx]["file_name"] 70 | image = Image.open(filename).convert("RGB") 71 | 72 | w, h = image.size 73 | assert w == self.dataset_dicts[idx]["width"], "image shape is not consistent." 74 | assert h == self.dataset_dicts[idx]["height"], "image shape is not consistent." 75 | 76 | image_id = self.dataset_dicts[idx]["image_id"] 77 | annos = self.dataset_dicts[idx]["annotations"] 78 | 79 | boxes = torch.as_tensor(annos["boxes"], dtype=torch.float32).reshape(-1, 4) 80 | boxes[:, 0::2].clamp_(min=0, max=w) 81 | boxes[:, 1::2].clamp_(min=0, max=h) 82 | 83 | classes = torch.tensor(annos["classes"], dtype=torch.int64) 84 | 85 | target = { 86 | "image_id": torch.tensor(image_id), 87 | "orig_size": torch.tensor([h, w]), 88 | "boxes": boxes, 89 | "classes": classes, 90 | "hois": annos["hois"], 91 | } 92 | 93 | if self.transforms is not None: 94 | image, target = self.transforms(image, target) 95 | 96 | return image, target 97 | 98 | def __len__(self): 99 | return len(self.dataset_dicts) 100 | 101 | 102 | def load_hico_json( 103 | json_file, 104 | image_root, 105 | zero_shot_exp=True, 106 | repeat_factor_sampling=False, 107 | ignore_non_interaction=True, 108 | ): 109 | """ 110 | Load a json file with HOI's instances annotation. 111 | 112 | Args: 113 | json_file (str): full path to the json file in HOI instances annotation format. 114 | image_root (str or path-like): the directory where the images in this json file exists. 115 | repeat_factor_sampling (bool): resampling training data to increase the rate of tail 116 | categories to be observed by oversampling the images that contain them. 117 | zero_shot_exp (bool): if true, see the last 120 rare HOI categories as zero-shot, 118 | excluding them from the training data. For the selected rare HOI categories, please 119 | refer to `: ZERO_SHOT_INTERACTION_IDS`. 120 | ignore_non_interaction (bool): Ignore non-interaction categories, since they tend to 121 | confuse the models with the meaning of true interactions. 122 | Returns: 123 | list[dict]: a list of dicts in the following format. 124 | { 125 | 'file_name': path-like str to load image, 126 | 'height': 480, 127 | 'width': 640, 128 | 'image_id': 222, 129 | 'annotations': { 130 | 'boxes': list[list[int]], # n x 4, bounding box annotations 131 | 'classes': list[int], # n, object category annotation of the bounding boxes 132 | 'hois': [ 133 | { 134 | 'subject_id': 0, # person box id (corresponding to the list of boxes above) 135 | 'object_id': 1, # object box id (corresponding to the list of boxes above) 136 | 'action_id', 76, # person action category 137 | 'hoi_id', 459, # interaction category 138 | 'text': ('ride', 'skateboard') # text description of human action and object 139 | } 140 | ] 141 | } 142 | } 143 | """ 144 | imgs_anns = json.load(open(json_file, "r")) 145 | 146 | id_to_contiguous_id_map = {x["id"]: i for i, x in enumerate(HICO_OBJECTS)} 147 | action_object_to_hoi_id = {(x["action"], x["object"]): x["interaction_id"] for x in HICO_INTERACTIONS} 148 | 149 | dataset_dicts = [] 150 | images_without_valid_annotations = [] 151 | for anno_dict in imgs_anns: 152 | record = {} 153 | record["file_name"] = os.path.join(image_root, anno_dict["file_name"]) 154 | record["height"] = anno_dict["height"] 155 | record["width"] = anno_dict["width"] 156 | record["image_id"] = anno_dict["img_id"] 157 | 158 | ignore_flag = False 159 | if len(anno_dict["annotations"]) == 0 or len(anno_dict["hoi_annotation"]) == 0: 160 | images_without_valid_annotations.append(anno_dict) 161 | continue 162 | 163 | boxes = [obj["bbox"] for obj in anno_dict["annotations"]] 164 | classes = [obj["category_id"] for obj in anno_dict["annotations"]] 165 | hoi_annotations = [] 166 | for hoi in anno_dict["hoi_annotation"]: 167 | action_id = hoi["category_id"] - 1 # Starting from 1 168 | target_id = hoi["object_id"] 169 | object_id = id_to_contiguous_id_map[classes[target_id]] 170 | text = (HICO_ACTIONS[action_id]["name"], HICO_OBJECTS[object_id]["name"]) 171 | hoi_id = action_object_to_hoi_id[text] 172 | 173 | # Ignore this annotation if we conduct zero-shot simulation experiments 174 | if zero_shot_exp and (hoi_id in ZERO_SHOT_INTERACTION_IDS): 175 | ignore_flag = True 176 | continue 177 | 178 | # Ignore non-interactions 179 | if ignore_non_interaction and action_id == 57: 180 | continue 181 | 182 | hoi_annotations.append({ 183 | "subject_id": hoi["subject_id"], 184 | "object_id": hoi["object_id"], 185 | "action_id": action_id, 186 | "hoi_id": hoi_id, 187 | "text": text 188 | }) 189 | 190 | if len(hoi_annotations) == 0 or ignore_flag: 191 | continue 192 | 193 | targets = { 194 | "boxes": boxes, 195 | "classes": classes, 196 | "hois": hoi_annotations, 197 | } 198 | 199 | record["annotations"] = targets 200 | dataset_dicts.append(record) 201 | 202 | if repeat_factor_sampling: 203 | repeat_factors = repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh=0.003) 204 | dataset_indices = get_dataset_indices(repeat_factors) 205 | dataset_dicts = [dataset_dicts[i] for i in dataset_indices] 206 | 207 | return dataset_dicts 208 | 209 | 210 | def prepare_dataset_text(): 211 | texts = [] 212 | text_mapper = {} 213 | for i, hoi in enumerate(HICO_INTERACTIONS): 214 | action_name = " ".join(hoi["action"].split("_")) 215 | object_name = hoi["object"] 216 | s = [action_name, object_name] 217 | text_mapper[len(texts)] = i 218 | texts.append(s) 219 | return texts, text_mapper 220 | 221 | 222 | def prepare_related_hois(): 223 | ''' Gather related hois based on object names and action names 224 | Returns: 225 | object_to_related_hois (dict): { 226 | object_text (e.g., chair): [ 227 | {'hoi_id': 86, 'text': ['carry', 'chair']}, 228 | {'hoi_id': 87, 'text': ['hold', 'chair']}, 229 | ... 230 | ] 231 | } 232 | 233 | action_to_relatedhois (dict): { 234 | action_text (e.g., carry): [ 235 | {'hoi_id': 10, 'text': ['carry', 'bicycle']}, 236 | {'hoi_id': 46, 'text': ['carry', 'bottle']}, 237 | ... 238 | ] 239 | } 240 | ''' 241 | object_to_related_hois = collections.defaultdict(list) 242 | action_to_related_hois = collections.defaultdict(list) 243 | 244 | for x in HICO_INTERACTIONS: 245 | action_text = x['action'] 246 | object_text = x['object'] 247 | hoi_id = x['interaction_id'] 248 | if hoi_id in ZERO_SHOT_INTERACTION_IDS or hoi_id in NON_INTERACTION_IDS: 249 | continue 250 | hoi_text = [action_text, object_text] 251 | 252 | object_to_related_hois[object_text].append({'hoi_id': hoi_id, 'text': hoi_text}) 253 | action_to_related_hois[action_text].append({'hoi_id': hoi_id, 'text': hoi_text}) 254 | 255 | return object_to_related_hois, action_to_related_hois 256 | 257 | 258 | def make_transforms(image_set, args): 259 | normalize = T.Compose([ 260 | T.ToTensor(), 261 | T.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]), 262 | ]) 263 | 264 | scales = [224, 256, 288, 320, 352, 384, 416, 448, 480, 512] 265 | 266 | if image_set == "train": 267 | return T.Compose([ 268 | T.RandomHorizontalFlip(), 269 | T.ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2], saturation=[0.8, 1.2]), 270 | T.RandomSelect( 271 | T.RandomResize(scales, max_size=scales[-1] * 1333 // 800), 272 | T.Compose([ 273 | T.RandomCrop_InteractionConstraint((0.75, 0.75), 0.8), 274 | T.RandomResize(scales, max_size=scales[-1] * 1333 // 800), 275 | ]) 276 | ), 277 | normalize, 278 | ]) 279 | 280 | if image_set == "val": 281 | return T.Compose([ 282 | T.RandomResize([args.eval_size], max_size=args.eval_size * 1333 // 800), 283 | normalize 284 | ]) 285 | 286 | raise ValueError(f'unknown {image_set}') 287 | 288 | """ deprecated (Fixed image resolution + random cropping + centering) 289 | normalize = T.Compose([ 290 | T.ToTensor(), 291 | T.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]), 292 | ]) 293 | 294 | if image_set == "train": 295 | return T.Compose([ 296 | T.RandomHorizontalFlip(), 297 | T.ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2], saturation=[0.8, 1.2]), 298 | T.RandomSelect( 299 | T.ResizeAndCenterCrop(224), 300 | T.Compose([ 301 | T.RandomCrop_InteractionConstraint((0.7, 0.7), 0.9), 302 | T.ResizeAndCenterCrop(224) 303 | ]), 304 | ), 305 | normalize 306 | ]) 307 | if image_set == "val": 308 | return T.Compose([ 309 | T.ResizeAndCenterCrop(224), 310 | normalize 311 | ]) 312 | 313 | raise ValueError(f'unknown {image_set}') 314 | """ 315 | 316 | 317 | def build(image_set, args): 318 | # NOTE: Replace the path to your file 319 | PATHS = { 320 | "train": (HICO_TRAIN_ROOT, HICO_TRAIN_ANNO), 321 | "val": (HICO_VAL_ROOT, HICO_VAL_ANNO), 322 | } 323 | 324 | img_folder, ann_file = PATHS[image_set] 325 | dataset = HICO( 326 | img_folder, 327 | ann_file, 328 | transforms=make_transforms(image_set, args), 329 | image_set=image_set, 330 | zero_shot_exp=args.zero_shot_exp, 331 | repeat_factor_sampling=args.repeat_factor_sampling, 332 | ignore_non_interaction=args.ignore_non_interaction 333 | ) 334 | 335 | return dataset -------------------------------------------------------------------------------- /datasets/hico_evaluator.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import numpy as np 3 | import json 4 | import os 5 | import pickle 6 | from .hico_categories import HICO_INTERACTIONS, HICO_ACTIONS, HICO_OBJECTS 7 | from .hico_categories import ZERO_SHOT_INTERACTION_IDS, NON_INTERACTION_IDS 8 | 9 | 10 | class HICOEvaluator(object): 11 | ''' Evaluator for HICO-DET dataset ''' 12 | def __init__(self, anno_file, output_dir): 13 | size = 600 14 | self.size = size 15 | self.gts = self.load_anno(anno_file) 16 | self.scores = {i: [] for i in range(size)} 17 | self.boxes = {i: [] for i in range(size)} 18 | self.keys = {i: [] for i in range(size)} 19 | self.hico_ap = np.zeros(size) 20 | self.hico_rec = np.zeros(size) 21 | self.output_dir = output_dir 22 | 23 | def update(self, predictions): 24 | ''' Store predictions 25 | Args: 26 | predictions (dict): a dictionary in the following format. 27 | { 28 | img_id: [ 29 | [hoi_id, score, pbox_x1, pbox_y1, pbox_x2, pbox_y2, obox_x1, obox_y1, obox_x2, obox_y2], 30 | ... 31 | ... 32 | ] 33 | } 34 | ''' 35 | for img_id, preds in predictions.items(): 36 | for pred in preds: 37 | hoi_id = pred[0] 38 | score = pred[1] 39 | boxes = pred[2:] 40 | self.scores[hoi_id].append(score) 41 | self.boxes[hoi_id].append(boxes) 42 | self.keys[hoi_id].append(img_id) 43 | 44 | def accumulate(self): 45 | for hoi_id in range(600): 46 | gts_per_hoi = self.gts[hoi_id] 47 | ap, rec = calc_ap(self.scores[hoi_id], self.boxes[hoi_id], self.keys[hoi_id], gts_per_hoi) 48 | self.hico_ap[hoi_id], self.hico_rec[hoi_id] = ap, rec 49 | 50 | def summarize(self): 51 | valid_hois = np.setdiff1d(np.arange(600), NON_INTERACTION_IDS) 52 | seen_hois = np.setdiff1d(valid_hois, ZERO_SHOT_INTERACTION_IDS) 53 | zero_shot_hois = np.setdiff1d(ZERO_SHOT_INTERACTION_IDS, NON_INTERACTION_IDS) 54 | zero_shot_mAP = np.mean(self.hico_ap[zero_shot_hois]) 55 | seen_mAP = np.mean(self.hico_ap[seen_hois]) 56 | print("zero-shot mAP: {:.2f}".format(zero_shot_mAP * 100.)) 57 | print("seen mAP: {:.2f}".format(seen_mAP * 100.)) 58 | print("full mAP: {:.2f}".format(np.mean(self.hico_ap[valid_hois]) * 100.)) 59 | 60 | def save_preds(self): 61 | with open(os.path.join(self.output_dir, "preds.pkl"), "wb") as f: 62 | pickle.dump({"scores": self.scores, "boxes": self.boxes, "keys": self.keys}, f) 63 | 64 | def save(self, output_dir=None): 65 | if output_dir is None: 66 | output_dir = self.output_dir 67 | with open(os.path.join(output_dir, "dets.pkl"), "wb") as f: 68 | pickle.dump({"gts": self.gts, "scores": self.scores, "boxes": self.boxes, "keys": self.keys}, f) 69 | 70 | def load_anno(self, anno_file): 71 | with open(anno_file, "r") as f: 72 | dataset_dicts = json.load(f) 73 | 74 | action_id2name = {x["id"]: x["name"] for x in HICO_ACTIONS} 75 | object_id2name = {x["id"]: x["name"] for x in HICO_OBJECTS} 76 | hoi_mapper = {(x["action"], x["object"]): x["interaction_id"] for x in HICO_INTERACTIONS} 77 | 78 | size = self.size 79 | gts = {i: collections.defaultdict(list) for i in range(size)} 80 | for anno_dict in dataset_dicts: 81 | image_id = anno_dict["img_id"] 82 | box_annos = anno_dict.get("annotations", []) 83 | hoi_annos = anno_dict.get("hoi_annotation", []) 84 | for hoi in hoi_annos: 85 | person_box = box_annos[hoi["subject_id"]]["bbox"] 86 | object_box = box_annos[hoi["object_id"]]["bbox"] 87 | action_id = hoi["category_id"] - 1 # original annotations start from 1 88 | object_id = box_annos[hoi["object_id"]]["category_id"] # original annotations start from 1 89 | hoi_id = hoi_mapper[(action_id2name[action_id], object_id2name[object_id])] 90 | gts[hoi_id][image_id].append(person_box + object_box) 91 | 92 | for hoi_id in gts: 93 | for img_id in gts[hoi_id]: 94 | gts[hoi_id][img_id] = np.array(gts[hoi_id][img_id]) 95 | 96 | return gts 97 | 98 | 99 | def calc_ap(scores, boxes, keys, gt_boxes): 100 | 101 | if len(keys) == 0: 102 | return 0, 0 103 | 104 | if isinstance(boxes, list): 105 | scores, boxes, key = np.array(scores), np.array(boxes), np.array(keys) 106 | 107 | hit = [] 108 | idx = np.argsort(scores)[::-1] 109 | npos = 0 110 | used = {} 111 | 112 | for key in gt_boxes.keys(): 113 | npos += gt_boxes[key].shape[0] 114 | used[key] = set() 115 | 116 | for i in range(min(len(idx), 19999)): 117 | pair_id = idx[i] 118 | box = boxes[pair_id, :] 119 | key = keys[pair_id] 120 | if key in gt_boxes: 121 | maxi = 0.0 122 | k = -1 123 | for i in range(gt_boxes[key].shape[0]): 124 | tmp = calc_hit(box, gt_boxes[key][i, :]) 125 | if maxi < tmp: 126 | maxi = tmp 127 | k = i 128 | if k in used[key] or maxi < 0.5: 129 | hit.append(0) 130 | else: 131 | hit.append(1) 132 | used[key].add(k) 133 | else: 134 | hit.append(0) 135 | bottom = np.array(range(len(hit))) + 1 136 | hit = np.cumsum(hit) 137 | rec = hit / npos if npos > 0 else hit / (npos + 1e-8) 138 | prec = hit / bottom 139 | ap = 0.0 140 | for i in range(11): 141 | mask = rec >= (i / 10.0) 142 | if np.sum(mask) > 0: 143 | ap += np.max(prec[mask]) / 11.0 144 | 145 | return ap, np.max(rec) if len(rec) else 0 146 | 147 | 148 | def calc_hit(det, gtbox): 149 | gtbox = gtbox.astype(np.float64) 150 | hiou = iou(det[:4], gtbox[:4]) 151 | oiou = iou(det[4:], gtbox[4:]) 152 | return min(hiou, oiou) 153 | 154 | 155 | def iou(bb1, bb2, debug = False): 156 | x1 = bb1[2] - bb1[0] 157 | y1 = bb1[3] - bb1[1] 158 | if x1 < 0: 159 | x1 = 0 160 | if y1 < 0: 161 | y1 = 0 162 | 163 | x2 = bb2[2] - bb2[0] 164 | y2 = bb2[3] - bb2[1] 165 | if x2 < 0: 166 | x2 = 0 167 | if y2 < 0: 168 | y2 = 0 169 | 170 | xiou = min(bb1[2], bb2[2]) - max(bb1[0], bb2[0]) 171 | yiou = min(bb1[3], bb2[3]) - max(bb1[1], bb2[1]) 172 | if xiou < 0: 173 | xiou = 0 174 | if yiou < 0: 175 | yiou = 0 176 | 177 | if debug: 178 | print(x1, y1, x2, y2, xiou, yiou) 179 | print(x1 * y1, x2 * y2, xiou * yiou) 180 | if xiou * yiou <= 0: 181 | return 0 182 | else: 183 | return xiou * yiou / (x1 * y1 + x2 * y2 - xiou * yiou) 184 | 185 | 186 | ''' deprecated, evaluator 187 | def hico_evaluation(predictions, gts): 188 | images, results = [], [] 189 | for img_key, ps in predictions.items(): 190 | images.extend([img_key] * len(ps)) 191 | results.extend(ps) 192 | 193 | hico_ap, hico_rec = np.zeros(600), np.zeros(600) 194 | 195 | scores = [[] for _ in range(600)] 196 | boxes = [[] for _ in range(600)] 197 | keys = [[] for _ in range(600)] 198 | 199 | for img_id, det in zip(images, results): 200 | hoi_id, person_box, object_box, score = int(det[0]), det[1], det[2], det[-1] 201 | scores[hoi_id].append(score) 202 | boxes[hoi_id].append([float(x) for x in person_box] + [float(x) for x in object_box]) 203 | keys[hoi_id].append(img_id) 204 | 205 | for hoi_id in range(600): 206 | gts_per_hoi = gts[hoi_id] 207 | ap, rec = calc_ap(scores[hoi_id], boxes[hoi_id], keys[hoi_id], gts_per_hoi) 208 | hico_ap[hoi_id], hico_rec[hoi_id] = ap, rec 209 | 210 | return hico_ap, hico_rec 211 | 212 | 213 | def prepare_hico_gts(anno_file): 214 | """ 215 | Convert dataset to the format required by evaluator. 216 | """ 217 | with open(anno_file, "r") as f: 218 | dataset_dicts = json.load(f) 219 | 220 | action_mapper = {x["name"]: x["id"]+1 for x in HICO_ACTIONS} 221 | object_mapper = {x["name"]: x["id"] for x in HICO_OBJECTS} 222 | hoi_mapper = {(action_mapper[x["action"]], object_mapper[x["object"]]): x["interaction_id"] 223 | for x in HICO_INTERACTIONS} 224 | 225 | gts = {i: collections.defaultdict(list) for i in range(600)} 226 | for anno_dict in dataset_dicts: 227 | image_id = int(anno_dict["file_name"].split("_")[-1].split(".")[0]) 228 | box_annos = anno_dict.get("annotations", []) 229 | hoi_annos = anno_dict.get("hoi_annotation", []) 230 | for hoi in hoi_annos: 231 | person_box = box_annos[hoi["subject_id"]]["bbox"] 232 | object_box = box_annos[hoi["object_id"]]["bbox"] 233 | action_id = hoi["category_id"] 234 | object_id = box_annos[hoi["object_id"]]["category_id"] 235 | hoi_id = hoi_mapper[(action_id, object_id)] 236 | gts[hoi_id][image_id].append(person_box + object_box) 237 | 238 | for hoi_id in gts: 239 | for img_id in gts[hoi_id]: 240 | gts[hoi_id][img_id] = np.array(gts[hoi_id][img_id]) 241 | 242 | return gts 243 | ''' -------------------------------------------------------------------------------- /datasets/swig.py: -------------------------------------------------------------------------------- 1 | """ 2 | SWiG-HOI dataset utils. 3 | """ 4 | import os 5 | import json 6 | import torch 7 | import torch.utils.data 8 | from torchvision.datasets import CocoDetection 9 | import datasets.transforms as T 10 | from PIL import Image 11 | from .swig_v1_categories import SWIG_INTERACTIONS, SWIG_ACTIONS, SWIG_CATEGORIES 12 | from utils.sampler import repeat_factors_from_category_frequency, get_dataset_indices 13 | 14 | # NOTE: Replace the path to your file 15 | SWIG_ROOT = "./data/swig_hoi/images_512" 16 | SWIG_TRAIN_ANNO = "./data/swig_hoi/annotations/swig_trainval_1000.json" 17 | SWIG_VAL_ANNO = "./data/swig_hoi/annotations/swig_test_1000.json" 18 | SWIG_TEST_ANNO = "./data/swig_hoi/annotations/swig_test_1000.json" 19 | 20 | 21 | class SWiGHOIDetection(CocoDetection): 22 | def __init__(self, img_folder, ann_file, transforms, image_set, repeat_factor_sampling): 23 | self.root = img_folder 24 | self.transforms = transforms 25 | # Text description of human-object interactions 26 | dataset_texts, text_mapper = prepare_dataset_text(image_set) 27 | self.dataset_texts = dataset_texts 28 | self.text_mapper = text_mapper 29 | # Load dataset 30 | repeat_factor_sampling = repeat_factor_sampling and image_set == "train" 31 | reverse_text_mapper = {v: k for k, v in text_mapper.items()} 32 | self.dataset_dicts = load_swig_json(ann_file, img_folder, reverse_text_mapper, repeat_factor_sampling) 33 | 34 | def __getitem__(self, idx: int): 35 | 36 | filename = self.dataset_dicts[idx]["file_name"] 37 | image = Image.open(filename).convert("RGB") 38 | 39 | w, h = image.size 40 | assert w == self.dataset_dicts[idx]["width"], "image shape is not consistent." 41 | assert h == self.dataset_dicts[idx]["height"], "image shape is not consistent." 42 | 43 | image_id = self.dataset_dicts[idx]["image_id"] 44 | annos = self.dataset_dicts[idx]["annotations"] 45 | 46 | boxes = torch.as_tensor(annos["boxes"], dtype=torch.float32).reshape(-1, 4) 47 | boxes[:, 0::2].clamp_(min=0, max=w) 48 | boxes[:, 1::2].clamp_(min=0, max=h) 49 | 50 | classes = torch.tensor(annos["classes"], dtype=torch.int64) 51 | aux_classes = torch.tensor(annos["aux_classes"], dtype=torch.int64) 52 | 53 | target = { 54 | "image_id": torch.tensor(image_id), 55 | "orig_size": torch.tensor([h, w]), 56 | "boxes": boxes, 57 | "classes": classes, 58 | "aux_classes": aux_classes, 59 | "hois": annos["hois"], 60 | } 61 | 62 | if self.transforms is not None: 63 | image, target = self.transforms(image, target) 64 | 65 | return image, target 66 | 67 | def __len__(self): 68 | return len(self.dataset_dicts) 69 | 70 | 71 | def load_swig_json(json_file, image_root, text_mapper, repeat_factor_sampling=False): 72 | """ 73 | Load a json file with HOI's instances annotation. 74 | 75 | Args: 76 | json_file (str): full path to the json file in HOI instances annotation format. 77 | image_root (str or path-like): the directory where the images in this json file exists. 78 | text_mapper (dict): a dictionary to map text descriptions of HOIs to contiguous ids. 79 | repeat_factor_sampling (bool): resampling training data to increase the rate of tail 80 | categories to be observed by oversampling the images that contain them. 81 | Returns: 82 | list[dict]: a list of dicts in the following format. 83 | { 84 | 'file_name': path-like str to load image, 85 | 'height': 480, 86 | 'width': 640, 87 | 'image_id': 222, 88 | 'annotations': { 89 | 'boxes': list[list[int]], # n x 4, bounding box annotations 90 | 'classes': list[int], # n, object category annotation of the bounding boxes 91 | 'aux_classes': list[list], # n x 3, a list of auxiliary object annotations 92 | 'hois': [ 93 | { 94 | 'subject_id': 0, # person box id (corresponding to the list of boxes above) 95 | 'object_id': 1, # object box id (corresponding to the list of boxes above) 96 | 'action_id', 76, # person action category 97 | 'hoi_id', 459, # interaction category 98 | 'text': ('ride', 'skateboard') # text description of human action and object 99 | } 100 | ] 101 | } 102 | } 103 | """ 104 | HOI_MAPPER = {(x["action_id"], x["object_id"]): x["id"] for x in SWIG_INTERACTIONS} 105 | 106 | imgs_anns = json.load(open(json_file, "r")) 107 | 108 | dataset_dicts = [] 109 | images_without_valid_annotations = [] 110 | for anno_dict in imgs_anns: 111 | record = {} 112 | record["file_name"] = os.path.join(image_root, anno_dict["file_name"]) 113 | record["height"] = anno_dict["height"] 114 | record["width"] = anno_dict["width"] 115 | record["image_id"] = anno_dict["img_id"] 116 | 117 | if len(anno_dict["box_annotations"]) == 0 or len(anno_dict["hoi_annotations"]) == 0: 118 | images_without_valid_annotations.append(anno_dict) 119 | continue 120 | 121 | boxes = [obj["bbox"] for obj in anno_dict["box_annotations"]] 122 | classes = [obj["category_id"] for obj in anno_dict["box_annotations"]] 123 | aux_classes = [] 124 | for obj in anno_dict["box_annotations"]: 125 | aux_categories = obj["aux_category_id"] 126 | while len(aux_categories) < 3: 127 | aux_categories.append(-1) 128 | aux_classes.append(aux_categories) 129 | 130 | for hoi in anno_dict["hoi_annotations"]: 131 | target_id = hoi["object_id"] 132 | object_id = classes[target_id] 133 | action_id = hoi["action_id"] 134 | hoi["text"] = generate_text(action_id, object_id) 135 | continguous_id = HOI_MAPPER[(action_id, object_id)] 136 | hoi["hoi_id"] = text_mapper[continguous_id] 137 | 138 | targets = { 139 | "boxes": boxes, 140 | "classes": classes, 141 | "aux_classes": aux_classes, 142 | "hois": anno_dict["hoi_annotations"], 143 | } 144 | 145 | record["annotations"] = targets 146 | dataset_dicts.append(record) 147 | 148 | if repeat_factor_sampling: 149 | repeat_factors = repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh=0.0001) 150 | dataset_indices = get_dataset_indices(repeat_factors) 151 | dataset_dicts = [dataset_dicts[i] for i in dataset_indices] 152 | 153 | return dataset_dicts 154 | 155 | 156 | def generate_text(action_id, object_id): 157 | act = SWIG_ACTIONS[action_id]["name"] 158 | obj = SWIG_CATEGORIES[object_id]["name"] 159 | act_def = SWIG_ACTIONS[action_id]["def"] 160 | obj_def = SWIG_CATEGORIES[object_id]["def"] 161 | obj_gloss = SWIG_CATEGORIES[object_id]["gloss"] 162 | obj_gloss = [obj] + [x for x in obj_gloss if x != obj] 163 | if len(obj_gloss) > 1: 164 | obj_gloss = " or ".join(obj_gloss) 165 | else: 166 | obj_gloss = obj_gloss[0] 167 | 168 | # s = [act, obj_gloss] 169 | s = [act, obj] 170 | return s 171 | 172 | 173 | ''' deprecated, text 174 | # def generate_text(action_id, object_id): 175 | # act = SWIG_ACTIONS[action_id]["name"] 176 | # obj = SWIG_CATEGORIES[object_id]["name"] 177 | # act_def = SWIG_ACTIONS[action_id]["def"] 178 | # obj_def = SWIG_CATEGORIES[object_id]["def"] 179 | # obj_gloss = SWIG_CATEGORIES[object_id]["gloss"] 180 | # obj_gloss = [obj] + [x for x in obj_gloss if x != obj] 181 | # if len(obj_gloss) > 1: 182 | # obj_gloss = " or ".join(obj_gloss) 183 | # else: 184 | # obj_gloss = obj_gloss[0] 185 | # # s = f"A photo of a person {act} with object {obj}. The object {obj} means {obj_def}." 186 | # # s = f"a photo of a person {act} with object {obj}" 187 | # # s = f"A photo of a person {act} with {obj}. The {act} means to {act_def}." 188 | # s = f"A photo of a person {act} with {obj_gloss}. The {act} means to {act_def}." 189 | # return s 190 | ''' 191 | 192 | 193 | def prepare_dataset_text(image_set): 194 | texts = [] 195 | text_mapper = {} 196 | for i, hoi in enumerate(SWIG_INTERACTIONS): 197 | if image_set != "train" and hoi["evaluation"] == 0: continue 198 | action_id = hoi["action_id"] 199 | object_id = hoi["object_id"] 200 | s = generate_text(action_id, object_id) 201 | text_mapper[len(texts)] = i 202 | texts.append(s) 203 | return texts, text_mapper 204 | 205 | 206 | def make_transforms(image_set, args): 207 | normalize = T.Compose([ 208 | T.ToTensor(), 209 | T.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]), 210 | ]) 211 | 212 | # scales = [224, 256, 288, 320, 352, 384, 416, 448, 480, 512] 213 | scales = [224, 256, 288, 320] 214 | 215 | if image_set == "train": 216 | return T.Compose([ 217 | T.RandomHorizontalFlip(), 218 | T.ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2], saturation=[0.8, 1.2]), 219 | T.RandomSelect( 220 | T.RandomResize(scales, max_size=scales[-1] * 1333 // 800), 221 | T.Compose([ 222 | T.RandomCrop_InteractionConstraint((0.7, 0.7), 0.9), 223 | T.RandomResize(scales, max_size=scales[-1] * 1333 // 800), 224 | ]) 225 | ), 226 | normalize, 227 | ]) 228 | 229 | if image_set == "val": 230 | return T.Compose([ 231 | T.RandomResize([args.eval_size], max_size=args.eval_size * 1333 // 800), 232 | normalize 233 | ]) 234 | 235 | raise ValueError(f'unknown {image_set}') 236 | 237 | 238 | ''' deprecated (Fixed image resolution + random cropping + centering) 239 | def make_transforms(image_set): 240 | 241 | normalize = T.Compose([ 242 | T.ToTensor(), 243 | T.Normalize([0.48145466, 0.4578275, 0.40821073], [0.26862954, 0.26130258, 0.27577711]), 244 | ]) 245 | 246 | if image_set == "train": 247 | return T.Compose([ 248 | T.RandomHorizontalFlip(), 249 | T.ColorJitter(brightness=[0.8, 1.2], contrast=[0.8, 1.2], saturation=[0.8, 1.2]), 250 | T.RandomSelect( 251 | T.ResizeAndCenterCrop(224), 252 | T.Compose([ 253 | T.RandomCrop_InteractionConstraint((0.8, 0.8), 0.9), 254 | T.ResizeAndCenterCrop(224) 255 | ]), 256 | ), 257 | normalize 258 | ]) 259 | if image_set == "val": 260 | return T.Compose([ 261 | T.ResizeAndCenterCrop(224), 262 | normalize 263 | ]) 264 | 265 | raise ValueError(f'unknown {image_set}') 266 | ''' 267 | 268 | 269 | def build(image_set, args): 270 | # NOTE: Replace the path to your file 271 | PATHS = { 272 | "train": (SWIG_ROOT, SWIG_TRAIN_ANNO), 273 | "val": (SWIG_ROOT, SWIG_VAL_ANNO), 274 | "dev": (SWIG_ROOT, SWIG_TEST_ANNO), 275 | } 276 | 277 | img_folder, ann_file = PATHS[image_set] 278 | dataset = SWiGHOIDetection( 279 | img_folder, 280 | ann_file, 281 | transforms=make_transforms(image_set, args), 282 | image_set=image_set, 283 | repeat_factor_sampling=args.repeat_factor_sampling, 284 | ) 285 | 286 | return dataset -------------------------------------------------------------------------------- /datasets/swig_evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import collections 4 | import json 5 | import numpy as np 6 | from .swig_v1_categories import SWIG_INTERACTIONS 7 | 8 | 9 | class SWiGEvaluator(object): 10 | ''' Evaluator for SWIG-HOI dataset ''' 11 | def __init__(self, anno_file, output_dir): 12 | eval_hois = [x["id"] for x in SWIG_INTERACTIONS if x["evaluation"] == 1] 13 | size = max(eval_hois) + 1 14 | self.eval_hois = eval_hois 15 | 16 | self.gts = self.load_anno(anno_file) 17 | self.scores = {i: [] for i in range(size)} 18 | self.boxes = {i: [] for i in range(size)} 19 | self.keys = {i: [] for i in range(size)} 20 | self.swig_ap = np.zeros(size) 21 | self.swig_rec = np.zeros(size) 22 | self.output_dir = output_dir 23 | 24 | def update(self, predictions): 25 | # update predictions 26 | for img_id, preds in predictions.items(): 27 | for pred in preds: 28 | hoi_id = pred[0] 29 | score = pred[1] 30 | boxes = pred[2:] 31 | self.scores[hoi_id].append(score) 32 | self.boxes[hoi_id].append(boxes) 33 | self.keys[hoi_id].append(img_id) 34 | 35 | def accumulate(self): 36 | for hoi_id in self.eval_hois: 37 | gts_per_hoi = self.gts[hoi_id] 38 | ap, rec = calc_ap(self.scores[hoi_id], self.boxes[hoi_id], self.keys[hoi_id], gts_per_hoi) 39 | self.swig_ap[hoi_id], self.swig_rec[hoi_id] = ap, rec 40 | 41 | def summarize(self): 42 | eval_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["evaluation"] == 1]) 43 | zero_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 0 and x["evaluation"] == 1]) 44 | rare_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 1 and x["evaluation"] == 1]) 45 | nonrare_hois = np.asarray([x["id"] for x in SWIG_INTERACTIONS if x["frequency"] == 2 and x["evaluation"] == 1]) 46 | 47 | full_mAP = np.mean(self.swig_ap[eval_hois]) 48 | zero_mAP = np.mean(self.swig_ap[zero_hois]) 49 | rare_mAP = np.mean(self.swig_ap[rare_hois]) 50 | nonrare_mAP = np.mean(self.swig_ap[nonrare_hois]) 51 | print("zero-shot mAP: {:.2f}".format(zero_mAP * 100.)) 52 | print("rare mAP: {:.2f}".format(rare_mAP * 100.)) 53 | print("nonrare mAP: {:.2f}".format(nonrare_mAP * 100.)) 54 | print("full mAP: {:.2f}".format(full_mAP * 100.)) 55 | 56 | def save_preds(self): 57 | with open(os.path.join(self.output_dir, "preds.pkl"), "wb") as f: 58 | pickle.dump({"scores": self.scores, "boxes": self.boxes, "keys": self.keys}, f) 59 | 60 | def save(self, output_dir=None): 61 | if output_dir is None: 62 | output_dir = self.output_dir 63 | with open(os.path.join(output_dir, "dets.pkl"), "wb") as f: 64 | pickle.dump({"gts": self.gts, "scores": self.scores, "boxes": self.boxes, "keys": self.keys}, f) 65 | 66 | def load_anno(self, anno_file): 67 | with open(anno_file, "r") as f: 68 | dataset_dicts = json.load(f) 69 | 70 | hoi_mapper = {(x["action_id"], x["object_id"]): x["id"] for x in SWIG_INTERACTIONS} 71 | 72 | size = max(self.eval_hois) + 1 73 | gts = {i: collections.defaultdict(list) for i in range(size)} 74 | for anno_dict in dataset_dicts: 75 | image_id = anno_dict["img_id"] 76 | box_annos = anno_dict.get("box_annotations", []) 77 | hoi_annos = anno_dict.get("hoi_annotations", []) 78 | for hoi in hoi_annos: 79 | person_box = box_annos[hoi["subject_id"]]["bbox"] 80 | object_box = box_annos[hoi["object_id"]]["bbox"] 81 | action_id = hoi["action_id"] 82 | object_id = box_annos[hoi["object_id"]]["category_id"] 83 | hoi_id = hoi_mapper[(action_id, object_id)] 84 | gts[hoi_id][image_id].append(person_box + object_box) 85 | 86 | for hoi_id in gts: 87 | for img_id in gts[hoi_id]: 88 | gts[hoi_id][img_id] = np.array(gts[hoi_id][img_id]) 89 | 90 | return gts 91 | 92 | 93 | def calc_ap(scores, boxes, keys, gt_boxes): 94 | if len(keys) == 0: 95 | return 0, 0 96 | 97 | if isinstance(boxes, list): 98 | scores, boxes, key = np.array(scores), np.array(boxes), np.array(keys) 99 | 100 | hit = [] 101 | idx = np.argsort(scores)[::-1] 102 | npos = 0 103 | used = {} 104 | 105 | for key in gt_boxes.keys(): 106 | npos += gt_boxes[key].shape[0] 107 | used[key] = set() 108 | 109 | for i in range(min(len(idx), 19999)): 110 | pair_id = idx[i] 111 | box = boxes[pair_id, :] 112 | key = keys[pair_id] 113 | if key in gt_boxes: 114 | maxi = 0.0 115 | k = -1 116 | for i in range(gt_boxes[key].shape[0]): 117 | tmp = calc_hit(box, gt_boxes[key][i, :]) 118 | if maxi < tmp: 119 | maxi = tmp 120 | k = i 121 | if k in used[key] or maxi < 0.5: 122 | hit.append(0) 123 | else: 124 | hit.append(1) 125 | used[key].add(k) 126 | else: 127 | hit.append(0) 128 | bottom = np.array(range(len(hit))) + 1 129 | hit = np.cumsum(hit) 130 | rec = hit / npos 131 | prec = hit / bottom 132 | ap = 0.0 133 | for i in range(11): 134 | mask = rec >= (i / 10.0) 135 | if np.sum(mask) > 0: 136 | ap += np.max(prec[mask]) / 11.0 137 | 138 | return ap, np.max(rec) 139 | 140 | 141 | def calc_hit(det, gtbox): 142 | gtbox = gtbox.astype(np.float64) 143 | hiou = iou(det[:4], gtbox[:4]) 144 | oiou = iou(det[4:], gtbox[4:]) 145 | return min(hiou, oiou) 146 | 147 | 148 | def iou(bb1, bb2, debug = False): 149 | x1 = bb1[2] - bb1[0] 150 | y1 = bb1[3] - bb1[1] 151 | if x1 < 0: 152 | x1 = 0 153 | if y1 < 0: 154 | y1 = 0 155 | 156 | x2 = bb2[2] - bb2[0] 157 | y2 = bb2[3] - bb2[1] 158 | if x2 < 0: 159 | x2 = 0 160 | if y2 < 0: 161 | y2 = 0 162 | 163 | xiou = min(bb1[2], bb2[2]) - max(bb1[0], bb2[0]) 164 | yiou = min(bb1[3], bb2[3]) - max(bb1[1], bb2[1]) 165 | if xiou < 0: 166 | xiou = 0 167 | if yiou < 0: 168 | yiou = 0 169 | 170 | if debug: 171 | print(x1, y1, x2, y2, xiou, yiou) 172 | print(x1 * y1, x2 * y2, xiou * yiou) 173 | if xiou * yiou <= 0: 174 | return 0 175 | else: 176 | return xiou * yiou / (x1 * y1 + x2 * y2 - xiou * yiou) 177 | 178 | 179 | ''' deprecated, evaluator 180 | eval_hois = [x["id"] for x in SWIG_INTERACTIONS if x["evaluation"] == 1] 181 | def swig_evaluation(predictions, gts): 182 | images, results = [], [] 183 | for img_key, ps in predictions.items(): 184 | images.extend([img_key] * len(ps)) 185 | results.extend(ps) 186 | 187 | size = max(eval_hois) + 1 188 | swig_ap, swig_rec = np.zeros(size), np.zeros(size) 189 | 190 | scores = [[] for _ in range(size)] 191 | boxes = [[] for _ in range(size)] 192 | keys = [[] for _ in range(size)] 193 | 194 | for img_id, det in zip(images, results): 195 | hoi_id, person_box, object_box, score = int(det[0]), det[1], det[2], det[-1] 196 | scores[hoi_id].append(score) 197 | boxes[hoi_id].append([float(x) for x in person_box] + [float(x) for x in object_box]) 198 | keys[hoi_id].append(img_id) 199 | 200 | for hoi_id in eval_hois: 201 | gts_per_hoi = gts[hoi_id] 202 | ap, rec = calc_ap(scores[hoi_id], boxes[hoi_id], keys[hoi_id], gts_per_hoi) 203 | swig_ap[hoi_id], swig_rec[hoi_id] = ap, rec 204 | 205 | return swig_ap, swig_rec 206 | 207 | 208 | def prepare_swig_gts(anno_file): 209 | """ 210 | Convert dataset to the format required by evaluator. 211 | """ 212 | with open(anno_file, "r") as f: 213 | dataset_dicts = json.load(f) 214 | 215 | filename_to_id_mapper = {x["file_name"]: i for i, x in enumerate(dataset_dicts)} 216 | hoi_mapper = {(x["action_id"], x["object_id"]): x["id"] for x in SWIG_INTERACTIONS} 217 | 218 | size = max(eval_hois) + 1 219 | gts = {i: collections.defaultdict(list) for i in range(size)} 220 | for anno_dict in dataset_dicts: 221 | image_id = filename_to_id_mapper[anno_dict["file_name"]] 222 | box_annos = anno_dict.get("box_annotations", []) 223 | hoi_annos = anno_dict.get("hoi_annotations", []) 224 | for hoi in hoi_annos: 225 | person_box = box_annos[hoi["subject_id"]]["bbox"] 226 | object_box = box_annos[hoi["object_id"]]["bbox"] 227 | action_id = hoi["action_id"] 228 | object_id = box_annos[hoi["object_id"]]["category_id"] 229 | hoi_id = hoi_mapper[(action_id, object_id)] 230 | gts[hoi_id][image_id].append(person_box + object_box) 231 | 232 | for hoi_id in gts: 233 | for img_id in gts[hoi_id]: 234 | gts[hoi_id][img_id] = np.array(gts[hoi_id][img_id]) 235 | 236 | return gts, filename_to_id_mapper 237 | ''' -------------------------------------------------------------------------------- /datasets/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Transforms and data augmentation for both image + bbox. 4 | """ 5 | import random 6 | import numpy as np 7 | import PIL 8 | import torch 9 | import torchvision.transforms as T 10 | import torchvision.transforms.functional as F 11 | from utils.box_ops import box_xyxy_to_cxcywh 12 | from utils.misc import interpolate 13 | 14 | 15 | def crop(image, target, region): 16 | ori_w, ori_h = image.size 17 | cropped_image = F.crop(image, *region) 18 | 19 | target = target.copy() 20 | i, j, h, w = region 21 | 22 | # should we do something wrt the original size? 23 | target["size"] = torch.tensor([h, w]) 24 | 25 | """ deprecated, this part is mainly for ResizeAndCenterCrop (deprecated) 26 | # Image is padded with 0 if the crop region is out of boundary. 27 | # We use `image_mask` to indicate the padding regions. 28 | image_mask = torch.zeros((h, w)).bool() 29 | image_mask[:abs(i), :] = True 30 | image_mask[:, :abs(j)] = True 31 | image_mask[abs(i) + ori_h :, :] = True 32 | image_mask[:, abs(j) + ori_w :] = True 33 | target["image_mask"] = image_mask 34 | """ 35 | 36 | fields = ["classes"] 37 | 38 | if "boxes" in target: 39 | boxes = target["boxes"] 40 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 41 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 42 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 43 | cropped_boxes = cropped_boxes.clamp(min=0) 44 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 45 | target["boxes"] = cropped_boxes.reshape(-1, 4) 46 | target["area"] = area 47 | fields.append("boxes") 48 | 49 | # remove elements for which the boxes or masks that have zero area 50 | if "boxes" in target: 51 | # favor boxes selection when defining which elements to keep 52 | # this is compatible with previous implementation 53 | if "boxes" in target: 54 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 55 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 56 | else: 57 | keep = target['masks'].flatten(1).any(1) 58 | 59 | for field in fields: 60 | target[field] = target[field][keep] 61 | 62 | id_mapper = {} 63 | cnt = 0 64 | for i, is_kept in enumerate(keep): 65 | if is_kept: 66 | id_mapper[i] = cnt 67 | cnt += 1 68 | 69 | if "hois" in target: 70 | kept_hois = [] 71 | for hoi in target["hois"]: 72 | if keep[hoi["subject_id"]] and keep[hoi["object_id"]]: 73 | hoi["subject_id"] = id_mapper[hoi["subject_id"]] 74 | hoi["object_id"] = id_mapper[hoi["object_id"]] 75 | kept_hois.append(hoi) 76 | target["hois"] = kept_hois 77 | 78 | return cropped_image, target 79 | 80 | 81 | def hflip(image, target): 82 | flipped_image = F.hflip(image) 83 | 84 | w, h = image.size 85 | 86 | target = target.copy() 87 | if "boxes" in target: 88 | boxes = target["boxes"] 89 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 90 | target["boxes"] = boxes 91 | 92 | if "masks" in target: 93 | target['masks'] = target['masks'].flip(-1) 94 | 95 | return flipped_image, target 96 | 97 | 98 | def resize(image, target, size, max_size=None): 99 | # size can be min_size (scalar) or (w, h) tuple 100 | # import pdb;pdb.set_trace() 101 | maxs = size 102 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 103 | w, h = image_size 104 | if max_size is not None: 105 | min_original_size = float(min((w, h))) 106 | max_original_size = float(max((w, h))) 107 | if max_original_size / min_original_size * size > max_size: 108 | size = int(round(max_size * min_original_size / max_original_size)) 109 | 110 | if (w <= h and w == size) or (h <= w and h == size): 111 | w_mod = np.mod(w, 16) 112 | h_mod = np.mod(h, 16) 113 | h = h - h_mod 114 | w = w - w_mod 115 | return (h, w) 116 | 117 | if w < h: 118 | ow = size 119 | oh = int(size * h / w) 120 | ow_mod = np.mod(ow, 16) 121 | oh_mod = np.mod(oh, 16) 122 | ow = ow - ow_mod 123 | oh = oh - oh_mod 124 | else: 125 | oh = size 126 | ow = int(size * w / h) 127 | ow_mod = np.mod(ow, 16) 128 | oh_mod = np.mod(oh, 16) 129 | ow = ow - ow_mod 130 | oh = oh - oh_mod 131 | 132 | return (oh, ow) 133 | 134 | def get_size(image_size, size, max_size=None): 135 | if isinstance(size, (list, tuple)): 136 | return size[::-1] 137 | else: 138 | return get_size_with_aspect_ratio(image_size, size, max_size) 139 | 140 | size = get_size(image.size, size, max_size) 141 | rescaled_image = F.resize(image, size) 142 | 143 | if target is None: 144 | return rescaled_image, None 145 | 146 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 147 | ratio_width, ratio_height = ratios 148 | 149 | target = target.copy() 150 | if "boxes" in target: 151 | boxes = target["boxes"] 152 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 153 | target["boxes"] = scaled_boxes 154 | 155 | if "area" in target: 156 | area = target["area"] 157 | scaled_area = area * (ratio_width * ratio_height) 158 | target["area"] = scaled_area 159 | 160 | h, w = size 161 | target["size"] = torch.tensor([h, w]) 162 | 163 | if "masks" in target: 164 | target['masks'] = interpolate( 165 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 166 | 167 | return rescaled_image, target 168 | 169 | 170 | 171 | def pad(image, target, padding): 172 | # assumes that we only pad on the bottom right corners 173 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 174 | if target is None: 175 | return padded_image, None 176 | target = target.copy() 177 | # should we do something wrt the original size? 178 | target["size"] = torch.tensor(padded_image.size[::-1]) 179 | if "masks" in target: 180 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 181 | return padded_image, target 182 | 183 | 184 | def resize_long_edge(image, target, size, max_size=None): 185 | """Resize the image based on the long edge.""" 186 | # size can be min_size (scalar) or (w, h) tuple 187 | 188 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 189 | w, h = image_size 190 | 191 | if (w >= h and w == size) or (h >= w and h == size): 192 | return (h, w) 193 | 194 | if w > h: 195 | ow = size 196 | oh = int(size * h / w) 197 | else: 198 | oh = size 199 | ow = int(size * w / h) 200 | 201 | return (oh, ow) 202 | 203 | def get_size(image_size, size, max_size=None): 204 | if isinstance(size, (list, tuple)): 205 | return size[::-1] 206 | else: 207 | return get_size_with_aspect_ratio(image_size, size, max_size) 208 | 209 | size = get_size(image.size, size, max_size) 210 | rescaled_image = F.resize(image, size) 211 | 212 | if target is None: 213 | return rescaled_image, None 214 | 215 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 216 | ratio_width, ratio_height = ratios 217 | 218 | target = target.copy() 219 | if "boxes" in target: 220 | boxes = target["boxes"] 221 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 222 | target["boxes"] = scaled_boxes 223 | 224 | if "area" in target: 225 | area = target["area"] 226 | scaled_area = area * (ratio_width * ratio_height) 227 | target["area"] = scaled_area 228 | 229 | h, w = size 230 | target["size"] = torch.tensor([h, w]) 231 | 232 | if "masks" in target: 233 | target['masks'] = interpolate( 234 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 235 | 236 | return rescaled_image, target 237 | 238 | 239 | class ColorJitter(object): 240 | def __init__(self, brightness, contrast, saturation): 241 | self.color_jitter = T.ColorJitter(brightness, contrast, saturation) 242 | 243 | def __call__(self, img, target): 244 | return self.color_jitter(img), target 245 | 246 | 247 | class RandomCrop(object): 248 | def __init__(self, size): 249 | self.size = size 250 | 251 | def __call__(self, img, target): 252 | region = T.RandomCrop.get_params(img, self.size) 253 | return crop(img, target, region) 254 | 255 | 256 | class RandomSizeCrop(object): 257 | def __init__(self, min_size: int, max_size: int): 258 | self.min_size = min_size 259 | self.max_size = max_size 260 | 261 | def __call__(self, img: PIL.Image.Image, target: dict): 262 | w = random.randint(self.min_size, min(img.width, self.max_size)) 263 | h = random.randint(self.min_size, min(img.height, self.max_size)) 264 | region = T.RandomCrop.get_params(img, [h, w]) 265 | return crop(img, target, region) 266 | 267 | 268 | class CenterCrop(object): 269 | def __init__(self, size): 270 | self.size = size 271 | 272 | def __call__(self, img, target): 273 | image_width, image_height = img.size 274 | crop_height, crop_width = self.size 275 | crop_top = int(round((image_height - crop_height) / 2.)) 276 | crop_left = int(round((image_width - crop_width) / 2.)) 277 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 278 | 279 | class ResizeAndCenterCrop(object): 280 | def __init__(self, size): 281 | self.size = size 282 | 283 | def __call__(self, img, target): 284 | img, target = resize_long_edge(img, target, self.size) 285 | image_width, image_height = img.size 286 | crop_height, crop_width = self.size, self.size 287 | crop_top = int(round((image_height - crop_height) / 2.)) 288 | crop_left = int(round((image_width - crop_width) / 2.)) 289 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 290 | 291 | class RandomHorizontalFlip(object): 292 | def __init__(self, p=0.5): 293 | self.p = p 294 | 295 | def __call__(self, img, target): 296 | if random.random() < self.p: 297 | return hflip(img, target) 298 | return img, target 299 | 300 | 301 | class RandomResize(object): 302 | def __init__(self, sizes, max_size=None): 303 | assert isinstance(sizes, (list, tuple)) 304 | self.sizes = sizes 305 | self.max_size = max_size 306 | 307 | def __call__(self, img, target=None): 308 | size = random.choice(self.sizes) 309 | return resize(img, target, size, self.max_size) 310 | 311 | 312 | class RandomPad(object): 313 | def __init__(self, max_pad): 314 | self.max_pad = max_pad 315 | 316 | def __call__(self, img, target): 317 | pad_x = random.randint(0, self.max_pad) 318 | pad_y = random.randint(0, self.max_pad) 319 | return pad(img, target, (pad_x, pad_y)) 320 | 321 | 322 | class RandomSelect(object): 323 | """ 324 | Randomly selects between transforms1 and transforms2, 325 | with probability p for transforms1 and (1 - p) for transforms2 326 | """ 327 | def __init__(self, transforms1, transforms2, p=0.5): 328 | self.transforms1 = transforms1 329 | self.transforms2 = transforms2 330 | self.p = p 331 | 332 | def __call__(self, img, target): 333 | if random.random() < self.p: 334 | return self.transforms1(img, target) 335 | return self.transforms2(img, target) 336 | 337 | 338 | class ToTensor(object): 339 | def __call__(self, img, target): 340 | return F.to_tensor(img), target 341 | 342 | 343 | class RandomErasing(object): 344 | 345 | def __init__(self, *args, **kwargs): 346 | self.eraser = T.RandomErasing(*args, **kwargs) 347 | 348 | def __call__(self, img, target): 349 | return self.eraser(img), target 350 | 351 | 352 | class Normalize(object): 353 | def __init__(self, mean, std): 354 | self.mean = mean 355 | self.std = std 356 | 357 | def __call__(self, image, target=None): 358 | image = F.normalize(image, mean=self.mean, std=self.std) 359 | if target is None: 360 | return image, None 361 | target = target.copy() 362 | h, w = image.shape[-2:] 363 | if "boxes" in target: 364 | boxes = target["boxes"] 365 | boxes = box_xyxy_to_cxcywh(boxes) 366 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 367 | target["boxes"] = boxes 368 | return image, target 369 | 370 | 371 | class Compose(object): 372 | def __init__(self, transforms): 373 | self.transforms = transforms 374 | 375 | def __call__(self, image, target): 376 | for t in self.transforms: 377 | image, target = t(image, target) 378 | return image, target 379 | 380 | def __repr__(self): 381 | format_string = self.__class__.__name__ + "(" 382 | for t in self.transforms: 383 | format_string += "\n" 384 | format_string += " {0}".format(t) 385 | format_string += "\n)" 386 | return format_string 387 | 388 | 389 | class RandomCrop_InteractionConstraint(object): 390 | """ 391 | Similar to :class:`RandomCrop`, but find a cropping window such that at most interactions 392 | in the image can be kept. 393 | """ 394 | def __init__(self, crop_ratio, p: float): 395 | """ 396 | Args: 397 | crop_type, crop_size: same as in :class:`RandomCrop` 398 | """ 399 | self.crop_ratio = crop_ratio 400 | self.p = p 401 | 402 | def __call__(self, image, target): 403 | boxes = target["boxes"] 404 | w, h = image.size[:2] 405 | croph, cropw = int(h * self.crop_ratio[0]), int(w * self.crop_ratio[1]) 406 | assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self) 407 | h0_choice = np.arange(0, h - croph + 1) 408 | w0_choice = np.arange(0, w - cropw + 1) 409 | h_prob, w_prob = np.ones(len(h0_choice)), np.ones(len(w0_choice)) 410 | for box in boxes: 411 | h_min = min(int(box[1] - croph) + 1, len(h_prob)) 412 | h_max = int(box[3]) 413 | w_min = min(int(box[0] - cropw) + 1, len(w_prob)) 414 | w_max = int(box[2]) 415 | if h_min > 0: 416 | h_prob[:h_min] = h_prob[:h_min] * self.p 417 | if h_max < h0_choice[-1]: 418 | h_prob[h_max:] = h_prob[h_max:] * self.p 419 | if w_min > 0: 420 | w_prob[:w_min] = w_prob[:w_min] * self.p 421 | if w_max < w0_choice[-1]: 422 | w_prob[w_max:] = w_prob[w_max:] * self.p 423 | h_prob, w_prob = h_prob / h_prob.sum(), w_prob / w_prob.sum() 424 | h0 = int(np.random.choice(h0_choice, 1, p=h_prob)[0]) 425 | w0 = int(np.random.choice(w0_choice, 1, p=w_prob)[0]) 426 | return crop(image, target, (h0, w0, croph, cropw)) 427 | 428 | 429 | """ deprecated, resize implementation 430 | def resize(image, target, size, max_size=None): 431 | # size can be min_size (scalar) or (w, h) tuple 432 | 433 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 434 | w, h = image_size 435 | if max_size is not None: 436 | min_original_size = float(min((w, h))) 437 | max_original_size = float(max((w, h))) 438 | if max_original_size / min_original_size * size > max_size: 439 | size = int(round(max_size * min_original_size / max_original_size)) 440 | 441 | if (w <= h and w == size) or (h <= w and h == size): 442 | return (h, w) 443 | 444 | if w < h: 445 | ow = size 446 | oh = int(size * h / w) 447 | else: 448 | oh = size 449 | ow = int(size * w / h) 450 | 451 | return (oh, ow) 452 | 453 | def get_size(image_size, size, max_size=None): 454 | if isinstance(size, (list, tuple)): 455 | return size[::-1] 456 | else: 457 | return get_size_with_aspect_ratio(image_size, size, max_size) 458 | 459 | size = get_size(image.size, size, max_size) 460 | rescaled_image = F.resize(image, size) 461 | 462 | if target is None: 463 | return rescaled_image, None 464 | 465 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 466 | ratio_width, ratio_height = ratios 467 | 468 | target = target.copy() 469 | if "boxes" in target: 470 | boxes = target["boxes"] 471 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 472 | target["boxes"] = scaled_boxes 473 | 474 | if "area" in target: 475 | area = target["area"] 476 | scaled_area = area * (ratio_width * ratio_height) 477 | target["area"] = scaled_area 478 | 479 | h, w = size 480 | target["size"] = torch.tensor([h, w]) 481 | 482 | if "masks" in target: 483 | target['masks'] = interpolate( 484 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 485 | 486 | return rescaled_image, target 487 | """ -------------------------------------------------------------------------------- /engine.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Modified by Suchen for HOI detection 3 | """ 4 | Train and eval functions used in main.py 5 | """ 6 | import math 7 | import sys 8 | from typing import Iterable 9 | import torch 10 | import utils.misc as utils 11 | from models.model import convert_weights 12 | from datasets import build_evaluator 13 | from utils.visualizer import Visualizer 14 | from fvcore.nn import FlopCountAnalysis, flop_count_table 15 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 16 | _tokenizer = _Tokenizer() 17 | 18 | 19 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 20 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 21 | device: torch.device, epoch: int, max_norm: float = 0): 22 | model.train() 23 | criterion.train() 24 | metric_logger = utils.MetricLogger(delimiter=" ") 25 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 26 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 27 | header = 'Epoch: [{}]'.format(epoch) 28 | print_freq = 10 29 | 30 | for images, targets in metric_logger.log_every(data_loader, print_freq, header): 31 | 32 | images, targets, texts = prepare_inputs(images, targets, data_loader, device) 33 | outputs = model(images.tensors, texts, images.mask) 34 | loss_dict, indices = criterion(outputs, targets) 35 | weight_dict = criterion.weight_dict 36 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 37 | 38 | # reduce losses over all GPUs for logging purposes 39 | loss_dict_reduced = utils.reduce_dict(loss_dict) 40 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} 41 | loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} 42 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 43 | 44 | loss_value = losses_reduced_scaled.item() 45 | 46 | if not math.isfinite(loss_value): 47 | print("Loss is {}, stopping training".format(loss_value)) 48 | print(loss_dict_reduced) 49 | sys.exit(1) 50 | 51 | optimizer.zero_grad() 52 | losses.backward() 53 | if max_norm > 0: 54 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) 55 | optimizer.step() 56 | 57 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 58 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 59 | metric_logger.update(lr=optimizer.param_groups[0]["lr"]) 60 | # gather the stats from all processes 61 | metric_logger.synchronize_between_processes() 62 | print("Averaged stats:", metric_logger) 63 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 64 | 65 | 66 | @torch.no_grad() 67 | def evaluate(model, postprocessors, criterion, data_loader, device, args): 68 | model.eval() 69 | criterion.eval() 70 | 71 | metric_logger = utils.MetricLogger(delimiter=" ") 72 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 73 | header = 'Test:' 74 | 75 | # Convert applicable model parameters to fp16 76 | convert_weights(model) 77 | 78 | # Build evaluator 79 | evaluator = build_evaluator(args) 80 | 81 | # Convert all interaction categories into embeddings 82 | text_features = prepare_text_inputs(model, data_loader.dataset.dataset_texts, device) 83 | 84 | # Inference 85 | for images, targets in metric_logger.log_every(data_loader, 10, header): 86 | images = images.to(device) 87 | targets = [{k: v.to(device) if k != "hois" else v for k, v in t.items()} for t in targets] 88 | 89 | vision_outputs = model.encode_image(images.tensors, images.mask) 90 | 91 | hoi_features = vision_outputs['hoi_features'] 92 | hoi_features = hoi_features / hoi_features.norm(dim=-1, keepdim=True) 93 | logits_per_hoi = model.logit_scale.exp() * hoi_features @ text_features.t() 94 | pred_boxes = vision_outputs["pred_boxes"] 95 | box_scores = vision_outputs["box_scores"] 96 | 97 | outputs = {"logits_per_hoi": logits_per_hoi, 98 | "pred_boxes": pred_boxes, 99 | "box_scores": box_scores, 100 | "aux_outputs": vision_outputs["aux_outputs"], 101 | "attn_maps": vision_outputs['attn_maps']} 102 | 103 | loss_dict, indices = criterion(outputs, targets) 104 | weight_dict = criterion.weight_dict 105 | 106 | if args.vis_outputs: 107 | visualizer = Visualizer(args) 108 | visualizer.visualize_preds(images, targets, outputs) 109 | # visualizer.visualize_attention(images, targets, outputs) 110 | 111 | # reduce losses over all GPUs for logging purposes 112 | loss_dict_reduced = utils.reduce_dict(loss_dict) 113 | loss_dict_reduced_scaled = {k: v * weight_dict[k] for k, v in loss_dict_reduced.items() if k in weight_dict} 114 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v for k, v in loss_dict_reduced.items()} 115 | metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()), **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled) 116 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 117 | 118 | results = {int(targets[i]['image_id']): postprocessors( 119 | {'pred_logits': logits_per_hoi[i], 'pred_boxes': pred_boxes[i], 'box_scores': box_scores[i]}, 120 | targets[i]['orig_size'], 121 | data_loader.dataset.text_mapper 122 | ) for i in range(len(images.tensors))} 123 | 124 | evaluator.update(results) 125 | 126 | # gather the stats from all processes 127 | metric_logger.synchronize_between_processes() 128 | print("Averaged stats:", metric_logger) 129 | 130 | evaluator.save_preds() 131 | # accumulate predictions from all images 132 | evaluator.accumulate() 133 | evaluator.summarize() 134 | stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 135 | return stats, evaluator 136 | 137 | 138 | def prepare_inputs(images, targets, data_loader, device): 139 | """Prepare model inputs.""" 140 | # image inputs 141 | images = images.to(device) 142 | targets = [{k: v.to(device) if k != "hois" else v for k, v in t.items()} for t in targets] 143 | 144 | # text inputs 145 | sot_token = _tokenizer.encoder["<|startoftext|>"] 146 | eot_token = _tokenizer.encoder["<|endoftext|>"] 147 | 148 | texts = [] 149 | text_inputs = [] 150 | unique_hois = set() 151 | 152 | for t in targets: 153 | for hoi in t["hois"]: 154 | # Ensure all texts are unique (no duplicates). 155 | hoi_id = hoi["hoi_id"] 156 | if hoi_id in unique_hois: 157 | continue 158 | else: 159 | unique_hois.add(hoi_id) 160 | action_text, object_text = hoi["text"] 161 | action_token = _tokenizer.encode(action_text.replace("_", " ")) 162 | object_token = _tokenizer.encode(object_text.replace("_", " ")) 163 | 164 | action_token = torch.as_tensor([sot_token] + action_token, dtype=torch.long).to(device) 165 | object_token = torch.as_tensor(object_token + [eot_token], dtype=torch.long).to(device) 166 | texts.append([action_token, object_token]) 167 | text_inputs.append(action_text + " " + object_text) 168 | 169 | # [specific for HICO-DET], load related hois based on the targets in mini-batch 170 | if hasattr(data_loader.dataset, 'object_to_related_hois') and hasattr(data_loader.dataset, 'action_to_related_hois'): 171 | object_to_related_hois = data_loader.dataset.object_to_related_hois 172 | action_to_related_hois = data_loader.dataset.action_to_related_hois 173 | 174 | related_texts = [] 175 | related_text_inputs = [] 176 | unique_actions = set() 177 | unique_objects = set() 178 | unique_related_hois = set() 179 | for t in targets: 180 | for hoi in t["hois"]: 181 | hoi_id = hoi["hoi_id"] 182 | query_action_text, query_object_text = hoi["text"] 183 | if query_action_text in unique_actions or query_object_text in unique_objects: 184 | continue 185 | else: 186 | unique_actions.add(query_action_text) 187 | unique_objects.add(query_object_text) 188 | 189 | related_hois = action_to_related_hois[query_action_text] 190 | for hoi in related_hois: 191 | hoi_id = hoi["hoi_id"] 192 | if hoi_id in unique_hois: 193 | continue 194 | if hoi_id in unique_related_hois: 195 | continue 196 | else: 197 | unique_related_hois.add(hoi_id) 198 | 199 | action_text, object_text = hoi["text"] 200 | action_token = _tokenizer.encode(action_text.replace("_", " ")) 201 | object_token = _tokenizer.encode(object_text.replace("_", " ")) 202 | action_token = torch.as_tensor([sot_token] + action_token, dtype=torch.long).to(device) 203 | object_token = torch.as_tensor(object_token + [eot_token], dtype=torch.long).to(device) 204 | related_texts.append([action_token, object_token]) 205 | related_text_inputs.append(action_text + " " + object_text) 206 | 207 | related_hois = object_to_related_hois[query_object_text] 208 | for hoi in related_hois: 209 | hoi_id = hoi["hoi_id"] 210 | if hoi_id in unique_hois: 211 | continue 212 | if hoi_id in unique_related_hois: 213 | continue 214 | else: 215 | unique_related_hois.add(hoi_id) 216 | 217 | action_text, object_text = hoi["text"] 218 | action_token = _tokenizer.encode(action_text.replace("_", " ")) 219 | object_token = _tokenizer.encode(object_text.replace("_", " ")) 220 | action_token = torch.as_tensor([sot_token] + action_token, dtype=torch.long).to(device) 221 | object_token = torch.as_tensor(object_token + [eot_token], dtype=torch.long).to(device) 222 | related_texts.append([action_token, object_token]) 223 | related_text_inputs.append(action_text + " " + object_text) 224 | texts.extend(related_texts) 225 | 226 | return images, targets, texts 227 | 228 | 229 | @torch.no_grad() 230 | def prepare_text_inputs(model, texts, device): 231 | sot_token = _tokenizer.encoder["<|startoftext|>"] 232 | eot_token = _tokenizer.encoder["<|endoftext|>"] 233 | 234 | text_tokens = [] 235 | for action_text, object_text in texts: 236 | action_token = _tokenizer.encode(action_text.replace("_", " ")) 237 | object_token = _tokenizer.encode(object_text.replace("_", " ")) 238 | 239 | action_token = torch.as_tensor([sot_token] + action_token, dtype=torch.long).to(device) 240 | object_token = torch.as_tensor(object_token + [eot_token], dtype=torch.long).to(device) 241 | text_tokens.append([action_token, object_token]) 242 | 243 | text_features = model.encode_text(text_tokens) 244 | text_features /= text_features.norm(dim=-1, keepdim=True) 245 | return text_features 246 | 247 | 248 | def get_flop_stats(model, data_loader): 249 | """ 250 | Compute the gflops for the current model given the config. 251 | Args: 252 | model (model): model to compute the flop counts. 253 | cfg (CfgNode): configs. Details can be found in 254 | slowfast/config/defaults.py 255 | is_train (bool): if True, compute flops for training. Otherwise, 256 | compute flops for testing. 257 | Returns: 258 | float: the total number of gflops of the given model. 259 | """ 260 | inputs = _get_model_analysis_input(data_loader) 261 | flops = FlopCountAnalysis(model, inputs) 262 | print("Total FLOPs(G)", flops.total() / 1e9) 263 | print(flop_count_table(flops, max_depth=4, show_param_shapes=False)) 264 | return flops 265 | 266 | 267 | def _get_model_analysis_input(data_loader): 268 | for images, targets in data_loader: 269 | images, targets, texts = prepare_inputs(images, targets, "cuda") 270 | inputs = (images.tensors, texts, images.mask) 271 | return inputs 272 | 273 | 274 | ''' deprecated, text 275 | def prepare_inputs(images, targets, device): 276 | """Prepare model inputs.""" 277 | images = images.to(device) 278 | targets = [{k: v.to(device) if k != "hois" else v for k, v in t.items()} for t in targets] 279 | 280 | sot_token = _tokenizer.encoder["<|startoftext|>"] 281 | eot_token = _tokenizer.encoder["<|endoftext|>"] 282 | 283 | texts = [] 284 | text_inputs = [] 285 | unique_hois = set() 286 | 287 | for t in targets: 288 | for hoi in t["hois"]: 289 | # Ensure all texts are unique (no duplicates). 290 | hoi_id = hoi["hoi_id"] 291 | if hoi_id in unique_hois: 292 | continue 293 | else: 294 | unique_hois.add(hoi_id) 295 | action_text, object_text = hoi["text"] 296 | action_token = _tokenizer.encode(action_text.replace("_", " ")) 297 | object_token = _tokenizer.encode(object_text.replace("_", " ")) 298 | 299 | action_token = torch.as_tensor([sot_token] + action_token, dtype=torch.long).to(device) 300 | object_token = torch.as_tensor(object_token + [eot_token], dtype=torch.long).to(device) 301 | texts.append([action_token, object_token]) 302 | text_inputs.append(action_text + " " + object_text) 303 | 304 | return images, targets, texts 305 | ''' -------------------------------------------------------------------------------- /figures/THID_arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scwangdyd/promting_hoi/29938ccbcb7c8206873a984628a132064c769270/figures/THID_arch.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Modified by Suchen for HOI detection 3 | 4 | import argparse 5 | import datetime 6 | import json 7 | import random 8 | import time 9 | from pathlib import Path 10 | 11 | import numpy as np 12 | import torch 13 | from torch.utils.data import DataLoader, DistributedSampler 14 | 15 | import utils.misc as utils 16 | from datasets import build_dataset 17 | from engine import train_one_epoch, evaluate, get_flop_stats 18 | from models import build_model 19 | from arguments import get_args_parser 20 | from utils.scheduler import create_scheduler 21 | 22 | 23 | def main(args): 24 | """Training and evaluation function""" 25 | 26 | # distributed data parallel setup 27 | utils.init_distributed_mode(args) 28 | print("git:\n {}\n".format(utils.get_sha())) 29 | print(args) 30 | 31 | device = torch.device(args.device) 32 | 33 | # fix the seed for reproducibility 34 | seed = args.seed + utils.get_rank() 35 | torch.manual_seed(seed) 36 | np.random.seed(seed) 37 | random.seed(seed) 38 | 39 | model, criterion, postprocessors = build_model(args) 40 | model.to(device) 41 | 42 | model_without_ddp = model 43 | if args.distributed: 44 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True) 45 | model_without_ddp = model.module 46 | 47 | # optimizer setup 48 | def build_optimizer(model): 49 | # * frozen CLIP model 50 | update_modules, update_params = [], [] 51 | frozen_modules, frozen_params = [], [] 52 | for n, p in model.named_parameters(): 53 | if 'hoi' in n or 'bbox' in n: 54 | update_modules.append(n) 55 | update_params.append(p) 56 | else: 57 | frozen_modules.append(n) 58 | frozen_params.append(p) 59 | p.requires_grad = False 60 | 61 | optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 62 | lr=args.lr, weight_decay=args.weight_decay) 63 | return optimizer 64 | 65 | optimizer = build_optimizer(model_without_ddp) 66 | lr_scheduler, _ = create_scheduler(args, optimizer) 67 | 68 | n_parameters = sum(p.numel() for p in model_without_ddp.parameters() if p.requires_grad) 69 | print('number of trainable params:', n_parameters, f'{n_parameters/1e6:.3f}M') 70 | 71 | dataset_train = build_dataset(image_set='train', args=args) 72 | dataset_val = build_dataset(image_set='val', args=args) 73 | print('# train:', len(dataset_train), ', # val', len(dataset_val)) 74 | 75 | if args.distributed: 76 | sampler_train = DistributedSampler(dataset_train) 77 | sampler_val = DistributedSampler(dataset_val, shuffle=False) 78 | else: 79 | sampler_train = torch.utils.data.RandomSampler(dataset_train) 80 | sampler_val = torch.utils.data.SequentialSampler(dataset_val) 81 | 82 | batch_sampler_train = torch.utils.data.BatchSampler( 83 | sampler_train, args.batch_size, drop_last=True) 84 | 85 | data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train, 86 | collate_fn=utils.collate_fn, num_workers=args.num_workers) 87 | data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val, 88 | drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers) 89 | 90 | output_dir = Path(args.output_dir) 91 | 92 | # resume from the given checkpoint 93 | if args.resume: 94 | if args.resume.startswith('https'): 95 | checkpoint = torch.hub.load_state_dict_from_url( 96 | args.resume, map_location='cpu', check_hash=True) 97 | else: 98 | print(f"load checkpoint from {args.resume}") 99 | checkpoint = torch.load(args.resume, map_location='cpu') 100 | model_without_ddp.load_state_dict(checkpoint['model']) 101 | if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 102 | optimizer.load_state_dict(checkpoint['optimizer']) 103 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 104 | args.start_epoch = checkpoint['epoch'] + 1 105 | 106 | # evaluation 107 | if args.eval: 108 | # print FLOPs 109 | # get_flop_stats(model, data_loader_val) 110 | test_stats, evaluator = evaluate(model, postprocessors, criterion, data_loader_val, device, args) 111 | if args.output_dir: 112 | evaluator.save(args.output_dir) 113 | return 114 | 115 | print("Start training") 116 | start_time = time.time() 117 | for epoch in range(args.start_epoch, args.epochs): 118 | if args.distributed: 119 | sampler_train.set_epoch(epoch) 120 | train_stats = train_one_epoch(model, criterion, data_loader_train, optimizer, 121 | device, epoch, args.clip_max_norm) 122 | lr_scheduler.step(epoch) 123 | 124 | # checkpoint saving 125 | if args.output_dir: 126 | checkpoint_paths = [output_dir / 'checkpoint.pth'] 127 | # extra checkpoint before LR drop and every 100 epochs 128 | if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 100 == 0: 129 | checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth') 130 | for checkpoint_path in checkpoint_paths: 131 | utils.save_on_master({ 132 | 'model': model_without_ddp.state_dict(), 133 | 'optimizer': optimizer.state_dict(), 134 | 'lr_scheduler': lr_scheduler.state_dict(), 135 | 'epoch': epoch, 136 | 'args': args, 137 | }, checkpoint_path) 138 | 139 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 140 | 'epoch': epoch, 141 | 'n_parameters': n_parameters} 142 | 143 | if args.output_dir and utils.is_main_process(): 144 | with (output_dir / "log.txt").open("a") as f: 145 | f.write(json.dumps(log_stats) + "\n") 146 | 147 | total_time = time.time() - start_time 148 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 149 | print('Training time {}'.format(total_time_str)) 150 | 151 | 152 | if __name__ == '__main__': 153 | parser = argparse.ArgumentParser('Training and evaluation script', parents=[get_args_parser()]) 154 | args = parser.parse_args() 155 | if args.output_dir: 156 | Path(args.output_dir).mkdir(parents=True, exist_ok=True) 157 | main(args) -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import build_model -------------------------------------------------------------------------------- /models/criterion.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Modified by Suchen for HOI detection 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from utils import box_ops 7 | from utils.misc import accuracy, get_world_size, is_dist_avail_and_initialized 8 | 9 | 10 | class SetCriterion(nn.Module): 11 | """ This class computes the loss for DETR. 12 | The process happens in two steps: 13 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 14 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 15 | """ 16 | def __init__(self, matcher, weight_dict, eos_coef, losses): 17 | """ Create the criterion. 18 | Parameters: 19 | matcher: module able to compute a matching between targets and proposals 20 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 21 | eos_coef: relative classification weight applied to the no-object category 22 | losses: list of all the losses to be applied. See get_loss for list of available losses. 23 | """ 24 | super().__init__() 25 | self.matcher = matcher 26 | self.weight_dict = weight_dict 27 | self.eos_coef = eos_coef 28 | self.losses = losses 29 | 30 | def loss_labels(self, outputs, targets, indices, num_boxes, log=True): 31 | """Classification loss (NLL) 32 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 33 | """ 34 | assert 'logits_per_hoi' in outputs 35 | src_logits = outputs['logits_per_hoi'] 36 | target_classes_i, target_classes_t = self._get_tgt_labels(targets, indices, src_logits.device) 37 | 38 | idx = self._get_src_permutation_idx(indices) 39 | # image-to-text alignment loss 40 | loss_i = F.cross_entropy(src_logits[idx], target_classes_i) 41 | # text-to-image alignment loss 42 | if self.training: 43 | num_tgts = target_classes_t.shape[1] 44 | loss_t = self.masked_out_cross_entropy(src_logits[idx][:, :num_tgts].t(), target_classes_t.t()) 45 | losses = {"loss_ce": (loss_i + loss_t) / 2} 46 | else: 47 | losses = {'loss_ce': loss_i} 48 | 49 | if log: 50 | # TODO this should probably be a separate loss, not hacked in this one here 51 | losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_i)[0] 52 | return losses 53 | 54 | def masked_out_cross_entropy(self, src_logits, target_classes): 55 | loss = 0 56 | num_pos = target_classes.sum(dim=-1) 57 | # If there is only one active positive label, then this will be ordinary cross entropy 58 | indices = torch.nonzero(num_pos < 2, as_tuple=True)[0] 59 | targets_one_pos = torch.argmax(target_classes[indices], dim=-1) 60 | loss += F.cross_entropy(src_logits[indices], targets_one_pos, reduction="sum") 61 | 62 | # If there are multiple positive labels, then we compute them one by one. Each time, 63 | # the other positive labels are masked out. 64 | indices = torch.nonzero(num_pos > 1, as_tuple=True)[0] 65 | for i in indices: 66 | t = target_classes[i] 67 | cnt = sum(t) 68 | loss_t = 0 69 | for j in torch.nonzero(t): 70 | mask = (t == 0) 71 | mask[j] = True 72 | tgt = t[mask].argmax(dim=-1, keepdim=True) 73 | loss_t += F.cross_entropy(src_logits[i:i+1, mask], tgt, reduction="sum") 74 | loss += (loss_t / cnt) 75 | loss = loss / len(src_logits) 76 | return loss 77 | 78 | def loss_confidences(self, outputs, targets, indices, num_boxes, log=True): 79 | """ Bounding box confidence score for the interaction prediction. """ 80 | assert 'box_scores' in outputs 81 | box_scores = outputs['box_scores'].sigmoid() 82 | 83 | idx = self._get_src_permutation_idx(indices) 84 | target_classes_o = torch.cat([torch.ones(len(J), dtype=torch.int64, device=box_scores.device) for t, (_, J) in zip(targets, indices)]) 85 | target_classes = torch.full(box_scores.shape[:2], 0, dtype=torch.int64, device=box_scores.device) 86 | target_classes[idx] = target_classes_o 87 | target_classes = target_classes.to(box_scores.dtype) 88 | 89 | weight = torch.ones_like(target_classes) * self.eos_coef 90 | weight[idx] = 1. 91 | loss_conf = F.binary_cross_entropy(box_scores.flatten(), target_classes.flatten(), weight=weight.flatten()) 92 | losses = {'loss_conf': loss_conf} 93 | return losses 94 | 95 | def loss_boxes(self, outputs, targets, indices, num_boxes, log=False): 96 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 97 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 98 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 99 | """ 100 | assert 'pred_boxes' in outputs 101 | idx = self._get_src_permutation_idx(indices) 102 | src_boxes = outputs['pred_boxes'][idx] 103 | target_boxes = [] 104 | for t, (_, indices_per_t) in zip(targets, indices): 105 | for i in indices_per_t: 106 | person_id = t["hois"][i]["subject_id"] 107 | object_id = t["hois"][i]["object_id"] 108 | target_boxes.append(torch.cat([t["boxes"][person_id], t["boxes"][object_id]])) 109 | target_boxes = torch.stack(target_boxes, dim=0) 110 | 111 | loss_pbbox = F.l1_loss(src_boxes[:, :4], target_boxes[:, :4], reduction='none') 112 | loss_obbox = F.l1_loss(src_boxes[:, 4:], target_boxes[:, 4:], reduction='none') 113 | 114 | losses = {} 115 | losses['loss_bbox'] = loss_pbbox.sum() / num_boxes + loss_obbox.sum() / num_boxes 116 | 117 | loss_pgiou = 1 - torch.diag(box_ops.generalized_box_iou( 118 | box_ops.box_cxcywh_to_xyxy(src_boxes[:, :4]), 119 | box_ops.box_cxcywh_to_xyxy(target_boxes[:, :4]))) 120 | loss_ogiou = 1 - torch.diag(box_ops.generalized_box_iou( 121 | box_ops.box_cxcywh_to_xyxy(src_boxes[:, 4:]), 122 | box_ops.box_cxcywh_to_xyxy(target_boxes[:, 4:]))) 123 | 124 | losses['loss_giou'] = loss_pgiou.sum() / num_boxes + loss_ogiou.sum() / num_boxes 125 | return losses 126 | 127 | def _get_src_permutation_idx(self, indices): 128 | # permute predictions following indices 129 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 130 | src_idx = torch.cat([src for (src, _) in indices]) 131 | return batch_idx, src_idx 132 | 133 | def _get_tgt_permutation_idx(self, indices): 134 | # permute targets following indices 135 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 136 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 137 | return batch_idx, tgt_idx 138 | 139 | def _get_tgt_labels(self, targets, indices, device): 140 | if self.training: 141 | unique_hois, cnt = {}, 0 # Get unique hoi ids in the mini-batch 142 | for t in targets: 143 | for hoi in t["hois"]: 144 | hoi_id = hoi["hoi_id"] 145 | if hoi_id not in unique_hois: 146 | unique_hois[hoi_id] = cnt 147 | cnt += 1 148 | target_classes_i = [] 149 | for t, (_, indices_per_t) in zip(targets, indices): 150 | for i in indices_per_t: 151 | hoi_id = t["hois"][i]["hoi_id"] 152 | target_classes_i.append(unique_hois[hoi_id]) 153 | 154 | num_fgs = len(torch.cat([src for (src, _) in indices])) 155 | target_classes_t = torch.zeros((num_fgs, cnt), dtype=torch.int64) 156 | for i, cls_id in zip(range(len(target_classes_i)), target_classes_i): 157 | target_classes_t[i, cls_id] = 1 158 | target_classes_t = target_classes_t.to(device) 159 | else: 160 | target_classes_i = [] 161 | for t, (_, indices_per_t) in zip(targets, indices): 162 | for i in indices_per_t: 163 | target_classes_i.append(t["hois"][int(i)]["hoi_id"]) 164 | target_classes_t = None # Skip the calculation of text-to-image alignment at inference 165 | target_classes_i = torch.as_tensor(target_classes_i).to(device) 166 | return target_classes_i, target_classes_t 167 | 168 | def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): 169 | loss_map = { 170 | 'labels': self.loss_labels, 171 | 'boxes': self.loss_boxes, 172 | "confidences": self.loss_confidences, 173 | } 174 | assert loss in loss_map, f'do you really want to compute {loss} loss?' 175 | return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs) 176 | 177 | def forward(self, outputs, targets): 178 | """ This performs the loss computation. 179 | Parameters: 180 | outputs: dict of tensors, see the output specification of the model for the format 181 | targets: list of dicts, such that len(targets) == batch_size. 182 | The expected keys in each dict depends on the losses applied, see each loss' doc 183 | """ 184 | # Retrieve the matching between the outputs of the last layer and the targets 185 | indices = self.matcher(outputs, targets) 186 | 187 | # Compute the average number of target boxes accross all nodes, for normalization purposes 188 | num_boxes = sum(len(t["hois"]) for t in targets) 189 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) 190 | if is_dist_avail_and_initialized(): 191 | torch.distributed.all_reduce(num_boxes) 192 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 193 | 194 | # Compute all the requested losses 195 | losses = {} 196 | for loss in self.losses: 197 | losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes)) 198 | 199 | if 'aux_outputs' in outputs: 200 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 201 | aux_outputs.update({'logits_per_hoi': outputs['logits_per_hoi']}) 202 | indices = self.matcher(aux_outputs, targets) 203 | for loss in ['boxes', 'confidences']: 204 | l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes) 205 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 206 | losses.update(l_dict) 207 | 208 | return losses, indices -------------------------------------------------------------------------------- /models/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Modified by Suchen for HOI detection 3 | """ 4 | Modules to compute the matching cost and solve the corresponding LSAP. 5 | """ 6 | """ 7 | Modules to compute the matching cost and solve the corresponding LSAP. 8 | """ 9 | import torch 10 | from scipy.optimize import linear_sum_assignment 11 | from torch import nn 12 | 13 | from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 14 | 15 | 16 | class HungarianMatcher(nn.Module): 17 | """This class computes an assignment between the targets and the predictions of the network 18 | 19 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 20 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 21 | while the others are un-matched (and thus treated as non-objects). 22 | """ 23 | 24 | def __init__( 25 | self, 26 | cost_class: float = 1, 27 | cost_bbox: float = 1, 28 | cost_giou: float = 1, 29 | cost_conf: float = 1, 30 | ): 31 | """Creates the matcher 32 | 33 | Params: 34 | cost_class: This is the relative weight of the classification error in the matching cost 35 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 36 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 37 | """ 38 | super().__init__() 39 | self.cost_class = cost_class 40 | self.cost_bbox = cost_bbox 41 | self.cost_giou = cost_giou 42 | self.cost_conf = cost_conf 43 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0 or cost_conf != 0, "all costs cant be 0" 44 | 45 | @torch.no_grad() 46 | def forward(self, outputs, targets): 47 | """ Performs the matching 48 | 49 | Params: 50 | outputs: This is a dict that contains at least these entries: 51 | "logits_per_hoi": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 52 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 53 | "box_scores": Tensor of dim [batch_size, num_queries, 1] with the predicted box confidence scores 54 | 55 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 56 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 57 | objects in the target) containing the class labels 58 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 59 | 60 | Returns: 61 | A list of size batch_size, containing tuples of (index_i, index_j) where: 62 | - index_i is the indices of the selected predictions (in order) 63 | - index_j is the indices of the corresponding selected targets (in order) 64 | For each batch element, it holds: 65 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 66 | """ 67 | bs, num_queries = outputs["logits_per_hoi"].shape[:2] 68 | 69 | # We flatten to compute the cost matrices in a batch 70 | out_prob = outputs["logits_per_hoi"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 71 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 8] 72 | out_conf = outputs["box_scores"].flatten(0, 1).sigmoid() # [batch_size * num_queries, 1] 73 | 74 | # Also concat the target labels and boxes. During the training, due to the limit 75 | # GPU memory, we also consider the texts within each mini-batch. Differently, during 76 | # the inference, we consider all interactions in the dataset. 77 | unique_hois, cnt = {}, 0 78 | tgt_ids = [] 79 | for t in targets: 80 | for hoi in t["hois"]: 81 | hoi_id = hoi["hoi_id"] 82 | if self.training: 83 | # Only consider the texts within each mini-batch 84 | if hoi_id not in unique_hois: 85 | unique_hois[hoi_id] = cnt 86 | cnt += 1 87 | tgt_ids.append(unique_hois[hoi_id]) 88 | else: 89 | # Consider all hois in the dataset 90 | tgt_ids.append(hoi_id) 91 | tgt_ids = torch.as_tensor(tgt_ids, dtype=torch.int64, device=out_prob.device) 92 | 93 | tgt_bbox = [torch.cat([t["boxes"][hoi["subject_id"]], t["boxes"][hoi["object_id"]]]) 94 | for t in targets for hoi in t["hois"]] 95 | tgt_bbox = torch.stack(tgt_bbox, dim=0) 96 | 97 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 98 | # but approximate it in 1 - proba[target class]. 99 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 100 | cost_class = -out_prob[:, tgt_ids] 101 | 102 | # Compute the confidence cost 103 | cost_conf = -out_conf 104 | 105 | # Compute the L1 cost between boxes 106 | if out_bbox.dtype == torch.float16: 107 | out_bbox = out_bbox.type(torch.float32) 108 | cost_pbbox = torch.cdist(out_bbox[:, :4], tgt_bbox[:, :4], p=1) 109 | cost_obbox = torch.cdist(out_bbox[:, 4:], tgt_bbox[:, 4:], p=1) 110 | 111 | # Compute the giou cost betwen boxes 112 | cost_pgiou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox[:, :4]), box_cxcywh_to_xyxy(tgt_bbox[:, :4])) 113 | cost_ogiou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox[:, 4:]), box_cxcywh_to_xyxy(tgt_bbox[:, 4:])) 114 | 115 | # Final cost matrix 116 | C = self.cost_bbox * cost_pbbox + self.cost_bbox * cost_obbox + \ 117 | self.cost_giou * cost_pgiou + self.cost_giou * cost_ogiou + \ 118 | self.cost_class * cost_class + self.cost_conf * cost_conf 119 | C = C.view(bs, num_queries, -1).cpu() 120 | 121 | sizes = [len(v["hois"]) for v in targets] 122 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 123 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 124 | 125 | 126 | def build_matcher(args): 127 | return HungarianMatcher( 128 | cost_class=args.set_cost_class, 129 | cost_bbox=args.set_cost_bbox, 130 | cost_giou=args.set_cost_giou, 131 | cost_conf=args.set_cost_conf, 132 | ) -------------------------------------------------------------------------------- /models/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from utils.misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, mask: torch.tensor): 29 | assert mask is not None 30 | not_mask = ~mask 31 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 32 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 33 | if self.normalize: 34 | eps = 1e-6 35 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 36 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 37 | 38 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=mask.device) 39 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 40 | 41 | pos_x = x_embed[:, :, :, None] / dim_t 42 | pos_y = y_embed[:, :, :, None] / dim_t 43 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 44 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 45 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 46 | return pos 47 | 48 | 49 | class PositionEmbeddingLearned(nn.Module): 50 | """ 51 | Absolute pos embedding, learned. 52 | """ 53 | def __init__(self, num_pos_feats=256): 54 | super().__init__() 55 | self.row_embed = nn.Embedding(50, num_pos_feats) 56 | self.col_embed = nn.Embedding(50, num_pos_feats) 57 | self.reset_parameters() 58 | 59 | def reset_parameters(self): 60 | nn.init.uniform_(self.row_embed.weight) 61 | nn.init.uniform_(self.col_embed.weight) 62 | 63 | def forward(self, tensor_list: NestedTensor): 64 | x = tensor_list.tensors 65 | h, w = x.shape[-2:] 66 | i = torch.arange(w, device=x.device) 67 | j = torch.arange(h, device=x.device) 68 | x_emb = self.col_embed(i) 69 | y_emb = self.row_embed(j) 70 | pos = torch.cat([ 71 | x_emb.unsqueeze(0).repeat(h, 1, 1), 72 | y_emb.unsqueeze(1).repeat(1, w, 1), 73 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 74 | return pos 75 | 76 | 77 | def build_position_encoding(args): 78 | N_steps = args.hidden_dim // 2 79 | if args.position_embedding in ('v2', 'sine'): 80 | # TODO find a better way of exposing other arguments 81 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 82 | elif args.position_embedding in ('v3', 'learned'): 83 | position_embedding = PositionEmbeddingLearned(N_steps) 84 | else: 85 | raise ValueError(f"not supported {args.position_embedding}") 86 | 87 | return position_embedding -------------------------------------------------------------------------------- /models/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | # Modified by Suchen for HOI detection 3 | """ 4 | DETR Transformer class. 5 | 6 | Copy-paste from torch.nn.Transformer with modifications: 7 | * positional encodings are passed in MHattention 8 | * extra LN at the end of encoder is removed 9 | * decoder returns a stack of activations from all decoding layers 10 | """ 11 | import copy 12 | from typing import Optional, List 13 | 14 | import torch 15 | import torch.nn.functional as F 16 | from torch import nn, Tensor 17 | from clip.model import LayerNorm 18 | 19 | 20 | class TransformerDecoder(nn.Module): 21 | 22 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 23 | super().__init__() 24 | self.layers = _get_clones(decoder_layer, num_layers) 25 | self.num_layers = num_layers 26 | self.norm = norm 27 | self.return_intermediate = return_intermediate 28 | 29 | def forward(self, tgt, memory, 30 | tgt_mask: Optional[Tensor] = None, 31 | memory_mask: Optional[Tensor] = None, 32 | tgt_key_padding_mask: Optional[Tensor] = None, 33 | memory_key_padding_mask: Optional[Tensor] = None, 34 | pos: Optional[Tensor] = None, 35 | query_pos: Optional[Tensor] = None): 36 | output = tgt 37 | 38 | intermediate = [] 39 | 40 | for layer in self.layers: 41 | output = layer(output, memory, tgt_mask=tgt_mask, 42 | memory_mask=memory_mask, 43 | tgt_key_padding_mask=tgt_key_padding_mask, 44 | memory_key_padding_mask=memory_key_padding_mask, 45 | pos=pos, query_pos=query_pos) 46 | if self.return_intermediate: 47 | intermediate.append(self.norm(output)) 48 | 49 | if self.norm is not None: 50 | output = self.norm(output) 51 | if self.return_intermediate: 52 | intermediate.pop() 53 | intermediate.append(output) 54 | 55 | if self.return_intermediate: 56 | return torch.stack(intermediate) 57 | 58 | return output.unsqueeze(0) 59 | 60 | 61 | class TransformerDecoderLayer(nn.Module): 62 | 63 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 64 | activation="relu", normalize_before=False): 65 | super().__init__() 66 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 67 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 68 | # Implementation of Feedforward model 69 | self.linear1 = nn.Linear(d_model, dim_feedforward) 70 | self.dropout = nn.Dropout(dropout) 71 | self.linear2 = nn.Linear(dim_feedforward, d_model) 72 | 73 | self.norm1 = LayerNorm(d_model) 74 | self.norm2 = LayerNorm(d_model) 75 | self.norm3 = LayerNorm(d_model) 76 | self.dropout1 = nn.Dropout(dropout) 77 | self.dropout2 = nn.Dropout(dropout) 78 | self.dropout3 = nn.Dropout(dropout) 79 | 80 | self.activation = _get_activation_fn(activation) 81 | self.normalize_before = normalize_before 82 | 83 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 84 | return tensor if pos is None else tensor + pos 85 | 86 | def forward_post(self, tgt, memory, 87 | tgt_mask: Optional[Tensor] = None, 88 | memory_mask: Optional[Tensor] = None, 89 | tgt_key_padding_mask: Optional[Tensor] = None, 90 | memory_key_padding_mask: Optional[Tensor] = None, 91 | pos: Optional[Tensor] = None, 92 | query_pos: Optional[Tensor] = None): 93 | q = k = self.with_pos_embed(tgt, query_pos) 94 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 95 | key_padding_mask=tgt_key_padding_mask)[0] 96 | tgt = tgt + self.dropout1(tgt2) 97 | tgt = self.norm1(tgt) 98 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 99 | key=self.with_pos_embed(memory, pos), 100 | value=self.with_pos_embed(memory, pos), attn_mask=memory_mask, 101 | key_padding_mask=memory_key_padding_mask)[0] 102 | tgt = tgt + self.dropout2(tgt2) 103 | tgt = self.norm2(tgt) 104 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 105 | tgt = tgt + self.dropout3(tgt2) 106 | tgt = self.norm3(tgt) 107 | return tgt 108 | 109 | def forward_pre(self, tgt, memory, 110 | tgt_mask: Optional[Tensor] = None, 111 | memory_mask: Optional[Tensor] = None, 112 | tgt_key_padding_mask: Optional[Tensor] = None, 113 | memory_key_padding_mask: Optional[Tensor] = None, 114 | pos: Optional[Tensor] = None, 115 | query_pos: Optional[Tensor] = None): 116 | tgt2 = self.norm1(tgt) 117 | q = k = self.with_pos_embed(tgt2, query_pos) 118 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 119 | key_padding_mask=tgt_key_padding_mask)[0] 120 | tgt = tgt + self.dropout1(tgt2) 121 | tgt2 = self.norm2(tgt) 122 | tgt2, attn = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 123 | key=self.with_pos_embed(memory, pos), 124 | value=self.with_pos_embed(memory, pos), attn_mask=memory_mask, 125 | key_padding_mask=memory_key_padding_mask) 126 | tgt = tgt + self.dropout2(tgt2) 127 | tgt2 = self.norm3(tgt) 128 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 129 | tgt = tgt + self.dropout3(tgt2) 130 | return tgt 131 | 132 | def forward(self, tgt, memory, 133 | tgt_mask: Optional[Tensor] = None, 134 | memory_mask: Optional[Tensor] = None, 135 | tgt_key_padding_mask: Optional[Tensor] = None, 136 | memory_key_padding_mask: Optional[Tensor] = None, 137 | pos: Optional[Tensor] = None, 138 | query_pos: Optional[Tensor] = None): 139 | if self.normalize_before: 140 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 141 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 142 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 143 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 144 | 145 | 146 | def _get_clones(module, N): 147 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 148 | 149 | 150 | def _get_activation_fn(activation): 151 | """Return an activation function given a string""" 152 | if activation == "relu": 153 | return F.relu 154 | if activation == "gelu": 155 | return F.gelu 156 | if activation == "glu": 157 | return F.glu 158 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scwangdyd/promting_hoi/29938ccbcb7c8206873a984628a132064c769270/utils/__init__.py -------------------------------------------------------------------------------- /utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Misc functions, including distributed helpers. 4 | 5 | Mostly copy-paste from torchvision references. 6 | """ 7 | import os 8 | import subprocess 9 | import time 10 | from collections import defaultdict, deque 11 | import datetime 12 | import pickle 13 | from typing import Optional, List 14 | 15 | import torch 16 | import torch.distributed as dist 17 | from torch import Tensor 18 | 19 | # needed due to empty tensor bug in pytorch and torchvision 0.5 20 | import torchvision 21 | 22 | 23 | class SmoothedValue(object): 24 | """Track a series of values and provide access to smoothed values over a 25 | window or the global series average. 26 | """ 27 | 28 | def __init__(self, window_size=20, fmt=None): 29 | if fmt is None: 30 | fmt = "{median:.4f} ({global_avg:.4f})" 31 | self.deque = deque(maxlen=window_size) 32 | self.total = 0.0 33 | self.count = 0 34 | self.fmt = fmt 35 | 36 | def update(self, value, n=1): 37 | self.deque.append(value) 38 | self.count += n 39 | self.total += value * n 40 | 41 | def synchronize_between_processes(self): 42 | """ 43 | Warning: does not synchronize the deque! 44 | """ 45 | if not is_dist_avail_and_initialized(): 46 | return 47 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 48 | dist.barrier() 49 | dist.all_reduce(t) 50 | t = t.tolist() 51 | self.count = int(t[0]) 52 | self.total = t[1] 53 | 54 | @property 55 | def median(self): 56 | d = torch.tensor(list(self.deque)) 57 | return d.median().item() 58 | 59 | @property 60 | def avg(self): 61 | d = torch.tensor(list(self.deque), dtype=torch.float32) 62 | return d.mean().item() 63 | 64 | @property 65 | def global_avg(self): 66 | return self.total / self.count 67 | 68 | @property 69 | def max(self): 70 | return max(self.deque) 71 | 72 | @property 73 | def value(self): 74 | return self.deque[-1] 75 | 76 | def __str__(self): 77 | return self.fmt.format( 78 | median=self.median, 79 | avg=self.avg, 80 | global_avg=self.global_avg, 81 | max=self.max, 82 | value=self.value) 83 | 84 | 85 | def all_gather(data): 86 | """ 87 | Run all_gather on arbitrary picklable data (not necessarily tensors) 88 | Args: 89 | data: any picklable object 90 | Returns: 91 | list[data]: list of data gathered from each rank 92 | """ 93 | world_size = get_world_size() 94 | if world_size == 1: 95 | return [data] 96 | 97 | # serialized to a Tensor 98 | buffer = pickle.dumps(data) 99 | storage = torch.ByteStorage.from_buffer(buffer) 100 | tensor = torch.ByteTensor(storage).to("cuda") 101 | 102 | # obtain Tensor size of each rank 103 | local_size = torch.tensor([tensor.numel()], device="cuda") 104 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 105 | dist.all_gather(size_list, local_size) 106 | size_list = [int(size.item()) for size in size_list] 107 | max_size = max(size_list) 108 | 109 | # receiving Tensor from all ranks 110 | # we pad the tensor because torch all_gather does not support 111 | # gathering tensors of different shapes 112 | tensor_list = [] 113 | for _ in size_list: 114 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 115 | if local_size != max_size: 116 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 117 | tensor = torch.cat((tensor, padding), dim=0) 118 | dist.all_gather(tensor_list, tensor) 119 | 120 | data_list = [] 121 | for size, tensor in zip(size_list, tensor_list): 122 | buffer = tensor.cpu().numpy().tobytes()[:size] 123 | data_list.append(pickle.loads(buffer)) 124 | 125 | return data_list 126 | 127 | 128 | def reduce_dict(input_dict, average=True): 129 | """ 130 | Args: 131 | input_dict (dict): all the values will be reduced 132 | average (bool): whether to do average or sum 133 | Reduce the values in the dictionary from all processes so that all processes 134 | have the averaged results. Returns a dict with the same fields as 135 | input_dict, after reduction. 136 | """ 137 | world_size = get_world_size() 138 | if world_size < 2: 139 | return input_dict 140 | with torch.no_grad(): 141 | names = [] 142 | values = [] 143 | # sort the keys so that they are consistent across processes 144 | for k in sorted(input_dict.keys()): 145 | names.append(k) 146 | values.append(input_dict[k]) 147 | values = torch.stack(values, dim=0) 148 | dist.all_reduce(values) 149 | if average: 150 | values /= world_size 151 | reduced_dict = {k: v for k, v in zip(names, values)} 152 | return reduced_dict 153 | 154 | 155 | class MetricLogger(object): 156 | def __init__(self, delimiter="\t"): 157 | self.meters = defaultdict(SmoothedValue) 158 | self.delimiter = delimiter 159 | 160 | def update(self, **kwargs): 161 | for k, v in kwargs.items(): 162 | if isinstance(v, torch.Tensor): 163 | v = v.item() 164 | assert isinstance(v, (float, int)) 165 | self.meters[k].update(v) 166 | 167 | def __getattr__(self, attr): 168 | if attr in self.meters: 169 | return self.meters[attr] 170 | if attr in self.__dict__: 171 | return self.__dict__[attr] 172 | raise AttributeError("'{}' object has no attribute '{}'".format( 173 | type(self).__name__, attr)) 174 | 175 | def __str__(self): 176 | loss_str = [] 177 | for name, meter in self.meters.items(): 178 | loss_str.append( 179 | "{}: {}".format(name, str(meter)) 180 | ) 181 | return self.delimiter.join(loss_str) 182 | 183 | def synchronize_between_processes(self): 184 | for meter in self.meters.values(): 185 | meter.synchronize_between_processes() 186 | 187 | def add_meter(self, name, meter): 188 | self.meters[name] = meter 189 | 190 | def log_every(self, iterable, print_freq, header=None): 191 | i = 0 192 | if not header: 193 | header = '' 194 | start_time = time.time() 195 | end = time.time() 196 | iter_time = SmoothedValue(fmt='{avg:.4f}') 197 | data_time = SmoothedValue(fmt='{avg:.4f}') 198 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 199 | if torch.cuda.is_available(): 200 | log_msg = self.delimiter.join([ 201 | header, 202 | '[{0' + space_fmt + '}/{1}]', 203 | 'eta: {eta}', 204 | '{meters}', 205 | 'time: {time}', 206 | 'data: {data}', 207 | 'max mem: {memory:.0f}' 208 | ]) 209 | else: 210 | log_msg = self.delimiter.join([ 211 | header, 212 | '[{0' + space_fmt + '}/{1}]', 213 | 'eta: {eta}', 214 | '{meters}', 215 | 'time: {time}', 216 | 'data: {data}' 217 | ]) 218 | MB = 1024.0 * 1024.0 219 | for obj in iterable: 220 | data_time.update(time.time() - end) 221 | yield obj 222 | iter_time.update(time.time() - end) 223 | if i % print_freq == 0 or i == len(iterable) - 1: 224 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 225 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 226 | if torch.cuda.is_available(): 227 | print(log_msg.format( 228 | i, len(iterable), eta=eta_string, 229 | meters=str(self), 230 | time=str(iter_time), data=str(data_time), 231 | memory=torch.cuda.max_memory_allocated() / MB)) 232 | else: 233 | print(log_msg.format( 234 | i, len(iterable), eta=eta_string, 235 | meters=str(self), 236 | time=str(iter_time), data=str(data_time))) 237 | i += 1 238 | end = time.time() 239 | total_time = time.time() - start_time 240 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 241 | print('{} Total time: {} ({:.4f} s / it)'.format( 242 | header, total_time_str, total_time / len(iterable))) 243 | 244 | 245 | def get_sha(): 246 | cwd = os.path.dirname(os.path.abspath(__file__)) 247 | 248 | def _run(command): 249 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 250 | sha = 'N/A' 251 | diff = "clean" 252 | branch = 'N/A' 253 | try: 254 | sha = _run(['git', 'rev-parse', 'HEAD']) 255 | subprocess.check_output(['git', 'diff'], cwd=cwd) 256 | diff = _run(['git', 'diff-index', 'HEAD']) 257 | diff = "has uncommited changes" if diff else "clean" 258 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 259 | except Exception: 260 | pass 261 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 262 | return message 263 | 264 | 265 | def collate_fn(batch): 266 | batch = list(zip(*batch)) 267 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 268 | return tuple(batch) 269 | 270 | 271 | def _max_by_axis(the_list): 272 | # type: (List[List[int]]) -> List[int] 273 | maxes = the_list[0] 274 | for sublist in the_list[1:]: 275 | for index, item in enumerate(sublist): 276 | maxes[index] = max(maxes[index], item) 277 | return maxes 278 | 279 | 280 | def nested_tensor_from_tensor_list(tensor_list): 281 | # TODO make this more general 282 | if tensor_list[0].ndim == 3: 283 | if torchvision._is_tracing(): 284 | # nested_tensor_from_tensor_list() does not export well to ONNX 285 | # call _onnx_nested_tensor_from_tensor_list() instead 286 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 287 | 288 | # TODO make it support different-sized images 289 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 290 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 291 | batch_shape = [len(tensor_list)] + max_size 292 | b, c, h, w = batch_shape 293 | dtype = tensor_list[0].dtype 294 | device = tensor_list[0].device 295 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 296 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 297 | for img, pad_img, m in zip(tensor_list, tensor, mask): 298 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 299 | m[: img.shape[1], :img.shape[2]] = False 300 | else: 301 | raise ValueError('not supported') 302 | 303 | ''' 304 | if "image_mask" in target_list[0]: 305 | # merge the masks 306 | for i in range(len(target_list)): 307 | preprocess_img_mask = target_list[i].pop("image_mask") 308 | mask[i] = torch.logical_or(mask[i], preprocess_img_mask) 309 | ''' 310 | 311 | return NestedTensor(tensor, mask) 312 | 313 | 314 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 315 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 316 | @torch.jit.unused 317 | def _onnx_nested_tensor_from_tensor_list(tensor_list): 318 | max_size = [] 319 | for i in range(tensor_list[0].dim()): 320 | max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) 321 | max_size.append(max_size_i) 322 | max_size = tuple(max_size) 323 | 324 | # work around for 325 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 326 | # m[: img.shape[1], :img.shape[2]] = False 327 | # which is not yet supported in onnx 328 | padded_imgs = [] 329 | padded_masks = [] 330 | for img in tensor_list: 331 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 332 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 333 | padded_imgs.append(padded_img) 334 | 335 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 336 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 337 | padded_masks.append(padded_mask.to(torch.bool)) 338 | 339 | tensor = torch.stack(padded_imgs) 340 | mask = torch.stack(padded_masks) 341 | 342 | return NestedTensor(tensor, mask=mask) 343 | 344 | 345 | class NestedTensor(object): 346 | def __init__(self, tensors, mask: Optional[Tensor]): 347 | self.tensors = tensors 348 | self.mask = mask 349 | 350 | def to(self, device): 351 | # type: (Device) -> NestedTensor # noqa 352 | cast_tensor = self.tensors.to(device) 353 | mask = self.mask 354 | if mask is not None: 355 | assert mask is not None 356 | cast_mask = mask.to(device) 357 | else: 358 | cast_mask = None 359 | return NestedTensor(cast_tensor, cast_mask) 360 | 361 | def decompose(self): 362 | return self.tensors, self.mask 363 | 364 | def __repr__(self): 365 | return str(self.tensors) 366 | 367 | 368 | def setup_for_distributed(is_master): 369 | """ 370 | This function disables printing when not in master process 371 | """ 372 | import builtins as __builtin__ 373 | builtin_print = __builtin__.print 374 | 375 | def print(*args, **kwargs): 376 | force = kwargs.pop('force', False) 377 | if is_master or force: 378 | builtin_print(*args, **kwargs) 379 | 380 | __builtin__.print = print 381 | 382 | 383 | def is_dist_avail_and_initialized(): 384 | if not dist.is_available(): 385 | return False 386 | if not dist.is_initialized(): 387 | return False 388 | return True 389 | 390 | 391 | def get_world_size(): 392 | if not is_dist_avail_and_initialized(): 393 | return 1 394 | return dist.get_world_size() 395 | 396 | 397 | def get_rank(): 398 | if not is_dist_avail_and_initialized(): 399 | return 0 400 | return dist.get_rank() 401 | 402 | 403 | def is_main_process(): 404 | return get_rank() == 0 405 | 406 | 407 | def save_on_master(*args, **kwargs): 408 | if is_main_process(): 409 | torch.save(*args, **kwargs) 410 | 411 | 412 | def init_distributed_mode(args): 413 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 414 | args.rank = int(os.environ["RANK"]) 415 | args.world_size = int(os.environ['WORLD_SIZE']) 416 | args.gpu = int(os.environ['LOCAL_RANK']) 417 | elif 'SLURM_PROCID' in os.environ: 418 | args.rank = int(os.environ['SLURM_PROCID']) 419 | args.gpu = args.rank % torch.cuda.device_count() 420 | else: 421 | print('Not using distributed mode') 422 | args.distributed = False 423 | return 424 | 425 | args.distributed = True 426 | 427 | torch.cuda.set_device(args.gpu) 428 | args.dist_backend = 'nccl' 429 | print('| distributed init (rank {}): {}'.format( 430 | args.rank, args.dist_url), flush=True) 431 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 432 | world_size=args.world_size, rank=args.rank) 433 | torch.distributed.barrier() 434 | setup_for_distributed(args.rank == 0) 435 | 436 | 437 | @torch.no_grad() 438 | def accuracy(output, target, topk=(1,)): 439 | """Computes the precision@k for the specified values of k""" 440 | if target.numel() == 0: 441 | return [torch.zeros([], device=output.device)] 442 | maxk = max(topk) 443 | batch_size = target.size(0) 444 | 445 | _, pred = output.topk(maxk, 1, True, True) 446 | pred = pred.t() 447 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 448 | 449 | res = [] 450 | for k in topk: 451 | correct_k = correct[:k].view(-1).float().sum(0) 452 | res.append(correct_k.mul_(100.0 / batch_size)) 453 | return res 454 | 455 | 456 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 457 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 458 | """ 459 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 460 | This will eventually be supported natively by PyTorch, and this 461 | class can go away. 462 | """ 463 | if float(torchvision.__version__[:3]) < 0.7: 464 | if input.numel() > 0: 465 | return torch.nn.functional.interpolate( 466 | input, size, scale_factor, mode, align_corners 467 | ) 468 | 469 | output_shape = _output_size(2, input, size, scale_factor) 470 | output_shape = list(input.shape[:-2]) + list(output_shape) 471 | return _new_empty_tensor(input, output_shape) 472 | else: 473 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 474 | 475 | 476 | def inverse_sigmoid(x, eps=1e-5): 477 | x = x.clamp(min=0, max=1) 478 | x1 = x.clamp(min=eps) 479 | x2 = (1 - x).clamp(min=eps) 480 | return torch.log(x1/x2) -------------------------------------------------------------------------------- /utils/sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | import torch 4 | 5 | 6 | def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): 7 | """ 8 | Compute (fractional) per-image repeat factors based on category frequency. 9 | The repeat factor for an image is a function of the frequency of the rarest 10 | category labeled in that image. The "frequency of category c" in [0, 1] is defined 11 | as the fraction of images in the training set (without repeats) in which category c 12 | appears. 13 | See :paper:`lvis` (>= v2) Appendix B.2. 14 | 15 | Args: 16 | dataset_dicts (list[dict]): annotations in Detectron2 dataset format. 17 | repeat_thresh (float): frequency threshold below which data is repeated. 18 | If the frequency is half of `repeat_thresh`, the image will be 19 | repeated twice. 20 | 21 | Returns: 22 | torch.Tensor: the i-th element is the repeat factor for the dataset image 23 | at index i. 24 | """ 25 | # 1. For each interaction c, compute the fraction of images that contain it: f(c) 26 | interaction_freq = defaultdict(int) 27 | for dataset_dict in dataset_dicts: # For each image (without repeats) 28 | cats = set() 29 | for hoi in dataset_dict["annotations"]["hois"]: 30 | cats.add(hoi["hoi_id"]) 31 | for cat_id in cats: 32 | interaction_freq[cat_id] += 1 33 | num_images = len(dataset_dicts) 34 | for k, v in interaction_freq.items(): 35 | interaction_freq[k] = v / num_images 36 | 37 | # 2. For each category c, compute the category-level repeat factor: 38 | # r(c) = max(1, sqrt(t / f(c))) 39 | category_rep = { 40 | cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) 41 | for cat_id, cat_freq in interaction_freq.items() 42 | } 43 | 44 | # 3. For each image I, compute the image-level repeat factor: 45 | # r(I) = max_{c in I} r(c) 46 | rep_factors = [] 47 | for i, dataset_dict in enumerate(dataset_dicts): 48 | cats = set() 49 | for hoi in dataset_dict["annotations"]["hois"]: 50 | cats.add(hoi["hoi_id"]) 51 | rep_factor = max({category_rep[cat_id] for cat_id in cats}) 52 | rep_factors.append(rep_factor) 53 | 54 | return torch.tensor(rep_factors, dtype=torch.float32) 55 | 56 | 57 | def get_dataset_indices(repeat_factors): 58 | g = torch.Generator() 59 | # Split into whole number (_int_part) and fractional (_frac_part) parts. 60 | _int_part = torch.trunc(repeat_factors) 61 | _frac_part = repeat_factors - _int_part 62 | 63 | # Since repeat factors are fractional, we use stochastic rounding so 64 | # that the target repeat factor is achieved in expectation over the 65 | # course of training 66 | rands = torch.rand(len(_frac_part), generator=g) 67 | rep_factors = _int_part + (rands < _frac_part).float() 68 | # Construct a list of indices in which we repeat images as specified 69 | indices = [] 70 | for dataset_index, rep_factor in enumerate(rep_factors): 71 | indices.extend([dataset_index] * int(rep_factor.item())) 72 | return indices -------------------------------------------------------------------------------- /utils/scheduler.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/hustvl/YOLOS/blob/main/util/scheduler.py 2 | 3 | import math 4 | from typing import Dict, Any 5 | 6 | import torch 7 | 8 | 9 | class Scheduler: 10 | """ Parameter Scheduler Base Class 11 | A scheduler base class that can be used to schedule any optimizer parameter groups. 12 | 13 | Unlike the builtin PyTorch schedulers, this is intended to be consistently called 14 | * At the END of each epoch, before incrementing the epoch count, to calculate next epoch's value 15 | * At the END of each optimizer update, after incrementing the update count, to calculate next update's value 16 | 17 | The schedulers built on this should try to remain as stateless as possible (for simplicity). 18 | 19 | This family of schedulers is attempting to avoid the confusion of the meaning of 'last_epoch' 20 | and -1 values for special behaviour. All epoch and update counts must be tracked in the training 21 | code and explicitly passed in to the schedulers on the corresponding step or step_update call. 22 | 23 | Based on ideas from: 24 | * https://github.com/pytorch/fairseq/tree/master/fairseq/optim/lr_scheduler 25 | * https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers 26 | """ 27 | 28 | def __init__(self, 29 | optimizer: torch.optim.Optimizer, 30 | param_group_field: str, 31 | noise_range_t=None, 32 | noise_type='normal', 33 | noise_pct=0.67, 34 | noise_std=1.0, 35 | noise_seed=None, 36 | initialize: bool = True) -> None: 37 | self.optimizer = optimizer 38 | self.param_group_field = param_group_field 39 | self._initial_param_group_field = f"initial_{param_group_field}" 40 | if initialize: 41 | for i, group in enumerate(self.optimizer.param_groups): 42 | if param_group_field not in group: 43 | raise KeyError(f"{param_group_field} missing from param_groups[{i}]") 44 | group.setdefault(self._initial_param_group_field, group[param_group_field]) 45 | else: 46 | for i, group in enumerate(self.optimizer.param_groups): 47 | if self._initial_param_group_field not in group: 48 | raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]") 49 | self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups] 50 | self.metric = None # any point to having this for all? 51 | self.noise_range_t = noise_range_t 52 | self.noise_pct = noise_pct 53 | self.noise_type = noise_type 54 | self.noise_std = noise_std 55 | self.noise_seed = noise_seed if noise_seed is not None else 42 56 | self.update_groups(self.base_values) 57 | 58 | def state_dict(self) -> Dict[str, Any]: 59 | return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} 60 | 61 | def load_state_dict(self, state_dict: Dict[str, Any]) -> None: 62 | self.__dict__.update(state_dict) 63 | 64 | def get_epoch_values(self, epoch: int): 65 | return None 66 | 67 | def get_update_values(self, num_updates: int): 68 | return None 69 | 70 | def step(self, epoch: int, metric: float = None) -> None: 71 | self.metric = metric 72 | values = self.get_epoch_values(epoch) 73 | if values is not None: 74 | values = self._add_noise(values, epoch) 75 | self.update_groups(values) 76 | 77 | def step_update(self, num_updates: int, metric: float = None): 78 | self.metric = metric 79 | values = self.get_update_values(num_updates) 80 | if values is not None: 81 | values = self._add_noise(values, num_updates) 82 | self.update_groups(values) 83 | 84 | def update_groups(self, values): 85 | if not isinstance(values, (list, tuple)): 86 | values = [values] * len(self.optimizer.param_groups) 87 | for param_group, value in zip(self.optimizer.param_groups, values): 88 | param_group[self.param_group_field] = value 89 | 90 | def _add_noise(self, lrs, t): 91 | if self.noise_range_t is not None: 92 | if isinstance(self.noise_range_t, (list, tuple)): 93 | apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1] 94 | else: 95 | apply_noise = t >= self.noise_range_t 96 | if apply_noise: 97 | g = torch.Generator() 98 | g.manual_seed(self.noise_seed + t) 99 | if self.noise_type == 'normal': 100 | while True: 101 | # resample if noise out of percent limit, brute force but shouldn't spin much 102 | noise = torch.randn(1, generator=g).item() 103 | if abs(noise) < self.noise_pct: 104 | break 105 | else: 106 | noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct 107 | lrs = [v + v * noise for v in lrs] 108 | return lrs 109 | 110 | 111 | class CosineLRScheduler(Scheduler): 112 | """ 113 | Cosine decay with restarts. 114 | This is described in the paper https://arxiv.org/abs/1608.03983. 115 | 116 | Inspiration from 117 | https://github.com/allenai/allennlp/blob/master/allennlp/training/learning_rate_schedulers/cosine.py 118 | """ 119 | 120 | def __init__(self, 121 | optimizer: torch.optim.Optimizer, 122 | t_initial: int, 123 | t_mul: float = 1., 124 | lr_min: float = 0., 125 | decay_rate: float = 1., 126 | warmup_t=0, 127 | warmup_lr_init=0, 128 | warmup_prefix=False, 129 | cycle_limit=0, 130 | t_in_epochs=True, 131 | noise_range_t=None, 132 | noise_pct=0.67, 133 | noise_std=1.0, 134 | noise_seed=42, 135 | initialize=True) -> None: 136 | super().__init__( 137 | optimizer, param_group_field="lr", 138 | noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, 139 | initialize=initialize) 140 | 141 | assert t_initial > 0 142 | assert lr_min >= 0 143 | if t_initial == 1 and t_mul == 1 and decay_rate == 1: 144 | _logger.warning("Cosine annealing scheduler will have no effect on the learning " 145 | "rate since t_initial = t_mul = eta_mul = 1.") 146 | self.t_initial = t_initial 147 | self.t_mul = t_mul 148 | self.lr_min = lr_min 149 | self.decay_rate = decay_rate 150 | self.cycle_limit = cycle_limit 151 | self.warmup_t = warmup_t 152 | self.warmup_lr_init = warmup_lr_init 153 | self.warmup_prefix = warmup_prefix 154 | self.t_in_epochs = t_in_epochs 155 | if self.warmup_t: 156 | self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] 157 | super().update_groups(self.warmup_lr_init) 158 | else: 159 | self.warmup_steps = [1 for _ in self.base_values] 160 | 161 | def _get_lr(self, t): 162 | if t < self.warmup_t: 163 | lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] 164 | else: 165 | if self.warmup_prefix: 166 | t = t - self.warmup_t 167 | 168 | if self.t_mul != 1: 169 | i = math.floor(math.log(1 - t / self.t_initial * (1 - self.t_mul), self.t_mul)) 170 | t_i = self.t_mul ** i * self.t_initial 171 | t_curr = t - (1 - self.t_mul ** i) / (1 - self.t_mul) * self.t_initial 172 | else: 173 | i = t // self.t_initial 174 | t_i = self.t_initial 175 | t_curr = t - (self.t_initial * i) 176 | 177 | gamma = self.decay_rate ** i 178 | lr_min = self.lr_min * gamma 179 | lr_max_values = [v * gamma for v in self.base_values] 180 | 181 | if self.cycle_limit == 0 or (self.cycle_limit > 0 and i < self.cycle_limit): 182 | lrs = [ 183 | lr_min + 0.5 * (lr_max - lr_min) * (1 + math.cos(math.pi * t_curr / t_i)) for lr_max in lr_max_values 184 | ] 185 | else: 186 | lrs = [self.lr_min for _ in self.base_values] 187 | 188 | return lrs 189 | 190 | def get_epoch_values(self, epoch: int): 191 | if self.t_in_epochs: 192 | return self._get_lr(epoch) 193 | else: 194 | return None 195 | 196 | def get_update_values(self, num_updates: int): 197 | if not self.t_in_epochs: 198 | return self._get_lr(num_updates) 199 | else: 200 | return None 201 | 202 | def get_cycle_length(self, cycles=0): 203 | if not cycles: 204 | cycles = self.cycle_limit 205 | cycles = max(1, cycles) 206 | if self.t_mul == 1.0: 207 | return self.t_initial * cycles 208 | else: 209 | return int(math.floor(-self.t_initial * (self.t_mul ** cycles - 1) / (1 - self.t_mul))) 210 | 211 | def create_scheduler(args, optimizer): 212 | num_epochs = args.epochs 213 | if getattr(args, 'lr_noise', None) is not None: 214 | lr_noise = getattr(args, 'lr_noise') 215 | if isinstance(lr_noise, (list, tuple)): 216 | noise_range = [n * num_epochs for n in lr_noise] 217 | if len(noise_range) == 1: 218 | noise_range = noise_range[0] 219 | else: 220 | noise_range = lr_noise * num_epochs 221 | else: 222 | noise_range = None 223 | 224 | lr_scheduler = None 225 | if args.sched == 'step': 226 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) 227 | elif args.sched == 'warmupcos': 228 | lr_scheduler = CosineLRScheduler( 229 | optimizer, 230 | t_initial=num_epochs, 231 | t_mul=getattr(args, 'lr_cycle_mul', 1.), 232 | lr_min=args.min_lr, 233 | decay_rate=args.decay_rate, 234 | warmup_lr_init=args.warmup_lr, 235 | warmup_t=args.warmup_epochs, 236 | cycle_limit=getattr(args, 'lr_cycle_limit', 1), 237 | t_in_epochs=True, 238 | noise_range_t=noise_range, 239 | noise_pct=getattr(args, 'lr_noise_pct', 0.67), 240 | noise_std=getattr(args, 'lr_noise_std', 1.), 241 | noise_seed=getattr(args, 'seed', 42), 242 | ) 243 | num_epochs = lr_scheduler.get_cycle_length() 244 | return lr_scheduler, num_epochs -------------------------------------------------------------------------------- /utils/visualizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import matplotlib.pyplot as plt 4 | import utils.box_ops as box_ops 5 | import torch.nn.functional as F 6 | from pathlib import Path 7 | from PIL import Image, ImageDraw 8 | 9 | 10 | class Visualizer(object): 11 | def __init__(self, args): 12 | if args.vis_dir: 13 | Path(args.vis_dir).mkdir(parents=True, exist_ok=True) 14 | self.vis_dir = Path(args.vis_dir) 15 | self.patch_size = args.vision_patch_size 16 | 17 | def visualize_preds(self, images, targets, outputs, vis_threshold=0.1): 18 | vis_images = images.tensors.permute(0, 2, 3, 1).detach().cpu().numpy() 19 | image_masks = images.mask 20 | 21 | for b in range(len(vis_images)): 22 | img_rgb = vis_images[b] 23 | img_rgb = img_rgb - img_rgb.min() 24 | img_rgb = (img_rgb / img_rgb.max()) * 255 25 | img_pd = Image.fromarray(np.uint8(img_rgb)) 26 | 27 | img_id = int(targets[b]["image_id"]) 28 | img_mask = image_masks[b] 29 | ori_h = int(torch.sum(~img_mask[:, 0])) 30 | ori_w = int(torch.sum(~img_mask[0, :])) 31 | 32 | # visualize preds 33 | hoi_scores = outputs["logits_per_hoi"][b].softmax(dim=-1) 34 | box_scores = outputs["box_scores"][b].sigmoid() 35 | scores = (hoi_scores * box_scores).detach().cpu() 36 | 37 | boxes = outputs["pred_boxes"][b].detach().cpu() 38 | pboxes = box_ops.box_cxcywh_to_xyxy(boxes[:, :4]) 39 | oboxes = box_ops.box_cxcywh_to_xyxy(boxes[:, 4:]) 40 | pboxes[:, 0::2] = pboxes[:, 0::2] * ori_w 41 | pboxes[:, 1::2] = pboxes[:, 1::2] * ori_h 42 | oboxes[:, 0::2] = oboxes[:, 0::2] * ori_w 43 | oboxes[:, 1::2] = oboxes[:, 1::2] * ori_h 44 | 45 | keep = torch.nonzero(scores > vis_threshold, as_tuple=True) 46 | scores = scores[keep].numpy() 47 | classes = keep[1].numpy() 48 | pboxes = pboxes[keep[0]].numpy() 49 | oboxes = oboxes[keep[0]].numpy() 50 | 51 | # draw predictions in descending order 52 | indices = np.argsort(scores)[::-1] 53 | for i in indices: 54 | hoi_id = int(classes[i]) 55 | img_pd = Image.fromarray(np.uint8(img_rgb)) 56 | drawing = ImageDraw.Draw(img_pd) 57 | top_left = (int(pboxes[i, 0]), int(pboxes[i, 1])) 58 | bottom_right = (int(pboxes[i, 2]), int(pboxes[i, 3])) 59 | draw_rectangle(drawing, (top_left, bottom_right), color="blue") 60 | 61 | top_left = (int(oboxes[i, 0]), int(oboxes[i, 1])) 62 | bottom_right = (int(oboxes[i, 2]), int(oboxes[i, 3])) 63 | draw_rectangle(drawing, (top_left, bottom_right), color="red") 64 | 65 | dst = Image.new('RGB', (img_pd.width, img_pd.height)) 66 | dst.paste(img_pd, (0, 0)) 67 | dst.save(self.vis_dir.joinpath(f'image_{img_id}_hoi_{hoi_id}_score_{scores[i]:.2f}.jpg')) 68 | 69 | 70 | def visualize_attention(self, images, targets, outputs, vis_threshold=0.1): 71 | vis_images = images.tensors.permute(0, 2, 3, 1).detach().cpu().numpy() 72 | image_masks = images.mask 73 | bs, h, w, _ = vis_images.shape 74 | 75 | for b in range(bs): 76 | img_rgb = vis_images[b] 77 | img_rgb = img_rgb - img_rgb.min() 78 | img_rgb = (img_rgb / img_rgb.max()) * 255 79 | img_pd = Image.fromarray(np.uint8(img_rgb)) 80 | 81 | img_id = int(targets[b]["image_id"]) 82 | img_mask = image_masks[b] 83 | ori_h = int(torch.sum(~img_mask[:, 0])) 84 | ori_w = int(torch.sum(~img_mask[0, :])) 85 | 86 | # visualize preds 87 | hoi_scores = outputs["logits_per_hoi"][b].softmax(dim=-1) 88 | box_scores = outputs["box_scores"][b].sigmoid() 89 | scores = (hoi_scores * box_scores).detach().cpu() 90 | 91 | boxes = outputs["pred_boxes"][b].detach().cpu() 92 | pboxes = box_ops.box_cxcywh_to_xyxy(boxes[:, :4]) 93 | oboxes = box_ops.box_cxcywh_to_xyxy(boxes[:, 4:]) 94 | pboxes[:, 0::2] = pboxes[:, 0::2] * ori_w 95 | pboxes[:, 1::2] = pboxes[:, 1::2] * ori_h 96 | oboxes[:, 0::2] = oboxes[:, 0::2] * ori_w 97 | oboxes[:, 1::2] = oboxes[:, 1::2] * ori_h 98 | 99 | keep = torch.nonzero(scores > vis_threshold, as_tuple=True) 100 | scores = scores[keep].numpy() 101 | classes = keep[1].numpy() 102 | pboxes = pboxes[keep[0]].numpy() 103 | oboxes = oboxes[keep[0]].numpy() 104 | 105 | # draw predictions in descending order 106 | indices = np.argsort(scores)[::-1] 107 | for i in indices: 108 | hoi_id = int(classes[i]) 109 | img_pd = Image.fromarray(np.uint8(img_rgb)) 110 | drawing = ImageDraw.Draw(img_pd) 111 | top_left = (int(pboxes[i, 0]), int(pboxes[i, 1])) 112 | bottom_right = (int(pboxes[i, 2]), int(pboxes[i, 3])) 113 | draw_rectangle(drawing, (top_left, bottom_right), color="blue") 114 | 115 | top_left = (int(oboxes[i, 0]), int(oboxes[i, 1])) 116 | bottom_right = (int(oboxes[i, 2]), int(oboxes[i, 3])) 117 | draw_rectangle(drawing, (top_left, bottom_right), color="red") 118 | 119 | dst = Image.new('RGB', (img_pd.width, img_pd.height)) 120 | dst.paste(img_pd, (0, 0)) 121 | dst.save(self.vis_dir.joinpath(f'image_{img_id}_hoi_{hoi_id}_score_{scores[i]:.2f}.jpg')) 122 | 123 | # visualize attention maps 124 | attn_map = outputs["attn_maps"][b] 125 | token_id = keep[0][i] 126 | attn = attn_map[token_id, 1:].view(1, h // self.patch_size, w // self.patch_size) 127 | attn = attn - attn.min() 128 | attn = attn / attn.max() 129 | attn = F.interpolate(attn.unsqueeze(0), scale_factor=self.patch_size, mode="nearest")[0][0].detach().cpu().numpy() 130 | plt.imsave(self.vis_dir.joinpath(f'image_{img_id}_hoi_{hoi_id}_score_{scores[i]:.2f}_attn.jpg'), arr=attn, format='jpg') 131 | 132 | 133 | def draw_rectangle(draw, coordinates, color, width=1): 134 | for i in range(width): 135 | rect_start = (coordinates[0][0] - i, coordinates[0][1] - i) 136 | rect_end = (coordinates[1][0] + i, coordinates[1][1] + i) 137 | draw.rectangle((rect_start, rect_end), outline = color) --------------------------------------------------------------------------------