├── metrics ├── __init__.py ├── common.py ├── calculate_ap_results.py ├── vrmn_relationship.py └── ap_eval_rel.py ├── models ├── detr_modules │ ├── __init__.py │ ├── position_encoding.py │ ├── matcher.py │ ├── backbone.py │ └── transformer.py ├── __init__.py ├── deformable_detr_modules │ ├── ops │ │ ├── make.sh │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── ms_deform_attn_func.py │ │ ├── src │ │ │ ├── vision.cpp │ │ │ ├── cuda │ │ │ │ ├── ms_deform_attn_cuda.h │ │ │ │ └── ms_deform_attn_cuda.cu │ │ │ ├── cpu │ │ │ │ ├── ms_deform_attn_cpu.h │ │ │ │ └── ms_deform_attn_cpu.cpp │ │ │ └── ms_deform_attn.h │ │ ├── setup.py │ │ └── test.py │ ├── __init__.py │ ├── position_encoding.py │ ├── matcher.py │ └── backbone.py ├── detr_gheads.py ├── graph_transformer_dense.py ├── rcnn_graph.py └── detr.py ├── resources └── model_all.png ├── data ├── __init__.py ├── graph_builder.py ├── metagraspnet_labels.py ├── eval_metagraspnet.py ├── augmentations.py ├── metagraspnet_real_mapper.py └── metagraspnet_synth_mapper.py ├── configs ├── pretrains │ ├── rcnn_pretrain.yaml │ ├── detr_pretrain.yaml │ └── defdetr_pretrain.yaml ├── vrmn.yaml ├── gru_gnn.yaml ├── pair_gnn.yaml ├── detr_graphdense.yaml ├── def_detr_graphdense.yaml ├── detr_base.yaml └── rcnn_base.yaml ├── requirements.txt ├── LICENSE ├── utils ├── configs.py ├── box_ops.py ├── train_utils.py ├── data_utils.py └── vis_utils.py ├── README.md ├── .gitignore └── main.py /metrics/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /models/detr_modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /resources/model_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/paolotron/D3G/HEAD/resources/model_all.png -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .rcnn_vrmn import VMN_Head 2 | from .rcnn_graph import GraphRCNN 3 | from .rcnn_gheads import GraphHead 4 | from .detr import Detr 5 | from .detr_graph import GraphDetr 6 | try: 7 | from .deformable_detr_graph import GraphDeformableDetr 8 | from .deformable_detr import DeformableDetr 9 | except: 10 | print('Failed to load deformable Detr') -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .metagraspnet_synth_mapper import * 2 | from .metagraspnet_real_mapper import * 3 | 4 | def get_mapper(name: str): 5 | 6 | if name.startswith('meta_graspnet_v2_synth'): 7 | mapper = MetaGraspNetV2Mapper 8 | elif name.startswith('meta_graspnet_v2_real'): 9 | mapper = MetaGraspNetV2MapperReal 10 | else: 11 | mapper = MetaGraspNetV2Mapper 12 | return mapper -------------------------------------------------------------------------------- /configs/pretrains/rcnn_pretrain.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../rcnn_base.yaml" 2 | 3 | MODEL: 4 | META_ARCHITECTURE: GeneralizedRCNN 5 | ROI_HEADS: 6 | NUM_CLASSES: 97 7 | # PIXEL_STD: [57.375, 57.120, 58.395] 8 | PIXEL_STD: [1, 1, 1] 9 | 10 | SOLVER: 11 | IMS_PER_BATCH: 16 12 | MAX_ITER: 60000 13 | BASE_LR: 0.001 14 | LR_SCHEDULER_NAME: WarmupCosineLR 15 | BASE_LR_END: 0.0000001 16 | AMP: 17 | ENABLED: True 18 | 19 | INPUT: 20 | DEP_GRAPH: False -------------------------------------------------------------------------------- /metrics/common.py: -------------------------------------------------------------------------------- 1 | 2 | def filter_graph(instances, graph, thresh=0.5): 3 | if instances['scores'].shape[0] <= 1: 4 | return graph 5 | keep = instances['scores'] > thresh 6 | rel, ix = graph 7 | keep = keep[ix[0]] & keep[ix[1]] 8 | keep = keep.cpu().numpy() 9 | rel = rel[keep] 10 | ix = ix[:, keep] 11 | return rel, ix 12 | 13 | def filter_boxes(instances, thresh=0.5): 14 | keep = instances['scores'] > thresh 15 | instances = {k: v[keep] for k, v in instances.items()} 16 | return instances -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | 5 | beautifulsoup4 6 | detectron2 7 | easydict 8 | einops 9 | fvcore 10 | h5py 11 | huggingface-hub 12 | lovely-numpy 13 | lovely-tensors 14 | matplotlib 15 | networkx 16 | numpy 17 | pandas 18 | pycocotools 19 | setuptools 20 | tensorboard 21 | timm 22 | tokenizers 23 | tqdm 24 | transformers 25 | torch=2.2.0+cu118 26 | torch-geometric 27 | torchdata 28 | torchmetrics 29 | torchvision -------------------------------------------------------------------------------- /configs/vrmn.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "rcnn_base.yaml" 2 | 3 | MODEL: 4 | WEIGHTS: ./checkpoints/pretrain_metagraspnetv2_maskrcnn.pth 5 | META_ARCHITECTURE: GraphRCNN 6 | ROI_HEADS: 7 | NAME: VMN_Head 8 | NUM_CLASSES: 98 9 | # PIXEL_STD: [57.375, 57.120, 58.395] 10 | PIXEL_STD: [1, 1, 1] 11 | 12 | SOLVER: 13 | IMS_PER_BATCH: 32 14 | MAX_ITER: 30000 15 | BASE_LR: 0.0005 16 | LR_SCHEDULER_NAME: WarmupCosineLR 17 | BASE_LR_END: 0.0000001 18 | AMP: 19 | ENABLED: True 20 | 21 | INPUT: 22 | GRAPH_GT_TYPE: classification 23 | AUGMENT: v3 -------------------------------------------------------------------------------- /configs/gru_gnn.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "rcnn_base.yaml" 2 | 3 | MODEL: 4 | WEIGHTS: ./checkpoints/pretrain_metagraspnetv2_maskrcnn.pth 5 | META_ARCHITECTURE: GraphRCNN 6 | ROI_HEADS: 7 | NAME: GraphHeadGru 8 | NUM_CLASSES: 98 9 | # PIXEL_STD: [57.375, 57.120, 58.395] 10 | PIXEL_STD: [1, 1, 1] 11 | 12 | DATALOADER: 13 | NUM_WORKERS: 20 14 | 15 | SOLVER: 16 | IMS_PER_BATCH: 32 17 | MAX_ITER: 30000 18 | BASE_LR: 0.0005 19 | LR_SCHEDULER_NAME: WarmupCosineLR 20 | BASE_LR_END: 0.0000001 21 | AMP: 22 | ENABLED: True 23 | # REFERENCE_WORLD_SIZE : 1 24 | 25 | 26 | INPUT: 27 | GRAPH_GT_TYPE: gru_graph 28 | AUGMENT: v3 29 | 30 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/make.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # ------------------------------------------------------------------------------------------------ 3 | # Deformable DETR 4 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | # ------------------------------------------------------------------------------------------------ 7 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | # ------------------------------------------------------------------------------------------------ 9 | 10 | python setup.py build install 11 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn import MSDeformAttn 10 | -------------------------------------------------------------------------------- /configs/pair_gnn.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../rcnn_base.yaml" 2 | 3 | 4 | MODEL: 5 | WEIGHTS: ./checkpoints/pretrain_metagraspnetv2_maskrcnn.pth 6 | META_ARCHITECTURE: GraphRCNN 7 | ROI_HEADS: 8 | NAME: GraphHead 9 | NUM_CLASSES: 98 10 | # PIXEL_STD: [57.375, 57.120, 58.395] 11 | PIXEL_STD: [1, 1, 1] 12 | 13 | DATALOADER: 14 | NUM_WORKERS: 20 15 | 16 | SOLVER: 17 | IMS_PER_BATCH: 32 18 | MAX_ITER: 30000 19 | BASE_LR: 0.0005 20 | LR_SCHEDULER_NAME: WarmupCosineLR 21 | BASE_LR_END: 0.0000001 22 | AMP: 23 | ENABLED: True 24 | # REFERENCE_WORLD_SIZE : 1 25 | 26 | 27 | INPUT: 28 | GRAPH_GT_TYPE: relation_graph 29 | AUGMENT: v3 30 | 31 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | from .deformable_detr import build 11 | 12 | def build_model(args): 13 | return build(args) 14 | 15 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/functions/__init__.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from .ms_deform_attn_func import MSDeformAttnFunction 10 | 11 | -------------------------------------------------------------------------------- /configs/pretrains/detr_pretrain.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: '../detr_base.yaml' 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "Detr" 5 | WEIGHTS: "./checkpoints/detr-r50-dc5.pth" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | MASK_ON: False 9 | DETR: 10 | NUM_CLASSES: 98 11 | 12 | SOLVER: 13 | IMS_PER_BATCH: 32 14 | BASE_LR: 0.0006 15 | BASE_LR_END: 0.00000001 16 | MAX_ITER: 60000 17 | WARMUP_FACTOR: 1.0 18 | WARMUP_ITERS: 1000 19 | LR_SCHEDULER_NAME: WarmupCosineLR 20 | WEIGHT_DECAY: 0.0001 21 | OPTIMIZER: "ADAMW" 22 | BACKBONE_MULTIPLIER: 0.1 23 | CLIP_GRADIENTS: 24 | ENABLED: True 25 | CLIP_TYPE: "full_model" 26 | CLIP_VALUE: 0.01 27 | NORM_TYPE: 2.0 28 | AMP: 29 | ENABLED: True 30 | 31 | INPUT: 32 | DEP_GRAPH: False -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/src/vision.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include "ms_deform_attn.h" 12 | 13 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 14 | m.def("ms_deform_attn_forward", &ms_deform_attn_forward, "ms_deform_attn_forward"); 15 | m.def("ms_deform_attn_backward", &ms_deform_attn_backward, "ms_deform_attn_backward"); 16 | } 17 | -------------------------------------------------------------------------------- /configs/detr_graphdense.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: 'detr_base.yaml' 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "GraphDetr" 5 | WEIGHTS: "./checkpoints/pretrain_metagraspnetv2_detr.pth" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | MASK_ON: False 9 | DETR: 10 | NUM_CLASSES: 98 11 | 12 | GRAPH_HEAD: 13 | NAME: 'GraphTransformerDense' 14 | EDGE_FEATURES: 'concat' 15 | HIDDEN_DIM: 256 16 | NUM_HEADS: 2 17 | NUM_LAYERS: 2 18 | 19 | 20 | SOLVER: 21 | IMS_PER_BATCH: 32 22 | BASE_LR: 0.0005 23 | BASE_LR_END: 0.00000001 24 | MAX_ITER: 30000 25 | WARMUP_FACTOR: 1.0 26 | WARMUP_ITERS: 100 27 | LR_SCHEDULER_NAME: WarmupCosineLR 28 | WEIGHT_DECAY: 0.0001 29 | OPTIMIZER: "ADAMW" 30 | BACKBONE_MULTIPLIER: 0.1 31 | CLIP_GRADIENTS: 32 | ENABLED: True 33 | CLIP_TYPE: "full_model" 34 | CLIP_VALUE: 0.01 35 | NORM_TYPE: 2.0 36 | AMP: 37 | ENABLED: True 38 | 39 | INPUT: 40 | GRAPH_GT_TYPE: "dense" 41 | AUGMENT: v3 -------------------------------------------------------------------------------- /configs/pretrains/defdetr_pretrain.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: '../detr_base.yaml' 2 | 3 | MODEL: 4 | META_ARCHITECTURE: "DeformableDetr" 5 | WEIGHTS: "./checkpoints/r50_deformable_detr-checkpoint.pth" 6 | PIXEL_MEAN: [123.675, 116.280, 103.530] 7 | PIXEL_STD: [58.395, 57.120, 57.375] 8 | MASK_ON: False 9 | 10 | DETR: 11 | NUM_CLASSES: 98 12 | DEC_N_POINTS: 4 13 | ENC_N_POINTS: 4 14 | NHEADS: 8 15 | TWO_STAGE: False 16 | NUM_FEATURE_LEVELS: 4 17 | HIDDEN_DIM: 256 18 | DIM_FEEDFORWARD: 1024 19 | DROPOUT: 0.1 20 | NUM_OBJECT_QUERIES: 100 21 | 22 | 23 | SOLVER: 24 | IMS_PER_BATCH: 32 25 | BASE_LR: 0.0003 26 | BASE_LR_END: 0.00000001 27 | MAX_ITER: 60000 28 | WARMUP_FACTOR: 1.0 29 | WARMUP_ITERS: 5000 30 | LR_SCHEDULER_NAME: WarmupCosineLR 31 | WEIGHT_DECAY: 0.0001 32 | OPTIMIZER: "ADAMW" 33 | BACKBONE_MULTIPLIER: 0.1 34 | CLIP_GRADIENTS: 35 | ENABLED: True 36 | CLIP_TYPE: "full_model" 37 | CLIP_VALUE: 0.01 38 | NORM_TYPE: 2.0 39 | AMP: 40 | ENABLED: True 41 | 42 | INPUT: 43 | DEP_GRAPH: False -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Paolo Rabino 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/def_detr_graphdense.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: 'detr_base.yaml' 2 | 3 | 4 | MODEL: 5 | META_ARCHITECTURE: "GraphDeformableDetr" 6 | WEIGHTS: "./checkpoints/pretrain_metagraspnetv2_dfedetr.pth" 7 | PIXEL_MEAN: [123.675, 116.280, 103.530] 8 | PIXEL_STD: [58.395, 57.120, 57.375] 9 | MASK_ON: False 10 | 11 | DETR: 12 | NUM_CLASSES: 98 13 | DEC_N_POINTS: 4 14 | ENC_N_POINTS: 4 15 | NHEADS: 8 16 | TWO_STAGE: False 17 | NUM_FEATURE_LEVELS: 4 18 | HIDDEN_DIM: 256 19 | DIM_FEEDFORWARD: 1024 20 | DROPOUT: 0.1 21 | NUM_OBJECT_QUERIES: 100 22 | 23 | GRAPH_HEAD: 24 | NAME: 'GraphTransformerDense' 25 | EDGE_FEATURES: 'concat' 26 | HIDDEN_DIM: 256 27 | NUM_HEADS: 2 28 | NUM_LAYERS: 2 29 | 30 | 31 | SOLVER: 32 | IMS_PER_BATCH: 32 33 | BASE_LR: 0.0005 34 | BASE_LR_END: 0.00000001 35 | MAX_ITER: 30000 36 | WARMUP_FACTOR: 1.0 37 | WARMUP_ITERS: 100 38 | LR_SCHEDULER_NAME: WarmupCosineLR 39 | WEIGHT_DECAY: 0.0001 40 | OPTIMIZER: "ADAMW" 41 | BACKBONE_MULTIPLIER: 0.1 42 | CLIP_GRADIENTS: 43 | ENABLED: True 44 | CLIP_TYPE: "full_model" 45 | CLIP_VALUE: 0.01 46 | NORM_TYPE: 2.0 47 | AMP: 48 | ENABLED: True 49 | 50 | INPUT: 51 | GRAPH_GT_TYPE: "dense" 52 | AUGMENT: v3 -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/src/cuda/ms_deform_attn_cuda.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor ms_deform_attn_cuda_forward( 15 | const at::Tensor &value, 16 | const at::Tensor &spatial_shapes, 17 | const at::Tensor &level_start_index, 18 | const at::Tensor &sampling_loc, 19 | const at::Tensor &attn_weight, 20 | const int im2col_step); 21 | 22 | std::vector ms_deform_attn_cuda_backward( 23 | const at::Tensor &value, 24 | const at::Tensor &spatial_shapes, 25 | const at::Tensor &level_start_index, 26 | const at::Tensor &sampling_loc, 27 | const at::Tensor &attn_weight, 28 | const at::Tensor &grad_output, 29 | const int im2col_step); 30 | 31 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/src/cpu/ms_deform_attn_cpu.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | #include 13 | 14 | at::Tensor 15 | ms_deform_attn_cpu_forward( 16 | const at::Tensor &value, 17 | const at::Tensor &spatial_shapes, 18 | const at::Tensor &level_start_index, 19 | const at::Tensor &sampling_loc, 20 | const at::Tensor &attn_weight, 21 | const int im2col_step); 22 | 23 | std::vector 24 | ms_deform_attn_cpu_backward( 25 | const at::Tensor &value, 26 | const at::Tensor &spatial_shapes, 27 | const at::Tensor &level_start_index, 28 | const at::Tensor &sampling_loc, 29 | const at::Tensor &attn_weight, 30 | const at::Tensor &grad_output, 31 | const int im2col_step); 32 | 33 | 34 | -------------------------------------------------------------------------------- /configs/detr_base.yaml: -------------------------------------------------------------------------------- 1 | DATASETS: 2 | ROOT: './datasets' 3 | TRAIN: ("meta_graspnet_v2_synth_train", "meta_graspnet_v2_synth_eval") 4 | EVAL: ("meta_graspnet_v2_synth_test_hard",) 5 | TEST: ("meta_graspnet_v2_synth_test_hard", "meta_graspnet_v2_synth_test_easy", "meta_graspnet_v2_synth_test_medium", "meta_graspnet_v2_real_test") 6 | 7 | INPUT: 8 | DEP_GRAPH: True 9 | CLS_GT: True 10 | OBJ_DET: True 11 | 12 | TEST: 13 | EVAL_PERIOD: 1000 14 | 15 | MODEL: 16 | META_ARCHITECTURE: "Detr" 17 | PIXEL_MEAN: [123.675, 116.280, 103.530] 18 | PIXEL_STD: [58.395, 57.120, 57.375] 19 | MASK_ON: True 20 | RESNETS: 21 | DEPTH: 50 22 | STRIDE_IN_1X1: False 23 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 24 | DETR: 25 | GIOU_WEIGHT: 2.0 26 | L1_WEIGHT: 5.0 27 | NUM_OBJECT_QUERIES: 100 28 | FROZEN_WEIGHTS: '' 29 | 30 | DATALOADER: 31 | NUM_WORKERS: 20 32 | 33 | SOLVER: 34 | IMS_PER_BATCH: 64 35 | BASE_LR: 0.0001 36 | STEPS: (55440,) 37 | MAX_ITER: 92400 38 | WARMUP_FACTOR: 1.0 39 | WARMUP_ITERS: 10 40 | WEIGHT_DECAY: 0.0001 41 | OPTIMIZER: "ADAMW" 42 | BACKBONE_MULTIPLIER: 0.1 43 | CLIP_GRADIENTS: 44 | ENABLED: True 45 | CLIP_TYPE: "full_model" 46 | CLIP_VALUE: 0.01 47 | NORM_TYPE: 2.0 48 | INPUT: 49 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 50 | CROP: 51 | ENABLED: True 52 | TYPE: "absolute_range" 53 | SIZE: (384, 600) 54 | FORMAT: "RGB" 55 | 56 | VERSION: 2 -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/src/cpu/ms_deform_attn_cpu.cpp: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | 13 | #include 14 | #include 15 | 16 | 17 | at::Tensor 18 | ms_deform_attn_cpu_forward( 19 | const at::Tensor &value, 20 | const at::Tensor &spatial_shapes, 21 | const at::Tensor &level_start_index, 22 | const at::Tensor &sampling_loc, 23 | const at::Tensor &attn_weight, 24 | const int im2col_step) 25 | { 26 | AT_ERROR("Not implement on cpu"); 27 | } 28 | 29 | std::vector 30 | ms_deform_attn_cpu_backward( 31 | const at::Tensor &value, 32 | const at::Tensor &spatial_shapes, 33 | const at::Tensor &level_start_index, 34 | const at::Tensor &sampling_loc, 35 | const at::Tensor &attn_weight, 36 | const at::Tensor &grad_output, 37 | const int im2col_step) 38 | { 39 | AT_ERROR("Not implement on cpu"); 40 | } 41 | 42 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/src/ms_deform_attn.h: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #pragma once 12 | 13 | #include "cpu/ms_deform_attn_cpu.h" 14 | 15 | #ifdef WITH_CUDA 16 | #include "cuda/ms_deform_attn_cuda.h" 17 | #endif 18 | 19 | 20 | at::Tensor 21 | ms_deform_attn_forward( 22 | const at::Tensor &value, 23 | const at::Tensor &spatial_shapes, 24 | const at::Tensor &level_start_index, 25 | const at::Tensor &sampling_loc, 26 | const at::Tensor &attn_weight, 27 | const int im2col_step) 28 | { 29 | if (value.type().is_cuda()) 30 | { 31 | #ifdef WITH_CUDA 32 | return ms_deform_attn_cuda_forward( 33 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, im2col_step); 34 | #else 35 | AT_ERROR("Not compiled with GPU support"); 36 | #endif 37 | } 38 | AT_ERROR("Not implemented on the CPU"); 39 | } 40 | 41 | std::vector 42 | ms_deform_attn_backward( 43 | const at::Tensor &value, 44 | const at::Tensor &spatial_shapes, 45 | const at::Tensor &level_start_index, 46 | const at::Tensor &sampling_loc, 47 | const at::Tensor &attn_weight, 48 | const at::Tensor &grad_output, 49 | const int im2col_step) 50 | { 51 | if (value.type().is_cuda()) 52 | { 53 | #ifdef WITH_CUDA 54 | return ms_deform_attn_cuda_backward( 55 | value, spatial_shapes, level_start_index, sampling_loc, attn_weight, grad_output, im2col_step); 56 | #else 57 | AT_ERROR("Not compiled with GPU support"); 58 | #endif 59 | } 60 | AT_ERROR("Not implemented on the CPU"); 61 | } 62 | 63 | -------------------------------------------------------------------------------- /configs/rcnn_base.yaml: -------------------------------------------------------------------------------- 1 | 2 | DATASETS: 3 | ROOT: './datasets' 4 | TRAIN: ("meta_graspnet_v2_synth_train", "meta_graspnet_v2_synth_eval") 5 | EVAL: ("meta_graspnet_v2_synth_test_hard",) 6 | TEST: ("meta_graspnet_v2_synth_test_hard", "meta_graspnet_v2_synth_test_easy", "meta_graspnet_v2_synth_test_medium", "meta_graspnet_v2_real_test") 7 | 8 | INPUT: 9 | DEP_GRAPH: True 10 | CLS_GT: True 11 | OBJ_DET: True 12 | 13 | TEST: 14 | EVAL_PERIOD: 1000 15 | DETECTIONS_PER_IMAGE: 50 16 | 17 | DATALOADER: 18 | NUM_WORKERS: 20 19 | 20 | 21 | MODEL: 22 | WEIGHTS: "https://dl.fbaipublicfiles.com/detectron2/COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x/137260431/model_final_a54504.pkl" 23 | MASK_ON: True 24 | META_ARCHITECTURE: "GeneralizedRCNN" 25 | BACKBONE: 26 | NAME: "build_resnet_fpn_backbone" 27 | RESNETS: 28 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 29 | DEPTH: 50 30 | FPN: 31 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 32 | ANCHOR_GENERATOR: 33 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 34 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 35 | RPN: 36 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 37 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 38 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 39 | # Detectron1 uses 2000 proposals per-batch, 40 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 41 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 42 | POST_NMS_TOPK_TRAIN: 1000 43 | POST_NMS_TOPK_TEST: 1000 44 | ROI_HEADS: 45 | NAME: "StandardROIHeads" 46 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 47 | ROI_BOX_HEAD: 48 | NAME: "FastRCNNConvFCHead" 49 | NUM_FC: 2 50 | POOLER_RESOLUTION: 7 51 | ROI_MASK_HEAD: 52 | NAME: "MaskRCNNConvUpsampleHead" 53 | NUM_CONV: 4 54 | POOLER_RESOLUTION: 14 55 | 56 | SOLVER: 57 | IMS_PER_BATCH: 128 58 | MAX_ITER: 30000 59 | OPTIMIZER: ADAM 60 | BASE_LR: 0.001 61 | LR_SCHEDULER_NAME: WarmupCosineLR 62 | BASE_LR_END: 0.00001 63 | AMP: 64 | ENABLED: True 65 | 66 | VERSION: 2 -------------------------------------------------------------------------------- /data/graph_builder.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch_geometric as tg 3 | import torch_geometric.data as data 4 | import numpy as np 5 | import torch 6 | 7 | def objgraph_to_objrelgraph(*, obj_graph:Optional[data.Data] = None , num_objs:Optional[int]=None): 8 | 9 | if not ((obj_graph is None) ^ (num_objs is None)): 10 | raise TypeError("Only one argument can be specified") 11 | 12 | n_nodes = obj_graph.num_nodes if obj_graph is not None else num_objs 13 | n_rel = int(n_nodes * (n_nodes - 1) / 2) 14 | total_nodes = n_nodes + n_rel * 2 15 | is_object = torch.from_numpy(np.array([True] * n_nodes + [False] * n_rel * 2)) 16 | 17 | rel_i, rel_j = np.triu_indices(n_nodes, 1) 18 | rel_i, rel_j = np.concatenate([rel_i, rel_j]), np.concatenate([rel_j, rel_i]) 19 | 20 | roi_indices = np.triu_indices(n_nodes, 1) 21 | # indices = list(range(n_nodes)) 22 | # index_to_feat = {frozenset(i): ix for ix, i in enumerate(zip(*roi_indices))} 23 | index_to_node = {i : ((i, i))for i in range(n_nodes)} 24 | index_to_node = {**{j: ((rel_i[i], rel_j[i])) for i, j in enumerate(range(n_nodes, total_nodes))}, **index_to_node} 25 | node_to_index = {v: k for k, v in index_to_node.items()} 26 | edge_to_feat = {**{frozenset(i): ix+n_nodes for ix, i in enumerate(zip(*roi_indices))}, 27 | **{frozenset([i]): i for i in range(n_nodes)}} 28 | feat_indices = torch.tensor([edge_to_feat[frozenset(index_to_node[i])] for i in range(len(index_to_node))]) 29 | edges = [] 30 | for (start, dest), index in node_to_index.items(): 31 | if start == dest: 32 | edges.append((index, index)) 33 | else: 34 | edges.append((start, index)) 35 | edges.append((index, dest)) 36 | edges = torch.Tensor(edges) 37 | rel_graph = data.Data(num_nodes=total_nodes, edge_index=edges.T.long(), is_object=is_object, index_to_node=feat_indices) 38 | if obj_graph is not None: 39 | rel_gt = torch.zeros(rel_graph.num_nodes) 40 | if obj_graph.edge_index is not None: 41 | for edge in obj_graph.edge_index.T: 42 | edge = tuple(edge.long().tolist()) 43 | rel_gt[node_to_index[edge]] = 1 44 | rel_gt[node_to_index[tuple(reversed(edge))]] = 2 45 | rel_graph.rel_gt = rel_gt 46 | return rel_graph 47 | 48 | -------------------------------------------------------------------------------- /utils/configs.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | from detectron2.config import CfgNode as CN 4 | 5 | 6 | def add_dep_graph_config(cfg): 7 | """ 8 | Add generic dependency config 9 | """ 10 | cfg.INPUT.RGB = True 11 | cfg.INPUT.DEPTH = False 12 | 13 | cfg.INPUT.DEP_GRAPH = False 14 | cfg.INPUT.OBJ_DET = False 15 | cfg.INPUT.CLS_GT = False 16 | cfg.INPUT.GRAPH_GT_TYPE = 'classification' 17 | cfg.INPUT.AUGMENT = 'default' 18 | cfg.DATASETS.ROOT = './datasets' 19 | 20 | cfg.SOLVER.OPTIMIZER = "ADAM" 21 | cfg.SOLVER.BACKBONE_MULTIPLIER = 1 22 | cfg.SOLVER.GRAD_STEP = 1 23 | 24 | cfg.OUTPUT_DIR = '' 25 | 26 | cfg.TEST.GRAPH_THRESH = 0.5 27 | 28 | cfg.DATASETS.EVAL = [] 29 | 30 | return cfg 31 | 32 | 33 | 34 | def add_detr_config(cfg): 35 | """ 36 | Add config for DETR. 37 | """ 38 | cfg.MODEL.DETR = CN() 39 | cfg.MODEL.DETR.NUM_CLASSES = 80 40 | 41 | # For Segmentation 42 | cfg.MODEL.DETR.FROZEN_WEIGHTS = '' 43 | 44 | # LOSS 45 | cfg.MODEL.DETR.GIOU_WEIGHT = 2.0 46 | cfg.MODEL.DETR.L1_WEIGHT = 5.0 47 | cfg.MODEL.DETR.DEEP_SUPERVISION = True 48 | cfg.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1 49 | 50 | # TRANSFORMER 51 | cfg.MODEL.DETR.NHEADS = 8 52 | cfg.MODEL.DETR.DROPOUT = 0.1 53 | cfg.MODEL.DETR.DIM_FEEDFORWARD = 2048 54 | cfg.MODEL.DETR.ENC_LAYERS = 6 55 | cfg.MODEL.DETR.DEC_LAYERS = 6 56 | cfg.MODEL.DETR.PRE_NORM = False 57 | 58 | cfg.MODEL.DETR.HIDDEN_DIM = 256 59 | cfg.MODEL.DETR.NUM_OBJECT_QUERIES = 100 60 | 61 | cfg.MODEL.GRAPH_HEAD = CN() 62 | cfg.MODEL.GRAPH_HEAD.NAME = 'StandardHead' 63 | cfg.MODEL.GRAPH_HEAD.HIDDEN_DIM = 256 64 | cfg.MODEL.GRAPH_HEAD.NUM_HEADS = 1 65 | cfg.MODEL.GRAPH_HEAD.NUM_LAYERS = 1 66 | cfg.MODEL.GRAPH_HEAD.EDGE_FEATURES = 'constant_zero' 67 | cfg.MODEL.DETR.GRAPH_CRITEREON = 'cross' 68 | cfg.MODEL.DETR.GRAPH_WEIGHT = 1.0 69 | cfg.MODEL.DETR.FINETUNE_GHEAD = False 70 | 71 | cfg.SOLVER.OPTIMIZER = "ADAMW" 72 | cfg.SOLVER.BACKBONE_MULTIPLIER = 0.1 73 | 74 | # Deformable Params 75 | cfg.MODEL.DETR.DEC_N_POINTS = 4 76 | cfg.MODEL.DETR.ENC_N_POINTS = 4 77 | cfg.MODEL.DETR.TWO_STAGE = False 78 | cfg.MODEL.DETR.NUM_FEATURE_LEVELS = 4 79 | cfg.MODEL.DETR.POS_EMBEDDING = 'sine' 80 | 81 | return cfg -------------------------------------------------------------------------------- /metrics/calculate_ap_results.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def pack_instance(pred_bboxes, pred_graph, pred_scores, pred_labels, gt_bbox, gt_graph, gt_label): 5 | 6 | if pred_bboxes.shape[0] <= 1: 7 | pred_graph = np.zeros((0, 3), dtype=bool) 8 | 9 | pred_indices = np.vstack(np.triu_indices(pred_bboxes.shape[0], 1)) 10 | pred_indices = pred_indices * pred_graph[:, 1] + pred_indices[[1, 0]] * pred_graph[:, 2] * ~pred_graph[:, 1] 11 | pred_indices = pred_indices[:, ~pred_graph[:, 0]] 12 | prd_scores = pred_graph.astype(np.int32)[~pred_graph[:, 0]] 13 | prd_scores = np.ones((prd_scores.shape[0], 2)) 14 | 15 | gt_indices = np.vstack(np.triu_indices(gt_bbox.shape[0], 1)) 16 | gt_indices = gt_indices * (gt_graph == 1).astype(np.int32) + gt_indices[[1, 0]] * (gt_graph == 2).astype(np.int32) 17 | gt_indices = gt_indices[:, gt_graph != 0] 18 | prd_gt_classes = gt_graph[gt_graph != 0] 19 | prd_gt_classes = np.ones((prd_gt_classes.shape[0])) 20 | 21 | subj_bbox = pred_bboxes[pred_indices[0]] 22 | sbj_scores = pred_scores[pred_indices[0]] 23 | subj_label = pred_labels[pred_indices[0]] 24 | 25 | obj_boxes = pred_bboxes[pred_indices[1]] 26 | obj_scores = pred_scores[pred_indices[1]] 27 | obj_labels = pred_labels[pred_indices[1]] 28 | 29 | sbj_gt_boxes = gt_bbox[gt_indices[0]] 30 | sbj_gt_classes = gt_label[gt_indices[0]] 31 | 32 | obj_gt_boxes = gt_bbox[gt_indices[1]] 33 | obj_gt_classes = gt_label[gt_indices[1]] 34 | 35 | 36 | packed_data = dict( 37 | sbj_boxes=subj_bbox, # N 4 box ? V 38 | sbj_labels=subj_label.astype(np.int32, copy=False), 39 | sbj_scores=sbj_scores, 40 | obj_boxes=obj_boxes, # N 4 box ? 41 | obj_labels=obj_labels.astype(np.int32, copy=False), 42 | obj_scores=obj_scores, 43 | prd_scores=prd_scores, 44 | gt_sbj_boxes=sbj_gt_boxes, 45 | gt_obj_boxes=obj_gt_boxes, 46 | gt_sbj_labels=sbj_gt_classes.astype(np.int32, copy=False), 47 | gt_obj_labels=obj_gt_classes.astype(np.int32, copy=False), 48 | gt_prd_labels=prd_gt_classes.astype(np.int32, copy=False)) 49 | return packed_data 50 | 51 | def pack_data(local_res): 52 | packed_data = [] 53 | for ix in range(len(local_res['pred_bbox'])): 54 | pred_bboxes = local_res['pred_bbox'][ix] 55 | pred_graph = local_res['pred_graph'][ix] 56 | pred_scores = local_res['pred_scores'][ix] 57 | pred_labels = local_res['pred_classes'][ix] 58 | gt_bbox = local_res['gt_bbox'][ix] 59 | gt_graph = local_res['gt_graph'][ix] 60 | gt_label = local_res['gt_class'][ix] 61 | 62 | inst = pack_instance(pred_bboxes, pred_graph, pred_scores, pred_labels, gt_bbox, gt_graph, gt_label) 63 | packed_data.append(inst) 64 | 65 | return packed_data 66 | 67 | 68 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/setup.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | import os 10 | import glob 11 | 12 | import torch 13 | 14 | from torch.utils.cpp_extension import CUDA_HOME 15 | from torch.utils.cpp_extension import CppExtension 16 | from torch.utils.cpp_extension import CUDAExtension 17 | 18 | from setuptools import find_packages 19 | from setuptools import setup 20 | 21 | requirements = ["torch", "torchvision"] 22 | 23 | def get_extensions(): 24 | this_dir = os.path.dirname(os.path.abspath(__file__)) 25 | extensions_dir = os.path.join(this_dir, "src") 26 | 27 | main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) 28 | source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) 29 | source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) 30 | 31 | sources = main_file + source_cpu 32 | extension = CppExtension 33 | extra_compile_args = {"cxx": []} 34 | define_macros = [] 35 | 36 | if torch.cuda.is_available() and CUDA_HOME is not None: 37 | extension = CUDAExtension 38 | sources += source_cuda 39 | define_macros += [("WITH_CUDA", None)] 40 | extra_compile_args["nvcc"] = [ 41 | "-DCUDA_HAS_FP16=1", 42 | "-D__CUDA_NO_HALF_OPERATORS__", 43 | "-D__CUDA_NO_HALF_CONVERSIONS__", 44 | "-D__CUDA_NO_HALF2_OPERATORS__", 45 | ] 46 | else: 47 | raise NotImplementedError('Cuda is not availabel') 48 | 49 | sources = [os.path.join(extensions_dir, s) for s in sources] 50 | include_dirs = [extensions_dir] 51 | ext_modules = [ 52 | extension( 53 | "MultiScaleDeformableAttention", 54 | sources, 55 | include_dirs=include_dirs, 56 | define_macros=define_macros, 57 | extra_compile_args=extra_compile_args, 58 | ) 59 | ] 60 | return ext_modules 61 | 62 | setup( 63 | name="MultiScaleDeformableAttention", 64 | version="1.0", 65 | author="Weijie Su", 66 | url="https://github.com/fundamentalvision/Deformable-DETR", 67 | description="PyTorch Wrapper for CUDA Functions of Multi-Scale Deformable Attention", 68 | packages=find_packages(exclude=("configs", "tests",)), 69 | ext_modules=get_extensions(), 70 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 71 | ) 72 | -------------------------------------------------------------------------------- /utils/box_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Utilities for bounding box manipulation and GIoU. 4 | """ 5 | import torch 6 | from torchvision.ops.boxes import box_area 7 | 8 | 9 | def box_cxcywh_to_xyxy(x): 10 | x_c, y_c, w, h = x.unbind(-1) 11 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 12 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 13 | return torch.stack(b, dim=-1) 14 | 15 | 16 | def box_xyxy_to_cxcywh(x): 17 | x0, y0, x1, y1 = x.unbind(-1) 18 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 19 | (x1 - x0), (y1 - y0)] 20 | return torch.stack(b, dim=-1) 21 | 22 | 23 | # modified from torchvision to also return the union 24 | def box_iou(boxes1, boxes2): 25 | area1 = box_area(boxes1) 26 | area2 = box_area(boxes2) 27 | 28 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 29 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 30 | 31 | wh = (rb - lt).clamp(min=0) # [N,M,2] 32 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 33 | 34 | union = area1[:, None] + area2 - inter 35 | 36 | iou = inter / union 37 | return iou, union 38 | 39 | 40 | def generalized_box_iou(boxes1, boxes2): 41 | """ 42 | Generalized IoU from https://giou.stanford.edu/ 43 | 44 | The boxes should be in [x0, y0, x1, y1] format 45 | 46 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 47 | and M = len(boxes2) 48 | """ 49 | # degenerate boxes gives inf / nan results 50 | # so do an early check 51 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 52 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 53 | iou, union = box_iou(boxes1, boxes2) 54 | 55 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 56 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 57 | 58 | wh = (rb - lt).clamp(min=0) # [N,M,2] 59 | area = wh[:, :, 0] * wh[:, :, 1] 60 | 61 | return iou - (area - union) / area 62 | 63 | 64 | def masks_to_boxes(masks): 65 | """Compute the bounding boxes around the provided masks 66 | 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | 69 | Returns a [N, 4] tensors, with the boxes in xyxy format 70 | """ 71 | if masks.numel() == 0: 72 | return torch.zeros((0, 4), device=masks.device) 73 | 74 | h, w = masks.shape[-2:] 75 | 76 | y = torch.arange(0, h, dtype=torch.float) 77 | x = torch.arange(0, w, dtype=torch.float) 78 | y, x = torch.meshgrid(y, x) 79 | 80 | x_mask = (masks * x.unsqueeze(0)) 81 | x_max = x_mask.flatten(1).max(-1)[0] 82 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 83 | 84 | y_mask = (masks * y.unsqueeze(0)) 85 | y_max = y_mask.flatten(1).max(-1)[0] 86 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 87 | 88 | return torch.stack([x_min, y_min, x_max, y_max], 1) 89 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # D3G 2 | This is the code for the paper "A Modern Take on Visual Relationship Reasoning for Grasp Planning" 3 | ![model picture](./resources/model_all.png) 4 | 5 | ## D3GD Data 6 | We base our test-bed on the MetaGraspNetV2 dataset, download from [here](https://github.com/maximiliangilles/MetaGraspNet) the **MGN-Sim** and **MGN-Real** data and put both of them in the same folder. 7 | 8 | We also provide splits and compressed meta-data for our testbed [here](https://drive.google.com/drive/folders/1e9_Oa05Cdt5K4aa3rRRf__t5l5ozUeZf?usp=drive_link), download all files and put them in the ```d3g/data/``` folder 9 | 10 | ## Installation 11 | For installing the environment create a fresh python env and run the following command 12 | ``` 13 | pip install -r requirements.txt 14 | ``` 15 | If you want to run deformable detr models run the following 16 | 17 | ``` 18 | cd ./models/deformable_detr_modules/ops 19 | sh ./make.sh 20 | python ./test.py 21 | ``` 22 | 23 | ## Run experiments and define new ones 24 | We leverage [detectron2](https://detectron2.readthedocs.io/en/latest/) config systems to define experiments and models, to launch one simply run 25 | ``` 26 | python main.py --config-file configs/config.yaml --data-path /your/data/path 27 | ``` 28 | create a new .yaml file in the config folder to create new experiments 29 | 30 | ### Pretraining 31 | For all reported experiments we first pretrain the detection part of the model on the detection task and then fine-tune/train the complete model on the detection and relationship understanding tasks. 32 | To pretrain models run the pretrain configs as follows 33 | 34 | ``` 35 | # Depending on your desired model run one of the following 36 | 37 | # Detr based Models 38 | python main.py --config-file ./configs/pretrains/detr_pretrain.yaml 39 | 40 | # Deformable Detr based models 41 | python main.py --config-file ./configs/pretrains/defdetr_pretrain.yaml 42 | 43 | # Mask-RCNN based models 44 | python main.py --config-file ./configs/pretrains/rcnn_pretrain.yaml 45 | 46 | ``` 47 | For all models we start from the publicly avaiable COCO checkpoints, due to our changes to the detr and deformable detr architectures we need to change some key names, pretrain checkpoints with fixed keys are avaiable [here](https://drive.google.com/drive/folders/1v9XdnxK1eKCYFYpilOSpm3zOi3BycMqY?usp=drive_link) 48 | ### Relationship Reasoning Training 49 | Now that you have your pretrained model you can train it on the relationship reasoning task as follows 50 | ``` 51 | python main.py --config-file detr_graphdense_medium.yaml MODEL.WEIGHTS /path/to/checkpoint 52 | ``` 53 | ### Citation 54 | If you find this work usefull please cite it using: 55 | ``` 56 | @ARTICLE{10819650, 57 | author={Rabino, Paolo and Tommasi, Tatiana}, 58 | journal={IEEE Robotics and Automation Letters}, 59 | title={A Modern Take on Visual Relationship Reasoning for Grasp Planning}, 60 | year={2025}, 61 | volume={10}, 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /data/metagraspnet_labels.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | 4 | IFL_SCENE_RECORDING = { 5 | "cracker_box" : "097", 6 | "sugar_box" : "001", 7 | "tomato_soup_can" : "002", 8 | "mustard_bottle" : "003", 9 | "potted_meat_can" : "004", 10 | "banana" :"005", 11 | "bowl" : "006", 12 | "mug" :"007", 13 | "power_drill" : "008", 14 | "scissor" : "009", 15 | "chips_can" :"010", 16 | "strawberry" :"011", 17 | "apple" : "012", 18 | "lemon" : "013", 19 | "peach" : "014", 20 | "pear" : "015", 21 | "orange" : "016", 22 | "plum" : "017", 23 | "knife" : "018", 24 | "phillips_screwdriver" : "019", 25 | "flat_screwdriver" : "020", 26 | "racquetball" :"021", 27 | "b_cups" : "022", 28 | "d_cups" : "023", 29 | "a_toy_airplane" : "024", 30 | "c_toy_airplane" : "025", 31 | "d_toy_airplane" : "026", 32 | "f_toy_airplane" : "027", 33 | "h_toy_airplane" : "028", 34 | "i_toy_airplane" : "029", 35 | "j_toy_airplane" : "030", 36 | "k_toy_airplane" : "031", 37 | "light_bulb" : "032", 38 | "kitchen_knife" : "034", 39 | "screw_valve" : "035", 40 | "plastic_pipes" :"036", 41 | "cables_in_transparent_bag" : "037", 42 | "cables" : "038", 43 | "wire_cutter" : "039", 44 | "desinfection" : "040", 45 | "hairspray" : "041", 46 | "handcream" : "042", 47 | "toothpaste" : "043", 48 | "toydog" : "044", 49 | "sponge" : "045", 50 | "pneumatic_cylinder" : "046", 51 | "airfilter" : "047", 52 | "coffeefilter" : "048", 53 | "wash_glove" : "049", 54 | "wash_sponge" : "050", 55 | "garbage_bags" : "051", 56 | "deo" : "052", 57 | "cat_milk" : "053", 58 | "bottle_glass" : "054", 59 | "bottle_press_head" : "055", 60 | "shaving_cream" : "056", 61 | "chewing_gum_with_spray" : "057", 62 | "lighters" : "058", 63 | "cream_soap" : "059", 64 | "box_1" : "060", 65 | "box_2" : "061", 66 | "box_3" : "062", 67 | "box_4" : "063", 68 | "box_5" : "064", 69 | "box_6" : "065", 70 | "box_7" : "066", 71 | "box_8" : "067", 72 | "glass_cup" : "068", 73 | "tennis_ball" : "069", 74 | "cup" : "070", 75 | "wineglass" : "071", 76 | "handsaw" : "072", 77 | "lipcare" : "073", 78 | "woodcube_a" : "074", 79 | "lipstick" : "075", 80 | "nosespray" : "076", 81 | "tape" : "077", 82 | "bookholder" : "078", 83 | "clamp" : "079", 84 | "glue" : "080", 85 | "stapler" : "081", 86 | "calculator" : "082", 87 | "clamp_small" : "083", 88 | "clamp_big" : "084", 89 | "glasses" : "085", 90 | "crayons" : "086", 91 | "marker_big" : "087", 92 | "marker_small" : "088", 93 | "greek_busts" : "089", 94 | "object_wrapped_in_foil" : "090", 95 | "water_bottle_deformed" : "091", 96 | "bubble_wrap" : "092", 97 | "woodblock_a" : "093", 98 | "woodblock_b" : "094", 99 | "woodblock_c" : "095", 100 | "mannequin" : "096" 101 | } 102 | 103 | IFL_SYNSET_TO_LABEL = {v: k for k, v in IFL_SCENE_RECORDING.items()} 104 | 105 | def get_coco_things(): 106 | things = defaultdict(lambda: '', {int(v): k for k, v in IFL_SCENE_RECORDING.items()}) 107 | things = [things[i] for i in range(max(things)+1)] 108 | return things -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/functions/ms_deform_attn_func.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Function 16 | from torch.autograd.function import once_differentiable 17 | 18 | import MultiScaleDeformableAttention as MSDA 19 | 20 | 21 | class MSDeformAttnFunction(Function): 22 | @staticmethod 23 | def forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step): 24 | 25 | ctx.im2col_step = im2col_step 26 | output = MSDA.ms_deform_attn_forward( 27 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, ctx.im2col_step) 28 | ctx.save_for_backward(value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights) 29 | return output 30 | 31 | @staticmethod 32 | @once_differentiable 33 | def backward(ctx, grad_output): 34 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights = ctx.saved_tensors 35 | grad_value, grad_sampling_loc, grad_attn_weight = \ 36 | MSDA.ms_deform_attn_backward( 37 | value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, grad_output, ctx.im2col_step) 38 | 39 | return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None 40 | 41 | 42 | def ms_deform_attn_core_pytorch(value, value_spatial_shapes, sampling_locations, attention_weights): 43 | # for debug and test only, 44 | # need to use cuda version instead 45 | N_, S_, M_, D_ = value.shape 46 | _, Lq_, M_, L_, P_, _ = sampling_locations.shape 47 | value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1) 48 | sampling_grids = 2 * sampling_locations - 1 49 | sampling_value_list = [] 50 | for lid_, (H_, W_) in enumerate(value_spatial_shapes): 51 | # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_ 52 | value_l_ = value_list[lid_].flatten(2).transpose(1, 2).reshape(N_*M_, D_, H_, W_) 53 | # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2 54 | sampling_grid_l_ = sampling_grids[:, :, :, lid_].transpose(1, 2).flatten(0, 1) 55 | # N_*M_, D_, Lq_, P_ 56 | sampling_value_l_ = F.grid_sample(value_l_, sampling_grid_l_, 57 | mode='bilinear', padding_mode='zeros', align_corners=False) 58 | sampling_value_list.append(sampling_value_l_) 59 | # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_, M_, 1, Lq_, L_*P_) 60 | attention_weights = attention_weights.transpose(1, 2).reshape(N_*M_, 1, Lq_, L_*P_) 61 | output = (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights).sum(-1).view(N_, M_*D_, Lq_) 62 | return output.transpose(1, 2).contiguous() 63 | -------------------------------------------------------------------------------- /models/detr_modules/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Various positional encodings for the transformer. 4 | """ 5 | import math 6 | import torch 7 | from torch import nn 8 | 9 | from utils.fb_misc import NestedTensor 10 | 11 | 12 | class PositionEmbeddingSine(nn.Module): 13 | """ 14 | This is a more standard version of the position embedding, very similar to the one 15 | used by the Attention is all you need paper, generalized to work on images. 16 | """ 17 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 18 | super().__init__() 19 | self.num_pos_feats = num_pos_feats 20 | self.temperature = temperature 21 | self.normalize = normalize 22 | if scale is not None and normalize is False: 23 | raise ValueError("normalize should be True if scale is passed") 24 | if scale is None: 25 | scale = 2 * math.pi 26 | self.scale = scale 27 | 28 | def forward(self, tensor_list: NestedTensor): 29 | x = tensor_list.tensors 30 | mask = tensor_list.mask 31 | assert mask is not None 32 | not_mask = ~mask 33 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 34 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 35 | if self.normalize: 36 | eps = 1e-6 37 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 38 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 39 | 40 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 41 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 42 | 43 | pos_x = x_embed[:, :, :, None] / dim_t 44 | pos_y = y_embed[:, :, :, None] / dim_t 45 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 46 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 47 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 48 | return pos 49 | 50 | 51 | class PositionEmbeddingLearned(nn.Module): 52 | """ 53 | Absolute pos embedding, learned. 54 | """ 55 | def __init__(self, num_pos_feats=256): 56 | super().__init__() 57 | self.row_embed = nn.Embedding(50, num_pos_feats) 58 | self.col_embed = nn.Embedding(50, num_pos_feats) 59 | self.reset_parameters() 60 | 61 | def reset_parameters(self): 62 | nn.init.uniform_(self.row_embed.weight) 63 | nn.init.uniform_(self.col_embed.weight) 64 | 65 | def forward(self, tensor_list: NestedTensor): 66 | x = tensor_list.tensors 67 | h, w = x.shape[-2:] 68 | i = torch.arange(w, device=x.device) 69 | j = torch.arange(h, device=x.device) 70 | x_emb = self.col_embed(i) 71 | y_emb = self.row_embed(j) 72 | pos = torch.cat([ 73 | x_emb.unsqueeze(0).repeat(h, 1, 1), 74 | y_emb.unsqueeze(1).repeat(1, w, 1), 75 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 76 | return pos 77 | 78 | 79 | def build_position_encoding(args): 80 | N_steps = args.hidden_dim // 2 81 | if args.position_embedding in ('v2', 'sine'): 82 | # TODO find a better way of exposing other arguments 83 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 84 | elif args.position_embedding in ('v3', 'learned'): 85 | position_embedding = PositionEmbeddingLearned(N_steps) 86 | else: 87 | raise ValueError(f"not supported {args.position_embedding}") 88 | 89 | return position_embedding 90 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | output/* 165 | .vscode/* 166 | *.json -------------------------------------------------------------------------------- /models/deformable_detr_modules/position_encoding.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Various positional encodings for the transformer. 12 | """ 13 | import math 14 | import torch 15 | from torch import nn 16 | 17 | from utils.fb_misc import NestedTensor 18 | 19 | 20 | class PositionEmbeddingSine(nn.Module): 21 | """ 22 | This is a more standard version of the position embedding, very similar to the one 23 | used by the Attention is all you need paper, generalized to work on images. 24 | """ 25 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 26 | super().__init__() 27 | self.num_pos_feats = num_pos_feats 28 | self.temperature = temperature 29 | self.normalize = normalize 30 | if scale is not None and normalize is False: 31 | raise ValueError("normalize should be True if scale is passed") 32 | if scale is None: 33 | scale = 2 * math.pi 34 | self.scale = scale 35 | 36 | def forward(self, tensor_list: NestedTensor): 37 | x = tensor_list.tensors 38 | mask = tensor_list.mask 39 | assert mask is not None 40 | not_mask = ~mask 41 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 42 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 43 | if self.normalize: 44 | eps = 1e-6 45 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 46 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 47 | 48 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 49 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 50 | 51 | pos_x = x_embed[:, :, :, None] / dim_t 52 | pos_y = y_embed[:, :, :, None] / dim_t 53 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 54 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 55 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 56 | return pos 57 | 58 | 59 | class PositionEmbeddingLearned(nn.Module): 60 | """ 61 | Absolute pos embedding, learned. 62 | """ 63 | def __init__(self, num_pos_feats=256): 64 | super().__init__() 65 | self.row_embed = nn.Embedding(50, num_pos_feats) 66 | self.col_embed = nn.Embedding(50, num_pos_feats) 67 | self.reset_parameters() 68 | 69 | def reset_parameters(self): 70 | nn.init.uniform_(self.row_embed.weight) 71 | nn.init.uniform_(self.col_embed.weight) 72 | 73 | def forward(self, tensor_list: NestedTensor): 74 | x = tensor_list.tensors 75 | h, w = x.shape[-2:] 76 | i = torch.arange(w, device=x.device) 77 | j = torch.arange(h, device=x.device) 78 | x_emb = self.col_embed(i) 79 | y_emb = self.row_embed(j) 80 | pos = torch.cat([ 81 | x_emb.unsqueeze(0).repeat(h, 1, 1), 82 | y_emb.unsqueeze(1).repeat(1, w, 1), 83 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 84 | return pos 85 | 86 | 87 | def build_position_encoding(cfg): 88 | N_steps = cfg.MODEL.DETR.HIDDEN_DIM // 2 89 | pos_embedding = cfg.MODEL.DETR.POS_EMBEDDING 90 | if pos_embedding in ('v2', 'sine'): 91 | # TODO find a better way of exposing other arguments 92 | position_embedding = PositionEmbeddingSine(N_steps, normalize=True) 93 | elif pos_embedding in ('v3', 'learned'): 94 | position_embedding = PositionEmbeddingLearned(N_steps) 95 | else: 96 | raise ValueError(f"not supported {pos_embedding}") 97 | 98 | return position_embedding 99 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/test.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import time 14 | import torch 15 | import torch.nn as nn 16 | from torch.autograd import gradcheck 17 | 18 | from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch 19 | 20 | 21 | N, M, D = 1, 2, 2 22 | Lq, L, P = 2, 2, 2 23 | shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() 24 | level_start_index = torch.cat((shapes.new_zeros((1, )), shapes.prod(1).cumsum(0)[:-1])) 25 | S = sum([(H*W).item() for H, W in shapes]) 26 | 27 | 28 | torch.manual_seed(3) 29 | 30 | 31 | @torch.no_grad() 32 | def check_forward_equal_with_pytorch_double(): 33 | value = torch.rand(N, S, M, D).cuda() * 0.01 34 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 35 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 36 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 37 | im2col_step = 2 38 | output_pytorch = ms_deform_attn_core_pytorch(value.double(), shapes, sampling_locations.double(), attention_weights.double()).detach().cpu() 39 | output_cuda = MSDeformAttnFunction.apply(value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step).detach().cpu() 40 | fwdok = torch.allclose(output_cuda, output_pytorch) 41 | max_abs_err = (output_cuda - output_pytorch).abs().max() 42 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 43 | 44 | print(f'* {fwdok} check_forward_equal_with_pytorch_double: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 45 | 46 | 47 | @torch.no_grad() 48 | def check_forward_equal_with_pytorch_float(): 49 | value = torch.rand(N, S, M, D).cuda() * 0.01 50 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 51 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 52 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 53 | im2col_step = 2 54 | output_pytorch = ms_deform_attn_core_pytorch(value, shapes, sampling_locations, attention_weights).detach().cpu() 55 | output_cuda = MSDeformAttnFunction.apply(value, shapes, level_start_index, sampling_locations, attention_weights, im2col_step).detach().cpu() 56 | fwdok = torch.allclose(output_cuda, output_pytorch, rtol=1e-2, atol=1e-3) 57 | max_abs_err = (output_cuda - output_pytorch).abs().max() 58 | max_rel_err = ((output_cuda - output_pytorch).abs() / output_pytorch.abs()).max() 59 | 60 | print(f'* {fwdok} check_forward_equal_with_pytorch_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}') 61 | 62 | 63 | def check_gradient_numerical(channels=1, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True): 64 | 65 | value = torch.rand(N, S, M, channels).cuda() * 0.01 66 | sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() 67 | attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 68 | attention_weights /= attention_weights.sum(-1, keepdim=True).sum(-2, keepdim=True) 69 | im2col_step = 2 70 | func = MSDeformAttnFunction.apply 71 | 72 | value.requires_grad = grad_value 73 | sampling_locations.requires_grad = grad_sampling_loc 74 | attention_weights.requires_grad = grad_attn_weight 75 | 76 | gradok = gradcheck(func, (value.double(), shapes, level_start_index, sampling_locations.double(), attention_weights.double(), im2col_step)) 77 | 78 | print(f'* {gradok} check_gradient_numerical(D={channels})') 79 | 80 | 81 | if __name__ == '__main__': 82 | check_forward_equal_with_pytorch_double() 83 | check_forward_equal_with_pytorch_float() 84 | 85 | for channels in [30, 32, 64, 71, 1025, 2048, 3096]: 86 | check_gradient_numerical(channels, True, True, True) 87 | 88 | 89 | 90 | -------------------------------------------------------------------------------- /models/detr_modules/matcher.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Modules to compute the matching cost and solve the corresponding LSAP. 4 | """ 5 | import torch 6 | from scipy.optimize import linear_sum_assignment 7 | from torch import nn 8 | 9 | from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 10 | 11 | 12 | class HungarianMatcher(nn.Module): 13 | """This class computes an assignment between the targets and the predictions of the network 14 | 15 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 16 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 17 | while the others are un-matched (and thus treated as non-objects). 18 | """ 19 | 20 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 21 | """Creates the matcher 22 | 23 | Params: 24 | cost_class: This is the relative weight of the classification error in the matching cost 25 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 26 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 27 | """ 28 | super().__init__() 29 | self.cost_class = cost_class 30 | self.cost_bbox = cost_bbox 31 | self.cost_giou = cost_giou 32 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 33 | 34 | @torch.no_grad() 35 | def forward(self, outputs, targets): 36 | """ Performs the matching 37 | 38 | Params: 39 | outputs: This is a dict that contains at least these entries: 40 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 41 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 42 | 43 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 44 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 45 | objects in the target) containing the class labels 46 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 47 | 48 | Returns: 49 | A list of size batch_size, containing tuples of (index_i, index_j) where: 50 | - index_i is the indices of the selected predictions (in order) 51 | - index_j is the indices of the corresponding selected targets (in order) 52 | For each batch element, it holds: 53 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 54 | """ 55 | bs, num_queries = outputs["pred_logits"].shape[:2] 56 | 57 | # We flatten to compute the cost matrices in a batch 58 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 59 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 60 | 61 | # Also concat the target labels and boxes 62 | tgt_ids = torch.cat([v["labels"] for v in targets]) 63 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 64 | 65 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 66 | # but approximate it in 1 - proba[target class]. 67 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 68 | cost_class = -out_prob[:, tgt_ids] 69 | 70 | # Compute the L1 cost between boxes 71 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 72 | 73 | # Compute the giou cost betwen boxes 74 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 75 | 76 | # Final cost matrix 77 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 78 | C = C.view(bs, num_queries, -1).cpu() 79 | 80 | sizes = [len(v["boxes"]) for v in targets] 81 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 82 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 83 | 84 | 85 | def build_matcher(args): 86 | return HungarianMatcher(cost_class=args.set_cost_class, cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou) 87 | -------------------------------------------------------------------------------- /models/detr_modules/backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | Backbone modules. 4 | """ 5 | from collections import OrderedDict 6 | 7 | import torch 8 | import torch.nn.functional as F 9 | import torchvision 10 | from torch import nn 11 | from torchvision.models._utils import IntermediateLayerGetter 12 | from typing import Dict, List 13 | 14 | from utils.fb_misc import NestedTensor, is_main_process 15 | 16 | from .position_encoding import build_position_encoding 17 | 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | 23 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 24 | without which any other models than torchvision.models.resnet[18,34,50,101] 25 | produce nans. 26 | """ 27 | 28 | def __init__(self, n): 29 | super(FrozenBatchNorm2d, self).__init__() 30 | self.register_buffer("weight", torch.ones(n)) 31 | self.register_buffer("bias", torch.zeros(n)) 32 | self.register_buffer("running_mean", torch.zeros(n)) 33 | self.register_buffer("running_var", torch.ones(n)) 34 | 35 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 36 | missing_keys, unexpected_keys, error_msgs): 37 | num_batches_tracked_key = prefix + 'num_batches_tracked' 38 | if num_batches_tracked_key in state_dict: 39 | del state_dict[num_batches_tracked_key] 40 | 41 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 42 | state_dict, prefix, local_metadata, strict, 43 | missing_keys, unexpected_keys, error_msgs) 44 | 45 | def forward(self, x): 46 | # move reshapes to the beginning 47 | # to make it fuser-friendly 48 | w = self.weight.reshape(1, -1, 1, 1) 49 | b = self.bias.reshape(1, -1, 1, 1) 50 | rv = self.running_var.reshape(1, -1, 1, 1) 51 | rm = self.running_mean.reshape(1, -1, 1, 1) 52 | eps = 1e-5 53 | scale = w * (rv + eps).rsqrt() 54 | bias = b - rm * scale 55 | return x * scale + bias 56 | 57 | 58 | class BackboneBase(nn.Module): 59 | 60 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 61 | super().__init__() 62 | for name, parameter in backbone.named_parameters(): 63 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 64 | parameter.requires_grad_(False) 65 | if return_interm_layers: 66 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 67 | else: 68 | return_layers = {'layer4': "0"} 69 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 70 | self.num_channels = num_channels 71 | 72 | def forward(self, tensor_list: NestedTensor): 73 | xs = self.body(tensor_list.tensors) 74 | out: Dict[str, NestedTensor] = {} 75 | for name, x in xs.items(): 76 | m = tensor_list.mask 77 | assert m is not None 78 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 79 | out[name] = NestedTensor(x, mask) 80 | return out 81 | 82 | 83 | class Backbone(BackboneBase): 84 | """ResNet backbone with frozen BatchNorm.""" 85 | def __init__(self, name: str, 86 | train_backbone: bool, 87 | return_interm_layers: bool, 88 | dilation: bool): 89 | backbone = getattr(torchvision.models, name)( 90 | replace_stride_with_dilation=[False, False, dilation], 91 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 92 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 93 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 94 | 95 | 96 | class Joiner(nn.Sequential): 97 | def __init__(self, backbone, position_embedding): 98 | super().__init__(backbone, position_embedding) 99 | 100 | def forward(self, tensor_list: NestedTensor): 101 | xs = self[0](tensor_list) 102 | out: List[NestedTensor] = [] 103 | pos = [] 104 | for name, x in xs.items(): 105 | out.append(x) 106 | # position encoding 107 | pos.append(self[1](x).to(x.tensors.dtype)) 108 | 109 | return out, pos 110 | 111 | 112 | def build_backbone(args): 113 | position_embedding = build_position_encoding(args) 114 | train_backbone = args.lr_backbone > 0 115 | return_interm_layers = args.masks 116 | backbone = Backbone(args.backbone, train_backbone, return_interm_layers, args.dilation) 117 | model = Joiner(backbone, position_embedding) 118 | model.num_channels = backbone.num_channels 119 | return model 120 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/matcher.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Modules to compute the matching cost and solve the corresponding LSAP. 12 | """ 13 | import torch 14 | from scipy.optimize import linear_sum_assignment 15 | from torch import nn 16 | 17 | from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 18 | 19 | 20 | class HungarianMatcher(nn.Module): 21 | """This class computes an assignment between the targets and the predictions of the network 22 | 23 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 24 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 25 | while the others are un-matched (and thus treated as non-objects). 26 | """ 27 | 28 | def __init__(self, 29 | cost_class: float = 1, 30 | cost_bbox: float = 1, 31 | cost_giou: float = 1): 32 | """Creates the matcher 33 | 34 | Params: 35 | cost_class: This is the relative weight of the classification error in the matching cost 36 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 37 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 38 | """ 39 | super().__init__() 40 | self.cost_class = cost_class 41 | self.cost_bbox = cost_bbox 42 | self.cost_giou = cost_giou 43 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 44 | 45 | def forward(self, outputs, targets): 46 | """ Performs the matching 47 | 48 | Params: 49 | outputs: This is a dict that contains at least these entries: 50 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 51 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 52 | 53 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 54 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 55 | objects in the target) containing the class labels 56 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 57 | 58 | Returns: 59 | A list of size batch_size, containing tuples of (index_i, index_j) where: 60 | - index_i is the indices of the selected predictions (in order) 61 | - index_j is the indices of the corresponding selected targets (in order) 62 | For each batch element, it holds: 63 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 64 | """ 65 | with torch.no_grad(): 66 | bs, num_queries = outputs["pred_logits"].shape[:2] 67 | 68 | # We flatten to compute the cost matrices in a batch 69 | out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid() 70 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 71 | 72 | # Also concat the target labels and boxes 73 | tgt_ids = torch.cat([v["labels"] for v in targets]) 74 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 75 | 76 | # Compute the classification cost. 77 | alpha = 0.25 78 | gamma = 2.0 79 | neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log()) 80 | pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) 81 | cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] 82 | 83 | # Compute the L1 cost between boxes 84 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 85 | 86 | # Compute the giou cost betwen boxes 87 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), 88 | box_cxcywh_to_xyxy(tgt_bbox)) 89 | 90 | # Final cost matrix 91 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 92 | C = C.view(bs, num_queries, -1).cpu() 93 | 94 | sizes = [len(v["boxes"]) for v in targets] 95 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))] 96 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 97 | 98 | 99 | def build_matcher(args): 100 | return HungarianMatcher(cost_class=args.set_cost_class, 101 | cost_bbox=args.set_cost_bbox, 102 | cost_giou=args.set_cost_giou) 103 | -------------------------------------------------------------------------------- /models/detr_gheads.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | from models.detr_modules.detr import MLP 6 | from models.detr_modules.transformer import TransformerDecoder, TransformerDecoderLayer 7 | from models.graph_transformer_dense import GraphTransformerLayerDense 8 | 9 | 10 | 11 | import networkx as nx 12 | import einops as es 13 | import einops.layers.torch as el 14 | from detectron2.config import configurable 15 | 16 | class PairwiseConcatLayer(nn.Module): 17 | def __init__(self): 18 | super().__init__() 19 | 20 | def forward(self, x, y): 21 | d1, d2 = x.shape[-2], y.shape[-2] 22 | grid_x, grid_y = torch.meshgrid(torch.arange(d1, device=x.device), torch.arange(d2, device=y.device), indexing='ij') 23 | res = torch.concat([torch.index_select(x, dim=-2, index=grid_x.flatten()), torch.index_select(y, dim=-2, index=grid_y.flatten())], dim=-1) 24 | res = es.rearrange(res, '... (L1 L2) C -> ... L1 L2 C', L1=d1, L2=d2) 25 | return res 26 | 27 | 28 | class DummyHead(nn.Module): 29 | def __init__(self, cfg): 30 | super().__init__() 31 | 32 | def forward(self, hs): 33 | hs = hs ** 2 34 | 35 | class BaseGHead(nn.Module): 36 | 37 | @configurable 38 | def __init__(self, num_layers=1, in_dim=256, hidden_dim=256, num_heads=1, num_nodes=100, edge_features='constant_one') -> None: 39 | super().__init__() 40 | self.in_dim = in_dim 41 | self.hidden_dim = hidden_dim 42 | self.num_heads = num_heads 43 | self.num_nodes = num_nodes 44 | self.num_layers = num_layers 45 | self.edge_features = edge_features 46 | 47 | if edge_features == 'concat': 48 | out_proj_edge = hidden_dim // 2 49 | self.pairwise_layer = PairwiseConcatLayer() 50 | else: 51 | out_proj_edge = hidden_dim 52 | 53 | 54 | self.proj_e1 = nn.Linear(in_dim, out_proj_edge) 55 | self.proj_e2 = nn.Linear(in_dim, out_proj_edge) 56 | 57 | self.proj_node_input = nn.Linear(in_dim, hidden_dim) 58 | self.edge_features = edge_features 59 | self.hidden_dim = hidden_dim 60 | 61 | @classmethod 62 | def from_config(cls, cfg): 63 | cfg = cfg.MODEL.GRAPH_HEAD 64 | return {'num_layers': cfg.NUM_LAYERS, 'hidden_dim': cfg.HIDDEN_DIM, 'num_heads': cfg.NUM_HEADS, 'edge_features': cfg.EDGE_FEATURES} 65 | 66 | 67 | def _compute_edge_features(self, features): 68 | # features L B Q C 69 | L, B, Q, C = features.shape 70 | device = features.device 71 | C = self.hidden_dim 72 | e1, e2 = self.proj_e1(features), self.proj_e2(features) 73 | if self.edge_features == 'concat': 74 | e = self.pairwise_layer(e1, e2) 75 | elif self.edge_features == 'sum': 76 | e = e1[:, :, None, :, :] + e2[:, :, :, None, :] 77 | elif self.edge_features == 'diff': 78 | e = e1[:, :, None, :, :] - e2[:, :, :, None, :] 79 | elif self.edge_features == 'div': 80 | e = e1[:, :, None, :, :] / e2[:, :, :, None, :] 81 | elif self.edge_features == 'mul': 82 | e = e1[:, :, None, :, :] * e2[:, :, :, None, :] 83 | else: 84 | raise NotImplementedError(f'{self.edge_features} aggregations not implemented') 85 | return e 86 | 87 | 88 | class DenseGraphTransformerHead(BaseGHead): 89 | 90 | @configurable 91 | def __init__(self, *args, **kwargs): 92 | super().__init__(*args, **kwargs) 93 | self.graph_transformer_layers = nn.ModuleList([ 94 | GraphTransformerLayerDense(in_dim=self.hidden_dim, out_dim=self.hidden_dim, num_heads=self.num_heads, layer_norm=True, batch_norm=False) 95 | for _ in range(self.num_layers) 96 | ]) 97 | self.edge_cls = MLP(input_dim=self.hidden_dim, hidden_dim=self.hidden_dim//2, output_dim=1, num_layers=3) 98 | 99 | @classmethod 100 | def from_config(cls, cfg): 101 | cfg = cfg.MODEL.GRAPH_HEAD 102 | return {'num_layers': cfg.NUM_LAYERS, 'hidden_dim': cfg.HIDDEN_DIM, 'num_heads': cfg.NUM_HEADS, 'edge_features': cfg.EDGE_FEATURES} 103 | 104 | 105 | def forward(self, hs: torch.Tensor): 106 | """ _summary_ 107 | Args: 108 | hs (torch.Tensor): L x B x Q x C 109 | """ 110 | L, B, Q, C = hs.shape 111 | 112 | e = self._compute_edge_features(features=hs) 113 | hs = self.proj_node_input(hs) 114 | # hs = es.rearrange(hs, 'B L Q C -> (B L) Q C') 115 | 116 | 117 | # e = es.rearrange(e, '(B Q1 Q2) C -> B Q1 Q2 C', B=L*B, Q1=Q, Q2=Q, C=C) 118 | hs = es.rearrange(hs, 'L B Q C -> (L B) Q C') 119 | e = es.rearrange(e, 'L B Q1 Q2 C -> (L B) Q1 Q2 C') 120 | 121 | for layer in self.graph_transformer_layers: 122 | hs, e = layer(hs, e) 123 | e = self.edge_cls(e) 124 | e = es.rearrange(e, "(L B) Q1 Q2 C -> L B Q1 Q2 C", C=1, L=L, Q1=Q, Q2=Q, B=B) 125 | return e 126 | 127 | 128 | def build_graph_head(cfg): 129 | name = cfg.MODEL.GRAPH_HEAD.NAME 130 | head = { 131 | 'DummyHead': DummyHead, 132 | 'GraphTransformerDense': DenseGraphTransformerHead, 133 | }[name](cfg) 134 | return head 135 | 136 | -------------------------------------------------------------------------------- /data/eval_metagraspnet.py: -------------------------------------------------------------------------------- 1 | from detectron2.evaluation import DatasetEvaluator 2 | from torchmetrics.detection import MeanAveragePrecision, IntersectionOverUnion 3 | import torch 4 | import numpy as np 5 | from metrics.vrmn_relationship import RelationshipEval 6 | from metrics.calculate_ap_results import pack_instance 7 | import metrics.oi_eval 8 | 9 | 10 | class GraphEvaluator(DatasetEvaluator): 11 | 12 | def __init__(self, dataset_name, output_dir=None, thresh=0.3, det_only=False) -> None: 13 | super().__init__() 14 | self.output_dir = output_dir 15 | self.dataset_name = dataset_name 16 | self.classless = 'real' in self.dataset_name 17 | self.threshold = thresh 18 | mkw = {'sync_on_compute': False} 19 | self.m_ap = MeanAveragePrecision(**mkw) 20 | self.m_ap_classless = MeanAveragePrecision(extended_summary=True, **mkw) 21 | self.iou = IntersectionOverUnion(**mkw) 22 | self.IoU = IntersectionOverUnion(class_metrics=True) 23 | self.VRMNRel = RelationshipEval(classless=self.classless, thresh=self.threshold) 24 | 25 | self.det_only = det_only 26 | self.ap_prep_list = [] 27 | 28 | def reset(self): 29 | self.m_ap.reset() 30 | self.m_ap_classless.reset() 31 | self.iou.reset() 32 | self.VRMNRel.reset() 33 | self.ap_prep_list = [] 34 | 35 | def _convert_predictions(self, d): 36 | d = d.get_fields() 37 | return { 38 | 'boxes': d['pred_boxes'].tensor.cpu(), 39 | 'scores': d['scores'].cpu(), 40 | 'labels': d['pred_classes'].cpu() if not self.classless else torch.zeros_like(d['pred_classes'], device='cpu'), 41 | } 42 | 43 | def _convert_gt(self, d): 44 | d = d.get_fields() 45 | return { 46 | 'boxes': d['gt_boxes'].tensor, 47 | 'labels': d['gt_classes'] if not self.classless else torch.zeros_like(d['gt_classes'], device='cpu'), 48 | } 49 | 50 | def _convert_graph_pred(self, graph: torch.Tensor): 51 | 52 | graph = graph.bool() 53 | num_objs = graph.shape[0] 54 | if num_objs <= 1: 55 | return torch.zeros(0, 3), torch.zeros(0, 2) 56 | x, y = np.triu_indices(num_objs, 1) 57 | graph = torch.vstack([graph[x, y], graph[y, x]]).T 58 | graph = torch.concat([~(graph[:, 0]|graph[:, 1])[:, None], graph], dim=1) 59 | graph = graph.cpu() 60 | return graph, np.vstack([x, y]) 61 | 62 | def process(self, input, output): 63 | assert len(input) == 1 # Adapt code if BS > 1 64 | gt_detection = [self._convert_gt(x['instances']) for x in input] 65 | predictions = [self._convert_predictions(x['instances']) for x in output] 66 | 67 | self.m_ap(preds=predictions, target=gt_detection) 68 | self.IoU(preds=predictions, target=gt_detection) 69 | 70 | if self.det_only: 71 | return 72 | 73 | graph_gt = [self._convert_graph_pred(x['dense_gt']) for x in input] 74 | graph_pred_all = [self._convert_graph_pred(x['graph_all']) for x in output] 75 | 76 | if self.save_all: 77 | self.save_list.append({ 78 | 'gt_bbox': gt_detection[0]['boxes'].numpy().tolist(), 79 | 'gt_class': gt_detection[0]['labels'].numpy().tolist(), 80 | 'pred_bbox': predictions[0]['boxes'].numpy().tolist(), 81 | 'pred_classes': predictions[0]['labels'].numpy().tolist(), 82 | 'pred_scores': predictions[0]['scores'].numpy().tolist(), 83 | 'gt_graph': torch.argmax(graph_gt[0][0].float(), 1).numpy().tolist(), 84 | 'pred_graph': graph_pred_all[0][0].cpu().numpy().tolist(), 85 | 'image_id': input[0]['image_id'] 86 | }) 87 | 88 | pack = pack_instance( 89 | pred_bboxes=predictions[0]['boxes'].numpy(), 90 | pred_labels=predictions[0]['labels'].numpy(), 91 | pred_scores=predictions[0]['scores'].numpy(), 92 | pred_graph=graph_pred_all[0][0].cpu().numpy(), 93 | gt_bbox=gt_detection[0]['boxes'].numpy(), 94 | gt_label=gt_detection[0]['labels'].numpy(), 95 | gt_graph=torch.argmax(graph_gt[0][0].float(), 1).numpy(), 96 | ) 97 | self.ap_prep_list.append(pack) 98 | 99 | if graph_gt[0][0].shape[0] == 0: 100 | return 101 | 102 | relation_res = self.VRMNRel(predictions, graph_pred_all, gt_detection, graph_gt) 103 | 104 | 105 | def evaluate(self): 106 | m_ap = self.m_ap.compute() 107 | if self.det_only: 108 | return { 109 | 'eval/map': m_ap['map'].item(), 110 | 'eval/map_50': m_ap['map_50'].item(), 111 | 'eval/map_75': m_ap['map_75'].item(), 112 | } 113 | relation_res = self.VRMNRel.compute() 114 | eval_relation_res = {f'eval/{k}': v for k, v in relation_res.items()} 115 | rel_ap_metrics = metrics.oi_eval.eval_rel_results(self.ap_prep_list, ['background', 'rel']) 116 | rel_ap_metrics = {'eval/'+ k: v for k, v in rel_ap_metrics.items()} 117 | 118 | 119 | return { 120 | 'eval/map': m_ap['map'].item(), 121 | 'eval/map_50': m_ap['map_50'].item(), 122 | 'eval/map_75': m_ap['map_75'].item(), 123 | **eval_relation_res, 124 | **rel_ap_metrics, 125 | } 126 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/backbone.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------ 6 | # Modified from DETR (https://github.com/facebookresearch/detr) 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | # ------------------------------------------------------------------------ 9 | 10 | """ 11 | Backbone modules. 12 | """ 13 | from collections import OrderedDict 14 | 15 | import torch 16 | import torch.nn.functional as F 17 | import torchvision 18 | from torch import nn 19 | from torchvision.models._utils import IntermediateLayerGetter 20 | from typing import Dict, List 21 | 22 | from utils.fb_misc import NestedTensor, is_main_process 23 | 24 | from .position_encoding import build_position_encoding 25 | 26 | 27 | class FrozenBatchNorm2d(torch.nn.Module): 28 | """ 29 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 30 | 31 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 32 | without which any other models than torchvision.models.resnet[18,34,50,101] 33 | produce nans. 34 | """ 35 | 36 | def __init__(self, n, eps=1e-5): 37 | super(FrozenBatchNorm2d, self).__init__() 38 | self.register_buffer("weight", torch.ones(n)) 39 | self.register_buffer("bias", torch.zeros(n)) 40 | self.register_buffer("running_mean", torch.zeros(n)) 41 | self.register_buffer("running_var", torch.ones(n)) 42 | self.eps = eps 43 | 44 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 45 | missing_keys, unexpected_keys, error_msgs): 46 | num_batches_tracked_key = prefix + 'num_batches_tracked' 47 | if num_batches_tracked_key in state_dict: 48 | del state_dict[num_batches_tracked_key] 49 | 50 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 51 | state_dict, prefix, local_metadata, strict, 52 | missing_keys, unexpected_keys, error_msgs) 53 | 54 | def forward(self, x): 55 | # move reshapes to the beginning 56 | # to make it fuser-friendly 57 | w = self.weight.reshape(1, -1, 1, 1) 58 | b = self.bias.reshape(1, -1, 1, 1) 59 | rv = self.running_var.reshape(1, -1, 1, 1) 60 | rm = self.running_mean.reshape(1, -1, 1, 1) 61 | eps = self.eps 62 | scale = w * (rv + eps).rsqrt() 63 | bias = b - rm * scale 64 | return x * scale + bias 65 | 66 | 67 | class BackboneBase(nn.Module): 68 | 69 | def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): 70 | super().__init__() 71 | for name, parameter in backbone.named_parameters(): 72 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 73 | parameter.requires_grad_(False) 74 | if return_interm_layers: 75 | # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 76 | return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"} 77 | self.strides = [8, 16, 32] 78 | self.num_channels = [512, 1024, 2048] 79 | else: 80 | return_layers = {'layer4': "0"} 81 | self.strides = [32] 82 | self.num_channels = [2048] 83 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 84 | 85 | def forward(self, tensor_list: NestedTensor): 86 | xs = self.body(tensor_list.tensors) 87 | out: Dict[str, NestedTensor] = {} 88 | for name, x in xs.items(): 89 | m = tensor_list.mask 90 | assert m is not None 91 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 92 | out[name] = NestedTensor(x, mask) 93 | return out 94 | 95 | 96 | class Backbone(BackboneBase): 97 | """ResNet backbone with frozen BatchNorm.""" 98 | def __init__(self, name: str, 99 | train_backbone: bool, 100 | return_interm_layers: bool, 101 | dilation: bool): 102 | norm_layer = FrozenBatchNorm2d 103 | backbone = getattr(torchvision.models, name)( 104 | replace_stride_with_dilation=[False, False, dilation], 105 | pretrained=is_main_process(), norm_layer=norm_layer) 106 | assert name not in ('resnet18', 'resnet34'), "number of channels are hard coded" 107 | super().__init__(backbone, train_backbone, return_interm_layers) 108 | if dilation: 109 | self.strides[-1] = self.strides[-1] // 2 110 | 111 | 112 | class Joiner(nn.Sequential): 113 | def __init__(self, backbone, position_embedding): 114 | super().__init__(backbone, position_embedding) 115 | self.strides = backbone.strides 116 | self.num_channels = backbone.num_channels 117 | 118 | def forward(self, tensor_list: NestedTensor): 119 | xs = self[0](tensor_list) 120 | out: List[NestedTensor] = [] 121 | pos = [] 122 | for name, x in sorted(xs.items()): 123 | out.append(x) 124 | 125 | # position encoding 126 | for x in out: 127 | pos.append(self[1](x).to(x.tensors.dtype)) 128 | 129 | return out, pos 130 | 131 | 132 | def build_backbone(cfg): 133 | position_embedding = build_position_encoding(cfg) 134 | train_backbone = cfg.SOLVER.BACKBONE_MULTIPLIER > 0 135 | return_interm_layers = cfg.MODEL.MASK_ON or (cfg.MODEL.DETR.NUM_FEATURE_LEVELS > 1) 136 | backbone = Backbone('resnet50', train_backbone, return_interm_layers, False) 137 | model = Joiner(backbone, position_embedding) 138 | return model 139 | -------------------------------------------------------------------------------- /metrics/vrmn_relationship.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | import itertools 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch 6 | from torchvision.ops import box_iou 7 | 8 | from metrics.common import filter_boxes, filter_graph 9 | 10 | 11 | def do_rel_single_image_eval(preds_det, preds_graph, gts_det, gts_graph, iou_thresh=0.5, classless=False): 12 | gt_bboxes = gts_det["boxes"].cpu().numpy() 13 | gt_classes = gts_det["labels"].cpu().numpy() 14 | num_gt = gt_bboxes.shape[0] 15 | rel_mat_gt = np.zeros([num_gt, num_gt]) 16 | indices = gts_graph[1] 17 | rel_gt = torch.argmax(gts_graph[0].float(), 1) 18 | rel_mat_gt[indices[0], indices[1]] = rel_gt 19 | rel_mat_gt = rel_mat_gt + (rel_mat_gt == 1).T *2 + (rel_mat_gt == 2).T * 1 20 | det_bboxes = preds_det['boxes'].cpu().numpy() 21 | det_labels = preds_det['labels'].cpu().numpy() 22 | det_rel_prob = preds_graph 23 | 24 | # total number of relationships 25 | ngt_rel = num_gt * (num_gt - 1) / 2 26 | 27 | 28 | # no detected rel, tp and fp is all 0 29 | if not det_rel_prob[0].shape[0]: 30 | # return 0, 0, num_gt * (num_gt - 1) /2 31 | res = { 32 | 'true_positive': 0, 33 | 'false_positive': 0, 34 | 'all_correct': 0 35 | } 36 | return res 37 | 38 | det_rel = np.argmax(det_rel_prob[0].float(), 1) 39 | overlaps = box_iou(torch.from_numpy(gt_bboxes), torch.from_numpy(det_bboxes)).numpy().T 40 | # match bbox ground truth and detections 41 | match_mat = np.zeros([det_bboxes.shape[0], gt_bboxes.shape[0]]) 42 | for i in range(det_bboxes.shape[0]): 43 | if classless: 44 | match_cand_inds = np.ones_like(det_labels[i]) 45 | else: 46 | match_cand_inds = (det_labels[i] == gt_classes) 47 | match_cand_overlap = overlaps[i] * match_cand_inds 48 | # decending sort 49 | ovs = np.sort(match_cand_overlap, 0) 50 | ovs = ovs[::-1] 51 | inds = np.argsort(match_cand_overlap, 0) 52 | inds = inds[::-1] 53 | for ii, ov in enumerate(ovs): 54 | if ov > iou_thresh and np.sum(match_mat[:,inds[ii]]) == 0: 55 | match_mat[i, inds[ii]] = 1 56 | break 57 | elif ov < iou_thresh: 58 | break 59 | 60 | # true positive and false positive 61 | tp = 0 62 | fp = 0 63 | rel_ind = 0 64 | correct_edge_found = 0 65 | wrong_edge_direction = 0 66 | missed_edge = 0 67 | correct_empty = 0 68 | wrong_presence_of_edge = 0 69 | relation_from_wrong_box = 0 70 | 71 | for b1 in range(det_bboxes.shape[0]): 72 | for b2 in range(b1+1, det_bboxes.shape[0]): 73 | if np.sum(match_mat[b1]) > 0 and np.sum(match_mat[b2])> 0: 74 | b1_gt = np.argmax(match_mat[b1]) 75 | b2_gt = np.argmax(match_mat[b2]) 76 | rel_gt = rel_mat_gt[b1_gt, b2_gt] 77 | rel_pred = det_rel[rel_ind] 78 | 79 | if rel_gt != 0: 80 | # WE FOUND AN EDGE 81 | if rel_gt == rel_pred: 82 | correct_edge_found += 1 83 | elif (rel_gt != rel_pred) and rel_pred != 0: 84 | wrong_edge_direction += 1 85 | elif (rel_gt != rel_pred) and rel_pred == 0: 86 | missed_edge += 1 87 | else: 88 | if rel_gt == rel_pred: 89 | correct_empty += 1 90 | else: 91 | wrong_presence_of_edge += 1 92 | 93 | if rel_gt == rel_pred: 94 | tp += 1 95 | else: 96 | fp += 1 97 | else: 98 | relation_from_wrong_box += 1 99 | fp += 1 100 | rel_ind += 1 101 | 102 | assert fp + tp == det_bboxes.shape[0] * (det_bboxes.shape[0] - 1) / 2 103 | 104 | res = { 105 | 'true_positive': tp, 106 | 'false_positive': fp, 107 | 'all_correct': fp == 0 and tp == ngt_rel 108 | } 109 | return res 110 | return tp, fp, ngt_rel 111 | 112 | class RelationshipEval(nn.Module): 113 | def __init__(self, classless=False, thresh=0.5, iou_thresh=0.5,): 114 | self.preds = [] 115 | super().__init__() 116 | self.counter = 0 117 | self.classless = classless 118 | self.threshold = thresh 119 | self.iou_thresh = iou_thresh 120 | self.results = [] 121 | 122 | def reset(self): 123 | self.counter = 0 124 | self.results = [] 125 | 126 | def compute_sum(self,): 127 | return {k: float(sum(map(lambda x: x[k], self.results))) for k in self.results[0]} 128 | 129 | def compute(self): 130 | acc_res = self.compute_sum() 131 | 132 | if acc_res['true_positive'] + acc_res['false_positive'] > 0: 133 | o_prec = acc_res['true_positive'] / (acc_res['true_positive'] + acc_res['false_positive']) 134 | else: 135 | o_prec = 0 136 | o_rec = acc_res['true_positive'] / acc_res['ngt_rel'] 137 | 138 | 139 | img_acc = acc_res['all_correct'] / len(self.results) 140 | 141 | return { 142 | 'OP': o_prec, # Precision over all relations 143 | 'OR': o_rec, # Recall over all relations 144 | 'IA': img_acc, # Complitely Solved Images 145 | } 146 | 147 | 148 | 149 | 150 | def forward(self, preds_det, preds_graph, gts_det, gts_graph): 151 | filtered_pred_graph = [filter_graph(instances=b, graph=g, thresh=self.threshold) for b, g in zip(preds_det, preds_graph)] 152 | filtered_boxes = [filter_boxes(instances=b, thresh=self.threshold) for b in preds_det] 153 | 154 | res_list = [] 155 | for p_det, p_graph, gt_det, gt_graph in zip(filtered_boxes, filtered_pred_graph, gts_det, gts_graph): 156 | res = do_rel_single_image_eval(p_det, p_graph, gt_det, gt_graph, iou_thresh=self.iou_thresh, classless=self.classless) 157 | res_list.append(res) 158 | 159 | self.results += res_list 160 | return res 161 | 162 | 163 | -------------------------------------------------------------------------------- /models/graph_transformer_dense.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import einops as ei 6 | import lovely_tensors as lt 7 | from einops.layers.torch import Rearrange 8 | 9 | lt.monkey_patch() 10 | 11 | class MultiHeadAttentionLayer(nn.Module): 12 | def __init__(self, in_dim, out_dim, num_heads, use_bias): 13 | super().__init__() 14 | 15 | self.in_dim = in_dim 16 | self.out_dim = out_dim 17 | self.num_heads = num_heads 18 | self.Q = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) 19 | self.K = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) 20 | self.V = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) 21 | self.proj_e = nn.Linear(in_dim, out_dim * num_heads, bias=use_bias) 22 | self.sqrt_dim = np.sqrt(out_dim) 23 | 24 | def _reshape_to_batches(self, x): 25 | batch_size, seq_len, in_feature = x.size() 26 | sub_dim = in_feature // self.head_num 27 | return x.reshape(batch_size, seq_len, self.head_num, sub_dim)\ 28 | .permute(0, 2, 1, 3)\ 29 | .reshape(batch_size * self.head_num, seq_len, sub_dim) 30 | 31 | 32 | def forward(self, h, e): 33 | 34 | b, q1, c = h.shape 35 | b, q2, q3, c = e.shape 36 | assert q1 == q2 == q3 37 | 38 | Q_h = self.Q(h) 39 | K_h = self.K(h) 40 | V_h = self.V(h) 41 | proj_e = self.proj_e(e) 42 | Q_h = ei.rearrange(Q_h, 'B L (H C) -> B H L C', H=self.num_heads, C=self.out_dim) 43 | K_h = ei.rearrange(K_h, 'B L (H C) -> B H L C', H=self.num_heads, C=self.out_dim) 44 | V_h = ei.rearrange(V_h, 'B L (H C) -> B H L C', H=self.num_heads, C=self.out_dim) 45 | proj_e = ei.rearrange(proj_e, 'B L1 L2 (H C) -> B H L1 L2 C', H=self.num_heads, C=self.out_dim) 46 | 47 | 48 | score = Q_h[:, :, :, None, :] * K_h[:, :, None, :, :] / self.sqrt_dim 49 | score = ei.repeat(Q_h, 'B H L C -> B H L Lr C', Lr=q1) * ei.repeat(K_h, 'B H L C -> B H Lr L C', Lr=q2) / self.sqrt_dim 50 | 51 | score = score * proj_e 52 | e_out = score 53 | score = ei.reduce(score, 'B H L1 L2 C -> B H L1 L2', 'sum').clamp(-5, 5) 54 | score = torch.nn.functional.softmax(score, dim=2) 55 | h_out = score @ V_h 56 | 57 | h_out = ei.rearrange(h_out, 'B H L C -> B L (H C)') 58 | e_out = ei.rearrange(e_out, 'B H L1 L2 C -> B L1 L2 (H C)') 59 | 60 | return h_out, e_out 61 | 62 | 63 | class GraphTransformerLayerDense(nn.Module): 64 | 65 | def _build_norm(self, out_dim, num_nodes): 66 | if self.layer_norm: 67 | norm_h = nn.LayerNorm(out_dim) 68 | norm_e = nn.LayerNorm(out_dim) 69 | elif self.batch_norm: 70 | norm_h = nn.Sequential( 71 | Rearrange('B L C -> B C L'), 72 | nn.BatchNorm1d(out_dim), 73 | Rearrange('B C L -> B L C') 74 | ) 75 | norm_e = nn.Sequential( 76 | Rearrange('B L1 L2 C -> B C (L1 L2)'), 77 | nn.BatchNorm1d(out_dim), 78 | Rearrange('B C (L1 L2) -> B L1 L2 C', L1=num_nodes, L2=num_nodes) 79 | ) 80 | else: 81 | norm_h = nn.Identity() 82 | norm_e = nn.Identity() 83 | return norm_h, norm_e 84 | 85 | 86 | def __init__(self, in_dim, out_dim, num_heads, num_nodes=100, dropout=0.0, layer_norm=False, batch_norm=True, residual=True, use_bias=False) -> None: 87 | super().__init__() 88 | 89 | self.in_channels = in_dim 90 | self.out_channels = out_dim 91 | self.num_heads = num_heads 92 | self.dropout = dropout 93 | self.residual = residual 94 | self.layer_norm = layer_norm 95 | self.batch_norm = batch_norm 96 | self.num_nodes = num_nodes 97 | 98 | self.attention = MultiHeadAttentionLayer(in_dim, out_dim//num_heads, num_heads, use_bias) 99 | 100 | self.O_h = nn.Linear(out_dim, out_dim) 101 | self.O_e = nn.Linear(out_dim, out_dim) 102 | 103 | self.norm1_h, self.norm1_e = self._build_norm(out_dim=out_dim, num_nodes=num_nodes) 104 | 105 | # FFN for h 106 | self.FFN_h_layer1 = nn.Linear(out_dim, out_dim*2) 107 | self.FFN_h_layer2 = nn.Linear(out_dim*2, out_dim) 108 | 109 | # FFN for e 110 | self.FFN_e_layer1 = nn.Linear(out_dim, out_dim*2) 111 | self.FFN_e_layer2 = nn.Linear(out_dim*2, out_dim) 112 | 113 | self.norm2_h, self.norm2_e = self._build_norm(out_dim=out_dim, num_nodes=num_nodes) 114 | 115 | 116 | def forward(self, h, e): 117 | h_in1 = h # for first residual connection 118 | e_in1 = e # for first residual connection 119 | 120 | # multi-head attention out 121 | h_attn_out, e_attn_out = self.attention(h, e) 122 | 123 | h = h_attn_out 124 | e = e_attn_out 125 | 126 | h = F.dropout(h, self.dropout, training=self.training) 127 | e = F.dropout(e, self.dropout, training=self.training) 128 | 129 | h = self.O_h(h) 130 | e = self.O_e(e) 131 | 132 | if self.residual: 133 | h = h_in1 + h # residual connection 134 | e = e_in1 + e # residual connection 135 | 136 | h = self.norm1_h(h) 137 | e = self.norm1_e(e) 138 | 139 | 140 | h_in2 = h # for second residual connection 141 | e_in2 = e # for second residual connection 142 | 143 | # FFN for h 144 | h = self.FFN_h_layer1(h) 145 | h = F.relu(h) 146 | h = F.dropout(h, self.dropout, training=self.training) 147 | h = self.FFN_h_layer2(h) 148 | 149 | # FFN for e 150 | e = self.FFN_e_layer1(e) 151 | e = F.relu(e) 152 | e = F.dropout(e, self.dropout, training=self.training) 153 | e = self.FFN_e_layer2(e) 154 | 155 | if self.residual: 156 | h = h_in2 + h # residual connection 157 | e = e_in2 + e # residual connection 158 | 159 | h = self.norm2_h(h) 160 | e = self.norm2_e(e) 161 | 162 | return h, e 163 | 164 | if __name__ == '__main__': 165 | l = GraphTransformerLayerDense(256, 256, 2) 166 | x = torch.rand(20, 100, 256) 167 | e = torch.rand(20, 100, 100, 256) 168 | x = l(x, e) 169 | x = x -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from detectron2.solver.build import maybe_add_gradient_clipping 3 | from typing import Set, Sequence, List, Any, Dict 4 | import torch 5 | from detectron2.engine import hooks 6 | import detectron2.utils.comm as comm 7 | from detectron2.evaluation.testing import flatten_results_dict 8 | 9 | 10 | def build_optimizer(cfg, model): 11 | params: List[Dict[str, Any]] = [] 12 | memo: Set[torch.nn.parameter.Parameter] = set() 13 | gradient_accum = cfg.SOLVER.GRAD_STEP 14 | for key, value in model.named_parameters(recurse=True): 15 | if not value.requires_grad: 16 | continue 17 | # Avoid duplicating parameters 18 | if value in memo: 19 | continue 20 | memo.add(value) 21 | lr = cfg.SOLVER.BASE_LR 22 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 23 | if "backbone" in key: 24 | lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER 25 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 26 | 27 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 28 | # detectron2 doesn't have full model gradient clipping now 29 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 30 | enable = ( 31 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 32 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 33 | and clip_norm_val > 0.0 34 | ) 35 | 36 | class FullModelGradientClippingOptimizer(optim): 37 | def step(self, closure=None): 38 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 39 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 40 | super().step(closure=closure) 41 | 42 | return FullModelGradientClippingOptimizer if enable else optim 43 | 44 | def add_gradient_accumultion(optim): 45 | accum_steps = gradient_accum 46 | class GradAccumOptim(optim): 47 | def __init__(self, *args, **kwargs): 48 | super().__init__(*args, **kwargs) 49 | self._iter_counter = 0 50 | 51 | def zero_grad(self): 52 | if self._iter_counter % accum_steps == 0 and self._iter_counter != 0: 53 | super().zero_grad() 54 | self._iter_counter += 1 55 | 56 | return GradAccumOptim if accum_steps > 1 else optim 57 | 58 | 59 | optimizer_type = cfg.SOLVER.OPTIMIZER 60 | if optimizer_type == "SGD": 61 | optimizer = add_gradient_accumultion(maybe_add_full_model_gradient_clipping(torch.optim.SGD))( 62 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM 63 | ) 64 | elif optimizer_type == "ADAMW": 65 | optimizer = add_gradient_accumultion(maybe_add_full_model_gradient_clipping(torch.optim.AdamW))( 66 | params, cfg.SOLVER.BASE_LR 67 | ) 68 | elif optimizer_type == "ADAM": 69 | optimizer = add_gradient_accumultion(maybe_add_full_model_gradient_clipping(torch.optim.Adam))( 70 | params, cfg.SOLVER.BASE_LR 71 | ) 72 | else: 73 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 74 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 75 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 76 | 77 | return optimizer 78 | 79 | 80 | 81 | class EvalTestHook(hooks.HookBase): 82 | """ 83 | Run an evaluation function periodically, and at the end of training. 84 | 85 | It is executed every ``eval_period`` iterations and after the last iteration. 86 | """ 87 | 88 | def __init__(self, eval_period, eval_function, test_function, eval_after_train=True): 89 | """ 90 | Args: 91 | eval_period (int): the period to run `eval_function`. Set to 0 to 92 | not evaluate periodically (but still evaluate after the last iteration 93 | if `eval_after_train` is True). 94 | eval_function (callable): a function which takes no arguments, and 95 | returns a nested dict of evaluation metrics. 96 | eval_after_train (bool): whether to evaluate after the last iteration 97 | 98 | Note: 99 | This hook must be enabled in all or none workers. 100 | If you would like only certain workers to perform evaluation, 101 | give other workers a no-op function (`eval_function=lambda: None`). 102 | """ 103 | self._period = eval_period 104 | self._eval_func = eval_function 105 | self._test_func = test_function 106 | self._eval_after_train = eval_after_train 107 | 108 | def _do_eval(self, test=False): 109 | 110 | if test: 111 | results = self._test_func() 112 | else: 113 | results = self._eval_func() 114 | 115 | if results: 116 | assert isinstance( 117 | results, dict 118 | ), "Eval function must return a dict. Got {} instead.".format(results) 119 | 120 | flattened_results = flatten_results_dict(results) 121 | for k, v in flattened_results.items(): 122 | try: 123 | v = float(v) 124 | except Exception as e: 125 | raise ValueError( 126 | "[EvalHook] eval_function should return a nested dict of float. " 127 | "Got '{}: {}' instead.".format(k, v) 128 | ) from e 129 | self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) 130 | 131 | # Evaluation may take different time among workers. 132 | # A barrier make them start the next iteration together. 133 | comm.synchronize() 134 | 135 | def after_step(self): 136 | next_iter = self.trainer.iter + 1 137 | if self._period > 0 and next_iter % self._period == 0: 138 | # do the last eval in after_train 139 | if next_iter != self.trainer.max_iter: 140 | self._do_eval(test=False) 141 | 142 | def after_train(self): 143 | # This condition is to prevent the eval from running after a failed training 144 | if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter: 145 | self._do_eval(test=True) 146 | # func is likely a closure that holds reference to the trainer 147 | # therefore we clean it to avoid circular reference in the end 148 | del self._eval_func 149 | del self._test_func -------------------------------------------------------------------------------- /models/rcnn_graph.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from detectron2.modeling import META_ARCH_REGISTRY 3 | from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN 4 | from detectron2.structures import Instances 5 | from detectron2.utils.events import get_event_storage 6 | from typing import Dict, List, Optional 7 | import torch.nn as nn 8 | 9 | 10 | class RelCELoss(nn.Module): 11 | def __init__(self, reweight=(0.1, 1.0, 1.0)): 12 | super().__init__() 13 | self.loss = nn.CrossEntropyLoss(weight=torch.tensor(reweight)) 14 | 15 | def forward(self, x, y): 16 | loss = self.loss(x, y) 17 | preds = (torch.argmax(x, dim=1) == y.long()) 18 | train_acc = torch.mean(preds, dtype=float) 19 | edge_recall = torch.mean(preds[y.long()>0], dtype=float) 20 | edge_recall = torch.nan_to_num(edge_recall, nan=0) 21 | return {'graph_bce_loss': loss, 'rel_acc': train_acc, 'edge_recall': edge_recall} 22 | 23 | 24 | @META_ARCH_REGISTRY.register() 25 | class GraphRCNN(GeneralizedRCNN): 26 | 27 | def __init__(self, *args, **kwargs): 28 | super().__init__(*args, **kwargs) 29 | 30 | 31 | def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): 32 | """ 33 | Args: 34 | batched_inputs: a list, batched outputs of :class:`DatasetMapper` . 35 | Each item in the list contains the inputs for one image. 36 | For now, each item in the list is a dict that contains: 37 | 38 | * image: Tensor, image in (C, H, W) format. 39 | * instances (optional): groundtruth :class:`Instances` 40 | * proposals (optional): :class:`Instances`, precomputed proposals. 41 | 42 | Other information that's included in the original dicts, such as: 43 | 44 | * "height", "width" (int): the output resolution of the model, used in inference. 45 | See :meth:`postprocess` for details. 46 | 47 | Returns: 48 | list[dict]: 49 | Each dict is the output for one input image. 50 | The dict contains one key "instances" whose value is a :class:`Instances`. 51 | The :class:`Instances` object has the following keys: 52 | "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" 53 | """ 54 | if not self.training: 55 | if "instances" in batched_inputs[0]: 56 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 57 | return self.inference(batched_inputs, detected_instances=gt_instances) 58 | 59 | images = self.preprocess_image(batched_inputs) 60 | if "instances" in batched_inputs[0]: 61 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 62 | gt_graphs = [x["graph_gt"] for x in batched_inputs] 63 | else: 64 | gt_instances = None 65 | 66 | features = self.backbone(images.tensor) 67 | 68 | if self.proposal_generator is not None: 69 | proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) 70 | else: 71 | assert "proposals" in batched_inputs[0] 72 | proposals = [x["proposals"].to(self.device) for x in batched_inputs] 73 | proposal_losses = {} 74 | 75 | _, detector_losses = self.roi_heads(images, features, proposals, gt_instances, gt_graph=gt_graphs) 76 | 77 | losses = {} 78 | losses.update(detector_losses) 79 | losses.update(proposal_losses) 80 | return losses 81 | 82 | 83 | def inference( 84 | self, 85 | batched_inputs: List[Dict[str, torch.Tensor]], 86 | detected_instances: Optional[List[Instances]] = None, 87 | do_postprocess: bool = True, 88 | ): 89 | """ 90 | Run inference on the given inputs. 91 | 92 | Args: 93 | batched_inputs (list[dict]): same as in :meth:`forward` 94 | detected_instances (None or list[Instances]): if not None, it 95 | contains an `Instances` object per image. The `Instances` 96 | object contains "pred_boxes" and "pred_classes" which are 97 | known boxes in the image. 98 | The inference will then skip the detection of bounding boxes, 99 | and only predict other per-ROI outputs. 100 | do_postprocess (bool): whether to apply post-processing on the outputs. 101 | 102 | Returns: 103 | When do_postprocess=True, same as in :meth:`forward`. 104 | Otherwise, a list[Instances] containing raw network outputs. 105 | """ 106 | assert not self.training 107 | 108 | images = self.preprocess_image(batched_inputs) 109 | features = self.backbone(images.tensor) 110 | 111 | proposals, _ = self.proposal_generator(images, features, None) 112 | 113 | if detected_instances is not None: 114 | proposals_gt = [x.gt_boxes.to(images.device) for x in detected_instances] 115 | 116 | num_objs = [x.tensor.shape[0] for x in proposals_gt] 117 | results, _ = self.roi_heads(images, features, proposals, None) 118 | proposals_pred = [x.pred_boxes for x in results] 119 | 120 | graph_pred = self.roi_heads._forward_graph(features, proposals_gt, None) 121 | graph_pred = self.roi_heads.pred_to_dense(graph_pred, num_objs) 122 | 123 | num_objs = [x.tensor.shape[0] for x in proposals_pred] 124 | graph_pred_all = self.roi_heads._forward_graph(features, proposals_pred, None) 125 | graph_pred_all = self.roi_heads.pred_to_dense(graph_pred_all, num_objs) 126 | else: 127 | results, _ = self.roi_heads(images, features, proposals, None) 128 | proposals_pred = [x.pred_boxes for x in results] 129 | num_objs = [x.tensor.shape[0] for x in proposals_pred] 130 | graph_pred_all = self.roi_heads._forward_graph(features, proposals_pred, None) 131 | graph_pred_all = self.roi_heads.pred_to_dense(graph_pred, num_objs) 132 | graph_pred = [None] * len(graph_pred) 133 | 134 | if do_postprocess: 135 | assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." 136 | instances = GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) 137 | else: 138 | instances = results 139 | for i, g, ga in zip(instances, graph_pred, graph_pred_all): 140 | i['graph'] = g 141 | i['graph_all'] = ga 142 | 143 | return instances -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/modules/ms_deform_attn.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------------------------------ 2 | # Deformable DETR 3 | # Copyright (c) 2020 SenseTime. All Rights Reserved. 4 | # Licensed under the Apache License, Version 2.0 [see LICENSE for details] 5 | # ------------------------------------------------------------------------------------------------ 6 | # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 7 | # ------------------------------------------------------------------------------------------------ 8 | 9 | from __future__ import absolute_import 10 | from __future__ import print_function 11 | from __future__ import division 12 | 13 | import warnings 14 | import math 15 | 16 | import torch 17 | from torch import nn 18 | import torch.nn.functional as F 19 | from torch.nn.init import xavier_uniform_, constant_ 20 | 21 | from ..functions import MSDeformAttnFunction 22 | 23 | 24 | def _is_power_of_2(n): 25 | if (not isinstance(n, int)) or (n < 0): 26 | raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) 27 | return (n & (n-1) == 0) and n != 0 28 | 29 | 30 | class MSDeformAttn(nn.Module): 31 | def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): 32 | """ 33 | Multi-Scale Deformable Attention Module 34 | :param d_model hidden dimension 35 | :param n_levels number of feature levels 36 | :param n_heads number of attention heads 37 | :param n_points number of sampling points per attention head per feature level 38 | """ 39 | super().__init__() 40 | if d_model % n_heads != 0: 41 | raise ValueError('d_model must be divisible by n_heads, but got {} and {}'.format(d_model, n_heads)) 42 | _d_per_head = d_model // n_heads 43 | # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation 44 | if not _is_power_of_2(_d_per_head): 45 | warnings.warn("You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " 46 | "which is more efficient in our CUDA implementation.") 47 | 48 | self.im2col_step = 64 49 | 50 | self.d_model = d_model 51 | self.n_levels = n_levels 52 | self.n_heads = n_heads 53 | self.n_points = n_points 54 | 55 | self.sampling_offsets = nn.Linear(d_model, n_heads * n_levels * n_points * 2) 56 | self.attention_weights = nn.Linear(d_model, n_heads * n_levels * n_points) 57 | self.value_proj = nn.Linear(d_model, d_model) 58 | self.output_proj = nn.Linear(d_model, d_model) 59 | 60 | self._reset_parameters() 61 | 62 | def _reset_parameters(self): 63 | constant_(self.sampling_offsets.weight.data, 0.) 64 | thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) 65 | grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) 66 | grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(self.n_heads, 1, 1, 2).repeat(1, self.n_levels, self.n_points, 1) 67 | for i in range(self.n_points): 68 | grid_init[:, :, i, :] *= i + 1 69 | with torch.no_grad(): 70 | self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) 71 | constant_(self.attention_weights.weight.data, 0.) 72 | constant_(self.attention_weights.bias.data, 0.) 73 | xavier_uniform_(self.value_proj.weight.data) 74 | constant_(self.value_proj.bias.data, 0.) 75 | xavier_uniform_(self.output_proj.weight.data) 76 | constant_(self.output_proj.bias.data, 0.) 77 | 78 | def forward(self, query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None): 79 | """ 80 | :param query (N, Length_{query}, C) 81 | :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area 82 | or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes 83 | :param input_flatten (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) 84 | :param input_spatial_shapes (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] 85 | :param input_level_start_index (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] 86 | :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements 87 | 88 | :return output (N, Length_{query}, C) 89 | """ 90 | N, Len_q, _ = query.shape 91 | N, Len_in, _ = input_flatten.shape 92 | assert (input_spatial_shapes[:, 0] * input_spatial_shapes[:, 1]).sum() == Len_in 93 | 94 | value = self.value_proj(input_flatten) 95 | if input_padding_mask is not None: 96 | value = value.masked_fill(input_padding_mask[..., None], float(0)) 97 | value = value.view(N, Len_in, self.n_heads, self.d_model // self.n_heads) 98 | sampling_offsets = self.sampling_offsets(query).view(N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) 99 | attention_weights = self.attention_weights(query).view(N, Len_q, self.n_heads, self.n_levels * self.n_points) 100 | attention_weights = F.softmax(attention_weights, -1).view(N, Len_q, self.n_heads, self.n_levels, self.n_points) 101 | # N, Len_q, n_heads, n_levels, n_points, 2 102 | if reference_points.shape[-1] == 2: 103 | offset_normalizer = torch.stack([input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1) 104 | sampling_locations = reference_points[:, :, None, :, None, :] \ 105 | + sampling_offsets / offset_normalizer[None, None, None, :, None, :] 106 | elif reference_points.shape[-1] == 4: 107 | sampling_locations = reference_points[:, :, None, :, None, :2] \ 108 | + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 109 | else: 110 | raise ValueError( 111 | 'Last dim of reference_points must be 2 or 4, but get {} instead.'.format(reference_points.shape[-1])) 112 | 113 | sampling_locations = sampling_locations.to(value.dtype) 114 | attention_weights = attention_weights.to(value.dtype) 115 | output = MSDeformAttnFunction.apply( 116 | value, input_spatial_shapes, input_level_start_index, sampling_locations, attention_weights, self.im2col_step) 117 | 118 | output = self.output_proj(output) 119 | return output 120 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import numpy as np 4 | import torch_geometric.data as data_g 5 | import torch_geometric as tg 6 | import itertools 7 | import logging 8 | import numpy as np 9 | import operator 10 | import pickle 11 | from collections import OrderedDict, defaultdict 12 | from typing import Any, Callable, Dict, List, Optional, Union 13 | import torch 14 | import torch.utils.data as torchdata 15 | from tabulate import tabulate 16 | from termcolor import colored 17 | 18 | from detectron2.config import configurable 19 | from detectron2.structures import BoxMode 20 | from detectron2.utils.comm import get_world_size 21 | from detectron2.utils.env import seed_all_rng 22 | from detectron2.utils.file_io import PathManager 23 | from detectron2.utils.logger import _log_api_usage, log_first_n 24 | 25 | from detectron2.data.catalog import DatasetCatalog, MetadataCatalog 26 | from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset 27 | from detectron2.data.dataset_mapper import DatasetMapper 28 | from detectron2.data.detection_utils import check_metadata_consistency 29 | from detectron2.data.samplers import ( 30 | InferenceSampler, 31 | RandomSubsetTrainingSampler, 32 | RepeatFactorTrainingSampler, 33 | TrainingSampler, 34 | ) 35 | 36 | 37 | 38 | def cls_graph_to_dense(cls_pred: torch.Tensor, obj_num: torch.Tensor): 39 | res = [] 40 | running_ix = 0 41 | for num in obj_num: 42 | if num <= 1: 43 | res.append(torch.tensor([])) 44 | continue 45 | i_x, i_y = np.triu_indices(n=num, k=1) 46 | curr_pred = cls_pred[running_ix:running_ix+len(i_x)] 47 | dense = torch.zeros(num, num, dtype=curr_pred.dtype, device=curr_pred.device) 48 | dense[i_x, i_y] = curr_pred 49 | inverse_mask = dense == 2 50 | dense += inverse_mask.T 51 | dense -= inverse_mask * 2 52 | res.append(dense) 53 | running_ix += len(i_x) 54 | return res 55 | 56 | 57 | def relation_graph_to_dense(graph: data_g.Data): 58 | num_objs = graph.is_object.sum() 59 | dense = torch.zeros(num_objs, num_objs, graph.x.shape[1]) 60 | edge_index, _ = tg.utils.remove_self_loops(graph.edge_index) 61 | for i in range(num_objs): 62 | nodes, connectivity, mapping, e_mask = tg.utils.k_hop_subgraph(i, 2, edge_index, flow='source_to_target') 63 | connection_nodes = nodes[~graph.is_object[nodes]] 64 | src = connectivity[0, ::2] 65 | dest = connectivity[1, 1::2] 66 | dense[src, dest] = graph.x[connection_nodes] 67 | return dense 68 | 69 | 70 | def _test_loader_from_config(cfg, dataset_name, mapper=None): 71 | """ 72 | Uses the given `dataset_name` argument (instead of the names in cfg), because the 73 | standard practice is to evaluate each test set individually (not combining them). 74 | """ 75 | if isinstance(dataset_name, str): 76 | dataset_name = [dataset_name] 77 | 78 | dataset = get_detection_dataset_dicts( 79 | dataset_name, 80 | filter_empty=False, 81 | proposal_files=[ 82 | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name 83 | ] 84 | if cfg.MODEL.LOAD_PROPOSALS 85 | else None, 86 | ) 87 | if mapper is None: 88 | mapper = DatasetMapper(cfg, False) 89 | return { 90 | "dataset": dataset, 91 | "mapper": mapper, 92 | "num_workers": cfg.DATALOADER.NUM_WORKERS, 93 | "sampler": InferenceSampler(len(dataset)) 94 | if not isinstance(dataset, torchdata.IterableDataset) 95 | else None, 96 | } 97 | 98 | 99 | @configurable(from_config=_test_loader_from_config) 100 | def build_detection_test_loader( 101 | dataset: Union[List[Any], torchdata.Dataset], 102 | *, 103 | mapper: Callable[[Dict[str, Any]], Any], 104 | sampler: Optional[torchdata.Sampler] = None, 105 | batch_size: int = 1, 106 | num_workers: int = 0, 107 | collate_fn: Optional[Callable[[List[Any]], Any]] = None, 108 | **kwargs, 109 | ) -> torchdata.DataLoader: 110 | """ 111 | Similar to `build_detection_train_loader`, with default batch size = 1, 112 | and sampler = :class:`InferenceSampler`. This sampler coordinates all workers 113 | to produce the exact set of all samples. 114 | 115 | Args: 116 | dataset: a list of dataset dicts, 117 | or a pytorch dataset (either map-style or iterable). They can be obtained 118 | by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. 119 | mapper: a callable which takes a sample (dict) from dataset 120 | and returns the format to be consumed by the model. 121 | When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. 122 | sampler: a sampler that produces 123 | indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, 124 | which splits the dataset across all workers. Sampler must be None 125 | if `dataset` is iterable. 126 | batch_size: the batch size of the data loader to be created. 127 | Default to 1 image per worker since this is the standard when reporting 128 | inference time in papers. 129 | num_workers: number of parallel data loading workers 130 | collate_fn: same as the argument of `torch.utils.data.DataLoader`. 131 | Defaults to do no collation and return a list of data. 132 | 133 | Returns: 134 | DataLoader: a torch DataLoader, that loads the given detection 135 | dataset, with test-time transformation and batching. 136 | 137 | Examples: 138 | :: 139 | data_loader = build_detection_test_loader( 140 | DatasetRegistry.get("my_test"), 141 | mapper=DatasetMapper(...)) 142 | 143 | # or, instantiate with a CfgNode: 144 | data_loader = build_detection_test_loader(cfg, "my_test") 145 | """ 146 | if isinstance(dataset, list): 147 | dataset = DatasetFromList(dataset, copy=False) 148 | if mapper is not None: 149 | dataset = MapDataset(dataset, mapper) 150 | if isinstance(dataset, torchdata.IterableDataset): 151 | assert sampler is None, "sampler must be None if dataset is IterableDataset" 152 | else: 153 | if sampler is None: 154 | sampler = InferenceSampler(len(dataset)) 155 | return torchdata.DataLoader( 156 | dataset, 157 | batch_size=batch_size, 158 | sampler=sampler, 159 | drop_last=False, 160 | num_workers=num_workers, 161 | collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, 162 | **kwargs 163 | ) 164 | 165 | 166 | 167 | def trivial_batch_collator(batch): 168 | """ 169 | A batch collator that does nothing. 170 | """ 171 | return batch -------------------------------------------------------------------------------- /data/augmentations.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import warnings 3 | import detectron2.data.transforms as T 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as VT 7 | 8 | import torchvision.transforms.functional as F 9 | from fvcore.transforms import Transform 10 | 11 | 12 | class EraseTransform(Transform): 13 | 14 | def __init__(self, x, y, h, w, v, inplace=False): 15 | """Erase the input Tensor Image with given value. 16 | This transform does not support PIL Image. 17 | 18 | Args: 19 | img (Tensor Image): Tensor image of size (C, H, W) to be erased 20 | i (int): i in (i,j) i.e coordinates of the upper left corner. 21 | j (int): j in (i,j) i.e coordinates of the upper left corner. 22 | h (int): Height of the erased region. 23 | w (int): Width of the erased region. 24 | v: Erasing value. 25 | inplace(bool, optional): For in-place operations. By default, is set False. 26 | """ 27 | super().__init__() 28 | self._set_attributes(locals()) 29 | 30 | def apply_image(self, img: np.ndarray, interp: str = None) -> np.ndarray: 31 | """ 32 | Apply blend transform on the image(s). 33 | 34 | Args: 35 | img (ndarray): of shape NxHxWxC, or HxWxC or HxW. The array can be 36 | of type uint8 in range [0, 255], or floating point in range 37 | [0, 1] or [0, 255]. 38 | interp (str): keep this option for consistency 39 | Returns: 40 | ndarray: blended image(s). 41 | """ 42 | img = F.erase(torch.from_numpy(img).permute(2, 0, 1), self.x, self.y, self.h, self.w, self.v, self.inplace) 43 | return img.permute(1, 2, 0).numpy() 44 | 45 | def apply_coords(self, coords: np.ndarray) -> np.ndarray: 46 | """ 47 | Apply no transform on the coordinates. 48 | """ 49 | return coords 50 | 51 | def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray: 52 | """ 53 | Apply no transform on the full-image segmentation. 54 | """ 55 | return segmentation 56 | 57 | def inverse(self) -> T.Transform: 58 | """ 59 | The inverse is a no-op. 60 | """ 61 | return T.NoOpTransform() 62 | 63 | 64 | 65 | class RamndomEraseTransform(T.Augmentation): 66 | def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False): 67 | if not isinstance(value, (numbers.Number, str, tuple, list)): 68 | raise TypeError("Argument value should be either a number or str or a sequence") 69 | if isinstance(value, str) and value != "random": 70 | raise ValueError("If value is str, it should be 'random'") 71 | if not isinstance(scale, (tuple, list)): 72 | raise TypeError("Scale should be a sequence") 73 | if not isinstance(ratio, (tuple, list)): 74 | raise TypeError("Ratio should be a sequence") 75 | if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): 76 | warnings.warn("Scale and ratio should be of kind (min, max)") 77 | if scale[0] < 0 or scale[1] > 1: 78 | raise ValueError("Scale should be between 0 and 1") 79 | if p < 0 or p > 1: 80 | raise ValueError("Random erasing probability should be between 0 and 1") 81 | 82 | 83 | self.p = p 84 | self.scale = scale 85 | self.ratio = ratio 86 | self.value = value 87 | self.inplace = inplace 88 | 89 | def get_transform(self, image): 90 | x, y, h, w, v = VT.RandomErasing.get_params(image.transpose(2, 0, 1), scale=self.scale, ratio=self.ratio) 91 | return EraseTransform(x, y, h, w, v, inplace=self.inplace) 92 | 93 | 94 | def get_default_rgb_transform_train(img_size=(512, 512)): 95 | augs = T.AugmentationList([ 96 | T.Resize(img_size), 97 | T.RandomBrightness(0.9, 1.1), 98 | T.RandomFlip(prob=0.5), 99 | ]) 100 | return augs 101 | 102 | def get_enhanced_rgb_transform_train(img_size=(512, 512)): 103 | return T.AugmentationList([ 104 | T.RandomBrightness(0.7, 1.2), 105 | T.RandomFlip(prob=0.5), 106 | T.RandomSaturation(0.7, 1.2), 107 | # T.RandomRotation([0, 180]), 108 | T.MinIoURandomCrop((0.9), min_crop_size=0.7), 109 | T.Resize(img_size), 110 | ]) 111 | 112 | def get_v3_transform(img_size=(512, 512)): 113 | return T.AugmentationList([ 114 | T.RandomBrightness(0.9, 1.1), 115 | T.RandomFlip(prob=0.5), 116 | T.RandomSaturation(0.9, 1.1), 117 | T.RandomContrast(0.9, 1.1), 118 | T.Resize(img_size), 119 | ]) 120 | 121 | def get_color_transform(img_size=(512, 512)): 122 | return T.AugmentationList([ 123 | T.RandomBrightness(0.7, 1.3), 124 | T.RandomFlip(prob=0.5, horizontal=False, vertical=True), 125 | T.RandomFlip(prob=0.5, horizontal=True, vertical=False), 126 | T.RandomSaturation(0.7, 1.3), 127 | T.RandomLighting(2), 128 | T.RandomContrast(0.7, 1.3), 129 | T.Resize(img_size), 130 | ]) 131 | 132 | def get_resize_transform(img_size=(512, 512)): 133 | return T.AugmentationList([ 134 | T.RandomApply(T.RandomRotation([0, 180])), 135 | T.RandomBrightness(0.9, 1.1), 136 | T.ResizeScale(min_scale=0.7, max_scale=1.3, target_width=img_size[0],target_height=img_size[1]), 137 | ]) 138 | 139 | def get_erase_transform(img_size=(512, 512)): 140 | return T.AugmentationList([ 141 | T.RandomBrightness(0.7, 1.2), 142 | T.RandomApply(RamndomEraseTransform(), 0.5), 143 | T.Resize(img_size), 144 | ]) 145 | 146 | def get_extreme_salad_transform(img_size=(512, 512)): 147 | return T.AugmentationList([ 148 | T.RandomApply(RamndomEraseTransform(), 0.3), 149 | T.RandomApply(T.RandomRotation([-20, 20])), 150 | T.RandomBrightness(0.5, 1.5), 151 | T.RandomFlip(prob=0.5, horizontal=False, vertical=True), 152 | T.RandomFlip(prob=0.5, horizontal=True, vertical=False), 153 | T.RandomSaturation(0.5, 1.5), 154 | T.RandomLighting(1.2), 155 | T.RandomContrast(0.5, 1.5), 156 | T.MinIoURandomCrop(min_ious=(0.8, 0.9), min_crop_size=0.8), 157 | T.Resize(img_size), 158 | ]) 159 | 160 | def get_v4_transform(img_size): 161 | 162 | return T.AugmentationList([ 163 | T.RandomSaturation(0.5, 1.5), 164 | T.RandomLighting(1.2), 165 | T.RandomContrast(0.5, 1.5), 166 | T.RandomContrast(0.5, 1.5), 167 | T.Resize(img_size), 168 | ]) 169 | 170 | def get_train_transform(name, img_size=(512,512)): 171 | return { 172 | 'default': get_default_rgb_transform_train, 173 | 'enhanced': get_enhanced_rgb_transform_train, 174 | 'v3': get_v3_transform, 175 | 'color': get_color_transform, 176 | 'resize': get_resize_transform, 177 | 'erase': get_erase_transform, 178 | 'extreme': get_extreme_salad_transform, 179 | 'v4': get_v4_transform, 180 | }[name](img_size=img_size) 181 | -------------------------------------------------------------------------------- /utils/vis_utils.py: -------------------------------------------------------------------------------- 1 | import torch_geometric 2 | from torchvision.utils import draw_bounding_boxes, draw_segmentation_masks 3 | import matplotlib.pyplot as plt 4 | import torchvision.ops as ops 5 | import torch 6 | import cv2 7 | import torch_geometric.utils as utils 8 | from utils.data_utils import cls_graph_to_dense 9 | import numpy as np 10 | 11 | 12 | 13 | def plot_graph(drawing, dep_graph, bbox, color=(255, 0, 0), thickness=4): 14 | edge_ix = dep_graph.edge_index.cpu().to(torch.int) 15 | for i, j in edge_ix.T: 16 | start = bbox[i] 17 | s_x, s_y = (start[0].item() + start[2].item()) // 2, (start[1].item() + start[3].item()) // 2 18 | end = bbox[j] 19 | e_x, e_y = (end[0].item() + end[2].item()) // 2, (end[1].item() + end[3].item()) // 2 20 | drawing = cv2.arrowedLine(drawing, (s_x, s_y), (e_x, e_y), color, thickness) 21 | return drawing 22 | 23 | 24 | from collections import defaultdict 25 | import data.metagraspnet_labels as labels 26 | l = {int(k): v for k, v in labels.IFL_SYNSET_TO_LABEL.items()} 27 | meta_labels = defaultdict(lambda: 'NA') 28 | meta_labels.update(l) 29 | 30 | def plot_sample(image, bbox, dense_graph=None, label=None): 31 | color = (255, 0, 0) 32 | thickness = 2 33 | if bbox.shape[0] == 0: 34 | drawing = image 35 | else: 36 | drawing = draw_bounding_boxes(image, bbox, labels=label,) 37 | drawing = drawing.permute(1, 2, 0).numpy() 38 | if dense_graph is not None and dense_graph.shape[0] > 1: 39 | edge_ix = torch.argwhere(dense_graph) 40 | edge_ix = edge_ix.to(torch.int).numpy() 41 | for i, j in edge_ix: 42 | start = bbox[i] 43 | s_x, s_y = (start[0].item() + start[2].item()) // 2, (start[1].item() + start[3].item()) // 2 44 | end = bbox[j] 45 | e_x, e_y = (end[0].item() + end[2].item()) // 2, (end[1].item() + end[3].item()) // 2 46 | drawing = cv2.arrowedLine(drawing, (s_x, s_y), (e_x, e_y), color, thickness) 47 | return drawing 48 | 49 | 50 | 51 | def plot_errors(ax, res: dict, total='gts', del_thresh=0.01): 52 | assert total in ('gts', 'preds') 53 | missings = 'relation_from_wrong_box' if total == 'preds' else 'non_matched_gt_relations' 54 | name = 'predictions' if total == 'preds' else 'ground truths' 55 | dict_hint = """ { 56 | 'true_positive': tp, 57 | 'false_positive': fp, 58 | 'correct_edge_found': correct_edge_found, 59 | 'correct_empty': correct_empty, 60 | 'wrong_edge_direction': wrong_edge_direction, 61 | 'wrong_presence_of_edge': wrong_presence_of_edge, 62 | 'missed_edge': missed_edge, 63 | 'relation_from_wrong_box': relation_from_wrong_box, 64 | 'non_matched_gt_relations': non_matched_gt_relations, 65 | 'non_matched_edge_relation': non_matched_edge_relation, 66 | 'ngt_rel': ngt_rel, 67 | 'ngt_edge': ngt_edge, 68 | 'all_correct': fp == 0 and tp == ngt_rel 69 | } """ 70 | 71 | keys = ['correct_edge_found', 'correct_empty', 'wrong_edge_direction', 'missed_edge', 'wrong_presence_of_edge', missings] 72 | 73 | # green dgreen lyellow red purple grey 74 | colors = ['#00cc66', '#669900', '#ffcc66', '#ff3333', '#b30059', '#8c8c8c'] 75 | values = [res[k] for k in keys] 76 | total = sum(values) 77 | if total == 0: 78 | return ax 79 | # explode = [0.1, 0, -0.1, 0, 0, 0] 80 | explode = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 81 | try: 82 | keys, colors, values, explode = zip(*[(k, c, v, e) for k, c, v, e in zip(keys, colors, values, explode) if v / total > del_thresh]) 83 | except: 84 | return ax 85 | keys = [k.replace('_', ' ') for k in keys] 86 | ax.pie(x=values, 87 | colors=colors, 88 | labels=keys, 89 | autopct='%1.1f%%', 90 | explode=explode, 91 | textprops={'size': 'large'} 92 | ) 93 | ax.set_axis_off() 94 | ax.set_title(f'Distribution of {name}') 95 | 96 | return ax 97 | 98 | def plot_only_edges(ax, res: dict, del_thresh=0.01): 99 | keys = ['correct_edge_found', 'wrong_edge_direction', 'missed_edge', 'non_matched_edge_relations'] 100 | values = [res[k] for k in keys] 101 | total = sum(values) 102 | # green lyellow red grey 103 | colors = ['#00cc66', '#ffcc66', '#ff3333', '#8c8c8c'] 104 | try: 105 | keys, colors, values = zip(*[(k, c, v) for k, c, v in zip(keys, colors, values) if v / total > del_thresh]) 106 | except: 107 | return ax 108 | values = [res[k] for k in keys] 109 | total = sum(values) 110 | if total == 0: 111 | return ax 112 | keys = [k.replace('_', ' ') for k in keys] 113 | ax.pie(x=values, 114 | colors=colors, 115 | labels=keys, 116 | autopct='%1.1f%%', 117 | textprops={'size': 'large'} 118 | ) 119 | ax.set_axis_off() 120 | ax.set_title(f'Distribution of only edges') 121 | return ax 122 | 123 | def plot_pie_charts(relation_res): 124 | fig = plt.figure(figsize=(17, 12)) 125 | ax1 = fig.add_subplot(1, 3, 1) 126 | ax2 = fig.add_subplot(1, 3, 2) 127 | ax3 = fig.add_subplot(1, 3, 3) 128 | ax1 = plot_errors(ax1, relation_res, total='preds') 129 | ax2 = plot_errors(ax2, relation_res, total='gts') 130 | ax3 = plot_only_edges(ax3, relation_res) 131 | fig.tight_layout() 132 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 133 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 134 | return data 135 | 136 | 137 | def plot_output(input, output, relation_res=None, thresh=0.5): 138 | _detach = lambda x: x.cpu().detach() 139 | img = _detach(input['image']).to(torch.uint8) 140 | 141 | # End-to-End 142 | keep = _detach(output['instances'].scores > thresh) 143 | boxes = _detach(output['instances'].pred_boxes.tensor.long())[keep] 144 | if boxes.shape[0] > 1: 145 | graph = _detach(output['graph_all'][keep, :][:, keep].long()) 146 | else: 147 | graph = None 148 | 149 | labels = [meta_labels[i.item()] for i in _detach(output['instances'].pred_classes[keep])] 150 | img_pred = plot_sample(img, boxes, graph, labels) 151 | 152 | # GT 153 | gt_boxes = input['instances'].gt_boxes.tensor.long() 154 | gt_graph = _detach(input['dense_gt'].long()) 155 | gt_labels = [meta_labels[i.item()] for i in _detach(input['instances'].gt_classes)] 156 | img_gt = plot_sample(img, gt_boxes, gt_graph, gt_labels) 157 | 158 | # Only-Graph 159 | graph = _detach(output['graph'].long()) 160 | img_only_graph = plot_sample(img, gt_boxes, graph, gt_labels) 161 | 162 | fig = plt.figure(figsize=(17, 12)) 163 | ax1 = fig.add_subplot(2, 3, 1) 164 | ax2 = fig.add_subplot(2, 3, 2) 165 | ax3 = fig.add_subplot(2, 3, 3) 166 | ax4 = fig.add_subplot(2, 3, 4) 167 | ax5 = fig.add_subplot(2, 3, 5) 168 | ax6 = fig.add_subplot(2, 3, 6) 169 | ax1.imshow(img_gt) 170 | ax1.set_axis_off() 171 | ax1.title.set_text(f'GT') 172 | ax2.imshow(img_only_graph) 173 | ax2.set_axis_off() 174 | ax2.title.set_text(f'PRED Only Graph') 175 | ax3.imshow(img_pred) 176 | ax3.set_axis_off() 177 | ax3.title.set_text(f'PRED') 178 | # Plot Error distribution 179 | if relation_res is not None: 180 | ax4 = plot_errors(ax4, relation_res, total='preds') 181 | ax5 = plot_errors(ax5, relation_res, total='gts') 182 | ax6 = plot_only_edges(ax6, relation_res) 183 | fig.tight_layout() 184 | fig.text(.45, .25, f'ID: {input["image_id"]}') 185 | fig.canvas.draw() 186 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 187 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 188 | return data 189 | -------------------------------------------------------------------------------- /data/metagraspnet_real_mapper.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any 3 | import detectron2.data.detection_utils as utils 4 | import detectron2.data.transforms as T 5 | from detectron2.config import configurable 6 | import numpy as np 7 | import torch 8 | import torch 9 | import torch_geometric.utils as utils 10 | import torch_geometric.data as data_g 11 | import torch_geometric as tg 12 | from torchvision.utils import draw_bounding_boxes 13 | import matplotlib.pyplot as plt 14 | import torch 15 | import torch_geometric.utils as utils 16 | import os.path as osp 17 | import numpy as np 18 | from detectron2.data import DatasetCatalog 19 | from detectron2.data import transforms as T 20 | import detectron2.structures as structures 21 | import imageio.v3 as iio 22 | import pandas as pd 23 | from data.graph_builder import * 24 | from data.metagraspnet_synth_mapper import MetaGraspNetV2Mapper 25 | 26 | scene_data = None 27 | sample_metadata = None 28 | sample_empty_metadata = None 29 | cache_dir = osp.dirname(__file__) 30 | 31 | def get_rgb_test_transform(img_size=(512, 512)): 32 | augs = T.AugmentationList([ 33 | T.CropTransform(372, 0, 1200, 1200), 34 | T.Resize(img_size) 35 | ]) 36 | return augs 37 | 38 | def get_metagraspnet_dict_empty_bin(cache_dir=cache_dir): 39 | global sample_empty_metadata 40 | if sample_empty_metadata is None: 41 | sample_empty_metadata = pd.read_json(osp.join(cache_dir, 'sample_empty_metadata.json')) 42 | 43 | return sample_empty_metadata.to_dict('records') 44 | 45 | 46 | def get_metagraspnet_dict_real(split='test_all', cache_dir=cache_dir): 47 | global scene_data 48 | global sample_metadata 49 | 50 | assert split in ('test_easy', 'test_medium', 'test_all', 'debug') 51 | 52 | if scene_data is None: 53 | with open(osp.join(cache_dir, 'scene_real_metadata.json')) as f: 54 | scene_data = json.load(f) 55 | difficulties = { 56 | 'test_all': set([0, 1, 2, 3]), 57 | 'test_easy': set([1]), 58 | 'test_medium': set([2]), 59 | 'debug': set([0, 1, 2, 3]) 60 | }[split] 61 | 62 | scene_ix = [int(k[5:]) for k, v in scene_data['difficulty'].items() 63 | if v in difficulties] 64 | scene_ix = set(scene_ix) 65 | 66 | if sample_metadata is None: 67 | sample_metadata = pd.read_json(osp.join(cache_dir, 'sample_real_metadata.json')) 68 | 69 | if split == 'debug': 70 | return sample_metadata[sample_metadata['scene'].isin([565])].to_dict('records') 71 | 72 | 73 | valid_sample = sample_metadata['bbox'].map(lambda x:len(x)) == sample_metadata['graph'].map(lambda x:len(x)) 74 | sample = sample_metadata[sample_metadata['scene'].isin(scene_ix) & (sample_metadata['num_objects'] > 0) & valid_sample] 75 | 76 | return sample.to_dict('records') 77 | 78 | DatasetCatalog.register('meta_graspnet_v2_real_test', func=lambda: get_metagraspnet_dict_real('test_all')) 79 | 80 | 81 | 82 | class MetaGraspNetV2MapperReal(MetaGraspNetV2Mapper): 83 | 84 | def __init__(self, *args, **kwargs): 85 | super().__init__(*args, **kwargs) 86 | self.mask_on = False 87 | self.rgb_transform = get_rgb_test_transform(self.img_size) 88 | 89 | 90 | 91 | def __call__(self, dataset_dict) -> Any: 92 | 93 | npz = osp.join(self.data_root, dataset_dict['npz_path']) 94 | # h5 = osp.join(self.data_root, dataset_dict['h5_path']) 95 | rgb_path = osp.join(self.data_root, dataset_dict['rgb_path']) 96 | depth_path = osp.join(self.data_root, dataset_dict['depth_path']) 97 | bbox = np.array(dataset_dict['bbox'], dtype=int) 98 | bbox_categories = dataset_dict['bbox_categories'] 99 | 100 | rgb = None 101 | depth = None 102 | 103 | rgb = iio.imread(rgb_path) 104 | 105 | if self.load_depth: 106 | depth = iio.imread(depth_path) 107 | depth = depth * -1 108 | depth[depth == 0] = 255 109 | 110 | # Instance Seg 111 | if self.mask_on: 112 | with np.load(npz) as file_npz: 113 | instance_seg = file_npz['instances_objects'] 114 | else: 115 | instance_seg = None 116 | 117 | label = torch.tensor(bbox_categories).long() 118 | in_aug = T.AugInput(image=rgb, boxes=bbox, sem_seg=instance_seg) 119 | transformation = self.rgb_transform(in_aug) 120 | 121 | if depth is not None: 122 | depth = transformation.apply_segmentation(depth) 123 | num_boxes = in_aug.boxes.shape[0] 124 | obj_boxes = in_aug.boxes 125 | obj_mask = torch.Tensor(np.array([in_aug.sem_seg == i for i in range(1, num_boxes + 1)])) 126 | 127 | adj_matrix = (torch.tensor(np.atleast_2d(dataset_dict['graph'])).T * -1).long() 128 | 129 | graph_gt = self._generate_graph_gt(adj_matrix) 130 | 131 | size = in_aug.image.shape 132 | image = in_aug.image 133 | 134 | if depth is not None: 135 | image = np.concatenate([image, depth[..., None]], axis=2) 136 | 137 | instances = structures.Instances( 138 | image_size=size[:2], 139 | gt_boxes=structures.Boxes(obj_boxes), 140 | gt_classes=label) 141 | 142 | if self.mask_on: 143 | instances.set('gt_masks', structures.BitMasks(obj_mask.bool())) 144 | 145 | return { 146 | 'width': size[0], 147 | 'height': size[1], 148 | 'image': torch.from_numpy(image.transpose(2, 0, 1).copy()).float(), 149 | 'instances': instances, 150 | 'graph_gt': graph_gt, 151 | 'dense_gt': adj_matrix, 152 | 'image_id': dataset_dict['id'], 153 | } 154 | 155 | def _generate_graph_gt(self, order): 156 | n = order.shape[0] 157 | if self.graph_gt_type == 'relation_graph': 158 | if order.sum() > 0: 159 | edge_ix = utils.to_torch_coo_tensor(utils.dense_to_sparse(order)[0]).indices().float() 160 | else: 161 | edge_ix = None 162 | geometric_graph = data_g.Data(edge_index=edge_ix, num_nodes=n) 163 | graph_gt = objgraph_to_objrelgraph(obj_graph=geometric_graph) 164 | elif self.graph_gt_type == 'classification': 165 | # The Visual Manipulation RelationShip paper wants them in the form n*(n-1) where the indices are given by triu_indices 166 | indices = np.array(np.triu_indices(n, k=1)) 167 | graph_gt = (order.triu() + order.T.triu() * 2)[indices[0], indices[1]].int() 168 | elif self.graph_gt_type == 'gru_graph': 169 | indices = np.array(np.triu_indices(n, k=1)) 170 | indices = torch.from_numpy(np.concatenate([indices[[0, 1]], indices[[1, 0]]], axis=1)) 171 | r1 = order[indices[0], indices[1]] 172 | r2 = order[indices[1], indices[0]] 173 | r0 = ((r1+r2) == 0).long() 174 | y = torch.vstack([r0, r1, r2]).T 175 | y = torch.argmax(y, dim=1) 176 | edges = torch.cartesian_prod(torch.arange(len(y)), torch.arange(len(y))).T 177 | graph_gt = tg.data.Data(edge_index=edges, num_nodes=n, rel_gt=y) 178 | elif self.graph_gt_type == 'dense': 179 | graph_gt = order 180 | else: 181 | raise NotImplemented("Type of graph GT Not implemented") 182 | return graph_gt 183 | 184 | 185 | if __name__ == '__main__': 186 | 187 | data_dict = get_metagraspnet_dict_empty_bin() 188 | mapper = MetaGraspNetV2Mapper(data_root='./datasets', graph_gt_type='classification', is_train=True) 189 | 190 | sample = mapper(data_dict[3]) 191 | sample['graph_gt'].x = sample['graph_gt'].rel_gt[:, None] 192 | out = cls_graph_to_dense(sample['graph_gt']) 193 | plot_sample(sample['image'], sample['instances'].gt_boxes, out[..., 0]) 194 | print(out) 195 | -------------------------------------------------------------------------------- /models/deformable_detr_modules/ops/src/cuda/ms_deform_attn_cuda.cu: -------------------------------------------------------------------------------- 1 | /*! 2 | ************************************************************************************************** 3 | * Deformable DETR 4 | * Copyright (c) 2020 SenseTime. All Rights Reserved. 5 | * Licensed under the Apache License, Version 2.0 [see LICENSE for details] 6 | ************************************************************************************************** 7 | * Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 8 | ************************************************************************************************** 9 | */ 10 | 11 | #include 12 | #include "cuda/ms_deform_im2col_cuda.cuh" 13 | 14 | #include 15 | #include 16 | #include 17 | #include 18 | 19 | 20 | at::Tensor ms_deform_attn_cuda_forward( 21 | const at::Tensor &value, 22 | const at::Tensor &spatial_shapes, 23 | const at::Tensor &level_start_index, 24 | const at::Tensor &sampling_loc, 25 | const at::Tensor &attn_weight, 26 | const int im2col_step) 27 | { 28 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 29 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 30 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 31 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 32 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 33 | 34 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 35 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 36 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 37 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 38 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 39 | 40 | const int batch = value.size(0); 41 | const int spatial_size = value.size(1); 42 | const int num_heads = value.size(2); 43 | const int channels = value.size(3); 44 | 45 | const int num_levels = spatial_shapes.size(0); 46 | 47 | const int num_query = sampling_loc.size(1); 48 | const int num_point = sampling_loc.size(4); 49 | 50 | const int im2col_step_ = std::min(batch, im2col_step); 51 | 52 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 53 | 54 | auto output = at::zeros({batch, num_query, num_heads, channels}, value.options()); 55 | 56 | const int batch_n = im2col_step_; 57 | auto output_n = output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 58 | auto per_value_size = spatial_size * num_heads * channels; 59 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 60 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 61 | for (int n = 0; n < batch/im2col_step_; ++n) 62 | { 63 | auto columns = output_n.select(0, n); 64 | // AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_forward_cuda", ([&] { 65 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_forward_cuda", ([&] { 66 | ms_deformable_im2col_cuda(at::cuda::getCurrentCUDAStream(), 67 | value.data() + n * im2col_step_ * per_value_size, 68 | spatial_shapes.data(), 69 | level_start_index.data(), 70 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 71 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 72 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 73 | columns.data()); 74 | 75 | })); 76 | } 77 | 78 | output = output.view({batch, num_query, num_heads*channels}); 79 | 80 | return output; 81 | } 82 | 83 | 84 | std::vector ms_deform_attn_cuda_backward( 85 | const at::Tensor &value, 86 | const at::Tensor &spatial_shapes, 87 | const at::Tensor &level_start_index, 88 | const at::Tensor &sampling_loc, 89 | const at::Tensor &attn_weight, 90 | const at::Tensor &grad_output, 91 | const int im2col_step) 92 | { 93 | 94 | AT_ASSERTM(value.is_contiguous(), "value tensor has to be contiguous"); 95 | AT_ASSERTM(spatial_shapes.is_contiguous(), "spatial_shapes tensor has to be contiguous"); 96 | AT_ASSERTM(level_start_index.is_contiguous(), "level_start_index tensor has to be contiguous"); 97 | AT_ASSERTM(sampling_loc.is_contiguous(), "sampling_loc tensor has to be contiguous"); 98 | AT_ASSERTM(attn_weight.is_contiguous(), "attn_weight tensor has to be contiguous"); 99 | AT_ASSERTM(grad_output.is_contiguous(), "grad_output tensor has to be contiguous"); 100 | 101 | AT_ASSERTM(value.type().is_cuda(), "value must be a CUDA tensor"); 102 | AT_ASSERTM(spatial_shapes.type().is_cuda(), "spatial_shapes must be a CUDA tensor"); 103 | AT_ASSERTM(level_start_index.type().is_cuda(), "level_start_index must be a CUDA tensor"); 104 | AT_ASSERTM(sampling_loc.type().is_cuda(), "sampling_loc must be a CUDA tensor"); 105 | AT_ASSERTM(attn_weight.type().is_cuda(), "attn_weight must be a CUDA tensor"); 106 | AT_ASSERTM(grad_output.type().is_cuda(), "grad_output must be a CUDA tensor"); 107 | 108 | const int batch = value.size(0); 109 | const int spatial_size = value.size(1); 110 | const int num_heads = value.size(2); 111 | const int channels = value.size(3); 112 | 113 | const int num_levels = spatial_shapes.size(0); 114 | 115 | const int num_query = sampling_loc.size(1); 116 | const int num_point = sampling_loc.size(4); 117 | 118 | const int im2col_step_ = std::min(batch, im2col_step); 119 | 120 | AT_ASSERTM(batch % im2col_step_ == 0, "batch(%d) must divide im2col_step(%d)", batch, im2col_step_); 121 | 122 | auto grad_value = at::zeros_like(value); 123 | auto grad_sampling_loc = at::zeros_like(sampling_loc); 124 | auto grad_attn_weight = at::zeros_like(attn_weight); 125 | 126 | const int batch_n = im2col_step_; 127 | auto per_value_size = spatial_size * num_heads * channels; 128 | auto per_sample_loc_size = num_query * num_heads * num_levels * num_point * 2; 129 | auto per_attn_weight_size = num_query * num_heads * num_levels * num_point; 130 | auto grad_output_n = grad_output.view({batch/im2col_step_, batch_n, num_query, num_heads, channels}); 131 | 132 | for (int n = 0; n < batch/im2col_step_; ++n) 133 | { 134 | auto grad_output_g = grad_output_n.select(0, n); 135 | // AT_DISPATCH_FLOATING_TYPES(value.type(), "ms_deform_attn_backward_cuda", ([&] { 136 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(value.type(), "ms_deform_attn_backward_cuda", ([&] { 137 | ms_deformable_col2im_cuda(at::cuda::getCurrentCUDAStream(), 138 | grad_output_g.data(), 139 | value.data() + n * im2col_step_ * per_value_size, 140 | spatial_shapes.data(), 141 | level_start_index.data(), 142 | sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 143 | attn_weight.data() + n * im2col_step_ * per_attn_weight_size, 144 | batch_n, spatial_size, num_heads, channels, num_levels, num_query, num_point, 145 | grad_value.data() + n * im2col_step_ * per_value_size, 146 | grad_sampling_loc.data() + n * im2col_step_ * per_sample_loc_size, 147 | grad_attn_weight.data() + n * im2col_step_ * per_attn_weight_size); 148 | 149 | })); 150 | } 151 | 152 | return { 153 | grad_value, grad_sampling_loc, grad_attn_weight 154 | }; 155 | } -------------------------------------------------------------------------------- /data/metagraspnet_synth_mapper.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Any 3 | import detectron2.data.detection_utils as utils 4 | import detectron2.data.transforms as T 5 | from detectron2.config import configurable 6 | import numpy as np 7 | import torch 8 | import torch 9 | import torch_geometric.utils as utils 10 | import torch_geometric.data as data_g 11 | import torch_geometric as tg 12 | from torchvision.utils import draw_bounding_boxes 13 | import matplotlib.pyplot as plt 14 | import torch 15 | import cv2 16 | import torch_geometric.utils as utils 17 | 18 | import os.path as osp 19 | import numpy as np 20 | from detectron2.data import DatasetCatalog 21 | from detectron2.data import transforms as T 22 | import detectron2.structures as structures 23 | import imageio.v3 as iio 24 | import pandas as pd 25 | import tqdm 26 | 27 | from data.augmentations import get_train_transform 28 | from data.graph_builder import objgraph_to_objrelgraph 29 | # from data.graph_builder import * 30 | 31 | 32 | 33 | scene_data = None 34 | splits = None 35 | sample_metadata = None 36 | cache_dir = osp.dirname(__file__) 37 | 38 | def get_test_rgb_transform_test(img_size=(512, 512)): 39 | augs = T.AugmentationList([ 40 | T.Resize(img_size) 41 | ]) 42 | return augs 43 | 44 | 45 | def get_metagraspnet_dict_synth(split='all', cache_dir=cache_dir): 46 | global scene_data 47 | global splits 48 | global sample_metadata 49 | if scene_data is None: 50 | with open(osp.join(cache_dir, 'scene_synt_metadata.json')) as f: 51 | scene_data = json.load(f) 52 | if splits is None: 53 | with open(osp.join(cache_dir, 'splits.json')) as f: 54 | splits = json.load(f) 55 | if split == 'all': 56 | scene_ix = splits['train'] + splits['test'] + splits['val'] 57 | if split.startswith('test'): 58 | difficulty = split.split('_')[1] 59 | scene_ix = splits['test'] 60 | difficulty_ix = {'easy': 1, 'medium': 2, 'hard': 3}[difficulty] 61 | scene_difficulty = {int(k[5:]): v for k, v in scene_data['difficulty'].items()} 62 | scene_ix = set([ix for ix in scene_ix if scene_difficulty[ix] == difficulty_ix]) 63 | else: 64 | scene_ix = splits[split] 65 | scene_ix = set(scene_ix) 66 | if sample_metadata is None: 67 | sample_metadata = pd.read_json(osp.join(cache_dir, 'sample_metadata.json')) 68 | sample = sample_metadata[sample_metadata['scene'].isin(scene_ix)] 69 | sample = sample[sample['num_objects'] > 1] 70 | 71 | # TODO REMOVE ME 72 | sample = sample_metadata.iloc[:1] 73 | 74 | return sample.to_dict('records') 75 | 76 | 77 | DatasetCatalog.register('meta_graspnet_v2_synth_train', func=lambda: get_metagraspnet_dict_synth('train')) 78 | DatasetCatalog.register('meta_graspnet_v2_synth_test_easy', func=lambda: get_metagraspnet_dict_synth('test_easy')) 79 | DatasetCatalog.register('meta_graspnet_v2_synth_test_medium', func=lambda: get_metagraspnet_dict_synth('test_medium')) 80 | DatasetCatalog.register('meta_graspnet_v2_synth_test_hard', func=lambda: get_metagraspnet_dict_synth('test_hard')) 81 | DatasetCatalog.register('meta_graspnet_v2_synth_eval', func=lambda: get_metagraspnet_dict_synth('val')) 82 | DatasetCatalog.register('meta_graspnet_v2_synth_all', func=lambda: get_metagraspnet_dict_synth('all')) 83 | 84 | 85 | 86 | 87 | class MetaGraspNetV2Mapper: 88 | 89 | @configurable 90 | def __init__( 91 | self, 92 | data_root: str, 93 | is_train: bool, 94 | graph_gt_type: str = 'relation_graph', 95 | depth: bool = False, 96 | mask_on: bool = False, 97 | aug='default', 98 | img_size = (512, 512) 99 | ) -> None: 100 | 101 | assert graph_gt_type in ('relation_graph', 'classification', 'dense', 'gru_graph') 102 | 103 | self.data_root = data_root 104 | self.is_train = is_train 105 | self.graph_gt_type = graph_gt_type 106 | self.img_size = img_size 107 | self.load_depth = depth 108 | if self.is_train: 109 | self.rgb_transform = get_train_transform(aug, self.img_size) 110 | else: 111 | self.rgb_transform = get_test_rgb_transform_test(self.img_size) 112 | self.mask_on = mask_on 113 | 114 | @classmethod 115 | def from_config(cls, cfg, is_train: bool=True, ): 116 | return { 117 | 'data_root': cfg.DATASETS.ROOT, 118 | 'is_train': is_train, 119 | 'graph_gt_type': cfg.INPUT.GRAPH_GT_TYPE, 120 | 'depth': cfg.INPUT.DEPTH, 121 | 'aug': cfg.INPUT.AUGMENT, 122 | 'mask_on': cfg.MODEL.MASK_ON 123 | } 124 | 125 | def __call__(self, dataset_dict) -> Any: 126 | 127 | npz = osp.join(self.data_root, dataset_dict['npz_path']) 128 | # h5 = osp.join(self.data_root, dataset_dict['h5_path']) 129 | rgb_path = osp.join(self.data_root, dataset_dict['rgb_path']) 130 | depth_path = osp.join(self.data_root, dataset_dict['depth_path']) 131 | bbox = np.array(dataset_dict['bbox'], dtype=int) 132 | bbox_categories = dataset_dict['bbox_categories'] 133 | 134 | rgb = None 135 | depth = None 136 | 137 | rgb = iio.imread(rgb_path) 138 | 139 | if self.load_depth: 140 | depth = iio.imread(depth_path) 141 | 142 | # Instance Seg 143 | if self.mask_on: 144 | with np.load(npz) as file_npz: 145 | instance_seg = file_npz['instances_objects'] 146 | else: 147 | instance_seg = None 148 | 149 | if bbox.size != 0: 150 | bbox = structures.BoxMode.convert(bbox, from_mode=structures.BoxMode.XYWH_ABS, to_mode=structures.BoxMode.XYXY_ABS) 151 | label = torch.tensor(bbox_categories).long() 152 | in_aug = T.AugInput(image=rgb, boxes=bbox, sem_seg=instance_seg) 153 | transformation = self.rgb_transform(in_aug) 154 | 155 | if depth is not None: 156 | depth = transformation.apply_segmentation(depth) 157 | num_boxes = in_aug.boxes.shape[0] 158 | obj_boxes = in_aug.boxes 159 | obj_mask = torch.Tensor(np.array([in_aug.sem_seg == i for i in range(1, num_boxes + 1)])) 160 | 161 | adj_matrix = (torch.tensor(np.atleast_2d(dataset_dict['graph'])).T * -1).long() 162 | 163 | graph_gt = self._generate_graph_gt(adj_matrix) 164 | 165 | size = in_aug.image.shape 166 | image = in_aug.image 167 | 168 | if depth is not None: 169 | image = np.concatenate([image, depth[..., None]], axis=2) 170 | 171 | instances = structures.Instances( 172 | image_size=size[:2], 173 | gt_boxes=structures.Boxes(obj_boxes), 174 | gt_classes=label) 175 | 176 | if self.mask_on: 177 | if obj_mask.shape[0] == 0: 178 | obj_mask = obj_mask.reshape(0, 0, 0) 179 | instances.set('gt_masks', structures.BitMasks(obj_mask.bool())) 180 | 181 | return { 182 | 'width': size[0], 183 | 'height': size[1], 184 | 'image': torch.from_numpy(image.transpose(2, 0, 1).copy()).float(), 185 | 'instances': instances, 186 | 'graph_gt': graph_gt, 187 | 'dense_gt': adj_matrix, 188 | 'image_id': dataset_dict['id'], 189 | } 190 | 191 | def _generate_graph_gt(self, order): 192 | n = order.shape[0] 193 | if self.graph_gt_type == 'relation_graph': 194 | if order.sum() > 0: 195 | edge_ix = utils.to_torch_coo_tensor(utils.dense_to_sparse(order)[0]).indices().float() 196 | else: 197 | edge_ix = None 198 | geometric_graph = data_g.Data(edge_index=edge_ix, num_nodes=n) 199 | graph_gt = objgraph_to_objrelgraph(obj_graph=geometric_graph) 200 | elif self.graph_gt_type == 'classification': 201 | # The Visual Manipulation RelationShip paper wants them in the form n*(n-1) where the indices are given by triu_indices 202 | indices = np.array(np.triu_indices(n, k=1)) 203 | graph_gt = (order.triu() + order.T.triu() * 2)[indices[0], indices[1]].int() 204 | elif self.graph_gt_type == 'gru_graph': 205 | indices = np.array(np.triu_indices(n, k=1)) 206 | indices = torch.from_numpy(np.concatenate([indices[[0, 1]], indices[[1, 0]]], axis=1)) 207 | r1 = order[indices[0], indices[1]] 208 | r2 = order[indices[1], indices[0]] 209 | r0 = ((r1+r2) == 0).long() 210 | y = torch.vstack([r0, r1, r2]).T 211 | y = torch.argmax(y, dim=1) 212 | edges = torch.cartesian_prod(torch.arange(len(y)), torch.arange(len(y))).T 213 | graph_gt = tg.data.Data(edge_index=edges, num_nodes=n, rel_gt=y) 214 | elif self.graph_gt_type == 'dense': 215 | graph_gt = order 216 | else: 217 | raise NotImplemented("Type of graph GT Not implemented") 218 | return graph_gt 219 | 220 | 221 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import pathlib 3 | 4 | from tqdm import tqdm 5 | from utils.configs import add_dep_graph_config, add_detr_config 6 | 7 | 8 | from detectron2.engine import DefaultTrainer 9 | from detectron2.data import build_detection_train_loader, get_detection_dataset_dicts 10 | from detectron2.config import get_cfg 11 | from detectron2.engine import default_argument_parser, launch 12 | from utils.data_utils import build_detection_test_loader 13 | import detectron2.utils.comm as comm 14 | from detectron2.utils.file_io import PathManager 15 | import detectron2.evaluation 16 | import data.eval_metagraspnet 17 | import detectron2.utils 18 | import models 19 | import logging 20 | from detectron2.checkpoint import DetectionCheckpointer 21 | import os 22 | 23 | from detectron2.evaluation import ( 24 | DatasetEvaluator, 25 | inference_on_dataset, 26 | print_csv_format, 27 | ) 28 | from utils.train_utils import build_optimizer, EvalTestHook 29 | from detectron2.engine import hooks 30 | from fvcore.nn.precise_bn import get_bn_modules 31 | import data 32 | import os.path as osp 33 | import json 34 | import datetime 35 | 36 | logger = logging.getLogger("detectron2") 37 | 38 | class Trainer(DefaultTrainer): 39 | @classmethod 40 | def build_train_loader(cls, cfg): 41 | mapper = data.get_mapper(cfg.DATASETS.TRAIN[0])(cfg, is_train=True) 42 | return build_detection_train_loader(cfg, mapper=mapper, pin_memory=True) 43 | 44 | @classmethod 45 | def build_test_loader(cls, cfg, dataset_name): 46 | """ 47 | Returns: 48 | iterable 49 | """ 50 | mapper = data.get_mapper(dataset_name)(cfg, is_train=False) 51 | dataset = get_detection_dataset_dicts(names=dataset_name) 52 | return build_detection_test_loader(dataset=dataset, mapper=mapper, num_workers=cfg.DATALOADER.NUM_WORKERS, ) 53 | 54 | @classmethod 55 | def build_evaluator(cls, cfg, dataset_name, fast=True, save_all=False): 56 | det_only = cfg.MODEL.META_ARCHITECTURE in ('GeneralizedRCNN', 'Detr', 'DeformableDetr') 57 | 58 | return data.eval_metagraspnet.GraphEvaluator(dataset_name, cfg.OUTPUT_DIR, thresh=cfg.TEST.GRAPH_THRESH, det_only=det_only) 59 | 60 | @classmethod 61 | def test(cls, cfg, model, datasets, evaluators=None, fast=True, save_all=False): 62 | """ 63 | Evaluate the given model. The given model is expected to already contain 64 | weights to evaluate. 65 | 66 | Args: 67 | cfg (CfgNode): 68 | model (nn.Module): 69 | evaluators (list[DatasetEvaluator] or None): if None, will call 70 | :meth:`build_evaluator`. Otherwise, must have the same length as 71 | ``cfg.DATASETS.TEST``. 72 | 73 | Returns: 74 | dict: a dict of result metrics 75 | """ 76 | logger = logging.getLogger(__name__) 77 | logger.addHandler(logging.StreamHandler) 78 | 79 | if isinstance(evaluators, DatasetEvaluator): 80 | evaluators = [evaluators] 81 | if evaluators is not None: 82 | assert len(datasets) == len(evaluators), "{} != {}".format( 83 | len(datasets), len(evaluators) 84 | ) 85 | 86 | results = OrderedDict() 87 | for idx, dataset_name in tqdm(enumerate(datasets)): 88 | data_loader = cls.build_test_loader(cfg, dataset_name) 89 | # When evaluators are passed in as arguments, 90 | # implicitly assume that evaluators can be created before data_loader. 91 | if evaluators is not None: 92 | evaluator = evaluators[idx] 93 | else: 94 | try: 95 | evaluator = cls.build_evaluator(cfg, dataset_name, fast=fast, save_all=save_all) 96 | except NotImplementedError: 97 | logger.warn( 98 | "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " 99 | "or implement its `build_evaluator` method." 100 | ) 101 | results[dataset_name] = {} 102 | continue 103 | results_i = inference_on_dataset(model, data_loader, evaluator) 104 | results[dataset_name] = results_i 105 | if comm.is_main_process(): 106 | assert isinstance( 107 | results_i, dict 108 | ), "Evaluator must return a dict on the main process. Got {} instead.".format( 109 | results_i 110 | ) 111 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 112 | print_csv_format(results_i) 113 | 114 | if len(results) == 1: 115 | results = list(results.values())[0] 116 | return results 117 | 118 | @classmethod 119 | def build_optimizer(cls, cfg, model): 120 | total_params = sum(p.numel() for p in model.parameters()) 121 | print(f'TOTAL PARAMETERS {total_params}') 122 | try: 123 | graph_head = model.detr.graph_embed 124 | graph_params = sum(p.numel() for p in graph_head.graph_transformer_layers.parameters()) 125 | print(f'GRAPH PARAMETERS {graph_params}') 126 | except AttributeError: 127 | pass 128 | 129 | return build_optimizer(cfg, model) 130 | 131 | def build_hooks(self): 132 | """ 133 | Build a list of default hooks, including timing, evaluation, 134 | checkpointing, lr scheduling, precise BN, writing events. 135 | 136 | Returns: 137 | list[HookBase]: 138 | """ 139 | cfg = self.cfg.clone() 140 | cfg.defrost() 141 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN 142 | 143 | ret = [ 144 | hooks.IterationTimer(), 145 | hooks.LRScheduler(), 146 | hooks.PreciseBN( 147 | # Run at the same freq as (but before) evaluation. 148 | cfg.TEST.EVAL_PERIOD, 149 | self.model, 150 | # Build a new data loader to not affect training 151 | self.build_train_loader(cfg), 152 | cfg.TEST.PRECISE_BN.NUM_ITER, 153 | ) 154 | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) 155 | else None, 156 | ] 157 | 158 | # Do PreciseBN before checkpointer, because it updates the model and need to 159 | # be saved by checkpointer. 160 | # This is not always the best: if checkpointing has a different frequency, 161 | # some checkpoints may have more precise statistics than others. 162 | if comm.is_main_process(): 163 | ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) 164 | 165 | def test_and_save_results(): 166 | self._last_eval_results = self.test(self.cfg, self.model, datasets=cfg.DATASETS.TEST) 167 | return self._last_eval_results 168 | 169 | def eval_and_save_results(): 170 | self._last_eval_results = self.test(self.cfg, self.model, datasets=cfg.DATASETS.EVAL) 171 | return self._last_eval_results 172 | 173 | 174 | # Do evaluation after checkpointer, because then if it fails, 175 | # we can use the saved checkpoint to debug. 176 | ret.append(EvalTestHook(cfg.TEST.EVAL_PERIOD, test_function=test_and_save_results, eval_function=eval_and_save_results)) 177 | 178 | if comm.is_main_process(): 179 | # Here the default print/log frequency of each writer is used. 180 | # run writers in the end, so that evaluation metrics are written 181 | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) 182 | 183 | return ret 184 | 185 | def build_writers(self): 186 | PathManager.mkdirs(self.cfg.OUTPUT_DIR) 187 | return [ 188 | detectron2.utils.events.CommonMetricPrinter(self.max_iter), 189 | detectron2.utils.events.JSONWriter(os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")), 190 | detectron2.utils.events.TensorboardXWriter(self.cfg.OUTPUT_DIR), 191 | ] 192 | 193 | 194 | def setup(args): 195 | cfg = get_cfg() 196 | cfg = add_dep_graph_config(cfg) 197 | if 'detr' in args.config_file: 198 | cfg = add_detr_config(cfg) 199 | cfg.merge_from_file(args.config_file) 200 | 201 | if args.data_path is not None: 202 | cfg.DATASETS.ROOT = args.data_path 203 | 204 | opts = [i.split('=') for i in args.opts] 205 | opts = [x for xs in opts for x in xs] 206 | run_id = f"{pathlib.Path(args.config_file).name[:-1]}_{datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}" 207 | cfg.merge_from_list(opts) 208 | if cfg.OUTPUT_DIR == '': 209 | cfg.OUTPUT_DIR = f"./output/{''.join(cfg.DATASETS.TRAIN)}_out/{run_id}" 210 | # remove .yaml 211 | cfg.NAME = args.name 212 | cfg.freeze() 213 | return cfg 214 | 215 | 216 | def main(args): 217 | cfg = setup(args) 218 | 219 | if args.eval_only: 220 | model = Trainer.build_model(cfg) 221 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 222 | cfg.MODEL.WEIGHTS, resume=True 223 | ) 224 | res = Trainer.test(cfg, model, datasets=cfg.DATASETS.TEST, fast=False, save_all=False) 225 | print(res) 226 | with open(osp.join(cfg.OUTPUT_DIR, 'test_results.json'), mode='w') as f: 227 | json.dump(res, f) 228 | return res 229 | 230 | trainer = Trainer(cfg) 231 | trainer.resume_or_load(resume=args.resume) 232 | trainer.train() 233 | 234 | 235 | 236 | 237 | if __name__ == '__main__': 238 | parser = default_argument_parser() 239 | parser.add_argument('--data-path', type=str, default=None, help='root path for the dataset folder, if specified overwrites the one defined from the config file') 240 | parser.add_argument('--name', type=str, default=None, help='custom experiment name if needed') 241 | args = parser.parse_args() 242 | print("Command Line Args:", args) 243 | launch( 244 | main, 245 | args.num_gpus, 246 | num_machines=args.num_machines, 247 | machine_rank=args.machine_rank, 248 | dist_url=args.dist_url, 249 | args=(args,), 250 | ) 251 | 252 | -------------------------------------------------------------------------------- /metrics/ap_eval_rel.py: -------------------------------------------------------------------------------- 1 | # Copied from https://github.com/SHTUPLUS/PySGG/blob/main/pysgg/data/datasets/evaluation/oi/ap_eval_rel.py 2 | # Slightly changed the function from https://github.com/SHTUPLUS/PySGG/blob/main/pysgg/structures/boxlist_ops.py 3 | 4 | # Adapted from Detectron.pytorch/lib/datasets/voc_eval.py for 5 | # this project by Ji Zhang, 2019 6 | # ----------------------------------------------------------------------------- 7 | # Copyright (c) 2017-present, Facebook, Inc. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | ############################################################################## 21 | # 22 | # Based on: 23 | # -------------------------------------------------------- 24 | # Fast/er R-CNN 25 | # Licensed under The MIT License [see LICENSE for details] 26 | # Written by Bharath Hariharan 27 | # -------------------------------------------------------- 28 | 29 | """relationship AP evaluation code.""" 30 | 31 | import logging 32 | 33 | import numpy as np 34 | import torch 35 | from tqdm import tqdm 36 | 37 | logger = logging.getLogger(__name__) 38 | 39 | 40 | # https://github.com/SHTUPLUS/PySGG/blob/main/pysgg/structures/boxlist_ops.py#L54 41 | def bbox_iou(box1, box2): 42 | """Compute the intersection over union of two set of boxes. 43 | The box order must be (xmin, ymin, xmax, ymax). 44 | Args: 45 | box1: (tensor) bounding boxes, sized [N,4]. 46 | box2: (tensor) bounding boxes, sized [M,4]. 47 | Return: 48 | (tensor) iou, sized [N,M]. 49 | Reference: 50 | https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py 51 | """ 52 | box1 = torch.from_numpy(box1) 53 | box2 = torch.from_numpy(box2) 54 | 55 | lt = torch.max(box1[:, None, :2], box2[:, :2]) # [N,M,2] 56 | rb = torch.min(box1[:, None, 2:], box2[:, 2:]) # [N,M,2] 57 | 58 | TO_REMOVE = 1 59 | wh = (rb - lt + TO_REMOVE).clamp(min=0) # [N,M,2] 60 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 61 | 62 | area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1]) # [N,] 63 | area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1]) # [M,] 64 | iou = inter / (area1[:, None] + area2 - inter) 65 | return iou 66 | 67 | 68 | def prepare_mAP_dets(topk_dets, cls_num): 69 | cls_image_ids = [[] for _ in range(cls_num)] 70 | cls_dets = [ 71 | { 72 | "confidence": np.empty(0), 73 | "BB_s": np.empty((0, 4)), 74 | "BB_o": np.empty((0, 4)), 75 | "BB_r": np.empty((0, 4)), 76 | "LBL_s": np.empty(0), 77 | "LBL_o": np.empty(0), 78 | } 79 | for _ in range(cls_num) 80 | ] 81 | cls_gts = [{} for _ in range(cls_num)] 82 | npos = [0 for _ in range(cls_num)] 83 | for dets in tqdm(topk_dets): 84 | image_id = dets["image"] 85 | sbj_boxes = dets["det_boxes_s_top"] 86 | obj_boxes = dets["det_boxes_o_top"] 87 | rel_boxes = boxes_union(sbj_boxes, obj_boxes) 88 | sbj_labels = dets["det_labels_s_top"] 89 | obj_labels = dets["det_labels_o_top"] 90 | prd_labels = dets["det_labels_p_top"] 91 | det_scores = dets["det_scores_top"] 92 | gt_boxes_sbj = dets["gt_boxes_sbj"] 93 | gt_boxes_obj = dets["gt_boxes_obj"] 94 | gt_boxes_rel = boxes_union(gt_boxes_sbj, gt_boxes_obj) 95 | gt_labels_sbj = dets["gt_labels_sbj"] 96 | gt_labels_prd = dets["gt_labels_prd"] 97 | gt_labels_obj = dets["gt_labels_obj"] 98 | for c in range(cls_num): 99 | cls_inds = np.where(prd_labels == c)[0] 100 | # logger.info(cls_inds) 101 | if len(cls_inds): 102 | cls_sbj_boxes = sbj_boxes[cls_inds] 103 | cls_obj_boxes = obj_boxes[cls_inds] 104 | cls_rel_boxes = rel_boxes[cls_inds] 105 | cls_sbj_labels = sbj_labels[cls_inds] 106 | cls_obj_labels = obj_labels[cls_inds] 107 | cls_det_scores = det_scores[cls_inds] 108 | cls_dets[c]["confidence"] = np.concatenate( 109 | (cls_dets[c]["confidence"], cls_det_scores) 110 | ) 111 | cls_dets[c]["BB_s"] = np.concatenate( 112 | (cls_dets[c]["BB_s"], cls_sbj_boxes), 0 113 | ) 114 | cls_dets[c]["BB_o"] = np.concatenate( 115 | (cls_dets[c]["BB_o"], cls_obj_boxes), 0 116 | ) 117 | cls_dets[c]["BB_r"] = np.concatenate( 118 | (cls_dets[c]["BB_r"], cls_rel_boxes), 0 119 | ) 120 | cls_dets[c]["LBL_s"] = np.concatenate( 121 | (cls_dets[c]["LBL_s"], cls_sbj_labels) 122 | ) 123 | cls_dets[c]["LBL_o"] = np.concatenate( 124 | (cls_dets[c]["LBL_o"], cls_obj_labels) 125 | ) 126 | cls_image_ids[c] += [image_id] * len(cls_inds) 127 | cls_gt_inds = np.where(gt_labels_prd == c)[0] 128 | cls_gt_boxes_sbj = gt_boxes_sbj[cls_gt_inds] 129 | cls_gt_boxes_obj = gt_boxes_obj[cls_gt_inds] 130 | cls_gt_boxes_rel = gt_boxes_rel[cls_gt_inds] 131 | cls_gt_labels_sbj = gt_labels_sbj[cls_gt_inds] 132 | cls_gt_labels_obj = gt_labels_obj[cls_gt_inds] 133 | cls_gt_num = len(cls_gt_inds) 134 | det = [False] * cls_gt_num 135 | npos[c] = npos[c] + cls_gt_num 136 | cls_gts[c][image_id] = { 137 | "gt_boxes_sbj": cls_gt_boxes_sbj, 138 | "gt_boxes_obj": cls_gt_boxes_obj, 139 | "gt_boxes_rel": cls_gt_boxes_rel, 140 | "gt_labels_sbj": cls_gt_labels_sbj, 141 | "gt_labels_obj": cls_gt_labels_obj, 142 | "gt_num": cls_gt_num, 143 | "det": det, 144 | } 145 | return cls_image_ids, cls_dets, cls_gts, npos 146 | 147 | 148 | def get_ap(rec, prec): 149 | """Compute AP given precision and recall.""" 150 | # correct AP calculation 151 | # first append sentinel values at the end 152 | mrec = np.concatenate(([0.0], rec, [1.0])) 153 | mpre = np.concatenate(([0.0], prec, [0.0])) 154 | 155 | # compute the precision envelope 156 | for i in range(mpre.size - 1, 0, -1): 157 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 158 | 159 | # to calculate area under PR curve, look for points 160 | # where X axis (recall) changes value 161 | i = np.where(mrec[1:] != mrec[:-1])[0] 162 | 163 | # and sum (\Delta recall) * prec 164 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 165 | return ap 166 | 167 | 168 | def ap_eval(image_ids, dets, gts, npos, rel_or_phr=True, ovthresh=0.5): 169 | """ 170 | Top level function that does the relationship AP evaluation. 171 | 172 | detpath: Path to detections 173 | detpath.format(classname) should produce the detection results file. 174 | classname: Category name (duh) 175 | [ovthresh]: Overlap threshold (default = 0.5) 176 | """ 177 | 178 | confidence = dets["confidence"] 179 | BB_s = dets["BB_s"] 180 | BB_o = dets["BB_o"] 181 | BB_r = dets["BB_r"] 182 | LBL_s = dets["LBL_s"] 183 | LBL_o = dets["LBL_o"] 184 | 185 | # sort by confidence 186 | sorted_ind = np.argsort(-confidence) 187 | BB_s = BB_s[sorted_ind, :] 188 | BB_o = BB_o[sorted_ind, :] 189 | BB_r = BB_r[sorted_ind, :] 190 | LBL_s = LBL_s[sorted_ind] 191 | LBL_o = LBL_o[sorted_ind] 192 | image_ids = [image_ids[x] for x in sorted_ind] 193 | 194 | # go down dets and mark TPs and FPs 195 | nd = len(image_ids) 196 | tp = np.zeros(nd) 197 | fp = np.zeros(nd) 198 | gts_visited = {k: [False] * v["gt_num"] for k, v in gts.items()} 199 | for d in range(nd): 200 | R = gts[image_ids[d]] 201 | visited = gts_visited[image_ids[d]] 202 | bb_s = BB_s[d, :].astype(float) 203 | bb_o = BB_o[d, :].astype(float) 204 | bb_r = BB_r[d, :].astype(float) 205 | lbl_s = LBL_s[d] 206 | lbl_o = LBL_o[d] 207 | ovmax = -np.inf 208 | BBGT_s = R["gt_boxes_sbj"].astype(float) 209 | BBGT_o = R["gt_boxes_obj"].astype(float) 210 | BBGT_r = R["gt_boxes_rel"].astype(float) 211 | LBLGT_s = R["gt_labels_sbj"] 212 | LBLGT_o = R["gt_labels_obj"] 213 | if BBGT_s.size > 0: 214 | valid_mask = np.logical_and(LBLGT_s == lbl_s, LBLGT_o == lbl_o) 215 | if valid_mask.any(): 216 | if rel_or_phr: # mAP(rel) 217 | overlaps_s = bbox_iou( 218 | bb_s[None, :].astype(dtype=np.float32, copy=False), 219 | BBGT_s.astype(dtype=np.float32, copy=False), 220 | )[0] 221 | overlaps_o = bbox_iou( 222 | bb_o[None, :].astype(dtype=np.float32, copy=False), 223 | BBGT_o.astype(dtype=np.float32, copy=False), 224 | )[0] 225 | overlaps = np.minimum(overlaps_s, overlaps_o) 226 | else: # mAP(phr) 227 | overlaps = bbox_iou( 228 | bb_r[None, :].astype(dtype=np.float32, copy=False), 229 | BBGT_r.astype(dtype=np.float32, copy=False), 230 | )[0] 231 | 232 | overlaps *= valid_mask 233 | ovmax = torch.max(overlaps) 234 | jmax = torch.argmax(overlaps) 235 | else: 236 | ovmax = 0.0 237 | jmax = -1 238 | 239 | if ovmax > ovthresh: 240 | if not visited[jmax]: 241 | tp[d] = 1.0 242 | visited[jmax] = 1 243 | else: 244 | fp[d] = 1.0 245 | else: 246 | fp[d] = 1.0 247 | 248 | # compute precision recall 249 | fp = np.cumsum(fp) 250 | tp = np.cumsum(tp) 251 | rec = tp / (float(npos) + 1e-12) 252 | # ground truth 253 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 254 | ap = get_ap(rec, prec) 255 | 256 | return rec, prec, ap 257 | 258 | 259 | def boxes_union(boxes1, boxes2): 260 | assert boxes1.shape == boxes2.shape 261 | xmin = np.minimum(boxes1[:, 0], boxes2[:, 0]) 262 | ymin = np.minimum(boxes1[:, 1], boxes2[:, 1]) 263 | xmax = np.maximum(boxes1[:, 2], boxes2[:, 2]) 264 | ymax = np.maximum(boxes1[:, 3], boxes2[:, 3]) 265 | return np.vstack((xmin, ymin, xmax, ymax)).transpose() 266 | -------------------------------------------------------------------------------- /models/detr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | import math 4 | from typing import List 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import torch.nn.functional as F 10 | from scipy.optimize import linear_sum_assignment 11 | from torch import nn 12 | 13 | from detectron2.layers import ShapeSpec 14 | from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess 15 | from detectron2.structures import Boxes, ImageList, Instances, BitMasks, PolygonMasks 16 | from detectron2.utils.logger import log_first_n 17 | from fvcore.nn import giou_loss, smooth_l1_loss 18 | from .detr_modules.backbone import Joiner 19 | from .detr_modules.detr import DETR, SetCriterion 20 | from .detr_modules.matcher import HungarianMatcher 21 | from .detr_modules.position_encoding import PositionEmbeddingSine 22 | from .detr_modules.transformer import Transformer 23 | from .detr_modules.segmentation import DETRsegm, PostProcessPanoptic, PostProcessSegm 24 | from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh 25 | from utils.fb_misc import NestedTensor 26 | from pycocotools import mask as coco_mask 27 | 28 | __all__ = ["Detr"] 29 | 30 | def convert_coco_poly_to_mask(segmentations, height, width): 31 | masks = [] 32 | for polygons in segmentations: 33 | rles = coco_mask.frPyObjects(polygons, height, width) 34 | mask = coco_mask.decode(rles) 35 | if len(mask.shape) < 3: 36 | mask = mask[..., None] 37 | mask = torch.as_tensor(mask, dtype=torch.uint8) 38 | mask = mask.any(dim=2) 39 | masks.append(mask) 40 | if masks: 41 | masks = torch.stack(masks, dim=0) 42 | else: 43 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 44 | return masks 45 | 46 | 47 | 48 | 49 | class MaskedBackbone(nn.Module): 50 | """ This is a thin wrapper around D2's backbone to provide padding masking""" 51 | 52 | def __init__(self, cfg): 53 | super().__init__() 54 | self.backbone = build_backbone(cfg) 55 | backbone_shape = self.backbone.output_shape() 56 | self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] 57 | self.num_channels = backbone_shape[list(backbone_shape.keys())[-1]].channels 58 | 59 | def forward(self, images): 60 | features = self.backbone(images.tensor) 61 | masks = self.mask_out_padding( 62 | [features_per_level.shape for features_per_level in features.values()], 63 | images.image_sizes, 64 | images.tensor.device, 65 | ) 66 | assert len(features) == len(masks) 67 | for i, k in enumerate(features.keys()): 68 | features[k] = NestedTensor(features[k], masks[i]) 69 | return features 70 | 71 | def mask_out_padding(self, feature_shapes, image_sizes, device): 72 | masks = [] 73 | assert len(feature_shapes) == len(self.feature_strides) 74 | for idx, shape in enumerate(feature_shapes): 75 | N, _, H, W = shape 76 | masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device) 77 | for img_idx, (h, w) in enumerate(image_sizes): 78 | masks_per_feature_level[ 79 | img_idx, 80 | : int(np.ceil(float(h) / self.feature_strides[idx])), 81 | : int(np.ceil(float(w) / self.feature_strides[idx])), 82 | ] = 0 83 | masks.append(masks_per_feature_level) 84 | return masks 85 | 86 | 87 | @META_ARCH_REGISTRY.register() 88 | class Detr(nn.Module): 89 | """ 90 | Implement Detr 91 | """ 92 | 93 | def __init__(self, cfg): 94 | super().__init__() 95 | 96 | self.device = torch.device(cfg.MODEL.DEVICE) 97 | 98 | self.num_classes = cfg.MODEL.DETR.NUM_CLASSES 99 | self.mask_on = cfg.MODEL.MASK_ON 100 | hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM 101 | num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES 102 | # Transformer parameters: 103 | nheads = cfg.MODEL.DETR.NHEADS 104 | dropout = cfg.MODEL.DETR.DROPOUT 105 | dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD 106 | enc_layers = cfg.MODEL.DETR.ENC_LAYERS 107 | dec_layers = cfg.MODEL.DETR.DEC_LAYERS 108 | pre_norm = cfg.MODEL.DETR.PRE_NORM 109 | 110 | # Loss parameters: 111 | giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT 112 | l1_weight = cfg.MODEL.DETR.L1_WEIGHT 113 | deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION 114 | no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT 115 | 116 | N_steps = hidden_dim // 2 117 | d2_backbone = MaskedBackbone(cfg) 118 | backbone = Joiner(d2_backbone, PositionEmbeddingSine(N_steps, normalize=True)) 119 | backbone.num_channels = d2_backbone.num_channels 120 | 121 | transformer = Transformer( 122 | d_model=hidden_dim, 123 | dropout=dropout, 124 | nhead=nheads, 125 | dim_feedforward=dim_feedforward, 126 | num_encoder_layers=enc_layers, 127 | num_decoder_layers=dec_layers, 128 | normalize_before=pre_norm, 129 | return_intermediate_dec=deep_supervision, 130 | ) 131 | 132 | self.detr = DETR( 133 | backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, aux_loss=deep_supervision 134 | ) 135 | if self.mask_on: 136 | frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS 137 | if frozen_weights != '': 138 | print("LOAD pre-trained weights") 139 | weight = torch.load(frozen_weights, map_location=lambda storage, loc: storage)['model'] 140 | new_weight = {} 141 | for k, v in weight.items(): 142 | if 'detr.' in k: 143 | new_weight[k.replace('detr.', '')] = v 144 | else: 145 | print(f"Skipping loading weight {k} from frozen model") 146 | del weight 147 | self.detr.load_state_dict(new_weight) 148 | del new_weight 149 | self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != '')) 150 | self.seg_postprocess = PostProcessSegm 151 | 152 | self.detr.to(self.device) 153 | 154 | # building criterion 155 | matcher = HungarianMatcher(cost_class=1, cost_bbox=l1_weight, cost_giou=giou_weight) 156 | weight_dict = {"loss_ce": 1, "loss_bbox": l1_weight} 157 | weight_dict["loss_giou"] = giou_weight 158 | if deep_supervision: 159 | aux_weight_dict = {} 160 | for i in range(dec_layers - 1): 161 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 162 | weight_dict.update(aux_weight_dict) 163 | losses = ["labels", "boxes", "cardinality"] 164 | if self.mask_on: 165 | losses += ["masks"] 166 | self.criterion = SetCriterion( 167 | self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses, 168 | ) 169 | self.criterion.to(self.device) 170 | channels = len(cfg.MODEL.PIXEL_MEAN) 171 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(channels, 1, 1) 172 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(channels, 1, 1) 173 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std 174 | self.to(self.device) 175 | 176 | def forward(self, batched_inputs): 177 | """ 178 | Args: 179 | batched_inputs: a list, batched outputs of :class:`DatasetMapper` . 180 | Each item in the list contains the inputs for one image. 181 | For now, each item in the list is a dict that contains: 182 | 183 | * image: Tensor, image in (C, H, W) format. 184 | * instances: Instances 185 | 186 | Other information that's included in the original dicts, such as: 187 | 188 | * "height", "width" (int): the output resolution of the model, used in inference. 189 | See :meth:`postprocess` for details. 190 | Returns: 191 | dict[str: Tensor]: 192 | mapping from a named loss to a tensor storing the loss. Used during training only. 193 | """ 194 | images = self.preprocess_image(batched_inputs) 195 | output = self.detr(images) 196 | 197 | if self.training: 198 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 199 | 200 | targets = self.prepare_targets(gt_instances) 201 | loss_dict = self.criterion(output, targets) 202 | weight_dict = self.criterion.weight_dict 203 | for k in loss_dict.keys(): 204 | if k in weight_dict: 205 | loss_dict[k] *= weight_dict[k] 206 | return loss_dict 207 | else: 208 | box_cls = output["pred_logits"] 209 | box_pred = output["pred_boxes"] 210 | mask_pred = output["pred_masks"] if self.mask_on else None 211 | results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes) 212 | processed_results = [] 213 | for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): 214 | height = input_per_image.get("height", image_size[0]) 215 | width = input_per_image.get("width", image_size[1]) 216 | r = detector_postprocess(results_per_image, height, width) 217 | processed_results.append({"instances": r}) 218 | return processed_results 219 | 220 | def prepare_targets(self, targets): 221 | new_targets = [] 222 | for targets_per_image in targets: 223 | h, w = targets_per_image.image_size 224 | image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) 225 | gt_classes = targets_per_image.gt_classes 226 | gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy 227 | gt_boxes = box_xyxy_to_cxcywh(gt_boxes) 228 | new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) 229 | if self.mask_on and hasattr(targets_per_image, 'gt_masks'): 230 | gt_masks = targets_per_image.gt_masks.tensor 231 | new_targets[-1].update({'masks': gt_masks}) 232 | return new_targets 233 | 234 | def inference(self, box_cls, box_pred, mask_pred, image_sizes): 235 | """ 236 | Arguments: 237 | box_cls (Tensor): tensor of shape (batch_size, num_queries, K). 238 | The tensor predicts the classification probability for each query. 239 | box_pred (Tensor): tensors of shape (batch_size, num_queries, 4). 240 | The tensor predicts 4-vector (x,y,w,h) box 241 | regression values for every queryx 242 | image_sizes (List[torch.Size]): the input image sizes 243 | 244 | Returns: 245 | results (List[Instances]): a list of #images elements. 246 | """ 247 | assert len(box_cls) == len(image_sizes) 248 | results = [] 249 | 250 | # For each box we assign the best class or the second best if the best on is `no_object`. 251 | scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1) 252 | 253 | for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate(zip( 254 | scores, labels, box_pred, image_sizes 255 | )): 256 | result = Instances(image_size) 257 | result.pred_boxes = Boxes(box_cxcywh_to_xyxy(box_pred_per_image)) 258 | 259 | result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) 260 | if self.mask_on: 261 | mask = F.interpolate(mask_pred[i].unsqueeze(0), size=image_size, mode='bilinear', align_corners=False) 262 | mask = mask[0].sigmoid() > 0.5 263 | B, N, H, W = mask_pred.shape 264 | mask = BitMasks(mask.cpu()).crop_and_resize(result.pred_boxes.tensor.cpu(), 32) 265 | result.pred_masks = mask.unsqueeze(1).to(mask_pred[0].device) 266 | 267 | result.scores = scores_per_image 268 | result.pred_classes = labels_per_image 269 | results.append(result) 270 | return results 271 | 272 | def preprocess_image(self, batched_inputs): 273 | """ 274 | Normalize, pad and batch the input images. 275 | """ 276 | images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] 277 | images = ImageList.from_tensors(images) 278 | return images 279 | -------------------------------------------------------------------------------- /models/detr_modules/transformer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | """ 3 | DETR Transformer class. 4 | 5 | Copy-paste from torch.nn.Transformer with modifications: 6 | * positional encodings are passed in MHattention 7 | * extra LN at the end of encoder is removed 8 | * decoder returns a stack of activations from all decoding layers 9 | """ 10 | import copy 11 | from typing import Optional, List 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch import nn, Tensor 16 | 17 | 18 | class Transformer(nn.Module): 19 | 20 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 21 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 22 | activation="relu", normalize_before=False, 23 | return_intermediate_dec=False): 24 | super().__init__() 25 | 26 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 27 | dropout, activation, normalize_before) 28 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 29 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 30 | 31 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 32 | dropout, activation, normalize_before) 33 | decoder_norm = nn.LayerNorm(d_model) 34 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 35 | return_intermediate=return_intermediate_dec) 36 | 37 | self._reset_parameters() 38 | 39 | self.d_model = d_model 40 | self.nhead = nhead 41 | 42 | def _reset_parameters(self): 43 | for p in self.parameters(): 44 | if p.dim() > 1: 45 | nn.init.xavier_uniform_(p) 46 | 47 | def forward(self, src, mask, query_embed, pos_embed): 48 | # flatten NxCxHxW to HWxNxC 49 | bs, c, h, w = src.shape 50 | src = src.flatten(2).permute(2, 0, 1) 51 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 52 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 53 | mask = mask.flatten(1) 54 | 55 | tgt = torch.zeros_like(query_embed) 56 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 57 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 58 | pos=pos_embed, query_pos=query_embed) 59 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 60 | 61 | 62 | class TransformerEncoder(nn.Module): 63 | 64 | def __init__(self, encoder_layer, num_layers, norm=None): 65 | super().__init__() 66 | self.layers = _get_clones(encoder_layer, num_layers) 67 | self.num_layers = num_layers 68 | self.norm = norm 69 | 70 | def forward(self, src, 71 | mask: Optional[Tensor] = None, 72 | src_key_padding_mask: Optional[Tensor] = None, 73 | pos: Optional[Tensor] = None): 74 | output = src 75 | 76 | for layer in self.layers: 77 | output = layer(output, src_mask=mask, 78 | src_key_padding_mask=src_key_padding_mask, pos=pos) 79 | 80 | if self.norm is not None: 81 | output = self.norm(output) 82 | 83 | return output 84 | 85 | 86 | class TransformerDecoder(nn.Module): 87 | 88 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 89 | super().__init__() 90 | self.layers = _get_clones(decoder_layer, num_layers) 91 | self.num_layers = num_layers 92 | self.norm = norm 93 | self.return_intermediate = return_intermediate 94 | 95 | def forward(self, tgt, memory, 96 | tgt_mask: Optional[Tensor] = None, 97 | memory_mask: Optional[Tensor] = None, 98 | tgt_key_padding_mask: Optional[Tensor] = None, 99 | memory_key_padding_mask: Optional[Tensor] = None, 100 | pos: Optional[Tensor] = None, 101 | query_pos: Optional[Tensor] = None): 102 | output = tgt 103 | 104 | intermediate = [] 105 | 106 | for layer in self.layers: 107 | output = layer(output, memory, tgt_mask=tgt_mask, 108 | memory_mask=memory_mask, 109 | tgt_key_padding_mask=tgt_key_padding_mask, 110 | memory_key_padding_mask=memory_key_padding_mask, 111 | pos=pos, query_pos=query_pos) 112 | if self.return_intermediate: 113 | intermediate.append(self.norm(output)) 114 | 115 | if self.norm is not None: 116 | output = self.norm(output) 117 | if self.return_intermediate: 118 | intermediate.pop() 119 | intermediate.append(output) 120 | 121 | if self.return_intermediate: 122 | return torch.stack(intermediate) 123 | 124 | return output.unsqueeze(0) 125 | 126 | 127 | class TransformerEncoderLayer(nn.Module): 128 | 129 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 130 | activation="relu", normalize_before=False): 131 | super().__init__() 132 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 133 | # Implementation of Feedforward model 134 | self.linear1 = nn.Linear(d_model, dim_feedforward) 135 | self.dropout = nn.Dropout(dropout) 136 | self.linear2 = nn.Linear(dim_feedforward, d_model) 137 | 138 | self.norm1 = nn.LayerNorm(d_model) 139 | self.norm2 = nn.LayerNorm(d_model) 140 | self.dropout1 = nn.Dropout(dropout) 141 | self.dropout2 = nn.Dropout(dropout) 142 | 143 | self.activation = _get_activation_fn(activation) 144 | self.normalize_before = normalize_before 145 | 146 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 147 | return tensor if pos is None else tensor + pos 148 | 149 | def forward_post(self, 150 | src, 151 | src_mask: Optional[Tensor] = None, 152 | src_key_padding_mask: Optional[Tensor] = None, 153 | pos: Optional[Tensor] = None): 154 | q = k = self.with_pos_embed(src, pos) 155 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 156 | key_padding_mask=src_key_padding_mask)[0] 157 | src = src + self.dropout1(src2) 158 | src = self.norm1(src) 159 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 160 | src = src + self.dropout2(src2) 161 | src = self.norm2(src) 162 | return src 163 | 164 | def forward_pre(self, src, 165 | src_mask: Optional[Tensor] = None, 166 | src_key_padding_mask: Optional[Tensor] = None, 167 | pos: Optional[Tensor] = None): 168 | src2 = self.norm1(src) 169 | q = k = self.with_pos_embed(src2, pos) 170 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 171 | key_padding_mask=src_key_padding_mask)[0] 172 | src = src + self.dropout1(src2) 173 | src2 = self.norm2(src) 174 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 175 | src = src + self.dropout2(src2) 176 | return src 177 | 178 | def forward(self, src, 179 | src_mask: Optional[Tensor] = None, 180 | src_key_padding_mask: Optional[Tensor] = None, 181 | pos: Optional[Tensor] = None): 182 | if self.normalize_before: 183 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 184 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 185 | 186 | 187 | class TransformerDecoderLayer(nn.Module): 188 | 189 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 190 | activation="relu", normalize_before=False): 191 | super().__init__() 192 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 193 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 194 | # Implementation of Feedforward model 195 | self.linear1 = nn.Linear(d_model, dim_feedforward) 196 | self.dropout = nn.Dropout(dropout) 197 | self.linear2 = nn.Linear(dim_feedforward, d_model) 198 | 199 | self.norm1 = nn.LayerNorm(d_model) 200 | self.norm2 = nn.LayerNorm(d_model) 201 | self.norm3 = nn.LayerNorm(d_model) 202 | self.dropout1 = nn.Dropout(dropout) 203 | self.dropout2 = nn.Dropout(dropout) 204 | self.dropout3 = nn.Dropout(dropout) 205 | 206 | self.activation = _get_activation_fn(activation) 207 | self.normalize_before = normalize_before 208 | 209 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 210 | return tensor if pos is None else tensor + pos 211 | 212 | def forward_post(self, tgt, memory, 213 | tgt_mask: Optional[Tensor] = None, 214 | memory_mask: Optional[Tensor] = None, 215 | tgt_key_padding_mask: Optional[Tensor] = None, 216 | memory_key_padding_mask: Optional[Tensor] = None, 217 | pos: Optional[Tensor] = None, 218 | query_pos: Optional[Tensor] = None): 219 | q = k = self.with_pos_embed(tgt, query_pos) 220 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 221 | key_padding_mask=tgt_key_padding_mask)[0] 222 | tgt = tgt + self.dropout1(tgt2) 223 | tgt = self.norm1(tgt) 224 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 225 | key=self.with_pos_embed(memory, pos), 226 | value=memory, attn_mask=memory_mask, 227 | key_padding_mask=memory_key_padding_mask)[0] 228 | tgt = tgt + self.dropout2(tgt2) 229 | tgt = self.norm2(tgt) 230 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 231 | tgt = tgt + self.dropout3(tgt2) 232 | tgt = self.norm3(tgt) 233 | return tgt 234 | 235 | def forward_pre(self, tgt, memory, 236 | tgt_mask: Optional[Tensor] = None, 237 | memory_mask: Optional[Tensor] = None, 238 | tgt_key_padding_mask: Optional[Tensor] = None, 239 | memory_key_padding_mask: Optional[Tensor] = None, 240 | pos: Optional[Tensor] = None, 241 | query_pos: Optional[Tensor] = None): 242 | tgt2 = self.norm1(tgt) 243 | q = k = self.with_pos_embed(tgt2, query_pos) 244 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 245 | key_padding_mask=tgt_key_padding_mask)[0] 246 | tgt = tgt + self.dropout1(tgt2) 247 | tgt2 = self.norm2(tgt) 248 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 249 | key=self.with_pos_embed(memory, pos), 250 | value=memory, attn_mask=memory_mask, 251 | key_padding_mask=memory_key_padding_mask)[0] 252 | tgt = tgt + self.dropout2(tgt2) 253 | tgt2 = self.norm3(tgt) 254 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 255 | tgt = tgt + self.dropout3(tgt2) 256 | return tgt 257 | 258 | def forward(self, tgt, memory, 259 | tgt_mask: Optional[Tensor] = None, 260 | memory_mask: Optional[Tensor] = None, 261 | tgt_key_padding_mask: Optional[Tensor] = None, 262 | memory_key_padding_mask: Optional[Tensor] = None, 263 | pos: Optional[Tensor] = None, 264 | query_pos: Optional[Tensor] = None): 265 | if self.normalize_before: 266 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 267 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 268 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 269 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 270 | 271 | 272 | def _get_clones(module, N): 273 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 274 | 275 | 276 | def build_transformer(args): 277 | return Transformer( 278 | d_model=args.hidden_dim, 279 | dropout=args.dropout, 280 | nhead=args.nheads, 281 | dim_feedforward=args.dim_feedforward, 282 | num_encoder_layers=args.enc_layers, 283 | num_decoder_layers=args.dec_layers, 284 | normalize_before=args.pre_norm, 285 | return_intermediate_dec=True, 286 | ) 287 | 288 | 289 | def _get_activation_fn(activation): 290 | """Return an activation function given a string""" 291 | if activation == "relu": 292 | return F.relu 293 | if activation == "gelu": 294 | return F.gelu 295 | if activation == "glu": 296 | return F.glu 297 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 298 | --------------------------------------------------------------------------------