├── .gitignore ├── LICENSE ├── README.md ├── configs ├── __init__.py ├── default.py ├── hico.yaml └── hoia.yaml ├── data ├── hico │ ├── rel_np.npy │ ├── test_hico.json │ ├── train_anno.json │ └── trainval_hico.json └── hoia │ ├── corre_hoia.npy │ ├── test_hoia.json │ └── train_anno.json ├── eval_hico.sh ├── eval_hoia.sh ├── eval_tools ├── hico_eval.py └── hoia_eval.py ├── libs ├── datasets │ ├── collate.py │ ├── hico_det.py │ └── transform.py ├── models │ ├── asnet.py │ ├── backbone.py │ ├── hoia_asnet.py │ ├── matcher.py │ ├── position_encoding.py │ └── transformer.py ├── trainer │ ├── hoi_trainer.py │ └── trainer.py └── utils │ ├── box_ops.py │ ├── misc.py │ └── utils.py ├── requirements.txt ├── tools ├── _init_paths.py ├── eval.py └── train.py └── train.sh /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | checkpoints/ 132 | output/ 133 | data/hico/images 134 | ASNet_* -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Mingfei Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AS-Net 2 | Code for one-stage adaptive set-based HOI detector AS-Net. 3 | 4 | Mingfei Chen*, Yue Liao*, Si Liu, Zhiyuan Chen, Fei Wang, Chen Qian. "Reformulating HOI Detection as Adaptive Set Prediction." Accepted to CVPR 2021. 5 | https://arxiv.org/abs/2103.05983 6 | 7 | ## Installation 8 | Environment 9 | - python >= 3.6 10 | 11 | Install the dependencies. 12 | ```shell 13 | pip install -r requirements.txt 14 | ``` 15 | 16 | ## Data preparation 17 | - We first download the [ HICO-DET ](https://drive.google.com/open?id=1QZcJmGVlF9f4h-XLWe9Gkmnmj2z1gSnk " HICO-DET ") dataset. 18 | - The data should be prepared in the following structure: 19 | ``` 20 | data/hico 21 | |——— images 22 | | └——————train 23 | | | └——————anno.json 24 | | | └——————XXX1.jpg 25 | | | └——————XXX2.jpg 26 | | └——————test 27 | | └——————anno.json 28 | | └——————XXX1.jpg 29 | | └——————XXX2.jpg 30 | └——— test_hico.json 31 | └——— trainval_hico.json 32 | └——— rel_np.npy 33 | ``` 34 | Noted: 35 | - We transformed the original annotation files of HICO-DET to a *.json format, like data/hico/images/train_anno.json and ata/hico/images/test_hico.json. 36 | - test_hico.json, trainval_hico.json and rel_np.npy are used in the evaluation on HICO-DET. We provided these three files in our data/hico directory. 37 | - data/hico/train_anno.json and data/hico/images/train/anno.json are the same file. 38 | `cp data/hico/train_anno.json data/hico/images/train/anno.json` 39 | - data/hico/test_hico.json and data/hico/images/test/anno.json are the same file. 40 | `cp data/hico/test_hico.json data/hico/images/test/anno.json` 41 | 42 | ## Evaluation 43 | To evaluate our model on HICO-DET: 44 | ```shell 45 | python3 tools/eval.py --cfg configs/hico.yaml MODEL.RESUME_PATH [checkpoint_path] 46 | ``` 47 | - The checkpoint is saved on HICO-DET with torch==1.4.0. 48 | - Checkpoint path:[ ASNet_hico_res50.pth ](https://drive.google.com/file/d/1EIE7KxqQO0DHU1GDRznnHnahlpOHDk6U/view?usp=sharing " ASNet_hico_res50.pth "). 49 | - Currently support evaluation on single GPU. 50 | 51 | ## Train 52 | To train our model on HICO-DET: 53 | ```shell 54 | CUDA_VISIBLE_DEVICES=0 python3 tools/train.py --cfg configs/hico.yaml MODEL.RESUME_PATH [pretrained path] 55 | ``` 56 | 57 | - The pretrained model of DETR detector [ detr-r50-e632da11.pth ]( https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth " detr-r50-e632da11.pth "). 58 | - Other pretrained models of DETR detector can be downloaded from [ detr-github ]( https://github.com/facebookresearch/detr " detr-github "). 59 | - Download the pretrain model to the [pretrained path]. 60 | 61 | 62 | ## HOIA 63 | - First download the [ HOIA ](https://drive.google.com/drive/folders/15xrIt-biSmE9hEJ2W6lWlUmdDmhatjKt " HOIA ") dataset. We also provide our transformed annotations in data/hoia. 64 | - The data preparation and training is following our data preparation and training process for HICO-DET. You need to modify the config file to hoia.yaml. 65 | - Checkpoint path:[ ASNet_hoia_res50.pth ](https://drive.google.com/file/d/1u6bCUZk063T2z5CKGwQfqWqeGKpta6kw/view?usp=sharing " ASNet_hoia_res50.pth "). 66 | 67 | ## Citation 68 | ``` 69 | @inproceedings{chen_2021_asnet, 70 | author = {Chen, Mingfei and Liao, Yue and Liu, Si and Chen, Zhiyuan and Wang, Fei and Qian, Chen}, 71 | title = {Reformulating HOI Detection as Adaptive Set Prediction}, 72 | booktitle={CVPR}, 73 | year = {2021}, 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /configs/__init__.py: -------------------------------------------------------------------------------- 1 | from .default import _C as cfg 2 | from .default import update_config -------------------------------------------------------------------------------- /configs/default.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | INF = 1e8 4 | 5 | _C = CN() 6 | 7 | # working dir 8 | _C.OUTPUT_ROOT = '' 9 | 10 | # distribution 11 | _C.DIST_BACKEND = 'nccl' 12 | _C.DEVICE = 'cuda' 13 | _C.WORKERS = 4 14 | _C.PI = 'mAP' 15 | _C.SEED = 42 16 | 17 | # cudnn related params 18 | _C.CUDNN = CN() 19 | _C.CUDNN.BENCHMARK = True 20 | _C.CUDNN.DETERMINISTIC = False 21 | _C.CUDNN.ENABLED = True 22 | 23 | # dataset 24 | _C.DATASET = CN() 25 | _C.DATASET.FILE = 'hoi_det' 26 | _C.DATASET.NAME = 'HICODetDataset' 27 | _C.DATASET.ROOT = '' 28 | _C.DATASET.MEAN = [] 29 | _C.DATASET.STD = [] 30 | _C.DATASET.MAX_SIZE = 1333 31 | _C.DATASET.SCALES = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 32 | _C.DATASET.IMG_NUM_PER_GPU = 2 33 | _C.DATASET.SUB_NUM_CLASSES = 1 34 | _C.DATASET.OBJ_NUM_CLASSES = 89 35 | _C.DATASET.REL_NUM_CLASSES = 117 36 | 37 | # model 38 | _C.MODEL = CN() 39 | # specific model 40 | _C.MODEL.FILE = '' 41 | _C.MODEL.NAME = '' 42 | # resume 43 | _C.MODEL.RESUME_PATH = '' 44 | _C.MODEL.MASKS = False 45 | 46 | # backbone 47 | _C.BACKBONE = CN() 48 | _C.BACKBONE.NAME = 'resnet50' 49 | _C.BACKBONE.DIALATION = False 50 | 51 | # transformer 52 | _C.TRANSFORMER = CN() 53 | _C.TRANSFORMER.BRANCH_AGGREGATION = False 54 | _C.TRANSFORMER.POSITION_EMBEDDING = 'sine' # choices=('sine', 'learned') 55 | _C.TRANSFORMER.HIDDEN_DIM = 256 56 | _C.TRANSFORMER.ENC_LAYERS = 6 57 | _C.TRANSFORMER.DEC_LAYERS = 6 58 | _C.TRANSFORMER.DIM_FEEDFORWARD = 2048 59 | _C.TRANSFORMER.DROPOUT = 0.1 60 | _C.TRANSFORMER.NHEADS = 8 61 | _C.TRANSFORMER.NUM_QUERIES = 100 62 | _C.TRANSFORMER.REL_NUM_QUERIES = 16 63 | _C.TRANSFORMER.PRE_NORM = False 64 | 65 | # matcher 66 | _C.MATCHER = CN() 67 | _C.MATCHER.COST_CLASS = 1 68 | _C.MATCHER.COST_BBOX = 5 69 | _C.MATCHER.COST_GIOU = 2 70 | 71 | # LOSS 72 | _C.LOSS = CN() 73 | _C.LOSS.AUX_LOSS = True 74 | _C.LOSS.DICE_LOSS_COEF = 1 75 | _C.LOSS.DET_CLS_COEF = [1, 1] 76 | _C.LOSS.REL_CLS_COEF = 1 77 | _C.LOSS.BBOX_LOSS_COEF = [5, 5] 78 | _C.LOSS.GIOU_LOSS_COEF = [2, 2] 79 | _C.LOSS.EOS_COEF = 0.1 80 | 81 | # trainer 82 | _C.TRAINER = CN() 83 | _C.TRAINER.FILE = '' 84 | _C.TRAINER.NAME = '' 85 | 86 | # train 87 | _C.TRAIN = CN() 88 | _C.TRAIN.OPTIMIZER = '' 89 | _C.TRAIN.LR = 0.0001 90 | _C.TRAIN.LR_BACKBONE = 0.00001 91 | _C.TRAIN.MOMENTUM = 0.9 92 | _C.TRAIN.WEIGHT_DECAY = 0.0001 93 | # optimizer SGD 94 | _C.TRAIN.NESTEROV = False 95 | # learning rate scheduler 96 | _C.TRAIN.LR_FACTOR = 0.1 97 | _C.TRAIN.LR_DROP = 70 98 | _C.TRAIN.CLIP_MAX_NORM = 0.1 99 | _C.TRAIN.MAX_EPOCH = 100 100 | # train resume 101 | _C.TRAIN.RESUME = False 102 | # print freq 103 | _C.TRAIN.PRINT_FREQ = 20 104 | # save checkpoint during train 105 | _C.TRAIN.SAVE_INTERVAL = 5000 106 | _C.TRAIN.SAVE_EVERY_CHECKPOINT = False 107 | # val when train 108 | _C.TRAIN.VAL_WHEN_TRAIN = False 109 | 110 | # test 111 | _C.TEST = CN() 112 | _C.TEST.REL_ARRAY_PATH = '' 113 | _C.TEST.USE_EMB = False 114 | _C.TEST.MODE = '' 115 | 116 | 117 | def update_config(config, args): 118 | config.defrost() 119 | # set cfg using yaml config file 120 | config.merge_from_file(args.yaml_file) 121 | # update cfg using args 122 | config.merge_from_list(args.opts) 123 | config.freeze() 124 | 125 | 126 | if __name__ == '__main__': 127 | import sys 128 | with open(sys.argv[1], 'w') as f: 129 | print(_C, file=f) -------------------------------------------------------------------------------- /configs/hico.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_ROOT: ASNet_hico 2 | DIST_BACKEND: 'nccl' 3 | WORKERS: 4 4 | DEVICE: cuda 5 | SEED: 42 6 | PI: mAP 7 | CUDNN: 8 | BENCHMARK: False 9 | DETERMINISTIC: False 10 | ENABLED: True 11 | DATASET: 12 | FILE: hico_det 13 | NAME: HICODetDataset 14 | ROOT: 'data/hico/images/' 15 | MEAN: [0.485, 0.456, 0.406] 16 | STD: [0.229, 0.224, 0.225] 17 | IMG_NUM_PER_GPU: 2 18 | SUB_NUM_CLASSES: 1 19 | OBJ_NUM_CLASSES: 91 20 | REL_NUM_CLASSES: 117 21 | MAX_SIZE: 1333 22 | SCALES: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 23 | MODEL: 24 | FILE: asnet 25 | NAME: ASNet 26 | RESUME_PATH: 'data/detr-r50-e632da11.pth' 27 | BACKBONE: 28 | NAME: resnet50 29 | DIALATION: False 30 | TRANSFORMER: 31 | BRANCH_AGGREGATION: True 32 | POSITION_EMBEDDING: sine 33 | HIDDEN_DIM: 256 34 | ENC_LAYERS: 6 35 | DEC_LAYERS: 6 36 | DIM_FEEDFORWARD: 2048 37 | DROPOUT: 0.1 38 | NHEADS: 8 39 | NUM_QUERIES: 100 40 | REL_NUM_QUERIES: 16 41 | PRE_NORM: False 42 | MATCHER: 43 | COST_CLASS: 1 44 | COST_BBOX: 5 45 | COST_GIOU: 2 46 | LOSS: 47 | AUX_LOSS: True 48 | DICE_LOSS_COEF: 1 49 | REL_CLS_COEF: 1 50 | DET_CLS_COEF: [1, 1] 51 | BBOX_LOSS_COEF: [5, 5] 52 | GIOU_LOSS_COEF: [2, 2] 53 | EOS_COEF: 0.1 54 | TRAINER: 55 | FILE: hoi_trainer 56 | NAME: HOITrainer 57 | TRAIN: 58 | LR: 0.0001 59 | LR_BACKBONE: 0.00001 60 | MOMENTUM: 0.9 61 | WEIGHT_DECAY: 0.0001 62 | LR_DROP: 55 63 | MAX_EPOCH: 100 64 | PRINT_FREQ: 20 65 | SAVE_INTERVAL: 5 66 | SAVE_EVERY_CHECKPOINT: True 67 | VAL_WHEN_TRAIN: False 68 | CLIP_MAX_NORM: 0.1 69 | RESUME: True 70 | TEST: 71 | REL_ARRAY_PATH: data/hico/rel_np.npy 72 | USE_EMB: TRUE 73 | MODE: hico 74 | -------------------------------------------------------------------------------- /configs/hoia.yaml: -------------------------------------------------------------------------------- 1 | OUTPUT_ROOT: ASNet_hoia 2 | DIST_BACKEND: 'nccl' 3 | WORKERS: 4 4 | DEVICE: cuda 5 | SEED: 42 6 | PI: mAP 7 | CUDNN: 8 | BENCHMARK: False 9 | DETERMINISTIC: False 10 | ENABLED: True 11 | DATASET: 12 | FILE: hico_det 13 | NAME: HICODetDataset 14 | ROOT: 'data/hoia/' 15 | MEAN: [0.485, 0.456, 0.406] 16 | STD: [0.229, 0.224, 0.225] 17 | IMG_NUM_PER_GPU: 2 18 | SUB_NUM_CLASSES: 1 19 | OBJ_NUM_CLASSES: 12 20 | REL_NUM_CLASSES: 10 21 | MAX_SIZE: 1333 22 | SCALES: [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] 23 | MODEL: 24 | FILE: hoia_asnet 25 | NAME: ASNet_HOIA 26 | RESUME_PATH: 'data/detr-r50-e632da11.pth' 27 | BACKBONE: 28 | NAME: resnet50 29 | DIALATION: False 30 | TRANSFORMER: 31 | BRANCH_AGGREGATION: True 32 | POSITION_EMBEDDING: sine 33 | HIDDEN_DIM: 256 34 | ENC_LAYERS: 6 35 | DEC_LAYERS: 6 36 | DIM_FEEDFORWARD: 2048 37 | DROPOUT: 0.1 38 | NHEADS: 8 39 | NUM_QUERIES: 100 40 | REL_NUM_QUERIES: 16 41 | PRE_NORM: False 42 | MATCHER: 43 | COST_CLASS: 1 44 | COST_BBOX: 5 45 | COST_GIOU: 2 46 | LOSS: 47 | AUX_LOSS: True 48 | DICE_LOSS_COEF: 1 49 | REL_CLS_COEF: 1 50 | DET_CLS_COEF: [1, 1] 51 | BBOX_LOSS_COEF: [5, 5] 52 | GIOU_LOSS_COEF: [2, 2] 53 | EOS_COEF: 0.1 54 | TRAINER: 55 | FILE: hoi_trainer 56 | NAME: HOITrainer 57 | TRAIN: 58 | LR: 0.0001 59 | LR_BACKBONE: 0.00001 60 | MOMENTUM: 0.9 61 | WEIGHT_DECAY: 0.0001 62 | LR_DROP: 55 63 | MAX_EPOCH: 100 64 | PRINT_FREQ: 20 65 | SAVE_INTERVAL: 5 66 | SAVE_EVERY_CHECKPOINT: True 67 | VAL_WHEN_TRAIN: False 68 | CLIP_MAX_NORM: 0.1 69 | RESUME: True 70 | TEST: 71 | REL_ARRAY_PATH: data/hoia/corre_hoia.npy 72 | USE_EMB: TRUE 73 | MODE: hoia 74 | -------------------------------------------------------------------------------- /data/hico/rel_np.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyomimi/AS-Net/85ce753707c6d1838c3983111ccbba4b1861f438/data/hico/rel_np.npy -------------------------------------------------------------------------------- /data/hoia/corre_hoia.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoyomimi/AS-Net/85ce753707c6d1838c3983111ccbba4b1861f438/data/hoia/corre_hoia.npy -------------------------------------------------------------------------------- /eval_hico.sh: -------------------------------------------------------------------------------- 1 | python3 tools/eval.py --cfg configs/hico.yaml \ 2 | MODEL.RESUME_PATH checkpoints/ASNet_hico_res50.pth -------------------------------------------------------------------------------- /eval_hoia.sh: -------------------------------------------------------------------------------- 1 | python3 tools/eval.py --cfg configs/hoia.yaml \ 2 | MODEL.RESUME_PATH checkpoints/ASNet_hoia_res50.pth -------------------------------------------------------------------------------- /eval_tools/hico_eval.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | 4 | class hico(): 5 | def __init__(self, annotation_file, train_annotation): 6 | self.annotations = mmcv.load(annotation_file) 7 | self.train_annotations = mmcv.load(train_annotation) 8 | self.overlap_iou = 0.5 9 | self.verb_name_dict = [] 10 | self.fp = {} 11 | self.tp = {} 12 | self.score = {} 13 | self.sum_gt = {} 14 | self.file_name = [] 15 | self.train_sum = {} 16 | for gt_i in self.annotations: 17 | self.file_name.append(gt_i['file_name']) 18 | gt_hoi = gt_i['hoi_annotation'] 19 | gt_bbox = gt_i['annotations'] 20 | for gt_hoi_i in gt_hoi: 21 | if isinstance(gt_hoi_i['category_id'], str): 22 | gt_hoi_i['category_id'] = int(gt_hoi_i['category_id'].replace('\n', '')) 23 | triplet = [gt_bbox[gt_hoi_i['subject_id']]['category_id'],gt_bbox[gt_hoi_i['object_id']]['category_id'],gt_hoi_i['category_id']] 24 | if triplet not in self.verb_name_dict: 25 | self.verb_name_dict.append(triplet) 26 | if self.verb_name_dict.index(triplet) not in self.sum_gt.keys(): 27 | self.sum_gt[self.verb_name_dict.index(triplet)] =0 28 | self.sum_gt[self.verb_name_dict.index(triplet)] += 1 29 | for train_i in self.train_annotations: 30 | train_hoi = train_i['hoi_annotation'] 31 | train_bbox = train_i['annotations'] 32 | for train_hoi_i in train_hoi: 33 | if isinstance(train_hoi_i['category_id'], str): 34 | train_hoi_i['category_id'] = int(train_hoi_i['category_id'].replace('\n', '')) 35 | triplet = [train_bbox[train_hoi_i['subject_id']]['category_id'],train_bbox[train_hoi_i['object_id']]['category_id'],train_hoi_i['category_id']] 36 | if triplet not in self.verb_name_dict: 37 | continue 38 | if self.verb_name_dict.index(triplet) not in self.train_sum.keys(): 39 | self.train_sum[self.verb_name_dict.index(triplet)] =0 40 | self.train_sum[self.verb_name_dict.index(triplet)] += 1 41 | for i in range(len(self.verb_name_dict)): 42 | self.fp[i] = [] 43 | self.tp[i] = [] 44 | self.score[i] = [] 45 | self.r_inds = [] 46 | self.c_inds = [] 47 | for id in self.train_sum.keys(): 48 | if self.train_sum[id] < 10: 49 | self.r_inds.append(id) 50 | else: 51 | self.c_inds.append(id) 52 | self.num_class = len(self.verb_name_dict) 53 | 54 | def evalution(self, predict_annot): 55 | for pred_i in predict_annot: 56 | if pred_i['file_name'] not in self.file_name: 57 | continue 58 | gt_i = self.annotations[self.file_name.index(pred_i['file_name'])] 59 | gt_bbox = gt_i['annotations'] 60 | if len(gt_bbox)!=0: 61 | pred_bbox = self.add_One(pred_i['predictions']) #convert zero-based to one-based indices 62 | bbox_pairs, bbox_ov = self.compute_iou_mat(gt_bbox, pred_bbox) 63 | pred_hoi = pred_i['hoi_prediction'] 64 | gt_hoi = gt_i['hoi_annotation'] 65 | self.compute_fptp(pred_hoi, gt_hoi, bbox_pairs, pred_bbox,bbox_ov) 66 | else: 67 | pred_bbox = self.add_One(pred_i['predictions']) #convert zero-based to one-based indices 68 | for i, pred_hoi_i in enumerate(pred_i['hoi_prediction']): 69 | triplet = [pred_bbox[pred_hoi_i['subject_id']]['category_id'], 70 | pred_bbox[pred_hoi_i['object_id']]['category_id'], pred_hoi_i['category_id']] 71 | verb_id = self.verb_name_dict.index(triplet) 72 | self.tp[verb_id].append(0) 73 | self.fp[verb_id].append(1) 74 | self.score[verb_id].append(pred_hoi_i['score']) 75 | map = self.compute_map() 76 | return map 77 | 78 | def compute_map(self): 79 | ap = np.zeros(self.num_class) 80 | max_recall = np.zeros(self.num_class) 81 | for i in range(len(self.verb_name_dict)): 82 | sum_gt = self.sum_gt[i] 83 | 84 | if sum_gt == 0: 85 | continue 86 | tp = np.asarray((self.tp[i]).copy()) 87 | fp = np.asarray((self.fp[i]).copy()) 88 | res_num = len(tp) 89 | if res_num == 0: 90 | continue 91 | score = np.asarray(self.score[i].copy()) 92 | sort_inds = np.argsort(-score) 93 | fp = fp[sort_inds] 94 | tp = tp[sort_inds] 95 | fp = np.cumsum(fp) 96 | tp = np.cumsum(tp) 97 | rec = tp / sum_gt 98 | prec = tp / (fp + tp) 99 | ap[i] = self.voc_ap(rec,prec) 100 | max_recall[i] = np.max(rec) 101 | mAP = np.mean(ap[:]) 102 | mAP_rare = np.mean(ap[self.r_inds]) 103 | mAP_nonrare = np.mean(ap[self.c_inds]) 104 | m_rec = np.mean(max_recall[:]) 105 | print('--------------------') 106 | print('mAP: {} mAP rare: {} mAP nonrare: {} max recall: {}'.format(mAP, mAP_rare, mAP_nonrare, m_rec)) 107 | print('--------------------') 108 | return mAP 109 | 110 | def voc_ap(self, rec, prec): 111 | ap = 0. 112 | for t in np.arange(0., 1.1, 0.1): 113 | if np.sum(rec >= t) == 0: 114 | p = 0 115 | else: 116 | p = np.max(prec[rec >= t]) 117 | ap = ap + p / 11. 118 | return ap 119 | 120 | def compute_fptp(self, pred_hoi, gt_hoi, match_pairs, pred_bbox,bbox_ov): 121 | pos_pred_ids = match_pairs.keys() 122 | vis_tag = np.zeros(len(gt_hoi)) 123 | pred_hoi.sort(key=lambda k: (k.get('score', 0)), reverse=True) 124 | if len(pred_hoi) != 0: 125 | for i, pred_hoi_i in enumerate(pred_hoi): 126 | is_match = 0 127 | if isinstance(pred_hoi_i['category_id'], str): 128 | pred_hoi_i['category_id'] = int(pred_hoi_i['category_id'].replace('\n', '')) 129 | if len(match_pairs) != 0 and pred_hoi_i['subject_id'] in pos_pred_ids and pred_hoi_i['object_id'] in pos_pred_ids: 130 | pred_sub_ids = match_pairs[pred_hoi_i['subject_id']] 131 | pred_obj_ids = match_pairs[pred_hoi_i['object_id']] 132 | pred_obj_ov=bbox_ov[pred_hoi_i['object_id']] 133 | pred_sub_ov=bbox_ov[pred_hoi_i['subject_id']] 134 | pred_category_id = pred_hoi_i['category_id'] 135 | max_ov=0 136 | max_gt_id=0 137 | for gt_id in range(len(gt_hoi)): 138 | gt_hoi_i = gt_hoi[gt_id] 139 | if (gt_hoi_i['subject_id'] in pred_sub_ids) and (gt_hoi_i['object_id'] in pred_obj_ids) and (pred_category_id == gt_hoi_i['category_id']): 140 | is_match = 1 141 | min_ov_gt=min(pred_sub_ov[pred_sub_ids.index(gt_hoi_i['subject_id'])], pred_obj_ov[pred_obj_ids.index(gt_hoi_i['object_id'])]) 142 | if min_ov_gt>max_ov: 143 | max_ov=min_ov_gt 144 | max_gt_id=gt_id 145 | triplet = [pred_bbox[pred_hoi_i['subject_id']]['category_id'], pred_bbox[pred_hoi_i['object_id']]['category_id'], pred_hoi_i['category_id']] 146 | if triplet not in self.verb_name_dict: 147 | continue 148 | verb_id = self.verb_name_dict.index(triplet) 149 | self.fp.setdefault(verb_id, []) 150 | self.tp.setdefault(verb_id, []) 151 | self.score.setdefault(verb_id, []) 152 | if is_match == 1 and vis_tag[max_gt_id] == 0: 153 | self.fp[verb_id].append(0) 154 | self.tp[verb_id].append(1) 155 | vis_tag[max_gt_id] =1 156 | else: 157 | self.fp[verb_id].append(1) 158 | self.tp[verb_id].append(0) 159 | self.score[verb_id].append(pred_hoi_i['score']) 160 | 161 | def compute_iou_mat(self, bbox_list1, bbox_list2): 162 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) 163 | if len(bbox_list1) == 0 or len(bbox_list2) == 0: 164 | return {}, {} 165 | for i, bbox1 in enumerate(bbox_list1): 166 | for j, bbox2 in enumerate(bbox_list2): 167 | iou_i = self.compute_IOU(bbox1, bbox2) 168 | iou_mat[i, j] = iou_i 169 | iou_mat_ov=iou_mat.copy() 170 | iou_mat[iou_mat>= 0.5] = 1 171 | iou_mat[iou_mat< 0.5] = 0 172 | 173 | match_pairs = np.nonzero(iou_mat) 174 | match_pairs_dict = {} 175 | match_pairs_ov={} 176 | if iou_mat.max() > 0: 177 | for i, pred_id in enumerate(match_pairs[1]): 178 | if pred_id not in match_pairs_dict.keys(): 179 | match_pairs_dict[pred_id] = [] 180 | match_pairs_ov[pred_id]=[] 181 | match_pairs_dict[pred_id].append(match_pairs[0][i]) 182 | match_pairs_ov[pred_id].append(iou_mat_ov[match_pairs[0][i],pred_id]) 183 | return match_pairs_dict,match_pairs_ov 184 | 185 | def compute_IOU(self, bbox1, bbox2): 186 | if isinstance(bbox1['category_id'], str): 187 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) 188 | if isinstance(bbox2['category_id'], str): 189 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) 190 | if bbox1['category_id'] == bbox2['category_id']: 191 | rec1 = bbox1['bbox'] 192 | rec2 = bbox2['bbox'] 193 | # computing area of each rectangles 194 | S_rec1 = (rec1[2] - rec1[0]+1) * (rec1[3] - rec1[1]+1) 195 | S_rec2 = (rec2[2] - rec2[0]+1) * (rec2[3] - rec2[1]+1) 196 | # computing the sum_area 197 | sum_area = S_rec1 + S_rec2 198 | # find the each edge of intersect rectangle 199 | left_line = max(rec1[1], rec2[1]) 200 | right_line = min(rec1[3], rec2[3]) 201 | top_line = max(rec1[0], rec2[0]) 202 | bottom_line = min(rec1[2], rec2[2]) 203 | # judge if there is an intersect 204 | if left_line >= right_line or top_line >= bottom_line: 205 | return 0 206 | else: 207 | intersect = (right_line - left_line+1) * (bottom_line - top_line+1) 208 | return intersect / (sum_area - intersect) 209 | else: 210 | return 0 211 | 212 | def add_One(self,prediction): #Add 1 to all coordinates 213 | for i, pred_bbox in enumerate(prediction): 214 | rec = pred_bbox['bbox'] 215 | rec[0]+=1 216 | rec[1]+=1 217 | rec[2]+=1 218 | rec[3]+=1 219 | return prediction -------------------------------------------------------------------------------- /eval_tools/hoia_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | 5 | class hoia(): 6 | def __init__(self, annotation_file): 7 | self.annotations = json.load(open(annotation_file, 'r')) 8 | self.overlap_iou = 0.5 9 | self.verb_name_dict = {1: 'smoke', 2: 'call', 3: 'play(cellphone)', 4: 'eat', 5: 'drink', 10 | 6: 'ride', 7: 'hold', 8: 'kick', 9: 'read', 10: 'play (computer)'} 11 | self.fp = {} 12 | self.tp = {} 13 | self.score = {} 14 | self.sum_gt = {} 15 | for i in list(self.verb_name_dict.keys()): 16 | self.fp[i] = [] 17 | self.tp[i] = [] 18 | self.score[i] = [] 19 | self.sum_gt[i] = 0 20 | self.file_name = [] 21 | for gt_i in self.annotations: 22 | self.file_name.append(gt_i['file_name']) 23 | gt_hoi = gt_i['hoi_annotation'] 24 | for gt_hoi_i in gt_hoi: 25 | if isinstance(gt_hoi_i['category_id'], str): 26 | gt_hoi_i['category_id'] = int(gt_hoi_i['category_id'].replace('\n', '')) 27 | if gt_hoi_i['category_id'] in list(self.verb_name_dict.keys()): 28 | self.sum_gt[gt_hoi_i['category_id']] += 1 29 | self.num_class = len(list(self.verb_name_dict.keys())) 30 | 31 | def evalution(self, predict_annot): 32 | for pred_i in predict_annot: 33 | if pred_i['file_name'] not in self.file_name: 34 | continue 35 | gt_i = self.annotations[self.file_name.index(pred_i['file_name'])] 36 | gt_bbox = gt_i['annotations'] 37 | pred_bbox = pred_i['predictions'] 38 | pred_hoi = pred_i['hoi_prediction'] 39 | gt_hoi = gt_i['hoi_annotation'] 40 | bbox_pairs = self.compute_iou_mat(gt_bbox, pred_bbox) 41 | self.compute_fptp(pred_hoi, gt_hoi, bbox_pairs) 42 | map = self.compute_map() 43 | return map 44 | 45 | def compute_map(self): 46 | ap = np.zeros(self.num_class) 47 | max_recall = np.zeros(self.num_class) 48 | for i in list(self.verb_name_dict.keys()): 49 | sum_gt = self.sum_gt[i] 50 | 51 | if sum_gt == 0: 52 | continue 53 | tp = np.asarray((self.tp[i]).copy()) 54 | fp = np.asarray((self.fp[i]).copy()) 55 | res_num = len(tp) 56 | if res_num == 0: 57 | continue 58 | score = np.asarray(self.score[i].copy()) 59 | sort_inds = np.argsort(-score) 60 | fp = fp[sort_inds] 61 | tp = tp[sort_inds] 62 | fp = np.cumsum(fp) 63 | tp = np.cumsum(tp) 64 | rec = tp / sum_gt 65 | prec = tp / (fp + tp) 66 | ap[i - 1] = self.voc_ap(rec,prec) 67 | max_recall[i-1] = np.max(rec) 68 | mAP = np.mean(ap[:]) 69 | m_rec = np.mean(max_recall[:]) 70 | print('--------------------') 71 | print('mAP: {} max recall: {}'.format(mAP, m_rec)) 72 | print('--------------------') 73 | return mAP 74 | 75 | def voc_ap(self, rec, prec): 76 | mrec = np.concatenate(([0.], rec, [1.])) 77 | mpre = np.concatenate(([0.], prec, [0.])) 78 | for i in range(mpre.size - 1, 0, -1): 79 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 80 | i = np.where(mrec[1:] != mrec[:-1])[0] 81 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 82 | return ap 83 | 84 | def compute_fptp(self, pred_hoi, gt_hoi, match_pairs): 85 | pos_pred_ids = match_pairs.keys() 86 | vis_tag = np.zeros(len(gt_hoi)) 87 | pred_hoi.sort(key=lambda k: (k.get('score', 0)), reverse=True) 88 | if len(pred_hoi) != 0: 89 | for i, pred_hoi_i in enumerate(pred_hoi): 90 | is_match = 0 91 | if isinstance(pred_hoi_i['category_id'], str): 92 | pred_hoi_i['category_id'] = int(pred_hoi_i['category_id'].replace('\n', '')) 93 | if len(match_pairs) != 0 and pred_hoi_i['subject_id'] in pos_pred_ids and pred_hoi_i['object_id'] in pos_pred_ids: 94 | pred_sub_ids = match_pairs[pred_hoi_i['subject_id']] 95 | pred_obj_ids = match_pairs[pred_hoi_i['object_id']] 96 | pred_category_id = pred_hoi_i['category_id'] 97 | for gt_id in np.nonzero(1 - vis_tag)[0]: 98 | gt_hoi_i = gt_hoi[gt_id] 99 | if (gt_hoi_i['subject_id'] in pred_sub_ids) and (gt_hoi_i['object_id'] in pred_obj_ids) and (pred_category_id == gt_hoi_i['category_id']): 100 | is_match = 1 101 | vis_tag[gt_id] = 1 102 | continue 103 | if pred_hoi_i['category_id'] not in list(self.fp.keys()): 104 | continue 105 | if is_match == 1: 106 | self.fp[pred_hoi_i['category_id']].append(0) 107 | self.tp[pred_hoi_i['category_id']].append(1) 108 | 109 | else: 110 | self.fp[pred_hoi_i['category_id']].append(1) 111 | self.tp[pred_hoi_i['category_id']].append(0) 112 | self.score[pred_hoi_i['category_id']].append(pred_hoi_i['score']) 113 | 114 | def compute_iou_mat(self, bbox_list1, bbox_list2): 115 | iou_mat = np.zeros((len(bbox_list1), len(bbox_list2))) 116 | if len(bbox_list1) == 0 or len(bbox_list2) == 0: 117 | return {} 118 | for i, bbox1 in enumerate(bbox_list1): 119 | for j, bbox2 in enumerate(bbox_list2): 120 | iou_i = self.compute_IOU(bbox1, bbox2) 121 | iou_mat[i, j] = iou_i 122 | iou_mat[iou_mat>= self.overlap_iou] = 1 123 | iou_mat[iou_mat< self.overlap_iou] = 0 124 | 125 | match_pairs = np.nonzero(iou_mat) 126 | match_pairs_dict = {} 127 | if iou_mat.max() > 0: 128 | for i, pred_id in enumerate(match_pairs[1]): 129 | if pred_id not in match_pairs_dict.keys(): 130 | match_pairs_dict[pred_id] = [] 131 | match_pairs_dict[pred_id].append(match_pairs[0][i]) 132 | return match_pairs_dict 133 | 134 | def compute_IOU(self, bbox1, bbox2): 135 | if isinstance(bbox1['category_id'], str): 136 | bbox1['category_id'] = int(bbox1['category_id'].replace('\n', '')) 137 | if isinstance(bbox2['category_id'], str): 138 | bbox2['category_id'] = int(bbox2['category_id'].replace('\n', '')) 139 | if bbox1['category_id'] == bbox2['category_id']: 140 | rec1 = bbox1['bbox'] 141 | rec2 = bbox2['bbox'] 142 | # computing area of each rectangles 143 | S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1]) 144 | S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1]) 145 | 146 | # computing the sum_area 147 | sum_area = S_rec1 + S_rec2 148 | 149 | # find the each edge of intersect rectangle 150 | left_line = max(rec1[1], rec2[1]) 151 | right_line = min(rec1[3], rec2[3]) 152 | top_line = max(rec1[0], rec2[0]) 153 | bottom_line = min(rec1[2], rec2[2]) 154 | # judge if there is an intersect 155 | if left_line >= right_line or top_line >= bottom_line: 156 | return 0 157 | else: 158 | intersect = (right_line - left_line) * (bottom_line - top_line) 159 | return intersect / (sum_area - intersect) 160 | else: 161 | return 0 -------------------------------------------------------------------------------- /libs/datasets/collate.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def collect(batch): 4 | """Collect the data for one batch. 5 | """ 6 | imgs = [] 7 | targets = [] 8 | filenames = [] 9 | for sample in batch: 10 | imgs.append(sample[0]) 11 | targets.append(sample[1]) 12 | filenames.append(sample[2]) 13 | return imgs, targets, filenames -------------------------------------------------------------------------------- /libs/datasets/hico_det.py: -------------------------------------------------------------------------------- 1 | import json 2 | import numpy as np 3 | import os 4 | import os.path as osp 5 | from PIL import Image 6 | from PIL import ImageFile 7 | ImageFile.LOAD_TRUNCATED_IMAGES = True 8 | import random 9 | 10 | import torch 11 | from torch.utils.data import Dataset 12 | 13 | 14 | class HICODetDataset(Dataset): 15 | 16 | def __init__(self, 17 | cfg, 18 | data_root, 19 | transform=None, 20 | istrain=False, 21 | ): 22 | """ 23 | Args: 24 | data_root: absolute root path for train or val data folder 25 | transform: train_transform or eval_transform or prediction_transform 26 | """ 27 | self.num_classes_verb = cfg.DATASET.REL_NUM_CLASSES 28 | self.data_root = data_root 29 | self.labels_path = osp.join(osp.abspath( 30 | self.data_root), 'anno.json') 31 | self.transform = transform 32 | self.hoi_annotations = json.load(open(self.labels_path, 'r')) 33 | self.ids = [] 34 | for i, hico in enumerate(self.hoi_annotations): 35 | flag_bad = 0 36 | if len(hico['annotations']) > cfg.TRANSFORMER.NUM_QUERIES: 37 | flag_bad = 1 38 | continue 39 | for hoi in hico['hoi_annotation']: 40 | if hoi['subject_id'] >= len(hico['annotations']) or hoi[ 41 | 'object_id'] >= len(hico['annotations']): 42 | flag_bad = 1 43 | break 44 | if flag_bad == 0: 45 | self.ids.append(i) 46 | self.neg_rel_id = 0 47 | 48 | def __len__(self): 49 | return len(self.ids) 50 | 51 | def multi_dense_to_one_hot(self, labels, num_classes): 52 | num_labels = labels.shape[0] 53 | index_offset = np.arange(num_labels) * num_classes 54 | labels_one_hot = np.zeros((num_labels, num_classes)) 55 | labels_one_hot.flat[index_offset + labels.ravel()] = 1 56 | one_hot = np.sum(labels_one_hot, axis=0)[1:] 57 | in_valid = np.where(one_hot>1)[0] 58 | one_hot[in_valid] = 1 59 | return one_hot 60 | 61 | def __getitem__(self, index): 62 | ann_id = self.ids[index] 63 | file_name = self.hoi_annotations[ann_id]['file_name'] 64 | img_path = os.path.join(self.data_root, file_name) 65 | 66 | anns = self.hoi_annotations[ann_id]['annotations'] 67 | hoi_anns = self.hoi_annotations[ann_id]['hoi_annotation'] 68 | 69 | if not osp.exists(img_path): 70 | logging.error("Cannot found image data: " + img_path) 71 | raise FileNotFoundError 72 | img = Image.open(img_path).convert('RGB') 73 | w, h = img.size 74 | 75 | num_object = len(anns) 76 | num_rels = len(hoi_anns) 77 | boxes = [] 78 | labels = [] 79 | no_object = False 80 | if num_object == 0: 81 | # no gt boxes 82 | no_object = True 83 | boxes = np.array([]).reshape(-1, 4) 84 | labels = np.array([]).reshape(-1,) 85 | else: 86 | for k in range(num_object): 87 | ann = anns[k] 88 | boxes.append(np.asarray(ann['bbox'])) 89 | if isinstance(ann['category_id'], str): 90 | ann['category_id'] = int(ann['category_id'].replace('\n', '')) 91 | cls_id = int(ann['category_id']) 92 | labels.append(cls_id) 93 | boxes = np.vstack(boxes) 94 | 95 | boxes = torch.from_numpy(boxes.reshape(-1, 4).astype(np.float32)) 96 | labels = np.array(labels).reshape(-1,) 97 | target = dict( 98 | boxes=boxes, 99 | labels=labels 100 | ) 101 | if self.transform is not None: 102 | img, target = self.transform( 103 | img, target 104 | ) 105 | target['labels'] = torch.from_numpy(target['labels']).long() 106 | boxes = target['boxes'] 107 | 108 | hoi_labels = [] 109 | hoi_boxes = [] 110 | if num_object == 0: 111 | hoi_boxes = torch.from_numpy(np.array([]).reshape(-1, 4)) 112 | hoi_labels = np.array([]).reshape(-1, self.num_classes_verb) 113 | else: 114 | for k in range(num_rels): 115 | hoi = hoi_anns[k] 116 | if not isinstance(hoi['category_id'], list): 117 | hoi['category_id'] = [hoi['category_id']] 118 | hoi_label_np = np.array(hoi['category_id']) 119 | hoi_labels.append(self.multi_dense_to_one_hot(hoi_label_np, 120 | self.num_classes_verb+1)) 121 | 122 | sub_ct_coord = boxes[hoi['subject_id']][..., :2] 123 | obj_ct_coord = boxes[hoi['object_id']][..., :2] 124 | hoi_boxes.append(torch.cat([sub_ct_coord, obj_ct_coord], dim=-1).reshape(-1, 4)) 125 | hoi_labels = np.array(hoi_labels).reshape(-1, self.num_classes_verb) 126 | 127 | target['rel_labels'] = torch.from_numpy(hoi_labels) 128 | if len(hoi_boxes) == 0: 129 | target['rel_vecs'] = torch.from_numpy(np.array([]).reshape(-1, 4)).float() 130 | else: 131 | target['rel_vecs'] = torch.cat(hoi_boxes).reshape(-1, 4).float() 132 | target['size'] = torch.from_numpy(np.array([h, w])) 133 | return img, target, file_name 134 | -------------------------------------------------------------------------------- /libs/datasets/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import PIL 4 | import torch 5 | import torchvision 6 | import torchvision.transforms as T 7 | import torchvision.transforms.functional as F 8 | 9 | from libs.utils.box_ops import box_xyxy_to_cxcywh 10 | 11 | 12 | def crop(image, target, region): 13 | cropped_image = F.crop(image, *region) 14 | 15 | target = target.copy() 16 | i, j, h, w = region 17 | 18 | # should we do something wrt the original size? 19 | target["size"] = torch.tensor([h, w]) 20 | 21 | fields = ["labels", "area"] # remove 'iscrowd' 22 | 23 | if "boxes" in target: 24 | boxes = target["boxes"] 25 | max_size = torch.as_tensor([w, h], dtype=torch.float32) 26 | cropped_boxes = boxes - torch.as_tensor([j, i, j, i]) 27 | cropped_boxes = torch.min(cropped_boxes.reshape(-1, 2, 2), max_size) 28 | cropped_boxes = cropped_boxes.clamp(min=0) 29 | area = (cropped_boxes[:, 1, :] - cropped_boxes[:, 0, :]).prod(dim=1) 30 | target["boxes"] = cropped_boxes.reshape(-1, 4) 31 | target["area"] = area 32 | fields.append("boxes") 33 | 34 | if "masks" in target: 35 | # FIXME should we update the area here if there are no boxes? 36 | target['masks'] = target['masks'][:, i:i + h, j:j + w] 37 | fields.append("masks") 38 | 39 | # remove elements for which the boxes or masks that have zero area 40 | if "boxes" in target or "masks" in target: 41 | # favor boxes selection when defining which elements to keep 42 | # this is compatible with previous implementation 43 | if "boxes" in target: 44 | cropped_boxes = target['boxes'].reshape(-1, 2, 2) 45 | keep = torch.all(cropped_boxes[:, 1, :] > cropped_boxes[:, 0, :], dim=1) 46 | else: 47 | keep = target['masks'].flatten(1).any(1) 48 | 49 | for field in fields: 50 | target[field] = target[field][keep] 51 | 52 | return cropped_image, target 53 | 54 | 55 | def hflip(image, target): 56 | flipped_image = F.hflip(image) 57 | 58 | w, h = image.size 59 | 60 | target = target.copy() 61 | if "boxes" in target: 62 | boxes = target["boxes"] 63 | boxes = boxes[:, [2, 1, 0, 3]] * torch.as_tensor([-1, 1, -1, 1]) + torch.as_tensor([w, 0, w, 0]) 64 | target["boxes"] = boxes 65 | 66 | if "masks" in target: 67 | target['masks'] = target['masks'].flip(-1) 68 | 69 | return flipped_image, target 70 | 71 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 72 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 73 | """ 74 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 75 | This will eventually be supported natively by PyTorch, and this 76 | class can go away. 77 | """ 78 | if float(torchvision.__version__[:3]) < 0.7: 79 | if input.numel() > 0: 80 | return torch.nn.functional.interpolate( 81 | input, size, scale_factor, mode, align_corners 82 | ) 83 | 84 | output_shape = _output_size(2, input, size, scale_factor) 85 | output_shape = list(input.shape[:-2]) + list(output_shape) 86 | return _new_empty_tensor(input, output_shape) 87 | else: 88 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 89 | 90 | 91 | def resize(image, target, size, max_size=None): 92 | # size can be min_size (scalar) or (w, h) tuple 93 | 94 | def get_size_with_aspect_ratio(image_size, size, max_size=None): 95 | w, h = image_size 96 | if max_size is not None: 97 | min_original_size = float(min((w, h))) 98 | max_original_size = float(max((w, h))) 99 | if max_original_size / min_original_size * size > max_size: 100 | size = int(round(max_size * min_original_size / max_original_size)) 101 | 102 | if (w <= h and w == size) or (h <= w and h == size): 103 | return (h, w) 104 | 105 | if w < h: 106 | ow = size 107 | oh = int(size * h / w) 108 | else: 109 | oh = size 110 | ow = int(size * w / h) 111 | 112 | return (oh, ow) 113 | 114 | def get_size(image_size, size, max_size=None): 115 | if isinstance(size, (list, tuple)): 116 | return size[::-1] 117 | else: 118 | return get_size_with_aspect_ratio(image_size, size, max_size) 119 | 120 | size = get_size(image.size, size, max_size) 121 | rescaled_image = F.resize(image, size) 122 | 123 | if target is None: 124 | return rescaled_image, None 125 | 126 | ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) 127 | ratio_width, ratio_height = ratios 128 | 129 | target = target.copy() 130 | if "boxes" in target: 131 | boxes = target["boxes"] 132 | scaled_boxes = boxes * torch.as_tensor([ratio_width, ratio_height, ratio_width, ratio_height]) 133 | target["boxes"] = scaled_boxes 134 | 135 | if "area" in target: 136 | area = target["area"] 137 | scaled_area = area * (ratio_width * ratio_height) 138 | target["area"] = scaled_area 139 | 140 | h, w = size 141 | target["size"] = torch.tensor([h, w]) 142 | 143 | if "masks" in target: 144 | target['masks'] = interpolate( 145 | target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 146 | 147 | return rescaled_image, target 148 | 149 | 150 | def pad(image, target, padding): 151 | # assumes that we only pad on the bottom right corners 152 | padded_image = F.pad(image, (0, 0, padding[0], padding[1])) 153 | if target is None: 154 | return padded_image, None 155 | target = target.copy() 156 | # should we do something wrt the original size? 157 | target["size"] = torch.tensor(padded_image[::-1]) 158 | if "masks" in target: 159 | target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) 160 | return padded_image, target 161 | 162 | 163 | class RandomCrop(object): 164 | def __init__(self, size): 165 | self.size = size 166 | 167 | def __call__(self, img, target): 168 | region = T.RandomCrop.get_params(img, self.size) 169 | return crop(img, target, region) 170 | 171 | 172 | class RandomSizeCrop(object): 173 | def __init__(self, min_size: int, max_size: int): 174 | self.min_size = min_size 175 | self.max_size = max_size 176 | 177 | def __call__(self, img: PIL.Image.Image, target: dict): 178 | w = random.randint(self.min_size, min(img.width, self.max_size)) 179 | h = random.randint(self.min_size, min(img.height, self.max_size)) 180 | region = T.RandomCrop.get_params(img, [h, w]) 181 | return crop(img, target, region) 182 | 183 | 184 | class CenterCrop(object): 185 | def __init__(self, size): 186 | self.size = size 187 | 188 | def __call__(self, img, target): 189 | image_width, image_height = img.size 190 | crop_height, crop_width = self.size 191 | crop_top = int(round((image_height - crop_height) / 2.)) 192 | crop_left = int(round((image_width - crop_width) / 2.)) 193 | return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) 194 | 195 | 196 | class RandomHorizontalFlip(object): 197 | def __init__(self, p=0.5): 198 | self.p = p 199 | 200 | def __call__(self, img, target): 201 | if random.random() < self.p: 202 | return hflip(img, target) 203 | return img, target 204 | 205 | 206 | class RandomResize(object): 207 | def __init__(self, sizes, max_size=None): 208 | assert isinstance(sizes, (list, tuple)) 209 | self.sizes = sizes 210 | self.max_size = max_size 211 | 212 | def __call__(self, img, target=None): 213 | size = random.choice(self.sizes) 214 | return resize(img, target, size, self.max_size) 215 | 216 | 217 | class RandomPad(object): 218 | def __init__(self, max_pad): 219 | self.max_pad = max_pad 220 | 221 | def __call__(self, img, target): 222 | pad_x = random.randint(0, self.max_pad) 223 | pad_y = random.randint(0, self.max_pad) 224 | return pad(img, target, (pad_x, pad_y)) 225 | 226 | 227 | class RandomSelect(object): 228 | """ 229 | Randomly selects between transforms1 and transforms2, 230 | with probability p for transforms1 and (1 - p) for transforms2 231 | """ 232 | def __init__(self, transforms1, transforms2, p=0.5): 233 | self.transforms1 = transforms1 234 | self.transforms2 = transforms2 235 | self.p = p 236 | 237 | def __call__(self, img, target): 238 | if random.random() < self.p: 239 | return self.transforms1(img, target) 240 | return self.transforms2(img, target) 241 | 242 | 243 | class ToTensor(object): 244 | def __call__(self, img, target): 245 | return F.to_tensor(img), target 246 | 247 | 248 | class RandomErasing(object): 249 | 250 | def __init__(self, *args, **kwargs): 251 | self.eraser = T.RandomErasing(*args, **kwargs) 252 | 253 | def __call__(self, img, target): 254 | return self.eraser(img), target 255 | 256 | 257 | class Normalize(object): 258 | def __init__(self, mean, std): 259 | self.mean = mean 260 | self.std = std 261 | 262 | def __call__(self, image, target=None): 263 | image = F.normalize(image, mean=self.mean, std=self.std) 264 | if target is None: 265 | return image, None 266 | target = target.copy() 267 | h, w = image.shape[-2:] 268 | if "boxes" in target: 269 | boxes = target["boxes"] 270 | boxes = box_xyxy_to_cxcywh(boxes) 271 | boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) 272 | target["boxes"] = boxes 273 | return image, target 274 | 275 | 276 | class Compose(object): 277 | def __init__(self, transforms): 278 | self.transforms = transforms 279 | 280 | def __call__(self, image, target): 281 | for t in self.transforms: 282 | image, target = t(image, target) 283 | return image, target 284 | 285 | def __repr__(self): 286 | format_string = self.__class__.__name__ + "(" 287 | for t in self.transforms: 288 | format_string += "\n" 289 | format_string += " {0}".format(t) 290 | format_string += "\n)" 291 | return format_string 292 | 293 | 294 | class TrainTransform(object): 295 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], 296 | scales=[480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800], max_size=1333): 297 | normalize = Compose([ 298 | ToTensor(), 299 | Normalize(mean, std) 300 | ]) 301 | self.augment = Compose([ 302 | RandomHorizontalFlip(), 303 | RandomResize(scales, max_size=max_size), 304 | normalize, 305 | ]) 306 | 307 | def __call__(self, img, target): 308 | # target["boxes"] xyxy; "masks"(optional) 309 | return self.augment(img, target) 310 | 311 | 312 | class EvalTransform(object): 313 | def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], 314 | max_size=1333): 315 | normalize = Compose([ 316 | ToTensor(), 317 | Normalize(mean, std) 318 | ]) 319 | self.augment = Compose([ 320 | RandomResize([800], max_size=max_size), 321 | normalize, 322 | ]) 323 | 324 | def __call__(self, img, target): 325 | return self.augment(img, target) -------------------------------------------------------------------------------- /libs/models/asnet.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | import os 4 | import sys 5 | import time 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | from scipy.spatial.distance import cdist 12 | 13 | from libs.models.backbone import build_backbone 14 | from libs.models.matcher import build_matcher 15 | from libs.models.transformer import build_transformer 16 | from libs.utils import box_ops 17 | from libs.utils.misc import (NestedTensor, nested_tensor_from_tensor_list, 18 | accuracy, get_world_size, interpolate, 19 | is_dist_avail_and_initialized) 20 | 21 | 22 | class ASNet(nn.Module): 23 | """ This is the HOI Transformer module that performs HOI detection """ 24 | def __init__(self, 25 | backbone, 26 | transformer, 27 | num_classes=dict( 28 | obj_labels=91, 29 | rel_labels=117 30 | ), 31 | num_queries=100, 32 | rel_num_queries=16, 33 | id_emb_dim=8, 34 | aux_loss=False): 35 | """ Initializes the model. 36 | Parameters: 37 | backbone: torch module of the backbone to be used. See backbone.py 38 | transformer: torch module of the transformer architecture. See transformer.py 39 | num_classes: dict of number of sub clses, obj clses and relation clses, 40 | omitting the special no-object category 41 | keys: ["obj_labels", "rel_labels"] 42 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects 43 | DETR can detect in a single image. For COCO, we recommend 100 queries. 44 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 45 | """ 46 | super().__init__() 47 | self.num_queries = num_queries 48 | self.rel_num_queries = rel_num_queries 49 | self.backbone = backbone 50 | self.transformer = transformer 51 | hidden_dim = transformer.d_model 52 | # instance branch 53 | self.class_embed = nn.Linear(hidden_dim, num_classes['obj_labels'] + 1) 54 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 55 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 56 | # interaction branch 57 | self.rel_query_embed = nn.Embedding(rel_num_queries, hidden_dim) 58 | self.rel_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 59 | self.rel_class_embed = nn.Linear(hidden_dim, num_classes['rel_labels']) 60 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) 61 | # embedding 62 | self.rel_id_embed = MLP(hidden_dim, hidden_dim, id_emb_dim, 3) 63 | self.rel_src_embed = MLP(hidden_dim, hidden_dim, id_emb_dim, 3) 64 | self.rel_dst_embed = MLP(hidden_dim, hidden_dim, id_emb_dim, 3) 65 | # aux loss of each decoder layer 66 | self.aux_loss = aux_loss 67 | 68 | def forward(self, samples: NestedTensor): 69 | """ The forward expects a NestedTensor, which consists of: 70 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W] 71 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels 72 | 73 | It returns a dict with the following elements: 74 | - "pred_logits": the classification logits (including no-object) for all queries. 75 | Shape= [batch_size x num_queries x (num_classes + 1)] 76 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 77 | (center_x, center_y, height, width). These values are normalized in [0, 1], 78 | relative to the size of each individual image (disregarding possible padding). 79 | See PostProcess for information on how to retrieve the unnormalized bounding box. 80 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 81 | dictionnaries containing the two above keys for each decoder layer. 82 | """ 83 | if isinstance(samples, (list, torch.Tensor)): 84 | samples = nested_tensor_from_tensor_list(samples) 85 | # backbone 86 | features, pos = self.backbone(samples) 87 | src, mask = features[-1].decompose() 88 | assert mask is not None 89 | input_src = self.input_proj(src) 90 | 91 | # encoder + two parellel decoders 92 | rel_hs, hs = self.transformer(input_src, mask, self.query_embed.weight, 93 | self.rel_query_embed.weight, pos[-1])[:2] 94 | rel_hs = rel_hs[-1].unsqueeze(0) 95 | hs = hs[-1].unsqueeze(0) 96 | 97 | # FFN on top of the instance decoder 98 | outputs_class = self.class_embed(hs) 99 | outputs_coord = self.bbox_embed(hs).sigmoid() 100 | id_emb = self.rel_id_embed(hs) 101 | 102 | # FFN on top of the interaction decoder 103 | outputs_rel_class = self.rel_class_embed(rel_hs) 104 | outputs_rel_coord = self.rel_bbox_embed(rel_hs).sigmoid() 105 | src_emb = self.rel_src_embed(rel_hs) 106 | dst_emb = self.rel_dst_embed(rel_hs) 107 | 108 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 109 | 'id_emb': id_emb[-1]} 110 | rel_out = {'pred_logits': outputs_rel_class[-1], 'pred_boxes': outputs_rel_coord[-1], 111 | 'src_emb': src_emb[-1], 'dst_emb': dst_emb[-1]} 112 | output = { 113 | 'pred_det': out, 114 | 'pred_rel': rel_out 115 | } 116 | if self.aux_loss: 117 | output['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, 118 | outputs_rel_class, outputs_rel_coord, id_emb, src_emb, dst_emb) 119 | 120 | return output 121 | 122 | @torch.jit.unused 123 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_rel_class, 124 | outputs_rel_coord, id_emb, src_emb, dst_emb): 125 | # this is a workaround to make torchscript happy, as torchscript 126 | # doesn't support dictionary with non-homogeneous values, such 127 | # as a dict having both a Tensor and a list. 128 | aux_output = [] 129 | for idx in range(len(outputs_class)): 130 | out = {'pred_logits': outputs_class[idx], 'pred_boxes': outputs_coord[idx], 131 | 'id_emb': id_emb[idx]} 132 | if idx < len(outputs_rel_class): 133 | rel_out = {'pred_logits': outputs_rel_class[idx], 'pred_boxes': outputs_rel_coord[idx], 134 | 'src_emb': src_emb[idx], 'dst_emb': dst_emb[idx]} 135 | else: 136 | rel_out = None 137 | aux_output.append({ 138 | 'pred_det': out, 139 | 'pred_rel': rel_out 140 | }) 141 | return aux_output 142 | 143 | 144 | class SetCriterion(nn.Module): 145 | """ This class computes the loss for HOI Transformer. 146 | The process happens in two steps: 147 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 148 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 149 | """ 150 | def __init__(self, 151 | matcher, 152 | losses, 153 | weight_dict, 154 | eos_coef, 155 | rel_eos_coef=0.1, 156 | num_classes=dict( 157 | obj_labels=90, 158 | rel_labels=117 159 | ), 160 | neg_act_id=0): 161 | """ Create the criterion. 162 | Parameters: 163 | num_classes: dict of number of sub clses, obj clses and relation clses, 164 | omitting the special no-object category 165 | keys: ["obj_labels", "rel_labels"] 166 | matcher: module able to compute a matching between targets and proposals 167 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 168 | eos_coef: relative classification weight applied to the no-object category 169 | losses: list of all the losses to be applied. See get_loss for list of available losses. 170 | """ 171 | super().__init__() 172 | self.num_classes = num_classes['obj_labels'] 173 | self.rel_classes = num_classes['rel_labels'] 174 | self.matcher = matcher 175 | self.weight_dict = weight_dict 176 | self.eos_coef = eos_coef 177 | self.losses = losses 178 | empty_weight = torch.ones(self.num_classes + 1) 179 | empty_weight[-1] = self.eos_coef 180 | self.register_buffer('empty_weight', empty_weight) 181 | 182 | def loss_labels(self, outputs_dict, targets, indices_dict, num_boxes_dict, log=True): 183 | """Classification loss (NLL) 184 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 185 | """ 186 | assert 'pred_det' in outputs_dict 187 | outputs = outputs_dict['pred_det'] 188 | assert 'pred_logits' in outputs 189 | src_logits = outputs['pred_logits'] 190 | indices = indices_dict['det'] 191 | idx = self._get_src_permutation_idx(indices) 192 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 193 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, 194 | dtype=torch.int64, device=src_logits.device) 195 | target_classes[idx] = target_classes_o 196 | 197 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) 198 | losses = {'loss_ce': loss_ce} 199 | 200 | if log: 201 | # TODO this should probably be a separate loss, not hacked in this one here 202 | losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] 203 | return losses 204 | 205 | def loss_actions(self, outputs_dict, targets, indices_dict, num_boxes_dict, log=True, 206 | neg_act_id=0, topk=5, alpha=0.25, gamma=2, loss_reduce='sum'): 207 | """Intereaction classificatioon loss (multi-label Focal Loss based on Sigmoid) 208 | targets dicts must contain the key "actions" containing a tensor of dim [nb_target_boxes] 209 | Return: 210 | losses keys:["rel_loss_ce", "rel_class_error"] 211 | """ 212 | assert 'pred_rel' in outputs_dict 213 | outputs = outputs_dict['pred_rel'] 214 | assert 'pred_logits' in outputs 215 | src_logits = outputs['pred_logits'] 216 | indices = indices_dict['rel'] 217 | idx = self._get_src_permutation_idx(indices) 218 | 219 | target_classes_obj = torch.cat([t["rel_labels"][J].to(src_logits.device) for t, (_, J) in zip(targets, indices)]) 220 | 221 | target_classes = torch.zeros(src_logits.shape[0], src_logits.shape[1], 222 | self.rel_classes).type_as(src_logits).to(src_logits.device) 223 | target_classes[idx] = target_classes_obj.type_as(src_logits) 224 | losses = {} 225 | pred_sigmoid = src_logits.sigmoid() 226 | label = target_classes.long() 227 | pt = (1 - pred_sigmoid) * label + pred_sigmoid * (1 - label) 228 | focal_weight = (alpha * label + (1 - alpha) * (1 - label)) * pt.pow(gamma) 229 | rel_loss = F.binary_cross_entropy_with_logits(src_logits, 230 | target_classes, reduction='none') * focal_weight 231 | if loss_reduce == 'mean': 232 | losses['rel_loss_ce'] = rel_loss.mean() 233 | else: 234 | losses['rel_loss_ce'] = rel_loss.sum() 235 | if log: 236 | _, pred = src_logits[idx].topk(topk, 1, True, True) 237 | acc = 0.0 238 | for tid, target in enumerate(target_classes_obj): 239 | tgt_idx = torch.where(target==1)[0] 240 | if len(tgt_idx) == 0: 241 | continue 242 | acc_pred = 0.0 243 | for tgt_rel in tgt_idx: 244 | acc_pred += (tgt_rel in pred[tid]) 245 | acc += acc_pred / len(tgt_idx) 246 | rel_labels_error = 100 - 100 * acc / len(target_classes_obj) 247 | losses['rel_class_error'] = torch.from_numpy(np.array( 248 | rel_labels_error)).to(src_logits.device).float() 249 | return losses 250 | 251 | @torch.no_grad() 252 | def loss_cardinality(self, outputs_dict, targets, indices_dict, num_boxes_dict): 253 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 254 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 255 | """ 256 | assert 'pred_det' in outputs_dict 257 | outputs = outputs_dict['pred_det'] 258 | assert 'pred_logits' in outputs 259 | pred_logits = outputs['pred_logits'] 260 | device = pred_logits.device 261 | tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) 262 | # Count the number of predictions that are NOT "no-object" (which is the last class) 263 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) 264 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 265 | losses = {'cardinality_error': card_err} 266 | return losses 267 | 268 | @torch.no_grad() 269 | def loss_rel_cardinality(self, outputs_dict, targets, indices_dict, num_boxes_dict, neg_act_id=0): 270 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 271 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 272 | """ 273 | assert 'pred_rel' in outputs_dict 274 | outputs = outputs_dict['pred_rel'] 275 | assert 'pred_logits' in outputs 276 | pred_logits = outputs['pred_logits'] 277 | device = pred_logits.device 278 | tgt_lengths = torch.as_tensor([len(v["rel_labels"]) for v in targets], device=device) 279 | # Count the number of predictions that are NOT "no-object" (which is the last class) 280 | card_pred = (pred_logits.argmax(-1) != neg_act_id).sum(1) 281 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 282 | losses = {'rel_cardinality_error': card_err} 283 | return losses 284 | 285 | def loss_boxes(self, outputs_dict, targets, indices_dict, num_boxes_dict): 286 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 287 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 288 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 289 | """ 290 | assert 'pred_det' in outputs_dict 291 | outputs = outputs_dict['pred_det'] 292 | assert 'pred_boxes' in outputs 293 | 294 | indices = indices_dict['det'] 295 | num_boxes = num_boxes_dict['det'] 296 | idx = self._get_src_permutation_idx(indices) 297 | src_boxes = outputs['pred_boxes'][idx] 298 | target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) 299 | 300 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') 301 | 302 | losses = {} 303 | losses['loss_bbox'] = loss_bbox.sum() / num_boxes 304 | 305 | loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( 306 | box_ops.box_cxcywh_to_xyxy(src_boxes), 307 | box_ops.box_cxcywh_to_xyxy(target_boxes))) 308 | losses['loss_giou'] = loss_giou.sum() / num_boxes 309 | return losses 310 | 311 | def loss_rel_vecs(self, outputs_dict, targets, indices_dict, num_boxes_dict): 312 | """Compute the losses related to the interaction vector, the L1 regression loss 313 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 314 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 315 | """ 316 | assert 'pred_rel' in outputs_dict 317 | outputs = outputs_dict['pred_rel'] 318 | assert 'pred_boxes' in outputs 319 | indices = indices_dict['rel'] 320 | num_vecs = num_boxes_dict['rel'] 321 | idx = self._get_src_permutation_idx(indices) 322 | self.out_idx = idx 323 | self.tgt_idx = self._get_tgt_permutation_idx(indices) 324 | src_vecs = outputs['pred_boxes'][idx] 325 | target_vecs = torch.cat([t['rel_vecs'][i] for t, (_, i) in zip(targets, indices)], dim=0) 326 | loss_bbox = F.l1_loss(src_vecs, target_vecs, reduction='none') 327 | losses = {} 328 | losses['rel_loss_bbox'] = loss_bbox.sum() / num_vecs 329 | return losses 330 | 331 | 332 | def loss_emb_push(self, outputs_dict, targets, indices_dict, num_boxes_dict, margin=8): 333 | """id embedding push loss. 334 | """ 335 | indices = indices_dict['det'] 336 | idx = self._get_src_permutation_idx(indices) 337 | if len(idx) == 0: 338 | losses = {'loss_push': torch.Tensor([0.]).mean().to(idx.device)} 339 | return losses 340 | id_emb = outputs_dict['pred_det']['id_emb'][idx] 341 | n = id_emb.shape[0] 342 | m = [m.reshape(-1) for m in torch.meshgrid(torch.arange(n), torch.arange(n))] 343 | mask = torch.where(m[1] < m[0])[0] 344 | emb_cmp = id_emb[m[0][mask]] - id_emb[m[1][mask]] 345 | emb_dist = torch.pow(torch.sum(torch.pow(emb_cmp, 2), 1), 0.5) 346 | loss_push = torch.pow((margin - emb_dist).clamp(0), 2).mean() 347 | losses = {'loss_push': loss_push} 348 | return losses 349 | 350 | def loss_emb_pull(self, outputs_dict, targets, indices_dict, num_boxes_dict): 351 | """id embedding pull loss. 352 | """ 353 | det_indices = indices_dict['det'] 354 | rel_indices = indices_dict['rel'] 355 | 356 | # get indices: det_idx1: [rel_idx1_src, rel_idx2_dst] 357 | det_pred_idx = self._get_src_permutation_idx(det_indices) 358 | target_det_centr = torch.cat([t['boxes'][i] for t, (_, i) in zip( 359 | targets, det_indices)], dim=0)[..., :2] 360 | rel_pred_idx = self._get_src_permutation_idx(rel_indices) 361 | if len(rel_pred_idx) == 0: 362 | losses = {'loss_pull': torch.Tensor([0.]).mean().to(rel_pred_idx.device)} 363 | return losses 364 | target_rel_centr = torch.cat([t['rel_vecs'][i] for t, (_, i) in zip( 365 | targets, rel_indices)], dim=0) 366 | src_emb = outputs_dict['pred_rel']['src_emb'][rel_pred_idx] 367 | dst_emb = outputs_dict['pred_rel']['dst_emb'][rel_pred_idx] 368 | id_emb = outputs_dict['pred_det']['id_emb'][det_pred_idx] 369 | 370 | ref_id_emb = [] 371 | for i in range(len(src_emb)): 372 | ref_idx = torch.where(target_det_centr==target_rel_centr[i, :2])[0] 373 | if len(ref_idx) == 0: 374 | # to remove cur instead of setting to 0. 375 | losses = {'loss_pull': torch.Tensor([0.]).mean().to(ref_idx.device)} 376 | return losses 377 | ref_id_emb.append(id_emb[ref_idx[0]]) 378 | for i in range(len(dst_emb)): 379 | ref_idx = torch.where(target_det_centr==target_rel_centr[i, 2:])[0] 380 | if len(ref_idx) == 0: 381 | losses = {'loss_pull': torch.Tensor([0.]).mean().to(ref_idx.device)} 382 | return losses 383 | ref_id_emb.append(id_emb[ref_idx[0]]) 384 | pred_rel_emb = torch.cat([src_emb, dst_emb], 0) 385 | ref_id_emb = torch.stack(ref_id_emb, 0).to(pred_rel_emb.device) 386 | loss_pull = torch.pow((pred_rel_emb - ref_id_emb), 2).mean() 387 | losses = {'loss_pull': loss_pull} 388 | 389 | return losses 390 | 391 | 392 | def _get_src_permutation_idx(self, indices): 393 | # permute predictions following indices 394 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 395 | src_idx = torch.cat([src for (src, _) in indices]) 396 | return batch_idx, src_idx 397 | 398 | def _get_tgt_permutation_idx(self, indices): 399 | # permute targets following indices 400 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 401 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 402 | return batch_idx, tgt_idx 403 | 404 | def _get_neg_permutation_idx(self, neg_indices): 405 | # permute neg rel predictions following indices 406 | batch_idx = torch.cat([torch.full_like(neg_ind, i) for i, neg_ind in enumerate(neg_indices)]) 407 | neg_idx = torch.cat([neg_ind for neg_ind in neg_indices]) 408 | return batch_idx, neg_idx 409 | 410 | def get_loss(self, loss, outputs_dict, targets, indices_dict, num_boxes_dict, **kwargs): 411 | if outputs_dict['pred_rel'] is None: 412 | loss_map = { 413 | 'labels': self.loss_labels, 414 | 'cardinality': self.loss_cardinality, 415 | 'boxes': self.loss_boxes 416 | } 417 | else: 418 | loss_map = { 419 | 'labels': self.loss_labels, 420 | 'cardinality': self.loss_cardinality, 421 | 'boxes': self.loss_boxes, 422 | 'actions': self.loss_actions, 423 | 'rel_vecs': self.loss_rel_vecs, 424 | 'rel_cardinality': self.loss_rel_cardinality, 425 | 'emb_push': self.loss_emb_push, 426 | 'emb_pull':self.loss_emb_pull 427 | } 428 | if loss not in loss_map: 429 | return {} 430 | return loss_map[loss](outputs_dict, targets, indices_dict, num_boxes_dict, **kwargs) 431 | 432 | 433 | def forward(self, outputs, targets): 434 | """ This performs the loss computation. 435 | Parameters: 436 | outputs: dict of tensors, see the output specification of the model for the format 437 | targets: list of dicts, such that len(targets) == batch_size. 438 | The expected keys in each dict depends on the losses applied, see each loss' doc 439 | """ 440 | indices_dict = self.matcher(outputs, targets) 441 | # Compute the average number of target boxes accross all nodes, for normalization purposes 442 | num_boxes = sum(len(t["labels"]) for t in targets) 443 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, 444 | device=next(iter(outputs['pred_det'].values())).device) 445 | rel_num_boxes = sum(len(t["rel_labels"]) for t in targets) 446 | rel_num_boxes = torch.as_tensor([rel_num_boxes], dtype=torch.float, 447 | device=next(iter(outputs['pred_rel'].values())).device) 448 | if is_dist_avail_and_initialized(): 449 | torch.distributed.all_reduce(num_boxes) 450 | torch.distributed.all_reduce(rel_num_boxes) 451 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 452 | rel_num_boxes = torch.clamp(rel_num_boxes / get_world_size(), min=1).item() 453 | num_boxes_dict = { 454 | 'det': num_boxes, 455 | 'rel': rel_num_boxes 456 | } 457 | # Compute all the requested losses 458 | losses = {} 459 | for loss in self.losses: 460 | losses.update(self.get_loss(loss, outputs, targets, 461 | indices_dict, num_boxes_dict)) 462 | 463 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 464 | if 'aux_outputs' in outputs.keys(): 465 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 466 | indices_dict = self.matcher(aux_outputs, targets) 467 | for loss in self.losses: 468 | kwargs = {} 469 | if loss == 'labels' or loss == 'actions': 470 | # Logging is enabled only for the last layer 471 | kwargs = {'log': False} 472 | l_dict = self.get_loss(loss, aux_outputs, targets, indices_dict, num_boxes_dict, **kwargs) 473 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 474 | losses.update(l_dict) 475 | 476 | return losses 477 | 478 | 479 | class MLP(nn.Module): 480 | """ Very simple multi-layer perceptron (also called FFN)""" 481 | 482 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 483 | super().__init__() 484 | self.num_layers = num_layers 485 | h = [hidden_dim] * (num_layers - 1) 486 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 487 | 488 | def forward(self, x): 489 | for i, layer in enumerate(self.layers): 490 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 491 | return x 492 | 493 | 494 | class PostProcess(nn.Module): 495 | """ This module converts the model's output into the format expected by the coco api""" 496 | def __init__(self, 497 | rel_array_path, 498 | use_emb=False): 499 | super().__init__() 500 | # use semantic embedding in the matching or not 501 | self.use_emb = use_emb 502 | # rel array to remove non-exist hoi categories in training 503 | self.rel_array_path = rel_array_path 504 | 505 | def get_matching_scores(self, s_cetr, o_cetr, s_scores, o_scores, rel_vec, 506 | s_emb, o_emb, src_emb, dst_emb): 507 | rel_s_centr = rel_vec[..., :2].unsqueeze(-1).repeat(1, 1, s_cetr.shape[0]) 508 | rel_o_centr = rel_vec[..., 2:].unsqueeze(-1).repeat(1, 1, o_cetr.shape[0]) 509 | s_cetr = s_cetr.unsqueeze(0).repeat(rel_vec.shape[0], 1, 1) 510 | s_scores = s_scores.repeat(rel_vec.shape[0], 1) 511 | o_cetr = o_cetr.unsqueeze(0).repeat(rel_vec.shape[0], 1, 1) 512 | o_scores = o_scores.repeat(rel_vec.shape[0], 1) 513 | dist_s_x = abs(rel_s_centr[..., 0, :] - s_cetr[..., 0]) 514 | dist_s_y = abs(rel_s_centr[..., 1, :] - s_cetr[..., 1]) 515 | dist_o_x = abs(rel_o_centr[..., 0, :] - o_cetr[..., 0]) 516 | dist_o_y = abs(rel_o_centr[..., 1, :] - o_cetr[..., 1]) 517 | dist_s = (1.0 / (dist_s_x + 1.0)) * (1.0 / (dist_s_y + 1.0)) 518 | dist_o = (1.0 / (dist_o_x + 1.0)) * (1.0 / (dist_o_y + 1.0)) 519 | # involving emb into the matching strategy 520 | if self.use_emb is True: 521 | s_emb_np = s_emb.data.cpu().numpy() 522 | o_emb_np = o_emb.data.cpu().numpy() 523 | src_emb_np = src_emb.data.cpu().numpy() 524 | dst_emb_np = dst_emb.data.cpu().numpy() 525 | dist_s_emb = torch.from_numpy(cdist(src_emb_np, s_emb_np, metric='euclidean')).to(rel_vec.device) 526 | dist_o_emb = torch.from_numpy(cdist(dst_emb_np, o_emb_np, metric='euclidean')).to(rel_vec.device) 527 | dist_s_emb = 1. / (dist_s_emb + 1.0) 528 | dist_o_emb = 1. / (dist_o_emb + 1.0) 529 | dist_s *= dist_s_emb 530 | dist_o *= dist_o_emb 531 | dist_s = dist_s * s_scores 532 | dist_o = dist_o * o_scores 533 | return dist_s, dist_o 534 | 535 | @torch.no_grad() 536 | def forward(self, outputs_dict, file_name, target_sizes, 537 | rel_topk=20, sub_cls=1): 538 | """ Perform the matching of postprocess to generate final predicted HOI triplets 539 | Parameters: 540 | outputs: raw outputs of the model 541 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 542 | For evaluation, this must be the original image size (before any data augmentation) 543 | For visualization, this should be the image size after data augment, but before padding 544 | """ 545 | outputs = outputs_dict['pred_det'] 546 | # '(bs, num_queries,) bs=1 547 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] 548 | id_emb = outputs['id_emb'].flatten(0, 1) 549 | rel_outputs = outputs_dict['pred_rel'] 550 | rel_out_logits, rel_out_bbox = rel_outputs['pred_logits'], \ 551 | rel_outputs['pred_boxes'] 552 | src_emb, dst_emb = rel_outputs['src_emb'].flatten(0, 1), \ 553 | rel_outputs['dst_emb'].flatten(0, 1) 554 | assert len(out_logits) == len(target_sizes) == len(rel_out_logits) \ 555 | == len(rel_out_bbox) 556 | assert target_sizes.shape[1] == 2 557 | img_h, img_w = target_sizes.unbind(1) 558 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 559 | 560 | # parse instance detection results 561 | out_bbox = out_bbox * scale_fct[:, None, :] 562 | out_bbox_flat = out_bbox.flatten(0, 1) 563 | prob = F.softmax(out_logits, -1) 564 | scores, labels = prob[..., :-1].max(-1) 565 | labels_flat = labels.flatten(0, 1) # '(bs * num_queries, ) 566 | scores_flat = scores.flatten(0, 1) 567 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox_flat) 568 | s_idx = torch.where(labels_flat==sub_cls)[0] 569 | o_idx = torch.arange(0, len(labels_flat)).long() 570 | # no detected human or object instances 571 | if len(s_idx) == 0 or len(o_idx) == 0: 572 | pred_out = { 573 | 'file_name': file_name, 574 | 'hoi_prediction': [], 575 | 'predictions': [] 576 | } 577 | return pred_out 578 | s_cetr = box_ops.box_xyxy_to_cxcywh(boxes[s_idx])[..., :2] 579 | o_cetr = box_ops.box_xyxy_to_cxcywh(boxes[o_idx])[..., :2] 580 | s_boxes, s_clses, s_scores = boxes[s_idx], labels_flat[s_idx], scores_flat[s_idx] 581 | o_boxes, o_clses, o_scores = boxes[o_idx], labels_flat[o_idx], scores_flat[o_idx] 582 | s_emb, o_emb = id_emb[s_idx], id_emb[o_idx] 583 | 584 | # parse interaction detection results 585 | rel_prob = rel_out_logits.sigmoid() 586 | topk = rel_prob.shape[-1] 587 | rel_scores = rel_prob.flatten(0, 1) 588 | hoi_labels = torch.arange(0, topk).repeat(rel_scores.shape[0], 1).to( 589 | rel_prob.device) + 1 590 | rel_vec = rel_out_bbox * scale_fct[:, None, :] 591 | rel_vec_flat = rel_vec.flatten(0, 1) 592 | 593 | # matching distance in post-processing 594 | dist_s, dist_o = self.get_matching_scores(s_cetr, o_cetr, s_scores, 595 | o_scores, rel_vec_flat, s_emb, o_emb, src_emb, dst_emb) 596 | rel_s_scores, rel_s_ids = torch.max(dist_s, dim=-1) 597 | rel_o_scores, rel_o_ids = torch.max(dist_o, dim=-1) 598 | hoi_scores = rel_scores * s_scores[rel_s_ids].unsqueeze(-1) * \ 599 | o_scores[rel_o_ids].unsqueeze(-1) 600 | 601 | # exclude non-exist hoi categories of training 602 | rel_array = torch.from_numpy(np.load(self.rel_array_path)).to(hoi_scores.device) 603 | valid_hoi_mask = rel_array[o_clses[rel_o_ids], 1:] 604 | hoi_scores = (valid_hoi_mask * hoi_scores).reshape(-1, 1) 605 | hoi_labels = hoi_labels.reshape(-1, 1) 606 | rel_s_ids = rel_s_ids.unsqueeze(-1).repeat(1, topk).reshape(-1, 1) 607 | rel_o_ids = rel_o_ids.unsqueeze(-1).repeat(1, topk).reshape(-1, 1) 608 | hoi_triplet = (torch.cat((rel_s_ids.float(), rel_o_ids.float(), hoi_labels.float(), 609 | hoi_scores), 1)).cpu().numpy() 610 | hoi_triplet = hoi_triplet[hoi_triplet[..., -1]>0.0] 611 | 612 | # remove repeated triplets 613 | hoi_triplet = hoi_triplet[np.argsort(-hoi_triplet[:,-1])] 614 | _, hoi_id = np.unique(hoi_triplet[:, [0, 1, 2]], axis=0, return_index=True) 615 | rel_triplet = hoi_triplet[hoi_id] 616 | rel_triplet = rel_triplet[np.argsort(-rel_triplet[:,-1])] 617 | 618 | # save topk hoi triplets 619 | rel_topk = min(rel_topk, len(rel_triplet)) 620 | rel_triplet = rel_triplet[:rel_topk] 621 | hoi_labels, hoi_scores = rel_triplet[..., 2], rel_triplet[..., 3] 622 | rel_s_ids, rel_o_ids = np.array(rel_triplet[..., 0], dtype=np.int64), np.array(rel_triplet[..., 1], dtype=np.int64) 623 | sub_boxes, obj_boxes = s_boxes.cpu().numpy()[rel_s_ids], o_boxes.cpu().numpy()[rel_o_ids] 624 | sub_clses, obj_clses = s_clses.cpu().numpy()[rel_s_ids], o_clses.cpu().numpy()[rel_o_ids] 625 | sub_scores, obj_scores = s_scores.cpu().numpy()[rel_s_ids], o_scores.cpu().numpy()[rel_o_ids] 626 | self.end_time = time.time() 627 | 628 | # wtite to files 629 | pred_out = {} 630 | pred_out['file_name'] = file_name 631 | pred_out['hoi_prediction'] = [] 632 | num_rel = len(hoi_labels) 633 | for i in range(num_rel): 634 | sid = i 635 | oid = i + num_rel 636 | hoi_dict = { 637 | 'subject_id': sid, 638 | 'object_id': oid, 639 | 'category_id': hoi_labels[i], 640 | 'score': hoi_scores[i] 641 | } 642 | pred_out['hoi_prediction'].append(hoi_dict) 643 | pred_out['predictions'] = [] 644 | for i in range(num_rel): 645 | det_dict = { 646 | 'bbox': sub_boxes[i], 647 | 'category_id': sub_clses[i], 648 | 'score': sub_scores[i] 649 | } 650 | pred_out['predictions'].append(det_dict) 651 | for i in range(num_rel): 652 | det_dict = { 653 | 'bbox': obj_boxes[i], 654 | 'category_id': obj_clses[i], 655 | 'score': obj_scores[i] 656 | } 657 | pred_out['predictions'].append(det_dict) 658 | return pred_out 659 | 660 | 661 | def build_model(cfg, device): 662 | backbone = build_backbone(cfg) 663 | transformer = build_transformer(cfg) 664 | num_classes=dict( 665 | obj_labels=cfg.DATASET.OBJ_NUM_CLASSES, 666 | rel_labels=cfg.DATASET.REL_NUM_CLASSES 667 | ) 668 | model = ASNet( 669 | backbone, 670 | transformer, 671 | num_classes=num_classes, 672 | num_queries=cfg.TRANSFORMER.NUM_QUERIES, 673 | rel_num_queries=cfg.TRANSFORMER.REL_NUM_QUERIES, 674 | aux_loss=cfg.LOSS.AUX_LOSS, 675 | ) 676 | matcher = build_matcher(cfg) 677 | weight_dict = {'loss_ce': cfg.LOSS.DET_CLS_COEF[0], 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF[0]} 678 | weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF[0] 679 | weight_dict.update({'rel_loss_ce': cfg.LOSS.REL_CLS_COEF, 'rel_loss_bbox': cfg.LOSS.BBOX_LOSS_COEF[1]}) 680 | weight_dict.update({'loss_pull': 0.1, 'loss_push': 0.1}) 681 | if cfg.LOSS.AUX_LOSS: 682 | aux_weight_dict = {} 683 | for i in range(cfg.TRANSFORMER.DEC_LAYERS - 1): 684 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 685 | weight_dict.update(aux_weight_dict) 686 | 687 | losses = ['labels', 'boxes', 'cardinality', 'actions', 'rel_vecs', 'rel_cardinality', 688 | 'emb_pull', 'emb_push'] 689 | criterion = SetCriterion(matcher=matcher, losses=losses, weight_dict=weight_dict, 690 | eos_coef=cfg.LOSS.EOS_COEF, num_classes=num_classes) 691 | criterion.to(device) 692 | postprocessors = PostProcess(cfg.TEST.REL_ARRAY_PATH, cfg.TEST.USE_EMB) 693 | return model, criterion, postprocessors -------------------------------------------------------------------------------- /libs/models/backbone.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import torchvision 6 | from torch import nn 7 | from torchvision.models._utils import IntermediateLayerGetter 8 | from typing import Dict, List 9 | 10 | from libs.utils.misc import NestedTensor, is_main_process 11 | from libs.models.position_encoding import build_position_encoding 12 | 13 | class FrozenBatchNorm2d(torch.nn.Module): 14 | """ 15 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 16 | 17 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 18 | without which any other models than torchvision.models.resnet[18,34,50,101] 19 | produce nans. 20 | """ 21 | 22 | def __init__(self, n): 23 | super(FrozenBatchNorm2d, self).__init__() 24 | self.register_buffer("weight", torch.ones(n)) 25 | self.register_buffer("bias", torch.zeros(n)) 26 | self.register_buffer("running_mean", torch.zeros(n)) 27 | self.register_buffer("running_var", torch.ones(n)) 28 | 29 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 30 | missing_keys, unexpected_keys, error_msgs): 31 | num_batches_tracked_key = prefix + 'num_batches_tracked' 32 | if num_batches_tracked_key in state_dict: 33 | del state_dict[num_batches_tracked_key] 34 | 35 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 36 | state_dict, prefix, local_metadata, strict, 37 | missing_keys, unexpected_keys, error_msgs) 38 | 39 | def forward(self, x): 40 | # move reshapes to the beginning 41 | # to make it fuser-friendly 42 | w = self.weight.reshape(1, -1, 1, 1) 43 | b = self.bias.reshape(1, -1, 1, 1) 44 | rv = self.running_var.reshape(1, -1, 1, 1) 45 | rm = self.running_mean.reshape(1, -1, 1, 1) 46 | eps = 1e-5 47 | scale = w * (rv + eps).rsqrt() 48 | bias = b - rm * scale 49 | return x * scale + bias 50 | 51 | 52 | class BackboneBase(nn.Module): 53 | 54 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 55 | super().__init__() 56 | for name, parameter in backbone.named_parameters(): 57 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 58 | parameter.requires_grad_(False) 59 | if return_interm_layers: 60 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 61 | else: 62 | return_layers = {'layer4': "0"} 63 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 64 | self.num_channels = num_channels 65 | 66 | def forward(self, tensor_list: NestedTensor): 67 | xs = self.body(tensor_list.tensors) 68 | out: Dict[str, NestedTensor] = {} 69 | for name, x in xs.items(): 70 | m = tensor_list.mask 71 | assert m is not None 72 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 73 | out[name] = NestedTensor(x, mask) 74 | return out 75 | 76 | 77 | class Backbone(BackboneBase): 78 | """ResNet backbone with frozen BatchNorm.""" 79 | def __init__(self, name: str, 80 | train_backbone: bool, 81 | return_interm_layers: bool, 82 | dilation: bool): 83 | backbone = getattr(torchvision.models, name)( 84 | replace_stride_with_dilation=[False, False, dilation], 85 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 86 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 87 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 88 | 89 | 90 | class Joiner(nn.Sequential): 91 | def __init__(self, backbone, position_embedding): 92 | super().__init__(backbone, position_embedding) 93 | 94 | def forward(self, tensor_list: NestedTensor): 95 | xs = self[0](tensor_list) 96 | out: List[NestedTensor] = [] 97 | pos = [] 98 | for name, x in xs.items(): 99 | out.append(x) 100 | # position encoding 101 | pos.append(self[1](x).to(x.tensors.dtype)) 102 | return out, pos 103 | 104 | 105 | def build_backbone(cfg): 106 | position_embedding = build_position_encoding(cfg) 107 | train_backbone = cfg.TRAIN.LR_BACKBONE > 0 108 | return_interm_layers = cfg.MODEL.MASKS 109 | backbone = Backbone(cfg.BACKBONE.NAME, train_backbone, return_interm_layers, 110 | cfg.BACKBONE.DIALATION) 111 | model = Joiner(backbone, position_embedding) 112 | model.num_channels = backbone.num_channels 113 | return model 114 | -------------------------------------------------------------------------------- /libs/models/hoia_asnet.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | import os 4 | import sys 5 | import time 6 | 7 | import torch 8 | from torch import nn 9 | import torch.nn.functional as F 10 | 11 | from scipy.spatial.distance import cdist 12 | 13 | from libs.models.backbone import build_backbone 14 | from libs.models.matcher import build_matcher 15 | from libs.models.transformer import build_transformer 16 | from libs.utils import box_ops 17 | from libs.utils.misc import (NestedTensor, nested_tensor_from_tensor_list, 18 | accuracy, get_world_size, interpolate, 19 | is_dist_avail_and_initialized) 20 | 21 | 22 | class ASNet_HOIA(nn.Module): 23 | """ This is the HOI Transformer module that performs HOI detection """ 24 | def __init__(self, 25 | backbone, 26 | transformer, 27 | num_classes=dict( 28 | obj_labels=91, 29 | rel_labels=117 30 | ), 31 | num_queries=100, 32 | rel_num_queries=16, 33 | id_emb_dim=8, 34 | aux_loss=False): 35 | """ Initializes the model. 36 | Parameters: 37 | backbone: torch module of the backbone to be used. See backbone.py 38 | transformer: torch module of the transformer architecture. See transformer.py 39 | num_classes: dict of number of sub clses, obj clses and relation clses, 40 | omitting the special no-object category 41 | keys: ["obj_labels", "rel_labels"] 42 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects 43 | DETR can detect in a single image. For COCO, we recommend 100 queries. 44 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 45 | """ 46 | super().__init__() 47 | self.num_queries = num_queries 48 | self.rel_num_queries = rel_num_queries 49 | self.backbone = backbone 50 | self.transformer = transformer 51 | hidden_dim = transformer.d_model 52 | # instance branch 53 | self.rel_det_class_embed = nn.Linear(hidden_dim, num_classes['obj_labels'] + 1) 54 | self.rel_det_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 55 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 56 | # interaction branch 57 | self.rel_query_embed = nn.Embedding(rel_num_queries, hidden_dim) 58 | self.rel_bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 59 | self.rel_class_embed = nn.Linear(hidden_dim, num_classes['rel_labels']) 60 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) 61 | # embedding 62 | self.rel_id_embed = MLP(hidden_dim, hidden_dim, id_emb_dim, 3) 63 | self.rel_src_embed = MLP(hidden_dim, hidden_dim, id_emb_dim, 3) 64 | self.rel_dst_embed = MLP(hidden_dim, hidden_dim, id_emb_dim, 3) 65 | # aux loss of each decoder layer 66 | self.aux_loss = aux_loss 67 | 68 | def forward(self, samples: NestedTensor): 69 | """ The forward expects a NestedTensor, which consists of: 70 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W] 71 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels 72 | 73 | It returns a dict with the following elements: 74 | - "pred_logits": the classification logits (including no-object) for all queries. 75 | Shape= [batch_size x num_queries x (num_classes + 1)] 76 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 77 | (center_x, center_y, height, width). These values are normalized in [0, 1], 78 | relative to the size of each individual image (disregarding possible padding). 79 | See PostProcess for information on how to retrieve the unnormalized bounding box. 80 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 81 | dictionnaries containing the two above keys for each decoder layer. 82 | """ 83 | if isinstance(samples, (list, torch.Tensor)): 84 | samples = nested_tensor_from_tensor_list(samples) 85 | # backbone 86 | features, pos = self.backbone(samples) 87 | src, mask = features[-1].decompose() 88 | assert mask is not None 89 | input_src = self.input_proj(src) 90 | 91 | # encoder + two parellel decoders 92 | rel_hs, hs = self.transformer(input_src, mask, self.query_embed.weight, 93 | self.rel_query_embed.weight, pos[-1])[:2] 94 | rel_hs = rel_hs[-1].unsqueeze(0) 95 | hs = hs[-1].unsqueeze(0) 96 | 97 | # FFN on top of the instance decoder 98 | outputs_class = self.rel_det_class_embed(hs) 99 | outputs_coord = self.rel_det_bbox_embed(hs).sigmoid() 100 | id_emb = self.rel_id_embed(hs) 101 | 102 | # FFN on top of the interaction decoder 103 | outputs_rel_class = self.rel_class_embed(rel_hs) 104 | outputs_rel_coord = self.rel_bbox_embed(rel_hs).sigmoid() 105 | src_emb = self.rel_src_embed(rel_hs) 106 | dst_emb = self.rel_dst_embed(rel_hs) 107 | 108 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1], 109 | 'id_emb': id_emb[-1]} 110 | rel_out = {'pred_logits': outputs_rel_class[-1], 'pred_boxes': outputs_rel_coord[-1], 111 | 'src_emb': src_emb[-1], 'dst_emb': dst_emb[-1]} 112 | output = { 113 | 'pred_det': out, 114 | 'pred_rel': rel_out 115 | } 116 | if self.aux_loss: 117 | output['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord, 118 | outputs_rel_class, outputs_rel_coord, id_emb, src_emb, dst_emb) 119 | 120 | return output 121 | 122 | @torch.jit.unused 123 | def _set_aux_loss(self, outputs_class, outputs_coord, outputs_rel_class, 124 | outputs_rel_coord, id_emb, src_emb, dst_emb): 125 | # this is a workaround to make torchscript happy, as torchscript 126 | # doesn't support dictionary with non-homogeneous values, such 127 | # as a dict having both a Tensor and a list. 128 | aux_output = [] 129 | for idx in range(len(outputs_class)): 130 | out = {'pred_logits': outputs_class[idx], 'pred_boxes': outputs_coord[idx], 131 | 'id_emb': id_emb[idx]} 132 | if idx < len(outputs_rel_class): 133 | rel_out = {'pred_logits': outputs_rel_class[idx], 'pred_boxes': outputs_rel_coord[idx], 134 | 'src_emb': src_emb[idx], 'dst_emb': dst_emb[idx]} 135 | else: 136 | rel_out = None 137 | aux_output.append({ 138 | 'pred_det': out, 139 | 'pred_rel': rel_out 140 | }) 141 | return aux_output 142 | 143 | 144 | class SetCriterion(nn.Module): 145 | """ This class computes the loss for HOI Transformer. 146 | The process happens in two steps: 147 | 1) we compute hungarian assignment between ground truth boxes and the outputs of the model 148 | 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) 149 | """ 150 | def __init__(self, 151 | matcher, 152 | losses, 153 | weight_dict, 154 | eos_coef, 155 | rel_eos_coef=0.1, 156 | num_classes=dict( 157 | obj_labels=90, 158 | rel_labels=117 159 | ), 160 | neg_act_id=0): 161 | """ Create the criterion. 162 | Parameters: 163 | num_classes: dict of number of sub clses, obj clses and relation clses, 164 | omitting the special no-object category 165 | keys: ["obj_labels", "rel_labels"] 166 | matcher: module able to compute a matching between targets and proposals 167 | weight_dict: dict containing as key the names of the losses and as values their relative weight. 168 | eos_coef: relative classification weight applied to the no-object category 169 | losses: list of all the losses to be applied. See get_loss for list of available losses. 170 | """ 171 | super().__init__() 172 | self.num_classes = num_classes['obj_labels'] 173 | self.rel_classes = num_classes['rel_labels'] 174 | self.matcher = matcher 175 | self.weight_dict = weight_dict 176 | self.eos_coef = eos_coef 177 | self.losses = losses 178 | empty_weight = torch.ones(self.num_classes + 1) 179 | empty_weight[-1] = self.eos_coef 180 | self.register_buffer('empty_weight', empty_weight) 181 | 182 | def loss_labels(self, outputs_dict, targets, indices_dict, num_boxes_dict, log=True): 183 | """Classification loss (NLL) 184 | targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] 185 | """ 186 | assert 'pred_det' in outputs_dict 187 | outputs = outputs_dict['pred_det'] 188 | assert 'pred_logits' in outputs 189 | src_logits = outputs['pred_logits'] 190 | indices = indices_dict['det'] 191 | idx = self._get_src_permutation_idx(indices) 192 | target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) 193 | target_classes = torch.full(src_logits.shape[:2], self.num_classes, 194 | dtype=torch.int64, device=src_logits.device) 195 | target_classes[idx] = target_classes_o 196 | 197 | loss_ce = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight) 198 | losses = {'loss_ce': loss_ce} 199 | 200 | if log: 201 | # TODO this should probably be a separate loss, not hacked in this one here 202 | losses['class_error'] = 100 - accuracy(src_logits[idx], target_classes_o)[0] 203 | return losses 204 | 205 | def loss_actions(self, outputs_dict, targets, indices_dict, num_boxes_dict, log=True, 206 | neg_act_id=0, topk=5, alpha=0.25, gamma=2, loss_reduce='sum'): 207 | """Intereaction classificatioon loss (multi-label Focal Loss based on Sigmoid) 208 | targets dicts must contain the key "actions" containing a tensor of dim [nb_target_boxes] 209 | Return: 210 | losses keys:["rel_loss_ce", "rel_class_error"] 211 | """ 212 | assert 'pred_rel' in outputs_dict 213 | outputs = outputs_dict['pred_rel'] 214 | assert 'pred_logits' in outputs 215 | src_logits = outputs['pred_logits'] 216 | indices = indices_dict['rel'] 217 | idx = self._get_src_permutation_idx(indices) 218 | 219 | target_classes_obj = torch.cat([t["rel_labels"][J].to(src_logits.device) for t, (_, J) in zip(targets, indices)]) 220 | 221 | target_classes = torch.zeros(src_logits.shape[0], src_logits.shape[1], 222 | self.rel_classes).type_as(src_logits).to(src_logits.device) 223 | target_classes[idx] = target_classes_obj.type_as(src_logits) 224 | losses = {} 225 | pred_sigmoid = src_logits.sigmoid() 226 | label = target_classes.long() 227 | pt = (1 - pred_sigmoid) * label + pred_sigmoid * (1 - label) 228 | focal_weight = (alpha * label + (1 - alpha) * (1 - label)) * pt.pow(gamma) 229 | rel_loss = F.binary_cross_entropy_with_logits(src_logits, 230 | target_classes, reduction='none') * focal_weight 231 | if loss_reduce == 'mean': 232 | losses['rel_loss_ce'] = rel_loss.mean() 233 | else: 234 | losses['rel_loss_ce'] = rel_loss.sum() 235 | if log: 236 | _, pred = src_logits[idx].topk(topk, 1, True, True) 237 | acc = 0.0 238 | for tid, target in enumerate(target_classes_obj): 239 | tgt_idx = torch.where(target==1)[0] 240 | if len(tgt_idx) == 0: 241 | continue 242 | acc_pred = 0.0 243 | for tgt_rel in tgt_idx: 244 | acc_pred += (tgt_rel in pred[tid]) 245 | acc += acc_pred / len(tgt_idx) 246 | rel_labels_error = 100 - 100 * acc / len(target_classes_obj) 247 | losses['rel_class_error'] = torch.from_numpy(np.array( 248 | rel_labels_error)).to(src_logits.device).float() 249 | return losses 250 | 251 | @torch.no_grad() 252 | def loss_cardinality(self, outputs_dict, targets, indices_dict, num_boxes_dict): 253 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 254 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 255 | """ 256 | assert 'pred_det' in outputs_dict 257 | outputs = outputs_dict['pred_det'] 258 | assert 'pred_logits' in outputs 259 | pred_logits = outputs['pred_logits'] 260 | device = pred_logits.device 261 | tgt_lengths = torch.as_tensor([len(v["labels"]) for v in targets], device=device) 262 | # Count the number of predictions that are NOT "no-object" (which is the last class) 263 | card_pred = (pred_logits.argmax(-1) != pred_logits.shape[-1] - 1).sum(1) 264 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 265 | losses = {'cardinality_error': card_err} 266 | return losses 267 | 268 | @torch.no_grad() 269 | def loss_rel_cardinality(self, outputs_dict, targets, indices_dict, num_boxes_dict, neg_act_id=0): 270 | """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes 271 | This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients 272 | """ 273 | assert 'pred_rel' in outputs_dict 274 | outputs = outputs_dict['pred_rel'] 275 | assert 'pred_logits' in outputs 276 | pred_logits = outputs['pred_logits'] 277 | device = pred_logits.device 278 | tgt_lengths = torch.as_tensor([len(v["rel_labels"]) for v in targets], device=device) 279 | # Count the number of predictions that are NOT "no-object" (which is the last class) 280 | card_pred = (pred_logits.argmax(-1) != neg_act_id).sum(1) 281 | card_err = F.l1_loss(card_pred.float(), tgt_lengths.float()) 282 | losses = {'rel_cardinality_error': card_err} 283 | return losses 284 | 285 | def loss_boxes(self, outputs_dict, targets, indices_dict, num_boxes_dict): 286 | """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss 287 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 288 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 289 | """ 290 | assert 'pred_det' in outputs_dict 291 | outputs = outputs_dict['pred_det'] 292 | assert 'pred_boxes' in outputs 293 | 294 | indices = indices_dict['det'] 295 | num_boxes = num_boxes_dict['det'] 296 | idx = self._get_src_permutation_idx(indices) 297 | src_boxes = outputs['pred_boxes'][idx] 298 | target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0) 299 | 300 | loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none') 301 | 302 | losses = {} 303 | losses['loss_bbox'] = loss_bbox.sum() / num_boxes 304 | 305 | loss_giou = 1 - torch.diag(box_ops.generalized_box_iou( 306 | box_ops.box_cxcywh_to_xyxy(src_boxes), 307 | box_ops.box_cxcywh_to_xyxy(target_boxes))) 308 | losses['loss_giou'] = loss_giou.sum() / num_boxes 309 | return losses 310 | 311 | def loss_rel_vecs(self, outputs_dict, targets, indices_dict, num_boxes_dict): 312 | """Compute the losses related to the interaction vector, the L1 regression loss 313 | targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] 314 | The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. 315 | """ 316 | assert 'pred_rel' in outputs_dict 317 | outputs = outputs_dict['pred_rel'] 318 | assert 'pred_boxes' in outputs 319 | indices = indices_dict['rel'] 320 | num_vecs = num_boxes_dict['rel'] 321 | idx = self._get_src_permutation_idx(indices) 322 | self.out_idx = idx 323 | self.tgt_idx = self._get_tgt_permutation_idx(indices) 324 | src_vecs = outputs['pred_boxes'][idx] 325 | target_vecs = torch.cat([t['rel_vecs'][i] for t, (_, i) in zip(targets, indices)], dim=0) 326 | loss_bbox = F.l1_loss(src_vecs, target_vecs, reduction='none') 327 | losses = {} 328 | losses['rel_loss_bbox'] = loss_bbox.sum() / num_vecs 329 | return losses 330 | 331 | 332 | def loss_emb_push(self, outputs_dict, targets, indices_dict, num_boxes_dict, margin=8): 333 | """id embedding push loss. 334 | """ 335 | indices = indices_dict['det'] 336 | idx = self._get_src_permutation_idx(indices) 337 | if len(idx) == 0: 338 | losses = {'loss_push': torch.Tensor([0.]).mean().to(idx.device)} 339 | return losses 340 | id_emb = outputs_dict['pred_det']['id_emb'][idx] 341 | n = id_emb.shape[0] 342 | m = [m.reshape(-1) for m in torch.meshgrid(torch.arange(n), torch.arange(n))] 343 | mask = torch.where(m[1] < m[0])[0] 344 | emb_cmp = id_emb[m[0][mask]] - id_emb[m[1][mask]] 345 | emb_dist = torch.pow(torch.sum(torch.pow(emb_cmp, 2), 1), 0.5) 346 | loss_push = torch.pow((margin - emb_dist).clamp(0), 2).mean() 347 | losses = {'loss_push': loss_push} 348 | return losses 349 | 350 | def loss_emb_pull(self, outputs_dict, targets, indices_dict, num_boxes_dict): 351 | """id embedding pull loss. 352 | """ 353 | det_indices = indices_dict['det'] 354 | rel_indices = indices_dict['rel'] 355 | 356 | # get indices: det_idx1: [rel_idx1_src, rel_idx2_dst] 357 | det_pred_idx = self._get_src_permutation_idx(det_indices) 358 | target_det_centr = torch.cat([t['boxes'][i] for t, (_, i) in zip( 359 | targets, det_indices)], dim=0)[..., :2] 360 | rel_pred_idx = self._get_src_permutation_idx(rel_indices) 361 | if len(rel_pred_idx) == 0: 362 | losses = {'loss_pull': torch.Tensor([0.]).mean().to(rel_pred_idx.device)} 363 | return losses 364 | target_rel_centr = torch.cat([t['rel_vecs'][i] for t, (_, i) in zip( 365 | targets, rel_indices)], dim=0) 366 | src_emb = outputs_dict['pred_rel']['src_emb'][rel_pred_idx] 367 | dst_emb = outputs_dict['pred_rel']['dst_emb'][rel_pred_idx] 368 | id_emb = outputs_dict['pred_det']['id_emb'][det_pred_idx] 369 | 370 | ref_id_emb = [] 371 | for i in range(len(src_emb)): 372 | ref_idx = torch.where(target_det_centr==target_rel_centr[i, :2])[0] 373 | if len(ref_idx) == 0: 374 | # to remove cur instead of setting to 0. 375 | losses = {'loss_pull': torch.Tensor([0.]).mean().to(ref_idx.device)} 376 | return losses 377 | ref_id_emb.append(id_emb[ref_idx[0]]) 378 | for i in range(len(dst_emb)): 379 | ref_idx = torch.where(target_det_centr==target_rel_centr[i, 2:])[0] 380 | if len(ref_idx) == 0: 381 | losses = {'loss_pull': torch.Tensor([0.]).mean().to(ref_idx.device)} 382 | return losses 383 | ref_id_emb.append(id_emb[ref_idx[0]]) 384 | pred_rel_emb = torch.cat([src_emb, dst_emb], 0) 385 | ref_id_emb = torch.stack(ref_id_emb, 0).to(pred_rel_emb.device) 386 | loss_pull = torch.pow((pred_rel_emb - ref_id_emb), 2).mean() 387 | losses = {'loss_pull': loss_pull} 388 | 389 | return losses 390 | 391 | 392 | def _get_src_permutation_idx(self, indices): 393 | # permute predictions following indices 394 | batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)]) 395 | src_idx = torch.cat([src for (src, _) in indices]) 396 | return batch_idx, src_idx 397 | 398 | def _get_tgt_permutation_idx(self, indices): 399 | # permute targets following indices 400 | batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) 401 | tgt_idx = torch.cat([tgt for (_, tgt) in indices]) 402 | return batch_idx, tgt_idx 403 | 404 | def _get_neg_permutation_idx(self, neg_indices): 405 | # permute neg rel predictions following indices 406 | batch_idx = torch.cat([torch.full_like(neg_ind, i) for i, neg_ind in enumerate(neg_indices)]) 407 | neg_idx = torch.cat([neg_ind for neg_ind in neg_indices]) 408 | return batch_idx, neg_idx 409 | 410 | def get_loss(self, loss, outputs_dict, targets, indices_dict, num_boxes_dict, **kwargs): 411 | if outputs_dict['pred_rel'] is None: 412 | loss_map = { 413 | 'labels': self.loss_labels, 414 | 'cardinality': self.loss_cardinality, 415 | 'boxes': self.loss_boxes 416 | } 417 | else: 418 | loss_map = { 419 | 'labels': self.loss_labels, 420 | 'cardinality': self.loss_cardinality, 421 | 'boxes': self.loss_boxes, 422 | 'actions': self.loss_actions, 423 | 'rel_vecs': self.loss_rel_vecs, 424 | 'rel_cardinality': self.loss_rel_cardinality, 425 | 'emb_push': self.loss_emb_push, 426 | 'emb_pull':self.loss_emb_pull 427 | } 428 | if loss not in loss_map: 429 | return {} 430 | return loss_map[loss](outputs_dict, targets, indices_dict, num_boxes_dict, **kwargs) 431 | 432 | 433 | def forward(self, outputs, targets): 434 | """ This performs the loss computation. 435 | Parameters: 436 | outputs: dict of tensors, see the output specification of the model for the format 437 | targets: list of dicts, such that len(targets) == batch_size. 438 | The expected keys in each dict depends on the losses applied, see each loss' doc 439 | """ 440 | indices_dict = self.matcher(outputs, targets) 441 | # Compute the average number of target boxes accross all nodes, for normalization purposes 442 | num_boxes = sum(len(t["labels"]) for t in targets) 443 | num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, 444 | device=next(iter(outputs['pred_det'].values())).device) 445 | rel_num_boxes = sum(len(t["rel_labels"]) for t in targets) 446 | rel_num_boxes = torch.as_tensor([rel_num_boxes], dtype=torch.float, 447 | device=next(iter(outputs['pred_rel'].values())).device) 448 | if is_dist_avail_and_initialized(): 449 | torch.distributed.all_reduce(num_boxes) 450 | torch.distributed.all_reduce(rel_num_boxes) 451 | num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item() 452 | rel_num_boxes = torch.clamp(rel_num_boxes / get_world_size(), min=1).item() 453 | num_boxes_dict = { 454 | 'det': num_boxes, 455 | 'rel': rel_num_boxes 456 | } 457 | # Compute all the requested losses 458 | losses = {} 459 | for loss in self.losses: 460 | losses.update(self.get_loss(loss, outputs, targets, 461 | indices_dict, num_boxes_dict)) 462 | 463 | # In case of auxiliary losses, we repeat this process with the output of each intermediate layer. 464 | if 'aux_outputs' in outputs.keys(): 465 | for i, aux_outputs in enumerate(outputs['aux_outputs']): 466 | indices_dict = self.matcher(aux_outputs, targets) 467 | for loss in self.losses: 468 | kwargs = {} 469 | if loss == 'labels' or loss == 'actions': 470 | # Logging is enabled only for the last layer 471 | kwargs = {'log': False} 472 | l_dict = self.get_loss(loss, aux_outputs, targets, indices_dict, num_boxes_dict, **kwargs) 473 | l_dict = {k + f'_{i}': v for k, v in l_dict.items()} 474 | losses.update(l_dict) 475 | 476 | return losses 477 | 478 | 479 | class MLP(nn.Module): 480 | """ Very simple multi-layer perceptron (also called FFN)""" 481 | 482 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 483 | super().__init__() 484 | self.num_layers = num_layers 485 | h = [hidden_dim] * (num_layers - 1) 486 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 487 | 488 | def forward(self, x): 489 | for i, layer in enumerate(self.layers): 490 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 491 | return x 492 | 493 | 494 | class PostProcess(nn.Module): 495 | """ This module converts the model's output into the format expected by the coco api""" 496 | def __init__(self, 497 | rel_array_path, 498 | use_emb=False): 499 | super().__init__() 500 | # use semantic embedding in the matching or not 501 | self.use_emb = use_emb 502 | # rel array to remove non-exist hoi categories in training 503 | self.rel_array_path = rel_array_path 504 | 505 | def get_matching_scores(self, s_cetr, o_cetr, s_scores, o_scores, rel_vec, 506 | s_emb, o_emb, src_emb, dst_emb): 507 | rel_s_centr = rel_vec[..., :2].unsqueeze(-1).repeat(1, 1, s_cetr.shape[0]) 508 | rel_o_centr = rel_vec[..., 2:].unsqueeze(-1).repeat(1, 1, o_cetr.shape[0]) 509 | s_cetr = s_cetr.unsqueeze(0).repeat(rel_vec.shape[0], 1, 1) 510 | s_scores = s_scores.repeat(rel_vec.shape[0], 1) 511 | o_cetr = o_cetr.unsqueeze(0).repeat(rel_vec.shape[0], 1, 1) 512 | o_scores = o_scores.repeat(rel_vec.shape[0], 1) 513 | dist_s_x = abs(rel_s_centr[..., 0, :] - s_cetr[..., 0]) 514 | dist_s_y = abs(rel_s_centr[..., 1, :] - s_cetr[..., 1]) 515 | dist_o_x = abs(rel_o_centr[..., 0, :] - o_cetr[..., 0]) 516 | dist_o_y = abs(rel_o_centr[..., 1, :] - o_cetr[..., 1]) 517 | dist_s = (1.0 / (dist_s_x + 1.0)) * (1.0 / (dist_s_y + 1.0)) 518 | dist_o = (1.0 / (dist_o_x + 1.0)) * (1.0 / (dist_o_y + 1.0)) 519 | # involving emb into the matching strategy 520 | if self.use_emb is True: 521 | s_emb_np = s_emb.data.cpu().numpy() 522 | o_emb_np = o_emb.data.cpu().numpy() 523 | src_emb_np = src_emb.data.cpu().numpy() 524 | dst_emb_np = dst_emb.data.cpu().numpy() 525 | dist_s_emb = torch.from_numpy(cdist(src_emb_np, s_emb_np, metric='euclidean')).to(rel_vec.device) 526 | dist_o_emb = torch.from_numpy(cdist(dst_emb_np, o_emb_np, metric='euclidean')).to(rel_vec.device) 527 | dist_s_emb = 1. / (dist_s_emb + 1.0) 528 | dist_o_emb = 1. / (dist_o_emb + 1.0) 529 | dist_s *= dist_s_emb 530 | dist_o *= dist_o_emb 531 | dist_s = dist_s * s_scores 532 | dist_o = dist_o * o_scores 533 | return dist_s, dist_o 534 | 535 | @torch.no_grad() 536 | def forward(self, outputs_dict, file_name, target_sizes, 537 | rel_topk=20, sub_cls=1): 538 | """ Perform the matching of postprocess to generate final predicted HOI triplets 539 | Parameters: 540 | outputs: raw outputs of the model 541 | target_sizes: tensor of dimension [batch_size x 2] containing the size of each images of the batch 542 | For evaluation, this must be the original image size (before any data augmentation) 543 | For visualization, this should be the image size after data augment, but before padding 544 | """ 545 | outputs = outputs_dict['pred_det'] 546 | # '(bs, num_queries,) bs=1 547 | out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes'] 548 | id_emb = outputs['id_emb'].flatten(0, 1) 549 | rel_outputs = outputs_dict['pred_rel'] 550 | rel_out_logits, rel_out_bbox = rel_outputs['pred_logits'], \ 551 | rel_outputs['pred_boxes'] 552 | src_emb, dst_emb = rel_outputs['src_emb'].flatten(0, 1), \ 553 | rel_outputs['dst_emb'].flatten(0, 1) 554 | assert len(out_logits) == len(target_sizes) == len(rel_out_logits) \ 555 | == len(rel_out_bbox) 556 | assert target_sizes.shape[1] == 2 557 | img_h, img_w = target_sizes.unbind(1) 558 | scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) 559 | 560 | # parse instance detection results 561 | out_bbox = out_bbox * scale_fct[:, None, :] 562 | out_bbox_flat = out_bbox.flatten(0, 1) 563 | prob = F.softmax(out_logits, -1) 564 | scores, labels = prob[..., :-1].max(-1) 565 | labels_flat = labels.flatten(0, 1) # '(bs * num_queries, ) 566 | scores_flat = scores.flatten(0, 1) 567 | boxes = box_ops.box_cxcywh_to_xyxy(out_bbox_flat) 568 | s_idx = torch.where(labels_flat==sub_cls)[0] 569 | o_idx = torch.arange(0, len(labels_flat)).long() 570 | # no detected human or object instances 571 | if len(s_idx) == 0 or len(o_idx) == 0: 572 | pred_out = { 573 | 'file_name': file_name, 574 | 'hoi_prediction': [], 575 | 'predictions': [] 576 | } 577 | return pred_out 578 | s_cetr = box_ops.box_xyxy_to_cxcywh(boxes[s_idx])[..., :2] 579 | o_cetr = box_ops.box_xyxy_to_cxcywh(boxes[o_idx])[..., :2] 580 | s_boxes, s_clses, s_scores = boxes[s_idx], labels_flat[s_idx], scores_flat[s_idx] 581 | o_boxes, o_clses, o_scores = boxes[o_idx], labels_flat[o_idx], scores_flat[o_idx] 582 | s_emb, o_emb = id_emb[s_idx], id_emb[o_idx] 583 | 584 | # parse interaction detection results 585 | rel_prob = rel_out_logits.sigmoid() 586 | topk = rel_prob.shape[-1] 587 | rel_scores = rel_prob.flatten(0, 1) 588 | hoi_labels = torch.arange(0, topk).repeat(rel_scores.shape[0], 1).to( 589 | rel_prob.device) + 1 590 | rel_vec = rel_out_bbox * scale_fct[:, None, :] 591 | rel_vec_flat = rel_vec.flatten(0, 1) 592 | 593 | # matching distance in post-processing 594 | dist_s, dist_o = self.get_matching_scores(s_cetr, o_cetr, s_scores, 595 | o_scores, rel_vec_flat, s_emb, o_emb, src_emb, dst_emb) 596 | rel_s_scores, rel_s_ids = torch.max(dist_s, dim=-1) 597 | rel_o_scores, rel_o_ids = torch.max(dist_o, dim=-1) 598 | hoi_scores = rel_scores * s_scores[rel_s_ids].unsqueeze(-1) * \ 599 | o_scores[rel_o_ids].unsqueeze(-1) 600 | 601 | # exclude non-exist hoi categories of training 602 | rel_array = torch.from_numpy(np.load(self.rel_array_path)).to(hoi_scores.device) 603 | valid_hoi_mask = rel_array[..., o_clses[rel_o_ids]-1].permute(1, 0) 604 | hoi_scores = (valid_hoi_mask * hoi_scores).reshape(-1, 1) 605 | hoi_labels = hoi_labels.reshape(-1, 1) 606 | rel_s_ids = rel_s_ids.unsqueeze(-1).repeat(1, topk).reshape(-1, 1) 607 | rel_o_ids = rel_o_ids.unsqueeze(-1).repeat(1, topk).reshape(-1, 1) 608 | hoi_triplet = (torch.cat((rel_s_ids.float(), rel_o_ids.float(), hoi_labels.float(), 609 | hoi_scores.float()), 1)).cpu().numpy() 610 | hoi_triplet = hoi_triplet[hoi_triplet[..., -1]>0.0] 611 | 612 | # remove repeated triplets 613 | if len(hoi_triplet) == 0: 614 | pred_out = { 615 | 'file_name': file_name, 616 | 'hoi_prediction': [], 617 | 'predictions': [] 618 | } 619 | return pred_out 620 | hoi_triplet = hoi_triplet[np.argsort(-hoi_triplet[:,-1])] 621 | _, hoi_id = np.unique(hoi_triplet[:, [0, 1, 2]], axis=0, return_index=True) 622 | rel_triplet = hoi_triplet[hoi_id] 623 | rel_triplet = rel_triplet[np.argsort(-rel_triplet[:,-1])] 624 | 625 | # save topk hoi triplets 626 | rel_topk = min(rel_topk, len(rel_triplet)) 627 | rel_triplet = rel_triplet[:rel_topk] 628 | hoi_labels, hoi_scores = rel_triplet[..., 2], rel_triplet[..., 3] 629 | rel_s_ids, rel_o_ids = np.array(rel_triplet[..., 0], dtype=np.int64), np.array(rel_triplet[..., 1], dtype=np.int64) 630 | sub_boxes, obj_boxes = s_boxes.cpu().numpy()[rel_s_ids], o_boxes.cpu().numpy()[rel_o_ids] 631 | sub_clses, obj_clses = s_clses.cpu().numpy()[rel_s_ids], o_clses.cpu().numpy()[rel_o_ids] 632 | sub_scores, obj_scores = s_scores.cpu().numpy()[rel_s_ids], o_scores.cpu().numpy()[rel_o_ids] 633 | self.end_time = time.time() 634 | 635 | # wtite to files 636 | pred_out = {} 637 | pred_out['file_name'] = file_name 638 | pred_out['hoi_prediction'] = [] 639 | num_rel = len(hoi_labels) 640 | for i in range(num_rel): 641 | sid = i 642 | oid = i + num_rel 643 | hoi_dict = { 644 | 'subject_id': sid, 645 | 'object_id': oid, 646 | 'category_id': hoi_labels[i], 647 | 'score': hoi_scores[i] 648 | } 649 | pred_out['hoi_prediction'].append(hoi_dict) 650 | pred_out['predictions'] = [] 651 | for i in range(num_rel): 652 | det_dict = { 653 | 'bbox': sub_boxes[i], 654 | 'category_id': sub_clses[i], 655 | 'score': sub_scores[i] 656 | } 657 | pred_out['predictions'].append(det_dict) 658 | for i in range(num_rel): 659 | det_dict = { 660 | 'bbox': obj_boxes[i], 661 | 'category_id': obj_clses[i], 662 | 'score': obj_scores[i] 663 | } 664 | pred_out['predictions'].append(det_dict) 665 | return pred_out 666 | 667 | 668 | def build_model(cfg, device): 669 | backbone = build_backbone(cfg) 670 | transformer = build_transformer(cfg) 671 | num_classes=dict( 672 | obj_labels=cfg.DATASET.OBJ_NUM_CLASSES, 673 | rel_labels=cfg.DATASET.REL_NUM_CLASSES 674 | ) 675 | model = ASNet_HOIA( 676 | backbone, 677 | transformer, 678 | num_classes=num_classes, 679 | num_queries=cfg.TRANSFORMER.NUM_QUERIES, 680 | rel_num_queries=cfg.TRANSFORMER.REL_NUM_QUERIES, 681 | aux_loss=cfg.LOSS.AUX_LOSS, 682 | ) 683 | matcher = build_matcher(cfg) 684 | weight_dict = {'loss_ce': cfg.LOSS.DET_CLS_COEF[0], 'loss_bbox': cfg.LOSS.BBOX_LOSS_COEF[0]} 685 | weight_dict['loss_giou'] = cfg.LOSS.GIOU_LOSS_COEF[0] 686 | weight_dict.update({'rel_loss_ce': cfg.LOSS.REL_CLS_COEF, 'rel_loss_bbox': cfg.LOSS.BBOX_LOSS_COEF[1]}) 687 | weight_dict.update({'loss_pull': 0.1, 'loss_push': 0.1}) 688 | if cfg.LOSS.AUX_LOSS: 689 | aux_weight_dict = {} 690 | for i in range(cfg.TRANSFORMER.DEC_LAYERS - 1): 691 | aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()}) 692 | weight_dict.update(aux_weight_dict) 693 | 694 | losses = ['labels', 'boxes', 'cardinality', 'actions', 'rel_vecs', 'rel_cardinality', 695 | 'emb_pull', 'emb_push'] 696 | criterion = SetCriterion(matcher=matcher, losses=losses, weight_dict=weight_dict, 697 | eos_coef=cfg.LOSS.EOS_COEF, num_classes=num_classes) 698 | criterion.to(device) 699 | postprocessors = PostProcess(cfg.TEST.REL_ARRAY_PATH, cfg.TEST.USE_EMB) 700 | return model, criterion, postprocessors -------------------------------------------------------------------------------- /libs/models/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import linear_sum_assignment 3 | from torch import nn 4 | 5 | from libs.utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 6 | 7 | 8 | class HungarianMatcher(nn.Module): 9 | """This class computes an assignment between the targets and the predictions of the network 10 | 11 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 12 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 13 | while the others are un-matched (and thus treated as non-objects). 14 | """ 15 | 16 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 17 | """Creates the matcher 18 | 19 | Params: 20 | cost_class: This is the relative weight of the classification error in the matching cost 21 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 22 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 23 | """ 24 | super().__init__() 25 | self.cost_class = cost_class 26 | self.cost_bbox = cost_bbox 27 | self.cost_giou = cost_giou 28 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 29 | 30 | @torch.no_grad() 31 | def forward(self, outputs_dict, targets): 32 | """ Performs the matching 33 | 34 | Returns: 35 | A list of size batch_size, containing tuples of (index_i, index_j) where: 36 | - index_i is the indices of the selected predictions (in order) 37 | - index_j is the indices of the corresponding selected targets (in order) 38 | For each batch element, it holds: 39 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 40 | """ 41 | outputs = outputs_dict['pred_det'] 42 | bs, num_queries = outputs["pred_logits"].shape[:2] 43 | 44 | # We flatten to compute the cost matrices in a batch 45 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 46 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 47 | 48 | # Also concat the target labels and boxes 49 | tgt_ids = torch.cat([v["labels"] for v in targets]) 50 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 51 | 52 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 53 | # but approximate it in 1 - proba[target class]. 54 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 55 | cost_class = -out_prob[:, tgt_ids] 56 | 57 | # Compute the L1 cost between boxes 58 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 59 | 60 | # Compute the giou cost betwen boxes 61 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 62 | 63 | # Final cost matrix 64 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 65 | C = C.view(bs, num_queries, -1).cpu() 66 | 67 | sizes = [len(v["boxes"]) for v in targets] 68 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 69 | indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 70 | 71 | if outputs_dict['pred_rel'] is None: 72 | indices_dict = { 73 | 'det': indices, 74 | 'rel': None 75 | } 76 | return indices_dict 77 | 78 | # for rel 79 | rel_outputs = outputs_dict['pred_rel'] 80 | bs, rel_num_queries = rel_outputs["pred_logits"].shape[:2] 81 | rel_out_prob = rel_outputs["pred_logits"].flatten(0, 1).sigmoid() # [batch_size * num_queries, num_classes] 82 | rel_out_bbox = rel_outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 83 | rel_tgt_ids = torch.cat([v["rel_labels"] for v in targets]) 84 | rel_tgt_bbox = torch.cat([v["rel_vecs"] for v in targets]) 85 | 86 | # interaction category semantic distance 87 | rel_cost_list = [] 88 | for idx, r_tgt_id in enumerate(rel_tgt_ids): 89 | tgt_rel_id = torch.where(r_tgt_id == 1)[0] 90 | rel_cost_list.append(-(rel_out_prob[:, tgt_rel_id]).sum( 91 | dim=-1) * self.cost_class) 92 | rel_cost_class = torch.stack(rel_cost_list, dim=-1) 93 | # another implementation 94 | # rel_cost_class = -(rel_out_prob * rel_tgt_ids).sum( 95 | # dim=-1) * self.cost_class) 96 | 97 | # interaction vector location distance 98 | rel_cost_bbox = torch.cdist(rel_out_bbox, rel_tgt_bbox, p=1) 99 | 100 | # Final cost matrix 101 | rel_C = self.cost_bbox * rel_cost_bbox + self.cost_class * rel_cost_class 102 | rel_C = rel_C.view(bs, rel_num_queries, -1).cpu() 103 | 104 | rel_sizes = [len(v["rel_vecs"]) for v in targets] 105 | rel_indices = [linear_sum_assignment(c[i]) for i, c in enumerate(rel_C.split(rel_sizes, -1))] 106 | rel_indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in rel_indices] 107 | 108 | indices_dict = { 109 | 'det': indices, 110 | 'rel': rel_indices, 111 | } 112 | 113 | return indices_dict 114 | 115 | 116 | def build_matcher(cfg): 117 | return HungarianMatcher(cost_class=cfg.MATCHER.COST_CLASS, 118 | cost_bbox=cfg.MATCHER.COST_BBOX, cost_giou=cfg.MATCHER.COST_GIOU) 119 | 120 | -------------------------------------------------------------------------------- /libs/models/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | from libs.utils.misc import NestedTensor 6 | 7 | 8 | class PositionEmbeddingSine(nn.Module): 9 | """ 10 | This is a more standard version of the position embedding, very similar to the one 11 | used by the Attention is all you need paper, generalized to work on images. 12 | """ 13 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 14 | super().__init__() 15 | self.num_pos_feats = num_pos_feats 16 | self.temperature = temperature 17 | self.normalize = normalize 18 | if scale is not None and normalize is False: 19 | raise ValueError("normalize should be True if scale is passed") 20 | if scale is None: 21 | scale = 2 * math.pi 22 | self.scale = scale 23 | 24 | def forward(self, tensor_list: NestedTensor): 25 | x = tensor_list.tensors 26 | mask = tensor_list.mask 27 | assert mask is not None 28 | not_mask = ~mask 29 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 30 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 31 | if self.normalize: 32 | eps = 1e-6 33 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 34 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 35 | 36 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 37 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 38 | 39 | pos_x = x_embed[:, :, :, None] / dim_t 40 | pos_y = y_embed[:, :, :, None] / dim_t 41 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 42 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 43 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 44 | return pos 45 | 46 | 47 | class PositionEmbeddingLearned(nn.Module): 48 | """ 49 | Absolute pos embedding, learned. 50 | """ 51 | def __init__(self, num_pos_feats=256): 52 | super().__init__() 53 | self.row_embed = nn.Embedding(50, num_pos_feats) 54 | self.col_embed = nn.Embedding(50, num_pos_feats) 55 | self.reset_parameters() 56 | 57 | def reset_parameters(self): 58 | nn.init.uniform_(self.row_embed.weight) 59 | nn.init.uniform_(self.col_embed.weight) 60 | 61 | def forward(self, tensor_list: NestedTensor): 62 | x = tensor_list.tensors 63 | h, w = x.shape[-2:] 64 | i = torch.arange(w, device=x.device) 65 | j = torch.arange(h, device=x.device) 66 | x_emb = self.col_embed(i) 67 | y_emb = self.row_embed(j) 68 | pos = torch.cat([ 69 | x_emb.unsqueeze(0).repeat(h, 1, 1), 70 | y_emb.unsqueeze(1).repeat(1, w, 1), 71 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 72 | return pos 73 | 74 | 75 | def build_position_encoding(cfg): 76 | N_steps = cfg.TRANSFORMER.HIDDEN_DIM // 2 77 | if cfg.TRANSFORMER.POSITION_EMBEDDING in ('v2', 'sine'): 78 | # TODO find a better way of exposing other arguments 79 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 80 | elif cfg.TRANSFORMER.POSITION_EMBEDDING in ('v3', 'learned'): 81 | position_embedding = PositionEmbeddingLearned(N_steps) 82 | else: 83 | raise ValueError(f"not supported {cfg.TRANSFORMER.POSITION_EMBEDDING}") 84 | 85 | return position_embedding 86 | -------------------------------------------------------------------------------- /libs/models/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import nn, Tensor 9 | 10 | 11 | class InteractionTransformer(nn.Module): 12 | 13 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, 14 | num_decoder_layers=6, num_rel_decoder_layers=6, 15 | dim_feedforward=2048, dropout=0.1, 16 | activation="relu", normalize_before=False, 17 | return_intermediate_dec=False): 18 | super().__init__() 19 | self.d_model = d_model 20 | self.nhead = nhead 21 | # encoder of the backbone to refine feature sequence 22 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 23 | dropout, activation, normalize_before) 24 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 25 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 26 | # interaction branch 27 | rel_decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 28 | dropout, activation, normalize_before) 29 | rel_decoder_norm = nn.LayerNorm(d_model) 30 | # instance branch 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(d_model) 34 | # branch aggregation: instance-aware attention 35 | interaction_layer = InteractionLayer(d_model, d_model, dropout) 36 | 37 | self.decoder = InteractionTransformerDecoder( 38 | decoder_layer, 39 | rel_decoder_layer, 40 | num_decoder_layers, 41 | interaction_layer, 42 | decoder_norm, 43 | rel_decoder_norm, 44 | return_intermediate_dec) 45 | 46 | self._reset_parameters() 47 | 48 | 49 | def _reset_parameters(self): 50 | for p in self.parameters(): 51 | if p.dim() > 1: 52 | nn.init.xavier_uniform_(p) 53 | 54 | def forward(self, src, mask, query_embed, rel_query_embed, pos_embed): 55 | # flatten NxCxHxW to HWxNxC 56 | bs, c, h, w = src.shape 57 | # generate feature sequence 58 | src = src.flatten(2).permute(2, 0, 1) 59 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 60 | mask = mask.flatten(1) 61 | # refine the feature sequence using encoder 62 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 63 | # object query set 64 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 65 | # interaction query set 66 | rel_query_embed = rel_query_embed.unsqueeze(1).repeat(1, bs, 1) 67 | # initialize the input of instance branch 68 | tgt = torch.zeros_like(query_embed) 69 | # initialize the input of interaction branch 70 | rel_tgt = torch.zeros_like(rel_query_embed) 71 | # memory shape: (W*H, bs, d_model) 72 | hs, rel_hs = self.decoder(tgt, rel_tgt, memory, memory_key_padding_mask=mask, 73 | pos=pos_embed, query_pos=query_embed, rel_query_pos=rel_query_embed) 74 | 75 | return rel_hs.transpose(1, 2), hs.transpose(1, 2), memory.permute( 76 | 1, 2, 0).view(bs, c, h, w) 77 | 78 | 79 | class Transformer(nn.Module): 80 | 81 | def __init__(self, d_model=256, nhead=8, num_encoder_layers=6, 82 | num_decoder_layers=6, num_rel_decoder_layers=6, 83 | dim_feedforward=2048, dropout=0.1, 84 | activation="relu", normalize_before=False, 85 | return_intermediate_dec=False): 86 | super().__init__() 87 | self.d_model = d_model 88 | self.nhead = nhead 89 | # encoder 90 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 91 | dropout, activation, normalize_before) 92 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 93 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 94 | # interaction branch 95 | rel_decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 96 | dropout, activation, normalize_before) 97 | rel_decoder_norm = nn.LayerNorm(d_model) 98 | self.rel_decoder = TransformerDecoder(rel_decoder_layer, num_rel_decoder_layers, rel_decoder_norm, 99 | return_intermediate=return_intermediate_dec) 100 | # instance branch 101 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 102 | dropout, activation, normalize_before) 103 | decoder_norm = nn.LayerNorm(d_model) 104 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 105 | return_intermediate=return_intermediate_dec) 106 | self._reset_parameters() 107 | 108 | def _reset_parameters(self): 109 | for p in self.parameters(): 110 | if p.dim() > 1: 111 | nn.init.xavier_uniform_(p) 112 | 113 | def forward(self, src, mask, query_embed, rel_query_embed, pos_embed): 114 | # flatten NxCxHxW to HWxNxC 115 | bs, c, h, w = src.shape 116 | src = src.flatten(2).permute(2, 0, 1) 117 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 118 | mask = mask.flatten(1) 119 | # memory shape: (W*H, bs, d_model) 120 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 121 | 122 | # object query set 123 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 124 | # interaction query set 125 | rel_query_embed = rel_query_embed.unsqueeze(1).repeat(1, bs, 1) 126 | # initialize the input of instance branch 127 | tgt = torch.zeros_like(query_embed) 128 | # initialize the input of interaction branch 129 | rel_tgt = torch.zeros_like(rel_query_embed) 130 | # interaction decoder 131 | rel_hs = self.rel_decoder(rel_tgt, memory, memory_key_padding_mask=mask, 132 | pos=pos_embed, query_pos=rel_query_embed) 133 | # insatance decoder 134 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 135 | pos=pos_embed, query_pos=query_embed) 136 | 137 | return rel_hs.transpose(1, 2), hs.transpose(1, 2), memory.permute( 138 | 1, 2, 0).view(bs, c, h, w) 139 | 140 | 141 | class TransformerEncoder(nn.Module): 142 | 143 | def __init__(self, encoder_layer, num_layers, norm=None): 144 | super().__init__() 145 | self.layers = _get_clones(encoder_layer, num_layers) 146 | self.num_layers = num_layers 147 | self.norm = norm 148 | 149 | def forward(self, src, 150 | mask: Optional[Tensor] = None, 151 | src_key_padding_mask: Optional[Tensor] = None, 152 | pos: Optional[Tensor] = None): 153 | output = src 154 | 155 | for layer in self.layers: 156 | output = layer(output, src_mask=mask, 157 | src_key_padding_mask=src_key_padding_mask, pos=pos) 158 | 159 | if self.norm is not None: 160 | output = self.norm(output) 161 | 162 | return output 163 | 164 | 165 | class InteractionLayer(nn.Module): 166 | def __init__(self, d_model, d_feature, dropout=0.1): 167 | super().__init__() 168 | self.d_feature = d_feature 169 | 170 | self.det_tfm = nn.Linear(d_model, d_feature) 171 | self.rel_tfm = nn.Linear(d_model, d_feature) 172 | self.det_value_tfm = nn.Linear(d_model, d_feature) 173 | 174 | self.rel_norm = nn.LayerNorm(d_model) 175 | 176 | if dropout is not None: 177 | self.dropout = dropout 178 | self.det_dropout = nn.Dropout(dropout) 179 | self.rel_add_dropout = nn.Dropout(dropout) 180 | else: 181 | self.dropout = None 182 | 183 | def forward(self, det_in, rel_in): 184 | det_attn_in = self.det_tfm(det_in) 185 | rel_attn_in = self.rel_tfm(rel_in) 186 | det_value = self.det_value_tfm(det_in) 187 | scores = torch.matmul(det_attn_in.transpose(0, 1), 188 | rel_attn_in.permute(1, 2, 0)) / math.sqrt(self.d_feature) 189 | det_weight = F.softmax(scores.transpose(1, 2), dim = -1) 190 | if self.dropout is not None: 191 | det_weight = self.det_dropout(det_weight) 192 | rel_add = torch.matmul(det_weight, det_value.transpose(0, 1)) 193 | rel_out = self.rel_add_dropout(rel_add) + rel_in.transpose(0, 1) 194 | rel_out = self.rel_norm(rel_out) 195 | 196 | return det_in, rel_out.transpose(0, 1) 197 | 198 | 199 | class InteractionTransformerDecoder(nn.Module): 200 | 201 | def __init__(self, 202 | decoder_layer, 203 | rel_decoder_layer, 204 | num_layers, 205 | interaction_layer=None, 206 | norm=None, 207 | rel_norm=None, 208 | return_intermediate=False): 209 | super().__init__() 210 | self.layers = _get_clones(decoder_layer, num_layers) 211 | self.rel_layers = _get_clones(rel_decoder_layer, num_layers) 212 | self.num_layers = num_layers 213 | if interaction_layer is not None: 214 | self.rel_interaction_layers = _get_clones(interaction_layer, num_layers) 215 | else: 216 | self.rel_interaction_layers = None 217 | self.norm = norm 218 | self.rel_norm = rel_norm 219 | self.return_intermediate = return_intermediate 220 | 221 | def forward(self, tgt, rel_tgt, memory, 222 | tgt_mask: Optional[Tensor] = None, 223 | memory_mask: Optional[Tensor] = None, 224 | tgt_key_padding_mask: Optional[Tensor] = None, 225 | memory_key_padding_mask: Optional[Tensor] = None, 226 | pos: Optional[Tensor] = None, 227 | query_pos: Optional[Tensor] = None, 228 | rel_query_pos: Optional[Tensor] = None): 229 | output = tgt 230 | rel_output = rel_tgt 231 | 232 | intermediate = [] 233 | rel_intermediate = [] 234 | 235 | for i in range(self.num_layers): 236 | # instance decoder layer 237 | output = self.layers[i](output, memory, tgt_mask=tgt_mask, 238 | memory_mask=memory_mask, 239 | tgt_key_padding_mask=tgt_key_padding_mask, 240 | memory_key_padding_mask=memory_key_padding_mask, 241 | pos=pos, query_pos=query_pos) 242 | # interaction decoder layer 243 | rel_output = self.rel_layers[i](rel_output, memory, tgt_mask=tgt_mask, 244 | memory_mask=memory_mask, 245 | tgt_key_padding_mask=tgt_key_padding_mask, 246 | memory_key_padding_mask=memory_key_padding_mask, 247 | pos=pos, query_pos=rel_query_pos) 248 | # instance-aware attention module 249 | if self.rel_interaction_layers is not None: 250 | output, rel_output = self.rel_interaction_layers[i]( 251 | output, rel_output 252 | ) 253 | # for aux loss 254 | if self.return_intermediate: 255 | intermediate.append(self.norm(output)) 256 | rel_intermediate.append(self.rel_norm(rel_output)) 257 | 258 | if self.norm is not None: 259 | output = self.norm(output) 260 | rel_output = self.rel_norm(rel_output) 261 | if self.return_intermediate: 262 | intermediate.pop() 263 | intermediate.append(output) 264 | rel_intermediate.pop() 265 | rel_intermediate.append(rel_output) 266 | 267 | if self.return_intermediate: 268 | return torch.stack(intermediate), torch.stack(rel_intermediate) 269 | 270 | return output, rel_output 271 | 272 | 273 | class TransformerDecoder(nn.Module): 274 | 275 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 276 | super().__init__() 277 | self.layers = _get_clones(decoder_layer, num_layers) 278 | self.num_layers = num_layers 279 | self.norm = norm 280 | self.return_intermediate = return_intermediate 281 | 282 | def forward(self, tgt, memory, 283 | tgt_mask: Optional[Tensor] = None, 284 | memory_mask: Optional[Tensor] = None, 285 | tgt_key_padding_mask: Optional[Tensor] = None, 286 | memory_key_padding_mask: Optional[Tensor] = None, 287 | pos: Optional[Tensor] = None, 288 | query_pos: Optional[Tensor] = None): 289 | output = tgt 290 | 291 | intermediate = [] 292 | 293 | for layer in self.layers: 294 | output = layer(output, memory, tgt_mask=tgt_mask, 295 | memory_mask=memory_mask, 296 | tgt_key_padding_mask=tgt_key_padding_mask, 297 | memory_key_padding_mask=memory_key_padding_mask, 298 | pos=pos, query_pos=query_pos) 299 | if self.return_intermediate: 300 | intermediate.append(self.norm(output)) 301 | 302 | if self.norm is not None: 303 | output = self.norm(output) 304 | if self.return_intermediate: 305 | intermediate.pop() 306 | intermediate.append(output) 307 | 308 | if self.return_intermediate: 309 | return torch.stack(intermediate) 310 | 311 | return output 312 | 313 | 314 | class TransformerEncoderLayer(nn.Module): 315 | 316 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 317 | activation="relu", normalize_before=False): 318 | super().__init__() 319 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 320 | # Implementation of Feedforward model 321 | self.linear1 = nn.Linear(d_model, dim_feedforward) 322 | self.dropout = nn.Dropout(dropout) 323 | self.linear2 = nn.Linear(dim_feedforward, d_model) 324 | 325 | self.norm1 = nn.LayerNorm(d_model) 326 | self.norm2 = nn.LayerNorm(d_model) 327 | self.dropout1 = nn.Dropout(dropout) 328 | self.dropout2 = nn.Dropout(dropout) 329 | 330 | self.activation = _get_activation_fn(activation) 331 | self.normalize_before = normalize_before 332 | 333 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 334 | return tensor if pos is None else tensor + pos 335 | 336 | def forward_post(self, 337 | src, 338 | src_mask: Optional[Tensor] = None, 339 | src_key_padding_mask: Optional[Tensor] = None, 340 | pos: Optional[Tensor] = None): 341 | q = k = self.with_pos_embed(src, pos) 342 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 343 | key_padding_mask=src_key_padding_mask)[0] 344 | src = src + self.dropout1(src2) 345 | src = self.norm1(src) 346 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 347 | src = src + self.dropout2(src2) 348 | src = self.norm2(src) 349 | return src 350 | 351 | def forward_pre(self, src, 352 | src_mask: Optional[Tensor] = None, 353 | src_key_padding_mask: Optional[Tensor] = None, 354 | pos: Optional[Tensor] = None): 355 | src2 = self.norm1(src) 356 | q = k = self.with_pos_embed(src2, pos) 357 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 358 | key_padding_mask=src_key_padding_mask)[0] 359 | src = src + self.dropout1(src2) 360 | src2 = self.norm2(src) 361 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 362 | src = src + self.dropout2(src2) 363 | return src 364 | 365 | def forward(self, src, 366 | src_mask: Optional[Tensor] = None, 367 | src_key_padding_mask: Optional[Tensor] = None, 368 | pos: Optional[Tensor] = None): 369 | if self.normalize_before: 370 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 371 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 372 | 373 | 374 | class TransformerDecoderLayer(nn.Module): 375 | 376 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 377 | activation="relu", normalize_before=False): 378 | super().__init__() 379 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 380 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 381 | # Implementation of Feedforward model 382 | self.linear1 = nn.Linear(d_model, dim_feedforward) 383 | self.dropout = nn.Dropout(dropout) 384 | self.linear2 = nn.Linear(dim_feedforward, d_model) 385 | 386 | self.norm1 = nn.LayerNorm(d_model) 387 | self.norm2 = nn.LayerNorm(d_model) 388 | self.norm3 = nn.LayerNorm(d_model) 389 | self.dropout1 = nn.Dropout(dropout) 390 | self.dropout2 = nn.Dropout(dropout) 391 | self.dropout3 = nn.Dropout(dropout) 392 | 393 | self.activation = _get_activation_fn(activation) 394 | self.normalize_before = normalize_before 395 | 396 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 397 | return tensor if pos is None else tensor + pos 398 | 399 | def forward_post(self, tgt, memory, 400 | tgt_mask: Optional[Tensor] = None, 401 | memory_mask: Optional[Tensor] = None, 402 | tgt_key_padding_mask: Optional[Tensor] = None, 403 | memory_key_padding_mask: Optional[Tensor] = None, 404 | pos: Optional[Tensor] = None, 405 | query_pos: Optional[Tensor] = None): 406 | q = k = self.with_pos_embed(tgt, query_pos) 407 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 408 | key_padding_mask=tgt_key_padding_mask)[0] 409 | tgt = tgt + self.dropout1(tgt2) 410 | tgt = self.norm1(tgt) 411 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 412 | key=self.with_pos_embed(memory, pos), 413 | value=memory, attn_mask=memory_mask, 414 | key_padding_mask=memory_key_padding_mask)[0] 415 | tgt = tgt + self.dropout2(tgt2) 416 | tgt = self.norm2(tgt) 417 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 418 | tgt = tgt + self.dropout3(tgt2) 419 | tgt = self.norm3(tgt) 420 | return tgt 421 | 422 | def forward_pre(self, tgt, memory, 423 | tgt_mask: Optional[Tensor] = None, 424 | memory_mask: Optional[Tensor] = None, 425 | tgt_key_padding_mask: Optional[Tensor] = None, 426 | memory_key_padding_mask: Optional[Tensor] = None, 427 | pos: Optional[Tensor] = None, 428 | query_pos: Optional[Tensor] = None): 429 | tgt2 = self.norm1(tgt) 430 | q = k = self.with_pos_embed(tgt2, query_pos) 431 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 432 | key_padding_mask=tgt_key_padding_mask)[0] 433 | tgt = tgt + self.dropout1(tgt2) 434 | tgt2 = self.norm2(tgt) 435 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 436 | key=self.with_pos_embed(memory, pos), 437 | value=memory, attn_mask=memory_mask, 438 | key_padding_mask=memory_key_padding_mask)[0] 439 | tgt = tgt + self.dropout2(tgt2) 440 | tgt2 = self.norm3(tgt) 441 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 442 | tgt = tgt + self.dropout3(tgt2) 443 | return tgt 444 | 445 | def forward(self, tgt, memory, 446 | tgt_mask: Optional[Tensor] = None, 447 | memory_mask: Optional[Tensor] = None, 448 | tgt_key_padding_mask: Optional[Tensor] = None, 449 | memory_key_padding_mask: Optional[Tensor] = None, 450 | pos: Optional[Tensor] = None, 451 | query_pos: Optional[Tensor] = None): 452 | if self.normalize_before: 453 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 454 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 455 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 456 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 457 | 458 | 459 | def _get_clones(module, N): 460 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 461 | 462 | 463 | def build_transformer(cfg): 464 | if cfg.TRANSFORMER.BRANCH_AGGREGATION is False: 465 | return Transformer( 466 | d_model=cfg.TRANSFORMER.HIDDEN_DIM, 467 | dropout=cfg.TRANSFORMER.DROPOUT, 468 | nhead=cfg.TRANSFORMER.NHEADS, 469 | dim_feedforward=cfg.TRANSFORMER.DIM_FEEDFORWARD, 470 | num_encoder_layers=cfg.TRANSFORMER.ENC_LAYERS, 471 | num_decoder_layers=cfg.TRANSFORMER.DEC_LAYERS, 472 | num_rel_decoder_layers=cfg.TRANSFORMER.DEC_LAYERS, 473 | normalize_before=cfg.TRANSFORMER.PRE_NORM, 474 | return_intermediate_dec=True, 475 | ) 476 | else: 477 | return InteractionTransformer( 478 | d_model=cfg.TRANSFORMER.HIDDEN_DIM, 479 | dropout=cfg.TRANSFORMER.DROPOUT, 480 | nhead=cfg.TRANSFORMER.NHEADS, 481 | dim_feedforward=cfg.TRANSFORMER.DIM_FEEDFORWARD, 482 | num_encoder_layers=cfg.TRANSFORMER.ENC_LAYERS, 483 | num_decoder_layers=cfg.TRANSFORMER.DEC_LAYERS, 484 | num_rel_decoder_layers=cfg.TRANSFORMER.DEC_LAYERS, 485 | normalize_before=cfg.TRANSFORMER.PRE_NORM, 486 | return_intermediate_dec=True, 487 | ) 488 | 489 | 490 | def _get_activation_fn(activation): 491 | """Return an activation function given a string""" 492 | if activation == "relu": 493 | return F.relu 494 | if activation == "gelu": 495 | return F.gelu 496 | if activation == "glu": 497 | return F.glu 498 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 499 | -------------------------------------------------------------------------------- /libs/trainer/hoi_trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import math 4 | import numpy as np 5 | import sys 6 | import time 7 | from tqdm import tqdm 8 | 9 | import torch 10 | from torch import autograd 11 | 12 | from libs.trainer.trainer import BaseTrainer 13 | import libs.utils.misc as utils 14 | from libs.utils.utils import save_checkpoint, write_dict_to_json 15 | 16 | 17 | class HOITrainer(BaseTrainer): 18 | 19 | def __init__(self, 20 | cfg, 21 | model, 22 | criterion, 23 | optimizer, 24 | lr_scheduler, 25 | postprocessors, 26 | log_dir='output', 27 | performance_indicator='mAP', 28 | last_iter=-1, 29 | rank=0, 30 | device='cuda', 31 | max_norm=0): 32 | 33 | super().__init__(cfg, model, criterion, optimizer, lr_scheduler, 34 | log_dir, performance_indicator, last_iter, rank) 35 | self.postprocessors = postprocessors 36 | self.device = device 37 | self.max_norm = max_norm 38 | 39 | def _read_inputs(self, inputs): 40 | imgs, targets, filenames = inputs 41 | imgs = [img.to(self.device) for img in imgs] 42 | # targets are list type in det tasks 43 | targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets] 44 | return imgs, targets 45 | 46 | def _forward(self, data): 47 | imgs = data[0] 48 | targets = data[1] 49 | outputs = self.model(imgs) 50 | loss_dict = self.criterion(outputs, targets) 51 | return loss_dict 52 | 53 | def train(self, train_loader, eval_loader): 54 | start_time = time.time() 55 | self.model.train() 56 | self.criterion.train() 57 | metric_logger = utils.MetricLogger(delimiter=" ") 58 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 59 | metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 60 | metric_logger.add_meter('rel_class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}')) 61 | header = 'Epoch: [{}]'.format(self.epoch) 62 | print_freq = self.cfg.TRAIN.PRINT_FREQ 63 | 64 | if self.epoch > self.max_epoch: 65 | logging.info("Optimization is done !") 66 | sys.exit(0) 67 | for data in metric_logger.log_every(train_loader, print_freq, header): 68 | data = self._read_inputs(data) 69 | loss_dict = self._forward(data) 70 | weight_dict = self.criterion.weight_dict 71 | losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict) 72 | 73 | # reduce losses over all GPUs for logging purposes 74 | loss_dict_reduced = utils.reduce_dict(loss_dict) 75 | loss_dict_reduced_unscaled = {f'{k}_unscaled': v 76 | for k, v in loss_dict_reduced.items()} 77 | loss_dict_reduced_scaled = {k: v * weight_dict[k] 78 | for k, v in loss_dict_reduced.items() if k in weight_dict} 79 | losses_reduced_scaled = sum(loss_dict_reduced_scaled.values()) 80 | 81 | loss_value = losses_reduced_scaled.item() 82 | 83 | if not math.isfinite(loss_value): 84 | print("Loss is {}, stopping training".format(loss_value)) 85 | print(loss_dict_reduced) 86 | sys.exit(1) 87 | 88 | self.optimizer.zero_grad() 89 | losses.backward() 90 | if self.max_norm > 0: 91 | torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.max_norm) 92 | self.optimizer.step() 93 | 94 | metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled) 95 | metric_logger.update(class_error=loss_dict_reduced['class_error']) 96 | metric_logger.update(rel_class_error=loss_dict_reduced['rel_class_error']) 97 | metric_logger.update(lr=self.optimizer.param_groups[0]["lr"]) 98 | 99 | # gather the stats from all processes 100 | metric_logger.synchronize_between_processes() 101 | print("Averaged stats:", metric_logger) 102 | train_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} 103 | log_stats = {**{f'train_{k}': v for k, v in train_stats.items()}, 104 | 'epoch': self.epoch} 105 | if self.rank == 0: 106 | for (key, val) in log_stats.items(): 107 | self.writer.add_scalar(key, val, log_stats['epoch']) 108 | self.lr_scheduler.step() 109 | 110 | # save checkpoint 111 | if self.rank == 0 and self.epoch > 0 and self.epoch % self.cfg.TRAIN.SAVE_INTERVAL == 0: 112 | # evaluation 113 | if self.cfg.TRAIN.VAL_WHEN_TRAIN: 114 | self.model.eval() 115 | performance = self.evaluate(eval_loader) 116 | self.writer.add_scalar(self.PI, performance, self.epoch) 117 | if performance > self.best_performance: 118 | self.is_best = True 119 | self.best_performance = performance 120 | else: 121 | self.is_best = False 122 | logging.info(f'Now: best {self.PI} is {self.best_performance}') 123 | else: 124 | performance = -1 125 | 126 | # save checkpoint 127 | try: 128 | state_dict = self.model.module.state_dict() # remove prefix of multi GPUs 129 | except AttributeError: 130 | state_dict = self.model.state_dict() 131 | 132 | if self.rank == 0: 133 | if self.cfg.TRAIN.SAVE_EVERY_CHECKPOINT: 134 | filename = f"{self.model_name}_epoch{self.epoch:03d}_checkpoint.pth" 135 | else: 136 | filename = "checkpoint.pth" 137 | save_checkpoint( 138 | { 139 | 'epoch': self.epoch, 140 | 'model': self.model_name, 141 | f'performance/{self.PI}': performance, 142 | 'state_dict': state_dict, 143 | 'optimizer': self.optimizer.state_dict(), 144 | }, 145 | self.is_best, 146 | self.log_dir, 147 | filename=f'{self.cfg.OUTPUT_ROOT}_{filename}' 148 | ) 149 | 150 | total_time = time.time() - start_time 151 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 152 | print('Training time {}'.format(total_time_str)) 153 | self.epoch += 1 154 | 155 | 156 | def evaluate(self, eval_loader, mode, rel_topk=100): 157 | self.model.eval() 158 | results = [] 159 | count = 0 160 | for data in tqdm(eval_loader): 161 | imgs, targets, filenames = data 162 | imgs = [img.to(self.device) for img in imgs] 163 | # targets are list type 164 | targets = [{k: v.to(self.device) for k, v in t.items()} for t in targets] 165 | bs = len(imgs) 166 | target_sizes = targets[0]['size'].expand(bs, 2) 167 | target_sizes = target_sizes.to(self.device) 168 | outputs_dict = self.model(imgs) 169 | file_name = filenames[0] 170 | pred_out = self.postprocessors(outputs_dict, file_name, target_sizes, 171 | rel_topk=rel_topk) 172 | results.append(pred_out) 173 | count += 1 174 | # save the result 175 | result_path = f'{self.cfg.OUTPUT_ROOT}/pred.json' 176 | write_dict_to_json(results, result_path) 177 | 178 | # eval 179 | if mode == 'hico': 180 | from eval_tools.hico_eval import hico 181 | eval_tool = hico(annotation_file='data/hico/test_hico.json', 182 | train_annotation='data/hico/trainval_hico.json') 183 | mAP = eval_tool.evalution(results) 184 | elif mode == 'hoia': 185 | from eval_tools.hoia_eval import hoia 186 | eval_tool = hoia(annotation_file='data/hoia/test_hoia.json') 187 | mAP = eval_tool.evalution(results) 188 | else: 189 | mAP = 0.0 190 | 191 | return mAP -------------------------------------------------------------------------------- /libs/trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | import time 4 | 5 | import torch 6 | from torch import autograd 7 | from tensorboardX import SummaryWriter 8 | 9 | from libs.utils.utils import AverageMeter, save_checkpoint 10 | 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, 14 | cfg, 15 | model, 16 | criterion, 17 | optimizer, 18 | lr_scheduler, 19 | log_dir, 20 | performance_indicator='mAP', 21 | last_iter=-1, 22 | rank=0): 23 | self.cfg = cfg 24 | self.model = model 25 | self.optimizer = optimizer 26 | self.lr_scheduler = lr_scheduler 27 | self.criterion = criterion 28 | self.log_dir = log_dir 29 | self.PI = performance_indicator 30 | self.rank = rank 31 | self.epoch = last_iter + 1 32 | self.best_performance = 0.0 33 | self.is_best = False 34 | self.max_epoch = self.cfg.TRAIN.MAX_EPOCH 35 | self.model_name = self.cfg.MODEL.NAME 36 | if self.optimizer is not None and rank == 0: 37 | self.writer = SummaryWriter(log_dir, comment=f'_rank{rank}') 38 | logging.info(f"max epochs = {self.max_epoch} ") 39 | 40 | def _read_inputs(self, inputs): 41 | imgs, targets, index = inputs 42 | imgs = imgs.cuda(non_blocking=True) 43 | targets = targets.cuda(non_blocking=True) 44 | return imgs, targets 45 | 46 | def _forward(self, data): 47 | imgs = data[0] 48 | targets = data[1] 49 | pred = self.model(imgs) 50 | 51 | loss, train_prec = self.criterion(pred, targets) 52 | return loss, train_prec 53 | 54 | def train(self, train_loader, eval_loader): 55 | losses_1 = AverageMeter() 56 | losses_2 = AverageMeter() 57 | data_time = AverageMeter() 58 | batch_time = AverageMeter() 59 | end_time = time.time() 60 | if self.epoch > self.max_epoch: 61 | logging.info("Optimization is done !") 62 | sys.exit(0) 63 | for data in train_loader: 64 | self.model.train() 65 | # forward 66 | data_time.update(time.time() - end_time) 67 | data = self._read_inputs(data) 68 | # get loss 69 | loss, train_prec = self._forward(data) 70 | if isinstance(loss, tuple): 71 | losses_1.update(loss[0].item()) 72 | losses_2.update(loss[1].item()) 73 | total_loss = sum(loss) 74 | else: 75 | losses_1.update(loss.item()) 76 | total_loss = loss 77 | # optimization 78 | self.optimizer.zero_grad() 79 | with autograd.detect_anomaly(): 80 | total_loss.backward() 81 | self.optimizer.step() 82 | self.lr_scheduler.step() 83 | # time for training(forward & loss computation & optimization) on one batch 84 | batch_time.update(time.time() - end_time) 85 | 86 | # log avg loss 87 | if self.epoch > 0 and self.epoch % self.cfg.TRAIN.PRINT_FREQ == 0: 88 | if isinstance(loss, tuple): 89 | self.writer.add_scalar('loss/cls', losses_1.avg, self.epoch) 90 | self.writer.add_scalar('loss/box', losses_2.avg, self.epoch) 91 | loss_msg = f'avg_cls_loss:{losses_1.avg:.04f} avg_box_loss:{losses_2.avg:.04f}' 92 | else: 93 | self.writer.add_scalar('loss', losses_1.avg, self.epoch) 94 | loss_msg = f'avg_loss:{losses_1.avg:.04f}' 95 | 96 | logging.info( 97 | f'epoch:{self.epoch:03d} ' 98 | f'{loss_msg:s} ' 99 | f'io_rate:{data_time.avg / batch_time.avg:.04f} ' 100 | f'samples/(gpu*s):{self.cfg.DATASET.IMG_NUM_PER_GPU / batch_time.avg:.02f}' 101 | ) 102 | 103 | self.writer.add_scalar('speed/samples_per_second_per_gpu', 104 | self.cfg.DATASET.IMG_NUM_PER_GPU / batch_time.avg, 105 | self.epoch) 106 | self.writer.add_scalar('speed/io_rate', 107 | data_time.avg / batch_time.avg, 108 | self.epoch) 109 | if train_prec is not None: 110 | logging.info(f'train precision: {train_prec}') 111 | losses_1.reset() 112 | losses_2.reset() 113 | 114 | # save checkpoint 115 | if self.epoch > 0 and self.epoch % self.cfg.TRAIN.SAVE_INTERVAL == 0: 116 | # evaluation 117 | if self.cfg.TRAIN.VAL_WHEN_TRAIN: 118 | self.model.eval() 119 | performance = self.evaluate(eval_loader) 120 | self.writer.add_scalar(self.PI, performance, self.epoch) 121 | if self.PI == 'triplet_loss' and performance < self.best_performance: 122 | self.is_best = True 123 | self.best_performance = performance 124 | elif performance > self.best_performance: 125 | self.is_best = True 126 | self.best_performance = performance 127 | else: 128 | self.is_best = False 129 | logging.info(f'Now: best {self.PI} is {self.best_performance}') 130 | else: 131 | performance = -1 132 | 133 | # save checkpoint 134 | try: 135 | state_dict = self.model.module.state_dict() # remove prefix of multi GPUs 136 | except AttributeError: 137 | state_dict = self.model.state_dict() 138 | 139 | if self.rank == 0: 140 | if self.cfg.TRAIN.SAVE_EVERY_CHECKPOINT: 141 | filename = f"{self.model_name}_epoch{self.epoch:03d}_iter{self.epoch:06d}_checkpoint.pth" 142 | else: 143 | filename = "checkpoint.pth" 144 | save_checkpoint( 145 | { 146 | 'epoch': self.epoch, 147 | 'model': self.model_name, 148 | f'performance/{self.PI}': performance, 149 | 'state_dict': state_dict, 150 | 'optimizer': self.optimizer.state_dict(), 151 | }, 152 | self.is_best, 153 | self.log_dir, 154 | filename=filename 155 | ) 156 | 157 | self.epoch += 1 158 | end_time = time.time() 159 | self.epoch += 1 160 | 161 | def evaluate(self, eval_loader): 162 | raise NotImplementedError -------------------------------------------------------------------------------- /libs/utils/box_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.ops.boxes import box_area 3 | 4 | 5 | def box_cxcywh_to_xyxy(x): 6 | x_c, y_c, w, h = x.unbind(-1) 7 | w = w.clamp(0) 8 | h = h.clamp(0) 9 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 10 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 11 | return torch.stack(b, dim=-1) 12 | 13 | 14 | def box_xyxy_to_cxcywh(x): 15 | x0, y0, x1, y1 = x.unbind(-1) 16 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 17 | (x1 - x0), (y1 - y0)] 18 | return torch.stack(b, dim=-1) 19 | 20 | 21 | # modified from torchvision to also return the union 22 | def box_iou(boxes1, boxes2): 23 | area1 = box_area(boxes1) 24 | area2 = box_area(boxes2) 25 | 26 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 27 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 28 | 29 | wh = (rb - lt).clamp(min=0) # [N,M,2] 30 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 31 | 32 | union = area1[:, None] + area2 - inter 33 | 34 | iou = inter / union 35 | return iou, union 36 | 37 | 38 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): 39 | """Calculate overlap between two set of bboxes. 40 | 41 | If ``is_aligned`` is ``False``, then calculate the ious between each bbox 42 | of bboxes1 and bboxes2, otherwise the ious between each aligned pair of 43 | bboxes1 and bboxes2. 44 | 45 | Args: 46 | bboxes1 (Tensor): shape (m, 4) in format. 47 | bboxes2 (Tensor): shape (n, 4) in format. 48 | If is_aligned is ``True``, then m and n must be equal. 49 | mode (str): "iou" (intersection over union) or iof (intersection over 50 | foreground). 51 | 52 | Returns: 53 | ious(Tensor): shape (m, n) if is_aligned == False else shape (m, 1) 54 | 55 | Example: 56 | >>> bboxes1 = torch.FloatTensor([ 57 | >>> [0, 0, 10, 10], 58 | >>> [10, 10, 20, 20], 59 | >>> [32, 32, 38, 42], 60 | >>> ]) 61 | >>> bboxes2 = torch.FloatTensor([ 62 | >>> [0, 0, 10, 20], 63 | >>> [0, 10, 10, 19], 64 | >>> [10, 10, 20, 20], 65 | >>> ]) 66 | >>> bbox_overlaps(bboxes1, bboxes2) 67 | tensor([[0.5238, 0.0500, 0.0041], 68 | [0.0323, 0.0452, 1.0000], 69 | [0.0000, 0.0000, 0.0000]]) 70 | 71 | Example: 72 | >>> empty = torch.FloatTensor([]) 73 | >>> nonempty = torch.FloatTensor([ 74 | >>> [0, 0, 10, 9], 75 | >>> ]) 76 | >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1) 77 | >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0) 78 | >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0) 79 | """ 80 | 81 | assert mode in ['iou', 'iof'] 82 | 83 | rows = bboxes1.size(0) 84 | cols = bboxes2.size(0) 85 | if is_aligned: 86 | assert rows == cols 87 | 88 | if rows * cols == 0: 89 | return bboxes1.new(rows, 1) if is_aligned else bboxes1.new(rows, cols) 90 | 91 | if is_aligned: 92 | lt = torch.max(bboxes1[:, :2], bboxes2[:, :2]) # [rows, 2] 93 | rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:]) # [rows, 2] 94 | 95 | wh = (rb - lt + 1).clamp(min=0) # [rows, 2] 96 | overlap = wh[:, 0] * wh[:, 1] 97 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( 98 | bboxes1[:, 3] - bboxes1[:, 1] + 1) 99 | 100 | if mode == 'iou': 101 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( 102 | bboxes2[:, 3] - bboxes2[:, 1] + 1) 103 | ious = overlap / (area1 + area2 - overlap) 104 | else: 105 | ious = overlap / area1 106 | else: 107 | lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [rows, cols, 2] 108 | rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [rows, cols, 2] 109 | 110 | wh = (rb - lt + 1).clamp(min=0) # [rows, cols, 2] 111 | overlap = wh[:, :, 0] * wh[:, :, 1] 112 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( 113 | bboxes1[:, 3] - bboxes1[:, 1] + 1) 114 | 115 | if mode == 'iou': 116 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( 117 | bboxes2[:, 3] - bboxes2[:, 1] + 1) 118 | ious = overlap / (area1[:, None] + area2 - overlap) 119 | else: 120 | ious = overlap / (area1[:, None]) 121 | 122 | return ious 123 | 124 | 125 | def generalized_box_iou(boxes1, boxes2): 126 | """ 127 | Generalized IoU from https://giou.stanford.edu/ 128 | 129 | The boxes should be in [x0, y0, x1, y1] format 130 | 131 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 132 | and M = len(boxes2) 133 | """ 134 | # degenerate boxes gives inf / nan results 135 | # so do an early check 136 | if (boxes1[:, 2:] < boxes1[:, :2]).any(): 137 | import pdb; pdb.set_trace() 138 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 139 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 140 | iou, union = box_iou(boxes1, boxes2) 141 | 142 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 143 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 144 | 145 | wh = (rb - lt).clamp(min=0) # [N,M,2] 146 | area = wh[:, :, 0] * wh[:, :, 1] 147 | 148 | return iou - (area - union) / area 149 | 150 | 151 | def masks_to_boxes(masks): 152 | """Compute the bounding boxes around the provided masks 153 | 154 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 155 | 156 | Returns a [N, 4] tensors, with the boxes in xyxy format 157 | """ 158 | if masks.numel() == 0: 159 | return torch.zeros((0, 4), device=masks.device) 160 | 161 | h, w = masks.shape[-2:] 162 | 163 | y = torch.arange(0, h, dtype=torch.float) 164 | x = torch.arange(0, w, dtype=torch.float) 165 | y, x = torch.meshgrid(y, x) 166 | 167 | x_mask = (masks * x.unsqueeze(0)) 168 | x_max = x_mask.flatten(1).max(-1)[0] 169 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 170 | 171 | y_mask = (masks * y.unsqueeze(0)) 172 | y_max = y_mask.flatten(1).max(-1)[0] 173 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 174 | 175 | return torch.stack([x_min, y_min, x_max, y_max], 1) 176 | -------------------------------------------------------------------------------- /libs/utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | import time 4 | from collections import defaultdict, deque 5 | import datetime 6 | import pickle 7 | from typing import Optional, List 8 | 9 | import torch 10 | import torch.distributed as dist 11 | from torch import Tensor 12 | 13 | # needed due to empty tensor bug in pytorch and torchvision 0.5 14 | import torchvision 15 | if float(torchvision.__version__[:3]) < 0.7: 16 | from torchvision.ops import _new_empty_tensor 17 | from torchvision.ops.misc import _output_size 18 | 19 | 20 | class SmoothedValue(object): 21 | """Track a series of values and provide access to smoothed values over a 22 | window or the global series average. 23 | """ 24 | 25 | def __init__(self, window_size=20, fmt=None): 26 | if fmt is None: 27 | fmt = "{median:.4f} ({global_avg:.4f})" 28 | self.deque = deque(maxlen=window_size) 29 | self.total = 0.0 30 | self.count = 0 31 | self.fmt = fmt 32 | 33 | def update(self, value, n=1): 34 | self.deque.append(value) 35 | self.count += n 36 | self.total += value * n 37 | 38 | def synchronize_between_processes(self): 39 | """ 40 | Warning: does not synchronize the deque! 41 | """ 42 | if not is_dist_avail_and_initialized(): 43 | return 44 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 45 | dist.barrier() 46 | dist.all_reduce(t) 47 | t = t.tolist() 48 | self.count = int(t[0]) 49 | self.total = t[1] 50 | 51 | @property 52 | def median(self): 53 | d = torch.tensor(list(self.deque)) 54 | return d.median().item() 55 | 56 | @property 57 | def avg(self): 58 | d = torch.tensor(list(self.deque), dtype=torch.float32) 59 | return d.mean().item() 60 | 61 | @property 62 | def global_avg(self): 63 | return self.total / self.count 64 | 65 | @property 66 | def max(self): 67 | return max(self.deque) 68 | 69 | @property 70 | def value(self): 71 | return self.deque[-1] 72 | 73 | def __str__(self): 74 | return self.fmt.format( 75 | median=self.median, 76 | avg=self.avg, 77 | global_avg=self.global_avg, 78 | max=self.max, 79 | value=self.value) 80 | 81 | 82 | def all_gather(data): 83 | """ 84 | Run all_gather on arbitrary picklable data (not necessarily tensors) 85 | Args: 86 | data: any picklable object 87 | Returns: 88 | list[data]: list of data gathered from each rank 89 | """ 90 | world_size = get_world_size() 91 | if world_size == 1: 92 | return [data] 93 | 94 | # serialized to a Tensor 95 | buffer = pickle.dumps(data) 96 | storage = torch.ByteStorage.from_buffer(buffer) 97 | tensor = torch.ByteTensor(storage).to("cuda") 98 | 99 | # obtain Tensor size of each rank 100 | local_size = torch.tensor([tensor.numel()], device="cuda") 101 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 102 | dist.all_gather(size_list, local_size) 103 | size_list = [int(size.item()) for size in size_list] 104 | max_size = max(size_list) 105 | 106 | # receiving Tensor from all ranks 107 | # we pad the tensor because torch all_gather does not support 108 | # gathering tensors of different shapes 109 | tensor_list = [] 110 | for _ in size_list: 111 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 112 | if local_size != max_size: 113 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 114 | tensor = torch.cat((tensor, padding), dim=0) 115 | dist.all_gather(tensor_list, tensor) 116 | 117 | data_list = [] 118 | for size, tensor in zip(size_list, tensor_list): 119 | buffer = tensor.cpu().numpy().tobytes()[:size] 120 | data_list.append(pickle.loads(buffer)) 121 | 122 | return data_list 123 | 124 | 125 | def reduce_dict(input_dict, average=True): 126 | """ 127 | Args: 128 | input_dict (dict): all the values will be reduced 129 | average (bool): whether to do average or sum 130 | Reduce the values in the dictionary from all processes so that all processes 131 | have the averaged results. Returns a dict with the same fields as 132 | input_dict, after reduction. 133 | """ 134 | world_size = get_world_size() 135 | if world_size < 2: 136 | return input_dict 137 | with torch.no_grad(): 138 | names = [] 139 | values = [] 140 | # sort the keys so that they are consistent across processes 141 | for k in sorted(input_dict.keys()): 142 | names.append(k) 143 | values.append(input_dict[k]) 144 | values = torch.stack(values, dim=0) 145 | dist.all_reduce(values) 146 | if average: 147 | values /= world_size 148 | reduced_dict = {k: v for k, v in zip(names, values)} 149 | return reduced_dict 150 | 151 | 152 | class MetricLogger(object): 153 | def __init__(self, delimiter="\t"): 154 | self.meters = defaultdict(SmoothedValue) 155 | self.delimiter = delimiter 156 | 157 | def update(self, **kwargs): 158 | for k, v in kwargs.items(): 159 | if isinstance(v, torch.Tensor): 160 | v = v.item() 161 | assert isinstance(v, (float, int)) 162 | self.meters[k].update(v) 163 | 164 | def __getattr__(self, attr): 165 | if attr in self.meters: 166 | return self.meters[attr] 167 | if attr in self.__dict__: 168 | return self.__dict__[attr] 169 | raise AttributeError("'{}' object has no attribute '{}'".format( 170 | type(self).__name__, attr)) 171 | 172 | def __str__(self): 173 | loss_str = [] 174 | for name, meter in self.meters.items(): 175 | loss_str.append( 176 | "{}: {}".format(name, str(meter)) 177 | ) 178 | return self.delimiter.join(loss_str) 179 | 180 | def synchronize_between_processes(self): 181 | for meter in self.meters.values(): 182 | meter.synchronize_between_processes() 183 | 184 | def add_meter(self, name, meter): 185 | self.meters[name] = meter 186 | 187 | def log_every(self, iterable, print_freq, header=None): 188 | i = 0 189 | if not header: 190 | header = '' 191 | start_time = time.time() 192 | end = time.time() 193 | iter_time = SmoothedValue(fmt='{avg:.4f}') 194 | data_time = SmoothedValue(fmt='{avg:.4f}') 195 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 196 | if torch.cuda.is_available(): 197 | log_msg = self.delimiter.join([ 198 | header, 199 | '[{0' + space_fmt + '}/{1}]', 200 | 'eta: {eta}', 201 | '{meters}', 202 | 'time: {time}', 203 | 'data: {data}', 204 | 'max mem: {memory:.0f}' 205 | ]) 206 | else: 207 | log_msg = self.delimiter.join([ 208 | header, 209 | '[{0' + space_fmt + '}/{1}]', 210 | 'eta: {eta}', 211 | '{meters}', 212 | 'time: {time}', 213 | 'data: {data}' 214 | ]) 215 | MB = 1024.0 * 1024.0 216 | for obj in iterable: 217 | data_time.update(time.time() - end) 218 | yield obj 219 | iter_time.update(time.time() - end) 220 | if i % print_freq == 0 or i == len(iterable) - 1: 221 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 222 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 223 | if torch.cuda.is_available(): 224 | if is_main_process(): 225 | print(log_msg.format( 226 | i, len(iterable), eta=eta_string, 227 | meters=str(self), 228 | time=str(iter_time), data=str(data_time), 229 | memory=torch.cuda.max_memory_allocated() / MB)) 230 | else: 231 | print(log_msg.format( 232 | i, len(iterable), eta=eta_string, 233 | meters=str(self), 234 | time=str(iter_time), data=str(data_time))) 235 | i += 1 236 | end = time.time() 237 | total_time = time.time() - start_time 238 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 239 | print('{} Total time: {} ({:.4f} s / it)'.format( 240 | header, total_time_str, total_time / len(iterable))) 241 | 242 | 243 | def get_sha(): 244 | cwd = os.path.dirname(os.path.abspath(__file__)) 245 | 246 | def _run(command): 247 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 248 | sha = 'N/A' 249 | diff = "clean" 250 | branch = 'N/A' 251 | try: 252 | sha = _run(['git', 'rev-parse', 'HEAD']) 253 | subprocess.check_output(['git', 'diff'], cwd=cwd) 254 | diff = _run(['git', 'diff-index', 'HEAD']) 255 | diff = "has uncommited changes" if diff else "clean" 256 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 257 | except Exception: 258 | pass 259 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 260 | return message 261 | 262 | 263 | def collate_fn(batch): 264 | batch = list(zip(*batch)) 265 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 266 | return tuple(batch) 267 | 268 | 269 | def _max_by_axis(the_list): 270 | # type: (List[List[int]]) -> List[int] 271 | maxes = the_list[0] 272 | for sublist in the_list[1:]: 273 | for index, item in enumerate(sublist): 274 | maxes[index] = max(maxes[index], item) 275 | return maxes 276 | 277 | 278 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 279 | # TODO make this more general 280 | if tensor_list[0].ndim == 3: 281 | # TODO make it support different-sized images 282 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 283 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 284 | batch_shape = [len(tensor_list)] + max_size 285 | b, c, h, w = batch_shape 286 | dtype = tensor_list[0].dtype 287 | device = tensor_list[0].device 288 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 289 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 290 | for img, pad_img, m in zip(tensor_list, tensor, mask): 291 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 292 | m[: img.shape[1], :img.shape[2]] = False 293 | else: 294 | raise ValueError('not supported') 295 | return NestedTensor(tensor, mask) 296 | 297 | 298 | class NestedTensor(object): 299 | def __init__(self, tensors, mask: Optional[Tensor]): 300 | self.tensors = tensors 301 | self.mask = mask 302 | 303 | def to(self, device): 304 | # type: (Device) -> NestedTensor # noqa 305 | cast_tensor = self.tensors.to(device) 306 | mask = self.mask 307 | if mask is not None: 308 | assert mask is not None 309 | cast_mask = mask.to(device) 310 | else: 311 | cast_mask = None 312 | return NestedTensor(cast_tensor, cast_mask) 313 | 314 | def decompose(self): 315 | return self.tensors, self.mask 316 | 317 | def __repr__(self): 318 | return str(self.tensors) 319 | 320 | 321 | def setup_for_distributed(is_master): 322 | """ 323 | This function disables printing when not in master process 324 | """ 325 | import builtins as __builtin__ 326 | builtin_print = __builtin__.print 327 | 328 | def print(*args, **kwargs): 329 | force = kwargs.pop('force', False) 330 | if is_master or force: 331 | builtin_print(*args, **kwargs) 332 | 333 | __builtin__.print = print 334 | 335 | 336 | def is_dist_avail_and_initialized(): 337 | if not dist.is_available(): 338 | return False 339 | if not dist.is_initialized(): 340 | return False 341 | return True 342 | 343 | 344 | def get_world_size(): 345 | if not is_dist_avail_and_initialized(): 346 | return 1 347 | return dist.get_world_size() 348 | 349 | 350 | def get_rank(): 351 | if not is_dist_avail_and_initialized(): 352 | return 0 353 | return dist.get_rank() 354 | 355 | 356 | def is_main_process(): 357 | return get_rank() == 0 358 | 359 | 360 | def save_on_master(*args, **kwargs): 361 | if is_main_process(): 362 | torch.save(*args, **kwargs) 363 | 364 | 365 | def init_distributed_mode(args): 366 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 367 | args.rank = int(os.environ["RANK"]) 368 | args.world_size = int(os.environ['WORLD_SIZE']) 369 | args.gpu = int(os.environ['LOCAL_RANK']) 370 | elif 'SLURM_PROCID' in os.environ: 371 | args.rank = int(os.environ['SLURM_PROCID']) 372 | args.gpu = args.rank % torch.cuda.device_count() 373 | else: 374 | print('Not using distributed mode') 375 | args.distributed = False 376 | return 377 | 378 | args.distributed = True 379 | 380 | torch.cuda.set_device(args.gpu) 381 | args.dist_backend = 'nccl' 382 | print('| distributed init (rank {}): {}'.format( 383 | args.rank, args.dist_url), flush=True) 384 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 385 | world_size=args.world_size, rank=args.rank) 386 | torch.distributed.barrier() 387 | setup_for_distributed(args.rank == 0) 388 | 389 | 390 | @torch.no_grad() 391 | def accuracy(output, target, topk=(1,)): 392 | """Computes the precision@k for the specified values of k""" 393 | if target.numel() == 0: 394 | return [torch.zeros([], device=output.device)] 395 | maxk = max(topk) 396 | batch_size = target.size(0) 397 | _, pred = output.topk(maxk, 1, True, True) 398 | pred = pred.t() 399 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 400 | 401 | res = [] 402 | for k in topk: 403 | correct_k = correct[:k].view(-1).float().sum(0) 404 | res.append(correct_k.mul_(100.0 / batch_size)) 405 | return res 406 | 407 | 408 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 409 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 410 | """ 411 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 412 | This will eventually be supported natively by PyTorch, and this 413 | class can go away. 414 | """ 415 | if float(torchvision.__version__[:3]) < 0.7: 416 | if input.numel() > 0: 417 | return torch.nn.functional.interpolate( 418 | input, size, scale_factor, mode, align_corners 419 | ) 420 | 421 | output_shape = _output_size(2, input, size, scale_factor) 422 | output_shape = list(input.shape[:-2]) + list(output_shape) 423 | return _new_empty_tensor(input, output_shape) 424 | else: 425 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 426 | -------------------------------------------------------------------------------- /libs/utils/utils.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import numpy as np 4 | import os 5 | import os.path as osp 6 | import time 7 | 8 | import torch 9 | import torch.optim as optim 10 | from torch.optim.lr_scheduler import _LRScheduler 11 | from torch.utils.data import ConcatDataset 12 | 13 | from bisect import bisect_right 14 | from functools import partial 15 | from six.moves import map, zip 16 | 17 | from libs.datasets.transform import TrainTransform 18 | from libs.datasets.transform import EvalTransform 19 | 20 | 21 | class AverageMeter(object): 22 | """Computes and stores the average and current value 23 | """ 24 | 25 | def __init__(self): 26 | self.reset() 27 | 28 | def reset(self): 29 | self.val = 0 30 | self.avg = 0 31 | self.sum = 0 32 | self.count = 0 33 | 34 | def update(self, val, n=1): 35 | self.val = val 36 | self.sum += val * n 37 | self.count += n 38 | self.avg = self.sum / self.count 39 | 40 | def resource_path(relative_path): 41 | """To get the absolute path""" 42 | base_path = osp.abspath(".") 43 | 44 | return osp.join(base_path, relative_path) 45 | 46 | 47 | def ensure_dir(root_dir, rank=0): 48 | if not osp.exists(root_dir) and rank == 0: 49 | print(f'=> creating {root_dir}') 50 | os.mkdir(root_dir) 51 | else: 52 | while not osp.exists(root_dir): 53 | print(f'=> wait for {root_dir} created') 54 | time.sleep(10) 55 | 56 | return root_dir 57 | 58 | 59 | def create_logger(cfg, rank=0): 60 | # working_dir root 61 | abs_working_dir = resource_path('work_dirs') 62 | working_dir = ensure_dir(abs_working_dir, rank) 63 | # output_dir root 64 | output_root_dir = ensure_dir(os.path.join(working_dir, cfg.OUTPUT_ROOT), rank) 65 | time_str = time.strftime('%Y-%m-%d-%H-%M') 66 | final_output_dir = ensure_dir(os.path.join(output_root_dir, time_str), rank) 67 | # set up logger 68 | logger = setup_logger(final_output_dir, time_str, rank) 69 | 70 | return logger, final_output_dir 71 | 72 | 73 | def setup_logger(final_output_dir, time_str, rank, phase='train'): 74 | log_file = f'{phase}_{time_str}_rank{rank}.log' 75 | final_log_file = os.path.join(final_output_dir, log_file) 76 | head = '%(asctime)-15s %(message)s' 77 | logging.basicConfig(filename=str(final_log_file), format=head) 78 | logger = logging.getLogger() 79 | logger.setLevel(logging.INFO) 80 | console = logging.StreamHandler() 81 | logging.getLogger('').addHandler(console) 82 | 83 | return logger 84 | 85 | 86 | def get_model(cfg, device): 87 | module = importlib.import_module(cfg.MODEL.FILE) 88 | model, criterion, postprocessors = getattr(module, 'build_model')(cfg, device) 89 | 90 | return model, criterion, postprocessors 91 | 92 | 93 | def get_optimizer(cfg, model): 94 | """Support two types of optimizers: SGD, Adam. 95 | """ 96 | assert (cfg.TRAIN.OPTIMIZER in [ 97 | 'sgd', 98 | 'adam', 99 | ]) 100 | if cfg.TRAIN.OPTIMIZER == 'sgd': 101 | optimizer = optim.SGD( 102 | filter(lambda p: p.requires_grad, model.parameters()), 103 | lr=cfg.TRAIN.LR, 104 | momentum=cfg.TRAIN.MOMENTUM, 105 | weight_decay=cfg.TRAIN.WEIGHT_DECAY, 106 | nesterov=cfg.TRAIN.NESTEROV) 107 | elif cfg.TRAIN.OPTIMIZER == 'adam': 108 | optimizer = optim.Adam( 109 | filter(lambda p: p.requires_grad, model.parameters()), 110 | lr=cfg.TRAIN.LR, 111 | weight_decay=cfg.TRAIN.WEIGHT_DECAY) 112 | 113 | return optimizer 114 | 115 | 116 | def load_checkpoint(cfg, model, optimizer, lr_scheduler, device, module_name='model'): 117 | last_iter = -1 118 | resume_path = cfg.MODEL.RESUME_PATH 119 | resume = cfg.TRAIN.RESUME 120 | if resume_path and resume: 121 | if osp.exists(resume_path): 122 | checkpoint = torch.load(resume_path, map_location='cpu') 123 | # resume 124 | if 'state_dict' in checkpoint: 125 | model.module.load_state_dict(checkpoint['state_dict'], strict=False) 126 | logging.info(f'==> model pretrained from {resume_path} \n') 127 | elif 'model' in checkpoint: 128 | if module_name == 'detr': 129 | model.module.detr_head.load_state_dict(checkpoint['model'], strict=False) 130 | logging.info(f'==> detr pretrained from {resume_path} \n') 131 | else: 132 | model.module.load_state_dict(checkpoint['model'], strict=False) 133 | logging.info(f'==> model pretrained from {resume_path} \n') 134 | if 'optimizer' in checkpoint: 135 | optimizer.load_state_dict(checkpoint['optimizer']) 136 | logging.info(f'==> optimizer resumed, continue training') 137 | for state in optimizer.state.values(): 138 | for k, v in state.items(): 139 | if torch.is_tensor(v): 140 | state[k] = v.to(device) 141 | if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: 142 | lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) 143 | last_iter = checkpoint['epoch'] 144 | logging.info(f'==> last_epoch = {last_iter}') 145 | if 'epoch' in checkpoint: 146 | last_iter = checkpoint['epoch'] 147 | logging.info(f'==> last_epoch = {last_iter}') 148 | # pre-train 149 | else: 150 | logging.error(f"==> checkpoint do not exists: \"{resume_path}\"") 151 | raise FileNotFoundError 152 | else: 153 | logging.info("==> train model without resume") 154 | 155 | return model, optimizer, lr_scheduler, last_iter 156 | 157 | 158 | class WarmupMultiStepLR(_LRScheduler): 159 | def __init__(self, optimizer, milestones, gamma=0.1, warmup_factor=1.0 / 3, 160 | warmup_iters=500, last_epoch=-1): 161 | if not list(milestones) == sorted(milestones): 162 | raise ValueError( 163 | "Milestones should be a list of" " increasing integers. Got {}", 164 | milestones, 165 | ) 166 | 167 | self.milestones = milestones 168 | self.gamma = gamma 169 | self.warmup_factor = warmup_factor 170 | self.warmup_iters = warmup_iters 171 | super().__init__(optimizer, last_epoch) 172 | 173 | def get_lr(self): 174 | warmup_factor = 1 175 | if self.last_epoch < self.warmup_iters: 176 | alpha = float(self.last_epoch) / self.warmup_iters 177 | warmup_factor = self.warmup_factor * (1 - alpha) + alpha 178 | return [ 179 | base_lr 180 | * warmup_factor 181 | * self.gamma ** bisect_right(self.milestones, self.last_epoch) 182 | for base_lr in self.base_lrs 183 | ] 184 | 185 | 186 | def get_lr_scheduler(cfg, optimizer, last_epoch=-1): 187 | """Support three types of optimizers: StepLR, MultiStepLR, MultiStepWithWarmup. 188 | """ 189 | assert (cfg.TRAIN.LR_SCHEDULER in [ 190 | 'StepLR', 191 | 'MultiStepLR', 192 | 'MultiStepWithWarmup', 193 | ]) 194 | if cfg.TRAIN.LR_SCHEDULER == 'StepLR': 195 | lr_scheduler = torch.optim.lr_scheduler.StepLR( 196 | optimizer, 197 | cfg.TRAIN.LR_STEPS[0], 198 | cfg.TRAIN.LR_FACTOR, 199 | last_epoch=last_epoch) 200 | elif cfg.TRAIN.LR_SCHEDULER == 'MultiStepLR': 201 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( 202 | optimizer, 203 | cfg.TRAIN.LR_STEPS, 204 | cfg.TRAIN.LR_FACTOR, 205 | last_epoch=last_epoch) 206 | elif cfg.TRAIN.LR_SCHEDULER == 'MultiStepWithWarmup': 207 | lr_scheduler = WarmupMultiStepLR( 208 | optimizer, 209 | cfg.TRAIN.LR_STEPS, 210 | cfg.TRAIN.LR_FACTOR, 211 | cfg.TRAIN.WARMUP_INIT_FACTOR, 212 | cfg.TRAIN.WARMUP_STEP, 213 | last_epoch) 214 | else: 215 | raise AttributeError(f'{cfg.TRAIN.LR_SCHEDULER} is not implemented') 216 | 217 | return lr_scheduler 218 | 219 | 220 | def get_det_criterion(cfg): 221 | 222 | return critertion 223 | 224 | def get_trainer(cfg, model, criterion, optimizer, lr_scheduler, postprocessors, 225 | log_dir, performance_indicator, last_iter, rank, device, max_norm): 226 | module = importlib.import_module(cfg.TRAINER.FILE) 227 | Trainer = getattr(module, cfg.TRAINER.NAME)( 228 | cfg, 229 | model=model, 230 | criterion=criterion, 231 | optimizer=optimizer, 232 | lr_scheduler=lr_scheduler, 233 | postprocessors=postprocessors, 234 | log_dir=log_dir, 235 | performance_indicator=performance_indicator, 236 | last_iter=last_iter, 237 | rank=rank, 238 | device=device, 239 | max_norm = max_norm 240 | ) 241 | return Trainer 242 | 243 | def list_to_set(data_list, name='train'): 244 | if len(data_list) == 0: 245 | dataset = None 246 | logging.warning(f"{name} dataset is None") 247 | elif len(data_list) == 1: 248 | dataset = data_list[0] 249 | else: 250 | dataset = ConcatDataset(data_list) 251 | 252 | if dataset is not None: 253 | logging.info(f'==> the size of {name} dataset is {len(dataset)}') 254 | return dataset 255 | 256 | def get_dataset(cfg): 257 | train_transform = TrainTransform( 258 | mean=cfg.DATASET.MEAN, 259 | std=cfg.DATASET.STD, 260 | scales=cfg.DATASET.SCALES, 261 | max_size=cfg.DATASET.MAX_SIZE 262 | ) 263 | eval_transform = EvalTransform( 264 | mean=cfg.DATASET.MEAN, 265 | std=cfg.DATASET.STD, 266 | max_size=cfg.DATASET.MAX_SIZE 267 | ) 268 | module = importlib.import_module(cfg.DATASET.FILE) 269 | Dataset = getattr(module, cfg.DATASET.NAME) 270 | data_root = cfg.DATASET.ROOT # abs path in yaml 271 | # get train data list 272 | train_root = osp.join(data_root, 'train') 273 | train_set = [d for d in os.listdir(train_root) if osp.isdir(osp.join(train_root, d))] 274 | if len(train_set) == 0: 275 | train_set = ['.'] 276 | train_list = [] 277 | for sub_set in train_set: 278 | train_sub_root = osp.join(train_root, sub_set) 279 | logging.info(f'==> load train sub set: {train_sub_root}') 280 | train_sub_set = Dataset(cfg, train_sub_root, train_transform) 281 | train_list.append(train_sub_set) 282 | # get eval data list 283 | eval_root = osp.join(data_root, 'test') 284 | eval_set = [d for d in os.listdir(eval_root) if osp.isdir(osp.join(eval_root, d))] 285 | if len(eval_set) == 0: 286 | eval_set = ['.'] 287 | eval_list = [] 288 | for sub_set in eval_set: 289 | eval_sub_root = osp.join(eval_root, sub_set) 290 | logging.info(f'==> load val sub set: {eval_sub_root}') 291 | eval_sub_set = Dataset(cfg, eval_sub_root, eval_transform) 292 | eval_list.append(eval_sub_set) 293 | # concat dataset list 294 | train_dataset = list_to_set(train_list, 'train') 295 | eval_dataset = list_to_set(eval_list, 'eval') 296 | 297 | return train_dataset, eval_dataset 298 | 299 | def save_checkpoint(states, is_best, output_dir, filename='checkpoint.pth'): 300 | torch.save(states, os.path.join(output_dir, filename)) 301 | logging.info(f'save model to {output_dir}') 302 | if is_best: 303 | torch.save(states['state_dict'], os.path.join(output_dir, 'model_best.pth')) 304 | 305 | def load_eval_model(resume_path, model): 306 | if resume_path != '': 307 | if osp.exists(resume_path): 308 | print(f'==> model load from {resume_path}') 309 | checkpoint = torch.load(resume_path) 310 | if 'state_dict' in checkpoint: 311 | model.load_state_dict(checkpoint['state_dict']) 312 | else: 313 | model.load_state_dict(checkpoint) 314 | else: 315 | print(f"==> checkpoint do not exists: \"{resume_path}\"") 316 | raise FileNotFoundError 317 | return model 318 | 319 | def multi_apply(func, *args, **kwargs): 320 | pfunc = partial(func, **kwargs) if kwargs else func 321 | map_results = map(pfunc, *args) 322 | return tuple(map(list, zip(*map_results))) 323 | 324 | def naive_np_nms(dets, thresh): 325 | """Pure Python NMS baseline.""" 326 | x1 = dets[:, 0] 327 | y1 = dets[:, 1] 328 | x2 = dets[:, 2] 329 | y2 = dets[:, 3] 330 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 331 | order = x1.argsort()[::-1] 332 | keep = [] 333 | while order.size > 0: 334 | i = order[0] 335 | keep.append(i) 336 | xx1 = np.maximum(x1[i], x1[order[1:]]) 337 | yy1 = np.maximum(y1[i], y1[order[1:]]) 338 | xx2 = np.minimum(x2[i], x2[order[1:]]) 339 | yy2 = np.minimum(y2[i], y2[order[1:]]) 340 | w = np.maximum(0.0, xx2 - xx1 + 1) 341 | h = np.maximum(0.0, yy2 - yy1 + 1) 342 | inter = w * h 343 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 344 | inds = np.where(ovr <= thresh)[0] 345 | order = order[inds + 1] 346 | return dets[keep] 347 | 348 | 349 | def write_dict_to_json(mydict, f_path): 350 | import json 351 | import numpy 352 | class DateEnconding(json.JSONEncoder): 353 | def default(self, obj): 354 | if isinstance(obj, (numpy.int_, numpy.intc, numpy.intp, numpy.int8, 355 | numpy.int16, numpy.int32, numpy.int64, numpy.uint8, 356 | numpy.uint16,numpy.uint32, numpy.uint64)): 357 | return int(obj) 358 | elif isinstance(obj, (numpy.float_, numpy.float16, numpy.float32, 359 | numpy.float64)): 360 | return float(obj) 361 | elif isinstance(obj, (numpy.ndarray,)): # add this line 362 | return obj.tolist() # add this line 363 | return json.JSONEncoder.default(self, obj) 364 | with open(f_path, 'w') as f: 365 | json.dump(mydict, f, cls=DateEnconding) 366 | print("write down det dict to %s!" %(f_path)) 367 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.5.2 2 | yacs==0.1.7 3 | torch==1.4.0 4 | mmcv==1.0.4 5 | numpy==1.19.1 6 | torchvision==0.5.0 7 | tqdm==4.48.2 8 | Pillow==8.0.1 9 | tensorboardX==2.1 10 | -------------------------------------------------------------------------------- /tools/_init_paths.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | 4 | 5 | def add_path(path): 6 | if path not in sys.path: 7 | sys.path.insert(0, path) 8 | 9 | 10 | this_dir = osp.dirname(__file__) 11 | 12 | project_path = osp.dirname(this_dir) 13 | dataset_path = osp.join(project_path, 'libs', 'datasets') 14 | trainer_path = osp.join(project_path, 'libs', 'trainer') 15 | model_path = osp.join(project_path, 'libs', 'models') 16 | 17 | add_path(project_path) 18 | add_path(dataset_path) 19 | add_path(trainer_path) 20 | add_path(model_path) 21 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from __future__ import print_function 3 | from __future__ import with_statement 4 | 5 | import argparse 6 | import importlib 7 | import logging 8 | import os 9 | 10 | import torch 11 | 12 | import _init_paths 13 | from configs import cfg 14 | from configs import update_config 15 | from libs.datasets.collate import collect 16 | from libs.datasets.transform import EvalTransform 17 | 18 | 19 | def parse_args(): 20 | parser = argparse.ArgumentParser(description='HOI Detection Task') 21 | parser.add_argument( 22 | '--cfg', 23 | dest='yaml_file', 24 | help='experiment configure file name, e.g. configs/hico.yaml', 25 | required=True, 26 | type=str) 27 | parser.add_argument( 28 | 'opts', 29 | help="Modify config options using the command-line", 30 | default=None, 31 | nargs=argparse.REMAINDER) 32 | args = parser.parse_args() 33 | 34 | return args 35 | 36 | 37 | def main_per_worker(): 38 | args = parse_args() 39 | update_config(cfg, args) 40 | ngpus_per_node = torch.cuda.device_count() 41 | device = torch.device(cfg.DEVICE) 42 | 43 | if not os.path.exists(cfg.OUTPUT_ROOT): 44 | os.makedirs(cfg.OUTPUT_ROOT) 45 | logging.basicConfig(filename=f'{cfg.OUTPUT_ROOT}/eval.log', level=logging.INFO) 46 | 47 | # model 48 | module = importlib.import_module(cfg.MODEL.FILE) 49 | model, criterion, postprocessors = getattr(module, 'build_model')(cfg, device) 50 | model = torch.nn.DataParallel(model).to(device) 51 | model_without_ddp = model.module 52 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 53 | print('number of params:', n_parameters) 54 | # load model checkpoints 55 | resume_path = cfg.MODEL.RESUME_PATH 56 | if os.path.exists(resume_path): 57 | checkpoint = torch.load(resume_path, map_location='cpu') 58 | # resume 59 | if 'state_dict' in checkpoint: 60 | model.module.load_state_dict(checkpoint['state_dict'], strict=True) 61 | logging.info(f'==> model pretrained from {resume_path}') 62 | 63 | # get datset 64 | module = importlib.import_module(cfg.DATASET.FILE) 65 | Dataset = getattr(module, cfg.DATASET.NAME) 66 | data_root = os.path.join(cfg.DATASET.ROOT, 'test') 67 | if not os.path.exists(data_root): 68 | logging.info(f'==> Cannot found data: {data_root}') 69 | raise FileNotFoundError 70 | eval_transform = EvalTransform( 71 | mean=cfg.DATASET.MEAN, 72 | std=cfg.DATASET.STD, 73 | max_size=cfg.DATASET.MAX_SIZE 74 | ) 75 | logging.info(f'==> load val sub set: {data_root}') 76 | eval_dataset = Dataset(cfg, data_root, eval_transform) 77 | if eval_dataset is not None: 78 | logging.info(f'==> the size of eval dataset is {len(eval_dataset)}') 79 | eval_loader = torch.utils.data.DataLoader( 80 | eval_dataset, 81 | batch_size=1, 82 | shuffle=False, 83 | drop_last=False, 84 | collate_fn=collect, 85 | num_workers=cfg.WORKERS 86 | ) 87 | 88 | # start evaluate in Trainer 89 | module = importlib.import_module(cfg.TRAINER.FILE) 90 | Trainer = getattr(module, cfg.TRAINER.NAME)( 91 | cfg, 92 | model, 93 | criterion=criterion, 94 | optimizer=None, 95 | lr_scheduler=None, 96 | postprocessors=postprocessors, 97 | log_dir=cfg.OUTPUT_ROOT+'/output', 98 | performance_indicator=cfg.PI, 99 | last_iter=-1, 100 | rank=0, 101 | device=device, 102 | max_norm=None 103 | ) 104 | logging.info(f'==> start eval...') 105 | 106 | assert cfg.TEST.MODE in ['hico', 'hoia'] 107 | Trainer.evaluate(eval_loader, cfg.TEST.MODE) 108 | 109 | 110 | if __name__ == '__main__': 111 | main_per_worker() 112 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # ------------------------------------------------------------------------------ 3 | # Created by Mingfei Chen (lasiafly@gmail.com) 4 | # Created On: 2020-7-24 5 | # ------------------------------------------------------------------------------ 6 | from __future__ import division 7 | from __future__ import print_function 8 | from __future__ import with_statement 9 | 10 | import argparse 11 | import os 12 | 13 | import numpy as np 14 | import random 15 | 16 | import pprint 17 | import torch 18 | import torch.backends.cudnn as cudnn 19 | import torch.distributed as dist 20 | 21 | import _init_paths 22 | from configs import cfg 23 | from configs import update_config 24 | 25 | from libs.datasets.collate import collect 26 | from libs.utils import misc 27 | from libs.utils.utils import create_logger 28 | from libs.utils.utils import get_model 29 | from libs.utils.utils import get_dataset 30 | from libs.utils.utils import get_trainer 31 | from libs.utils.utils import load_checkpoint 32 | 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description='HOI Transformer Task') 36 | parser.add_argument( 37 | '--cfg', 38 | dest='yaml_file', 39 | help='experiment configure file name, e.g. configs/fcos_detector.yaml', 40 | required=True, 41 | type=str) 42 | # default distributed training 43 | parser.add_argument( 44 | '--distributed', 45 | action='store_true', 46 | default=False, 47 | help='if use distribute train') 48 | 49 | parser.add_argument( 50 | '--dist-url', 51 | dest='dist_url', 52 | default='tcp://10.5.38.36:23456', 53 | type=str, 54 | help='url used to set up distributed training') 55 | parser.add_argument( 56 | '--world-size', 57 | dest='world_size', 58 | default=1, 59 | type=int, 60 | help='number of nodes for distributed training') 61 | parser.add_argument( 62 | '--rank', 63 | default=0, 64 | type=int, 65 | help='node rank for distributed training, machine level') 66 | 67 | parser.add_argument( 68 | 'opts', 69 | help="Modify config options using the command-line", 70 | default=None, 71 | nargs=argparse.REMAINDER) 72 | args = parser.parse_args() 73 | 74 | return args 75 | 76 | def get_ip(ip_addr): 77 | ip_list = ip_addr.split('-')[2:6] 78 | for i in range(4): 79 | if ip_list[i][0] == '[': 80 | ip_list[i] = ip_list[i][1:].split(',')[0] 81 | return f'tcp://{ip_list[0]}.{ip_list[1]}.{ip_list[2]}.{ip_list[3]}:23456' 82 | 83 | def main_per_worker(): 84 | args = parse_args() 85 | 86 | update_config(cfg, args) 87 | ngpus_per_node = torch.cuda.device_count() 88 | 89 | print(cfg.OUTPUT_ROOT) 90 | if 'SLURM_PROCID' in os.environ.keys(): 91 | proc_rank = int(os.environ['SLURM_PROCID']) 92 | local_rank = proc_rank % ngpus_per_node 93 | args.world_size = int(os.environ['SLURM_NTASKS']) 94 | else: 95 | proc_rank = 0 96 | local_rank = 0 97 | args.world_size = 1 98 | 99 | args.distributed = (args.world_size > 1 or args.distributed) 100 | 101 | #create logger 102 | if proc_rank == 0: 103 | logger, output_dir = create_logger(cfg, proc_rank) 104 | 105 | # distribution 106 | if args.distributed: 107 | dist_url = get_ip(os.environ['SLURM_STEP_NODELIST']) 108 | if proc_rank == 0: 109 | logger.info( 110 | f'Init process group: dist_url: {dist_url}, ' 111 | f'world_size: {args.world_size}, ' 112 | f'proc_rank: {proc_rank}, ' 113 | f'local_rank:{local_rank}' 114 | ) 115 | dist.init_process_group( 116 | backend=cfg.DIST_BACKEND, 117 | init_method=dist_url, 118 | world_size=args.world_size, 119 | rank=proc_rank 120 | ) 121 | torch.distributed.barrier() 122 | # torch seed 123 | seed = cfg.SEED + misc.get_rank() 124 | torch.manual_seed(seed) 125 | np.random.seed(seed) 126 | random.seed(seed) 127 | torch.backends.cudnn.deterministic = True 128 | torch.backends.cudnn.benchmark = False 129 | 130 | torch.cuda.set_device(local_rank) 131 | device = torch.device(cfg.DEVICE) 132 | model, criterion, postprocessors = get_model(cfg, device) 133 | model.to(device) 134 | model = torch.nn.parallel.DistributedDataParallel( 135 | model, device_ids=[local_rank], output_device=local_rank, 136 | find_unused_parameters=True 137 | ) 138 | train_dataset, eval_dataset = get_dataset(cfg) 139 | train_sampler = torch.utils.data.distributed.DistributedSampler( 140 | train_dataset 141 | ) 142 | batch_size = cfg.DATASET.IMG_NUM_PER_GPU 143 | 144 | else: 145 | assert proc_rank == 0, ('proc_rank != 0, it will influence ' 146 | 'the evaluation procedure') 147 | # torch seed 148 | seed = cfg.SEED 149 | torch.manual_seed(seed) 150 | np.random.seed(seed) 151 | random.seed(seed) 152 | torch.backends.cudnn.deterministic = True 153 | torch.backends.cudnn.benchmark = False 154 | 155 | if cfg.DEVICE == 'cuda': 156 | torch.cuda.set_device(local_rank) 157 | device = torch.device(cfg.DEVICE) 158 | model, criterion, postprocessors = get_model(cfg, device) 159 | model = torch.nn.DataParallel(model).to(device) 160 | train_dataset, eval_dataset = get_dataset(cfg) 161 | train_sampler = None 162 | if ngpus_per_node == 0: 163 | batch_size = cfg.DATASET.IMG_NUM_PER_GPU 164 | else: 165 | batch_size = cfg.DATASET.IMG_NUM_PER_GPU * ngpus_per_node 166 | 167 | model_without_ddp = model.module 168 | n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) 169 | print('number of params:', n_parameters) 170 | 171 | param_dicts = [ 172 | {"params": [p for n, p in model_without_ddp.named_parameters() 173 | if "rel" in n and p.requires_grad]}, 174 | { 175 | "params": [p for n, p in model_without_ddp.named_parameters() 176 | if "rel" not in n and p.requires_grad], 177 | "lr": cfg.TRAIN.LR_BACKBONE, 178 | }, 179 | ] 180 | optimizer = torch.optim.AdamW(param_dicts, lr=cfg.TRAIN.LR, 181 | weight_decay=cfg.TRAIN.WEIGHT_DECAY) 182 | lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, cfg.TRAIN.LR_DROP) 183 | model, optimizer, lr_scheduler, last_iter = load_checkpoint(cfg, model, 184 | optimizer, lr_scheduler, device) 185 | 186 | 187 | train_loader = torch.utils.data.DataLoader( 188 | train_dataset, 189 | batch_size=batch_size, 190 | # shuffle=False, 191 | shuffle=(train_sampler is None), 192 | drop_last=True, 193 | collate_fn=collect, 194 | num_workers=cfg.WORKERS, 195 | pin_memory=True, 196 | sampler=train_sampler 197 | ) 198 | 199 | eval_loader = torch.utils.data.DataLoader( 200 | eval_dataset, 201 | batch_size=batch_size, 202 | shuffle=False, 203 | drop_last=False, 204 | collate_fn=collect, 205 | num_workers=cfg.WORKERS 206 | ) 207 | 208 | Trainer = get_trainer( 209 | cfg, 210 | model, 211 | criterion=criterion, 212 | optimizer=optimizer, 213 | lr_scheduler=lr_scheduler, 214 | postprocessors=postprocessors, 215 | log_dir='output', 216 | performance_indicator='mAP', 217 | last_iter=last_iter, 218 | rank=proc_rank, 219 | device=device, 220 | max_norm=cfg.TRAIN.CLIP_MAX_NORM 221 | ) 222 | 223 | print('start training...') 224 | while True: 225 | Trainer.train(train_loader, eval_loader) 226 | 227 | 228 | if __name__ == '__main__': 229 | main_per_worker() 230 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 tools/train.py --cfg configs/hoia.yaml 2 | --------------------------------------------------------------------------------