├── .gitignore ├── README.md ├── checkpoint ├── __init__.py └── detection_checkpoint.py ├── configs ├── Base-RCNN-FPN.yaml ├── defaults.py └── iterative_model.yaml ├── data ├── __init__.py ├── dataset_mapper.py ├── datasets │ ├── __init__.py │ └── visual_genome.py └── tools │ ├── __init__.py │ └── utils.py ├── engine ├── __init__.py └── trainer.py ├── evaluation ├── __init__.py ├── datasets │ └── vg │ │ └── zeroshot_triplet.pytorch ├── evaluator.py ├── sg_evaluation.py └── utils.py ├── modeling ├── __init__.py ├── backbone │ ├── __init__.py │ └── backbone.py ├── meta_arch │ ├── __init__.py │ └── detr.py └── transformer │ ├── __init__.py │ ├── criterion.py │ ├── detr.py │ ├── matcher.py │ ├── positional_encoding.py │ ├── segmentation.py │ ├── transformer.py │ └── util │ ├── __init__.py │ ├── box_ops.py │ ├── misc.py │ ├── plot_utils.py │ └── utils.py ├── structures ├── __init__.py ├── boxes_ops.py └── masks_ops.py └── train_iterative_model.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Iterative Scene Graph Generation 2 | 3 | This is the code for our paper titled ["Iterative Scene Graph Generation"](https://openreview.net/pdf?id=i0FnLiIRj6U). 4 | 5 | ## Requirements 6 | The following packages are needed to run the code. 7 | - `python == 3.8.5` 8 | - `PyTorch == 1.8.2` 9 | - `detectron2 == 0.6` 10 | - `h5py` 11 | - `imantics` 12 | - `easydict` 13 | - `cv2 == 4.5.5` 14 | - `scikit-learn` 15 | - `scipy` 16 | - `pandas` 17 | 18 | ## Dataset 19 | We use the Visual Genome filtered data widely used in the Scene Graph community. 20 | Please see the public repository of the paper [Unbiased Scene Graph Generation repository](https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch/blob/master/DATASET.md) on instructions to download this dataset. After downloading the dataset you should have the following 4 files: 21 | - `VG_100K `directory containing all the images 22 | - `VG-SGG-with-attri.h5` 23 | - `VG-SGG-dicts-with-attri.json` (Can be found in the same repository [here](https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch/tree/master/datasets/vg)) 24 | - `image_data.json` (Can be found in the same repository [here](https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch/tree/master/datasets/vg)) 25 | 26 | ## Train Iterative Model 27 | To enable faster model convergence, we pre-train DETR on Visual Genome. We replicate the DETR decoder weights three times, and initialize our models three decoders with it. For convenience, the pretrained weights (with the decoder replication) are made available [here](https://drive.google.com/drive/folders/1CdcYdcYEvkZHz-I1IFF8sBxVMWSyWIkh?usp=share_link). To use these weights during training, simply use the `MODEL.WEIGHTS ` flag in the training command. 28 | 29 | Our proposed iterative model can be trained using the following command: 30 | ```python 31 | python train_iterative_model.py --resume --num-gpus --config-file configs/iterative_model.yaml OUTPUT_DIR DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 MODEL.DETR.OVERSAMPLE_PARAM MODEL.DETR.UNDERSAMPLE_PARAM SOLVER.CLIP_GRADIENTS.CLIP_VALUE 0.01 SOLVER.IMS_PER_BATCH 12 MODEL.DETR.NO_OBJECT_WEIGHT 0.1 MODEL.WEIGHTS 32 | ``` 33 | To set the `α` value use `MODEL.DETR.OVERSAMPLE_PARAM` flag, and set the `β` value using the `MODEL.DETR.UNDERSAMPLE_PARAM`. Note that `MODEL.DETR.UNDERSAMPLE_PARAM` should be specified as twice the desired β value. So for `β=0.75` use `MODEL.DETR.UNDERSAMPLE_PARAM 1.5`. 34 | 35 | **Note**: If the code fails, try running it on a single GPU first in order to allow some preprocessed files to be generated. This is a one-time step. Once the code runs succesfully on a single GPU, you can run it on multiple GPUs as well. Additionally, the code, by default, is configured to run on 4 GPUs with a batch size of 12. If you run out of memory, change the batch size by using the flag `SOLVER.IMS_PER_BATCH `. 36 | 37 | To evaluate the code, use the following command: 38 | ```python 39 | python train_iterative_model.py --resume --eval-only --num-gpus --config-file configs/iterative_model.yaml OUTPUT_DIR DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 40 | ``` 41 | You can find our model weights for `α=0.07` and `β=0.75` [here](https://drive.google.com/drive/folders/1L2H2e-UfyKfbbmM34LaJfT6S49VTfZDY?usp=share_link). To use these weights during evaluation, simply use the `MODEL.WEIGHTS ` flag in the evaluation command. To check if the code is running correctly on your machine, the released checkpoint should give you the following metrics on the Visual Genome test set `VG_test`. 42 | 43 | ```python 44 | SGG eval: R @ 20: 0.2179; R @ 50: 0.2712; R @ 100: 0.2972; for mode=sgdet, type=Recall(Main). 45 | SGG eval: ng-R @ 20: 0.2272; ng-R @ 50: 0.3052; ng-R @ 100: 0.3547; for mode=sgdet, type=No Graph Constraint Recall(Main). 46 | SGG eval: zR @ 20: 0.0134; zR @ 50: 0.0274; zR @ 100: 0.0384; for mode=sgdet, type=Zero Shot Recall. 47 | SGG eval: mR @ 20: 0.1115; mR @ 50: 0.1561; mR @ 100: 0.1770; for mode=sgdet, type=Mean Recall. 48 | ``` 49 | 50 | 51 | -------------------------------------------------------------------------------- /checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection_checkpoint import * -------------------------------------------------------------------------------- /checkpoint/detection_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from detectron2.utils import comm 5 | from detectron2.engine import hooks, HookBase 6 | import logging 7 | 8 | class PeriodicCheckpointerWithEval(HookBase): 9 | def __init__(self, eval_period, eval_function, checkpointer, checkpoint_period, max_to_keep=5): 10 | self.eval = hooks.EvalHook(eval_period, eval_function) 11 | self.checkpointer = hooks.PeriodicCheckpointer(checkpointer, checkpoint_period, max_to_keep=max_to_keep) 12 | self.best_ap = 0.0 13 | best_model_path = checkpointer.save_dir + 'best_model_final.pth.pth' 14 | if os.path.isfile(best_model_path): 15 | best_model = torch.load(best_model_path, map_location=torch.device('cpu')) 16 | try: 17 | self.best_ap = best_model['SGMeanRecall@100'] 18 | except: 19 | self.best_ap = best_model['AP50'] 20 | del best_model 21 | print ("BEST AP: ", self.best_ap) 22 | else: 23 | self.best_ap = 0.0 24 | 25 | def before_train(self): 26 | self.max_iter = self.trainer.max_iter 27 | self.checkpointer.max_iter = self.trainer.max_iter 28 | 29 | def _do_eval(self): 30 | results = self.eval._func() 31 | comm.synchronize() 32 | return results 33 | 34 | def after_step(self): 35 | next_iter = self.trainer.iter + 1 36 | is_final = next_iter == self.trainer.max_iter 37 | if is_final or (self.eval._period > 0 and next_iter % self.eval._period == 0): 38 | results = self._do_eval() 39 | if comm.is_main_process(): 40 | try: 41 | print (results) 42 | dataset = 'VG_val' if 'VG_val' in results.keys() else 'VG_test' 43 | if results['SG']['SGMeanRecall@100'] > self.best_ap: 44 | self.best_ap = results['SG']['SGMeanRecall@100'] 45 | additional_state = {"iteration":self.trainer.iter, "SGMeanRecall@100":self.best_ap} 46 | self.checkpointer.checkpointer.save( 47 | "best_model_final.pth", **additional_state 48 | ) 49 | except: 50 | current_ap = results['bbox']['AP50'] 51 | if current_ap > self.best_ap: 52 | self.best_ap = current_ap 53 | additional_state = {"iteration":self.trainer.iter, "AP50":self.best_ap} 54 | self.checkpointer.checkpointer.save( 55 | "best_model_final.pth", **additional_state 56 | ) 57 | if comm.is_main_process(): 58 | self.checkpointer.step(self.trainer.iter) 59 | comm.synchronize() 60 | 61 | def after_train(self): 62 | # func is likely a closure that holds reference to the trainer 63 | # therefore we clean it to avoid circular reference in the end 64 | del self.eval._func -------------------------------------------------------------------------------- /configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 -------------------------------------------------------------------------------- /configs/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | from detectron2.config import CfgNode as CN 3 | 4 | def add_dataset_config(cfg): 5 | _C = cfg 6 | 7 | _C.MODEL.ROI_HEADS.NUM_OUTPUT_CLASSES = 80 8 | _C.MODEL.ROI_HEADS.EMBEDDINGS_PATH = "" 9 | _C.MODEL.ROI_HEADS.EMBEDDINGS_PATH_COCO = "" 10 | _C.MODEL.ROI_HEADS.LINGUAL_MATRIX_THRESHOLD = 0.05 11 | _C.MODEL.ROI_HEADS.MASK_NUM_CLASSES = 80 12 | 13 | _C.MODEL.FREEZE_LAYERS = CN() 14 | _C.MODEL.FREEZE_LAYERS.META_ARCH = [] 15 | _C.MODEL.FREEZE_LAYERS.ROI_HEADS = [] 16 | 17 | _C.DATASETS.TYPE = "" 18 | _C.DATASETS.VISUAL_GENOME = CN() 19 | _C.DATASETS.VISUAL_GENOME.IMAGES = '' 20 | _C.DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY = '' 21 | _C.DATASETS.VISUAL_GENOME.IMAGE_DATA = '' 22 | _C.DATASETS.VISUAL_GENOME.SG_TRAIN_DATA = '' 23 | _C.DATASETS.VISUAL_GENOME.SG_VAL_DATA = '' 24 | _C.DATASETS.VISUAL_GENOME.SG_MAPPER = '' 25 | _C.DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 = '' 26 | _C.DATASETS.VISUAL_GENOME.TRAIN_MASKS = "" 27 | _C.DATASETS.VISUAL_GENOME.TEST_MASKS = "" 28 | _C.DATASETS.VISUAL_GENOME.VAL_MASKS = "" 29 | _C.DATASETS.VISUAL_GENOME.CLIPPED = False 30 | _C.DATASETS.VISUAL_GENOME.PER_CLASS_DATASET = False 31 | _C.DATASETS.VISUAL_GENOME.BGNN = False 32 | 33 | _C.DATASETS.ACTION_GENOME = CN() 34 | _C.DATASETS.ACTION_GENOME.ANNOTATIONS = '' 35 | _C.DATASETS.ACTION_GENOME.FILTER_EMPTY_RELATIONS = True 36 | _C.DATASETS.ACTION_GENOME.FILTER_DUPLICATE_RELATIONS = True 37 | _C.DATASETS.ACTION_GENOME.FILTER_NON_OVERLAP = True 38 | _C.DATASETS.ACTION_GENOME.FORMAT_VID_WISE = False 39 | _C.DATASETS.ACTION_GENOME.FRAMES = '' 40 | _C.DATASETS.ACTION_GENOME.VIDEOS = '' 41 | _C.DATASETS.ACTION_GENOME.NUM_VIDEOS_TRAIN = -1 42 | _C.DATASETS.ACTION_GENOME.NUM_VIDEOS_VAL = 400 43 | _C.DATASETS.ACTION_GENOME.VAL_SET_RANDOMIZED = False 44 | 45 | 46 | _C.DATASETS.MSCOCO = CN() 47 | _C.DATASETS.MSCOCO.ANNOTATIONS = '' 48 | _C.DATASETS.MSCOCO.DATAROOT = '' 49 | 50 | _C.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS = True 51 | _C.DATASETS.VISUAL_GENOME.FILTER_DUPLICATE_RELATIONS = True 52 | _C.DATASETS.VISUAL_GENOME.FILTER_NON_OVERLAP = True 53 | _C.DATASETS.VISUAL_GENOME.EXCLUDE_LEFT_RIGHT = False 54 | _C.DATASETS.VISUAL_GENOME.NUMBER_OF_VALIDATION_IMAGES = 5000 55 | _C.DATASETS.VISUAL_GENOME.BOX_SCALE = 1024 56 | _C.DATASETS.VISUAL_GENOME.MAX_NUM_RELATIONS = -1 57 | _C.DATASETS.VISUAL_GENOME.MAX_NUM_OBJECTS = -1 58 | _C.DATASETS.VISUAL_GENOME.UNDERSAMPLE_PARAM = 0.7 59 | _C.DATASETS.VISUAL_GENOME.OVERSAMPLE_PARAM = 0.07 60 | 61 | _C.DATASETS.SEG_DATA_DIVISOR = 1 62 | 63 | _C.DATASETS.TRANSFER = ('coco_train_2014',) 64 | _C.DATASETS.MASK_TRAIN = ('coco_train_2017',) 65 | _C.DATASETS.MASK_TEST = ('coco_val_2017',) 66 | 67 | _C.MODEL.DETR = CN() 68 | _C.MODEL.DETR.NUM_CLASSES = 80 69 | _C.MODEL.DETR.NUM_RELATION_CLASSES = 50 70 | 71 | # For Segmentation 72 | _C.MODEL.DETR.FROZEN_WEIGHTS = '' 73 | _C.MODEL.DETR.FREEZE = False 74 | _C.MODEL.DETR.SGDET_USE_GT = False 75 | 76 | # LOSS 77 | _C.MODEL.DETR.COST_CLASS = 1.0 78 | _C.MODEL.DETR.GIOU_WEIGHT = 2.0 79 | _C.MODEL.DETR.L1_WEIGHT = 5.0 80 | _C.MODEL.DETR.FOCAL_ALPHA = 0.25 81 | _C.MODEL.DETR.DEEP_SUPERVISION = True 82 | _C.MODEL.DETR.NO_OBJECT_WEIGHT = 0.1 83 | _C.MODEL.DETR.CLS_WEIGHT = 2.0 84 | _C.MODEL.DETR.COST_SELECTION = 1.0 85 | 86 | 87 | # RELATION LOSS 88 | _C.MODEL.DETR.RELATION_LOSS_WEIGHT = 1.0 89 | 90 | # TRANSFORMER 91 | _C.MODEL.DETR.NAME = 'DETR' 92 | _C.MODEL.DETR.CRITERION = 'SetCriterion' 93 | _C.MODEL.DETR.TRANSFORMER = 'Transformer' 94 | _C.MODEL.DETR.MATCHER = 'HungarianMatcher' 95 | _C.MODEL.DETR.POSITION_EMBEDDING = 'PositionEmbeddingSine' 96 | _C.MODEL.DETR.NHEADS = 8 97 | _C.MODEL.DETR.DROPOUT = 0.1 98 | _C.MODEL.DETR.DIM_FEEDFORWARD = 2048 99 | _C.MODEL.DETR.ENC_LAYERS = 6 100 | _C.MODEL.DETR.DEC_LAYERS = 6 101 | _C.MODEL.DETR.RELATION_DEC_LAYERS = 4 102 | _C.MODEL.DETR.OBJECT_DEC_LAYERS = 4 103 | _C.MODEL.DETR.PRE_NORM = False 104 | _C.MODEL.DETR.RELATION_HEAD = True 105 | _C.MODEL.DETR.CLASS_AGNOSTIC_NMS = True 106 | _C.MODEL.DETR.INTERSECTION_IOU_THRESHOLD = 0.3 107 | _C.MODEL.DETR.INTERSECTION_IOU_LAMBDA = 1.0 108 | _C.MODEL.DETR.INTERSECTION_LOSS = False 109 | _C.MODEL.DETR.NO_REL_WEIGHT = 0.1 110 | _C.MODEL.DETR.LATER_NMS_THRESHOLD = 0.3 111 | _C.MODEL.DETR.USE_FREQ_BIAS = True 112 | _C.MODEL.DETR.NUM_FEATURE_LEVELS = 4 113 | _C.MODEL.DETR.WITH_BOX_REFINE = False 114 | _C.MODEL.DETR.TWO_STAGE = False 115 | _C.MODEL.DETR.TWO_STAGE_NUM_PROPOSALS = 300 116 | _C.MODEL.DETR.REWEIGHT_RELATIONS = False 117 | _C.MODEL.DETR.REWEIGHT_USE_LOG = True 118 | _C.MODEL.DETR.REWEIGHT_REL_EOS_COEF = 0.1 119 | _C.MODEL.DETR.NEGATIVE_RELATION_FRACTION = 3.0 120 | _C.MODEL.DETR.MAX_RELATION_PAIRS = 16 121 | _C.MODEL.DETR.NMS_WEIGHT = 0.2 122 | _C.MODEL.DETR.BETA = 1000 123 | _C.MODEL.DETR.MATCHER_TOPK = 1 124 | _C.MODEL.DETR.OVERSAMPLE_PARAM = 0.07 125 | _C.MODEL.DETR.UNDERSAMPLE_PARAM = 1.0 126 | _C.MODEL.DETR.FREEZE_LAYERS = [] 127 | _C.MODEL.DETR.TEST_INDEX = -1 128 | 129 | _C.MODEL.DETR.HIDDEN_DIM = 256 130 | _C.MODEL.DETR.NUM_OBJECT_QUERIES = 100 131 | _C.MODEL.DETR.NUM_RELATION_QUERIES = 100 132 | _C.MODEL.DETR.CREATE_BG_PAIRS = False 133 | 134 | _C.SOLVER.OPTIMIZER = "ADAMW" 135 | _C.SOLVER.BACKBONE_MULTIPLIER = 0.1 136 | _C.SOLVER.RELATION_MULTIPLIER = 1.0 137 | _C.SOLVER.ENTITY_MULTIPLIER = 1.0 138 | _C.MODEL.DF_DETR = CN() 139 | _C.MODEL.DF_DETR.NUM_CLASSES = 80 140 | 141 | # MODEL Variants 142 | _C.MODEL.DF_DETR.WITH_BOX_REFINE = False 143 | _C.MODEL.DF_DETR.TWO_STAGE = False 144 | 145 | # Backbone 146 | _C.MODEL.DF_DETR.BACKBONE = "resnet50" 147 | _C.MODEL.DF_DETR.DILATION = False 148 | _C.MODEL.DF_DETR.POSITIONAL_EMBEDDING = 'sine' 149 | _C.MODEL.DF_DETR.NUM_FEATURE_LEVELS = 4 150 | 151 | # Transformer 152 | _C.MODEL.DF_DETR.ENC_LAYERS = 6 153 | _C.MODEL.DF_DETR.DEC_LAYERS = 6 154 | _C.MODEL.DF_DETR.DIM_FEEDFORWARD = 1024 155 | _C.MODEL.DF_DETR.HIDDEN_DIM = 256 156 | _C.MODEL.DF_DETR.DROPOUT = 0.1 157 | _C.MODEL.DF_DETR.NHEADS = 8 158 | _C.MODEL.DF_DETR.NUM_OBJECT_QUERIES = 300 159 | _C.MODEL.DETR.DEC_N_POINTS = 4 160 | _C.MODEL.DETR.ENC_N_POINTS = 4 161 | 162 | # Mathcher 163 | _C.MODEL.DF_DETR.SET_COST_CLASS = 2 164 | _C.MODEL.DF_DETR.SET_COST_BBOX = 5 165 | _C.MODEL.DF_DETR.SET_COST_GIOU = 2 166 | 167 | # Loss 168 | _C.MODEL.DF_DETR.CLS_LOSS_WEIGHT = 2. 169 | _C.MODEL.DF_DETR.BBOX_LOSS_WEIGHT = 5. 170 | _C.MODEL.DF_DETR.GIOU_LOSS_WEIGHT = 2. 171 | _C.MODEL.DF_DETR.RELATION_LOSS_WEIGHT = 1. 172 | _C.MODEL.DF_DETR.FOCAL_ALPHA = 0.25 173 | _C.MODEL.DF_DETR.DEEP_SUPERVISION = True 174 | 175 | # relation part 176 | _C.MODEL.DF_DETR.RELATION_DEC_LAYERS = 4 177 | _C.MODEL.DF_DETR.PRE_NORM = False 178 | _C.MODEL.DF_DETR.RELATION_HEAD = True 179 | 180 | def add_scenegraph_config(cfg): 181 | _C = cfg 182 | 183 | _C.GLOVE_DIR = 'glove/' 184 | _C.DEV_RUN = False 185 | ################################################################################################### 186 | 187 | _C.MODEL.SCENEGRAPH_ON = True 188 | _C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = True 189 | _C.MODEL.USE_MASK_ON_NODE = False 190 | _C.MODEL.ROI_HEADS.OBJECTNESS_THRESH = 0.3 191 | _C.MODEL.GROUP_NORM = CN() 192 | _C.MODEL.GROUP_NORM.DIM_PER_GP = -1 193 | _C.MODEL.GROUP_NORM.NUM_GROUPS = 32 194 | _C.MODEL.GROUP_NORM.EPSILON = 1e-5 # default: 1e-5 195 | ################################################################################################### 196 | _C.MODEL.ROI_SCENEGRAPH_HEAD = CN() 197 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NAME = 'SceneGraphHead' 198 | _C.MODEL.ROI_SCENEGRAPH_HEAD.MODE = 'predcls' 199 | _C.MODEL.ROI_SCENEGRAPH_HEAD.REQUIRE_BOX_OVERLAP = True 200 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NUM_SAMPLE_PER_GT_REL = 4 # when sample fg relationship from gt, the max number of corresponding proposal pairs 201 | _C.MODEL.ROI_SCENEGRAPH_HEAD.BATCH_SIZE_PER_IMAGE = 64 202 | _C.MODEL.ROI_SCENEGRAPH_HEAD.POSITIVE_FRACTION = 0.25 203 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX = False 204 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL = False 205 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NMS_FILTER_DUPLICATES = True 206 | 207 | _C.MODEL.SCENEGRAPH_SAMPLER = CN() 208 | _C.MODEL.SCENEGRAPH_SAMPLER.NAME = "TransformerRelationSampler" 209 | 210 | _C.MODEL.SCENEGRAPH_LOSS = CN() 211 | _C.MODEL.SCENEGRAPH_LOSS.NAME = "TransformerRelatonLoss" 212 | 213 | _C.MODEL.SCENEGRAPH_POST_PROCESSOR = CN() 214 | _C.MODEL.SCENEGRAPH_POST_PROCESSOR.NAME = 'TransformerRelationPostProcesser' 215 | 216 | _C.MODEL.ROI_SCENEGRAPH_HEAD.RETURN_SEG_MASKS = False 217 | _C.MODEL.ROI_SCENEGRAPH_HEAD.RETURN_SEG_ANNOS = False 218 | 219 | _C.MODEL.ROI_SCENEGRAPH_HEAD.PREDICT_USE_VISION = True 220 | _C.MODEL.ROI_SCENEGRAPH_HEAD.PREDICT_USE_BIAS = True 221 | 222 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_MASK_ATTENTION = False 223 | _C.MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE = 'Self' 224 | _C.MODEL.ROI_SCENEGRAPH_HEAD.SIGMOID_ATTENTION = True 225 | 226 | _C.MODEL.ROI_SCENEGRAPH_HEAD.PREDICTOR = "MotifPredictor" 227 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NUM_CLASSES = 50 228 | _C.MODEL.ROI_SCENEGRAPH_HEAD.EMBED_DIM = 200 229 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_DROPOUT_RATE = 0.2 230 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_HIDDEN_DIM = 512 231 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_POOLING_DIM = 4096 232 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_OBJ_LAYER = 1 # assert >= 1 233 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_REL_LAYER = 1 # assert >= 1 234 | _C.MODEL.ROI_SCENEGRAPH_HEAD.ADD_GTBOX_TO_PROPOSAL_IN_TRAIN = True 235 | _C.MODEL.ROI_SCENEGRAPH_HEAD.SEG_BBOX_LOSS_MULTIPLIER = 1.0 236 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_ONLY_FG_PROPOSALS = True 237 | 238 | _C.MODEL.ROI_SCENEGRAPH_HEAD.LABEL_SMOOTHING_LOSS = False 239 | _C.MODEL.ROI_SCENEGRAPH_HEAD.REL_PROP = [0.01858, 0.00057, 0.00051, 0.00109, 0.00150, 0.00489, 0.00432, 0.02913, 0.00245, 0.00121, 240 | 0.00404, 0.00110, 0.00132, 0.00172, 0.00005, 0.00242, 0.00050, 0.00048, 0.00208, 0.15608, 241 | 0.02650, 0.06091, 0.00900, 0.00183, 0.00225, 0.00090, 0.00028, 0.00077, 0.04844, 0.08645, 242 | 0.31621, 0.00088, 0.00301, 0.00042, 0.00186, 0.00100, 0.00027, 0.01012, 0.00010, 0.01286, 243 | 0.00647, 0.00084, 0.01077, 0.00132, 0.00069, 0.00376, 0.00214, 0.11424, 0.01205, 0.02958] 244 | 245 | _C.MODEL.ROI_SCENEGRAPH_HEAD.ZERO_SHOT_TRIPLETS = 'evaluation/datasets/vg/zeroshot_triplet.pytorch' 246 | 247 | #TransformerContext 248 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER = CN() 249 | # for TransformerPredictor only 250 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.DROPOUT_RATE = 0.1 251 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.OBJ_LAYER = 4 252 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.REL_LAYER = 2 253 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.NUM_HEAD = 8 254 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.INNER_DIM = 2048 255 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.KEY_DIM = 64 256 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.VAL_DIM = 64 257 | ################################################################################################### 258 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS = CN() 259 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.NAME = 'BoxFeatureExtractor' 260 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_RESOLUTION = 28 261 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_SAMPLING_RATIO = 0 262 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_TYPE = 'ROIAlignV2' 263 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK = True 264 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.CLASS_LOGITS_WITH_MASK = False 265 | 266 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS = CN() 267 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.NAME = 'RelationFeatureExtractor' 268 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.USE_MASK_COMBINER = False 269 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.MULTIPLY_LOGITS_WITH_MASKS = True 270 | 271 | # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD) 272 | _C.MODEL.ROI_HEADS.FG_IOU_THRESHOLD = 0.5 273 | _C.MODEL.ROI_HEADS.REFINE_SEG_MASKS = True 274 | _C.MODEL.ROI_HEADS.SEGMENTATION_STEP_MASK_REFINE = True 275 | 276 | # Settings for relation testing 277 | _C.TEST.RELATION = CN() 278 | _C.TEST.RELATION.REQUIRE_OVERLAP = True 279 | _C.TEST.RELATION.LATER_NMS_PREDICTION_THRES = 0.3 280 | _C.TEST.RELATION.MULTIPLE_PREDS = False 281 | _C.TEST.RELATION.IOU_THRESHOLD = 0.5 282 | 283 | 284 | _C.DATASETS.VISUAL_GENOME.CLIPPED = False -------------------------------------------------------------------------------- /configs/iterative_model.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "IterativeRelationDetr" 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-101.pkl" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | MASK_ON: False 7 | RESNETS: 8 | DEPTH: 101 9 | STRIDE_IN_1X1: False 10 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 11 | DETR: 12 | NAME: 'IterativeRelationDETR' 13 | TRANSFORMER: 'IterativeRelationTransformer' 14 | CRITERION: 'IterativeRelationCriterion' 15 | MATCHER: 'IterativeHungarianMatcher' 16 | GIOU_WEIGHT: 2.0 17 | L1_WEIGHT: 5.0 18 | NUM_OBJECT_QUERIES: 300 19 | NUM_RELATION_QUERIES: 300 20 | NUM_CLASSES: 150 21 | OBJECT_DEC_LAYERS: 6 22 | REWEIGHT_RELATIONS: True 23 | REWEIGHT_REL_EOS_COEF: 0.1 24 | NO_OBJECT_WEIGHT: 0.1 25 | NO_REL_WEIGHT: 0.1 26 | DATASETS: 27 | TYPE: "VISUAL GENOME" 28 | TRAIN: ('VG_train',) 29 | TEST: ('VG_val',) 30 | VISUAL_GENOME: 31 | TRAIN_MASKS: "" 32 | VAL_MASKS: "" 33 | TEST_MASKS: "" 34 | TRAIN_MASKS: "" 35 | FILTER_EMPTY_RELATIONS: True 36 | FILTER_NON_OVERLAP: False 37 | FILTER_DUPLICATE_RELATIONS: True 38 | IMAGES: '' 39 | MAPPING_DICTIONARY: '' 40 | IMAGE_DATA: '' 41 | VG_ATTRIBUTE_H5: '' 42 | SOLVER: 43 | IMS_PER_BATCH: 32 44 | BASE_LR: 0.0001 45 | STEPS: (160000,) 46 | MAX_ITER: 250000 47 | WARMUP_FACTOR: 1.0 48 | WARMUP_ITERS: 10 49 | WEIGHT_DECAY: 0.0001 50 | OPTIMIZER: "ADAMW" 51 | BACKBONE_MULTIPLIER: 0.1 52 | CHECKPOINT_PERIOD: 1000 53 | CLIP_GRADIENTS: 54 | ENABLED: True 55 | CLIP_TYPE: "full_model" 56 | CLIP_VALUE: 0.01 57 | NORM_TYPE: 2.0 58 | AMP: 59 | ENABLED: True 60 | INPUT: 61 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 62 | CROP: 63 | ENABLED: True 64 | TYPE: "absolute_range" 65 | SIZE: (384, 600) 66 | FORMAT: "RGB" 67 | TEST: 68 | EVAL_PERIOD: 10000 69 | DATALOADER: 70 | FILTER_EMPTY_ANNOTATIONS: False 71 | NUM_WORKERS: 4 72 | VERSION: 2 73 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_mapper import * 2 | from .tools import register_datasets 3 | from .datasets import VisualGenomeTrainData -------------------------------------------------------------------------------- /data/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import numpy as np 4 | import torch 5 | from fvcore.common.file_io import PathManager 6 | from PIL import Image 7 | 8 | from detectron2.data import detection_utils as utils 9 | from detectron2.data import transforms as T 10 | from detectron2.structures.instances import Instances 11 | from detectron2.data import DatasetCatalog, MetadataCatalog, MapDataset, DatasetFromList, DatasetMapper 12 | from collections import defaultdict 13 | from imantics import Polygons, Mask 14 | import logging 15 | from detectron2.structures import ( 16 | BitMasks, 17 | Boxes, 18 | BoxMode, 19 | Instances, 20 | Keypoints, 21 | PolygonMasks, 22 | RotatedBoxes, 23 | polygons_to_bitmask, 24 | ) 25 | 26 | def build_transform_gen(cfg, is_train): 27 | """ 28 | Create a list of :class:`TransformGen` from config. 29 | Returns: 30 | list[TransformGen] 31 | """ 32 | if is_train: 33 | min_size = cfg.INPUT.MIN_SIZE_TRAIN 34 | max_size = cfg.INPUT.MAX_SIZE_TRAIN 35 | sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING 36 | else: 37 | min_size = cfg.INPUT.MIN_SIZE_TEST 38 | max_size = cfg.INPUT.MAX_SIZE_TEST 39 | sample_style = "choice" 40 | if sample_style == "range": 41 | assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format(len(min_size)) 42 | 43 | logger = logging.getLogger(__name__) 44 | tfm_gens = [] 45 | if is_train: 46 | tfm_gens.append(T.RandomFlip()) 47 | tfm_gens.append(T.ResizeShortestEdge(min_size, max_size, sample_style)) 48 | if is_train: 49 | logger.info("TransformGens used in training: " + str(tfm_gens)) 50 | return tfm_gens 51 | 52 | def filter_empty_instances(instances, by_box=True, by_mask=True, box_threshold=1e-5): 53 | """ 54 | Filter out empty instances in an `Instances` object. 55 | Args: 56 | instances (Instances): 57 | by_box (bool): whether to filter out instances with empty boxes 58 | by_mask (bool): whether to filter out instances with empty masks 59 | box_threshold (float): minimum width and height to be considered non-empty 60 | Returns: 61 | Instances: the filtered instances. 62 | """ 63 | assert by_box or by_mask 64 | r = [] 65 | if by_box: 66 | r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) 67 | if instances.has("gt_masks") and by_mask: 68 | r.append(instances.gt_masks.nonempty()) 69 | 70 | # TODO: can also filter visible keypoints 71 | 72 | if not r: 73 | return instances 74 | m = r[0] 75 | # for x in r[1:]: 76 | # m = m & x 77 | return instances[m], r 78 | 79 | class DetrDatasetMapper: 80 | """ 81 | A callable which takes a dataset dict in Detectron2 Dataset format, 82 | and map it into a format used by DETR. 83 | The callable currently does the following: 84 | 1. Read the image from "file_name" 85 | 2. Applies geometric transforms to the image and annotation 86 | 3. Find and applies suitable cropping to the image and annotation 87 | 4. Prepare image and annotation to Tensors 88 | """ 89 | 90 | def __init__(self, cfg, is_train=True): 91 | if cfg.INPUT.CROP.ENABLED and is_train: 92 | self.crop_gen = [ 93 | T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), 94 | T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE), 95 | ] 96 | self.recompute_boxes = cfg.MODEL.MASK_ON 97 | else: 98 | self.crop_gen = None 99 | self.recompute_boxes = False 100 | 101 | self.mask_on = cfg.MODEL.MASK_ON 102 | self.tfm_gens = build_transform_gen(cfg, is_train) 103 | logging.getLogger(__name__).info( 104 | "Full TransformGens used in training: {}, crop: {}".format(str(self.tfm_gens), str(self.crop_gen)) 105 | ) 106 | 107 | self.img_format = cfg.INPUT.FORMAT 108 | self.is_train = is_train 109 | self.filter_duplicate_relations = cfg.DATASETS.VISUAL_GENOME.FILTER_DUPLICATE_RELATIONS 110 | self.max_num_rels = cfg.DATASETS.VISUAL_GENOME.MAX_NUM_RELATIONS 111 | self.max_num_objs = cfg.DATASETS.VISUAL_GENOME.MAX_NUM_OBJECTS 112 | self.data_type = cfg.DATASETS.TYPE 113 | 114 | 115 | def __call__(self, dataset_dict): 116 | """ 117 | Args: 118 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 119 | Returns: 120 | dict: a format that builtin models in detectron2 accept 121 | """ 122 | dataset_dict = copy.deepcopy(dataset_dict) 123 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 124 | h, w, _ = image.shape 125 | if w != dataset_dict['width'] or h != dataset_dict['height']: 126 | dataset_dict['width'] = w 127 | dataset_dict['height'] = h 128 | utils.check_image_size(dataset_dict, image) 129 | 130 | if self.crop_gen is None: 131 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 132 | else: 133 | if np.random.rand() > 0.5: 134 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 135 | else: 136 | image, transforms = T.apply_transform_gens( 137 | self.tfm_gens[:-1] + self.crop_gen + self.tfm_gens[-1:], image 138 | ) 139 | 140 | image_shape = image.shape[:2] # h, w 141 | 142 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 143 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 144 | # Therefore it's important to use torch.Tensor. 145 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 146 | 147 | # if not self.is_train: 148 | # # USER: Modify this if you want to keep them for some reason. 149 | # dataset_dict.pop("annotations", None) 150 | # return dataset_dict 151 | 152 | # Filter duplicate relations 153 | rel_present = False 154 | if "relations" in dataset_dict: 155 | if self.filter_duplicate_relations and self.is_train: 156 | relation_dict = defaultdict(list) 157 | for object_0, object_1, relation in dataset_dict["relations"]: 158 | relation_dict[(object_0,object_1)].append(relation) 159 | dataset_dict["relations"] = [(k[0], k[1], np.random.choice(v)) for k,v in relation_dict.items()] 160 | 161 | dataset_dict["relations"] = torch.as_tensor(np.ascontiguousarray(dataset_dict["relations"])) 162 | rel_present = True 163 | 164 | if self.data_type == "VISUAL GENOME RELATION": 165 | if self.filter_duplicate_relations and self.is_train: 166 | relation_dict = defaultdict(list) 167 | relation_idx = defaultdict(list) 168 | for idx, (object_0, object_1, relation) in enumerate(dataset_dict['relation_mapper']): 169 | relation_dict[(object_0,object_1)].append(relation) 170 | relation_idx[(object_0,object_1)].append(idx) 171 | selected_idxs = [] 172 | for k, v in relation_idx.items(): 173 | selected_idxs.append(np.random.choice(v)) 174 | selected_idxs = np.sort(selected_idxs) 175 | dataset_dict['annotations'] = [dataset_dict['annotations'][selected_idx] for selected_idx in selected_idxs] 176 | 177 | 178 | if "annotations" in dataset_dict: 179 | # USER: Modify this if you want to keep them for some reason. 180 | for anno in dataset_dict["annotations"]: 181 | if not self.mask_on: 182 | anno.pop("segmentation", None) 183 | anno.pop("keypoints", None) 184 | 185 | # USER: Implement additional transformations if you have other types of data 186 | if self.data_type != "VISUAL GENOME RELATION": 187 | annos = [ 188 | utils.transform_instance_annotations(obj, transforms, image_shape) 189 | for obj in dataset_dict.pop("annotations") 190 | if obj.get("iscrowd", 0) == 0 191 | ] 192 | instances = utils.annotations_to_instances(annos, image_shape) 193 | else: 194 | annos = [ 195 | transform_instance_annotations_relation(obj, transforms, image_shape) 196 | for obj in dataset_dict.pop("annotations") 197 | if obj.get("iscrowd", 0) == 0 198 | ] 199 | instances = annotations_to_instances_relation(annos, image_shape) 200 | 201 | if rel_present: 202 | # Add object attributes 203 | instances.gt_attributes = torch.tensor([obj['attribute'] for obj in annos], dtype=torch.int64) 204 | if self.recompute_boxes: 205 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 206 | dataset_dict["instances"], filter_mask = utils.filter_empty_instances(instances, return_mask=True) 207 | 208 | # Fix GT relations where boxes are removed due to them being too small. 209 | if rel_present: 210 | if not filter_mask.all(): 211 | object_mapper = {int(old_idx): new_idx for new_idx, old_idx in enumerate(torch.arange(filter_mask.size(0))[filter_mask])} 212 | new_relations = [] 213 | for idx, (object_0, object_1, relation) in enumerate(dataset_dict['relations'].numpy()): 214 | if (object_0 in object_mapper) and (object_1 in object_mapper): 215 | new_relations.append([object_mapper[object_0], object_mapper[object_1], relation]) 216 | if len(new_relations) > 0: 217 | dataset_dict['relations'] = torch.tensor(new_relations) 218 | else: 219 | dataset_dict['relations'] = torch.zeros(0, 3).long() 220 | else: 221 | if len(dataset_dict['relations']) == 0: 222 | dataset_dict['relations'] = torch.zeros(0, 3).long() 223 | else: 224 | dataset_dict['relations'] = torch.zeros(0, 3).long() 225 | 226 | if self.data_type == "GQA": 227 | if len(dataset_dict['instances']) > self.max_num_objs: 228 | # Randomly sample max number of objects 229 | sample_idxs = np.random.permutation(np.arange(len(dataset_dict['instances'])))[:self.max_num_objs] 230 | dataset_dict['instances'] = dataset_dict['instances'][sample_idxs] 231 | object_mapper = {sample_idx:new_idx for new_idx, sample_idx in enumerate(sample_idxs)} 232 | if len(dataset_dict['relations']) > 0: 233 | new_relations = [] 234 | for idx, (object_0, object_1, relation) in enumerate(dataset_dict['relations'].numpy()): 235 | if (object_0 in object_mapper) and (object_1 in object_mapper): 236 | new_relations.append([object_mapper[object_0], object_mapper[object_1], relation]) 237 | if len(new_relations) > 0: 238 | dataset_dict['relations'] = torch.tensor(new_relations) 239 | else: 240 | dataset_dict['relations'] = torch.zeros(0, 3).long() 241 | 242 | if len(dataset_dict['relations']) > self.max_num_rels: 243 | sample_idxs = np.random.permutation(np.arange(len(dataset_dict['relations'])))[:self.max_num_rels] 244 | dataset_dict['relations'] = dataset_dict['relations'][sample_idxs] 245 | 246 | return dataset_dict 247 | 248 | def transform_instance_annotations_relation( 249 | annotation, transforms, image_size, *, keypoint_hflip_indices=None 250 | ): 251 | """ 252 | Apply transforms to box, segmentation and keypoints annotations of a single instance. 253 | It will use `transforms.apply_box` for the box, and 254 | `transforms.apply_coords` for segmentation polygons & keypoints. 255 | If you need anything more specially designed for each data structure, 256 | you'll need to implement your own version of this function or the transforms. 257 | Args: 258 | annotation (dict): dict of instance annotations for a single instance. 259 | It will be modified in-place. 260 | transforms (TransformList or list[Transform]): 261 | image_size (tuple): the height, width of the transformed image 262 | keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`. 263 | Returns: 264 | dict: 265 | the same input dict with fields "bbox", "segmentation", "keypoints" 266 | transformed according to `transforms`. 267 | The "bbox_mode" field will be set to XYXY_ABS. 268 | """ 269 | if isinstance(transforms, (tuple, list)): 270 | transforms = T.TransformList(transforms) 271 | # bbox is 1d (per-instance bounding box) 272 | bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS) 273 | bbox_union = BoxMode.convert(annotation["union_bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS) 274 | # clip transformed bbox to image size 275 | bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0) 276 | bbox_union = transforms.apply_box(np.array([bbox_union]))[0].clip(min=0) 277 | annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1]) 278 | annotation["bbox_union"] = np.minimum(bbox_union, list(image_size + image_size)[::-1]) 279 | annotation["bbox_mode"] = BoxMode.XYXY_ABS 280 | return annotation 281 | 282 | def annotations_to_instances_relation(annos, image_size, mask_format="polygon"): 283 | """ 284 | Create an :class:`Instances` object used by the models, 285 | from instance annotations in the dataset dict. 286 | Args: 287 | annos (list[dict]): a list of instance annotations in one image, each 288 | element for one instance. 289 | image_size (tuple): height, width 290 | Returns: 291 | Instances: 292 | It will contain fields "gt_boxes", "gt_classes", 293 | "gt_masks", "gt_keypoints", if they can be obtained from `annos`. 294 | This is the format that builtin models expect. 295 | """ 296 | boxes = ( 297 | np.stack( 298 | [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos] 299 | ) 300 | if len(annos) 301 | else np.zeros((0, 4)) 302 | ) 303 | union_boxes = ( 304 | np.stack( 305 | [BoxMode.convert(obj["bbox_union"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos] 306 | ) 307 | if len(annos) 308 | else np.zeros((0, 4)) 309 | ) 310 | target = Instances(image_size) 311 | target.gt_boxes = Boxes(boxes) 312 | target.union_boxes = Boxes(union_boxes) 313 | 314 | classes = [int(obj["category_id"]) for obj in annos] 315 | classes = torch.tensor(classes, dtype=torch.int64) 316 | target.gt_classes = classes 317 | 318 | return target 319 | 320 | -------------------------------------------------------------------------------- /data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .visual_genome import VisualGenomeTrainData -------------------------------------------------------------------------------- /data/datasets/visual_genome.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import json 3 | import math 4 | from math import floor 5 | from PIL import Image, ImageDraw 6 | import random 7 | import os 8 | import torch 9 | import numpy as np 10 | import pickle 11 | import yaml 12 | from detectron2.config import get_cfg 13 | from detectron2.structures import Instances, Boxes, pairwise_iou, BoxMode 14 | from detectron2.data import DatasetCatalog, MetadataCatalog 15 | import logging 16 | from collections import defaultdict 17 | from torch.utils.data import Dataset, DataLoader 18 | import copy 19 | 20 | class VisualGenomeTrainData: 21 | """ 22 | Register data for Visual Genome training 23 | """ 24 | def __init__(self, cfg, split='train'): 25 | self.cfg = cfg 26 | self.split = split 27 | if split == 'train': 28 | self.mask_location = cfg.DATASETS.VISUAL_GENOME.TRAIN_MASKS 29 | elif split == 'val': 30 | self.mask_location = cfg.DATASETS.VISUAL_GENOME.VAL_MASKS 31 | else: 32 | self.mask_location = cfg.DATASETS.VISUAL_GENOME.TEST_MASKS 33 | self.mask_exists = os.path.isfile(self.mask_location) 34 | self.clamped = True if "clamped" in self.mask_location else "" 35 | self.per_class_dataset = cfg.DATASETS.VISUAL_GENOME.PER_CLASS_DATASET if split == 'train' else False 36 | self.bgnn = cfg.DATASETS.VISUAL_GENOME.BGNN if split == 'train' else False 37 | self.clipped = cfg.DATASETS.VISUAL_GENOME.CLIPPED 38 | self.precompute = False if (self.cfg.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS or self.cfg.DATASETS.VISUAL_GENOME.FILTER_NON_OVERLAP) else True 39 | try: 40 | with open('./data/datasets/images_to_remove.txt', 'r') as f: 41 | ids = f.readlines() 42 | self.ids_to_remove = {int(x.strip()) : 1 for x in ids[0].replace('[', '').replace(']','').split(",")} 43 | except: 44 | self.ids_to_remove = [] 45 | # self._process_data() 46 | self.dataset_dicts = self._fetch_data_dict() 47 | self.register_dataset() 48 | try: 49 | statistics = self.get_statistics() 50 | except: 51 | pass 52 | if self.bgnn: 53 | freq = statistics['fg_rel_count'] / statistics['fg_rel_count'].sum() 54 | freq = freq.numpy() 55 | oversample_param = cfg.DATASETS.VISUAL_GENOME.OVERSAMPLE_PARAM 56 | undersample_param = cfg.DATASETS.VISUAL_GENOME.UNDERSAMPLE_PARAM 57 | oversampling_ratio = np.maximum(np.sqrt((oversample_param / (freq + 1e-5))), np.ones_like(freq))[:-1] 58 | sampled_dataset_dicts = [] 59 | sampled_num = [] 60 | unique_relation_ratios = [] 61 | unique_relations_dict = [] 62 | for record in self.dataset_dicts: 63 | relations = record['relations'] 64 | if len(relations) > 0: 65 | unique_relations = np.unique(relations[:,2]) 66 | repeat_num = int(np.ceil(np.max(oversampling_ratio[unique_relations]))) 67 | for rep_idx in range(repeat_num): 68 | sampled_num.append(repeat_num) 69 | unique_relation_ratios.append(oversampling_ratio[unique_relations]) 70 | sampled_dataset_dicts.append(record) 71 | unique_relations_dict.append({rel:idx for idx, rel in enumerate(unique_relations)}) 72 | else: 73 | sampled_dataset_dicts.append(record) 74 | sampled_num.append(1) 75 | unique_relation_ratios.append([]) 76 | unique_relations_dict.append({}) 77 | 78 | self.dataset_dicts = sampled_dataset_dicts 79 | self.dataloader = BGNNSampler(self.dataset_dicts, sampled_num, oversampling_ratio, undersample_param, unique_relation_ratios, unique_relations_dict) 80 | DatasetCatalog.remove('VG_{}'.format(self.split)) 81 | self.register_dataset(dataloader=True) 82 | MetadataCatalog.get('VG_{}'.format(self.split)).set(statistics=statistics) 83 | print (self.idx_to_predicates, statistics['fg_rel_count'].numpy().tolist()) 84 | 85 | if self.per_class_dataset: 86 | freq = statistics['fg_rel_count'] / statistics['fg_rel_count'].sum() 87 | freq = freq.numpy() 88 | oversample_param = cfg.DATASETS.VISUAL_GENOME.OVERSAMPLE_PARAM 89 | undersample_param = cfg.DATASETS.VISUAL_GENOME.UNDERSAMPLE_PARAM 90 | oversampling_ratio = np.maximum(np.sqrt((oversample_param / (freq + 1e-5))), np.ones_like(freq))[:-1] 91 | unique_relation_ratios = defaultdict(list) 92 | unique_relations_dict = defaultdict(list) 93 | per_class_dataset = defaultdict(list) 94 | sampled_num = defaultdict(list) 95 | for record in self.dataset_dicts: 96 | relations = record['relations'] 97 | if len(relations) > 0: 98 | unique_relations = np.unique(relations[:,2]) 99 | repeat_num = int(np.ceil(np.max(oversampling_ratio[unique_relations]))) 100 | for rel in unique_relations: 101 | per_class_dataset[rel].append(record) 102 | sampled_num[rel].append(repeat_num) 103 | unique_relation_ratios[rel].append(oversampling_ratio[unique_relations]) 104 | unique_relations_dict[rel].append({rel:idx for idx, rel in enumerate(unique_relations)}) 105 | self.dataloader = ClassBalancedSampler(per_class_dataset, len(self.dataset_dicts), sampled_num, oversampling_ratio, undersample_param, unique_relation_ratios, unique_relations_dict) 106 | DatasetCatalog.remove('VG_{}'.format(self.split)) 107 | self.register_dataset(dataloader=True) 108 | MetadataCatalog.get('VG_{}'.format(self.split)).set(statistics=statistics) 109 | print (self.idx_to_predicates, statistics['fg_rel_count'].numpy().tolist()) 110 | 111 | def register_dataset(self, dataloader=False): 112 | """ 113 | Register datasets to use with Detectron2 114 | """ 115 | if not dataloader: 116 | DatasetCatalog.register('VG_{}'.format(self.split), lambda: self.dataset_dicts) 117 | else: 118 | DatasetCatalog.register('VG_{}'.format(self.split), lambda: self.dataloader) 119 | 120 | #Get labels 121 | self.mapping_dictionary = json.load(open(self.cfg.DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY, 'r')) 122 | self.idx_to_classes = sorted(self.mapping_dictionary['label_to_idx'], key=lambda k: self.mapping_dictionary['label_to_idx'][k]) 123 | self.idx_to_predicates = sorted(self.mapping_dictionary['predicate_to_idx'], key=lambda k: self.mapping_dictionary['predicate_to_idx'][k]) 124 | self.idx_to_attributes = sorted(self.mapping_dictionary['attribute_to_idx'], key=lambda k: self.mapping_dictionary['attribute_to_idx'][k]) 125 | MetadataCatalog.get('VG_{}'.format(self.split)).set(thing_classes=self.idx_to_classes, predicate_classes=self.idx_to_predicates, attribute_classes=self.idx_to_attributes) 126 | 127 | def _fetch_data_dict(self): 128 | """ 129 | Load data in detectron format 130 | """ 131 | fileName = "tmp/visual_genome_{}_data_{}{}{}{}{}{}{}{}.pkl".format(self.split, 'masks' if self.mask_exists else '', '_oi' if 'oi' in self.mask_location else '', "_clamped" if self.clamped else "", "_precomp" if self.precompute else "", "_clipped" if self.clipped else "", '_overlapfalse' if not self.cfg.DATASETS.VISUAL_GENOME.FILTER_NON_OVERLAP else "", '_emptyfalse' if not self.cfg.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS else '', "_perclass" if self.per_class_dataset else '') 132 | print("Loading file: ", fileName) 133 | if os.path.isfile(fileName): 134 | #If data has been processed earlier, load that to save time 135 | with open(fileName, 'rb') as inputFile: 136 | dataset_dicts = pickle.load(inputFile) 137 | else: 138 | #Process data 139 | os.makedirs('tmp', exist_ok=True) 140 | dataset_dicts = self._process_data() 141 | with open(fileName, 'wb') as inputFile: 142 | pickle.dump(dataset_dicts, inputFile) 143 | return dataset_dicts 144 | 145 | def _process_data(self): 146 | self.VG_attribute_h5 = h5py.File(self.cfg.DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5, 'r') 147 | 148 | # Remove corrupted images 149 | image_data = json.load(open(self.cfg.DATASETS.VISUAL_GENOME.IMAGE_DATA, 'r')) 150 | self.corrupted_ims = ['1592', '1722', '4616', '4617'] 151 | self.image_data = [] 152 | for i, img in enumerate(image_data): 153 | if str(img['image_id']) in self.corrupted_ims: 154 | continue 155 | self.image_data.append(img) 156 | assert(len(self.image_data) == 108073) 157 | self.masks = None 158 | if self.mask_location != "": 159 | try: 160 | with open(self.mask_location, 'rb') as f: 161 | self.masks = pickle.load(f) 162 | except: 163 | pass 164 | dataset_dicts = self._load_graphs() 165 | return dataset_dicts 166 | 167 | def get_statistics(self, eps=1e-3, bbox_overlap=True): 168 | num_object_classes = len(MetadataCatalog.get('VG_{}'.format(self.split)).thing_classes) + 1 169 | num_relation_classes = len(MetadataCatalog.get('VG_{}'.format(self.split)).predicate_classes) + 1 170 | 171 | fg_matrix = np.zeros((num_object_classes, num_object_classes, num_relation_classes), dtype=np.int64) 172 | bg_matrix = np.zeros((num_object_classes, num_object_classes), dtype=np.int64) 173 | fg_rel_count = np.zeros((num_relation_classes), dtype=np.int64) 174 | for idx, data in enumerate(self.dataset_dicts): 175 | gt_relations = data['relations'] 176 | gt_classes = np.array([x['category_id'] for x in data['annotations']]) 177 | gt_boxes = np.array([x['bbox'] for x in data['annotations']]) 178 | for (o1, o2), rel in zip(gt_classes[gt_relations[:,:2]], gt_relations[:,2]): 179 | fg_matrix[o1, o2, rel] += 1 180 | fg_rel_count[rel] += 1 181 | 182 | for (o1, o2) in gt_classes[np.array(box_filter(gt_boxes, must_overlap=bbox_overlap), dtype=int)]: 183 | bg_matrix[o1, o2] += 1 184 | bg_matrix += 1 185 | fg_matrix[:, :, -1] = bg_matrix 186 | pred_dist = np.log(fg_matrix / fg_matrix.sum(2)[:, :, None] + eps) 187 | 188 | result = { 189 | 'fg_matrix': torch.from_numpy(fg_matrix), 190 | 'pred_dist': torch.from_numpy(pred_dist).float(), 191 | 'fg_rel_count': torch.from_numpy(fg_rel_count).float(), 192 | 'obj_classes': self.idx_to_classes + ['__background__'], 193 | 'rel_classes': self.idx_to_predicates + ['__background__'], 194 | 'att_classes': self.idx_to_attributes, 195 | } 196 | print (torch.from_numpy(fg_rel_count).float()) 197 | MetadataCatalog.get('VG_{}'.format(self.split)).set(statistics=result) 198 | return result 199 | 200 | def _load_graphs(self): 201 | """ 202 | Parse examples and create dictionaries 203 | """ 204 | data_split = self.VG_attribute_h5['split'][:] 205 | split_flag = 2 if self.split == 'test' else 0 206 | split_mask = data_split == split_flag 207 | 208 | #Filter images without bounding boxes 209 | split_mask &= self.VG_attribute_h5['img_to_first_box'][:] >= 0 210 | if self.cfg.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS: 211 | split_mask &= self.VG_attribute_h5['img_to_first_rel'][:] >= 0 212 | image_index = np.where(split_mask)[0] 213 | 214 | if self.split == 'val': 215 | image_index = image_index[:self.cfg.DATASETS.VISUAL_GENOME.NUMBER_OF_VALIDATION_IMAGES] 216 | elif self.split == 'train': 217 | image_index = image_index[self.cfg.DATASETS.VISUAL_GENOME.NUMBER_OF_VALIDATION_IMAGES:] 218 | 219 | split_mask = np.zeros_like(data_split).astype(bool) 220 | split_mask[image_index] = True 221 | 222 | # Get box information 223 | all_labels = self.VG_attribute_h5['labels'][:, 0] 224 | all_attributes = self.VG_attribute_h5['attributes'][:, :] 225 | all_boxes = self.VG_attribute_h5['boxes_{}'.format(self.cfg.DATASETS.VISUAL_GENOME.BOX_SCALE)][:] # cx,cy,w,h 226 | assert np.all(all_boxes[:, :2] >= 0) # sanity check 227 | assert np.all(all_boxes[:, 2:] > 0) # no empty box 228 | 229 | # Convert from xc, yc, w, h to x1, y1, x2, y2 230 | all_boxes[:, :2] = all_boxes[:, :2] - all_boxes[:, 2:] / 2 231 | all_boxes[:, 2:] = all_boxes[:, :2] + all_boxes[:, 2:] 232 | 233 | first_box_index = self.VG_attribute_h5['img_to_first_box'][split_mask] 234 | last_box_index = self.VG_attribute_h5['img_to_last_box'][split_mask] 235 | first_relation_index = self.VG_attribute_h5['img_to_first_rel'][split_mask] 236 | last_relation_index = self.VG_attribute_h5['img_to_last_rel'][split_mask] 237 | 238 | #Load relation labels 239 | all_relations = self.VG_attribute_h5['relationships'][:] 240 | all_relation_predicates = self.VG_attribute_h5['predicates'][:, 0] 241 | 242 | image_indexer = np.arange(len(self.image_data))[split_mask] 243 | # Iterate over images 244 | dataset_dicts = [] 245 | num_rels = [] 246 | num_objs = [] 247 | for idx, _ in enumerate(image_index): 248 | record = {} 249 | #Get image metadata 250 | image_data = self.image_data[image_indexer[idx]] 251 | record['file_name'] = os.path.join(self.cfg.DATASETS.VISUAL_GENOME.IMAGES, '{}.jpg'.format(image_data['image_id'])) 252 | record['image_id'] = image_data['image_id'] 253 | record['height'] = image_data['height'] 254 | record['width'] = image_data['width'] 255 | if self.clipped: 256 | if image_data['coco_id'] in self.ids_to_remove: 257 | continue 258 | #Get annotations 259 | boxes = all_boxes[first_box_index[idx]:last_box_index[idx] + 1, :] 260 | gt_classes = all_labels[first_box_index[idx]:last_box_index[idx] + 1] 261 | gt_attributes = all_attributes[first_box_index[idx]:last_box_index[idx] + 1, :] 262 | 263 | if first_relation_index[idx] > -1: 264 | predicates = all_relation_predicates[first_relation_index[idx]:last_relation_index[idx] + 1] 265 | objects = all_relations[first_relation_index[idx]:last_relation_index[idx] + 1] - first_box_index[idx] 266 | predicates = predicates - 1 267 | relations = np.column_stack((objects, predicates)) 268 | else: 269 | assert not self.cfg.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS 270 | relations = np.zeros((0, 3), dtype=np.int32) 271 | 272 | if self.cfg.DATASETS.VISUAL_GENOME.FILTER_NON_OVERLAP and self.split == 'train': 273 | # Remove boxes that don't overlap 274 | boxes_list = Boxes(boxes) 275 | ious = pairwise_iou(boxes_list, boxes_list) 276 | relation_boxes_ious = ious[relations[:,0], relations[:,1]] 277 | iou_indexes = np.where(relation_boxes_ious > 0.0)[0] 278 | if iou_indexes.size > 0: 279 | relations = relations[iou_indexes] 280 | else: 281 | #Ignore image 282 | continue 283 | #Get masks if possible 284 | if self.masks is not None: 285 | try: 286 | gt_masks = self.masks[image_data['image_id']] 287 | except: 288 | print (image_data['image_id']) 289 | record['relations'] = relations 290 | objects = [] 291 | # if len(boxes) != len(gt_masks): 292 | mask_idx = 0 293 | for obj_idx in range(len(boxes)): 294 | resized_box = boxes[obj_idx] / self.cfg.DATASETS.VISUAL_GENOME.BOX_SCALE * max(record['height'], record['width']) 295 | obj = { 296 | "bbox": resized_box.tolist(), 297 | "bbox_mode": BoxMode.XYXY_ABS, 298 | "category_id": gt_classes[obj_idx] - 1, 299 | "attribute": gt_attributes[obj_idx], 300 | } 301 | if self.masks is not None: 302 | if gt_masks['empty_index'][obj_idx]: 303 | refined_poly = [] 304 | for poly_idx, poly in enumerate(gt_masks['polygons'][mask_idx]): 305 | if len(poly) >= 6: 306 | refined_poly.append(poly) 307 | obj["segmentation"] = refined_poly 308 | mask_idx += 1 309 | else: 310 | obj["segmentation"] = [] 311 | if len(obj["segmentation"]) > 0: 312 | objects.append(obj) 313 | else: 314 | objects.append(obj) 315 | num_objs.append(len(objects)) 316 | num_rels.append(len(relations)) 317 | 318 | record['annotations'] = objects 319 | dataset_dicts.append(record) 320 | print ("Max Rels:", np.max(num_rels), "Max Objs:", np.max(num_objs)) 321 | print ("Avg Rels:", np.mean(num_rels), "Avg Objs:", np.mean(num_objs)) 322 | print ("Median Rels:", np.median(num_rels), "Median Objs:", np.median(num_objs)) 323 | return dataset_dicts 324 | 325 | class ClassBalancedSampler(Dataset): 326 | """ 327 | Wrap a list to a torch Dataset. It produces elements of the list as data. 328 | """ 329 | 330 | def __init__(self, lst, lst_len, sampled_num, oversampled_ratio, undersample_param, unique_relation_ratios, unique_relations): 331 | self._lst = lst 332 | self._len = lst_len 333 | self.sampled_num = sampled_num 334 | self.oversampled_ratio = oversampled_ratio 335 | self.undersample_param = undersample_param 336 | self.unique_relation_ratios = unique_relation_ratios 337 | self.unique_relations = unique_relations 338 | self._num_classes = len(lst.keys()) 339 | 340 | def __len__(self): 341 | return self._len 342 | 343 | def __getitem__(self, idx): 344 | class_idx = np.random.randint(self._num_classes) 345 | random_example = np.random.randint(len(self._lst[class_idx])) 346 | record = self._lst[class_idx][random_example] 347 | relations = record['relations'] 348 | new_record = copy.deepcopy(record) 349 | if len(relations) > 0: 350 | unique_relations = self.unique_relations[class_idx][random_example] 351 | rc = self.unique_relation_ratios[class_idx][random_example] 352 | ri = self.sampled_num[class_idx][random_example] 353 | dropout = np.clip(((ri - rc)/ri) * self.undersample_param, 0.0, 1.0) 354 | random_arr = np.random.uniform(size=len(relations)) 355 | index_arr = np.array([unique_relations[rel] for rel in relations[:, 2]]) 356 | rel_dropout = dropout[index_arr] 357 | to_keep = rel_dropout < random_arr 358 | dropped_relations = relations[to_keep] 359 | new_record['relations'] = dropped_relations 360 | return new_record 361 | 362 | class BGNNSampler(Dataset): 363 | """ 364 | Wrap a list to a torch Dataset. It produces elements of the list as data. 365 | """ 366 | 367 | def __init__(self, lst, sampled_num, oversampled_ratio, undersample_param, unique_relation_ratios, unique_relations): 368 | self._lst = lst 369 | self.sampled_num = sampled_num 370 | self.oversampled_ratio = oversampled_ratio 371 | self.undersample_param = undersample_param 372 | self.unique_relation_ratios = unique_relation_ratios 373 | self.unique_relations = unique_relations 374 | 375 | 376 | def __len__(self): 377 | return len(self._lst) 378 | 379 | def __getitem__(self, idx): 380 | record = self._lst[idx] 381 | relations = record['relations'] 382 | new_record = copy.deepcopy(record) 383 | if len(relations) > 0: 384 | unique_relations = self.unique_relations[idx] 385 | rc = self.unique_relation_ratios[idx] 386 | ri = self.sampled_num[idx] 387 | dropout = np.clip(((ri - rc)/ri) * self.undersample_param, 0.0, 1.0) 388 | random_arr = np.random.uniform(size=len(relations)) 389 | index_arr = np.array([unique_relations[rel] for rel in relations[:, 2]]) 390 | rel_dropout = dropout[index_arr] 391 | to_keep = rel_dropout < random_arr 392 | dropped_relations = relations[to_keep] 393 | new_record['relations'] = dropped_relations 394 | 395 | return new_record 396 | 397 | 398 | def box_filter(boxes, must_overlap=False): 399 | """ Only include boxes that overlap as possible relations. 400 | If no overlapping boxes, use all of them.""" 401 | n_cands = boxes.shape[0] 402 | 403 | overlaps = bbox_overlaps(boxes.astype(np.float), boxes.astype(np.float), to_move=0) > 0 404 | np.fill_diagonal(overlaps, 0) 405 | 406 | all_possib = np.ones_like(overlaps, dtype=np.bool) 407 | np.fill_diagonal(all_possib, 0) 408 | 409 | if must_overlap: 410 | possible_boxes = np.column_stack(np.where(overlaps)) 411 | 412 | if possible_boxes.size == 0: 413 | possible_boxes = np.column_stack(np.where(all_possib)) 414 | else: 415 | possible_boxes = np.column_stack(np.where(all_possib)) 416 | return possible_boxes 417 | 418 | def bbox_overlaps(boxes1, boxes2, to_move=1): 419 | """ 420 | boxes1 : numpy, [num_obj, 4] (x1,y1,x2,y2) 421 | boxes2 : numpy, [num_obj, 4] (x1,y1,x2,y2) 422 | """ 423 | #print('boxes1: ', boxes1.shape) 424 | #print('boxes2: ', boxes2.shape) 425 | num_box1 = boxes1.shape[0] 426 | num_box2 = boxes2.shape[0] 427 | lt = np.maximum(boxes1.reshape([num_box1, 1, -1])[:,:,:2], boxes2.reshape([1, num_box2, -1])[:,:,:2]) # [N,M,2] 428 | rb = np.minimum(boxes1.reshape([num_box1, 1, -1])[:,:,2:], boxes2.reshape([1, num_box2, -1])[:,:,2:]) # [N,M,2] 429 | 430 | wh = (rb - lt + to_move).clip(min=0) # [N,M,2] 431 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 432 | return inter -------------------------------------------------------------------------------- /data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import register_datasets -------------------------------------------------------------------------------- /data/tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from ..datasets import VisualGenomeTrainData 5 | from detectron2.data.datasets import register_coco_instances 6 | 7 | def register_datasets(cfg): 8 | if cfg.DATASETS.TYPE == 'VISUAL GENOME': 9 | for split in ['train', 'val', 'test']: 10 | dataset_instance = VisualGenomeTrainData(cfg, split=split) 11 | -------------------------------------------------------------------------------- /engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * -------------------------------------------------------------------------------- /engine/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import numpy as np 5 | import logging 6 | import detectron2.utils.comm as comm 7 | import time 8 | import datetime 9 | import pickle 10 | import itertools 11 | import pycocotools.mask as mask_util 12 | from collections import OrderedDict 13 | from detectron2.utils.logger import setup_logger, log_every_n_seconds 14 | from detectron2.engine import DefaultTrainer 15 | from detectron2.data import ( 16 | MetadataCatalog, 17 | build_detection_test_loader, 18 | build_detection_train_loader, 19 | get_detection_dataset_dicts, 20 | build_batch_data_loader 21 | ) 22 | from detectron2.evaluation import DatasetEvaluators, DatasetEvaluator, inference_on_dataset, print_csv_format, inference_context 23 | from imantics import Polygons, Mask 24 | 25 | from detectron2.engine import hooks, HookBase 26 | from ..data import DetrDatasetMapper 27 | from detectron2.evaluation import ( 28 | COCOEvaluator, 29 | SemSegEvaluator 30 | ) 31 | from ..checkpoint import PeriodicCheckpointerWithEval 32 | from ..evaluation import scenegraph_inference_on_dataset, SceneGraphEvaluator 33 | from detectron2.engine import hooks 34 | from detectron2.data.samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler 35 | from detectron2.data.common import MapDataset, DatasetFromList 36 | from detectron2.data.dataset_mapper import DatasetMapper 37 | from detectron2.data.build import trivial_batch_collator 38 | from detectron2.utils.comm import get_world_size, is_main_process 39 | from typing import Any, Dict, List, Set 40 | from detectron2.solver.build import maybe_add_gradient_clipping 41 | 42 | class JointTransformerTrainer(DefaultTrainer): 43 | @classmethod 44 | def build_train_loader(cls, cfg): 45 | return build_detection_train_loader(cfg, mapper=DetrDatasetMapper(cfg, True)) 46 | 47 | @classmethod 48 | def build_test_loader(cls, cfg, dataset_name): 49 | return build_detection_test_loader(cfg, dataset_name, mapper=DetrDatasetMapper(cfg, False)) 50 | 51 | def build_hooks(self): 52 | """ 53 | Build a list of default hooks, including timing, evaluation, 54 | checkpointing, lr scheduling, precise BN, writing events. 55 | 56 | Returns: 57 | list[HookBase]: 58 | """ 59 | cfg = self.cfg.clone() 60 | cfg.defrost() 61 | cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN 62 | 63 | ret = [ 64 | hooks.IterationTimer(), 65 | hooks.LRScheduler(self.optimizer, self.scheduler), 66 | hooks.PreciseBN( 67 | # Run at the same freq as (but before) evaluation. 68 | cfg.TEST.EVAL_PERIOD, 69 | self.model, 70 | # Build a new data loader to not affect training 71 | self.build_train_loader(cfg), 72 | cfg.TEST.PRECISE_BN.NUM_ITER, 73 | ) 74 | if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) 75 | else None, 76 | ] 77 | 78 | # Do PreciseBN before checkpointer, because it updates the model and need to 79 | # be saved by checkpointer. 80 | # This is not always the best: if checkpointing has a different frequency, 81 | # some checkpoints may have more precise statistics than others. 82 | # if comm.is_main_process(): 83 | # ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_to_keep=1)) 84 | 85 | def test_and_save_results(): 86 | self._last_eval_results = self.test(self.cfg, self.model) 87 | return self._last_eval_results 88 | 89 | # Do evaluation after checkpointer, because then if it fails, 90 | # we can use the saved checkpoint to debug. 91 | # ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) 92 | ret.append(PeriodicCheckpointerWithEval(cfg.TEST.EVAL_PERIOD, test_and_save_results, self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, max_to_keep=1)) 93 | if comm.is_main_process(): 94 | # run writers in the end, so that evaluation metrics are written 95 | ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) 96 | return ret 97 | 98 | # def run_step(self): 99 | # """ 100 | # Implement the AMP training logic. 101 | # """ 102 | # assert self._trainer.model.training, "[AMPTrainer] model was changed to eval mode!" 103 | # assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" 104 | # from torch.cuda.amp import autocast 105 | 106 | # start = time.perf_counter() 107 | # data = next(self._trainer._data_loader_iter) 108 | # data_time = time.perf_counter() - start 109 | 110 | # with autocast(): 111 | # loss_dict = self._trainer.model(data) 112 | # if isinstance(loss_dict, torch.Tensor): 113 | # losses = loss_dict 114 | # loss_dict = {"total_loss": loss_dict} 115 | # else: 116 | # losses = sum(loss_dict.values()) 117 | 118 | # self._trainer.optimizer.zero_grad() 119 | # self._trainer.grad_scaler.scale(losses).backward() 120 | # for name, param in self._trainer.model.named_parameters(): 121 | # try: 122 | # print (name, param.grad.norm()) 123 | # except: 124 | # print (name, param.grad) 125 | # import ipdb; ipdb.set_trace() 126 | # self._write_metrics(loss_dict, data_time) 127 | 128 | # self.grad_scaler.step(self.optimizer) 129 | # self.grad_scaler.update() 130 | 131 | @classmethod 132 | def build_optimizer(cls, cfg, model): 133 | params: List[Dict[str, Any]] = [] 134 | memo: Set[torch.nn.parameter.Parameter] = set() 135 | logger = logging.getLogger("detectron2") 136 | for key, value in model.named_parameters(recurse=True): 137 | if not value.requires_grad: 138 | continue 139 | # Avoid duplicating parameters 140 | if value in memo: 141 | continue 142 | memo.add(value) 143 | lr = cfg.SOLVER.BASE_LR 144 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 145 | if "backbone" in key: 146 | lr = lr * cfg.SOLVER.BACKBONE_MULTIPLIER 147 | if "relation" in key: 148 | lr = lr * cfg.SOLVER.RELATION_MULTIPLIER 149 | logger.info("Setting LR for {} to {}".format(key, lr)) 150 | if "detr.transformer.encoder" in key or "detr.transformer.decoder.layers" in key or "detr.query_embed" in key or 'backbone' in key or 'detr.transformer.decoder.norm' in key: 151 | lr = lr * cfg.SOLVER.ENTITY_MULTIPLIER 152 | logger.info("Setting LR for {} to {}".format(key, lr)) 153 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 154 | 155 | def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class 156 | # detectron2 doesn't have full model gradient clipping now 157 | clip_norm_val = cfg.SOLVER.CLIP_GRADIENTS.CLIP_VALUE 158 | enable = ( 159 | cfg.SOLVER.CLIP_GRADIENTS.ENABLED 160 | and cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model" 161 | and clip_norm_val > 0.0 162 | ) 163 | 164 | class FullModelGradientClippingOptimizer(optim): 165 | def step(self, closure=None): 166 | all_params = itertools.chain(*[x["params"] for x in self.param_groups]) 167 | torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) 168 | super().step(closure=closure) 169 | 170 | return FullModelGradientClippingOptimizer if enable else optim 171 | 172 | optimizer_type = cfg.SOLVER.OPTIMIZER 173 | if optimizer_type == "SGD": 174 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.SGD)( 175 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM 176 | ) 177 | elif optimizer_type == "ADAMW": 178 | optimizer = maybe_add_full_model_gradient_clipping(torch.optim.AdamW)( 179 | params, cfg.SOLVER.BASE_LR 180 | ) 181 | else: 182 | raise NotImplementedError(f"no optimizer type {optimizer_type}") 183 | if not cfg.SOLVER.CLIP_GRADIENTS.CLIP_TYPE == "full_model": 184 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 185 | return optimizer 186 | 187 | @classmethod 188 | def test(cls, cfg, model, evaluators=None): 189 | """ 190 | Args: 191 | cfg (CfgNode): 192 | model (nn.Module): 193 | evaluators (list[DatasetEvaluator] or None): if None, will call 194 | :meth:`build_evaluator`. Otherwise, must have the same length as 195 | ``cfg.DATASETS.TEST``. 196 | Returns: 197 | dict: a dict of result metrics 198 | """ 199 | logger = logging.getLogger(__name__) 200 | 201 | 202 | results = OrderedDict() 203 | for idx, dataset_name in enumerate(cfg.DATASETS.TEST): 204 | data_loader = cls.build_test_loader(cfg, dataset_name) 205 | # import ipdb; ipdb.set_trace() 206 | 207 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 208 | if cfg.MODEL.DETR.RELATION_HEAD and cfg.MODEL.META_ARCHITECTURE != 'DetrWithSGGBBox' and cfg.MODEL.META_ARCHITECTURE != 'Detr' and cfg.MODEL.META_ARCHITECTURE != 'QuerySplitObjectDetr' and cfg.MODEL.META_ARCHITECTURE != 'QuerySplitUnionBoxDetr' and cfg.MODEL.META_ARCHITECTURE != 'QuerySplitObjectDetrTest' and cfg.MODEL.META_ARCHITECTURE != 'RelationDetr' and cfg.MODEL.META_ARCHITECTURE != 'ConditionalDETR' and cfg.MODEL.META_ARCHITECTURE != 'QueryConditionalDETR' and cfg.MODEL.META_ARCHITECTURE != 'QueryConditionalDeformableDETR' and cfg.MODEL.META_ARCHITECTURE != 'LatentBoxConditionalDETR' and cfg.MODEL.META_ARCHITECTURE != 'LatentRelationDETRTest' and cfg.MODEL.META_ARCHITECTURE != 'LatentBoxDETR' and cfg.MODEL.META_ARCHITECTURE != 'LatentBoxDeformableDETR' and cfg.MODEL.META_ARCHITECTURE != 'LatentBoxDETRTest' and cfg.MODEL.META_ARCHITECTURE != 'LatentBoxConditionalDETRTest' and cfg.MODEL.META_ARCHITECTURE != 'LatentRelationCoordsDETRTest' and cfg.MODEL.META_ARCHITECTURE != 'LatentBoxCoordsDetr': 209 | # and cfg.MODEL.META_ARCHITECTURE != 'LatentRelationCoordsNoAttentionDETR': 210 | evaluator = SceneGraphEvaluator(dataset_name, cfg, True, output_folder) 211 | else: 212 | evaluator = COCOEvaluator(dataset_name, cfg, True, output_folder) 213 | results_i = scenegraph_inference_on_dataset(cfg, model, data_loader, evaluator) 214 | 215 | # print("Out of sg inference") 216 | results[dataset_name] = results_i 217 | if comm.is_main_process(): 218 | assert isinstance( 219 | results_i, dict 220 | ), "Evaluator must return a dict on the main process. Got {} instead.".format( 221 | results_i 222 | ) 223 | logger.info("Evaluation results for {} in csv format:".format(dataset_name)) 224 | print_csv_format(results_i) 225 | comm.synchronize() 226 | if len(results) == 1: 227 | results = list(results.values())[0] 228 | return results 229 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import scenegraph_inference_on_dataset 2 | from .sg_evaluation import SceneGraphEvaluator 3 | 4 | 5 | -------------------------------------------------------------------------------- /evaluation/datasets/vg/zeroshot_triplet.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/IterativeSG/7abd162cb8e510cfbcededbeb18b36b54f381189/evaluation/datasets/vg/zeroshot_triplet.pytorch -------------------------------------------------------------------------------- /evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | from collections import OrderedDict 5 | from contextlib import contextmanager 6 | from collections import Counter 7 | import torch 8 | 9 | from detectron2.utils.comm import get_world_size, is_main_process 10 | from detectron2.utils.logger import log_every_n_seconds 11 | 12 | 13 | from detectron2.evaluation import COCOEvaluator 14 | 15 | 16 | def scenegraph_inference_on_dataset(cfg, model, data_loader, evaluator): 17 | """ 18 | Run model on the data_loader and evaluate the metrics with evaluator. 19 | Also benchmark the inference speed of `model.forward` accurately. 20 | The model will be used in eval mode. 21 | 22 | Args: 23 | model (nn.Module): a module which accepts an object from 24 | `data_loader` and returns some outputs. It will be temporarily set to `eval` mode. 25 | 26 | If you wish to evaluate a model in `training` mode instead, you can 27 | wrap the given model and override its behavior of `.eval()` and `.train()`. 28 | data_loader: an iterable object with a length. 29 | The elements it generates will be the inputs to the model. 30 | evaluator (DatasetEvaluator): the evaluator to run. Use `None` if you only want 31 | to benchmark, but don't want to do any evaluation. 32 | 33 | Returns: 34 | The return value of `evaluator.evaluate()` 35 | """ 36 | num_devices = get_world_size() 37 | logger = logging.getLogger('detectron2') 38 | logger.info("Start inference on {} images".format(len(data_loader))) 39 | 40 | total = len(data_loader) # inference data loader must have a fixed length 41 | 42 | # evaluator = COCOEvaluator(dataset_name, cfg, True, output_folder) 43 | 44 | evaluator.reset() 45 | num_warmup = min(5, total - 1) 46 | start_time = time.perf_counter() 47 | total_compute_time = 0 48 | with inference_context(model), torch.no_grad(): 49 | for idx, inputs in enumerate(data_loader): 50 | if idx == num_warmup: 51 | start_time = time.perf_counter() 52 | total_compute_time = 0 53 | 54 | # if len(inputs[0]['instances']) > 40: 55 | # continue 56 | start_compute_time = time.perf_counter() 57 | 58 | outputs = model(inputs) 59 | 60 | if torch.cuda.is_available(): 61 | torch.cuda.synchronize() 62 | total_compute_time += time.perf_counter() - start_compute_time 63 | evaluator.process(inputs, outputs) 64 | 65 | iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) 66 | seconds_per_img = total_compute_time / iters_after_start 67 | 68 | if idx >= num_warmup * 2 or seconds_per_img > 5: 69 | total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start 70 | eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) 71 | # logger.info("Inference done {}/{}. {:.4f} s / img. ETA={}".format(idx + 1, total, seconds_per_img, str(eta))) 72 | log_every_n_seconds( 73 | logging.INFO, 74 | "Inference done {}/{}. {:.4f} s / img. ETA={}".format( 75 | idx + 1, total, seconds_per_img, str(eta) 76 | ), 77 | n=5, 78 | name='detectron2' 79 | ) 80 | 81 | if cfg.DEV_RUN and idx==2: 82 | break 83 | 84 | 85 | # Measure the time only for this worker (before the synchronization barrier) 86 | total_time = time.perf_counter() - start_time 87 | total_time_str = str(datetime.timedelta(seconds=total_time)) 88 | # NOTE this format is parsed by grep 89 | logger.info( 90 | "Total inference time: {} ({:.6f} s / img per device, on {} devices)".format( 91 | total_time_str, total_time / (total - num_warmup), num_devices 92 | ) 93 | ) 94 | total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) 95 | logger.info( 96 | "Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format( 97 | total_compute_time_str, total_compute_time / (total - num_warmup), num_devices 98 | ) 99 | ) 100 | 101 | results = evaluator.evaluate() 102 | # An evaluator may return None when not in main process. 103 | # Replace it by an empty dict instead to make it easier for downstream code to handle 104 | if results is None: 105 | results = {} 106 | return results 107 | 108 | # _LOG_COUNTER = Counter() 109 | # _LOG_TIMER = {} 110 | 111 | # def log_every_n_seconds(lvl, msg, n=1, *, name=None): 112 | # """ 113 | # Log no more than once per n seconds. 114 | 115 | # Args: 116 | # lvl (int): the logging level 117 | # msg (str): 118 | # n (int): 119 | # name (str): name of the logger to use. Will use the caller's module by default. 120 | # """ 121 | # caller_module, key = _find_caller() 122 | # last_logged = _LOG_TIMER.get(key, None) 123 | # current_time = time.time() 124 | # if last_logged is None or current_time - last_logged >= n: 125 | # logging.getLogger('detectron2').log(lvl, msg) 126 | # _LOG_TIMER[key] = current_time 127 | 128 | @contextmanager 129 | def inference_context(model): 130 | """ 131 | A context where the model is temporarily changed to eval mode, 132 | and restored to previous mode afterwards. 133 | 134 | Args: 135 | model: a torch Module 136 | """ 137 | training_mode = model.training 138 | model.eval() 139 | yield 140 | model.train(training_mode) 141 | -------------------------------------------------------------------------------- /evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def intersect_2d(x1, x2): 4 | """ 5 | Given two arrays [m1, n], [m2,n], returns a [m1, m2] array where each entry is True if those 6 | rows match. 7 | :param x1: [m1, n] numpy array 8 | :param x2: [m2, n] numpy array 9 | :return: [m1, m2] bool array of the intersections 10 | """ 11 | if x1.shape[1] != x2.shape[1]: 12 | raise ValueError("Input arrays must have same #columns") 13 | 14 | # This performs a matrix multiplication-esque thing between the two arrays 15 | # Instead of summing, we want the equality, so we reduce in that way 16 | res = (x1[..., None] == x2.T[None, ...]).all(1) 17 | return res 18 | 19 | def argsort_desc(scores): 20 | """ 21 | Returns the indices that sort scores descending in a smart way 22 | :param scores: Numpy array of arbitrary size 23 | :return: an array of size [numel(scores), dim(scores)] where each row is the index you'd 24 | need to get the score. 25 | """ 26 | return np.column_stack(np.unravel_index(np.argsort(-scores.ravel()), scores.shape)) 27 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .meta_arch import * -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbone import * -------------------------------------------------------------------------------- /modeling/backbone/backbone.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from typing import Dict, List 4 | 5 | import numpy as np 6 | import torch 7 | import torch.distributed as dist 8 | import torch.nn.functional as F 9 | from torch import nn 10 | 11 | from detectron2.layers import ShapeSpec 12 | from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess 13 | from collections import OrderedDict 14 | import torchvision 15 | from torch import nn 16 | from torchvision.models._utils import IntermediateLayerGetter 17 | from ..transformer.util.misc import NestedTensor, is_main_process 18 | 19 | class FrozenBatchNorm2d(torch.nn.Module): 20 | """ 21 | BatchNorm2d where the batch statistics and the affine parameters are fixed. 22 | Copy-paste from torchvision.misc.ops with added eps before rqsrt, 23 | without which any other models than torchvision.models.resnet[18,34,50,101] 24 | produce nans. 25 | """ 26 | 27 | def __init__(self, n): 28 | super(FrozenBatchNorm2d, self).__init__() 29 | self.register_buffer("weight", torch.ones(n)) 30 | self.register_buffer("bias", torch.zeros(n)) 31 | self.register_buffer("running_mean", torch.zeros(n)) 32 | self.register_buffer("running_var", torch.ones(n)) 33 | 34 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, 35 | missing_keys, unexpected_keys, error_msgs): 36 | num_batches_tracked_key = prefix + 'num_batches_tracked' 37 | if num_batches_tracked_key in state_dict: 38 | del state_dict[num_batches_tracked_key] 39 | 40 | super(FrozenBatchNorm2d, self)._load_from_state_dict( 41 | state_dict, prefix, local_metadata, strict, 42 | missing_keys, unexpected_keys, error_msgs) 43 | 44 | def forward(self, x): 45 | # move reshapes to the beginning 46 | # to make it fuser-friendly 47 | w = self.weight.reshape(1, -1, 1, 1) 48 | b = self.bias.reshape(1, -1, 1, 1) 49 | rv = self.running_var.reshape(1, -1, 1, 1) 50 | rm = self.running_mean.reshape(1, -1, 1, 1) 51 | eps = 1e-5 52 | scale = w * (rv + eps).rsqrt() 53 | bias = b - rm * scale 54 | return x * scale + bias 55 | 56 | class MaskedBackbone(nn.Module): 57 | """ This is a thin wrapper around D2's backbone to provide padding masking""" 58 | 59 | def __init__(self, cfg): 60 | super().__init__() 61 | self.backbone = build_backbone(cfg) 62 | backbone_shape = self.backbone.output_shape() 63 | self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] 64 | self.num_channels = backbone_shape[list(backbone_shape.keys())[-1]].channels 65 | 66 | def forward(self, images): 67 | features = self.backbone(images.tensor) 68 | masks = self.mask_out_padding( 69 | [features_per_level.shape for features_per_level in features.values()], 70 | images.image_sizes, 71 | images.tensor.device, 72 | ) 73 | assert len(features) == len(masks) 74 | for i, k in enumerate(features.keys()): 75 | features[k] = NestedTensor(features[k], masks[i]) 76 | return features 77 | 78 | def mask_out_padding(self, feature_shapes, image_sizes, device): 79 | masks = [] 80 | assert len(feature_shapes) == len(self.feature_strides) 81 | for idx, shape in enumerate(feature_shapes): 82 | N, _, H, W = shape 83 | masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device) 84 | for img_idx, (h, w) in enumerate(image_sizes): 85 | masks_per_feature_level[ 86 | img_idx, 87 | : int(np.ceil(float(h) / self.feature_strides[idx])), 88 | : int(np.ceil(float(w) / self.feature_strides[idx])), 89 | ] = 0 90 | masks.append(masks_per_feature_level) 91 | return masks 92 | 93 | class DeformableDETRMaskedBackbone(nn.Module): 94 | """ This is a thin wrapper around D2's backbone to provide padding masking""" 95 | 96 | def __init__(self, cfg, return_interm_layers=False): 97 | super().__init__() 98 | self.backbone = build_backbone(cfg) 99 | backbone_shape = self.backbone.output_shape() 100 | self.feature_strides = [backbone_shape[f].stride for f in backbone_shape.keys()] 101 | self.num_channels = backbone_shape[list(backbone_shape.keys())[-1]].channels 102 | if return_interm_layers: 103 | # return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 104 | return_layers = {"res3": "0", "res4": "1", "res5": "2"} 105 | self.strides = [8, 16, 32] 106 | self.num_channels = [512, 1024, 2048] 107 | else: 108 | return_layers = {'res5': "0"} 109 | self.strides = [32] 110 | self.num_channels = [2048] 111 | 112 | def forward(self, images): 113 | features = self.backbone(images.tensor) 114 | masks = self.mask_out_padding( 115 | [features_per_level.shape for features_per_level in features.values()], 116 | images.image_sizes, 117 | images.tensor.device, 118 | ) 119 | assert len(features) == len(masks) 120 | for i, k in enumerate(features.keys()): 121 | features[k] = NestedTensor(features[k], masks[i]) 122 | return features 123 | 124 | def mask_out_padding(self, feature_shapes, image_sizes, device): 125 | masks = [] 126 | assert len(feature_shapes) == len(self.feature_strides) 127 | for idx, shape in enumerate(feature_shapes): 128 | N, _, H, W = shape 129 | masks_per_feature_level = torch.ones((N, H, W), dtype=torch.bool, device=device) 130 | for img_idx, (h, w) in enumerate(image_sizes): 131 | masks_per_feature_level[ 132 | img_idx, 133 | : int(np.ceil(float(h) / self.feature_strides[idx])), 134 | : int(np.ceil(float(w) / self.feature_strides[idx])), 135 | ] = 0 136 | masks.append(masks_per_feature_level) 137 | return masks 138 | 139 | class Joiner(nn.Sequential): 140 | def __init__(self, backbone, position_embedding): 141 | super().__init__(backbone, position_embedding) 142 | 143 | def forward(self, tensor_list: NestedTensor): 144 | xs = self[0](tensor_list) 145 | out: List[NestedTensor] = [] 146 | pos = [] 147 | for name, x in xs.items(): 148 | out.append(x) 149 | # position encoding 150 | pos.append(self[1](x).to(x.tensors.dtype)) 151 | 152 | return out, pos 153 | 154 | class DeformableDETRJoiner(nn.Sequential): 155 | def __init__(self, backbone, position_embedding): 156 | super().__init__(backbone, position_embedding) 157 | self.strides = backbone.strides 158 | self.num_channels = backbone.num_channels 159 | 160 | def forward(self, tensor_list: NestedTensor): 161 | xs = self[0](tensor_list) 162 | out: List[NestedTensor] = [] 163 | pos = [] 164 | for name, x in sorted(xs.items()): 165 | out.append(x) 166 | 167 | # position encoding 168 | for x in out: 169 | pos.append(self[1](x).to(x.tensors.dtype)) 170 | 171 | return out, pos 172 | 173 | class BackboneBase(nn.Module): 174 | def __init__(self, backbone: nn.Module, train_backbone: bool, num_channels: int, return_interm_layers: bool): 175 | super().__init__() 176 | for name, parameter in backbone.named_parameters(): 177 | if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name: 178 | parameter.requires_grad_(False) 179 | if return_interm_layers: 180 | return_layers = {"layer1": "0", "layer2": "1", "layer3": "2", "layer4": "3"} 181 | else: 182 | return_layers = {'layer4': "0"} 183 | self.body = IntermediateLayerGetter(backbone, return_layers=return_layers) 184 | self.num_channels = num_channels 185 | 186 | def forward(self, tensor_list): 187 | xs = self.body(tensor_list.tensor) 188 | out: Dict[str, NestedTensor] = {} 189 | for name, x in xs.items(): 190 | m = tensor_list.mask 191 | assert m is not None 192 | mask = F.interpolate(m[None].float(), size=x.shape[-2:]).to(torch.bool)[0] 193 | out[name] = NestedTensor(x, mask) 194 | return out 195 | 196 | class Backbone(BackboneBase): 197 | """ResNet backbone with frozen BatchNorm.""" 198 | def __init__(self, name: str, 199 | train_backbone: bool, 200 | return_interm_layers: bool, 201 | dilation: bool): 202 | backbone = getattr(torchvision.models, name)( 203 | replace_stride_with_dilation=[False, False, dilation], 204 | pretrained=is_main_process(), norm_layer=FrozenBatchNorm2d) 205 | num_channels = 512 if name in ('resnet18', 'resnet34') else 2048 206 | super().__init__(backbone, train_backbone, num_channels, return_interm_layers) 207 | 208 | 209 | 210 | -------------------------------------------------------------------------------- /modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .detr import * -------------------------------------------------------------------------------- /modeling/meta_arch/detr.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import math 3 | from multiprocessing import Condition 4 | from turtle import back 5 | from typing import List 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | import torch.nn.functional as F 11 | from scipy.optimize import linear_sum_assignment 12 | from torch import nn 13 | 14 | from detectron2.layers import ShapeSpec 15 | from detectron2.modeling import META_ARCH_REGISTRY, build_backbone, detector_postprocess 16 | from detectron2.structures import Boxes, ImageList, Instances, BitMasks, PolygonMasks, pairwise_iou 17 | from detectron2.utils.logger import log_first_n 18 | from detectron2.data import MetadataCatalog 19 | from fvcore.nn import giou_loss, smooth_l1_loss 20 | from ..transformer import build_detr, build_criterion, build_transformer, build_matcher, build_position_encoding 21 | from ..transformer.segmentation import DETRsegm, PostProcessPanoptic, PostProcessSegm 22 | from ..transformer.util.utils import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh, NestedTensor, convert_coco_poly_to_mask 23 | from ..transformer.util import box_ops 24 | from ..backbone import MaskedBackbone, Joiner, Backbone, DeformableDETRMaskedBackbone, DeformableDETRJoiner 25 | from detectron2.layers import batched_nms 26 | 27 | __all__ = ["Detr"] 28 | 29 | 30 | @META_ARCH_REGISTRY.register() 31 | class Detr(nn.Module): 32 | """ 33 | Implement Detr 34 | """ 35 | 36 | def __init__(self, cfg): 37 | super().__init__() 38 | 39 | self.device = torch.device(cfg.MODEL.DEVICE) 40 | self.num_classes = cfg.MODEL.DETR.NUM_CLASSES 41 | self.num_relation_classes = cfg.MODEL.DETR.NUM_RELATION_CLASSES 42 | self.mask_on = cfg.MODEL.MASK_ON 43 | self.use_gt_box = cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX 44 | self.use_gt_label = cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL 45 | self.later_nms_thres = cfg.MODEL.DETR.LATER_NMS_THRESHOLD 46 | self.use_freq_bias = cfg.MODEL.DETR.USE_FREQ_BIAS 47 | self.test_index = cfg.MODEL.DETR.TEST_INDEX 48 | hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM 49 | num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES 50 | num_relation_queries = cfg.MODEL.DETR.NUM_RELATION_QUERIES 51 | create_bg_pairs = cfg.MODEL.DETR.CREATE_BG_PAIRS 52 | # Transformer parameters: 53 | nheads = cfg.MODEL.DETR.NHEADS 54 | dropout = cfg.MODEL.DETR.DROPOUT 55 | dim_feedforward = cfg.MODEL.DETR.DIM_FEEDFORWARD 56 | enc_layers = cfg.MODEL.DETR.ENC_LAYERS 57 | dec_layers = cfg.MODEL.DETR.DEC_LAYERS 58 | obj_dec_layers = cfg.MODEL.DETR.OBJECT_DEC_LAYERS 59 | pre_norm = cfg.MODEL.DETR.PRE_NORM 60 | 61 | # Loss parameters: 62 | giou_weight = cfg.MODEL.DETR.GIOU_WEIGHT 63 | l1_weight = cfg.MODEL.DETR.L1_WEIGHT 64 | deep_supervision = cfg.MODEL.DETR.DEEP_SUPERVISION 65 | no_object_weight = cfg.MODEL.DETR.NO_OBJECT_WEIGHT 66 | no_rel_weight = cfg.MODEL.DETR.NO_REL_WEIGHT 67 | cost_class = cfg.MODEL.DETR.COST_CLASS 68 | nms_weight = cfg.MODEL.DETR.NMS_WEIGHT 69 | cost_selection = cfg.MODEL.DETR.COST_SELECTION 70 | beta = cfg.MODEL.DETR.BETA 71 | matcher_topk = cfg.MODEL.DETR.MATCHER_TOPK 72 | self.nms_thresh = cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST 73 | 74 | d2_backbone = MaskedBackbone(cfg) 75 | position_embedding = build_position_encoding(cfg.MODEL.DETR.POSITION_EMBEDDING, hidden_dim) 76 | backbone = Joiner(d2_backbone, position_embedding) 77 | backbone.num_channels = d2_backbone.num_channels 78 | 79 | transformer = build_transformer(cfg.MODEL.DETR.TRANSFORMER, d_model=hidden_dim, dropout=dropout, nhead=nheads, dim_feedforward=dim_feedforward, num_encoder_layers=enc_layers, num_decoder_layers=dec_layers, normalize_before=pre_norm, return_intermediate_dec=deep_supervision, num_object_decoder_layers=obj_dec_layers, num_classes=self.num_classes, num_relation_classes=self.num_relation_classes, beta=beta) 80 | 81 | self.detr = build_detr(cfg.MODEL.DETR.NAME, backbone, transformer, num_classes=self.num_classes, num_queries=num_queries, aux_loss=deep_supervision, use_gt_box=self.use_gt_box, use_gt_label = self.use_gt_label, num_relation_queries=num_relation_queries) 82 | if self.mask_on: 83 | frozen_weights = cfg.MODEL.DETR.FROZEN_WEIGHTS 84 | if frozen_weights != '': 85 | print("LOAD pre-trained weights") 86 | weight = torch.load(frozen_weights, map_location=lambda storage, loc: storage)['model'] 87 | new_weight = {} 88 | for k, v in weight.items(): 89 | if 'detr.' in k: 90 | new_weight[k.replace('detr.', '')] = v 91 | else: 92 | print(f"Skipping loading weight {k} from frozen model") 93 | del weight 94 | self.detr.load_state_dict(new_weight) 95 | del new_weight 96 | self.detr = DETRsegm(self.detr, freeze_detr=(frozen_weights != '')) 97 | self.seg_postprocess = PostProcessSegm 98 | if cfg.MODEL.DETR.FROZEN_WEIGHTS != '': 99 | self.load_detr_from_pretrained(cfg.MODEL.DETR.FROZEN_WEIGHTS) 100 | self.detr.to(self.device) 101 | 102 | # building criterion 103 | matcher = build_matcher(cfg.MODEL.DETR.MATCHER, cost_class=cost_class, cost_bbox=l1_weight, cost_giou=giou_weight, topk=matcher_topk) 104 | weight_dict = {"loss_ce": cost_class, "loss_bbox": l1_weight, 'loss_ce_subject': cost_class, 'loss_ce_object': cost_class, 'loss_bbox_subject': l1_weight, 'loss_bbox_object': l1_weight, 'loss_giou_subject': giou_weight, 'loss_giou_object': giou_weight, 'loss_relation': 1, 'loss_bbox_relation': l1_weight, 'loss_giou_relation':giou_weight, 'loss_nms':nms_weight, 'loss_selection_subject': cost_selection, 'loss_selection_object': cost_selection} 105 | weight_dict["loss_giou"] = giou_weight 106 | if deep_supervision: 107 | aux_weight_dict = {} 108 | for i in range(max(dec_layers, obj_dec_layers) - 1): 109 | aux_weight_dict.update({k + f"_{i}": v for k, v in weight_dict.items()}) 110 | weight_dict.update(aux_weight_dict) 111 | print (weight_dict) 112 | losses = ["labels", "boxes", "cardinality"] 113 | if self.mask_on: 114 | losses += ["masks"] 115 | statistics = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]).statistics 116 | self.criterion = build_criterion(cfg.MODEL.DETR.CRITERION, self.num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses, use_gt_box=self.use_gt_box, use_gt_label=self.use_gt_label, num_relation_classes=self.num_relation_classes, intersection_iou_threshold=cfg.MODEL.DETR.INTERSECTION_IOU_THRESHOLD, intersection_iou_lambda=cfg.MODEL.DETR.INTERSECTION_IOU_LAMBDA, intersection_loss=cfg.MODEL.DETR.INTERSECTION_LOSS, rel_eos_coef=no_rel_weight, statistics=statistics, reweight_relations=cfg.MODEL.DETR.REWEIGHT_RELATIONS, reweight_rel_eos_coef=cfg.MODEL.DETR.REWEIGHT_REL_EOS_COEF, neg_rel_fraction=cfg.MODEL.DETR.NEGATIVE_RELATION_FRACTION, max_rel_pairs=cfg.MODEL.DETR.MAX_RELATION_PAIRS, use_reweight_log=cfg.MODEL.DETR.REWEIGHT_USE_LOG, focal_alpha=cfg.MODEL.DETR.FOCAL_ALPHA, create_bg_pairs=create_bg_pairs, oversample_param=cfg.MODEL.DETR.OVERSAMPLE_PARAM, undersample_param=cfg.MODEL.DETR.UNDERSAMPLE_PARAM) 117 | self.criterion.to(self.device) 118 | 119 | pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) 120 | pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1) 121 | self.normalizer = lambda x: (x - pixel_mean) / pixel_std 122 | 123 | self.to(self.device) 124 | self._freeze_layers(layers=cfg.MODEL.DETR.FREEZE_LAYERS) 125 | pytorch_total_params = sum(p.numel() for p in self.parameters()) 126 | print ("Number of Parameters:", pytorch_total_params) 127 | 128 | def _freeze_layers(self, layers): 129 | # Freeze layers 130 | for name, param in self.named_parameters(): 131 | if any(layer in name for layer in layers): 132 | logging.getLogger('detectron2').log(logging.WARN, "Freezed Layer: {}".format(name)) 133 | param.requires_grad = False 134 | 135 | def load_detr_from_pretrained(self, path): 136 | print("Loading DETR checkpoint from pretrained: ", path) 137 | pretrained_detr = torch.load(path)['model'] 138 | pretrained_detr_without_class_head = {k: v for k, v in pretrained_detr.items() if 'class_embed' not in k} 139 | self.detr.load_state_dict(pretrained_detr_without_class_head, strict=False) 140 | 141 | def forward(self, batched_inputs): 142 | """ 143 | Args: 144 | batched_inputs: a list, batched outputs of :class:`DatasetMapper` . 145 | Each item in the list contains the inputs for one image. 146 | For now, each item in the list is a dict that contains: 147 | * image: Tensor, image in (C, H, W) format. 148 | * instances: Instances 149 | Other information that's included in the original dicts, such as: 150 | * "height", "width" (int): the output resolution of the model, used in inference. 151 | See :meth:`postprocess` for details. 152 | Returns: 153 | dict[str: Tensor]: 154 | mapping from a named loss to a tensor storing the loss. Used during training only. 155 | """ 156 | images = self.preprocess_image(batched_inputs) 157 | output = self.detr(images) 158 | 159 | if self.training: 160 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 161 | 162 | targets = self.prepare_targets(gt_instances) 163 | loss_dict = self.criterion(output, targets) 164 | weight_dict = self.criterion.weight_dict 165 | for k in loss_dict.keys(): 166 | if k in weight_dict: 167 | loss_dict[k] *= weight_dict[k] 168 | return loss_dict 169 | else: 170 | box_cls = output["pred_logits"] 171 | box_pred = output["pred_boxes"] 172 | mask_pred = output["pred_masks"] if self.mask_on else None 173 | results = self.inference(box_cls, box_pred, mask_pred, images.image_sizes) 174 | processed_results = [] 175 | for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): 176 | height = input_per_image.get("height", image_size[0]) 177 | width = input_per_image.get("width", image_size[1]) 178 | r = detector_postprocess(results_per_image, height, width) 179 | processed_results.append({"instances": r}) 180 | return processed_results 181 | 182 | def prepare_targets(self, targets): 183 | new_targets = [] 184 | for targets_per_image in targets: 185 | h, w = targets_per_image.image_size 186 | image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) 187 | gt_classes = targets_per_image.gt_classes 188 | gt_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy 189 | gt_boxes = box_xyxy_to_cxcywh(gt_boxes) 190 | new_targets.append({"labels": gt_classes, "boxes": gt_boxes}) 191 | if self.mask_on and hasattr(targets_per_image, 'gt_masks'): 192 | gt_masks = targets_per_image.gt_masks 193 | gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w) 194 | new_targets[-1].update({'masks': gt_masks}) 195 | return new_targets 196 | 197 | def inference(self, box_cls, box_pred, mask_pred, image_sizes): 198 | """ 199 | Arguments: 200 | box_cls (Tensor): tensor of shape (batch_size, num_queries, K). 201 | The tensor predicts the classification probability for each query. 202 | box_pred (Tensor): tensors of shape (batch_size, num_queries, 4). 203 | The tensor predicts 4-vector (x,y,w,h) box 204 | regression values for every queryx 205 | image_sizes (List[torch.Size]): the input image sizes 206 | Returns: 207 | results (List[Instances]): a list of #images elements. 208 | """ 209 | assert len(box_cls) == len(image_sizes) 210 | results = [] 211 | 212 | # For each box we assign the best class or the second best if the best on is `no_object`. 213 | scores, labels = F.softmax(box_cls, dim=-1)[:, :, :-1].max(-1) 214 | 215 | for i, (scores_per_image, labels_per_image, box_pred_per_image, image_size) in enumerate(zip( 216 | scores, labels, box_pred, image_sizes 217 | )): 218 | result = Instances(image_size) 219 | result.pred_boxes = Boxes(box_cxcywh_to_xyxy(box_pred_per_image)) 220 | 221 | result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) 222 | if self.mask_on: 223 | mask = F.interpolate(mask_pred[i].unsqueeze(0), size=image_size, mode='bilinear', align_corners=False) 224 | mask = mask[0].sigmoid() > 0.5 225 | B, N, H, W = mask_pred.shape 226 | mask = BitMasks(mask.cpu()).crop_and_resize(result.pred_boxes.tensor.cpu(), 32) 227 | result.pred_masks = mask.unsqueeze(1).to(mask_pred[0].device) 228 | 229 | result.scores = scores_per_image 230 | result.pred_classes = labels_per_image 231 | results.append(result) 232 | return results 233 | 234 | def preprocess_image(self, batched_inputs): 235 | """ 236 | Normalize, pad and batch the input images. 237 | """ 238 | images = [self.normalizer(x["image"].to(self.device)) for x in batched_inputs] 239 | images = ImageList.from_tensors(images) 240 | return images 241 | 242 | @META_ARCH_REGISTRY.register() 243 | class IterativeRelationDetr(Detr): 244 | def forward(self, batched_inputs): 245 | """ 246 | Args: 247 | batched_inputs: a list, batched outputs of :class:`DatasetMapper` . 248 | Each item in the list contains the inputs for one image. 249 | For now, each item in the list is a dict that contains: 250 | * image: Tensor, image in (C, H, W) format. 251 | * instances: Instances 252 | Other information that's included in the original dicts, such as: 253 | * "height", "width" (int): the output resolution of the model, used in inference. 254 | See :meth:`postprocess` for details. 255 | Returns: 256 | dict[str: Tensor]: 257 | mapping from a named loss to a tensor storing the loss. Used during training only. 258 | """ 259 | images = self.preprocess_image(batched_inputs) 260 | output = self.detr(images) 261 | 262 | if self.training: 263 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 264 | gt_relations = [x["relations"].to(self.device) for x in batched_inputs] 265 | targets = self.prepare_targets((gt_instances, gt_relations)) 266 | 267 | loss_dict = self.criterion(output, targets) 268 | weight_dict = self.criterion.weight_dict 269 | for k in loss_dict.keys(): 270 | if k in weight_dict: 271 | loss_dict[k] *= weight_dict[k] 272 | return loss_dict 273 | else: 274 | results = self.inference(output, images.image_sizes) 275 | processed_results = [] 276 | for results_per_image, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): 277 | height = input_per_image.get("height", image_size[0]) 278 | width = input_per_image.get("width", image_size[1]) 279 | r = detector_postprocess(results_per_image, height, width) 280 | processed_results.append({"instances": r, 281 | "rel_pair_idxs": results_per_image._rel_pair_idxs, 282 | "pred_rel_scores": results_per_image._pred_rel_scores 283 | }) 284 | return processed_results 285 | 286 | def inference(self, output, image_sizes): 287 | """ 288 | Arguments: 289 | box_cls (Tensor): tensor of shape (batch_size, num_queries, K). 290 | The tensor predicts the classification probability for each query. 291 | box_pred (Tensor): tensors of shape (batch_size, num_queries, 4). 292 | The tensor predicts 4-vector (x,y,w,h) box 293 | regression values for every queryx 294 | image_sizes (List[torch.Size]): the input image sizes 295 | Returns: 296 | results (List[Instances]): a list of #images elements. 297 | """ 298 | results = [] 299 | if self.test_index == -1: 300 | logits_r = F.softmax(output['relation_logits'], -1) 301 | 302 | # For each box we assign the best class or the second best if the best on is `no_object`. 303 | scores_s, labels_s = F.softmax(output['relation_subject_logits'], -1)[:, :, :-1].max(-1) 304 | scores_o, labels_o = F.softmax(output['relation_object_logits'], -1)[:, :, :-1].max(-1) 305 | scores_r, labels_r = logits_r[:, :, :-1].max(-1) 306 | 307 | box_s = output['relation_subject_boxes'] 308 | box_o = output['relation_object_boxes'] 309 | else: 310 | logits_r = F.softmax(output['aux_outputs_r'][self.test_index]['pred_logits'], -1) 311 | # For each box we assign the best class or the second best if the best on is `no_object`. 312 | scores_s, labels_s = F.softmax(output['aux_outputs_r_sub'][self.test_index]['pred_logits'], -1)[:, :, :-1].max(-1) 313 | scores_o, labels_o = F.softmax(output['aux_outputs_r_obj'][self.test_index]['pred_logits'], -1)[:, :, :-1].max(-1) 314 | scores_r, labels_r = logits_r[:, :, :-1].max(-1) 315 | 316 | box_s = output['aux_outputs_r_sub'][self.test_index]['pred_boxes'] 317 | box_o = output['aux_outputs_r_obj'][self.test_index]['pred_boxes'] 318 | 319 | for i, (scores_per_image_s, labels_per_image_s, box_per_image_s, scores_per_image_o, labels_per_image_o, box_per_image_o, scores_per_image_r, labels_per_image_r, logits_per_image_r, image_size) in enumerate(zip( 320 | scores_s, labels_s, box_s, scores_o, labels_o, box_o, scores_r, labels_r, logits_r, image_sizes 321 | )): 322 | 323 | image_boxes = Boxes(box_cxcywh_to_xyxy(torch.cat([box_per_image_s, box_per_image_o]))) 324 | image_scores = torch.cat([scores_per_image_s, scores_per_image_o]) 325 | image_pred_classes = torch.cat([labels_per_image_s, labels_per_image_o]) 326 | keep = batched_nms(image_boxes.tensor, image_scores, image_pred_classes, self.nms_thresh) 327 | keep_classes = image_pred_classes[keep] 328 | ious = pairwise_iou(image_boxes, image_boxes[keep]) 329 | iou_assignments = torch.zeros_like(image_pred_classes) 330 | for class_id in torch.unique(keep_classes): 331 | curr_indices = torch.where(image_pred_classes == class_id)[0] 332 | curr_keep_indices = torch.where(keep_classes == class_id)[0] 333 | curr_ious = ious[curr_indices][:, curr_keep_indices] 334 | curr_iou_assignment = curr_keep_indices[curr_ious.argmax(-1)] 335 | iou_assignments[curr_indices] = curr_iou_assignment 336 | 337 | result = Instances(image_size) 338 | result.pred_boxes = image_boxes[keep] 339 | result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0]) 340 | result.scores = image_scores[keep] 341 | result.pred_classes = image_pred_classes[keep] 342 | 343 | # rel_pair_indexer = torch.arange(labels_per_image_s.size(0)).to(labels_per_image_s.device) 344 | # rel_pair_idx = torch.stack([rel_pair_indexer, rel_pair_indexer + labels_per_image_s.size(0)], 1) 345 | rel_pair_idx = torch.stack(torch.split(iou_assignments, labels_per_image_s.size(0)), 1) 346 | 347 | triple_scores = scores_per_image_r * scores_per_image_s * scores_per_image_o 348 | _, sorting_idx = torch.sort(triple_scores, descending=True) 349 | rel_pair_idx = rel_pair_idx[sorting_idx] 350 | rel_class_prob = logits_per_image_r[sorting_idx] 351 | rel_labels = labels_per_image_r[sorting_idx] 352 | 353 | triplets = torch.cat((rel_pair_idx, rel_labels.unsqueeze(-1)), -1) 354 | unique_triplets = {} 355 | keep_triplet = torch.zeros_like(rel_labels) 356 | for idx, triplet in enumerate(triplets): 357 | if "{}-{}-{}".format(triplet[0], triplet[1], triplet[2]) not in unique_triplets: 358 | unique_triplets[ "{}-{}-{}".format(triplet[0], triplet[1], triplet[2])] = 1 359 | keep_triplet[idx] = 1 360 | 361 | result._rel_pair_idxs = rel_pair_idx[keep_triplet == 1] # (#rel, 2) 362 | result._pred_rel_scores = rel_class_prob[keep_triplet == 1] # (#rel, #rel_class) 363 | result._pred_rel_labels = rel_labels[keep_triplet == 1] # (#rel, ) 364 | results.append(result) 365 | return results 366 | 367 | def boxes_union(self, boxes1, boxes2): 368 | """ 369 | Compute the union region of two set of boxes 370 | Arguments: 371 | box1: (Boxes) bounding boxes, sized [N,4]. 372 | box2: (Boxes) bounding boxes, sized [N,4]. 373 | Returns: 374 | (Boxes) union, sized [N,4]. 375 | """ 376 | assert len(boxes1) == len(boxes2) 377 | 378 | union_box = torch.cat(( 379 | torch.min(boxes1.tensor[:,:2], boxes2.tensor[:,:2]), 380 | torch.max(boxes1.tensor[:,2:], boxes2.tensor[:,2:]) 381 | ),dim=1) 382 | return Boxes(union_box) 383 | 384 | def get_center_coords(self, boxes): 385 | x0, y0, x1, y1 = boxes.unbind(-1) 386 | b = [(x0 + x1) / 2, (y0 + y1) / 2] 387 | return torch.stack(b, dim=-1) 388 | 389 | def center_xyxy_to_cxcywh(self, x): 390 | x0, y0, x1, y1 = x.unbind(-1) 391 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 392 | torch.abs(x1 - x0), torch.abs(y1 - y0)] 393 | return torch.stack(b, dim=-1) 394 | 395 | def prepare_targets(self, targets, box_threshold=1e-5): 396 | new_targets = [] 397 | for image_idx, (targets_per_image, relations_per_image) in enumerate(zip(targets[0], targets[1])): 398 | h, w = targets_per_image.image_size 399 | image_size_xyxy = torch.as_tensor([w, h, w, h], dtype=torch.float, device=self.device) 400 | subject_boxes = targets_per_image.gt_boxes[relations_per_image[:, 0]] 401 | object_boxes = targets_per_image.gt_boxes[relations_per_image[:, 1]] 402 | 403 | gt_boxes = self.boxes_union(subject_boxes, object_boxes) 404 | gt_boxes = gt_boxes.tensor / image_size_xyxy 405 | gt_boxes = box_xyxy_to_cxcywh(gt_boxes) 406 | 407 | gt_subject_classes = targets_per_image.gt_classes[relations_per_image[:, 0]] 408 | gt_subject_boxes = subject_boxes.tensor / image_size_xyxy 409 | gt_subject_boxes = box_xyxy_to_cxcywh(gt_subject_boxes) 410 | 411 | gt_object_classes = targets_per_image.gt_classes[relations_per_image[:, 1]] 412 | gt_object_boxes = object_boxes.tensor / image_size_xyxy 413 | gt_object_boxes = box_xyxy_to_cxcywh(gt_object_boxes) 414 | 415 | gt_combined_classes = targets_per_image.gt_classes 416 | gt_combined_boxes = targets_per_image.gt_boxes.tensor / image_size_xyxy 417 | gt_combined_boxes = box_xyxy_to_cxcywh(gt_combined_boxes) 418 | 419 | subject_boxes_center = self.get_center_coords(subject_boxes.tensor) 420 | object_boxes_center = self.get_center_coords(object_boxes.tensor) 421 | center_boxes = torch.cat([subject_boxes_center, object_boxes_center], -1) 422 | center_boxes = center_boxes / image_size_xyxy 423 | center_boxes = self.center_xyxy_to_cxcywh(center_boxes) 424 | 425 | # Remove degenerate boxes 426 | center_boxes_xyxy = Boxes(box_cxcywh_to_xyxy(center_boxes)) 427 | center_boxes_xyxy.scale(scale_x=targets_per_image.image_size[1], scale_y=targets_per_image.image_size[0]) 428 | center_masks = center_boxes_xyxy.nonempty(threshold=box_threshold) 429 | center_boxes[~center_masks] = gt_subject_boxes[~center_masks] 430 | 431 | gt_classes = relations_per_image[:, 2] 432 | new_targets.append({"labels": gt_classes, "boxes": gt_boxes, 'subject_boxes': gt_subject_boxes, 'object_boxes': gt_object_boxes, 'combined_boxes': gt_combined_boxes, 'subject_labels': gt_subject_classes, 'object_labels': gt_object_classes, 'combined_labels': gt_combined_classes, 'image_relations': relations_per_image, 'relation_boxes':center_boxes, 'relation_labels':relations_per_image[:, 2]}) 433 | return new_targets 434 | 435 | -------------------------------------------------------------------------------- /modeling/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .detr import * 2 | from .criterion import * 3 | from .transformer import * 4 | from .matcher import * 5 | from .positional_encoding import * -------------------------------------------------------------------------------- /modeling/transformer/detr.py: -------------------------------------------------------------------------------- 1 | """ 2 | DETR model and criterion classes. 3 | """ 4 | from multiprocessing import Condition 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | from .util.misc import (NestedTensor, nested_tensor_from_tensor_list, 10 | accuracy, get_world_size, interpolate, 11 | is_dist_avail_and_initialized, inverse_sigmoid) 12 | 13 | import copy 14 | import numpy as np 15 | from detectron2.utils.registry import Registry 16 | import math 17 | 18 | DETR_REGISTRY = Registry("DETR_REGISTRY") 19 | 20 | 21 | @DETR_REGISTRY.register() 22 | class DETR(nn.Module): 23 | """ This is the DETR module that performs object detection """ 24 | def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, use_gt_box=False, use_gt_label=False, **kwargs): 25 | """ Initializes the model. 26 | Parameters: 27 | backbone: torch module of the backbone to be used. See backbone.py 28 | transformer: torch module of the transformer architecture. See transformer.py 29 | num_classes: number of object classes 30 | num_queries: number of object queries, ie detection slot. This is the maximal number of objects 31 | DETR can detect in a single image. For COCO, we recommend 100 queries. 32 | aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. 33 | """ 34 | super().__init__() 35 | self.num_queries = num_queries 36 | self.transformer = transformer 37 | hidden_dim = transformer.d_model 38 | self.class_embed = nn.Linear(hidden_dim, num_classes + 1) 39 | self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) 40 | self.query_embed = nn.Embedding(num_queries, hidden_dim) 41 | self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1) 42 | self.backbone = backbone 43 | self.aux_loss = aux_loss 44 | 45 | def forward(self, samples: NestedTensor): 46 | """ The forward expects a NestedTensor, which consists of: 47 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W] 48 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels 49 | It returns a dict with the following elements: 50 | - "pred_logits": the classification logits (including no-object) for all queries. 51 | Shape= [batch_size x num_queries x (num_classes + 1)] 52 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 53 | (center_x, center_y, height, width). These values are normalized in [0, 1], 54 | relative to the size of each individual image (disregarding possible padding). 55 | See PostProcess for information on how to retrieve the unnormalized bounding box. 56 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 57 | dictionnaries containing the two above keys for each decoder layer. 58 | """ 59 | if isinstance(samples, (list, torch.Tensor)): 60 | samples = nested_tensor_from_tensor_list(samples) 61 | features, pos = self.backbone(samples) 62 | 63 | src, mask = features[-1].decompose() 64 | assert mask is not None 65 | hs = self.transformer(self.input_proj(src), mask, self.query_embed.weight, pos[-1])[0] 66 | 67 | outputs_class = self.class_embed(hs) 68 | outputs_coord = self.bbox_embed(hs).sigmoid() 69 | out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]} 70 | if self.aux_loss: 71 | out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord) 72 | return out 73 | 74 | @torch.jit.unused 75 | def _set_aux_loss(self, outputs_class, outputs_coord): 76 | # this is a workaround to make torchscript happy, as torchscript 77 | # doesn't support dictionary with non-homogeneous values, such 78 | # as a dict having both a Tensor and a list. 79 | return [{'pred_logits': a, 'pred_boxes': b} 80 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] 81 | 82 | @DETR_REGISTRY.register() 83 | class IterativeRelationDETR(DETR): 84 | def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False, use_gt_box=False, use_gt_label=False, **kwargs): 85 | super().__init__(backbone=backbone, transformer=transformer, num_classes=num_classes, num_queries=num_queries, aux_loss=aux_loss, use_gt_box=use_gt_box, use_gt_label=use_gt_label, **kwargs) 86 | self.relation_query_embed = nn.Embedding(num_queries, transformer.d_model) 87 | self.object_query_embed = nn.Embedding(num_queries, transformer.d_model) 88 | del self.class_embed 89 | del self.bbox_embed 90 | 91 | def forward(self, samples: NestedTensor): 92 | """ The forward expects a NestedTensor, which consists of: 93 | - samples.tensor: batched images, of shape [batch_size x 3 x H x W] 94 | - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels 95 | It returns a dict with the following elements: 96 | - "pred_logits": the classification logits (including no-object) for all queries. 97 | Shape= [batch_size x num_queries x (num_classes + 1)] 98 | - "pred_boxes": The normalized boxes coordinates for all queries, represented as 99 | (center_x, center_y, height, width). These values are normalized in [0, 1], 100 | relative to the size of each individual image (disregarding possible padding). 101 | See PostProcess for information on how to retrieve the unnormalized bounding box. 102 | - "aux_outputs": Optional, only returned when auxilary losses are activated. It is a list of 103 | dictionnaries containing the two above keys for each decoder layer. 104 | """ 105 | if isinstance(samples, (list, torch.Tensor)): 106 | samples = nested_tensor_from_tensor_list(samples) 107 | features, pos = self.backbone(samples) 108 | 109 | src, mask = features[-1].decompose() 110 | assert mask is not None 111 | 112 | output = self.transformer(self.input_proj(src), mask, self.query_embed.weight, self.object_query_embed.weight, self.relation_query_embed.weight, pos[-1]) 113 | 114 | out = { 115 | 'relation_boxes': output['relation_coords'][-1], 116 | 'relation_logits': output['relation_logits'][-1], 117 | 'relation_subject_logits': output['relation_subject_logits'][-1], 118 | 'relation_object_logits': output['relation_object_logits'][-1], 119 | 'relation_subject_boxes': output['relation_subject_coords'][-1], 120 | 'relation_object_boxes': output['relation_object_coords'][-1] 121 | } 122 | 123 | if self.aux_loss: 124 | out['aux_outputs_r'] = self._set_aux_loss(output['relation_logits'], output['relation_coords']) 125 | out['aux_outputs_r_sub'] = self._set_aux_loss(output['relation_subject_logits'], output['relation_subject_coords']) 126 | out['aux_outputs_r_obj'] = self._set_aux_loss(output['relation_object_logits'], output['relation_object_coords']) 127 | 128 | return out 129 | 130 | @torch.jit.unused 131 | def _set_aux_loss(self, outputs_class, outputs_coord=None): 132 | # this is a workaround to make torchscript happy, as torchscript 133 | # doesn't support dictionary with non-homogeneous values, such 134 | # as a dict having both a Tensor and a list. 135 | if outputs_coord is not None: 136 | return [{'pred_logits': a, 'pred_boxes': b} 137 | for a, b in zip(outputs_class[:-1], outputs_coord[:-1])] 138 | else: 139 | return [{'pred_logits': a} 140 | for a in outputs_class[:-1]] 141 | 142 | 143 | 144 | class MLP(nn.Module): 145 | """ Very simple multi-layer perceptron (also called FFN)""" 146 | 147 | def __init__(self, input_dim, hidden_dim, output_dim, num_layers): 148 | super().__init__() 149 | self.num_layers = num_layers 150 | h = [hidden_dim] * (num_layers - 1) 151 | self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) 152 | 153 | def forward(self, x): 154 | for i, layer in enumerate(self.layers): 155 | x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) 156 | return x 157 | 158 | def _get_clones(module, N): 159 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 160 | 161 | def gen_sineembed_for_position(pos_tensor): 162 | # n_query, bs, _ = pos_tensor.size() 163 | # sineembed_tensor = torch.zeros(n_query, bs, 256) 164 | scale = 2 * math.pi 165 | dim_t = torch.arange(128, dtype=torch.float32, device=pos_tensor.device) 166 | dim_t = 10000 ** (2 * (dim_t // 2) / 128) 167 | x_embed = pos_tensor[:, :, 0] * scale 168 | y_embed = pos_tensor[:, :, 1] * scale 169 | pos_x = x_embed[:, :, None] / dim_t 170 | pos_y = y_embed[:, :, None] / dim_t 171 | pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) 172 | pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2) 173 | pos = torch.cat((pos_y, pos_x), dim=2) 174 | return pos 175 | 176 | def build_detr(name, backbone, transformer, num_classes, num_queries, aux_loss=False, use_gt_box=False, use_gt_label=False, **kwargs): 177 | return DETR_REGISTRY.get(name)(backbone, transformer, num_classes, num_queries, aux_loss=aux_loss, use_gt_box=use_gt_box, use_gt_label=use_gt_label, **kwargs) -------------------------------------------------------------------------------- /modeling/transformer/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from scipy.optimize import linear_sum_assignment 3 | from torch import nn 4 | import numpy as np 5 | 6 | from .util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou 7 | 8 | from detectron2.utils.registry import Registry 9 | from torchvision.ops.boxes import box_area 10 | 11 | MATCHER_REGISTRY = Registry("MATCHER_REGISTRY") 12 | 13 | @MATCHER_REGISTRY.register() 14 | class HungarianMatcher(nn.Module): 15 | """This class computes an assignment between the targets and the predictions of the network 16 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 17 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 18 | while the others are un-matched (and thus treated as non-objects). 19 | """ 20 | 21 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 22 | """Creates the matcher 23 | Params: 24 | cost_class: This is the relative weight of the classification error in the matching cost 25 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 26 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 27 | """ 28 | super().__init__() 29 | self.cost_class = cost_class 30 | self.cost_bbox = cost_bbox 31 | self.cost_giou = cost_giou 32 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 33 | 34 | @torch.no_grad() 35 | def forward(self, outputs, targets, return_cost_matrix=False): 36 | """ Performs the matching 37 | Params: 38 | outputs: This is a dict that contains at least these entries: 39 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 40 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 41 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 42 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 43 | objects in the target) containing the class labels 44 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 45 | Returns: 46 | A list of size batch_size, containing tuples of (index_i, index_j) where: 47 | - index_i is the indices of the selected predictions (in order) 48 | - index_j is the indices of the corresponding selected targets (in order) 49 | For each batch element, it holds: 50 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 51 | """ 52 | bs, num_queries = outputs["pred_logits"].shape[:2] 53 | 54 | # We flatten to compute the cost matrices in a batch 55 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 56 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 57 | 58 | # Also concat the target labels and boxes 59 | tgt_ids = torch.cat([v["labels"] for v in targets]) 60 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 61 | 62 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 63 | # but approximate it in 1 - proba[target class]. 64 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 65 | cost_class = -out_prob[:, tgt_ids] 66 | 67 | # Compute the L1 cost between boxes 68 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 69 | 70 | # Compute the giou cost betwen boxes 71 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 72 | 73 | # Final cost matrix 74 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 75 | C = C.view(bs, num_queries, -1).cpu() 76 | 77 | sizes = [len(v["boxes"]) for v in targets] 78 | C_split = C.split(sizes, -1) 79 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C_split)] 80 | if not return_cost_matrix: 81 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 82 | else: 83 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices], C_split 84 | 85 | @MATCHER_REGISTRY.register() 86 | class IterativeHungarianMatcher(nn.Module): 87 | """This class computes an assignment between the targets and the predictions of the network 88 | For efficiency reasons, the targets don't include the no_object. Because of this, in general, 89 | there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, 90 | while the others are un-matched (and thus treated as non-objects). 91 | """ 92 | 93 | def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): 94 | """Creates the matcher 95 | Params: 96 | cost_class: This is the relative weight of the classification error in the matching cost 97 | cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost 98 | cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost 99 | """ 100 | super().__init__() 101 | self.cost_class = cost_class 102 | self.cost_bbox = cost_bbox 103 | self.cost_giou = cost_giou 104 | assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0" 105 | 106 | @torch.no_grad() 107 | def forward(self, outputs, targets, return_cost_matrix=False, mask=None): 108 | """ Performs the matching 109 | Params: 110 | outputs: This is a dict that contains at least these entries: 111 | "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits 112 | "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates 113 | targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing: 114 | "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth 115 | objects in the target) containing the class labels 116 | "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates 117 | Returns: 118 | A list of size batch_size, containing tuples of (index_i, index_j) where: 119 | - index_i is the indices of the selected predictions (in order) 120 | - index_j is the indices of the corresponding selected targets (in order) 121 | For each batch element, it holds: 122 | len(index_i) = len(index_j) = min(num_queries, num_target_boxes) 123 | """ 124 | bs, num_queries = outputs["pred_logits"].shape[:2] 125 | 126 | # We flatten to compute the cost matrices in a batch 127 | out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1) # [batch_size * num_queries, num_classes] 128 | out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] 129 | 130 | # Also concat the target labels and boxes 131 | tgt_ids = torch.cat([v["labels"] for v in targets]) 132 | tgt_bbox = torch.cat([v["boxes"] for v in targets]) 133 | 134 | # Compute the classification cost. Contrary to the loss, we don't use the NLL, 135 | # but approximate it in 1 - proba[target class]. 136 | # The 1 is a constant that doesn't change the matching, it can be ommitted. 137 | cost_class = -out_prob[:, tgt_ids] 138 | 139 | # Compute the L1 cost between boxes 140 | cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) 141 | 142 | # Compute the giou cost betwen boxes 143 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) 144 | 145 | # Final cost matrix 146 | C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou 147 | C = C.view(bs, num_queries, -1).cpu() 148 | if mask is not None: 149 | C[:, ~mask] = np.float("inf") 150 | 151 | sizes = [len(v["boxes"]) for v in targets] 152 | C_split = C.split(sizes, -1) 153 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C_split)] 154 | 155 | if not return_cost_matrix: 156 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 157 | else: 158 | return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices], C_split 159 | 160 | @torch.no_grad() 161 | def forward_relation(self, outputs, targets, return_cost_matrix=False): 162 | bs, num_queries = outputs["relation_logits"].shape[:2] 163 | out_prob = outputs["relation_logits"].flatten(0, 1).softmax(-1) 164 | out_sub_prob = outputs["relation_subject_logits"].flatten(0, 1).softmax(-1) 165 | out_obj_prob = outputs["relation_object_logits"].flatten(0, 1).softmax(-1) 166 | out_bbox = outputs["relation_boxes"].flatten(0, 1) 167 | out_sub_bbox = outputs["relation_subject_boxes"].flatten(0, 1) 168 | out_obj_bbox = outputs["relation_object_boxes"].flatten(0, 1) 169 | 170 | aux_out_prob = [output['pred_logits'].flatten(0, 1).softmax(-1) for output in outputs['aux_outputs_r']] 171 | aux_out_sub_prob = [output['pred_logits'].flatten(0, 1).softmax(-1) for output in outputs['aux_outputs_r_sub']] 172 | aux_out_obj_prob = [output['pred_logits'].flatten(0, 1).softmax(-1) for output in outputs['aux_outputs_r_obj']] 173 | aux_out_bbox = [output['pred_boxes'].flatten(0, 1) for output in outputs['aux_outputs_r']] 174 | aux_out_sub_bbox = [output['pred_boxes'].flatten(0, 1) for output in outputs['aux_outputs_r_sub']] 175 | aux_out_obj_bbox = [output['pred_boxes'].flatten(0, 1) for output in outputs['aux_outputs_r_obj']] 176 | 177 | device = out_prob.device 178 | 179 | gt_labels = [v['combined_labels'] for v in targets] 180 | gt_boxes = [v['combined_boxes'] for v in targets] 181 | relations = [v["image_relations"] for v in targets] 182 | relation_boxes = [v['relation_boxes'] for v in targets] 183 | 184 | if len(relations) > 0: 185 | tgt_ids = torch.cat(relations)[:, 2] 186 | tgt_sub_labels = torch.cat([gt_label[relation[:, 0]] for gt_label, relation in zip(gt_labels, relations)]) 187 | tgt_obj_labels = torch.cat([gt_label[relation[:, 1]] for gt_label, relation in zip(gt_labels, relations)]) 188 | tgt_boxes = torch.cat(relation_boxes) 189 | tgt_sub_boxes = torch.cat([gt_box[relation[:, 0]] for gt_box, relation in zip(gt_boxes, relations)]) 190 | tgt_obj_boxes = torch.cat([gt_box[relation[:, 1]] for gt_box, relation in zip(gt_boxes, relations)]) 191 | else: 192 | tgt_ids = torch.tensor([]).long().to(device) 193 | tgt_sub_labels = torch.tensor([]).long().to(device) 194 | tgt_obj_labels = torch.tensor([]).long().to(device) 195 | tgt_boxes = torch.zeros((0,4)).to(device) 196 | tgt_sub_boxes = torch.zeros((0,4)).to(device) 197 | tgt_obj_boxes = torch.zeros((0,4)).to(device) 198 | 199 | cost_class = -out_prob[:, tgt_ids] 200 | cost_subject_class = -out_sub_prob[:, tgt_sub_labels] 201 | cost_object_class = -out_obj_prob[:, tgt_obj_labels] 202 | 203 | cost_bbox = torch.cdist(out_bbox, tgt_boxes, p=1) 204 | cost_subject_bbox = torch.cdist(out_sub_bbox, tgt_sub_boxes, p=1) 205 | cost_object_bbox = torch.cdist(out_obj_bbox, tgt_obj_boxes, p=1) 206 | 207 | cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_boxes)) 208 | cost_subject_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_sub_bbox), box_cxcywh_to_xyxy(tgt_sub_boxes)) 209 | cost_object_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_obj_bbox), box_cxcywh_to_xyxy(tgt_obj_boxes)) 210 | 211 | C = self.cost_bbox * (cost_bbox + cost_subject_bbox + cost_object_bbox) + self.cost_class * (cost_class + cost_subject_class + cost_object_class) + self.cost_giou * (cost_giou + cost_subject_giou + cost_object_giou) 212 | 213 | # Add aux loss cost 214 | for aux_idx in range(len(aux_out_prob)): 215 | aux_cost_class = -aux_out_prob[aux_idx][:, tgt_ids] 216 | aux_cost_subject_class = -aux_out_sub_prob[aux_idx][:, tgt_sub_labels] 217 | aux_cost_object_class = -aux_out_obj_prob[aux_idx][:, tgt_obj_labels] 218 | 219 | aux_cost_bbox = torch.cdist(aux_out_bbox[aux_idx], tgt_boxes, p=1) 220 | aux_cost_subject_bbox = torch.cdist(aux_out_sub_bbox[aux_idx], tgt_sub_boxes, p=1) 221 | aux_cost_object_bbox = torch.cdist(aux_out_obj_bbox[aux_idx], tgt_obj_boxes, p=1) 222 | 223 | aux_cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(aux_out_bbox[aux_idx]), box_cxcywh_to_xyxy(tgt_boxes)) 224 | aux_cost_subject_giou = -generalized_box_iou(box_cxcywh_to_xyxy(aux_out_sub_bbox[aux_idx]), box_cxcywh_to_xyxy(tgt_sub_boxes)) 225 | aux_cost_object_giou = -generalized_box_iou(box_cxcywh_to_xyxy(aux_out_obj_bbox[aux_idx]), box_cxcywh_to_xyxy(tgt_obj_boxes)) 226 | aux_C = self.cost_bbox * (aux_cost_bbox + aux_cost_subject_bbox + aux_cost_object_bbox) + self.cost_class * (aux_cost_class + aux_cost_subject_class + aux_cost_object_class) + self.cost_giou * (aux_cost_giou + aux_cost_subject_giou + aux_cost_object_giou) 227 | 228 | C = C + aux_C 229 | 230 | C = C.view(bs, num_queries, -1).cpu() 231 | 232 | sizes = [len(v["image_relations"]) for v in targets] 233 | C_split = C.split(sizes, -1) 234 | indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C_split)] 235 | indices = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] 236 | 237 | # Remaining GT objects matching 238 | pred_masks = {'subject': [], 'object': []} 239 | target_masks = {'subject' :[], 'object': []} 240 | combined_indices = {'subject' :[], 'object': [], 'relation': []} 241 | for image_idx, target in enumerate(targets): 242 | all_objects = torch.arange(len(gt_labels[image_idx])).to(device) 243 | relation = target['image_relations'] 244 | curr_relation_idx = indices[image_idx] 245 | curr_pred_mask = torch.ones(num_queries, device=device) 246 | curr_pred_mask[curr_relation_idx[0]] = 0 247 | curr_pred_mask = (curr_pred_mask == 1) 248 | 249 | combined_indices['relation'].append((curr_relation_idx[0], curr_relation_idx[1])) 250 | for branch_idx, branch_type in enumerate(['subject', 'object']): 251 | combined_indices[branch_type].append((curr_relation_idx[0], relation[:, branch_idx][curr_relation_idx[1]].cpu())) 252 | return combined_indices 253 | 254 | 255 | 256 | def build_matcher(name, cost_class, cost_bbox, cost_giou, topk=1): 257 | if topk == 1: 258 | return MATCHER_REGISTRY.get(name)(cost_class=cost_class, cost_bbox=cost_bbox, cost_giou=cost_giou) 259 | else: 260 | return MATCHER_REGISTRY.get(name)(cost_class=cost_class, cost_bbox=cost_bbox, cost_giou=cost_giou, topk=topk) -------------------------------------------------------------------------------- /modeling/transformer/positional_encoding.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various positional encodings for the transformer. 3 | """ 4 | import math 5 | import torch 6 | from torch import nn,Tensor 7 | 8 | from .util.misc import NestedTensor 9 | from detectron2.utils.registry import Registry 10 | 11 | POSITION_ENCODING_REGISTRY = Registry("POSITION_ENCODING_REGISTRY") 12 | 13 | @POSITION_ENCODING_REGISTRY.register() 14 | class PositionalEncoding(nn.Module): 15 | 16 | def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): 17 | super().__init__() 18 | self.dropout = nn.Dropout(p=dropout) 19 | 20 | position = torch.arange(max_len).unsqueeze(1) 21 | div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) 22 | pe = torch.zeros(max_len, 1, d_model) 23 | pe[:, 0, 0::2] = torch.sin(position * div_term) 24 | pe[:, 0, 1::2] = torch.cos(position * div_term) 25 | self.register_buffer('pe', pe) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | """ 29 | Args: 30 | x: Tensor, shape [seq_len, batch_size, embedding_dim] 31 | """ 32 | x = x + self.pe[:x.size(0)] 33 | return self.dropout(x) 34 | 35 | @POSITION_ENCODING_REGISTRY.register() 36 | class PositionEmbeddingSine(nn.Module): 37 | """ 38 | This is a more standard version of the position embedding, very similar to the one 39 | used by the Attention is all you need paper, generalized to work on images. 40 | """ 41 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): 42 | super().__init__() 43 | self.num_pos_feats = num_pos_feats 44 | self.temperature = temperature 45 | self.normalize = normalize 46 | if scale is not None and normalize is False: 47 | raise ValueError("normalize should be True if scale is passed") 48 | if scale is None: 49 | scale = 2 * math.pi 50 | self.scale = scale 51 | 52 | def forward(self, tensor_list: NestedTensor): 53 | x = tensor_list.tensors 54 | mask = tensor_list.mask 55 | assert mask is not None 56 | not_mask = ~mask 57 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 58 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 59 | if self.normalize: 60 | eps = 1e-6 61 | y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale 62 | x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale 63 | 64 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 65 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 66 | 67 | pos_x = x_embed[:, :, :, None] / dim_t 68 | pos_y = y_embed[:, :, :, None] / dim_t 69 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 70 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 71 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 72 | return pos 73 | 74 | @POSITION_ENCODING_REGISTRY.register() 75 | class DeformableDETRPositionEmbeddingSine(nn.Module): 76 | """ 77 | This is a more standard version of the position embedding, very similar to the one 78 | used by the Attention is all you need paper, generalized to work on images. 79 | """ 80 | def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): 81 | super().__init__() 82 | self.num_pos_feats = num_pos_feats 83 | self.temperature = temperature 84 | self.normalize = normalize 85 | if scale is not None and normalize is False: 86 | raise ValueError("normalize should be True if scale is passed") 87 | if scale is None: 88 | scale = 2 * math.pi 89 | self.scale = scale 90 | 91 | def forward(self, tensor_list: NestedTensor): 92 | x = tensor_list.tensors 93 | mask = tensor_list.mask 94 | assert mask is not None 95 | not_mask = ~mask 96 | y_embed = not_mask.cumsum(1, dtype=torch.float32) 97 | x_embed = not_mask.cumsum(2, dtype=torch.float32) 98 | if self.normalize: 99 | eps = 1e-6 100 | y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale 101 | x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale 102 | 103 | dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) 104 | dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) 105 | 106 | pos_x = x_embed[:, :, :, None] / dim_t 107 | pos_y = y_embed[:, :, :, None] / dim_t 108 | pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) 109 | pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) 110 | pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) 111 | return pos 112 | 113 | @POSITION_ENCODING_REGISTRY.register() 114 | class PositionEmbeddingLearned(nn.Module): 115 | """ 116 | Absolute pos embedding, learned. 117 | """ 118 | def __init__(self, num_pos_feats=256): 119 | super().__init__() 120 | self.row_embed = nn.Embedding(50, num_pos_feats) 121 | self.col_embed = nn.Embedding(50, num_pos_feats) 122 | self.reset_parameters() 123 | 124 | def reset_parameters(self): 125 | nn.init.uniform_(self.row_embed.weight) 126 | nn.init.uniform_(self.col_embed.weight) 127 | 128 | def forward(self, tensor_list: NestedTensor): 129 | x = tensor_list.tensors 130 | h, w = x.shape[-2:] 131 | i = torch.arange(w, device=x.device) 132 | j = torch.arange(h, device=x.device) 133 | x_emb = self.col_embed(i) 134 | y_emb = self.row_embed(j) 135 | pos = torch.cat([ 136 | x_emb.unsqueeze(0).repeat(h, 1, 1), 137 | y_emb.unsqueeze(1).repeat(1, w, 1), 138 | ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) 139 | return pos 140 | 141 | 142 | def build_position_encoding(name, hidden_dim): 143 | N_steps = hidden_dim // 2 144 | return POSITION_ENCODING_REGISTRY.get(name)(N_steps) -------------------------------------------------------------------------------- /modeling/transformer/segmentation.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file provides the definition of the convolutional heads used to predict masks, as well as the losses 3 | """ 4 | import io 5 | from collections import defaultdict 6 | from typing import List, Optional 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | from torch import Tensor 12 | from PIL import Image 13 | 14 | from .util import box_ops as box_ops 15 | from .util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list 16 | 17 | try: 18 | from panopticapi.utils import id2rgb, rgb2id 19 | except ImportError: 20 | pass 21 | 22 | 23 | class DETRsegm(nn.Module): 24 | def __init__(self, detr, freeze_detr=False): 25 | super().__init__() 26 | self.detr = detr 27 | 28 | if freeze_detr: 29 | for p in self.parameters(): 30 | p.requires_grad_(False) 31 | 32 | hidden_dim, nheads = detr.transformer.d_model, detr.transformer.nhead 33 | self.bbox_attention = MHAttentionMap(hidden_dim, hidden_dim, nheads, dropout=0.0) 34 | self.mask_head = MaskHeadSmallConv(hidden_dim + nheads, [1024, 512, 256], hidden_dim) 35 | 36 | def forward(self, samples: NestedTensor): 37 | if isinstance(samples, (list, torch.Tensor)): 38 | samples = nested_tensor_from_tensor_list(samples) 39 | features, pos = self.detr.backbone(samples) 40 | 41 | bs = features[-1].tensors.shape[0] 42 | 43 | src, mask = features[-1].decompose() 44 | assert mask is not None 45 | src_proj = self.detr.input_proj(src) 46 | hs, memory = self.detr.transformer(src_proj, mask, self.detr.query_embed.weight, pos[-1]) 47 | 48 | outputs_class = self.detr.class_embed(hs) 49 | outputs_coord = self.detr.bbox_embed(hs).sigmoid() 50 | out = {"pred_logits": outputs_class[-1], "pred_boxes": outputs_coord[-1]} 51 | if self.detr.aux_loss: 52 | out['aux_outputs'] = self.detr._set_aux_loss(outputs_class, outputs_coord) 53 | 54 | # FIXME h_boxes takes the last one computed, keep this in mind 55 | bbox_mask = self.bbox_attention(hs[-1], memory, mask=mask) 56 | 57 | seg_masks = self.mask_head(src_proj, bbox_mask, [features[2].tensors, features[1].tensors, features[0].tensors]) 58 | outputs_seg_masks = seg_masks.view(bs, self.detr.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) 59 | 60 | out["pred_masks"] = outputs_seg_masks 61 | return out 62 | 63 | 64 | def _expand(tensor, length: int): 65 | return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) 66 | 67 | 68 | class MaskHeadSmallConv(nn.Module): 69 | """ 70 | Simple convolutional head, using group norm. 71 | Upsampling is done using a FPN approach 72 | """ 73 | 74 | def __init__(self, dim, fpn_dims, context_dim): 75 | super().__init__() 76 | 77 | inter_dims = [dim, context_dim // 2, context_dim // 4, context_dim // 8, context_dim // 16, context_dim // 64] 78 | self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1) 79 | self.gn1 = torch.nn.GroupNorm(8, dim) 80 | self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1) 81 | self.gn2 = torch.nn.GroupNorm(8, inter_dims[1]) 82 | self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1) 83 | self.gn3 = torch.nn.GroupNorm(8, inter_dims[2]) 84 | self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1) 85 | self.gn4 = torch.nn.GroupNorm(8, inter_dims[3]) 86 | self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1) 87 | self.gn5 = torch.nn.GroupNorm(8, inter_dims[4]) 88 | self.out_lay = torch.nn.Conv2d(inter_dims[4], 1, 3, padding=1) 89 | 90 | self.dim = dim 91 | 92 | self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1) 93 | self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1) 94 | self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1) 95 | 96 | for m in self.modules(): 97 | if isinstance(m, nn.Conv2d): 98 | nn.init.kaiming_uniform_(m.weight, a=1) 99 | nn.init.constant_(m.bias, 0) 100 | 101 | def forward(self, x: Tensor, bbox_mask: Tensor, fpns: List[Tensor]): 102 | x = torch.cat([_expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) 103 | 104 | x = self.lay1(x) 105 | x = self.gn1(x) 106 | x = F.relu(x) 107 | x = self.lay2(x) 108 | x = self.gn2(x) 109 | x = F.relu(x) 110 | 111 | cur_fpn = self.adapter1(fpns[0]) 112 | if cur_fpn.size(0) != x.size(0): 113 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 114 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 115 | x = self.lay3(x) 116 | x = self.gn3(x) 117 | x = F.relu(x) 118 | 119 | cur_fpn = self.adapter2(fpns[1]) 120 | if cur_fpn.size(0) != x.size(0): 121 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 122 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 123 | x = self.lay4(x) 124 | x = self.gn4(x) 125 | x = F.relu(x) 126 | 127 | cur_fpn = self.adapter3(fpns[2]) 128 | if cur_fpn.size(0) != x.size(0): 129 | cur_fpn = _expand(cur_fpn, x.size(0) // cur_fpn.size(0)) 130 | x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest") 131 | x = self.lay5(x) 132 | x = self.gn5(x) 133 | x = F.relu(x) 134 | 135 | x = self.out_lay(x) 136 | return x 137 | 138 | 139 | class MHAttentionMap(nn.Module): 140 | """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" 141 | 142 | def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True): 143 | super().__init__() 144 | self.num_heads = num_heads 145 | self.hidden_dim = hidden_dim 146 | self.dropout = nn.Dropout(dropout) 147 | 148 | self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 149 | self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) 150 | 151 | nn.init.zeros_(self.k_linear.bias) 152 | nn.init.zeros_(self.q_linear.bias) 153 | nn.init.xavier_uniform_(self.k_linear.weight) 154 | nn.init.xavier_uniform_(self.q_linear.weight) 155 | self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 156 | 157 | def forward(self, q, k, mask: Optional[Tensor] = None): 158 | q = self.q_linear(q) 159 | k = F.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) 160 | qh = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) 161 | kh = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) 162 | weights = torch.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh) 163 | 164 | if mask is not None: 165 | weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf")) 166 | weights = F.softmax(weights.flatten(2), dim=-1).view(weights.size()) 167 | weights = self.dropout(weights) 168 | return weights 169 | 170 | 171 | def dice_loss(inputs, targets, num_boxes): 172 | """ 173 | Compute the DICE loss, similar to generalized IOU for masks 174 | Args: 175 | inputs: A float tensor of arbitrary shape. 176 | The predictions for each example. 177 | targets: A float tensor with the same shape as inputs. Stores the binary 178 | classification label for each element in inputs 179 | (0 for the negative class and 1 for the positive class). 180 | """ 181 | inputs = inputs.sigmoid() 182 | inputs = inputs.flatten(1) 183 | numerator = 2 * (inputs * targets).sum(1) 184 | denominator = inputs.sum(-1) + targets.sum(-1) 185 | loss = 1 - (numerator + 1) / (denominator + 1) 186 | return loss.sum() / num_boxes 187 | 188 | 189 | def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): 190 | """ 191 | Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. 192 | Args: 193 | inputs: A float tensor of arbitrary shape. 194 | The predictions for each example. 195 | targets: A float tensor with the same shape as inputs. Stores the binary 196 | classification label for each element in inputs 197 | (0 for the negative class and 1 for the positive class). 198 | alpha: (optional) Weighting factor in range (0,1) to balance 199 | positive vs negative examples. Default = -1 (no weighting). 200 | gamma: Exponent of the modulating factor (1 - p_t) to 201 | balance easy vs hard examples. 202 | Returns: 203 | Loss tensor 204 | """ 205 | prob = inputs.sigmoid() 206 | ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none") 207 | p_t = prob * targets + (1 - prob) * (1 - targets) 208 | loss = ce_loss * ((1 - p_t) ** gamma) 209 | 210 | if alpha >= 0: 211 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 212 | loss = alpha_t * loss 213 | 214 | return loss.mean(1).sum() / num_boxes 215 | 216 | 217 | class PostProcessSegm(nn.Module): 218 | def __init__(self, threshold=0.5): 219 | super().__init__() 220 | self.threshold = threshold 221 | 222 | @torch.no_grad() 223 | def forward(self, results, outputs, orig_target_sizes, max_target_sizes): 224 | assert len(orig_target_sizes) == len(max_target_sizes) 225 | max_h, max_w = max_target_sizes.max(0)[0].tolist() 226 | outputs_masks = outputs["pred_masks"].squeeze(2) 227 | outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False) 228 | outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() 229 | 230 | for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)): 231 | img_h, img_w = t[0], t[1] 232 | results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) 233 | results[i]["masks"] = F.interpolate( 234 | results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest" 235 | ).byte() 236 | 237 | return results 238 | 239 | 240 | class PostProcessPanoptic(nn.Module): 241 | """This class converts the output of the model to the final panoptic result, in the format expected by the 242 | coco panoptic API """ 243 | 244 | def __init__(self, is_thing_map, threshold=0.85): 245 | """ 246 | Parameters: 247 | is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether 248 | the class is a thing (True) or a stuff (False) class 249 | threshold: confidence threshold: segments with confidence lower than this will be deleted 250 | """ 251 | super().__init__() 252 | self.threshold = threshold 253 | self.is_thing_map = is_thing_map 254 | 255 | def forward(self, outputs, processed_sizes, target_sizes=None): 256 | """ This function computes the panoptic prediction from the model's predictions. 257 | Parameters: 258 | outputs: This is a dict coming directly from the model. See the model doc for the content. 259 | processed_sizes: This is a list of tuples (or torch tensors) of sizes of the images that were passed to the 260 | model, ie the size after data augmentation but before batching. 261 | target_sizes: This is a list of tuples (or torch tensors) corresponding to the requested final size 262 | of each prediction. If left to None, it will default to the processed_sizes 263 | """ 264 | if target_sizes is None: 265 | target_sizes = processed_sizes 266 | assert len(processed_sizes) == len(target_sizes) 267 | out_logits, raw_masks, raw_boxes = outputs["pred_logits"], outputs["pred_masks"], outputs["pred_boxes"] 268 | assert len(out_logits) == len(raw_masks) == len(target_sizes) 269 | preds = [] 270 | 271 | def to_tuple(tup): 272 | if isinstance(tup, tuple): 273 | return tup 274 | return tuple(tup.cpu().tolist()) 275 | 276 | for cur_logits, cur_masks, cur_boxes, size, target_size in zip( 277 | out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes 278 | ): 279 | # we filter empty queries and detection below threshold 280 | scores, labels = cur_logits.softmax(-1).max(-1) 281 | keep = labels.ne(outputs["pred_logits"].shape[-1] - 1) & (scores > self.threshold) 282 | cur_scores, cur_classes = cur_logits.softmax(-1).max(-1) 283 | cur_scores = cur_scores[keep] 284 | cur_classes = cur_classes[keep] 285 | cur_masks = cur_masks[keep] 286 | cur_masks = interpolate(cur_masks[:, None], to_tuple(size), mode="bilinear").squeeze(1) 287 | cur_boxes = box_ops.box_cxcywh_to_xyxy(cur_boxes[keep]) 288 | 289 | h, w = cur_masks.shape[-2:] 290 | assert len(cur_boxes) == len(cur_classes) 291 | 292 | # It may be that we have several predicted masks for the same stuff class. 293 | # In the following, we track the list of masks ids for each stuff class (they are merged later on) 294 | cur_masks = cur_masks.flatten(1) 295 | stuff_equiv_classes = defaultdict(lambda: []) 296 | for k, label in enumerate(cur_classes): 297 | if not self.is_thing_map[label.item()]: 298 | stuff_equiv_classes[label.item()].append(k) 299 | 300 | def get_ids_area(masks, scores, dedup=False): 301 | # This helper function creates the final panoptic segmentation image 302 | # It also returns the area of the masks that appears on the image 303 | 304 | m_id = masks.transpose(0, 1).softmax(-1) 305 | 306 | if m_id.shape[-1] == 0: 307 | # We didn't detect any mask :( 308 | m_id = torch.zeros((h, w), dtype=torch.long, device=m_id.device) 309 | else: 310 | m_id = m_id.argmax(-1).view(h, w) 311 | 312 | if dedup: 313 | # Merge the masks corresponding to the same stuff class 314 | for equiv in stuff_equiv_classes.values(): 315 | if len(equiv) > 1: 316 | for eq_id in equiv: 317 | m_id.masked_fill_(m_id.eq(eq_id), equiv[0]) 318 | 319 | final_h, final_w = to_tuple(target_size) 320 | 321 | seg_img = Image.fromarray(id2rgb(m_id.view(h, w).cpu().numpy())) 322 | seg_img = seg_img.resize(size=(final_w, final_h), resample=Image.NEAREST) 323 | 324 | np_seg_img = ( 325 | torch.ByteTensor(torch.ByteStorage.from_buffer(seg_img.tobytes())).view(final_h, final_w, 3).numpy() 326 | ) 327 | m_id = torch.from_numpy(rgb2id(np_seg_img)) 328 | 329 | area = [] 330 | for i in range(len(scores)): 331 | area.append(m_id.eq(i).sum().item()) 332 | return area, seg_img 333 | 334 | area, seg_img = get_ids_area(cur_masks, cur_scores, dedup=True) 335 | if cur_classes.numel() > 0: 336 | # We know filter empty masks as long as we find some 337 | while True: 338 | filtered_small = torch.as_tensor( 339 | [area[i] <= 4 for i, c in enumerate(cur_classes)], dtype=torch.bool, device=keep.device 340 | ) 341 | if filtered_small.any().item(): 342 | cur_scores = cur_scores[~filtered_small] 343 | cur_classes = cur_classes[~filtered_small] 344 | cur_masks = cur_masks[~filtered_small] 345 | area, seg_img = get_ids_area(cur_masks, cur_scores) 346 | else: 347 | break 348 | 349 | else: 350 | cur_classes = torch.ones(1, dtype=torch.long, device=cur_classes.device) 351 | 352 | segments_info = [] 353 | for i, a in enumerate(area): 354 | cat = cur_classes[i].item() 355 | segments_info.append({"id": i, "isthing": self.is_thing_map[cat], "category_id": cat, "area": a}) 356 | del cur_classes 357 | 358 | with io.BytesIO() as out: 359 | seg_img.save(out, format="PNG") 360 | predictions = {"png_string": out.getvalue(), "segments_info": segments_info} 361 | preds.append(predictions) 362 | return preds 363 | -------------------------------------------------------------------------------- /modeling/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Optional, List 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | from torch import nn, Tensor 7 | from detectron2.utils.registry import Registry 8 | from .detr import MLP 9 | from .detr import gen_sineembed_for_position 10 | from .util.misc import inverse_sigmoid 11 | import math 12 | from torch.nn.init import xavier_uniform_, constant_, uniform_, normal_ 13 | 14 | TRANSFORMER_REGISTRY = Registry("TRANSFORMER_REGISTRY") 15 | 16 | @TRANSFORMER_REGISTRY.register() 17 | class Transformer(nn.Module): 18 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 19 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 20 | activation="relu", normalize_before=False, 21 | return_intermediate_dec=False, **kwargs): 22 | super().__init__() 23 | 24 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 25 | dropout, activation, normalize_before) 26 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 27 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 28 | 29 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 30 | dropout, activation, normalize_before) 31 | decoder_norm = nn.LayerNorm(d_model) 32 | self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm, 33 | return_intermediate=return_intermediate_dec) 34 | 35 | self._reset_parameters() 36 | 37 | self.d_model = d_model 38 | self.nhead = nhead 39 | 40 | def _reset_parameters(self): 41 | for p in self.parameters(): 42 | if p.dim() > 1: 43 | nn.init.xavier_uniform_(p) 44 | 45 | def forward(self, src, mask, query_embed, pos_embed): 46 | # flatten NxCxHxW to HWxNxC 47 | bs, c, h, w = src.shape 48 | src = src.flatten(2).permute(2, 0, 1) 49 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 50 | query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1) 51 | mask = mask.flatten(1) 52 | 53 | tgt = torch.zeros_like(query_embed) 54 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 55 | hs = self.decoder(tgt, memory, memory_key_padding_mask=mask, 56 | pos=pos_embed, query_pos=query_embed) 57 | return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w) 58 | 59 | @TRANSFORMER_REGISTRY.register() 60 | class IterativeRelationTransformer(nn.Module): 61 | def __init__(self, d_model=512, nhead=8, num_encoder_layers=6, 62 | num_decoder_layers=6, dim_feedforward=2048, dropout=0.1, 63 | activation="relu", normalize_before=False, 64 | return_intermediate_dec=False, **kwargs): 65 | super().__init__() 66 | 67 | encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward, 68 | dropout, activation, normalize_before) 69 | encoder_norm = nn.LayerNorm(d_model) if normalize_before else None 70 | self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm) 71 | 72 | decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward, 73 | dropout, activation, normalize_before) 74 | layer_norm = nn.LayerNorm(d_model) 75 | relation_layer_norm = nn.LayerNorm(d_model) 76 | self.decoder = IterativeRelationDecoder(decoder_layer, num_decoder_layers, layer_norm, relation_layer_norm, 77 | return_intermediate=return_intermediate_dec, d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout) 78 | 79 | self.d_model = d_model 80 | self.nhead = nhead 81 | self.object_embed = nn.Linear(d_model, kwargs['num_classes'] + 1) 82 | self.object_bbox_coords = MLP(d_model, d_model, 4, 3) 83 | self.relation_embed = nn.Linear(d_model, kwargs['num_relation_classes'] + 1) 84 | 85 | self.num_relation_classes = kwargs['num_relation_classes'] 86 | self.num_object_classes = kwargs['num_classes'] 87 | self._reset_parameters() 88 | 89 | for layer in range(self.decoder.num_layers - 1): 90 | nn.init.constant_(self.decoder.subject_graph_query_residual[layer].weight, 0) 91 | nn.init.constant_(self.decoder.subject_graph_query_residual[layer].bias, 0) 92 | 93 | nn.init.constant_(self.decoder.object_graph_query_residual[layer].weight, 0) 94 | nn.init.constant_(self.decoder.object_graph_query_residual[layer].bias, 0) 95 | 96 | nn.init.constant_(self.decoder.relation_graph_query_residual[layer].weight, 0) 97 | nn.init.constant_(self.decoder.relation_graph_query_residual[layer].bias, 0) 98 | 99 | for layer in range(self.decoder.num_layers): 100 | nn.init.constant_(self.decoder.object_pos_linear[layer].weight, 0) 101 | nn.init.constant_(self.decoder.object_pos_linear[layer].bias, 0) 102 | nn.init.constant_(self.decoder.relation_pos_linear[layer].weight, 0) 103 | nn.init.constant_(self.decoder.relation_pos_linear[layer].bias, 0) 104 | 105 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 106 | return tensor if pos is None else tensor + pos 107 | 108 | def _reset_parameters(self): 109 | for p in self.parameters(): 110 | if p.dim() > 1: 111 | nn.init.xavier_uniform_(p) 112 | 113 | def forward(self, src, mask, subject_embed, object_embed, relation_embed, pos_embed): 114 | # flatten NxCxHxW to HWxNxC 115 | bs, c, h, w = src.shape 116 | src = src.flatten(2).permute(2, 0, 1) 117 | pos_embed = pos_embed.flatten(2).permute(2, 0, 1) 118 | mask = mask.flatten(1) 119 | memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed) 120 | 121 | subject_query_embed = subject_embed.unsqueeze(1).repeat(1, bs, 1) 122 | object_query_embed = object_embed.unsqueeze(1).repeat(1, bs, 1) 123 | relation_query_embed = relation_embed.unsqueeze(1).repeat(1, bs, 1) 124 | 125 | # Condition on subject 126 | tgt_sub = torch.zeros_like(subject_query_embed) 127 | tgt_obj = torch.zeros_like(object_query_embed) 128 | tgt_rel = torch.zeros_like(relation_query_embed) 129 | hs_subject, hs_object, hs_relation = self.decoder(tgt_sub, tgt_obj, tgt_rel, memory, memory_key_padding_mask=mask, 130 | pos=pos_embed, subject_pos=subject_query_embed, object_pos=object_query_embed, relation_pos=relation_query_embed) 131 | relation_subject_class = self.object_embed(hs_subject) 132 | relation_subject_coords = self.object_bbox_coords(hs_subject).sigmoid() 133 | relation_object_class = self.object_embed(hs_object) 134 | relation_object_coords = self.object_bbox_coords(hs_object).sigmoid() 135 | relation_class = self.relation_embed(hs_relation) 136 | relation_coords = self.object_bbox_coords(hs_relation).sigmoid() 137 | output = { 138 | 'relation_coords': relation_coords.transpose(1, 2), 139 | 'relation_logits': relation_class.transpose(1, 2), 140 | 'relation_subject_logits': relation_subject_class.transpose(1, 2), 141 | 'relation_object_logits': relation_object_class.transpose(1, 2), 142 | 'relation_subject_coords': relation_subject_coords.transpose(1, 2), 143 | 'relation_object_coords': relation_object_coords.transpose(1, 2) 144 | } 145 | 146 | return output 147 | 148 | 149 | class TransformerEncoder(nn.Module): 150 | 151 | def __init__(self, encoder_layer, num_layers, norm=None): 152 | super().__init__() 153 | self.layers = _get_clones(encoder_layer, num_layers) 154 | self.num_layers = num_layers 155 | self.norm = norm 156 | 157 | def forward(self, src, 158 | mask: Optional[Tensor] = None, 159 | src_key_padding_mask: Optional[Tensor] = None, 160 | pos: Optional[Tensor] = None): 161 | output = src 162 | 163 | for layer in self.layers: 164 | output = layer(output, src_mask=mask, 165 | src_key_padding_mask=src_key_padding_mask, pos=pos) 166 | 167 | if self.norm is not None: 168 | output = self.norm(output) 169 | 170 | return output 171 | 172 | class TransformerDecoder(nn.Module): 173 | 174 | def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False): 175 | super().__init__() 176 | self.layers = _get_clones(decoder_layer, num_layers) 177 | self.num_layers = num_layers 178 | self.norm = norm 179 | self.return_intermediate = return_intermediate 180 | 181 | def forward(self, tgt, memory, 182 | tgt_mask: Optional[Tensor] = None, 183 | memory_mask: Optional[Tensor] = None, 184 | tgt_key_padding_mask: Optional[Tensor] = None, 185 | memory_key_padding_mask: Optional[Tensor] = None, 186 | pos: Optional[Tensor] = None, 187 | query_pos: Optional[Tensor] = None): 188 | output = tgt 189 | 190 | intermediate = [] 191 | 192 | for layer in self.layers: 193 | output = layer(output, memory, tgt_mask=tgt_mask, 194 | memory_mask=memory_mask, 195 | tgt_key_padding_mask=tgt_key_padding_mask, 196 | memory_key_padding_mask=memory_key_padding_mask, 197 | pos=pos, query_pos=query_pos) 198 | if self.return_intermediate: 199 | intermediate.append(self.norm(output)) 200 | 201 | if self.norm is not None: 202 | output = self.norm(output) 203 | if self.return_intermediate: 204 | intermediate.pop() 205 | intermediate.append(output) 206 | 207 | if self.return_intermediate: 208 | return torch.stack(intermediate) 209 | 210 | return output.unsqueeze(0) 211 | 212 | class TransformerEncoderLayer(nn.Module): 213 | 214 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 215 | activation="relu", normalize_before=False): 216 | super().__init__() 217 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 218 | # Implementation of Feedforward model 219 | self.linear1 = nn.Linear(d_model, dim_feedforward) 220 | self.dropout = nn.Dropout(dropout) 221 | self.linear2 = nn.Linear(dim_feedforward, d_model) 222 | 223 | self.norm1 = nn.LayerNorm(d_model) 224 | self.norm2 = nn.LayerNorm(d_model) 225 | self.dropout1 = nn.Dropout(dropout) 226 | self.dropout2 = nn.Dropout(dropout) 227 | 228 | self.activation = _get_activation_fn(activation) 229 | self.normalize_before = normalize_before 230 | 231 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 232 | return tensor if pos is None else tensor + pos 233 | 234 | def forward_post(self, 235 | src, 236 | src_mask: Optional[Tensor] = None, 237 | src_key_padding_mask: Optional[Tensor] = None, 238 | pos: Optional[Tensor] = None): 239 | q = k = self.with_pos_embed(src, pos) 240 | src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, 241 | key_padding_mask=src_key_padding_mask)[0] 242 | src = src + self.dropout1(src2) 243 | src = self.norm1(src) 244 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 245 | src = src + self.dropout2(src2) 246 | src = self.norm2(src) 247 | return src 248 | 249 | def forward_pre(self, src, 250 | src_mask: Optional[Tensor] = None, 251 | src_key_padding_mask: Optional[Tensor] = None, 252 | pos: Optional[Tensor] = None): 253 | src2 = self.norm1(src) 254 | q = k = self.with_pos_embed(src2, pos) 255 | src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, 256 | key_padding_mask=src_key_padding_mask)[0] 257 | src = src + self.dropout1(src2) 258 | src2 = self.norm2(src) 259 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 260 | src = src + self.dropout2(src2) 261 | return src 262 | 263 | def forward(self, src, 264 | src_mask: Optional[Tensor] = None, 265 | src_key_padding_mask: Optional[Tensor] = None, 266 | pos: Optional[Tensor] = None): 267 | if self.normalize_before: 268 | return self.forward_pre(src, src_mask, src_key_padding_mask, pos) 269 | return self.forward_post(src, src_mask, src_key_padding_mask, pos) 270 | 271 | class TransformerDecoderLayer(nn.Module): 272 | 273 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, 274 | activation="relu", normalize_before=False): 275 | super().__init__() 276 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 277 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) 278 | # Implementation of Feedforward model 279 | self.linear1 = nn.Linear(d_model, dim_feedforward) 280 | self.dropout = nn.Dropout(dropout) 281 | self.linear2 = nn.Linear(dim_feedforward, d_model) 282 | 283 | self.norm1 = nn.LayerNorm(d_model) 284 | self.norm2 = nn.LayerNorm(d_model) 285 | self.norm3 = nn.LayerNorm(d_model) 286 | self.dropout1 = nn.Dropout(dropout) 287 | self.dropout2 = nn.Dropout(dropout) 288 | self.dropout3 = nn.Dropout(dropout) 289 | 290 | self.activation = _get_activation_fn(activation) 291 | self.normalize_before = normalize_before 292 | 293 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 294 | return tensor if pos is None else tensor + pos 295 | 296 | def forward_post(self, tgt, memory, 297 | tgt_mask: Optional[Tensor] = None, 298 | memory_mask: Optional[Tensor] = None, 299 | tgt_key_padding_mask: Optional[Tensor] = None, 300 | memory_key_padding_mask: Optional[Tensor] = None, 301 | pos: Optional[Tensor] = None, 302 | query_pos: Optional[Tensor] = None): 303 | q = k = self.with_pos_embed(tgt, query_pos) 304 | tgt2 = self.self_attn(q, k, value=tgt, attn_mask=tgt_mask, 305 | key_padding_mask=tgt_key_padding_mask)[0] 306 | tgt = tgt + self.dropout1(tgt2) 307 | tgt = self.norm1(tgt) 308 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos), 309 | key=self.with_pos_embed(memory, pos), 310 | value=memory, attn_mask=memory_mask, 311 | key_padding_mask=memory_key_padding_mask)[0] 312 | tgt = tgt + self.dropout2(tgt2) 313 | tgt = self.norm2(tgt) 314 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) 315 | tgt = tgt + self.dropout3(tgt2) 316 | tgt = self.norm3(tgt) 317 | return tgt 318 | 319 | def forward_pre(self, tgt, memory, 320 | tgt_mask: Optional[Tensor] = None, 321 | memory_mask: Optional[Tensor] = None, 322 | tgt_key_padding_mask: Optional[Tensor] = None, 323 | memory_key_padding_mask: Optional[Tensor] = None, 324 | pos: Optional[Tensor] = None, 325 | query_pos: Optional[Tensor] = None): 326 | tgt2 = self.norm1(tgt) 327 | q = k = self.with_pos_embed(tgt2, query_pos) 328 | tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, 329 | key_padding_mask=tgt_key_padding_mask)[0] 330 | tgt = tgt + self.dropout1(tgt2) 331 | tgt2 = self.norm2(tgt) 332 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 333 | key=self.with_pos_embed(memory, pos), 334 | value=memory, attn_mask=memory_mask, 335 | key_padding_mask=memory_key_padding_mask)[0] 336 | tgt = tgt + self.dropout2(tgt2) 337 | tgt2 = self.norm3(tgt) 338 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 339 | tgt = tgt + self.dropout3(tgt2) 340 | return tgt 341 | 342 | def forward(self, tgt, memory, 343 | tgt_mask: Optional[Tensor] = None, 344 | memory_mask: Optional[Tensor] = None, 345 | tgt_key_padding_mask: Optional[Tensor] = None, 346 | memory_key_padding_mask: Optional[Tensor] = None, 347 | pos: Optional[Tensor] = None, 348 | query_pos: Optional[Tensor] = None): 349 | if self.normalize_before: 350 | return self.forward_pre(tgt, memory, tgt_mask, memory_mask, 351 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 352 | return self.forward_post(tgt, memory, tgt_mask, memory_mask, 353 | tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos) 354 | 355 | class IterativeRelationDecoder(nn.Module): 356 | def __init__(self, decoder_layer, num_layers, norm=None, relation_norm=None, return_intermediate=False, d_model=512, nhead=8, dim_feedforward=2048, dropout=0.1): 357 | super().__init__() 358 | self.subject_layers = _get_clones(decoder_layer, num_layers) 359 | self.object_layers = _get_clones(decoder_layer, num_layers) 360 | self.relation_layers = _get_clones(decoder_layer, num_layers) 361 | self.num_layers = num_layers 362 | self.subject_norm = norm 363 | self.relation_norm = relation_norm 364 | self.return_intermediate = return_intermediate 365 | 366 | self.object_pos_attn = nn.ModuleList([nn.MultiheadAttention(d_model, nhead, dropout=dropout) for _ in range(num_layers)]) 367 | self.object_pos_linear = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)]) 368 | self.object_pos_dropout= nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers)]) 369 | 370 | self.relation_pos_attn = nn.ModuleList([nn.MultiheadAttention(d_model, nhead, dropout=dropout, kdim=2*d_model, vdim=2*d_model) for _ in range(num_layers)]) 371 | self.relation_pos_linear = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers)]) 372 | self.relation_pos_dropout = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers)]) 373 | 374 | self.subject_graph_query_attn = nn.ModuleList([nn.MultiheadAttention(d_model, nhead, dropout=dropout, kdim=3*d_model, vdim=3*d_model) for _ in range(num_layers-1)]) 375 | self.subject_graph_query_residual = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers-1)]) 376 | self.subject_graph_query_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers-1)]) 377 | self.subject_graph_query_dropout = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers-1)]) 378 | 379 | self.object_graph_query_attn = nn.ModuleList([nn.MultiheadAttention(d_model, nhead, dropout=dropout, kdim=3*d_model, vdim=3*d_model) for _ in range(num_layers-1)]) 380 | self.object_graph_query_residual = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers-1)]) 381 | self.object_graph_query_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers-1)]) 382 | self.object_graph_query_dropout = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers-1)]) 383 | 384 | self.relation_graph_query_attn = nn.ModuleList([nn.MultiheadAttention(d_model, nhead, dropout=dropout, kdim=3*d_model, vdim=3*d_model) for _ in range(num_layers-1)]) 385 | self.relation_graph_query_residual = nn.ModuleList([nn.Linear(d_model, d_model) for _ in range(num_layers-1)]) 386 | self.relation_graph_query_norm = nn.ModuleList([nn.LayerNorm(d_model) for _ in range(num_layers-1)]) 387 | self.relation_graph_query_dropout = nn.ModuleList([nn.Dropout(dropout) for _ in range(num_layers-1)]) 388 | 389 | def with_pos_embed(self, tensor, pos: Optional[Tensor]): 390 | return tensor if pos is None else tensor + pos 391 | 392 | def object_query_generator(self, object_query_embed, subject_features, subject_pos, layer): 393 | sub = self.with_pos_embed(subject_features, subject_pos) 394 | tgt = self.object_pos_attn[layer](object_query_embed, sub, value=subject_features)[0] 395 | tgt_residual = self.object_pos_linear[layer](tgt) 396 | object_query_embed = object_query_embed + self.object_pos_dropout[layer](tgt_residual) 397 | 398 | return object_query_embed 399 | 400 | def relation_query_generator(self, relation_query_embed, subject_features, object_features, subject_pos, object_pos, layer): 401 | sub = self.with_pos_embed(subject_features, subject_pos) 402 | obj = self.with_pos_embed(object_features, object_pos) 403 | 404 | k = torch.cat([sub, obj], -1) 405 | v = torch.cat([subject_features, object_features], -1) 406 | tgt = self.relation_pos_attn[layer](relation_query_embed, k, value=v)[0] 407 | tgt_residual = self.relation_pos_linear[layer](tgt) 408 | relation_query_embed = relation_query_embed + self.relation_pos_dropout[layer](tgt_residual) 409 | return relation_query_embed 410 | 411 | def forward(self, tgt_sub, tgt_obj, tgt_rel, memory, 412 | tgt_mask: Optional[Tensor] = None, 413 | memory_mask: Optional[Tensor] = None, 414 | tgt_key_padding_mask: Optional[Tensor] = None, 415 | memory_key_padding_mask: Optional[Tensor] = None, 416 | pos: Optional[Tensor] = None, 417 | subject_pos: Optional[Tensor] = None, 418 | object_pos: Optional[Tensor] = None, 419 | relation_pos: Optional[Tensor] = None): 420 | output_subject = tgt_sub 421 | output_object = tgt_obj 422 | output_relation = tgt_rel 423 | 424 | intermediate_relation = [] 425 | intermediate_subject = [] 426 | intermediate_object = [] 427 | for layer_id in range(self.num_layers): 428 | sub_features = self.subject_layers[layer_id](output_subject, memory, tgt_mask=tgt_mask, 429 | memory_mask=memory_mask, 430 | tgt_key_padding_mask=tgt_key_padding_mask, 431 | memory_key_padding_mask=memory_key_padding_mask, 432 | pos=pos, query_pos=subject_pos) 433 | 434 | conditional_object_pos = self.object_query_generator(object_pos, sub_features, subject_pos, layer_id) 435 | obj_features = self.object_layers[layer_id](output_object, memory, tgt_mask=tgt_mask, 436 | memory_mask=memory_mask, 437 | tgt_key_padding_mask=tgt_key_padding_mask, 438 | memory_key_padding_mask=memory_key_padding_mask, 439 | pos=pos, query_pos=conditional_object_pos) 440 | 441 | conditional_relation_pos = self.relation_query_generator(relation_pos, sub_features, obj_features, subject_pos, conditional_object_pos, layer_id) 442 | rel_features = self.relation_layers[layer_id](output_relation, memory, tgt_mask=tgt_mask, 443 | memory_mask=memory_mask, 444 | tgt_key_padding_mask=tgt_key_padding_mask, 445 | memory_key_padding_mask=memory_key_padding_mask, 446 | pos=pos, query_pos=conditional_relation_pos) 447 | 448 | if self.return_intermediate: 449 | intermediate_subject.append(self.subject_norm(sub_features)) 450 | intermediate_object.append(self.subject_norm(obj_features)) 451 | intermediate_relation.append(self.relation_norm(rel_features)) 452 | 453 | if layer_id != self.num_layers - 1: 454 | 455 | # Get queries for each decoder 456 | triplet_features = torch.cat([sub_features, obj_features, rel_features], -1) 457 | triplet_pos = torch.cat([subject_pos, conditional_object_pos, conditional_relation_pos], -1) 458 | triplet_features_pos = self.with_pos_embed(triplet_features, triplet_pos) 459 | 460 | subject_with_pos = self.with_pos_embed(sub_features, subject_pos) 461 | object_with_pos = self.with_pos_embed(obj_features, conditional_object_pos) 462 | relation_with_pos = self.with_pos_embed(rel_features, conditional_relation_pos) 463 | 464 | # Relation queries 465 | subject_graph_residual = self.subject_graph_query_residual[layer_id](self.subject_graph_query_attn[layer_id](subject_with_pos, triplet_features_pos, value=triplet_features)[0]) 466 | output_subject = sub_features + self.subject_graph_query_norm[layer_id](self.subject_graph_query_dropout[layer_id](subject_graph_residual)) 467 | 468 | object_graph_residual = self.object_graph_query_residual[layer_id](self.object_graph_query_attn[layer_id](object_with_pos, triplet_features_pos, value=triplet_features)[0]) 469 | output_object = obj_features + self.object_graph_query_norm[layer_id](self.object_graph_query_dropout[layer_id](object_graph_residual)) 470 | 471 | relation_graph_residual = self.relation_graph_query_residual[layer_id](self.relation_graph_query_attn[layer_id](relation_with_pos, triplet_features_pos, value=triplet_features)[0]) 472 | output_relation = rel_features + self.relation_graph_query_norm[layer_id](self.relation_graph_query_dropout[layer_id](relation_graph_residual)) 473 | 474 | 475 | if self.subject_norm is not None: 476 | sub_features = self.subject_norm(sub_features) 477 | obj_features = self.subject_norm(obj_features) 478 | rel_features = self.relation_norm(rel_features) 479 | if self.return_intermediate: 480 | intermediate_subject.pop() 481 | intermediate_subject.append(sub_features) 482 | intermediate_object.pop() 483 | intermediate_object.append(obj_features) 484 | intermediate_relation.pop() 485 | intermediate_relation.append(rel_features) 486 | 487 | if self.return_intermediate: 488 | return torch.stack(intermediate_subject), torch.stack(intermediate_object), torch.stack(intermediate_relation) 489 | 490 | return sub_features.unsqueeze(0), obj_features.unsqueeze(0), rel_features.unsqueeze(0) 491 | 492 | 493 | 494 | def _get_clones(module, N): 495 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 496 | 497 | 498 | def build_transformer(name, d_model, dropout, nhead, dim_feedforward, num_encoder_layers, num_decoder_layers, normalize_before, return_intermediate_dec, **kwargs): 499 | return TRANSFORMER_REGISTRY.get(name)( 500 | d_model=d_model, 501 | dropout=dropout, 502 | nhead=nhead, 503 | dim_feedforward=dim_feedforward, 504 | num_encoder_layers=num_encoder_layers, 505 | num_decoder_layers=num_decoder_layers, 506 | normalize_before=normalize_before, 507 | return_intermediate_dec=return_intermediate_dec, 508 | **kwargs 509 | ) 510 | 511 | def _get_activation_fn(activation): 512 | """Return an activation function given a string""" 513 | if activation == "relu": 514 | return F.relu 515 | if activation == "gelu": 516 | return F.gelu 517 | if activation == "glu": 518 | return F.glu 519 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") -------------------------------------------------------------------------------- /modeling/transformer/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/IterativeSG/7abd162cb8e510cfbcededbeb18b36b54f381189/modeling/transformer/util/__init__.py -------------------------------------------------------------------------------- /modeling/transformer/util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities for bounding box manipulation and GIoU. 3 | """ 4 | import torch 5 | from torchvision.ops.boxes import box_area 6 | 7 | 8 | def box_cxcywh_to_xyxy(x): 9 | x_c, y_c, w, h = x.unbind(-1) 10 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 11 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 12 | return torch.stack(b, dim=-1) 13 | 14 | 15 | def box_xyxy_to_cxcywh(x): 16 | x0, y0, x1, y1 = x.unbind(-1) 17 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 18 | (x1 - x0), (y1 - y0)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | # modified from torchvision to also return the union 23 | def box_iou(boxes1, boxes2): 24 | area1 = box_area(boxes1) 25 | area2 = box_area(boxes2) 26 | 27 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 28 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 29 | 30 | wh = (rb - lt).clamp(min=0) # [N,M,2] 31 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 32 | 33 | union = area1[:, None] + area2 - inter 34 | 35 | iou = inter / union 36 | return iou, union 37 | 38 | 39 | def generalized_box_iou(boxes1, boxes2): 40 | """ 41 | Generalized IoU from https://giou.stanford.edu/ 42 | The boxes should be in [x0, y0, x1, y1] format 43 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 44 | and M = len(boxes2) 45 | """ 46 | # degenerate boxes gives inf / nan results 47 | # so do an early check 48 | if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): 49 | print ("Box1", boxes1, boxes1[:, :2][boxes1[:, 2:] < boxes1[:, :2]], boxes1[:, 2:][boxes1[:, 2:] < boxes1[:, :2]], (boxes1[:, 2:] < boxes1[:, :2]).nonzero(), boxes1.max(), boxes1.min(), boxes1.isnan().nonzero()) 50 | if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): 51 | print ("Box2", boxes2, boxes2[:, :2][boxes2[:, 2:] < boxes2[:, :2]], boxes2[:, 2:][boxes2[:, 2:] < boxes2[:, :2]], (boxes2[:, 2:] < boxes2[:, :2]).nonzero(), boxes2.max(), boxes2.min(), boxes2.isnan().nonzero()) 52 | assert (boxes1[:, 2:] >= boxes1[:, :2]).all() 53 | assert (boxes2[:, 2:] >= boxes2[:, :2]).all() 54 | iou, union = box_iou(boxes1, boxes2) 55 | 56 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 57 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 58 | 59 | wh = (rb - lt).clamp(min=0) # [N,M,2] 60 | area = wh[:, :, 0] * wh[:, :, 1] 61 | 62 | return iou - (area - union) / area 63 | 64 | 65 | def masks_to_boxes(masks): 66 | """Compute the bounding boxes around the provided masks 67 | The masks should be in format [N, H, W] where N is the number of masks, (H, W) are the spatial dimensions. 68 | Returns a [N, 4] tensors, with the boxes in xyxy format 69 | """ 70 | if masks.numel() == 0: 71 | return torch.zeros((0, 4), device=masks.device) 72 | 73 | h, w = masks.shape[-2:] 74 | 75 | y = torch.arange(0, h, dtype=torch.float) 76 | x = torch.arange(0, w, dtype=torch.float) 77 | y, x = torch.meshgrid(y, x) 78 | 79 | x_mask = (masks * x.unsqueeze(0)) 80 | x_max = x_mask.flatten(1).max(-1)[0] 81 | x_min = x_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 82 | 83 | y_mask = (masks * y.unsqueeze(0)) 84 | y_max = y_mask.flatten(1).max(-1)[0] 85 | y_min = y_mask.masked_fill(~(masks.bool()), 1e8).flatten(1).min(-1)[0] 86 | 87 | return torch.stack([x_min, y_min, x_max, y_max], 1) -------------------------------------------------------------------------------- /modeling/transformer/util/misc.py: -------------------------------------------------------------------------------- 1 | """ 2 | Misc functions, including distributed helpers. 3 | Mostly copy-paste from torchvision references. 4 | """ 5 | import os 6 | import subprocess 7 | import time 8 | from collections import defaultdict, deque 9 | import datetime 10 | import pickle 11 | from packaging import version 12 | from typing import Optional, List 13 | 14 | import torch 15 | import torch.distributed as dist 16 | from torch import Tensor 17 | 18 | # needed due to empty tensor bug in pytorch and torchvision 0.5 19 | import torchvision 20 | if version.parse(torchvision.__version__) < version.parse('0.7'): 21 | from torchvision.ops import _new_empty_tensor 22 | from torchvision.ops.misc import _output_size 23 | 24 | 25 | class SmoothedValue(object): 26 | """Track a series of values and provide access to smoothed values over a 27 | window or the global series average. 28 | """ 29 | 30 | def __init__(self, window_size=20, fmt=None): 31 | if fmt is None: 32 | fmt = "{median:.4f} ({global_avg:.4f})" 33 | self.deque = deque(maxlen=window_size) 34 | self.total = 0.0 35 | self.count = 0 36 | self.fmt = fmt 37 | 38 | def update(self, value, n=1): 39 | self.deque.append(value) 40 | self.count += n 41 | self.total += value * n 42 | 43 | def synchronize_between_processes(self): 44 | """ 45 | Warning: does not synchronize the deque! 46 | """ 47 | if not is_dist_avail_and_initialized(): 48 | return 49 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') 50 | dist.barrier() 51 | dist.all_reduce(t) 52 | t = t.tolist() 53 | self.count = int(t[0]) 54 | self.total = t[1] 55 | 56 | @property 57 | def median(self): 58 | d = torch.tensor(list(self.deque)) 59 | return d.median().item() 60 | 61 | @property 62 | def avg(self): 63 | d = torch.tensor(list(self.deque), dtype=torch.float32) 64 | return d.mean().item() 65 | 66 | @property 67 | def global_avg(self): 68 | return self.total / self.count 69 | 70 | @property 71 | def max(self): 72 | return max(self.deque) 73 | 74 | @property 75 | def value(self): 76 | return self.deque[-1] 77 | 78 | def __str__(self): 79 | return self.fmt.format( 80 | median=self.median, 81 | avg=self.avg, 82 | global_avg=self.global_avg, 83 | max=self.max, 84 | value=self.value) 85 | 86 | 87 | def all_gather(data): 88 | """ 89 | Run all_gather on arbitrary picklable data (not necessarily tensors) 90 | Args: 91 | data: any picklable object 92 | Returns: 93 | list[data]: list of data gathered from each rank 94 | """ 95 | world_size = get_world_size() 96 | if world_size == 1: 97 | return [data] 98 | 99 | # serialized to a Tensor 100 | buffer = pickle.dumps(data) 101 | storage = torch.ByteStorage.from_buffer(buffer) 102 | tensor = torch.ByteTensor(storage).to("cuda") 103 | 104 | # obtain Tensor size of each rank 105 | local_size = torch.tensor([tensor.numel()], device="cuda") 106 | size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] 107 | dist.all_gather(size_list, local_size) 108 | size_list = [int(size.item()) for size in size_list] 109 | max_size = max(size_list) 110 | 111 | # receiving Tensor from all ranks 112 | # we pad the tensor because torch all_gather does not support 113 | # gathering tensors of different shapes 114 | tensor_list = [] 115 | for _ in size_list: 116 | tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) 117 | if local_size != max_size: 118 | padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") 119 | tensor = torch.cat((tensor, padding), dim=0) 120 | dist.all_gather(tensor_list, tensor) 121 | 122 | data_list = [] 123 | for size, tensor in zip(size_list, tensor_list): 124 | buffer = tensor.cpu().numpy().tobytes()[:size] 125 | data_list.append(pickle.loads(buffer)) 126 | 127 | return data_list 128 | 129 | 130 | def reduce_dict(input_dict, average=True): 131 | """ 132 | Args: 133 | input_dict (dict): all the values will be reduced 134 | average (bool): whether to do average or sum 135 | Reduce the values in the dictionary from all processes so that all processes 136 | have the averaged results. Returns a dict with the same fields as 137 | input_dict, after reduction. 138 | """ 139 | world_size = get_world_size() 140 | if world_size < 2: 141 | return input_dict 142 | with torch.no_grad(): 143 | names = [] 144 | values = [] 145 | # sort the keys so that they are consistent across processes 146 | for k in sorted(input_dict.keys()): 147 | names.append(k) 148 | values.append(input_dict[k]) 149 | values = torch.stack(values, dim=0) 150 | dist.all_reduce(values) 151 | if average: 152 | values /= world_size 153 | reduced_dict = {k: v for k, v in zip(names, values)} 154 | return reduced_dict 155 | 156 | 157 | class MetricLogger(object): 158 | def __init__(self, delimiter="\t"): 159 | self.meters = defaultdict(SmoothedValue) 160 | self.delimiter = delimiter 161 | 162 | def update(self, **kwargs): 163 | for k, v in kwargs.items(): 164 | if isinstance(v, torch.Tensor): 165 | v = v.item() 166 | assert isinstance(v, (float, int)) 167 | self.meters[k].update(v) 168 | 169 | def __getattr__(self, attr): 170 | if attr in self.meters: 171 | return self.meters[attr] 172 | if attr in self.__dict__: 173 | return self.__dict__[attr] 174 | raise AttributeError("'{}' object has no attribute '{}'".format( 175 | type(self).__name__, attr)) 176 | 177 | def __str__(self): 178 | loss_str = [] 179 | for name, meter in self.meters.items(): 180 | loss_str.append( 181 | "{}: {}".format(name, str(meter)) 182 | ) 183 | return self.delimiter.join(loss_str) 184 | 185 | def synchronize_between_processes(self): 186 | for meter in self.meters.values(): 187 | meter.synchronize_between_processes() 188 | 189 | def add_meter(self, name, meter): 190 | self.meters[name] = meter 191 | 192 | def log_every(self, iterable, print_freq, header=None): 193 | i = 0 194 | if not header: 195 | header = '' 196 | start_time = time.time() 197 | end = time.time() 198 | iter_time = SmoothedValue(fmt='{avg:.4f}') 199 | data_time = SmoothedValue(fmt='{avg:.4f}') 200 | space_fmt = ':' + str(len(str(len(iterable)))) + 'd' 201 | if torch.cuda.is_available(): 202 | log_msg = self.delimiter.join([ 203 | header, 204 | '[{0' + space_fmt + '}/{1}]', 205 | 'eta: {eta}', 206 | '{meters}', 207 | 'time: {time}', 208 | 'data: {data}', 209 | 'max mem: {memory:.0f}' 210 | ]) 211 | else: 212 | log_msg = self.delimiter.join([ 213 | header, 214 | '[{0' + space_fmt + '}/{1}]', 215 | 'eta: {eta}', 216 | '{meters}', 217 | 'time: {time}', 218 | 'data: {data}' 219 | ]) 220 | MB = 1024.0 * 1024.0 221 | for obj in iterable: 222 | data_time.update(time.time() - end) 223 | yield obj 224 | iter_time.update(time.time() - end) 225 | if i % print_freq == 0 or i == len(iterable) - 1: 226 | eta_seconds = iter_time.global_avg * (len(iterable) - i) 227 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 228 | if torch.cuda.is_available(): 229 | print(log_msg.format( 230 | i, len(iterable), eta=eta_string, 231 | meters=str(self), 232 | time=str(iter_time), data=str(data_time), 233 | memory=torch.cuda.max_memory_allocated() / MB)) 234 | else: 235 | print(log_msg.format( 236 | i, len(iterable), eta=eta_string, 237 | meters=str(self), 238 | time=str(iter_time), data=str(data_time))) 239 | i += 1 240 | end = time.time() 241 | total_time = time.time() - start_time 242 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 243 | print('{} Total time: {} ({:.4f} s / it)'.format( 244 | header, total_time_str, total_time / len(iterable))) 245 | 246 | 247 | def get_sha(): 248 | cwd = os.path.dirname(os.path.abspath(__file__)) 249 | 250 | def _run(command): 251 | return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() 252 | sha = 'N/A' 253 | diff = "clean" 254 | branch = 'N/A' 255 | try: 256 | sha = _run(['git', 'rev-parse', 'HEAD']) 257 | subprocess.check_output(['git', 'diff'], cwd=cwd) 258 | diff = _run(['git', 'diff-index', 'HEAD']) 259 | diff = "has uncommited changes" if diff else "clean" 260 | branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) 261 | except Exception: 262 | pass 263 | message = f"sha: {sha}, status: {diff}, branch: {branch}" 264 | return message 265 | 266 | 267 | def collate_fn(batch): 268 | batch = list(zip(*batch)) 269 | batch[0] = nested_tensor_from_tensor_list(batch[0]) 270 | return tuple(batch) 271 | 272 | 273 | def _max_by_axis(the_list): 274 | # type: (List[List[int]]) -> List[int] 275 | maxes = the_list[0] 276 | for sublist in the_list[1:]: 277 | for index, item in enumerate(sublist): 278 | maxes[index] = max(maxes[index], item) 279 | return maxes 280 | 281 | 282 | class NestedTensor(object): 283 | def __init__(self, tensors, mask: Optional[Tensor]): 284 | self.tensors = tensors 285 | self.mask = mask 286 | 287 | def to(self, device): 288 | # type: (Device) -> NestedTensor # noqa 289 | cast_tensor = self.tensors.to(device) 290 | mask = self.mask 291 | if mask is not None: 292 | assert mask is not None 293 | cast_mask = mask.to(device) 294 | else: 295 | cast_mask = None 296 | return NestedTensor(cast_tensor, cast_mask) 297 | 298 | def decompose(self): 299 | return self.tensors, self.mask 300 | 301 | def __repr__(self): 302 | return str(self.tensors) 303 | 304 | 305 | def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): 306 | # TODO make this more general 307 | if tensor_list[0].ndim == 3: 308 | if torchvision._is_tracing(): 309 | # nested_tensor_from_tensor_list() does not export well to ONNX 310 | # call _onnx_nested_tensor_from_tensor_list() instead 311 | return _onnx_nested_tensor_from_tensor_list(tensor_list) 312 | 313 | # TODO make it support different-sized images 314 | max_size = _max_by_axis([list(img.shape) for img in tensor_list]) 315 | # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) 316 | batch_shape = [len(tensor_list)] + max_size 317 | b, c, h, w = batch_shape 318 | dtype = tensor_list[0].dtype 319 | device = tensor_list[0].device 320 | tensor = torch.zeros(batch_shape, dtype=dtype, device=device) 321 | mask = torch.ones((b, h, w), dtype=torch.bool, device=device) 322 | for img, pad_img, m in zip(tensor_list, tensor, mask): 323 | pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 324 | m[: img.shape[1], :img.shape[2]] = False 325 | else: 326 | raise ValueError('not supported') 327 | return NestedTensor(tensor, mask) 328 | 329 | 330 | # _onnx_nested_tensor_from_tensor_list() is an implementation of 331 | # nested_tensor_from_tensor_list() that is supported by ONNX tracing. 332 | @torch.jit.unused 333 | def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: 334 | max_size = [] 335 | for i in range(tensor_list[0].dim()): 336 | max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) 337 | max_size.append(max_size_i) 338 | max_size = tuple(max_size) 339 | 340 | # work around for 341 | # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) 342 | # m[: img.shape[1], :img.shape[2]] = False 343 | # which is not yet supported in onnx 344 | padded_imgs = [] 345 | padded_masks = [] 346 | for img in tensor_list: 347 | padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] 348 | padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) 349 | padded_imgs.append(padded_img) 350 | 351 | m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) 352 | padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) 353 | padded_masks.append(padded_mask.to(torch.bool)) 354 | 355 | tensor = torch.stack(padded_imgs) 356 | mask = torch.stack(padded_masks) 357 | 358 | return NestedTensor(tensor, mask=mask) 359 | 360 | 361 | def setup_for_distributed(is_master): 362 | """ 363 | This function disables printing when not in master process 364 | """ 365 | import builtins as __builtin__ 366 | builtin_print = __builtin__.print 367 | 368 | def print(*args, **kwargs): 369 | force = kwargs.pop('force', False) 370 | if is_master or force: 371 | builtin_print(*args, **kwargs) 372 | 373 | __builtin__.print = print 374 | 375 | 376 | def is_dist_avail_and_initialized(): 377 | if not dist.is_available(): 378 | return False 379 | if not dist.is_initialized(): 380 | return False 381 | return True 382 | 383 | 384 | def get_world_size(): 385 | if not is_dist_avail_and_initialized(): 386 | return 1 387 | return dist.get_world_size() 388 | 389 | 390 | def get_rank(): 391 | if not is_dist_avail_and_initialized(): 392 | return 0 393 | return dist.get_rank() 394 | 395 | 396 | def is_main_process(): 397 | return get_rank() == 0 398 | 399 | 400 | def save_on_master(*args, **kwargs): 401 | if is_main_process(): 402 | torch.save(*args, **kwargs) 403 | 404 | 405 | def init_distributed_mode(args): 406 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 407 | args.rank = int(os.environ["RANK"]) 408 | args.world_size = int(os.environ['WORLD_SIZE']) 409 | args.gpu = int(os.environ['LOCAL_RANK']) 410 | elif 'SLURM_PROCID' in os.environ: 411 | args.rank = int(os.environ['SLURM_PROCID']) 412 | args.gpu = args.rank % torch.cuda.device_count() 413 | else: 414 | print('Not using distributed mode') 415 | args.distributed = False 416 | return 417 | 418 | args.distributed = True 419 | 420 | torch.cuda.set_device(args.gpu) 421 | args.dist_backend = 'nccl' 422 | print('| distributed init (rank {}): {}'.format( 423 | args.rank, args.dist_url), flush=True) 424 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 425 | world_size=args.world_size, rank=args.rank) 426 | torch.distributed.barrier() 427 | setup_for_distributed(args.rank == 0) 428 | 429 | 430 | @torch.no_grad() 431 | def accuracy(output, target, topk=(1,)): 432 | """Computes the precision@k for the specified values of k""" 433 | if target.numel() == 0: 434 | return [torch.zeros([], device=output.device)] 435 | maxk = max(topk) 436 | batch_size = target.size(0) 437 | 438 | _, pred = output.topk(maxk, 1, True, True) 439 | pred = pred.t() 440 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 441 | 442 | res = [] 443 | for k in topk: 444 | correct_k = correct[:k].view(-1).float().sum(0) 445 | res.append(correct_k.mul_(100.0 / batch_size)) 446 | return res 447 | 448 | 449 | def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): 450 | # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor 451 | """ 452 | Equivalent to nn.functional.interpolate, but with support for empty batch sizes. 453 | This will eventually be supported natively by PyTorch, and this 454 | class can go away. 455 | """ 456 | if version.parse(torchvision.__version__) < version.parse('0.7'): 457 | if input.numel() > 0: 458 | return torch.nn.functional.interpolate( 459 | input, size, scale_factor, mode, align_corners 460 | ) 461 | 462 | output_shape = _output_size(2, input, size, scale_factor) 463 | output_shape = list(input.shape[:-2]) + list(output_shape) 464 | return _new_empty_tensor(input, output_shape) 465 | else: 466 | return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) 467 | 468 | 469 | def inverse_sigmoid(x, eps=1e-4): 470 | x = x.clamp(min=0, max=1) 471 | x1 = x.clamp(min=eps) 472 | x2 = (1 - x).clamp(min=eps) 473 | return torch.log(x1/x2) 474 | 475 | 476 | -------------------------------------------------------------------------------- /modeling/transformer/util/plot_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Plotting utilities to visualize training logs. 3 | """ 4 | import torch 5 | import pandas as pd 6 | import numpy as np 7 | import seaborn as sns 8 | import matplotlib.pyplot as plt 9 | 10 | from pathlib import Path, PurePath 11 | 12 | 13 | def plot_logs(logs, fields=('class_error', 'loss_bbox_unscaled', 'mAP'), ewm_col=0, log_name='log.txt'): 14 | ''' 15 | Function to plot specific fields from training log(s). Plots both training and test results. 16 | :: Inputs - logs = list containing Path objects, each pointing to individual dir with a log file 17 | - fields = which results to plot from each log file - plots both training and test for each field. 18 | - ewm_col = optional, which column to use as the exponential weighted smoothing of the plots 19 | - log_name = optional, name of log file if different than default 'log.txt'. 20 | :: Outputs - matplotlib plots of results in fields, color coded for each log file. 21 | - solid lines are training results, dashed lines are test results. 22 | ''' 23 | func_name = "plot_utils.py::plot_logs" 24 | 25 | # verify logs is a list of Paths (list[Paths]) or single Pathlib object Path, 26 | # convert single Path to list to avoid 'not iterable' error 27 | 28 | if not isinstance(logs, list): 29 | if isinstance(logs, PurePath): 30 | logs = [logs] 31 | print(f"{func_name} info: logs param expects a list argument, converted to list[Path].") 32 | else: 33 | raise ValueError(f"{func_name} - invalid argument for logs parameter.\n \ 34 | Expect list[Path] or single Path obj, received {type(logs)}") 35 | 36 | # Quality checks - verify valid dir(s), that every item in list is Path object, and that log_name exists in each dir 37 | for i, dir in enumerate(logs): 38 | if not isinstance(dir, PurePath): 39 | raise ValueError(f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}") 40 | if not dir.exists(): 41 | raise ValueError(f"{func_name} - invalid directory in logs argument:\n{dir}") 42 | # verify log_name exists 43 | fn = Path(dir / log_name) 44 | if not fn.exists(): 45 | print(f"-> missing {log_name}. Have you gotten to Epoch 1 in training?") 46 | print(f"--> full path of missing log file: {fn}") 47 | return 48 | 49 | # load log file(s) and plot 50 | dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] 51 | 52 | fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) 53 | 54 | for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): 55 | for j, field in enumerate(fields): 56 | if field == 'mAP': 57 | coco_eval = pd.DataFrame( 58 | np.stack(df.test_coco_eval_bbox.dropna().values)[:, 1] 59 | ).ewm(com=ewm_col).mean() 60 | axs[j].plot(coco_eval, c=color) 61 | else: 62 | df.interpolate().ewm(com=ewm_col).mean().plot( 63 | y=[f'train_{field}', f'test_{field}'], 64 | ax=axs[j], 65 | color=[color] * 2, 66 | style=['-', '--'] 67 | ) 68 | for ax, field in zip(axs, fields): 69 | ax.legend([Path(p).name for p in logs]) 70 | ax.set_title(field) 71 | 72 | 73 | def plot_precision_recall(files, naming_scheme='iter'): 74 | if naming_scheme == 'exp_id': 75 | # name becomes exp_id 76 | names = [f.parts[-3] for f in files] 77 | elif naming_scheme == 'iter': 78 | names = [f.stem for f in files] 79 | else: 80 | raise ValueError(f'not supported {naming_scheme}') 81 | fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) 82 | for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): 83 | data = torch.load(f) 84 | # precision is n_iou, n_points, n_cat, n_area, max_det 85 | precision = data['precision'] 86 | recall = data['params'].recThrs 87 | scores = data['scores'] 88 | # take precision for all classes, all areas and 100 detections 89 | precision = precision[0, :, :, 0, -1].mean(1) 90 | scores = scores[0, :, :, 0, -1].mean(1) 91 | prec = precision.mean() 92 | rec = data['recall'][0, :, 0, -1].mean() 93 | print(f'{naming_scheme} {name}: mAP@50={prec * 100: 05.1f}, ' + 94 | f'score={scores.mean():0.3f}, ' + 95 | f'f1={2 * prec * rec / (prec + rec + 1e-8):0.3f}' 96 | ) 97 | axs[0].plot(recall, precision, c=color) 98 | axs[1].plot(recall, scores, c=color) 99 | 100 | axs[0].set_title('Precision / Recall') 101 | axs[0].legend(names) 102 | axs[1].set_title('Scores / Recall') 103 | axs[1].legend(names) 104 | return fig, axs -------------------------------------------------------------------------------- /modeling/transformer/util/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | from typing import Optional 4 | from pycocotools import mask as coco_mask 5 | 6 | 7 | # in util.box_ops 8 | def box_cxcywh_to_xyxy(x): 9 | x_c, y_c, w, h = x.unbind(-1) 10 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 11 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 12 | return torch.stack(b, dim=-1) 13 | 14 | 15 | def box_xyxy_to_cxcywh(x): 16 | x0, y0, x1, y1 = x.unbind(-1) 17 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 18 | (x1 - x0), (y1 - y0)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | # in util.misc 23 | class NestedTensor(object): 24 | def __init__(self, tensors, mask: Optional[Tensor]): 25 | self.tensors = tensors 26 | self.mask = mask 27 | 28 | def to(self, device): 29 | # type: (Device) -> NestedTensor # noqa 30 | cast_tensor = self.tensors.to(device) 31 | mask = self.mask 32 | if mask is not None: 33 | assert mask is not None 34 | cast_mask = mask.to(device) 35 | else: 36 | cast_mask = None 37 | return NestedTensor(cast_tensor, cast_mask) 38 | 39 | def decompose(self): 40 | return self.tensors, self.mask 41 | 42 | def __repr__(self): 43 | return str(self.tensors) 44 | 45 | 46 | # in datasets.coco 47 | def convert_coco_poly_to_mask(segmentations, height, width): 48 | masks = [] 49 | for polygons in segmentations: 50 | rles = coco_mask.frPyObjects(polygons, height, width) 51 | mask = coco_mask.decode(rles) 52 | if len(mask.shape) < 3: 53 | mask = mask[..., None] 54 | mask = torch.as_tensor(mask, dtype=torch.uint8) 55 | mask = mask.any(dim=2) 56 | masks.append(mask) 57 | if masks: 58 | masks = torch.stack(masks, dim=0) 59 | else: 60 | masks = torch.zeros((0, height, width), dtype=torch.uint8) 61 | return masks 62 | -------------------------------------------------------------------------------- /structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .boxes_ops import boxes_union 2 | from .masks_ops import masks_union -------------------------------------------------------------------------------- /structures/boxes_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from detectron2.structures.boxes import Boxes 4 | 5 | def boxes_union(boxes1, boxes2): 6 | """ 7 | Compute the union region of two set of boxes 8 | Arguments: 9 | box1: (Boxes) bounding boxes, sized [N,4]. 10 | box2: (Boxes) bounding boxes, sized [N,4]. 11 | Returns: 12 | (Boxes) union, sized [N,4]. 13 | """ 14 | assert len(boxes1) == len(boxes2) 15 | 16 | union_box = torch.cat(( 17 | torch.min(boxes1.tensor[:,:2], boxes2.tensor[:,:2]), 18 | torch.max(boxes1.tensor[:,2:], boxes2.tensor[:,2:]) 19 | ),dim=1) 20 | return Boxes(union_box) 21 | -------------------------------------------------------------------------------- /structures/masks_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from detectron2.structures.boxes import Boxes 4 | 5 | def masks_union(masks1, masks2): 6 | assert len(masks1) == len(masks2) 7 | masks_union = (masks1 + masks2)/2.0 8 | return masks_union 9 | -------------------------------------------------------------------------------- /train_iterative_model.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import torch 5 | sys.path.insert(0, '../../') 6 | sys.path.insert(0, '../') 7 | 8 | import detectron2.utils.comm as comm 9 | from detectron2.utils.logger import setup_logger 10 | from detectron2.engine import default_argument_parser, default_setup, launch 11 | from detectron2.config import get_cfg 12 | from detectron2.checkpoint import DetectionCheckpointer 13 | 14 | from IterativeSG.engine import JointTransformerTrainer 15 | from IterativeSG.data import VisualGenomeTrainData, register_datasets, DatasetCatalog, MetadataCatalog 16 | from IterativeSG.configs.defaults import add_dataset_config, add_scenegraph_config 17 | from IterativeSG.modeling import Detr 18 | from detectron2.data.datasets import register_coco_instances 19 | 20 | parser = default_argument_parser() 21 | 22 | def setup(args): 23 | cfg = get_cfg() 24 | add_dataset_config(cfg) 25 | add_scenegraph_config(cfg) 26 | assert(cfg.MODEL.ROI_SCENEGRAPH_HEAD.MODE in ['predcls', 'sgls', 'sgdet']), "Mode {} not supported".format(cfg.MODEL.ROI_SCENEGRaGraph.MODE) 27 | cfg.merge_from_file(args.config_file) 28 | cfg.merge_from_list(args.opts) 29 | cfg.freeze() 30 | register_datasets(cfg) 31 | # register_coco_data(cfg) 32 | default_setup(cfg, args) 33 | 34 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="LSDA") 35 | return cfg 36 | 37 | def main(args): 38 | cfg = setup(args) 39 | if args.eval_only: 40 | model = JointTransformerTrainer.build_model(cfg) 41 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 42 | cfg.MODEL.WEIGHTS, resume=args.resume 43 | ) 44 | res = JointTransformerTrainer.test(cfg, model) 45 | # if comm.is_main_process(): 46 | # verify_results(cfg, res) 47 | return res 48 | 49 | trainer = JointTransformerTrainer(cfg) 50 | trainer.resume_or_load(resume=args.resume) 51 | return trainer.train() 52 | 53 | if __name__ == '__main__': 54 | args = parser.parse_args() 55 | try: 56 | # use the last 4 numbers in the job id as the id 57 | default_port = os.environ['SLURM_JOB_ID'] 58 | default_port = default_port[-4:] 59 | 60 | # all ports should be in the 10k+ range 61 | default_port = int(default_port) + 15000 62 | except Exception: 63 | default_port = 59482 64 | 65 | args.dist_url = 'tcp://127.0.0.1:'+str(default_port) 66 | print(args) 67 | 68 | launch( 69 | main, 70 | args.num_gpus, 71 | num_machines=args.num_machines, 72 | machine_rank=args.machine_rank, 73 | dist_url=args.dist_url, 74 | args=(args,), 75 | ) --------------------------------------------------------------------------------