├── .gitignore ├── INSTALL.md ├── README.md ├── configs ├── Base-RCNN-FPN.yaml ├── pretrain_baseline.yaml ├── pretrain_object_detector_coco.yaml ├── pretrain_object_detector_vgg_coco.yaml ├── sg_baseline.yaml ├── sg_dev_masktransfer.yaml └── sg_dev_masktransfer_vgg.yaml ├── glove └── glove ├── scripts ├── pretrain_object_detector_baseline.py ├── pretrain_object_detector_withcoco.py ├── train_SG_baseline.py └── train_SG_segmentation_head.py ├── segmentationsg ├── __init__.py ├── checkpoint │ ├── __init__.py │ └── detection_checkpoint.py ├── data │ ├── __init__.py │ ├── dataset_mapper.py │ ├── datasets │ │ ├── __init__.py │ │ ├── coco.py │ │ ├── images_to_remove.txt │ │ ├── openimage_preprocess.py │ │ ├── openimages.py │ │ └── visual_genome.py │ ├── embeddings │ │ ├── glove_coco │ │ ├── glove_mean_coco │ │ ├── glove_mean_open_images │ │ ├── glove_mean_vg │ │ ├── glove_open_images │ │ └── glove_vg │ ├── get_embeddings.py │ ├── remove_intersection.py │ └── tools │ │ ├── __init__.py │ │ ├── config.py │ │ └── utils.py ├── engine │ ├── __init__.py │ ├── sg_trainer.py │ └── trainer.py ├── evaluation │ ├── __init__.py │ ├── coco_evaluation.py │ ├── datasets │ │ └── vg │ │ │ └── zeroshot_triplet.pytorch │ ├── evaluator.py │ ├── sg_evaluation.py │ └── utils.py ├── modeling │ ├── __init__.py │ ├── backbone │ │ ├── __init__.py │ │ └── vgg.py │ ├── meta_arch │ │ ├── __init__.py │ │ └── rcnn.py │ └── roi_heads │ │ ├── __init__.py │ │ ├── box_head.py │ │ ├── fast_rcnn.py │ │ ├── mask_head.py │ │ ├── roi_heads.py │ │ └── scenegraph_head │ │ ├── __init__.py │ │ ├── box_feature_extractor.py │ │ ├── defaults.py │ │ ├── imp │ │ ├── __init__.py │ │ ├── make_layers.py │ │ └── model_imp.py │ │ ├── inference.py │ │ ├── loss.py │ │ ├── motif │ │ ├── __init__.py │ │ ├── model_motifs.py │ │ └── utils_motifs.py │ │ ├── relation_feature_extractor.py │ │ ├── sampling.py │ │ ├── scenegraph_head.py │ │ ├── scenegraph_predictor.py │ │ ├── transformer │ │ ├── __init__.py │ │ └── model_transformer.py │ │ ├── utils.py │ │ └── vctree │ │ ├── __init__.py │ │ ├── model_vctree.py │ │ ├── utils_treelstm.py │ │ └── utils_vctree.py └── structures │ ├── __init__.py │ ├── boxes_ops.py │ └── masks_ops.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /INSTALL.md: -------------------------------------------------------------------------------- 1 | ## Requirements 2 | In order to run the code, the following libraries are required: 3 | - python >= 3.7 4 | - PyTorch >= 1.6 5 | - detectron2 >= 0.3 6 | - cv2 >= 4.4.0 7 | - scikit-learn >= 0.23.2 8 | - imantics 9 | - easydict 10 | - h5py 11 | 12 | **Note**: Please make sure the detectron2 version correctly corresponds to the pytorch and cuda versions installed. Please check [this page](https://detectron2.readthedocs.io/en/latest/tutorials/install.html) for installation instructions and common issues. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Segmentation-Grounded Scene Graph Generation 2 | 3 | This repository contains the code for the CVPR 2021 paper titled [**"Segmentation-Grounded Scene Graph Generation"**](https://arxiv.org/pdf/2104.14207.pdf). 4 | 5 | ## Bibtext 6 | ``` 7 | @inproceedings{khandelwal2021segmentation, 8 | title={Segmentation-grounded Scene Graph Generation}, 9 | author={Khandelwal, Siddhesh and Suhail, Mohammed and Sigal, Leonid}, 10 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 11 | year={2021} 12 | } 13 | ``` 14 | 15 | ## Requirements 16 | To setup the environment with all the required dependencies, follow the steps detailed in [INSTALL.md](https://github.com/ubc-vision/UniT/blob/main/INSTALL.md). Additionally, please rename the cloned repository from `segmentation-sg` to `segmentationsg`. 17 | 18 | ## Prepare Dataset 19 | The approach requires access to Visual Genome and MS-COCO datasets. 20 | - MS-COCO is publicly available [here](https://cocodataset.org/#download). We use the 2017 Train/Val splits in our experiments. 21 | - We use the Visual Genome filtered data widely used in the Scene Graph community. Please see the [Unbiased Scene Graph Generation repo](https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch/blob/master/DATASET.md) on instructions to download this dataset. 22 | 23 | ## Pretrain Object Detector 24 | Before the scene graph models can be trained, the first step involves jointly pre-training the object detector to accurately predict bounding boxes on Visual Genome and segmentation masks on MS-COCO. 25 | 26 | If using the ResNeXt-101 backbone, the pre-training can be achieved by running the following command 27 | ```python 28 | python pretrain_object_detector_withcoco.py --config-file ../configs/pretrain_object_detector_coco.yaml --num-gpus 4 --resume DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 DATASETS.MSCOCO.ANNOTATIONS DATASETS.MSCOCO.DATAROOT OUTPUT_DIR 29 | ``` 30 | 31 | If using the VGG-16 backbone, the pre-training can be achieved by running the following command 32 | ```python 33 | python pretrain_object_detector_withcoco.py --config-file ../configs/pretrain_object_detector_vgg_coco.yaml --num-gpus 4 --resume DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 DATASETS.MSCOCO.ANNOTATIONS DATASETS.MSCOCO.DATAROOT OUTPUT_DIR 34 | ``` 35 | 36 | The jointly trained pre-trained weights can be found [here](https://drive.google.com/drive/folders/1YZ3ipSi_ao_Xl9UsMBbmro7sp2mi8bqr?usp=sharing). 37 | 38 | ## Train Scene Graph Model 39 | Once the object detector pre-training is complete, prepare the pre-training weights to be used with scene graph training. Run the following script to achieve this 40 | ```python 41 | import torch 42 | pretrain_model = torch.load('') 43 | pretrain_weight = {} 44 | pretrain_weight['model'] = pretrain_model['model'] 45 | with open('', 'wb') as f: 46 | torch.save(pretrain_weight, f) 47 | 48 | ``` 49 | 50 | Depending on the task, the scene graph training can then be run as follows. The training scripts are available in the `scripts` folder. 51 | * Predicate Classification (PredCls) 52 | 53 | ```python 54 | python train_SG_segmentation_head.py --config-file ../configs/sg_dev_masktransfer.yaml --num-gpus 4 --resume DATALOADER.NUM_WORKERS 2 \ 55 | MODEL.WEIGHTS \ 56 | OUTPUT_DIR \ 57 | DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 \ 58 | DATASETS.MSCOCO.ANNOTATIONS DATASETS.MSCOCO.DATAROOT \ 59 | MODEL.MASK_ON True \ 60 | MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX True MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL True \ 61 | MODEL.ROI_SCENEGRAPH_HEAD.USE_MASK_ATTENTION True MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE 'Weighted' \ 62 | MODEL.ROI_SCENEGRAPH_HEAD.SIGMOID_ATTENTION True TEST.EVAL_PERIOD 100000 \ 63 | MODEL.ROI_RELATION_FEATURE_EXTRACTORS.MULTIPLY_LOGITS_WITH_MASKS False \ 64 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK True \ 65 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.CLASS_LOGITS_WITH_MASK False SOLVER.IMS_PER_BATCH 16 DATASETS.SEG_DATA_DIVISOR 2 \ 66 | MODEL.ROI_SCENEGRAPH_HEAD.PREDICTOR 'MotifSegmentationPredictorC' MODEL.ROI_HEADS.REFINE_SEG_MASKS False 67 | ``` 68 | 69 | - SceneGraph Classification (SGCls) 70 | ```python 71 | python train_SG_segmentation_head.py --config-file ../configs/sg_dev_masktransfer.yaml --num-gpus 4 --resume DATALOADER.NUM_WORKERS 2 \ 72 | MODEL.WEIGHTS \ 73 | OUTPUT_DIR \ 74 | DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 \ 75 | DATASETS.MSCOCO.ANNOTATIONS DATASETS.MSCOCO.DATAROOT \ 76 | MODEL.MASK_ON True \ 77 | MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX True MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL False \ 78 | MODEL.ROI_SCENEGRAPH_HEAD.USE_MASK_ATTENTION True MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE 'Weighted' \ 79 | MODEL.ROI_SCENEGRAPH_HEAD.SIGMOID_ATTENTION True TEST.EVAL_PERIOD 100000 \ 80 | MODEL.ROI_RELATION_FEATURE_EXTRACTORS.MULTIPLY_LOGITS_WITH_MASKS False \ 81 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK True \ 82 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.CLASS_LOGITS_WITH_MASK False SOLVER.IMS_PER_BATCH 16 DATASETS.SEG_DATA_DIVISOR 2 \ 83 | MODEL.ROI_SCENEGRAPH_HEAD.PREDICTOR 'MotifSegmentationPredictorC' MODEL.ROI_HEADS.REFINE_SEG_MASKS False 84 | ``` 85 | 86 | - SceneGraph Prediction (SGPred) 87 | ```python 88 | python train_SG_segmentation_head.py --config-file ../configs/sg_dev_masktransfer.yaml --num-gpus 4 --resume DATALOADER.NUM_WORKERS 2 \ 89 | MODEL.WEIGHTS \ 90 | OUTPUT_DIR \ 91 | DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 \ 92 | DATASETS.MSCOCO.ANNOTATIONS DATASETS.MSCOCO.DATAROOT \ 93 | MODEL.MASK_ON True \ 94 | MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX False MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL False \ 95 | MODEL.ROI_SCENEGRAPH_HEAD.USE_MASK_ATTENTION True MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE 'Weighted' \ 96 | MODEL.ROI_SCENEGRAPH_HEAD.SIGMOID_ATTENTION True TEST.EVAL_PERIOD 100000 \ 97 | MODEL.ROI_RELATION_FEATURE_EXTRACTORS.MULTIPLY_LOGITS_WITH_MASKS False \ 98 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK True \ 99 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.CLASS_LOGITS_WITH_MASK False SOLVER.IMS_PER_BATCH 16 DATASETS.SEG_DATA_DIVISOR 2 \ 100 | MODEL.ROI_SCENEGRAPH_HEAD.PREDICTOR 'MotifSegmentationPredictorC' MODEL.ROI_HEADS.REFINE_SEG_MASKS False TEST.DETECTIONS_PER_IMAGE 40 101 | ``` 102 | 103 | Note that these commands augment our approach to Neural Motifs with ResNeXt 101 backbone. To use VCTree, use 104 | ```python 105 | MODEL.ROI_SCENEGRAPH_HEAD.PREDICTOR 'VCTreeSegmentationPredictorC' 106 | ``` 107 | To use VGG-16 backbone, use 108 | ```python 109 | --config-file ../configs/sg_dev_masktransfer_vgg.yaml 110 | ``` 111 | 112 | ## Evaluation 113 | 114 | Evaluation can be done using the `--eval-only` flag. For example, evaluation can be run on the PredCLS model as follows, 115 | ```python 116 | python train_SG_segmentation_head.py --eval-only --config-file ../configs/sg_dev_masktransfer.yaml --num-gpus 4 --resume DATALOADER.NUM_WORKERS 2 \ 117 | MODEL.WEIGHTS \ 118 | OUTPUT_DIR \ 119 | DATASETS.VISUAL_GENOME.IMAGES DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY DATASETS.VISUAL_GENOME.IMAGE_DATA DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 \ 120 | DATASETS.MSCOCO.ANNOTATIONS DATASETS.MSCOCO.DATAROOT \ 121 | MODEL.MASK_ON True \ 122 | MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX True MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL True \ 123 | MODEL.ROI_SCENEGRAPH_HEAD.USE_MASK_ATTENTION True MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE 'Weighted' \ 124 | MODEL.ROI_SCENEGRAPH_HEAD.SIGMOID_ATTENTION True TEST.EVAL_PERIOD 100000 \ 125 | MODEL.ROI_RELATION_FEATURE_EXTRACTORS.MULTIPLY_LOGITS_WITH_MASKS False \ 126 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK True \ 127 | MODEL.ROI_BOX_FEATURE_EXTRACTORS.CLASS_LOGITS_WITH_MASK False SOLVER.IMS_PER_BATCH 16 DATASETS.SEG_DATA_DIVISOR 2 \ 128 | MODEL.ROI_SCENEGRAPH_HEAD.PREDICTOR 'MotifSegmentationPredictorC' MODEL.ROI_HEADS.REFINE_SEG_MASKS False 129 | ``` 130 | 131 | **Note**: The default training/testing assumes 4 GPUs. It can be modified to suit other GPU configurations, but would require changing the learning rate and batch sizes accordingly. Please look at `SOLVER.REFERENCE_WORLD_SIZE` parameter in the [detectron2 configurations](https://detectron2.readthedocs.io/en/latest/modules/config.html#config-references) for details on how this can be done automatically. 132 | 133 | -------------------------------------------------------------------------------- /configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | BACKBONE: 4 | NAME: "build_resnet_fpn_backbone" 5 | RESNETS: 6 | OUT_FEATURES: ["res2", "res3", "res4", "res5"] 7 | FPN: 8 | IN_FEATURES: ["res2", "res3", "res4", "res5"] 9 | ANCHOR_GENERATOR: 10 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 11 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 12 | RPN: 13 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 14 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 15 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 16 | # Detectron1 uses 2000 proposals per-batch, 17 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 18 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 19 | POST_NMS_TOPK_TRAIN: 1000 20 | POST_NMS_TOPK_TEST: 1000 21 | ROI_HEADS: 22 | NAME: "StandardROIHeads" 23 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 24 | ROI_BOX_HEAD: 25 | NAME: "FastRCNNConvFCHead" 26 | NUM_FC: 2 27 | POOLER_RESOLUTION: 7 28 | ROI_MASK_HEAD: 29 | NAME: "MaskRCNNConvUpsampleHead" 30 | NUM_CONV: 4 31 | POOLER_RESOLUTION: 14 32 | DATASETS: 33 | TRAIN: ("coco_2017_train",) 34 | TEST: ("coco_2017_val",) 35 | SOLVER: 36 | IMS_PER_BATCH: 16 37 | BASE_LR: 0.02 38 | STEPS: (60000, 80000) 39 | MAX_ITER: 90000 40 | INPUT: 41 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 42 | VERSION: 2 -------------------------------------------------------------------------------- /configs/pretrain_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-FPN.yaml" 2 | MODEL: 3 | MASK_ON: False 4 | WEIGHTS: "../models/model_final_2d9806.pkl" 5 | # WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" 6 | PIXEL_STD: [57.375, 57.120, 58.395] 7 | RESNETS: 8 | STRIDE_IN_1X1: False # this is a C2 model 9 | NUM_GROUPS: 32 10 | WIDTH_PER_GROUP: 8 11 | DEPTH: 101 12 | ROI_HEADS: 13 | NUM_CLASSES: 150 14 | EMBEDDINGS_PATH: "../data/embeddings/glove_mean_vg" 15 | EMBEDDINGS_PATH_COCO: "../data/embeddings/glove_mean_coco" 16 | DATASETS: 17 | TYPE: "VISUAL GENOME" 18 | TRAIN: ('VG_train',) 19 | TEST: ('VG_val',) 20 | VISUAL_GENOME: 21 | TRAIN_MASKS: '' 22 | VAL_MASKS: '' 23 | TEST_MASKS: '' 24 | TRAIN_MASKS: '' 25 | FILTER_EMPTY_RELATIONS: False 26 | FILTER_NON_OVERLAP: False 27 | FILTER_DUPLICATE_RELATIONS: False 28 | DATALOADER: 29 | NUM_WORKERS: 2 30 | SOLVER: 31 | IMS_PER_BATCH: 16 32 | CHECKPOINT_PERIOD: 500 33 | STEPS: (210000, 250000) 34 | MAX_ITER: 270000 35 | TEST: 36 | EVAL_PERIOD: 10000 -------------------------------------------------------------------------------- /configs/pretrain_object_detector_coco.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-FPN.yaml" 2 | MODEL: 3 | MASK_ON: True 4 | META_ARCHITECTURE: 'GeneralizedRCNNWithCOCO' 5 | WEIGHTS: "../models/model_final_2d9806.pkl" 6 | PIXEL_STD: [57.375, 57.120, 58.395] 7 | RESNETS: 8 | STRIDE_IN_1X1: False # this is a C2 model 9 | NUM_GROUPS: 32 10 | WIDTH_PER_GROUP: 8 11 | DEPTH: 101 12 | ROI_HEADS: 13 | NUM_CLASSES: 150 14 | EMBEDDINGS_PATH: "../data/embeddings/glove_mean_vg" 15 | EMBEDDINGS_PATH_COCO: "../data/embeddings/glove_mean_coco" 16 | NAME: 'StandardROIHeadsWithCOCO' 17 | ROI_MASK_HEAD: 18 | NAME: 'MaskRCNNConvUpsampleHeadwithCOCO' 19 | DATASETS: 20 | TYPE: "VISUAL GENOME" 21 | TRAIN: ('VG_train',) 22 | TEST: ('coco_val_2017', 'VG_val',) 23 | VISUAL_GENOME: 24 | TRAIN_MASKS: '' 25 | VAL_MASKS: '' 26 | TEST_MASKS: '' 27 | TRAIN_MASKS: '' 28 | FILTER_EMPTY_RELATIONS: False 29 | FILTER_NON_OVERLAP: False 30 | FILTER_DUPLICATE_RELATIONS: False 31 | DATALOADER: 32 | NUM_WORKERS: 2 33 | SOLVER: 34 | IMS_PER_BATCH: 8 35 | CHECKPOINT_PERIOD: 500 36 | STEPS: (210000, 250000) 37 | MAX_ITER: 270000 38 | TEST: 39 | EVAL_PERIOD: 10000 -------------------------------------------------------------------------------- /configs/pretrain_object_detector_vgg_coco.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | PIXEL_MEAN: (123.675, 116.280, 103.530) #RGB 3 | PIXEL_STD: (58.395, 57.120, 57.375) #RGB 4 | MASK_ON: True 5 | META_ARCHITECTURE: 'GeneralizedRCNNWithCOCO' 6 | WEIGHTS: "" 7 | BACKBONE: 8 | NAME: 'VGG' 9 | ROI_HEADS: 10 | NUM_CLASSES: 150 11 | EMBEDDINGS_PATH: "../data/embeddings/glove_mean_vg" 12 | EMBEDDINGS_PATH_COCO: "../data/embeddings/glove_mean_coco" 13 | NAME: 'StandardROIHeadsWithCOCO' 14 | IN_FEATURES: ['vgg_conv'] 15 | ROI_MASK_HEAD: 16 | NAME: 'MaskRCNNConvUpsampleHeadwithCOCO' 17 | ROI_BOX_HEAD: 18 | NAME: 'VGGConvFCHead' 19 | POOLER_RESOLUTION: 7 20 | POOLER_TYPE: "ROIPool" 21 | RPN: 22 | IN_FEATURES: ['vgg_conv'] 23 | DATASETS: 24 | TYPE: "VISUAL GENOME" 25 | TRAIN: ('VG_train',) 26 | TEST: ('coco_val_2017', 'VG_val',) 27 | VISUAL_GENOME: 28 | TRAIN_MASKS: '' 29 | VAL_MASKS: '' 30 | TEST_MASKS: '' 31 | TRAIN_MASKS: '' 32 | FILTER_EMPTY_RELATIONS: False 33 | FILTER_NON_OVERLAP: False 34 | FILTER_DUPLICATE_RELATIONS: False 35 | DATALOADER: 36 | NUM_WORKERS: 2 37 | SOLVER: 38 | BASE_LR: 0.02 39 | IMS_PER_BATCH: 8 40 | CHECKPOINT_PERIOD: 500 41 | STEPS: (210000, 250000) 42 | MAX_ITER: 270000 43 | TEST: 44 | EVAL_PERIOD: 10000 45 | INPUT: 46 | FORMAT: "RGB" -------------------------------------------------------------------------------- /configs/sg_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-FPN.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: 'SceneGraphRCNN' 4 | MASK_ON: False 5 | USE_MASK_ON_NODE: False 6 | WEIGHTS: "" 7 | PIXEL_STD: [57.375, 57.120, 58.395] 8 | RESNETS: 9 | STRIDE_IN_1X1: False # this is a C2 model 10 | NUM_GROUPS: 32 11 | WIDTH_PER_GROUP: 8 12 | DEPTH: 101 13 | ROI_HEADS: 14 | NAME: 'StandardSGROIHeads' 15 | NUM_CLASSES: 150 16 | EMBEDDINGS_PATH: "../data/embeddings/glove_mean_vg" 17 | EMBEDDINGS_PATH_COCO: "../data/embeddings/glove_mean_coco" 18 | ROI_MASK_HEAD: 19 | NAME: 'SceneGraphMaskHead' 20 | FREEZE_LAYERS: 21 | META_ARCH: [proposal_generator, backbone] 22 | ROI_HEADS: [box_pooler, box_head, box_predictor, mask_pooler, mask_head, keypoint_pooler, keypoint_head] 23 | DATASETS: 24 | TYPE: "VISUAL GENOME" 25 | TRAIN: ('VG_train',) 26 | TEST: ('VG_test',) 27 | VISUAL_GENOME: 28 | TRAIN_MASKS: "" 29 | VAL_MASKS: "" 30 | TEST_MASKS: "" 31 | TRAIN_MASKS: "" 32 | FILTER_EMPTY_RELATIONS: True 33 | FILTER_NON_OVERLAP: True 34 | FILTER_DUPLICATE_RELATIONS: True 35 | DATALOADER: 36 | NUM_WORKERS: 2 37 | SOLVER: 38 | IMS_PER_BATCH: 16 39 | BIAS_LR_FACTOR: 1.0 40 | BASE_LR: 0.01 41 | WARMUP_FACTOR: 0.1 42 | WEIGHT_DECAY: 0.0001 43 | MOMENTUM: 0.9 44 | STEPS: (10000, 16000) 45 | MAX_ITER: 40000 46 | CHECKPOINT_PERIOD: 500 47 | CLIP_GRADIENTS: 48 | CLIP_VALUE: 5.0 49 | TEST: 50 | EVAL_PERIOD: 5000 51 | GLOVE_DIR: '../glove/' -------------------------------------------------------------------------------- /configs/sg_dev_masktransfer.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-FPN.yaml" 2 | MODEL: 3 | META_ARCHITECTURE: 'SceneGraphSegmentationRCNN' 4 | MASK_ON: True 5 | USE_MASK_ON_NODE: False 6 | WEIGHTS: "" 7 | PIXEL_STD: [57.375, 57.120, 58.395] 8 | RESNETS: 9 | STRIDE_IN_1X1: False # this is a C2 model 10 | NUM_GROUPS: 32 11 | WIDTH_PER_GROUP: 8 12 | DEPTH: 101 13 | ROI_HEADS: 14 | NAME: 'SGSegmentationROIHeadsMaskTransfer' 15 | NUM_CLASSES: 150 16 | EMBEDDINGS_PATH: "../data/embeddings/glove_mean_vg" 17 | EMBEDDINGS_PATH_COCO: "../data/embeddings/glove_mean_coco" 18 | ROI_SCENEGRAPH_HEAD: 19 | NAME: 'SceneGraphSegmentationHead' 20 | PREDICTOR: 'MotifSegmentationPredictor' 21 | ROI_MASK_HEAD: 22 | NAME: 'SceneGraphMaskHeadTransfer' 23 | FREEZE_LAYERS: 24 | META_ARCH: [proposal_generator, backbone] 25 | ROI_HEADS: [box_pooler, box_head, box_predictor, mask_pooler, mask_head, keypoint_pooler, keypoint_head] 26 | ROI_BOX_FEATURE_EXTRACTORS: 27 | NAME: 'BoxFeatureSegmentationExtractor' 28 | DATASETS: 29 | TYPE: "VISUAL GENOME" 30 | TRAIN: ('VG_train',) 31 | TEST: ('coco_val_2017','VG_test',) 32 | VISUAL_GENOME: 33 | TRAIN_MASKS: "" 34 | VAL_MASKS: "" 35 | TEST_MASKS: "" 36 | TRAIN_MASKS: "" 37 | FILTER_EMPTY_RELATIONS: True 38 | FILTER_NON_OVERLAP: True 39 | FILTER_DUPLICATE_RELATIONS: True 40 | DATALOADER: 41 | NUM_WORKERS: 2 42 | SOLVER: 43 | IMS_PER_BATCH: 16 44 | BIAS_LR_FACTOR: 1.0 45 | BASE_LR: 0.01 46 | WARMUP_FACTOR: 0.1 47 | WEIGHT_DECAY: 0.0001 48 | MOMENTUM: 0.9 49 | STEPS: (25000, 35000) 50 | MAX_ITER: 40000 51 | CHECKPOINT_PERIOD: 500 52 | CLIP_GRADIENTS: 53 | CLIP_VALUE: 5.0 54 | TEST: 55 | EVAL_PERIOD: 5000 56 | GLOVE_DIR: '../glove/' 57 | -------------------------------------------------------------------------------- /configs/sg_dev_masktransfer_vgg.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: 'SceneGraphSegmentationRCNN' 3 | MASK_ON: True 4 | USE_MASK_ON_NODE: False 5 | WEIGHTS: "" 6 | PIXEL_MEAN: (123.675, 116.280, 103.530) #RGB 7 | PIXEL_STD: (58.395, 57.120, 57.375) #RGB 8 | BACKBONE: 9 | NAME: 'VGG' 10 | ROI_HEADS: 11 | NAME: 'SGSegmentationROIHeadsMaskTransfer' 12 | NUM_CLASSES: 150 13 | EMBEDDINGS_PATH: "../data/embeddings/glove_mean_vg" 14 | EMBEDDINGS_PATH_COCO: "../data/embeddings/glove_mean_coco" 15 | IN_FEATURES: ['vgg_conv'] 16 | ROI_SCENEGRAPH_HEAD: 17 | NAME: 'SceneGraphSegmentationHead' 18 | PREDICTOR: 'MotifSegmentationPredictor' 19 | ROI_MASK_HEAD: 20 | NAME: 'SceneGraphMaskHeadTransfer' 21 | FREEZE_LAYERS: 22 | META_ARCH: [proposal_generator, backbone] 23 | ROI_HEADS: [box_pooler, box_head, box_predictor, mask_pooler, mask_head, keypoint_pooler, keypoint_head] 24 | ROI_BOX_FEATURE_EXTRACTORS: 25 | NAME: 'BoxFeatureSegmentationExtractor' 26 | ROI_BOX_HEAD: 27 | NAME: 'VGGConvFCHead' 28 | POOLER_RESOLUTION: 7 29 | POOLER_TYPE: "ROIPool" 30 | RPN: 31 | IN_FEATURES: ['vgg_conv'] 32 | DATASETS: 33 | TYPE: "VISUAL GENOME" 34 | TRAIN: ('VG_train',) 35 | TEST: ('coco_val_2017','VG_test',) 36 | VISUAL_GENOME: 37 | TRAIN_MASKS: "" 38 | VAL_MASKS: "" 39 | TEST_MASKS: "" 40 | TRAIN_MASKS: "" 41 | FILTER_EMPTY_RELATIONS: True 42 | FILTER_NON_OVERLAP: True 43 | FILTER_DUPLICATE_RELATIONS: True 44 | DATALOADER: 45 | NUM_WORKERS: 2 46 | SOLVER: 47 | IMS_PER_BATCH: 16 48 | BIAS_LR_FACTOR: 1.0 49 | BASE_LR: 0.01 50 | WARMUP_FACTOR: 0.1 51 | WEIGHT_DECAY: 0.0001 52 | MOMENTUM: 0.9 53 | STEPS: (25000, 35000) 54 | MAX_ITER: 40000 55 | CHECKPOINT_PERIOD: 500 56 | CLIP_GRADIENTS: 57 | CLIP_VALUE: 5.0 58 | TEST: 59 | EVAL_PERIOD: 5000 60 | GLOVE_DIR: '../glove/' 61 | INPUT: 62 | FORMAT: "RGB" 63 | -------------------------------------------------------------------------------- /glove/glove: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/glove/glove -------------------------------------------------------------------------------- /scripts/pretrain_object_detector_baseline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import torch 5 | import logging 6 | 7 | import detectron2.utils.comm as comm 8 | from detectron2.utils.logger import setup_logger, log_every_n_seconds 9 | from detectron2.engine import default_argument_parser, default_setup, launch 10 | from detectron2.config import get_cfg 11 | from detectron2.evaluation import DatasetEvaluators, DatasetEvaluator, inference_on_dataset, print_csv_format, inference_context 12 | from detectron2.checkpoint import DetectionCheckpointer 13 | 14 | from segmentationsg.engine import ObjectDetectorTrainer 15 | from segmentationsg.data import add_dataset_config, VisualGenomeTrainData, register_datasets 16 | 17 | 18 | parser = default_argument_parser() 19 | 20 | def setup(args): 21 | cfg = get_cfg() 22 | add_dataset_config(cfg) 23 | cfg.merge_from_file(args.config_file) 24 | cfg.merge_from_list(args.opts) 25 | cfg.freeze() 26 | register_datasets(cfg) 27 | default_setup(cfg, args) 28 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="LSDA") 29 | return cfg 30 | 31 | def main(args): 32 | cfg = setup(args) 33 | if args.eval_only: 34 | model = ObjectDetectorTrainer.build_model(cfg) 35 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 36 | cfg.MODEL.WEIGHTS, resume=args.resume 37 | ) 38 | res = ObjectDetectorTrainer.test(cfg, model) 39 | if comm.is_main_process(): 40 | verify_results(cfg, res) 41 | return res 42 | trainer = ObjectDetectorTrainer(cfg) 43 | trainer.resume_or_load(resume=args.resume) 44 | return trainer.train() 45 | 46 | if __name__ == '__main__': 47 | args = parser.parse_args() 48 | try: 49 | # use the last 4 numbers in the job id as the id 50 | default_port = os.environ['SLURM_JOB_ID'] 51 | default_port = default_port[-4:] 52 | 53 | # all ports should be in the 10k+ range 54 | default_port = int(default_port) + 15000 55 | 56 | except Exception: 57 | default_port = 59482 58 | 59 | args.dist_url = 'tcp://127.0.0.1:'+str(default_port) 60 | print (args) 61 | 62 | launch( 63 | main, 64 | args.num_gpus, 65 | num_machines=args.num_machines, 66 | machine_rank=args.machine_rank, 67 | dist_url=args.dist_url, 68 | args=(args,), 69 | ) 70 | -------------------------------------------------------------------------------- /scripts/pretrain_object_detector_withcoco.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import torch 5 | import logging 6 | 7 | import detectron2.utils.comm as comm 8 | from detectron2.utils.logger import setup_logger, log_every_n_seconds 9 | from detectron2.engine import default_argument_parser, default_setup, launch 10 | from detectron2.config import get_cfg 11 | from detectron2.evaluation import DatasetEvaluators, DatasetEvaluator, inference_on_dataset, print_csv_format, inference_context 12 | from detectron2.checkpoint import DetectionCheckpointer 13 | 14 | from segmentationsg.engine import ObjectDetectorTrainerWithCoco 15 | from segmentationsg.data import add_dataset_config, VisualGenomeTrainData, register_datasets 16 | from detectron2.data.datasets import register_coco_instances 17 | from segmentationsg.modeling import * 18 | parser = default_argument_parser() 19 | 20 | def register_coco_data(args): 21 | annotations = args.DATASETS.MSCOCO.ANNOTATIONS 22 | dataroot = args.DATASETS.MSCOCO.DATAROOT 23 | register_coco_instances("coco_train_2017", {}, annotations + 'instances_train2017.json', dataroot + '/train2017/') 24 | register_coco_instances("coco_val_2017", {}, annotations + 'instances_val2017.json', dataroot + '/val2017/') 25 | 26 | def setup(args): 27 | cfg = get_cfg() 28 | add_dataset_config(cfg) 29 | cfg.merge_from_file(args.config_file) 30 | cfg.merge_from_list(args.opts) 31 | cfg.freeze() 32 | register_datasets(cfg) 33 | register_coco_data(cfg) 34 | default_setup(cfg, args) 35 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="LSDA") 36 | return cfg 37 | 38 | def main(args): 39 | cfg = setup(args) 40 | if args.eval_only: 41 | model = ObjectDetectorTrainerWithCoco.build_model(cfg) 42 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 43 | cfg.MODEL.WEIGHTS, resume=args.resume 44 | ) 45 | res = ObjectDetectorTrainerWithCoco.test(cfg, model) 46 | if comm.is_main_process(): 47 | verify_results(cfg, res) 48 | return res 49 | trainer = ObjectDetectorTrainerWithCoco(cfg) 50 | trainer.resume_or_load(resume=args.resume) 51 | return trainer.train() 52 | 53 | if __name__ == '__main__': 54 | args = parser.parse_args() 55 | try: 56 | # use the last 4 numbers in the job id as the id 57 | default_port = os.environ['SLURM_JOB_ID'] 58 | default_port = default_port[-4:] 59 | 60 | # all ports should be in the 10k+ range 61 | default_port = int(default_port) + 15000 62 | 63 | except Exception: 64 | default_port = 59482 65 | 66 | args.dist_url = 'tcp://127.0.0.1:'+str(default_port) 67 | print (args) 68 | 69 | launch( 70 | main, 71 | args.num_gpus, 72 | num_machines=args.num_machines, 73 | machine_rank=args.machine_rank, 74 | dist_url=args.dist_url, 75 | args=(args,), 76 | ) 77 | -------------------------------------------------------------------------------- /scripts/train_SG_baseline.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | import detectron2.utils.comm as comm 7 | from detectron2.utils.logger import setup_logger 8 | from detectron2.engine import default_argument_parser, default_setup, launch 9 | from detectron2.config import get_cfg 10 | from detectron2.checkpoint import DetectionCheckpointer 11 | 12 | from segmentationsg.engine import SceneGraphTrainer, TestTrainer 13 | from segmentationsg.data import add_dataset_config, VisualGenomeTrainData, register_datasets 14 | from segmentationsg.modeling.roi_heads.scenegraph_head import add_scenegraph_config 15 | from detectron2.data import DatasetCatalog, MetadataCatalog 16 | 17 | parser = default_argument_parser() 18 | 19 | def register_coco_data(args): 20 | # annotations = json.load(open('/h/skhandel/SceneGraph/data/coco/instances_train2014.json', 'r')) 21 | # classes = [x['name'] for x in annotations['categories']] 22 | classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] 23 | MetadataCatalog.get('coco_train_2014').set(thing_classes=classes, evaluator_type='coco') 24 | 25 | def setup(args): 26 | cfg = get_cfg() 27 | add_dataset_config(cfg) 28 | add_scenegraph_config(cfg) 29 | assert(cfg.MODEL.ROI_SCENEGRAPH_HEAD.MODE in ['predcls', 'sgls', 'sgdet']) , "Mode {} not supported".format(cfg.MODEL.ROI_SCENEGRaGraph.MODE) 30 | cfg.merge_from_file(args.config_file) 31 | cfg.merge_from_list(args.opts) 32 | cfg.freeze() 33 | register_datasets(cfg) 34 | register_coco_data(cfg) 35 | default_setup(cfg, args) 36 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="LSDA") 37 | return cfg 38 | 39 | def main(args): 40 | cfg = setup(args) 41 | if args.eval_only: 42 | 43 | model = SceneGraphTrainer.build_model(cfg) 44 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 45 | cfg.MODEL.WEIGHTS, resume=args.resume 46 | ) 47 | res = SceneGraphTrainer.test(cfg, model) 48 | # if comm.is_main_process(): 49 | # verify_results(cfg, res) 50 | return res 51 | 52 | trainer = SceneGraphTrainer(cfg) 53 | trainer.resume_or_load(resume=args.resume) 54 | return trainer.train() 55 | 56 | if __name__ == '__main__': 57 | args = parser.parse_args() 58 | try: 59 | # use the last 4 numbers in the job id as the id 60 | default_port = os.environ['SLURM_JOB_ID'] 61 | default_port = default_port[-4:] 62 | 63 | # all ports should be in the 10k+ range 64 | default_port = int(default_port) + 15000 65 | 66 | except Exception: 67 | default_port = 59482 68 | 69 | args.dist_url = 'tcp://127.0.0.1:'+str(default_port) 70 | print (args) 71 | 72 | launch( 73 | main, 74 | args.num_gpus, 75 | num_machines=args.num_machines, 76 | machine_rank=args.machine_rank, 77 | dist_url=args.dist_url, 78 | args=(args,), 79 | ) -------------------------------------------------------------------------------- /scripts/train_SG_segmentation_head.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import numpy as np 4 | import torch 5 | 6 | import detectron2.utils.comm as comm 7 | from detectron2.utils.logger import setup_logger 8 | from detectron2.engine import default_argument_parser, default_setup, launch 9 | from detectron2.config import get_cfg 10 | from detectron2.checkpoint import DetectionCheckpointer 11 | 12 | from segmentationsg.engine import SceneGraphSegmentationTrainer 13 | from segmentationsg.data import add_dataset_config, VisualGenomeTrainData, register_datasets 14 | from segmentationsg.modeling.roi_heads.scenegraph_head import add_scenegraph_config 15 | from detectron2.data import DatasetCatalog, MetadataCatalog 16 | from detectron2.data.datasets import register_coco_instances 17 | from segmentationsg.modeling import * 18 | 19 | parser = default_argument_parser() 20 | 21 | def register_coco_data(args): 22 | # annotations = json.load(open('/h/skhandel/SceneGraph/data/coco/instances_train2014.json', 'r')) 23 | # classes = [x['name'] for x in annotations['categories']] 24 | classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'] 25 | MetadataCatalog.get('coco_train_2014').set(thing_classes=classes, evaluator_type='coco') 26 | annotations = args.DATASETS.MSCOCO.ANNOTATIONS 27 | dataroot = args.DATASETS.MSCOCO.DATAROOT 28 | register_coco_instances("coco_train_2017", {}, annotations + 'instances_train2017.json', dataroot + '/train2017/') 29 | register_coco_instances("coco_val_2017", {}, annotations + 'instances_val2017.json', dataroot + '/val2017/') 30 | 31 | def setup(args): 32 | cfg = get_cfg() 33 | add_dataset_config(cfg) 34 | add_scenegraph_config(cfg) 35 | assert(cfg.MODEL.ROI_SCENEGRAPH_HEAD.MODE in ['predcls', 'sgls', 'sgdet']) , "Mode {} not supported".format(cfg.MODEL.ROI_SCENEGRaGraph.MODE) 36 | cfg.merge_from_file(args.config_file) 37 | cfg.merge_from_list(args.opts) 38 | cfg.freeze() 39 | register_datasets(cfg) 40 | register_coco_data(cfg) 41 | default_setup(cfg, args) 42 | setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="LSDA") 43 | return cfg 44 | 45 | def main(args): 46 | cfg = setup(args) 47 | if args.eval_only: 48 | 49 | model = SceneGraphSegmentationTrainer.build_model(cfg) 50 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 51 | cfg.MODEL.WEIGHTS, resume=args.resume 52 | ) 53 | res = SceneGraphSegmentationTrainer.test(cfg, model) 54 | # if comm.is_main_process(): 55 | # verify_results(cfg, res) 56 | return res 57 | 58 | trainer = SceneGraphSegmentationTrainer(cfg) 59 | trainer.resume_or_load(resume=args.resume) 60 | return trainer.train() 61 | 62 | if __name__ == '__main__': 63 | args = parser.parse_args() 64 | try: 65 | # use the last 4 numbers in the job id as the id 66 | default_port = os.environ['SLURM_JOB_ID'] 67 | default_port = default_port[-4:] 68 | 69 | # all ports should be in the 10k+ range 70 | default_port = int(default_port) + 15000 71 | 72 | except Exception: 73 | default_port = 59482 74 | 75 | args.dist_url = 'tcp://127.0.0.1:'+str(default_port) 76 | print (args) 77 | 78 | launch( 79 | main, 80 | args.num_gpus, 81 | num_machines=args.num_machines, 82 | machine_rank=args.machine_rank, 83 | dist_url=args.dist_url, 84 | args=(args,), 85 | ) 86 | -------------------------------------------------------------------------------- /segmentationsg/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/__init__.py -------------------------------------------------------------------------------- /segmentationsg/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection_checkpoint import * -------------------------------------------------------------------------------- /segmentationsg/checkpoint/detection_checkpoint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | from detectron2.utils import comm 5 | from detectron2.engine import hooks, HookBase 6 | import logging 7 | 8 | class PeriodicCheckpointerWithEval(HookBase): 9 | def __init__(self, eval_period, eval_function, checkpointer, checkpoint_period, max_to_keep=5): 10 | self.eval = hooks.EvalHook(eval_period, eval_function) 11 | self.checkpointer = hooks.PeriodicCheckpointer(checkpointer, checkpoint_period, max_to_keep=max_to_keep) 12 | self.best_ap = 0.0 13 | best_model_path = checkpointer.save_dir + 'best_model_final.pth.pth' 14 | if os.path.isfile(best_model_path): 15 | best_model = torch.load(best_model_path, map_location=torch.device('cpu')) 16 | try: 17 | self.best_ap = best_model['SGMeanRecall@20'] 18 | except: 19 | self.best_ap = best_model['AP50'] 20 | del best_model 21 | else: 22 | self.best_ap = 0.0 23 | 24 | def before_train(self): 25 | self.max_iter = self.trainer.max_iter 26 | self.checkpointer.max_iter = self.trainer.max_iter 27 | 28 | def _do_eval(self): 29 | results = self.eval._func() 30 | comm.synchronize() 31 | return results 32 | 33 | def after_step(self): 34 | next_iter = self.trainer.iter + 1 35 | is_final = next_iter == self.trainer.max_iter 36 | if is_final or (self.eval._period > 0 and next_iter % self.eval._period == 0): 37 | results = self._do_eval() 38 | if comm.is_main_process(): 39 | try: 40 | dataset = 'VG_val' if 'VG_val' in results.keys() else 'VG_test' 41 | if results[dataset]['SG']['SGMeanRecall@20'] > self.best_ap: 42 | self.best_ap = results[dataset]['SG']['SGMeanRecall@20'] 43 | additional_state = {"iteration":self.trainer.iter, "SGMeanRecall@20":self.best_ap} 44 | self.checkpointer.checkpointer.save( 45 | "best_model_final.pth", **additional_state 46 | ) 47 | except: 48 | current_ap = results['bbox']['AP50'] 49 | if current_ap > self.best_ap: 50 | self.best_ap = current_ap 51 | additional_state = {"iteration":self.trainer.iter, "AP50":self.best_ap} 52 | self.checkpointer.checkpointer.save( 53 | "best_model_final.pth", **additional_state 54 | ) 55 | if comm.is_main_process(): 56 | self.checkpointer.step(self.trainer.iter) 57 | comm.synchronize() 58 | 59 | def after_train(self): 60 | # func is likely a closure that holds reference to the trainer 61 | # therefore we clean it to avoid circular reference in the end 62 | del self.eval._func -------------------------------------------------------------------------------- /segmentationsg/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset_mapper import * 2 | from .tools import add_dataset_config, register_datasets 3 | from .datasets import VisualGenomeTrainData -------------------------------------------------------------------------------- /segmentationsg/data/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .visual_genome import VisualGenomeTrainData -------------------------------------------------------------------------------- /segmentationsg/data/datasets/coco.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import datetime 3 | import io 4 | import json 5 | import logging 6 | import numpy as np 7 | import os 8 | import shutil 9 | import pycocotools.mask as mask_util 10 | from PIL import Image 11 | from fvcore.common.file_io import PathManager, file_lock 12 | from fvcore.common.timer import Timer 13 | from detectron2.structures import Boxes, BoxMode, PolygonMasks 14 | from detectron2.data import DatasetCatalog, MetadataCatalog 15 | 16 | 17 | def convert_to_coco_dict(dataset_name): 18 | dataset_dicts = DatasetCatalog.get(dataset_name) 19 | metadata = MetadataCatalog.get(dataset_name) 20 | 21 | # unmap the category mapping ids for COCO 22 | if hasattr(metadata, "thing_dataset_id_to_contiguous_id"): 23 | reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()} 24 | reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa 25 | else: 26 | reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa 27 | 28 | categories = [ 29 | {"id": reverse_id_mapper(id), "name": name} 30 | for id, name in enumerate(metadata.thing_classes) 31 | ] 32 | 33 | logger.info("Converting dataset dicts into COCO format") 34 | coco_images = [] 35 | coco_annotations = [] 36 | 37 | for image_id, image_dict in enumerate(dataset_dicts): 38 | coco_image = { 39 | "id": image_dict.get("image_id", image_id), 40 | "width": int(image_dict["width"]), 41 | "height": int(image_dict["height"]), 42 | "file_name": str(image_dict["file_name"]), 43 | } 44 | coco_images.append(coco_image) 45 | 46 | anns_per_image = image_dict.get("annotations", []) 47 | for annotation in anns_per_image: 48 | # create a new dict with only COCO fields 49 | coco_annotation = {} 50 | 51 | # COCO requirement: XYWH box format 52 | bbox = annotation["bbox"] 53 | bbox_mode = annotation["bbox_mode"] 54 | bbox = BoxMode.convert(bbox, bbox_mode, BoxMode.XYWH_ABS) 55 | 56 | # COCO requirement: instance area 57 | if "segmentation" in annotation: 58 | # Computing areas for instances by counting the pixels 59 | segmentation = annotation["segmentation"] 60 | # TODO: check segmentation type: RLE, BinaryMask or Polygon 61 | if isinstance(segmentation, list): 62 | polygons = PolygonMasks([segmentation]) 63 | area = polygons.area()[0].item() 64 | elif isinstance(segmentation, dict): # RLE 65 | area = mask_util.area(segmentation).item() 66 | else: 67 | raise TypeError(f"Unknown segmentation type {type(segmentation)}!") 68 | else: 69 | # Computing areas using bounding boxes 70 | bbox_xy = BoxMode.convert(bbox, BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) 71 | area = Boxes([bbox_xy]).area()[0].item() 72 | 73 | if "keypoints" in annotation: 74 | keypoints = annotation["keypoints"] # list[int] 75 | for idx, v in enumerate(keypoints): 76 | if idx % 3 != 2: 77 | # COCO's segmentation coordinates are floating points in [0, H or W], 78 | # but keypoint coordinates are integers in [0, H-1 or W-1] 79 | # For COCO format consistency we substract 0.5 80 | # https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163 81 | keypoints[idx] = v - 0.5 82 | if "num_keypoints" in annotation: 83 | num_keypoints = annotation["num_keypoints"] 84 | else: 85 | num_keypoints = sum(kp > 0 for kp in keypoints[2::3]) 86 | 87 | # COCO requirement: 88 | # linking annotations to images 89 | # "id" field must start with 1 90 | coco_annotation["id"] = len(coco_annotations) + 1 91 | coco_annotation["image_id"] = coco_image["id"] 92 | coco_annotation["bbox"] = [round(float(x), 3) for x in bbox] 93 | coco_annotation["area"] = float(area) 94 | coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0)) 95 | coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"])) 96 | 97 | # Add optional fields 98 | if "keypoints" in annotation: 99 | coco_annotation["keypoints"] = keypoints 100 | coco_annotation["num_keypoints"] = num_keypoints 101 | 102 | if "segmentation" in annotation: 103 | seg = coco_annotation["segmentation"] = annotation["segmentation"] 104 | if isinstance(seg, dict): # RLE 105 | counts = seg["counts"] 106 | if not isinstance(counts, str): 107 | # make it json-serializable 108 | seg["counts"] = counts.decode("ascii") 109 | 110 | coco_annotations.append(coco_annotation) 111 | 112 | logger.info( 113 | "Conversion finished, " 114 | f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}" 115 | ) 116 | 117 | info = { 118 | "date_created": str(datetime.datetime.now()), 119 | "description": "Automatically generated COCO json file for Detectron2.", 120 | } 121 | coco_dict = {"info": info, "images": coco_images, "categories": categories, "licenses": None} 122 | if len(coco_annotations) > 0: 123 | coco_dict["annotations"] = coco_annotations 124 | return coco_dict 125 | 126 | def convert_to_coco_json(dataset_name, output_file, allow_cached=True): 127 | """ 128 | Converts dataset into COCO format and saves it to a json file. 129 | dataset_name must be registered in DatasetCatalog and in detectron2's standard format. 130 | Args: 131 | dataset_name: 132 | reference from the config file to the catalogs 133 | must be registered in DatasetCatalog and in detectron2's standard format 134 | output_file: path of json file that will be saved to 135 | allow_cached: if json file is already present then skip conversion 136 | """ 137 | 138 | # TODO: The dataset or the conversion script *may* change, 139 | # a checksum would be useful for validating the cached data 140 | 141 | PathManager.mkdirs(os.path.dirname(output_file)) 142 | with file_lock(output_file): 143 | if PathManager.exists(output_file) and allow_cached: 144 | logger.warning( 145 | f"Using previously cached COCO format annotations at '{output_file}'. " 146 | "You need to clear the cache file if your dataset has been modified." 147 | ) 148 | else: 149 | logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)") 150 | coco_dict = convert_to_coco_dict(dataset_name) 151 | 152 | logger.info(f"Caching COCO format annotations at '{output_file}' ...") 153 | tmp_file = output_file + ".tmp" 154 | with PathManager.open(tmp_file, "w") as f: 155 | json.dump(coco_dict, f) 156 | shutil.move(tmp_file, output_file) -------------------------------------------------------------------------------- /segmentationsg/data/datasets/openimage_preprocess.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | from pathlib import Path 4 | from math import floor 5 | from PIL import Image, ImageDraw 6 | import random 7 | import os 8 | import torch 9 | import numpy as np 10 | import pickle 11 | import yaml 12 | import pandas as pd 13 | import cv2 14 | from tqdm import tqdm 15 | from pycocotools import mask 16 | from detectron2.config import get_cfg 17 | from detectron2.structures import Instances, Boxes, pairwise_iou, BoxMode 18 | 19 | 20 | splits = ['train', 'test'] 21 | # splits = ['validation'] 22 | data_dir= Path('/scratch/hdd001/datasets/openimages') 23 | dev_run = False 24 | 25 | orig_id2cat_id = json.load(open('/h/suhail/SceneGraph/data/datasets/openimages/orig_id2cat_id.json', 'r')) 26 | 27 | for split in splits: 28 | 29 | #Directory containing the annotation files 30 | annotation_dir = data_dir / 'annotations' 31 | #Directory contating the segementation masks 32 | segmentation_dir = annotation_dir / 'segmentations' 33 | segmentation_annotation_file = segmentation_dir / split / '{}-annotations-object-segmentation.csv'.format(split) 34 | 35 | #Sort and group the data frame by imageid 36 | segmentation_annotion = pd.read_csv(segmentation_annotation_file).sort_values('ImageID') 37 | grouped_segmentation_annotation = segmentation_annotion.groupby('ImageID') 38 | 39 | dataset_dicts = [] 40 | # i = 0 41 | for imageid, group in tqdm(grouped_segmentation_annotation): 42 | 43 | image_dict = {} 44 | 45 | image_dict['image_id'] = imageid 46 | image_dict['file_name'] = str(data_dir / split / '{}.jpg'.format(imageid) ) 47 | 48 | image = cv2.imread(image_dict['file_name']) 49 | image_dict['height'], image_dict['width'] = image.shape[0], image.shape[1] 50 | 51 | objs = [] 52 | for index, rows in group.iterrows(): 53 | obj = {} 54 | 55 | obj["bbox"] = [rows['BoxXMin']*image_dict['width'], rows['BoxYMin']*image_dict['height'], 56 | rows['BoxXMax']*image_dict['width'], rows['BoxYMax']*image_dict['height']] 57 | 58 | obj["bbox_mode"] = BoxMode.XYXY_ABS 59 | 60 | #Reshape mask as the images and segmentation mask have different size 61 | maskimage = cv2.resize(cv2.imread(str(segmentation_dir / split / '{}-masks-{}'.format(split, imageid[0]) / rows['MaskPath'] )), (image_dict['width'], image_dict['height'])) 62 | 63 | #Conver the numpy array to contours 64 | imgray = cv2.cvtColor(maskimage, cv2.COLOR_BGR2GRAY) 65 | ret, thresh = cv2.threshold(imgray, 127, 255, 0) 66 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 67 | contour_lengths = [x.shape[0] for x in contours] 68 | 69 | #Some masks are corrupted and donot have any annotions(Skip these masks) 70 | if len(contour_lengths) == 0: 71 | continue 72 | 73 | #In case of multiple contours extracted choose the largest one 74 | choose_contour = contour_lengths.index(max(contour_lengths)) 75 | poly = contours[choose_contour].squeeze().tolist() 76 | if isinstance(poly[0], int): 77 | continue 78 | poly = [p for x in poly for p in x] 79 | 80 | obj["segmentation"] = [poly] 81 | obj["category_id"] = orig_id2cat_id[rows['LabelName']] - 1 82 | 83 | objs.append(obj) 84 | 85 | image_dict["annotations"] = objs 86 | 87 | dataset_dicts.append(image_dict) 88 | 89 | print("Svaing the dict for {}".format(split)) 90 | with open('/h/suhail/SceneGraph/data/datasets/openimages/{}-imagedict.pkl'.format(split), 'wb+') as f: 91 | pickle.dump(dataset_dicts, f) 92 | -------------------------------------------------------------------------------- /segmentationsg/data/datasets/openimages.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import json 3 | import math 4 | from pathlib import Path 5 | from math import floor 6 | from PIL import Image, ImageDraw 7 | import random 8 | import os 9 | import torch 10 | import numpy as np 11 | import pickle 12 | import yaml 13 | import cv2 14 | from detectron2.config import get_cfg 15 | from detectron2.structures import Instances, Boxes, pairwise_iou, BoxMode 16 | from detectron2.data import DatasetCatalog, MetadataCatalog 17 | 18 | class OpenImageTrainData(): 19 | ''' 20 | Registed Open Images Dataset in Detectron 21 | ''' 22 | def __init__(self, cfg, data_dir='/scratch/hdd001/datasets/openimages', split='train'): 23 | 24 | self.cfg = cfg 25 | valid_splits = ['train', 'validation', 'test'] 26 | assert split in valid_splits, "Invalid split {}. Specify one of {}".format(split, valid_splits) 27 | self.split = split 28 | self.data_dir = Path(data_dir) 29 | self.dataset_dicts = self.get_dataset_dicts() 30 | 31 | self.register_dataset() 32 | 33 | def register_dataset(self): 34 | 35 | DatasetCatalog.register("openimages_" + self.split, lambda: self.dataset_dicts) 36 | #MetaData for Open Images 37 | 38 | def get_dataset_dicts(self): 39 | ''' 40 | Convert dataset to Detectron format 41 | ''' 42 | 43 | annotation_dir = data_dir / 'annotations' 44 | segmentation_dir = annotation_dir / 'segmentations' 45 | # coco_style_annotation_file = annotation_dir / 'coco_style' / '{}-annotation-bbox.json' 46 | # cs_annotation = json.load(open(coco_style_annotation_file,'r')) 47 | 48 | segmentation_annotation_file = segmentation_dir / split / '{}-annotations-object-segmentation.csv'.format(split) 49 | 50 | segmentation_annotion = pd.read_csv(segmentation_annotation_file).sort_values('ImageID') 51 | 52 | grouped_segmentation_annotation = segmentation_annotion.groupby('ImageID') 53 | 54 | dataset_dicts = [] 55 | for idx, (imageid, group) in tqdm(enumerate(grouped_segmentation_annotation)): 56 | 57 | image_dict = {} 58 | 59 | image_dict['image_id'] = imageid 60 | image_dict['file_name'] = str(data_dir / split / '{}.jpg'.format(imageid) ) 61 | 62 | image = cv2.imread(image_dict['file_name']) 63 | image_dict['height'], image_dict['width'] = image.shape[0], image.shape[1] 64 | 65 | objs = [] 66 | for index, rows in group.iterrows(): 67 | obj = {} 68 | 69 | obj["bbox"] = [rows['BoxXMin'], rows['BoxYMin'], rows['BoxXMax'], rows['BoxYMax']] 70 | obj["bbox_mode"] = BoxMode.XYXY_ABS 71 | 72 | maskimage = cv2.resize(cv2.imread(str(segmentation_dir / split / '{}-masks-{}'.format(split, imageid[0]) / rows['MaskPath'] )), (image_dict['width'], image_dict['height'])) 73 | 74 | imgray = cv2.cvtColor(maskimage, cv2.COLOR_BGR2GRAY) 75 | ret, thresh = cv2.threshold(imgray, 127, 255, 0) 76 | contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 77 | poly = contours[0].squeeze().tolist() 78 | poly = [p for x in poly for p in x] 79 | 80 | obj["segmentation"] = [poly] 81 | obj["category_id"] = orig_id2cat_id[rows['LabelName']] - 1 82 | 83 | objs.append(obj) 84 | 85 | image_dict["annotations"] = objs 86 | 87 | dataset_dicts.append(image_dict) 88 | 89 | return dataset_dicts -------------------------------------------------------------------------------- /segmentationsg/data/datasets/visual_genome.py: -------------------------------------------------------------------------------- 1 | import h5py 2 | import json 3 | import math 4 | from math import floor 5 | from PIL import Image, ImageDraw 6 | import random 7 | import os 8 | import torch 9 | import numpy as np 10 | import pickle 11 | import yaml 12 | from detectron2.config import get_cfg 13 | from detectron2.structures import Instances, Boxes, pairwise_iou, BoxMode 14 | from detectron2.data import DatasetCatalog, MetadataCatalog 15 | import logging 16 | 17 | class VisualGenomeTrainData: 18 | """ 19 | Register data for Visual Genome training 20 | """ 21 | def __init__(self, cfg, split='train'): 22 | self.cfg = cfg 23 | self.split = split 24 | if split == 'train': 25 | self.mask_location = cfg.DATASETS.VISUAL_GENOME.TRAIN_MASKS 26 | elif split == 'val': 27 | self.mask_location = cfg.DATASETS.VISUAL_GENOME.VAL_MASKS 28 | else: 29 | self.mask_location = cfg.DATASETS.VISUAL_GENOME.TEST_MASKS 30 | self.mask_exists = os.path.isfile(self.mask_location) 31 | self.clamped = True if "clamped" in self.mask_location else "" 32 | self.clipped = cfg.DATASETS.VISUAL_GENOME.CLIPPED 33 | self.precompute = False if (self.cfg.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS or self.cfg.DATASETS.VISUAL_GENOME.FILTER_NON_OVERLAP) else True 34 | 35 | # self._process_data() 36 | self.dataset_dicts = self._fetch_data_dict() 37 | self.register_dataset() 38 | try: 39 | self.get_statistics() 40 | except: 41 | pass 42 | 43 | def register_dataset(self): 44 | """ 45 | Register datasets to use with Detectron2 46 | """ 47 | DatasetCatalog.register('VG_{}'.format(self.split), lambda: self.dataset_dicts) 48 | 49 | #Get labels 50 | self.mapping_dictionary = json.load(open(self.cfg.DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY, 'r')) 51 | self.idx_to_classes = sorted(self.mapping_dictionary['label_to_idx'], key=lambda k: self.mapping_dictionary['label_to_idx'][k]) 52 | self.idx_to_predicates = sorted(self.mapping_dictionary['predicate_to_idx'], key=lambda k: self.mapping_dictionary['predicate_to_idx'][k]) 53 | self.idx_to_attributes = sorted(self.mapping_dictionary['attribute_to_idx'], key=lambda k: self.mapping_dictionary['attribute_to_idx'][k]) 54 | MetadataCatalog.get('VG_{}'.format(self.split)).set(thing_classes=self.idx_to_classes, predicate_classes=self.idx_to_predicates, attribute_classes=self.idx_to_attributes) 55 | 56 | def _fetch_data_dict(self): 57 | """ 58 | Load data in detectron format 59 | """ 60 | fileName = "tmp/visual_genome_{}_data_{}{}{}{}{}.pkl".format(self.split, 'masks' if self.mask_exists else '', '_oi' if 'oi' in self.mask_location else '', "_clamped" if self.clamped else "", "_precomp" if self.precompute else "", "_clipped" if self.clipped else "") 61 | if os.path.isfile(fileName): 62 | #If data has been processed earlier, load that to save time 63 | print("loading cached file: ", fileName) 64 | with open(fileName, 'rb') as inputFile: 65 | dataset_dicts = pickle.load(inputFile) 66 | else: 67 | #Process data 68 | os.makedirs('tmp', exist_ok=True) 69 | dataset_dicts = self._process_data() 70 | #TODO: this can cause problems, if it is excecuted in a distributed setup 71 | print("creating cache file: ", fileName) 72 | with open(fileName, 'wb') as inputFile: 73 | pickle.dump(dataset_dicts, inputFile) 74 | return dataset_dicts 75 | 76 | def _process_data(self): 77 | self.VG_attribute_h5 = h5py.File(self.cfg.DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5, 'r') 78 | 79 | # Remove corrupted images 80 | image_data = json.load(open(self.cfg.DATASETS.VISUAL_GENOME.IMAGE_DATA, 'r')) 81 | self.corrupted_ims = ['1592', '1722', '4616', '4617'] 82 | self.image_data = [] 83 | for i, img in enumerate(image_data): 84 | if str(img['image_id']) in self.corrupted_ims: 85 | continue 86 | self.image_data.append(img) 87 | assert(len(self.image_data) == 108073) 88 | self.masks = None 89 | if self.mask_location != "": 90 | try: 91 | with open(self.mask_location, 'rb') as f: 92 | self.masks = pickle.load(f) 93 | except: 94 | pass 95 | dataset_dicts = self._load_graphs() 96 | return dataset_dicts 97 | 98 | def get_statistics(self, eps=1e-3, bbox_overlap=True): 99 | num_object_classes = len(MetadataCatalog.get('VG_{}'.format(self.split)).thing_classes) + 1 100 | num_relation_classes = len(MetadataCatalog.get('VG_{}'.format(self.split)).predicate_classes) + 1 101 | 102 | fg_matrix = np.zeros((num_object_classes, num_object_classes, num_relation_classes), dtype=np.int64) 103 | bg_matrix = np.zeros((num_object_classes, num_object_classes), dtype=np.int64) 104 | for idx, data in enumerate(self.dataset_dicts): 105 | gt_relations = data['relations'] 106 | gt_classes = np.array([x['category_id'] for x in data['annotations']]) 107 | gt_boxes = np.array([x['bbox'] for x in data['annotations']]) 108 | for (o1, o2), rel in zip(gt_classes[gt_relations[:,:2]], gt_relations[:,2]): 109 | fg_matrix[o1, o2, rel] += 1 110 | for (o1, o2) in gt_classes[np.array(box_filter(gt_boxes, must_overlap=bbox_overlap), dtype=int)]: 111 | bg_matrix[o1, o2] += 1 112 | bg_matrix += 1 113 | fg_matrix[:, :, -1] = bg_matrix 114 | pred_dist = np.log(fg_matrix / fg_matrix.sum(2)[:, :, None] + eps) 115 | 116 | result = { 117 | 'fg_matrix': torch.from_numpy(fg_matrix), 118 | 'pred_dist': torch.from_numpy(pred_dist).float(), 119 | 'obj_classes': self.idx_to_classes + ['__background__'], 120 | 'rel_classes': self.idx_to_predicates + ['__background__'], 121 | 'att_classes': self.idx_to_attributes, 122 | } 123 | MetadataCatalog.get('VG_{}'.format(self.split)).set(statistics=result) 124 | return result 125 | 126 | def _load_graphs(self): 127 | """ 128 | Parse examples and create dictionaries 129 | """ 130 | data_split = self.VG_attribute_h5['split'][:] 131 | split_flag = 2 if self.split == 'test' else 0 132 | split_mask = data_split == split_flag 133 | 134 | #Filter images without bounding boxes 135 | split_mask &= self.VG_attribute_h5['img_to_first_box'][:] >= 0 136 | if self.cfg.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS: 137 | split_mask &= self.VG_attribute_h5['img_to_first_rel'][:] >= 0 138 | image_index = np.where(split_mask)[0] 139 | 140 | if self.split == 'val': 141 | image_index = image_index[:self.cfg.DATASETS.VISUAL_GENOME.NUMBER_OF_VALIDATION_IMAGES] 142 | elif self.split == 'train': 143 | image_index = image_index[self.cfg.DATASETS.VISUAL_GENOME.NUMBER_OF_VALIDATION_IMAGES:] 144 | 145 | split_mask = np.zeros_like(data_split).astype(bool) 146 | split_mask[image_index] = True 147 | 148 | # Get box information 149 | all_labels = self.VG_attribute_h5['labels'][:, 0] 150 | all_attributes = self.VG_attribute_h5['attributes'][:, :] 151 | all_boxes = self.VG_attribute_h5['boxes_{}'.format(self.cfg.DATASETS.VISUAL_GENOME.BOX_SCALE)][:] # cx,cy,w,h 152 | assert np.all(all_boxes[:, :2] >= 0) # sanity check 153 | assert np.all(all_boxes[:, 2:] > 0) # no empty box 154 | 155 | # Convert from xc, yc, w, h to x1, y1, x2, y2 156 | all_boxes[:, :2] = all_boxes[:, :2] - all_boxes[:, 2:] / 2 157 | all_boxes[:, 2:] = all_boxes[:, :2] + all_boxes[:, 2:] 158 | 159 | first_box_index = self.VG_attribute_h5['img_to_first_box'][split_mask] 160 | last_box_index = self.VG_attribute_h5['img_to_last_box'][split_mask] 161 | first_relation_index = self.VG_attribute_h5['img_to_first_rel'][split_mask] 162 | last_relation_index = self.VG_attribute_h5['img_to_last_rel'][split_mask] 163 | 164 | #Load relation labels 165 | all_relations = self.VG_attribute_h5['relationships'][:] 166 | all_relation_predicates = self.VG_attribute_h5['predicates'][:, 0] 167 | 168 | image_indexer = np.arange(len(self.image_data))[split_mask] 169 | # Iterate over images 170 | dataset_dicts = [] 171 | for idx, _ in enumerate(image_index): 172 | record = {} 173 | #Get image metadata 174 | image_data = self.image_data[image_indexer[idx]] 175 | record['file_name'] = os.path.join(self.cfg.DATASETS.VISUAL_GENOME.IMAGES, '{}.jpg'.format(image_data['image_id'])) 176 | record['image_id'] = image_data['image_id'] 177 | record['height'] = image_data['height'] 178 | record['width'] = image_data['width'] 179 | 180 | #Get annotations 181 | boxes = all_boxes[first_box_index[idx]:last_box_index[idx] + 1, :] 182 | gt_classes = all_labels[first_box_index[idx]:last_box_index[idx] + 1] 183 | gt_attributes = all_attributes[first_box_index[idx]:last_box_index[idx] + 1, :] 184 | 185 | if first_relation_index[idx] > -1: 186 | predicates = all_relation_predicates[first_relation_index[idx]:last_relation_index[idx] + 1] 187 | objects = all_relations[first_relation_index[idx]:last_relation_index[idx] + 1] - first_box_index[idx] 188 | predicates = predicates - 1 189 | relations = np.column_stack((objects, predicates)) 190 | else: 191 | assert not self.cfg.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS 192 | relations = np.zeros((0, 3), dtype=np.int32) 193 | 194 | if self.cfg.DATASETS.VISUAL_GENOME.FILTER_NON_OVERLAP and self.split == 'train': 195 | # Remove boxes that don't overlap 196 | boxes_list = Boxes(boxes) 197 | ious = pairwise_iou(boxes_list, boxes_list) 198 | relation_boxes_ious = ious[relations[:,0], relations[:,1]] 199 | iou_indexes = np.where(relation_boxes_ious > 0.0)[0] 200 | if iou_indexes.size > 0: 201 | relations = relations[iou_indexes] 202 | else: 203 | #Ignore image 204 | continue 205 | #Get masks if possible 206 | if self.masks is not None: 207 | try: 208 | gt_masks = self.masks[image_data['image_id']] 209 | except: 210 | print (image_data['image_id']) 211 | record['relations'] = relations 212 | objects = [] 213 | # if len(boxes) != len(gt_masks): 214 | mask_idx = 0 215 | for obj_idx in range(len(boxes)): 216 | resized_box = boxes[obj_idx] / self.cfg.DATASETS.VISUAL_GENOME.BOX_SCALE * max(record['height'], record['width']) 217 | obj = { 218 | "bbox": resized_box.tolist(), 219 | "bbox_mode": BoxMode.XYXY_ABS, 220 | "category_id": gt_classes[obj_idx] - 1, 221 | "attribute": gt_attributes[obj_idx], 222 | } 223 | if self.masks is not None: 224 | if gt_masks['empty_index'][obj_idx]: 225 | refined_poly = [] 226 | for poly_idx, poly in enumerate(gt_masks['polygons'][mask_idx]): 227 | if len(poly) >= 6: 228 | refined_poly.append(poly) 229 | obj["segmentation"] = refined_poly 230 | mask_idx += 1 231 | else: 232 | obj["segmentation"] = [] 233 | if len(obj["segmentation"]) > 0: 234 | objects.append(obj) 235 | else: 236 | objects.append(obj) 237 | record['annotations'] = objects 238 | dataset_dicts.append(record) 239 | 240 | return dataset_dicts 241 | 242 | def box_filter(boxes, must_overlap=False): 243 | """ Only include boxes that overlap as possible relations. 244 | If no overlapping boxes, use all of them.""" 245 | n_cands = boxes.shape[0] 246 | 247 | overlaps = bbox_overlaps(boxes.astype(np.float), boxes.astype(np.float), to_move=0) > 0 248 | np.fill_diagonal(overlaps, 0) 249 | 250 | all_possib = np.ones_like(overlaps, dtype=np.bool) 251 | np.fill_diagonal(all_possib, 0) 252 | 253 | if must_overlap: 254 | possible_boxes = np.column_stack(np.where(overlaps)) 255 | 256 | if possible_boxes.size == 0: 257 | possible_boxes = np.column_stack(np.where(all_possib)) 258 | else: 259 | possible_boxes = np.column_stack(np.where(all_possib)) 260 | return possible_boxes 261 | 262 | def bbox_overlaps(boxes1, boxes2, to_move=1): 263 | """ 264 | boxes1 : numpy, [num_obj, 4] (x1,y1,x2,y2) 265 | boxes2 : numpy, [num_obj, 4] (x1,y1,x2,y2) 266 | """ 267 | #print('boxes1: ', boxes1.shape) 268 | #print('boxes2: ', boxes2.shape) 269 | num_box1 = boxes1.shape[0] 270 | num_box2 = boxes2.shape[0] 271 | lt = np.maximum(boxes1.reshape([num_box1, 1, -1])[:,:,:2], boxes2.reshape([1, num_box2, -1])[:,:,:2]) # [N,M,2] 272 | rb = np.minimum(boxes1.reshape([num_box1, 1, -1])[:,:,2:], boxes2.reshape([1, num_box2, -1])[:,:,2:]) # [N,M,2] 273 | 274 | wh = (rb - lt + to_move).clip(min=0) # [N,M,2] 275 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 276 | return inter -------------------------------------------------------------------------------- /segmentationsg/data/embeddings/glove_coco: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/data/embeddings/glove_coco -------------------------------------------------------------------------------- /segmentationsg/data/embeddings/glove_mean_coco: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/data/embeddings/glove_mean_coco -------------------------------------------------------------------------------- /segmentationsg/data/embeddings/glove_mean_open_images: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/data/embeddings/glove_mean_open_images -------------------------------------------------------------------------------- /segmentationsg/data/embeddings/glove_mean_vg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/data/embeddings/glove_mean_vg -------------------------------------------------------------------------------- /segmentationsg/data/embeddings/glove_open_images: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/data/embeddings/glove_open_images -------------------------------------------------------------------------------- /segmentationsg/data/embeddings/glove_vg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/data/embeddings/glove_vg -------------------------------------------------------------------------------- /segmentationsg/data/get_embeddings.py: -------------------------------------------------------------------------------- 1 | import torchtext 2 | import numpy as np 3 | import os 4 | import torch 5 | import json 6 | 7 | vg = {'airplane': 0, 'animal': 1, 'arm': 2, 'bag': 3, 'banana': 4, 'basket': 5, 'beach': 6, 'bear': 7, 'bed': 8, 'bench': 9, 'bike': 10, 'bird': 11, 'board': 12, 'boat': 13, 'book': 14, 'boot': 15, 'bottle': 16, 'bowl': 17, 'box': 18, 'boy': 19, 'branch': 20, 'building': 21, 'bus': 22, 'cabinet': 23, 'cap': 24, 'car': 25, 'cat': 26, 'chair': 27, 'child': 28, 'clock': 29, 'coat': 30, 'counter': 31, 'cow': 32, 'cup': 33, 'curtain': 34, 'desk': 35, 'dog': 36, 'door': 37, 'drawer': 38, 'ear': 39, 'elephant': 40, 'engine': 41, 'eye': 42, 'face': 43, 'fence': 44, 'finger': 45, 'flag': 46, 'flower': 47, 'food': 48, 'fork': 49, 'fruit': 50, 'giraffe': 51, 'girl': 52, 'glass': 53, 'glove': 54, 'guy': 55, 'hair': 56, 'hand': 57, 'handle': 58, 'hat': 59, 'head': 60, 'helmet': 61, 'hill': 62, 'horse': 63, 'house': 64, 'jacket': 65, 'jean': 66, 'kid': 67, 'kite': 68, 'lady': 69, 'lamp': 70, 'laptop': 71, 'leaf': 72, 'leg': 73, 'letter': 74, 'light': 75, 'logo': 76, 'man': 77, 'men': 78, 'motorcycle': 79, 'mountain': 80, 'mouth': 81, 'neck': 82, 'nose': 83, 'number': 84, 'orange': 85, 'pant': 86, 'paper': 87, 'paw': 88, 'people': 89, 'person': 90, 'phone': 91, 'pillow': 92, 'pizza': 93, 'plane': 94, 'plant': 95, 'plate': 96, 'player': 97, 'pole': 98, 'post': 99, 'pot': 100, 'racket': 101, 'railing': 102, 'rock': 103, 'roof': 104, 'room': 105, 'screen': 106, 'seat': 107, 'sheep': 108, 'shelf': 109, 'shirt': 110, 'shoe': 111, 'short': 112, 'sidewalk': 113, 'sign': 114, 'sink': 115, 'skateboard': 116, 'ski': 117, 'skier': 118, 'sneaker': 119, 'snow': 120, 'sock': 121, 'stand': 122, 'street': 123, 'surfboard': 124, 'table': 125, 'tail': 126, 'tie': 127, 'tile': 128, 'tire': 129, 'toilet': 130, 'towel': 131, 'tower': 132, 'track': 133, 'train': 134, 'tree': 135, 'truck': 136, 'trunk': 137, 'umbrella': 138, 'vase': 139, 'vegetable': 140, 'vehicle': 141, 'wave': 142, 'wheel': 143, 'window': 144, 'windshield': 145, 'wing': 146, 'wire': 147, 'woman': 148, 'zebra': 149} 8 | 9 | coco = {'person': 0, 'bicycle': 1, 'car': 2, 'motorcycle': 3, 'airplane': 4, 'bus': 5, 'train': 6, 'truck': 7, 'boat': 8, 'traffic light': 9, 'fire hydrant': 10, 'stop sign': 11, 'parking meter': 12, 'bench': 13, 'bird': 14, 'cat': 15, 'dog': 16, 'horse': 17, 'sheep': 18, 'cow': 19, 'elephant': 20, 'bear': 21, 'zebra': 22, 'giraffe': 23, 'backpack': 24, 'umbrella': 25, 'handbag': 26, 'tie': 27, 'suitcase': 28, 'frisbee': 29, 'skis': 30, 'snowboard': 31, 'sports ball': 32, 'kite': 33, 'baseball bat': 34, 'baseball glove': 35, 'skateboard': 36, 'surfboard': 37, 'tennis racket': 38, 'bottle': 39, 'wine glass': 40, 'cup': 41, 'fork': 42, 'knife': 43, 'spoon': 44, 'bowl': 45, 'banana': 46, 'apple': 47, 'sandwich': 48, 'orange': 49, 'broccoli': 50, 'carrot': 51, 'hot dog': 52, 'pizza': 53, 'donut': 54, 'cake': 55, 'chair': 56, 'couch': 57, 'potted plant': 58, 'bed': 59, 'dining table': 60, 'toilet': 61, 'tv': 62, 'laptop': 63, 'mouse': 64, 'remote': 65, 'keyboard': 66, 'cell phone': 67, 'microwave': 68, 'oven': 69, 'toaster': 70, 'sink': 71, 'refrigerator': 72, 'book': 73, 'clock': 74, 'vase': 75, 'scissors': 76, 'teddy bear': 77, 'hair drier': 78, 'toothbrush': 79} 10 | 11 | open_images = {'tortoise': 0, 'container': 1, 'magpie': 2, 'sea turtle': 3, 'football': 4, 'ambulance': 5, 'ladder': 6, 'toothbrush': 7, 'syringe': 8, 'sink': 9, 'toy': 10, 'organ (musical instrument)': 11, 'cassette deck': 12, 'apple': 13, 'human eye': 14, 'cosmetics': 15, 'paddle': 16, 'snowman': 17, 'beer': 18, 'chopsticks': 19, 'human beard': 20, 'bird': 21, 'parking meter': 22, 'traffic light': 23, 'croissant': 24, 'cucumber': 25, 'radish': 26, 'towel': 27, 'doll': 28, 'skull': 29, 'washing machine': 30, 'glove': 31, 'tick': 32, 'belt': 33, 'sunglasses': 34, 'banjo': 35, 'cart': 36, 'ball': 37, 'backpack': 38, 'bicycle': 39, 'home appliance': 40, 'centipede': 41, 'boat': 42, 'surfboard': 43, 'boot': 44, 'headphones': 45, 'hot dog': 46, 'shorts': 47, 'fast food': 48, 'bus': 49, 'boy': 50, 'screwdriver': 51, 'bicycle wheel': 52, 'barge': 53, 'laptop': 54, 'miniskirt': 55, 'drill (tool)': 56, 'dress': 57, 'bear': 58, 'waffle': 59, 'pancake': 60, 'brown bear': 61, 'woodpecker': 62, 'blue jay': 63, 'pretzel': 64, 'bagel': 65, 'tower': 66, 'teapot': 67, 'person': 68, 'bow and arrow': 69, 'swimwear': 70, 'beehive': 71, 'brassiere': 72, 'bee': 73, 'bat (animal)': 74, 'starfish': 75, 'popcorn': 76, 'burrito': 77, 'chainsaw': 78, 'balloon': 79, 'wrench': 80, 'tent': 81, 'vehicle registration plate': 82, 'lantern': 83, 'toaster': 84, 'flashlight': 85, 'billboard': 86, 'tiara': 87, 'limousine': 88, 'necklace': 89, 'carnivore': 90, 'scissors': 91, 'stairs': 92, 'computer keyboard': 93, 'printer': 94, 'traffic sign': 95, 'chair': 96, 'shirt': 97, 'poster': 98, 'cheese': 99, 'sock': 100, 'fire hydrant': 101, 'land vehicle': 102, 'earrings': 103, 'tie': 104, 'watercraft': 105, 'cabinetry': 106, 'suitcase': 107, 'muffin': 108, 'bidet': 109, 'snack': 110, 'snowmobile': 111, 'clock': 112, 'medical equipment': 113, 'cattle': 114, 'cello': 115, 'jet ski': 116, 'camel': 117, 'coat': 118, 'suit': 119, 'desk': 120, 'cat': 121, 'bronze sculpture': 122, 'juice': 123, 'gondola': 124, 'beetle': 125, 'cannon': 126, 'computer mouse': 127, 'cookie': 128, 'office building': 129, 'fountain': 130, 'coin': 131, 'calculator': 132, 'cocktail': 133, 'computer monitor': 134, 'box': 135, 'stapler': 136, 'christmas tree': 137, 'cowboy hat': 138, 'hiking equipment': 139, 'studio couch': 140, 'drum': 141, 'dessert': 142, 'wine rack': 143, 'drink': 144, 'zucchini': 145, 'ladle': 146, 'human mouth': 147, 'dairy product': 148, 'dice': 149, 'oven': 150, 'dinosaur': 151, 'ratchet (device)': 152, 'couch': 153, 'cricket ball': 154, 'winter melon': 155, 'spatula': 156, 'whiteboard': 157, 'pencil sharpener': 158, 'door': 159, 'hat': 160, 'shower': 161, 'eraser': 162, 'fedora': 163, 'guacamole': 164, 'dagger': 165, 'scarf': 166, 'dolphin': 167, 'sombrero': 168, 'tin can': 169, 'mug': 170, 'tap': 171, 'harbor seal': 172, 'stretcher': 173, 'can opener': 174, 'goggles': 175, 'human body': 176, 'roller skates': 177, 'coffee cup': 178, 'cutting board': 179, 'blender': 180, 'plumbing fixture': 181, 'stop sign': 182, 'office supplies': 183, 'volleyball (ball)': 184, 'vase': 185, 'slow cooker': 186, 'wardrobe': 187, 'coffee': 188, 'whisk': 189, 'paper towel': 190, 'personal care': 191, 'food': 192, 'sun hat': 193, 'tree house': 194, 'flying disc': 195, 'skirt': 196, 'gas stove': 197, 'salt and pepper shakers': 198, 'mechanical fan': 199, 'face powder': 200, 'fax': 201, 'fruit': 202, 'french fries': 203, 'nightstand': 204, 'barrel': 205, 'kite': 206, 'tart': 207, 'treadmill': 208, 'fox': 209, 'flag': 210, 'french horn': 211, 'window blind': 212, 'human foot': 213, 'golf cart': 214, 'jacket': 215, 'egg (food)': 216, 'street light': 217, 'guitar': 218, 'pillow': 219, 'human leg': 220, 'isopod': 221, 'grape': 222, 'human ear': 223, 'power plugs and sockets': 224, 'panda': 225, 'giraffe': 226, 'woman': 227, 'door handle': 228, 'rhinoceros': 229, 'bathtub': 230, 'goldfish': 231, 'houseplant': 232, 'goat': 233, 'baseball bat': 234, 'baseball glove': 235, 'mixing bowl': 236, 'marine invertebrates': 237, 'kitchen utensil': 238, 'light switch': 239, 'house': 240, 'horse': 241, 'stationary bicycle': 242, 'hammer': 243, 'ceiling fan': 244, 'sofa bed': 245, 'adhesive tape': 246, 'harp': 247, 'sandal': 248, 'bicycle helmet': 249, 'saucer': 250, 'harpsichord': 251, 'human hair': 252, 'heater': 253, 'harmonica': 254, 'hamster': 255, 'curtain': 256, 'bed': 257, 'kettle': 258, 'fireplace': 259, 'scale': 260, 'drinking straw': 261, 'insect': 262, 'hair dryer': 263, 'kitchenware': 264, 'indoor rower': 265, 'invertebrate': 266, 'food processor': 267, 'bookcase': 268, 'refrigerator': 269, 'wood-burning stove': 270, 'punching bag': 271, 'common fig': 272, 'cocktail shaker': 273, 'jaguar (animal)': 274, 'golf ball': 275, 'fashion accessory': 276, 'alarm clock': 277, 'filing cabinet': 278, 'artichoke': 279, 'table': 280, 'tableware': 281, 'kangaroo': 282, 'koala': 283, 'knife': 284, 'bottle': 285, 'bottle opener': 286, 'lynx': 287, 'lavender (plant)': 288, 'lighthouse': 289, 'dumbbell': 290, 'human head': 291, 'bowl': 292, 'humidifier': 293, 'porch': 294, 'lizard': 295, 'billiard table': 296, 'mammal': 297, 'mouse': 298, 'motorcycle': 299, 'musical instrument': 300, 'swim cap': 301, 'frying pan': 302, 'snowplow': 303, 'bathroom cabinet': 304, 'missile': 305, 'bust': 306, 'man': 307, 'waffle iron': 308, 'milk': 309, 'ring binder': 310, 'plate': 311, 'mobile phone': 312, 'baked goods': 313, 'mushroom': 314, 'crutch': 315, 'pitcher (container)': 316, 'mirror': 317, 'personal flotation device': 318, 'table tennis racket': 319, 'pencil case': 320, 'musical keyboard': 321, 'scoreboard': 322, 'briefcase': 323, 'kitchen knife': 324, 'nail (construction)': 325, 'tennis ball': 326, 'plastic bag': 327, 'oboe': 328, 'chest of drawers': 329, 'ostrich': 330, 'piano': 331, 'girl': 332, 'plant': 333, 'potato': 334, 'hair spray': 335, 'sports equipment': 336, 'pasta': 337, 'penguin': 338, 'pumpkin': 339, 'pear': 340, 'infant bed': 341, 'polar bear': 342, 'mixer': 343, 'cupboard': 344, 'jacuzzi': 345, 'pizza': 346, 'digital clock': 347, 'pig': 348, 'reptile': 349, 'rifle': 350, 'lipstick': 351, 'skateboard': 352, 'raven': 353, 'high heels': 354, 'red panda': 355, 'rose': 356, 'rabbit': 357, 'sculpture': 358, 'saxophone': 359, 'shotgun': 360, 'seafood': 361, 'submarine sandwich': 362, 'snowboard': 363, 'sword': 364, 'picture frame': 365, 'sushi': 366, 'loveseat': 367, 'ski': 368, 'squirrel': 369, 'tripod': 370, 'stethoscope': 371, 'submarine': 372, 'scorpion': 373, 'segway': 374, 'training bench': 375, 'snake': 376, 'coffee table': 377, 'skyscraper': 378, 'sheep': 379, 'television': 380, 'trombone': 381, 'tea': 382, 'tank': 383, 'taco': 384, 'telephone': 385, 'torch': 386, 'tiger': 387, 'strawberry': 388, 'trumpet': 389, 'tree': 390, 'tomato': 391, 'train': 392, 'tool': 393, 'picnic basket': 394, 'cooking spray': 395, 'trousers': 396, 'bowling equipment': 397, 'football helmet': 398, 'truck': 399, 'measuring cup': 400, 'coffeemaker': 401, 'violin': 402, 'vehicle': 403, 'handbag': 404, 'paper cutter': 405, 'wine': 406, 'weapon': 407, 'wheel': 408, 'worm': 409, 'wok': 410, 'whale': 411, 'zebra': 412, 'auto part': 413, 'jug': 414, 'pizza cutter': 415, 'cream': 416, 'monkey': 417, 'lion': 418, 'bread': 419, 'platter': 420, 'chicken': 421, 'eagle': 422, 'helicopter': 423, 'owl': 424, 'duck': 425, 'turtle': 426, 'hippopotamus': 427, 'crocodile': 428, 'toilet': 429, 'toilet paper': 430, 'squid': 431, 'clothing': 432, 'footwear': 433, 'lemon': 434, 'spider': 435, 'deer': 436, 'frog': 437, 'banana': 438, 'rocket': 439, 'wine glass': 440, 'countertop': 441, 'tablet computer': 442, 'waste container': 443, 'swimming pool': 444, 'dog': 445, 'book': 446, 'elephant': 447, 'shark': 448, 'candle': 449, 'leopard': 450, 'axe': 451, 'hand dryer': 452, 'soap dispenser': 453, 'porcupine': 454, 'flower': 455, 'canary': 456, 'cheetah': 457, 'palm tree': 458, 'hamburger': 459, 'maple': 460, 'building': 461, 'fish': 462, 'lobster': 463, 'garden asparagus': 464, 'furniture': 465, 'hedgehog': 466, 'airplane': 467, 'spoon': 468, 'otter': 469, 'bull': 470, 'oyster': 471, 'horizontal bar': 472, 'convenience store': 473, 'bomb': 474, 'bench': 475, 'ice cream': 476, 'caterpillar': 477, 'butterfly': 478, 'parachute': 479, 'orange': 480, 'antelope': 481, 'beaker': 482, 'moths and butterflies': 483, 'window': 484, 'closet': 485, 'castle': 486, 'jellyfish': 487, 'goose': 488, 'mule': 489, 'swan': 490, 'peach': 491, 'coconut': 492, 'seat belt': 493, 'raccoon': 494, 'chisel': 495, 'fork': 496, 'lamp': 497, 'camera': 498, 'squash (plant)': 499, 'racket': 500, 'human face': 501, 'human arm': 502, 'vegetable': 503, 'diaper': 504, 'unicycle': 505, 'falcon': 506, 'chime': 507, 'snail': 508, 'shellfish': 509, 'cabbage': 510, 'carrot': 511, 'mango': 512, 'jeans': 513, 'flowerpot': 514, 'pineapple': 515, 'drawer': 516, 'stool': 517, 'envelope': 518, 'cake': 519, 'dragonfly': 520, 'common sunflower': 521, 'microwave oven': 522, 'honeycomb': 523, 'marine mammal': 524, 'sea lion': 525, 'ladybug': 526, 'shelf': 527, 'watch': 528, 'candy': 529, 'salad': 530, 'parrot': 531, 'handgun': 532, 'sparrow': 533, 'van': 534, 'grinder': 535, 'spice rack': 536, 'light bulb': 537, 'corded phone': 538, 'sports uniform': 539, 'tennis racket': 540, 'wall clock': 541, 'serving tray': 542, 'kitchen & dining room table': 543, 'dog bed': 544, 'cake stand': 545, 'cat furniture': 546, 'bathroom accessory': 547, 'facial tissue holder': 548, 'pressure cooker': 549, 'kitchen appliance': 550, 'tire': 551, 'ruler': 552, 'luggage and bags': 553, 'microphone': 554, 'broccoli': 555, 'umbrella': 556, 'pastry': 557, 'grapefruit': 558, 'band-aid': 559, 'animal': 560, 'bell pepper': 561, 'turkey': 562, 'lily': 563, 'pomegranate': 564, 'doughnut': 565, 'glasses': 566, 'human nose': 567, 'pen': 568, 'ant': 569, 'car': 570, 'aircraft': 571, 'human hand': 572, 'skunk': 573, 'teddy bear': 574, 'watermelon': 575, 'cantaloupe': 576, 'dishwasher': 577, 'flute': 578, 'balance beam': 579, 'sandwich': 580, 'shrimp': 581, 'sewing machine': 582, 'binoculars': 583, 'rays and skates': 584, 'ipod': 585, 'accordion': 586, 'willow': 587, 'crab': 588, 'crown': 589, 'seahorse': 590, 'perfume': 591, 'alpaca': 592, 'taxi': 593, 'canoe': 594, 'remote control': 595, 'wheelchair': 596, 'rugby ball': 597, 'armadillo': 598, 'maracas': 599, 'helmet': 600} 12 | 13 | data_dicts = {'vg':vg, 'coco':coco, 'open_images':open_images} 14 | for data_name, data_dict in data_dicts.items(): 15 | print ("Processing:", data_name) 16 | inverse_dict = {data_dict[name]:name for name in data_dict.keys()} 17 | vec = torchtext.vocab.GloVe() 18 | embeddings = vec.get_vecs_by_tokens([inverse_dict[x] for x in range(len(data_dict.keys()))], lower_case_backup=True) 19 | with open('embeddings/glove_{}'.format(data_name),'wb') as f: 20 | torch.save({'embeddings':embeddings}, f) 21 | 22 | embedding_arr = [] 23 | for x in range(len(data_dict.keys())): 24 | embedding = torch.mean(vec.get_vecs_by_tokens(inverse_dict[x].split(' '), lower_case_backup=True), dim=0) 25 | embedding_arr.append(embedding) 26 | 27 | with open('embeddings/glove_mean_{}'.format(data_name),'wb') as f: 28 | torch.save({'embeddings':torch.stack(embedding_arr, dim=0)}, f) 29 | -------------------------------------------------------------------------------- /segmentationsg/data/remove_intersection.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import json 3 | 4 | coco_train = json.load(open('/h/skhandel/FewshotDetection/WSASOD/data/data_utils/data/MSCOCO2017/annotations/instances_train2017.json', 'r')) 5 | coco_test = json.load(open('/h/skhandel/FewshotDetection/WSASOD/data/data_utils/data/MSCOCO2017/annotations/instances_val2017.json', 'r')) 6 | 7 | vg_test_ids = pickle.load(open('/h/skhandel/SceneGraph/scripts/test_coco_ids.pkl', 'rb')) 8 | vg_train_ids = pickle.load(open('/h/skhandel/SceneGraph/scripts/VG_COCO_ids_train', 'rb')) 9 | coco_test_ids = [x['id'] for x in coco_test['images']] 10 | coco_train_ids = [x['id'] for x in coco_train['images']] 11 | num_overlap = 0 12 | for id in coco_test_ids: 13 | if id in vg_train_ids: 14 | num_overlap += 1 15 | 16 | num_overlap_1 = 0 17 | ids_to_remove = [] 18 | for id in vg_test_ids: 19 | if id in coco_train_ids: 20 | ids_to_remove.append(id) 21 | num_overlap_1 += 1 22 | 23 | num_overlap_2 = 0 24 | for id in vg_test_ids: 25 | if id in coco_test_ids: 26 | num_overlap_2 += 1 27 | 28 | num_overlap_3 = 0 29 | for id in coco_test_ids: 30 | if id in vg_test_ids: 31 | num_overlap_3 += 1 32 | 33 | new_coco_train = {} 34 | new_coco_train['info'] = coco_train['info'] 35 | new_coco_train['categories'] = coco_train['categories'] 36 | new_coco_train['licenses'] = coco_train['licenses'] 37 | new_coco_train['images'] = [] 38 | new_coco_train['annotations'] = [] 39 | 40 | new_coco_test = {} 41 | new_coco_test['info'] = coco_test['info'] 42 | new_coco_test['categories'] = coco_test['categories'] 43 | new_coco_test['licenses'] = coco_test['licenses'] 44 | new_coco_test['images'] = [] 45 | new_coco_test['annotations'] = [] 46 | 47 | 48 | for idx, data in enumerate(coco_train['images']): 49 | if data['id'] not in vg_test_ids: 50 | new_coco_train['images'].append(coco_train['images'][idx]) 51 | 52 | for idx, data in enumerate(coco_train['annotations']): 53 | if data['image_id'] not in vg_test_ids: 54 | new_coco_train['annotations'].append(coco_train['annotations'][idx]) 55 | 56 | for idx, data in enumerate(coco_test['images']): 57 | if data['id'] not in vg_train_ids: 58 | new_coco_test['images'].append(coco_test['images'][idx]) 59 | 60 | for idx, data in enumerate(coco_test['annotations']): 61 | if data['image_id'] not in vg_train_ids: 62 | new_coco_test['annotations'].append(coco_test['annotations'][idx]) 63 | 64 | with open('/h/skhandel/FewshotDetection/WSASOD/data/data_utils/data/MSCOCO2017/annotations/instances_train2017_clipped.json', 'w') as f: 65 | json.dump(new_coco_train, f) 66 | 67 | with open('/h/skhandel/FewshotDetection/WSASOD/data/data_utils/data/MSCOCO2017/annotations/instances_val2017_clipped.json', 'w') as f: 68 | json.dump(new_coco_test, f) 69 | 70 | with open('/h/skhandel/SceneGraph/scripts/coco_ids_to_remove.pkl', 'wb') as f: 71 | pickle.dump(coco_test_ids, f) 72 | 73 | import ipdb; ipdb.set_trace() 74 | a = 1 -------------------------------------------------------------------------------- /segmentationsg/data/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import add_dataset_config 2 | from .utils import register_datasets -------------------------------------------------------------------------------- /segmentationsg/data/tools/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from detectron2.config import CfgNode as CN 3 | 4 | def add_dataset_config(cfg): 5 | _C = cfg 6 | 7 | _C.MODEL.ROI_HEADS.NUM_OUTPUT_CLASSES = 80 8 | _C.MODEL.ROI_HEADS.EMBEDDINGS_PATH = "" 9 | _C.MODEL.ROI_HEADS.EMBEDDINGS_PATH_COCO = "" 10 | _C.MODEL.ROI_HEADS.LINGUAL_MATRIX_THRESHOLD = 0.05 11 | _C.MODEL.ROI_HEADS.MASK_NUM_CLASSES = 80 12 | 13 | _C.MODEL.FREEZE_LAYERS = CN() 14 | _C.MODEL.FREEZE_LAYERS.META_ARCH = [] 15 | _C.MODEL.FREEZE_LAYERS.ROI_HEADS = [] 16 | 17 | _C.DATASETS.TYPE = "" 18 | _C.DATASETS.VISUAL_GENOME = CN() 19 | _C.DATASETS.VISUAL_GENOME.IMAGES = '' 20 | _C.DATASETS.VISUAL_GENOME.MAPPING_DICTIONARY = '' 21 | _C.DATASETS.VISUAL_GENOME.IMAGE_DATA = '' 22 | _C.DATASETS.VISUAL_GENOME.VG_ATTRIBUTE_H5 = '' 23 | _C.DATASETS.VISUAL_GENOME.TRAIN_MASKS = "" 24 | _C.DATASETS.VISUAL_GENOME.TEST_MASKS = "" 25 | _C.DATASETS.VISUAL_GENOME.VAL_MASKS = "" 26 | _C.DATASETS.VISUAL_GENOME.CLIPPED = False 27 | 28 | _C.DATASETS.MSCOCO = CN() 29 | _C.DATASETS.MSCOCO.ANNOTATIONS = '' 30 | _C.DATASETS.MSCOCO.DATAROOT = '' 31 | 32 | _C.DATASETS.VISUAL_GENOME.FILTER_EMPTY_RELATIONS = True 33 | _C.DATASETS.VISUAL_GENOME.FILTER_DUPLICATE_RELATIONS = True 34 | _C.DATASETS.VISUAL_GENOME.FILTER_NON_OVERLAP = True 35 | _C.DATASETS.VISUAL_GENOME.NUMBER_OF_VALIDATION_IMAGES = 5000 36 | _C.DATASETS.VISUAL_GENOME.BOX_SCALE = 1024 37 | 38 | _C.DATASETS.SEG_DATA_DIVISOR = 1 39 | 40 | _C.DATASETS.TRANSFER = ('coco_train_2014',) 41 | _C.DATASETS.MASK_TRAIN = ('coco_train_2017',) 42 | _C.DATASETS.MASK_TEST = ('coco_val_2017',) -------------------------------------------------------------------------------- /segmentationsg/data/tools/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | from ..datasets import VisualGenomeTrainData 5 | 6 | def register_datasets(cfg): 7 | if cfg.DATASETS.TYPE == 'VISUAL GENOME': 8 | for split in ['train', 'val', 'test']: 9 | dataset_instance = VisualGenomeTrainData(cfg, split=split) 10 | -------------------------------------------------------------------------------- /segmentationsg/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .trainer import * 2 | from .sg_trainer import * -------------------------------------------------------------------------------- /segmentationsg/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_evaluation import * 2 | from .evaluator import scenegraph_inference_on_dataset 3 | from .sg_evaluation import SceneGraphEvaluator 4 | 5 | 6 | -------------------------------------------------------------------------------- /segmentationsg/evaluation/coco_evaluation.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import copy 3 | import io 4 | import itertools 5 | import json 6 | import logging 7 | import numpy as np 8 | import os 9 | import pickle 10 | from collections import OrderedDict 11 | import pycocotools.mask as mask_util 12 | import torch 13 | from fvcore.common.file_io import PathManager 14 | from pycocotools.coco import COCO 15 | from pycocotools.cocoeval import COCOeval 16 | from tabulate import tabulate 17 | 18 | import detectron2.utils.comm as comm 19 | from detectron2.data import MetadataCatalog 20 | from detectron2.data.datasets.coco import convert_to_coco_json 21 | from detectron2.evaluation.fast_eval_api import COCOeval_opt 22 | from detectron2.structures import Boxes, BoxMode, pairwise_iou 23 | from detectron2.utils.logger import create_small_table 24 | from detectron2.evaluation import COCOEvaluator 25 | 26 | class COCOEvaluatorWeakSegmentation(COCOEvaluator): 27 | def _tasks_from_config(self, cfg): 28 | """ 29 | Returns: 30 | tuple[str]: tasks that can be evaluated under the given configuration. 31 | """ 32 | tasks = ("bbox",) 33 | # if cfg.MODEL.MASK_ON: 34 | # tasks = tasks + ("segm",) 35 | # if cfg.MODEL.KEYPOINT_ON: 36 | # tasks = tasks + ("keypoints",) 37 | return tasks -------------------------------------------------------------------------------- /segmentationsg/evaluation/datasets/vg/zeroshot_triplet.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ubc-vision/segmentation-sg/ee3f05036489036bb6a326ae388da5a21c9fa354/segmentationsg/evaluation/datasets/vg/zeroshot_triplet.pytorch -------------------------------------------------------------------------------- /segmentationsg/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | from collections import OrderedDict 5 | from contextlib import contextmanager 6 | from collections import Counter 7 | import torch 8 | 9 | from detectron2.utils.comm import get_world_size, is_main_process 10 | from detectron2.utils.logger import log_every_n_seconds 11 | 12 | 13 | from detectron2.evaluation import COCOEvaluator 14 | 15 | 16 | def scenegraph_inference_on_dataset(cfg, model, data_loader, evaluator): 17 | """ 18 | Run model on the data_loader and evaluate the metrics with evaluator. 19 | Also benchmark the inference speed of `model.forward` accurately. 20 | The model will be used in eval mode. 21 | 22 | Args: 23 | model (nn.Module): a module which accepts an object from 24 | `data_loader` and returns some outputs. It will be temporarily set to `eval` mode. 25 | 26 | If you wish to evaluate a model in `training` mode instead, you can 27 | wrap the given model and override its behavior of `.eval()` and `.train()`. 28 | data_loader: an iterable object with a length. 29 | The elements it generates will be the inputs to the model. 30 | evaluator (DatasetEvaluator): the evaluator to run. Use `None` if you only want 31 | to benchmark, but don't want to do any evaluation. 32 | 33 | Returns: 34 | The return value of `evaluator.evaluate()` 35 | """ 36 | num_devices = get_world_size() 37 | logger = logging.getLogger('detectron2') 38 | logger.info("Start inference on {} images".format(len(data_loader))) 39 | 40 | total = len(data_loader) # inference data loader must have a fixed length 41 | 42 | # evaluator = COCOEvaluator(dataset_name, cfg, True, output_folder) 43 | 44 | evaluator.reset() 45 | num_warmup = min(5, total - 1) 46 | start_time = time.perf_counter() 47 | total_compute_time = 0 48 | with inference_context(model), torch.no_grad(): 49 | for idx, inputs in enumerate(data_loader): 50 | if idx == num_warmup: 51 | start_time = time.perf_counter() 52 | total_compute_time = 0 53 | 54 | if len(inputs[0]['instances']) > 40: 55 | continue 56 | start_compute_time = time.perf_counter() 57 | 58 | outputs = model(inputs) 59 | 60 | if torch.cuda.is_available(): 61 | torch.cuda.synchronize() 62 | total_compute_time += time.perf_counter() - start_compute_time 63 | evaluator.process(inputs, outputs) 64 | 65 | iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) 66 | seconds_per_img = total_compute_time / iters_after_start 67 | 68 | if idx >= num_warmup * 2 or seconds_per_img > 5: 69 | total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start 70 | eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) 71 | # logger.info("Inference done {}/{}. {:.4f} s / img. ETA={}".format(idx + 1, total, seconds_per_img, str(eta))) 72 | log_every_n_seconds( 73 | logging.INFO, 74 | "Inference done {}/{}. {:.4f} s / img. ETA={}".format( 75 | idx + 1, total, seconds_per_img, str(eta) 76 | ), 77 | n=5, 78 | name='detectron2' 79 | ) 80 | 81 | if cfg.DEV_RUN and idx==2: 82 | break 83 | 84 | 85 | # Measure the time only for this worker (before the synchronization barrier) 86 | total_time = time.perf_counter() - start_time 87 | total_time_str = str(datetime.timedelta(seconds=total_time)) 88 | # NOTE this format is parsed by grep 89 | logger.info( 90 | "Total inference time: {} ({:.6f} s / img per device, on {} devices)".format( 91 | total_time_str, total_time / (total - num_warmup), num_devices 92 | ) 93 | ) 94 | total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) 95 | logger.info( 96 | "Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format( 97 | total_compute_time_str, total_compute_time / (total - num_warmup), num_devices 98 | ) 99 | ) 100 | 101 | results = evaluator.evaluate() 102 | # An evaluator may return None when not in main process. 103 | # Replace it by an empty dict instead to make it easier for downstream code to handle 104 | if results is None: 105 | results = {} 106 | return results 107 | 108 | # _LOG_COUNTER = Counter() 109 | # _LOG_TIMER = {} 110 | 111 | # def log_every_n_seconds(lvl, msg, n=1, *, name=None): 112 | # """ 113 | # Log no more than once per n seconds. 114 | 115 | # Args: 116 | # lvl (int): the logging level 117 | # msg (str): 118 | # n (int): 119 | # name (str): name of the logger to use. Will use the caller's module by default. 120 | # """ 121 | # caller_module, key = _find_caller() 122 | # last_logged = _LOG_TIMER.get(key, None) 123 | # current_time = time.time() 124 | # if last_logged is None or current_time - last_logged >= n: 125 | # logging.getLogger('detectron2').log(lvl, msg) 126 | # _LOG_TIMER[key] = current_time 127 | 128 | @contextmanager 129 | def inference_context(model): 130 | """ 131 | A context where the model is temporarily changed to eval mode, 132 | and restored to previous mode afterwards. 133 | 134 | Args: 135 | model: a torch Module 136 | """ 137 | training_mode = model.training 138 | model.eval() 139 | yield 140 | model.train(training_mode) -------------------------------------------------------------------------------- /segmentationsg/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def intersect_2d(x1, x2): 4 | """ 5 | Given two arrays [m1, n], [m2,n], returns a [m1, m2] array where each entry is True if those 6 | rows match. 7 | :param x1: [m1, n] numpy array 8 | :param x2: [m2, n] numpy array 9 | :return: [m1, m2] bool array of the intersections 10 | """ 11 | if x1.shape[1] != x2.shape[1]: 12 | raise ValueError("Input arrays must have same #columns") 13 | 14 | # This performs a matrix multiplication-esque thing between the two arrays 15 | # Instead of summing, we want the equality, so we reduce in that way 16 | res = (x1[..., None] == x2.T[None, ...]).all(1) 17 | return res 18 | 19 | def argsort_desc(scores): 20 | """ 21 | Returns the indices that sort scores descending in a smart way 22 | :param scores: Numpy array of arbitrary size 23 | :return: an array of size [numel(scores), dim(scores)] where each row is the index you'd 24 | need to get the score. 25 | """ 26 | return np.column_stack(np.unravel_index(np.argsort(-scores.ravel()), scores.shape)) 27 | -------------------------------------------------------------------------------- /segmentationsg/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .meta_arch import * 2 | from .roi_heads import StandardSGROIHeads -------------------------------------------------------------------------------- /segmentationsg/modeling/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import * -------------------------------------------------------------------------------- /segmentationsg/modeling/backbone/vgg.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import fvcore.nn.weight_init as weight_init 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn 6 | from detectron2.modeling.backbone.build import BACKBONE_REGISTRY 7 | from detectron2.layers import ( 8 | CNNBlockBase, 9 | Conv2d, 10 | DeformConv, 11 | ModulatedDeformConv, 12 | ShapeSpec, 13 | get_norm, 14 | ) 15 | from torchvision.models.vgg import vgg16 16 | from detectron2.modeling.backbone import Backbone 17 | 18 | def load_vgg(use_dropout=True, use_relu=True, use_linear=True, pretrained=True): 19 | model = vgg16(pretrained=pretrained) 20 | del model.features._modules['30'] # Get rid of the maxpool 21 | del model.classifier._modules['6'] # Get rid of class layer 22 | if not use_dropout: 23 | del model.classifier._modules['5'] # Get rid of dropout 24 | if not use_relu: 25 | del model.classifier._modules['4'] # Get rid of relu activation 26 | if not use_linear: 27 | del model.classifier._modules['3'] # Get rid of linear layer 28 | convs = model.features 29 | fc = model.classifier 30 | return convs, fc 31 | 32 | def get_conv_scale(convs): 33 | """ 34 | Determines the downscaling performed by a sequence of convolutional and pooling layers 35 | """ 36 | scale = 1. 37 | channels = 3 38 | for c in convs: 39 | stride = getattr(c, 'stride', 1.) 40 | scale *= stride if isinstance(stride, (int, float)) else stride[0] 41 | channels = getattr(c, 'out_channels') if isinstance(c, nn.Conv2d) else channels 42 | return scale, channels 43 | 44 | @BACKBONE_REGISTRY.register() 45 | class VGG(Backbone): 46 | def __init__(self, cfg, input_shape): 47 | super().__init__() 48 | convs, _ = load_vgg(pretrained=True) 49 | self.convs = convs 50 | self.scale, self.channels = get_conv_scale(convs) 51 | self._out_features = ['vgg_conv'] 52 | 53 | def output_shape(self): 54 | return { 55 | name: ShapeSpec( 56 | channels=self.channels, stride=self.scale 57 | ) 58 | for name in self._out_features 59 | } 60 | def forward(self, x): 61 | output = self.convs(x) 62 | return {self._out_features[0]: output} -------------------------------------------------------------------------------- /segmentationsg/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .rcnn import * -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .roi_heads import * 2 | from .mask_head import * 3 | from .box_head import * -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/box_head.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import List 3 | import fvcore.nn.weight_init as weight_init 4 | import torch 5 | from torch import nn 6 | import logging 7 | from detectron2.config import configurable 8 | from detectron2.layers import Conv2d, Linear, ShapeSpec, get_norm 9 | from detectron2.utils.registry import Registry 10 | from detectron2.modeling.roi_heads.box_head import ROI_BOX_HEAD_REGISTRY 11 | 12 | from ..backbone import load_vgg 13 | from torchvision import models as M, ops 14 | from detectron2.modeling.backbone.resnet import BottleneckBlock, ResNet 15 | 16 | @ROI_BOX_HEAD_REGISTRY.register() 17 | class VGGConvFCHead(nn.Module): 18 | def __init__(self, cfg, input_shape): 19 | super().__init__() 20 | _, fc = load_vgg(pretrained=True) 21 | _output_size = input_shape.channels 22 | for c in fc: 23 | _output_size = getattr(c, 'out_features') if isinstance(c, nn.Linear) else _output_size 24 | self.fc = fc 25 | self._output_size = _output_size 26 | 27 | def forward(self, x): 28 | x = x.flatten(1) 29 | return self.fc(x) 30 | 31 | @property 32 | @torch.jit.unused 33 | def output_shape(self): 34 | """ 35 | Returns: 36 | ShapeSpec: the output feature shape 37 | """ 38 | o = self._output_size 39 | if isinstance(o, int): 40 | return ShapeSpec(channels=o) 41 | else: 42 | return ShapeSpec(channels=o[0], height=o[1], width=o[2]) -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/mask_head.py: -------------------------------------------------------------------------------- 1 | import fvcore.nn.weight_init as weight_init 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from detectron2.config import configurable 7 | from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm 8 | from detectron2.structures import Instances 9 | from detectron2.modeling.roi_heads.mask_head import MaskRCNNConvUpsampleHead, ROI_MASK_HEAD_REGISTRY 10 | from detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference, mask_rcnn_loss 11 | from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm 12 | from detectron2.utils.events import get_event_storage 13 | from detectron2.structures import Boxes, Instances 14 | from detectron2.modeling.roi_heads.roi_heads import select_foreground_proposals 15 | import copy 16 | 17 | @ROI_MASK_HEAD_REGISTRY.register() 18 | class SceneGraphMaskHeadAllClasses(MaskRCNNConvUpsampleHead): 19 | def forward(self, x, pred_instances): 20 | x = self.layers(x) 21 | 22 | mask_probs_pred = x.sigmoid() 23 | num_boxes_per_image = [len(i) for i in pred_instances] 24 | mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0) 25 | 26 | for prob, instances in zip(mask_probs_pred, pred_instances): 27 | instances.pred_masks = prob # (1, Hmask, Wmask) 28 | 29 | return pred_instances 30 | 31 | @ROI_MASK_HEAD_REGISTRY.register() 32 | class SceneGraphMaskHead(MaskRCNNConvUpsampleHead): 33 | def forward(self, x, instances): 34 | x = self.layers(x) 35 | mask_rcnn_inference(x, instances) 36 | return instances 37 | 38 | @ROI_MASK_HEAD_REGISTRY.register() 39 | class SceneGraphMaskHeadTransfer(MaskRCNNConvUpsampleHead): 40 | @classmethod 41 | def from_config(cls, cfg, input_shape): 42 | ret = super().from_config(cfg, input_shape) 43 | if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK: 44 | ret["num_classes"] = 1 45 | else: 46 | ret["num_classes"] = cfg.MODEL.ROI_HEADS.MASK_NUM_CLASSES 47 | return ret 48 | 49 | def forward(self, x, pred_instances, similarity_matrix, base_class_indexer, novel_class_indexer, output_to_coco_indexer, segmentation_step=False, residual_masks=None): 50 | x = self.layers(x) 51 | if residual_masks is not None: 52 | x = x + residual_masks 53 | if not segmentation_step: 54 | #Get mask for output class 55 | base_class_mask = x.index_select(1, output_to_coco_indexer) 56 | novel_class_mask = torch.bmm(similarity_matrix.unsqueeze(0).expand(x.size(0),-1,-1), x.view(*x.size()[:2],-1)).view(x.size(0), -1, *x.size()[2:]) 57 | output_class_mask = torch.zeros(x.size(0), base_class_mask.size(1) + novel_class_mask.size(1), *x.size()[2:]).to(x.device) 58 | output_class_mask = output_class_mask.index_copy(1, base_class_indexer, base_class_mask) 59 | output_class_mask = output_class_mask.index_copy(1, novel_class_indexer, novel_class_mask) 60 | mask_probs_pred = output_class_mask.sigmoid() 61 | else: 62 | mask_probs_pred = x.sigmoid() 63 | num_boxes_per_image = [len(i) for i in pred_instances] 64 | mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0) 65 | for prob, instances in zip(mask_probs_pred, pred_instances): 66 | instances.pred_masks = prob # (1, Hmask, Wmask) 67 | if not self.training: 68 | instances.pred_masks_base = prob.detach().clone() 69 | 70 | # if not self.training: 71 | # pred_mask_logits = pred_instances[0].pred_masks_base 72 | # num_masks = pred_mask_logits.shape[0] 73 | # class_pred = cat([i.pred_classes for i in pred_instances]) 74 | # indices = torch.arange(num_masks, device=class_pred.device) 75 | # mask_probs_pred_copy = pred_mask_logits[indices, class_pred][:, None] 76 | # num_boxes_per_image = [len(i) for i in pred_instances] 77 | # mask_probs_pred_copy = mask_probs_pred_copy.split(num_boxes_per_image, dim=0) 78 | # for prob, instances in zip(mask_probs_pred_copy, pred_instances): 79 | # instances.pred_masks_base = prob # (1, Hmask, Wmask) 80 | return pred_instances 81 | 82 | @ROI_MASK_HEAD_REGISTRY.register() 83 | class SceneGraphMaskHeadTransferSingleClass(MaskRCNNConvUpsampleHead): 84 | @classmethod 85 | def from_config(cls, cfg, input_shape): 86 | ret = super().from_config(cfg, input_shape) 87 | if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK: 88 | ret["num_classes"] = 1 89 | else: 90 | ret["num_classes"] = cfg.MODEL.ROI_HEADS.MASK_NUM_CLASSES 91 | return ret 92 | 93 | def forward(self, x, pred_instances, similarity_matrix, base_class_indexer, novel_class_indexer, output_to_coco_indexer, segmentation_step=False, residual_masks=None): 94 | x = self.layers(x) 95 | if residual_masks is not None: 96 | x = x + residual_masks 97 | if not segmentation_step: 98 | #Get mask for output class 99 | base_class_mask = x.index_select(1, output_to_coco_indexer) 100 | novel_class_mask = torch.bmm(similarity_matrix.unsqueeze(0).expand(x.size(0),-1,-1), x.view(*x.size()[:2],-1)).view(x.size(0), -1, *x.size()[2:]) 101 | output_class_mask = torch.zeros(x.size(0), base_class_mask.size(1) + novel_class_mask.size(1), *x.size()[2:]).to(x.device) 102 | output_class_mask = output_class_mask.index_copy(1, base_class_indexer, base_class_mask) 103 | output_class_mask = output_class_mask.index_copy(1, novel_class_indexer, novel_class_mask) 104 | mask_rcnn_inference(output_class_mask, pred_instances) 105 | else: 106 | try: 107 | mask_rcnn_inference(x, pred_instances) 108 | except: 109 | if not self.training: 110 | pred_instances[0].pred_masks = torch.zeros_like(x).narrow(1, 0, 1) 111 | else: 112 | pass 113 | return pred_instances 114 | 115 | 116 | @ROI_MASK_HEAD_REGISTRY.register() 117 | class SGSceneGraphMaskHead(SceneGraphMaskHeadTransfer): 118 | @configurable 119 | def __init__(self, input_shape, *, num_classes, conv_dims, conv_norm="", **kwargs): 120 | self.use_only_fg_proposals = kwargs['use_only_fg_proposals'] 121 | self.num_classes = num_classes 122 | del kwargs['use_only_fg_proposals'] 123 | super(SGSceneGraphMaskHead, self).__init__(input_shape=input_shape, num_classes=num_classes, conv_dims=conv_dims, conv_norm=conv_norm, **kwargs) 124 | nn.init.constant_(self.predictor.weight, 0) 125 | if self.predictor.bias is not None: 126 | nn.init.constant_(self.predictor.bias, 0) 127 | 128 | @classmethod 129 | def from_config(cls, cfg, input_shape): 130 | ret = super().from_config(cfg, input_shape) 131 | if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK: 132 | ret["num_classes"] = 1 133 | else: 134 | ret["num_classes"] = cfg.MODEL.ROI_HEADS.MASK_NUM_CLASSES 135 | ret['use_only_fg_proposals'] = cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_ONLY_FG_PROPOSALS 136 | return ret 137 | 138 | def forward(self, x, masks, proposals, eps=1e-8, return_masks=False): 139 | masks = torch.cat(masks) 140 | masks = -1 * torch.log((1.0 / (masks + eps)) - 1) 141 | if (not self.use_only_fg_proposals) and self.training: 142 | proposals, fg_indices = select_foreground_proposals(proposals, self.num_classes) 143 | fg_indices = torch.cat(fg_indices, 0) 144 | x = x[fg_indices] 145 | masks = masks[fg_indices] 146 | x = self.layers(x) 147 | if return_masks: 148 | return x 149 | 150 | combined_masks = x + masks 151 | if self.training: 152 | for proposal in proposals: 153 | if not proposal.has('proposal_boxes'): 154 | proposal.proposal_boxes = Boxes(proposal.pred_boxes.tensor.detach().clone()) 155 | loss = mask_rcnn_loss(combined_masks, proposals) 156 | if torch.any(torch.isnan(loss)): 157 | loss = torch.sum(x) * 0.0 158 | return {"loss_mask_segmentation" : loss}, proposals 159 | else: 160 | mask_rcnn_inference(combined_masks, proposals) 161 | return {}, proposals 162 | 163 | 164 | 165 | @ROI_MASK_HEAD_REGISTRY.register() 166 | class MaskLabelRCNNHead(MaskRCNNConvUpsampleHead): 167 | def forward(self, x, instances, similarity_matrix, base_class_indexer, novel_class_indexer, output_to_coco_indexer): 168 | x = self.layers(x) 169 | #Get mask for output class 170 | base_class_mask = x.index_select(1, output_to_coco_indexer) 171 | novel_class_mask = torch.bmm(similarity_matrix.unsqueeze(0).expand(x.size(0),-1,-1), x.view(*x.size()[:2],-1)).view(x.size(0), -1, *x.size()[2:]) 172 | output_class_mask = torch.zeros(x.size(0), base_class_mask.size(1) + novel_class_mask.size(1), *x.size()[2:]).to(x.device) 173 | output_class_mask = output_class_mask.index_copy(1, base_class_indexer, base_class_mask) 174 | output_class_mask = output_class_mask.index_copy(1, novel_class_indexer, novel_class_mask) 175 | 176 | if self.training: 177 | raise NotImplementedError 178 | else: 179 | mask_rcnn_inference(output_class_mask, instances) 180 | return instances 181 | 182 | @ROI_MASK_HEAD_REGISTRY.register() 183 | class PretrainObjectDetectionMaskHead(MaskRCNNConvUpsampleHead): 184 | def forward(self, x, instances): 185 | """ 186 | Args: 187 | x: input region feature(s) provided by :class:`ROIHeads`. 188 | instances (list[Instances]): contains the boxes & labels corresponding 189 | to the input features. 190 | Exact format is up to its caller to decide. 191 | Typically, this is the foreground instances in training, with 192 | "proposal_boxes" field and other gt annotations. 193 | In inference, it contains boxes that are already predicted. 194 | Returns: 195 | A dict of losses in training. The predicted "instances" in inference. 196 | """ 197 | x = self.layers(x) 198 | if self.training: 199 | return {"loss_mask": mask_rcnn_loss_with_empty_polygons(x, instances, self.vis_period)} 200 | else: 201 | mask_rcnn_inference(x, instances) 202 | return instances 203 | 204 | @ROI_MASK_HEAD_REGISTRY.register() 205 | class MaskRCNNConvUpsampleHeadwithCOCO(MaskRCNNConvUpsampleHead): 206 | @classmethod 207 | def from_config(cls, cfg, input_shape): 208 | ret = super().from_config(cfg, input_shape) 209 | conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM 210 | num_conv = cfg.MODEL.ROI_MASK_HEAD.NUM_CONV 211 | ret.update( 212 | conv_dims=[conv_dim] * (num_conv + 1), # +1 for ConvTranspose 213 | conv_norm=cfg.MODEL.ROI_MASK_HEAD.NORM, 214 | input_shape=input_shape, 215 | ) 216 | if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK: 217 | ret["num_classes"] = 1 218 | else: 219 | ret["num_classes"] = cfg.MODEL.ROI_HEADS.MASK_NUM_CLASSES 220 | return ret 221 | 222 | def mask_rcnn_loss_with_empty_polygons(pred_mask_logits, instances, vis_period=0): 223 | cls_agnostic_mask = pred_mask_logits.size(1) == 1 224 | mask_side_len = pred_mask_logits.size(2) 225 | assert pred_mask_logits.size(2) == pred_mask_logits.size(3), "Mask prediction must be square!" 226 | 227 | gt_classes = [] 228 | gt_masks = [] 229 | gt_masks_nonzero = [] 230 | for instances_per_image in instances: 231 | if len(instances_per_image) == 0: 232 | continue 233 | current_gt_masks_nonzero = instances_per_image.gt_masks.nonempty() 234 | instances_per_image = instances_per_image[current_gt_masks_nonzero] 235 | if not cls_agnostic_mask: 236 | gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) 237 | gt_classes.append(gt_classes_per_image) 238 | gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize( 239 | instances_per_image.proposal_boxes.tensor, mask_side_len 240 | ).to(device=pred_mask_logits.device) 241 | # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len 242 | gt_masks.append(gt_masks_per_image) 243 | gt_masks_nonzero.append(current_gt_masks_nonzero) 244 | 245 | if len(gt_masks) == 0: 246 | return pred_mask_logits.sum() * 0 247 | 248 | gt_masks_nonzero = cat(gt_masks_nonzero, dim=0) 249 | gt_masks = cat(gt_masks, dim=0) 250 | pred_mask_logits = pred_mask_logits[gt_masks_nonzero] 251 | total_num_masks = pred_mask_logits.size(0) 252 | if cls_agnostic_mask: 253 | pred_mask_logits = pred_mask_logits[:, 0] 254 | else: 255 | indices = torch.arange(total_num_masks) 256 | gt_classes = cat(gt_classes, dim=0) 257 | pred_mask_logits = pred_mask_logits[indices, gt_classes] 258 | 259 | if gt_masks.dtype == torch.bool: 260 | gt_masks_bool = gt_masks 261 | else: 262 | # Here we allow gt_masks to be float as well (depend on the implementation of rasterize()) 263 | gt_masks_bool = gt_masks > 0.5 264 | gt_masks = gt_masks.to(dtype=torch.float32) 265 | 266 | # Log the training accuracy (using gt classes and 0.5 threshold) 267 | mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool 268 | mask_accuracy = 1 - (mask_incorrect.sum().item() / max(mask_incorrect.numel(), 1.0)) 269 | num_positive = gt_masks_bool.sum().item() 270 | false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max( 271 | gt_masks_bool.numel() - num_positive, 1.0 272 | ) 273 | false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(num_positive, 1.0) 274 | 275 | storage = get_event_storage() 276 | storage.put_scalar("mask_rcnn/accuracy", mask_accuracy) 277 | storage.put_scalar("mask_rcnn/false_positive", false_positive) 278 | storage.put_scalar("mask_rcnn/false_negative", false_negative) 279 | if vis_period > 0 and storage.iter % vis_period == 0: 280 | pred_masks = pred_mask_logits.sigmoid() 281 | vis_masks = torch.cat([pred_masks, gt_masks], axis=2) 282 | name = "Left: mask prediction; Right: mask GT" 283 | for idx, vis_mask in enumerate(vis_masks): 284 | vis_mask = torch.stack([vis_mask] * 3, axis=0) 285 | storage.put_image(name + f" ({idx})", vis_mask) 286 | 287 | mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits, gt_masks, reduction="mean") 288 | return mask_loss 289 | -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/__init__.py: -------------------------------------------------------------------------------- 1 | from .scenegraph_head import build_scenegraph_head 2 | from .defaults import add_scenegraph_config -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/box_feature_extractor.py: -------------------------------------------------------------------------------- 1 | import fvcore.nn.weight_init as weight_init 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from detectron2.utils.registry import Registry 7 | from detectron2.modeling.poolers import ROIPooler 8 | from detectron2.layers import ShapeSpec 9 | from detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm 10 | ROI_BOX_FEATURE_EXTRACTORS_REGISTRY = Registry("ROI_BOX_FEATURE_EXTRACTORS_REGISTRY") 11 | 12 | @ROI_BOX_FEATURE_EXTRACTORS_REGISTRY.register() 13 | class BoxFeatureExtractor(nn.Module): 14 | """ 15 | Class to pool the the features from different scale and flatten them using some fully connected layers. 16 | These feature will be used as node states for the scene graph. 17 | """ 18 | 19 | def __init__(self, cfg, input_shape): 20 | super(BoxFeatureExtractor, self).__init__() 21 | 22 | in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES 23 | pooler_resolution = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_RESOLUTION 24 | pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) 25 | sampling_ratio = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_SAMPLING_RATIO 26 | pooler_type = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_TYPE 27 | mask_on = cfg.MODEL.MASK_ON 28 | use_mask_in_box_features = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK 29 | pooler = ROIPooler( 30 | output_size=pooler_resolution, 31 | scales=pooler_scales, 32 | sampling_ratio=sampling_ratio, 33 | pooler_type=pooler_type 34 | ) 35 | in_channels = [input_shape[f].channels for f in in_features][0] 36 | 37 | input_size = in_channels * pooler_resolution ** 2 38 | self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES 39 | self.attention_type = cfg.MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE 40 | if mask_on and use_mask_in_box_features: 41 | # input_size = input_size * 2 42 | self.combined_mask_input = cfg.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.USE_MASK_COMBINER 43 | if not self.combined_mask_input: 44 | if self.attention_type == 'Diff_Channels': 45 | self.mask_combiner_box = nn.Conv2d(in_channels, in_channels - 10, kernel_size=3, padding=1) 46 | self.mask_combiner_mask = nn.Conv2d(self.num_classes, 10, kernel_size=3, padding=1) 47 | else: 48 | self.mask_combiner = nn.Conv2d(in_channels + self.num_classes, in_channels, kernel_size=3, padding=1) 49 | else: 50 | self.mask_feature_extractor = nn.Conv2d(1, in_channels, kernel_size=3, padding=1) 51 | self.mask_combiner = nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) 52 | representation_size = cfg.MODEL.ROI_BOX_HEAD.FC_DIM 53 | 54 | self.mask_on = mask_on 55 | self.use_mask_in_box_features = use_mask_in_box_features 56 | self.in_features = in_features 57 | # self.input_shape = shape 58 | self.pooler = pooler 59 | self.fc6 = make_fc(input_size, representation_size) 60 | 61 | out_dim = representation_size 62 | 63 | self.fc7 = make_fc(representation_size, out_dim) 64 | self.resize_channels = input_size 65 | self.out_channels = out_dim 66 | 67 | def forward(self, features, boxes, masks=None, logits=None, segmentation_step=False): 68 | features = [features[f] for f in self.in_features] 69 | box_features = self.pooler(features, boxes) 70 | if self.mask_on and (masks is not None) and self.use_mask_in_box_features: 71 | masks = torch.cat(masks) 72 | if self.attention_type == 'Zero': 73 | masks = torch.zeros_like(masks) 74 | if logits is not None: 75 | logits = torch.cat(logits) 76 | logits = logits.narrow(1, 0, masks.size(1)).unsqueeze(2).unsqueeze(3) 77 | masks = masks * logits 78 | print ("NOPE") 79 | if not self.combined_mask_input: 80 | if self.attention_type == 'Diff_Channels': 81 | box_features = self.mask_combiner_box(box_features) 82 | masks = self.mask_combiner_mask(masks) 83 | box_features = torch.cat([box_features, masks], 1) 84 | else: 85 | box_features = torch.cat([box_features, masks], 1) 86 | if box_features.size(0) > 500: 87 | # Do it in chunks 88 | box_features_chunks = torch.split(box_features, 100, dim=0) 89 | box_features_all = [] 90 | for idx, box_feature_chunk in enumerate(box_features_chunks): 91 | box_features_all.append(self.mask_combiner(box_feature_chunk)) 92 | box_features = torch.cat(box_features_all, dim=0) 93 | else: 94 | box_features = self.mask_combiner(box_features) 95 | else: 96 | mask_features = self.mask_feature_extractor(masks) 97 | box_features = torch.cat([box_features, mask_features], 1) 98 | if box_features.size(0) > 500: 99 | # Do it in chunks 100 | box_features_chunks = torch.split(box_features, 100, dim=0) 101 | box_features_all = [] 102 | for idx, box_feature_chunk in enumerate(box_features_chunks): 103 | box_features_all.append(self.mask_combiner(box_feature_chunk)) 104 | box_features = torch.cat(box_features_all, dim=0) 105 | else: 106 | box_features = self.mask_combiner(box_features) 107 | box_features = box_features.flatten(1) 108 | box_features = F.relu(self.fc6(box_features)) 109 | box_features = F.relu(self.fc7(box_features)) 110 | return box_features 111 | 112 | def forward_without_pool(self, x): 113 | x = x.view(x.size(0), -1) 114 | x = F.relu(self.fc6(x)) 115 | x = F.relu(self.fc7(x)) 116 | return x 117 | 118 | @ROI_BOX_FEATURE_EXTRACTORS_REGISTRY.register() 119 | class BoxFeatureSegmentationExtractor(nn.Module): 120 | """ 121 | Class to pool the the features from different scale and flatten them using some fully connected layers. 122 | These feature will be used as node states for the scene graph. 123 | """ 124 | 125 | def __init__(self, cfg, input_shape): 126 | super(BoxFeatureSegmentationExtractor, self).__init__() 127 | 128 | in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES 129 | pooler_resolution = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_RESOLUTION 130 | pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) 131 | sampling_ratio = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_SAMPLING_RATIO 132 | pooler_type = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_TYPE 133 | mask_on = cfg.MODEL.MASK_ON 134 | use_mask_in_box_features = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK 135 | pooler = ROIPooler( 136 | output_size=pooler_resolution, 137 | scales=pooler_scales, 138 | sampling_ratio=sampling_ratio, 139 | pooler_type=pooler_type 140 | ) 141 | in_channels = [input_shape[f].channels for f in in_features][0] 142 | input_size = in_channels * pooler_resolution ** 2 143 | self.num_classes = cfg.MODEL.ROI_HEADS.NUM_CLASSES 144 | self.mask_num_classes = cfg.MODEL.ROI_HEADS.MASK_NUM_CLASSES 145 | if mask_on and use_mask_in_box_features: 146 | # input_size = input_size * 2 147 | self.combined_mask_input = cfg.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.USE_MASK_COMBINER 148 | if not self.combined_mask_input: 149 | self.mask_combiner = nn.Conv2d(in_channels + self.num_classes, in_channels, kernel_size=3, padding=1) 150 | self.mask_combiner_segmentation = nn.Conv2d(in_channels + self.mask_num_classes, in_channels, kernel_size=3, padding=1) 151 | else: 152 | self.mask_feature_extractor = nn.Conv2d(1, in_channels, kernel_size=3, padding=1) 153 | self.mask_combiner = nn.Conv2d(in_channels*2, in_channels, kernel_size=3, padding=1) 154 | representation_size = cfg.MODEL.ROI_BOX_HEAD.FC_DIM 155 | self.attention_type = cfg.MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE 156 | self.mask_on = mask_on 157 | self.use_mask_in_box_features = use_mask_in_box_features 158 | self.in_features = in_features 159 | # self.input_shape = shape 160 | self.pooler = pooler 161 | self.fc6 = make_fc(input_size, representation_size) 162 | 163 | out_dim = representation_size 164 | 165 | self.fc7 = make_fc(representation_size, out_dim) 166 | self.resize_channels = input_size 167 | self.out_channels = out_dim 168 | 169 | def forward(self, features, boxes, masks=None, logits=None, segmentation_step=False): 170 | features = [features[f] for f in self.in_features] 171 | box_features = self.pooler(features, boxes) 172 | if self.mask_on and (masks is not None) and self.use_mask_in_box_features: 173 | masks = torch.cat(masks) 174 | if self.attention_type == 'Zero': 175 | masks = torch.zeros_like(masks) 176 | if logits is not None: 177 | logits = torch.cat(logits) 178 | logits = logits.narrow(1, 0, masks.size(1)).unsqueeze(2).unsqueeze(3) 179 | masks = masks * logits 180 | print ("NOPE") 181 | if not self.combined_mask_input: 182 | box_features = torch.cat([box_features, masks], 1) 183 | if box_features.size(0) > 500: 184 | # Do it in chunks 185 | box_features_chunks = torch.split(box_features, 100, dim=0) 186 | box_features_all = [] 187 | for idx, box_feature_chunk in enumerate(box_features_chunks): 188 | if not segmentation_step: 189 | box_features_all.append(self.mask_combiner(box_feature_chunk)) 190 | else: 191 | box_features_all.append(self.mask_combiner_segmentation(box_feature_chunk)) 192 | box_features = torch.cat(box_features_all, dim=0) 193 | else: 194 | if not segmentation_step: 195 | box_features = self.mask_combiner(box_features) 196 | else: 197 | box_features = self.mask_combiner_segmentation(box_features) 198 | else: 199 | mask_features = self.mask_feature_extractor(masks) 200 | box_features = torch.cat([box_features, mask_features], 1) 201 | if box_features.size(0) > 500: 202 | # Do it in chunks 203 | box_features_chunks = torch.split(box_features, 100, dim=0) 204 | box_features_all = [] 205 | for idx, box_feature_chunk in enumerate(box_features_chunks): 206 | box_features_all.append(self.mask_combiner(box_feature_chunk)) 207 | box_features = torch.cat(box_features_all, dim=0) 208 | else: 209 | box_features = self.mask_combiner(box_features) 210 | box_features = box_features.flatten(1) 211 | box_features = F.relu(self.fc6(box_features)) 212 | box_features = F.relu(self.fc7(box_features)) 213 | return box_features 214 | 215 | def forward_without_pool(self, x): 216 | x = x.view(x.size(0), -1) 217 | x = F.relu(self.fc6(x)) 218 | x = F.relu(self.fc7(x)) 219 | return x 220 | 221 | def make_fc(dim_in, hidden_dim): 222 | ''' 223 | Make Fully connected Layer with xavier initialization 224 | ''' 225 | 226 | fc = nn.Linear(dim_in, hidden_dim) 227 | # nn.init.kaiming_uniform_(fc.weight, a=1) 228 | # nn.init.constant_(fc.bias, 0) 229 | weight_init.c2_xavier_fill(fc) 230 | return fc 231 | 232 | def build_box_feature_extractor(cfg, in_channels): 233 | name = cfg.MODEL.ROI_BOX_FEATURE_EXTRACTORS.NAME 234 | return ROI_BOX_FEATURE_EXTRACTORS_REGISTRY.get(name)(cfg, in_channels) 235 | -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/defaults.py: -------------------------------------------------------------------------------- 1 | import os 2 | from detectron2.config import CfgNode as CN 3 | 4 | def add_scenegraph_config(cfg): 5 | _C = cfg 6 | 7 | _C.GLOVE_DIR = '../glove/' 8 | _C.DEV_RUN = False 9 | ################################################################################################### 10 | 11 | _C.MODEL.SCENEGRAPH_ON = True 12 | _C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = True 13 | _C.MODEL.USE_MASK_ON_NODE = False 14 | _C.MODEL.ROI_HEADS.OBJECTNESS_THRESH = 0.3 15 | _C.MODEL.GROUP_NORM = CN() 16 | _C.MODEL.GROUP_NORM.DIM_PER_GP = -1 17 | _C.MODEL.GROUP_NORM.NUM_GROUPS = 32 18 | _C.MODEL.GROUP_NORM.EPSILON = 1e-5 # default: 1e-5 19 | ################################################################################################### 20 | _C.MODEL.ROI_SCENEGRAPH_HEAD = CN() 21 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NAME = 'SceneGraphHead' 22 | _C.MODEL.ROI_SCENEGRAPH_HEAD.MODE = 'predcls' 23 | _C.MODEL.ROI_SCENEGRAPH_HEAD.REQUIRE_BOX_OVERLAP = True 24 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NUM_SAMPLE_PER_GT_REL = 4 # when sample fg relationship from gt, the max number of corresponding proposal pairs 25 | _C.MODEL.ROI_SCENEGRAPH_HEAD.BATCH_SIZE_PER_IMAGE = 64 26 | _C.MODEL.ROI_SCENEGRAPH_HEAD.POSITIVE_FRACTION = 0.25 27 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX = True 28 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL = True 29 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NMS_FILTER_DUPLICATES = True 30 | 31 | _C.MODEL.ROI_SCENEGRAPH_HEAD.RETURN_SEG_MASKS = False 32 | _C.MODEL.ROI_SCENEGRAPH_HEAD.RETURN_SEG_ANNOS = False 33 | 34 | _C.MODEL.ROI_SCENEGRAPH_HEAD.PREDICT_USE_VISION = True 35 | _C.MODEL.ROI_SCENEGRAPH_HEAD.PREDICT_USE_BIAS = True 36 | 37 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_MASK_ATTENTION = True 38 | _C.MODEL.ROI_SCENEGRAPH_HEAD.MASK_ATTENTION_TYPE = 'Weighted' 39 | _C.MODEL.ROI_SCENEGRAPH_HEAD.SIGMOID_ATTENTION = True 40 | 41 | _C.MODEL.ROI_SCENEGRAPH_HEAD.PREDICTOR = "MotifPredictor" 42 | _C.MODEL.ROI_SCENEGRAPH_HEAD.NUM_CLASSES = 50 43 | _C.MODEL.ROI_SCENEGRAPH_HEAD.EMBED_DIM = 200 44 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_DROPOUT_RATE = 0.2 45 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_HIDDEN_DIM = 512 46 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_POOLING_DIM = 4096 47 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_OBJ_LAYER = 1 # assert >= 1 48 | _C.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_REL_LAYER = 1 # assert >= 1 49 | _C.MODEL.ROI_SCENEGRAPH_HEAD.ADD_GTBOX_TO_PROPOSAL_IN_TRAIN = True 50 | _C.MODEL.ROI_SCENEGRAPH_HEAD.SEG_BBOX_LOSS_MULTIPLIER = 1.0 51 | _C.MODEL.ROI_SCENEGRAPH_HEAD.USE_ONLY_FG_PROPOSALS = True 52 | 53 | _C.MODEL.ROI_SCENEGRAPH_HEAD.LABEL_SMOOTHING_LOSS = False 54 | _C.MODEL.ROI_SCENEGRAPH_HEAD.REL_PROP = [0.01858, 0.00057, 0.00051, 0.00109, 0.00150, 0.00489, 0.00432, 0.02913, 0.00245, 0.00121, 55 | 0.00404, 0.00110, 0.00132, 0.00172, 0.00005, 0.00242, 0.00050, 0.00048, 0.00208, 0.15608, 56 | 0.02650, 0.06091, 0.00900, 0.00183, 0.00225, 0.00090, 0.00028, 0.00077, 0.04844, 0.08645, 57 | 0.31621, 0.00088, 0.00301, 0.00042, 0.00186, 0.00100, 0.00027, 0.01012, 0.00010, 0.01286, 58 | 0.00647, 0.00084, 0.01077, 0.00132, 0.00069, 0.00376, 0.00214, 0.11424, 0.01205, 0.02958] 59 | 60 | _C.MODEL.ROI_SCENEGRAPH_HEAD.ZERO_SHOT_TRIPLETS = '../evaluation/datasets/vg/zeroshot_triplet.pytorch' 61 | 62 | #TransformerContext 63 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER = CN() 64 | # for TransformerPredictor only 65 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.DROPOUT_RATE = 0.1 66 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.OBJ_LAYER = 4 67 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.REL_LAYER = 2 68 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.NUM_HEAD = 8 69 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.INNER_DIM = 2048 70 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.KEY_DIM = 64 71 | _C.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.VAL_DIM = 64 72 | ################################################################################################### 73 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS = CN() 74 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.NAME = 'BoxFeatureExtractor' 75 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_RESOLUTION = 28 76 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_SAMPLING_RATIO = 0 77 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.POOLER_TYPE = 'ROIAlignV2' 78 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.BOX_FEATURE_MASK = True 79 | _C.MODEL.ROI_BOX_FEATURE_EXTRACTORS.CLASS_LOGITS_WITH_MASK = False 80 | 81 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS = CN() 82 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.NAME = 'RelationFeatureExtractor' 83 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.USE_MASK_COMBINER = False 84 | _C.MODEL.ROI_RELATION_FEATURE_EXTRACTORS.MULTIPLY_LOGITS_WITH_MASKS = False 85 | 86 | # Overlap threshold for an RoI to be considered foreground (if >= FG_IOU_THRESHOLD) 87 | _C.MODEL.ROI_HEADS.FG_IOU_THRESHOLD = 0.5 88 | _C.MODEL.ROI_HEADS.REFINE_SEG_MASKS = False 89 | _C.MODEL.ROI_HEADS.SEGMENTATION_STEP_MASK_REFINE = True 90 | 91 | # Settings for relation testing 92 | _C.TEST.RELATION = CN() 93 | _C.TEST.RELATION.REQUIRE_OVERLAP = True 94 | _C.TEST.RELATION.LATER_NMS_PREDICTION_THRES = 0.3 95 | _C.TEST.RELATION.MULTIPLE_PREDS = False 96 | _C.TEST.RELATION.IOU_THRESHOLD = 0.5 97 | 98 | 99 | _C.DATASETS.VISUAL_GENOME.CLIPPED = False -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/imp/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_imp import IMPContext, IMPSegmentationContext -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/imp/make_layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | """ 3 | Miscellaneous utility functions 4 | """ 5 | 6 | import torch 7 | from torch import nn 8 | from torch.nn import functional as F 9 | 10 | 11 | def get_group_gn(dim, dim_per_gp, num_groups): 12 | """get number of groups used by GroupNorm, based on number of channels.""" 13 | assert dim_per_gp == -1 or num_groups == -1, \ 14 | "GroupNorm: can only specify G or C/G." 15 | 16 | if dim_per_gp > 0: 17 | assert dim % dim_per_gp == 0, \ 18 | "dim: {}, dim_per_gp: {}".format(dim, dim_per_gp) 19 | group_gn = dim // dim_per_gp 20 | else: 21 | assert dim % num_groups == 0, \ 22 | "dim: {}, num_groups: {}".format(dim, num_groups) 23 | group_gn = num_groups 24 | 25 | return group_gn 26 | 27 | 28 | def group_norm(cfg, out_channels, affine=True, divisor=1): 29 | out_channels = out_channels // divisor 30 | dim_per_gp = cfg.MODEL.GROUP_NORM.DIM_PER_GP // divisor 31 | num_groups = cfg.MODEL.GROUP_NORM.NUM_GROUPS // divisor 32 | eps = cfg.MODEL.GROUP_NORM.EPSILON # default: 1e-5 33 | return torch.nn.GroupNorm( 34 | get_group_gn(out_channels, dim_per_gp, num_groups), 35 | out_channels, 36 | eps, 37 | affine 38 | ) 39 | 40 | 41 | 42 | def make_fc(dim_in, hidden_dim, use_gn=False): 43 | ''' 44 | Caffe2 implementation uses XavierFill, which in fact 45 | corresponds to kaiming_uniform_ in PyTorch 46 | ''' 47 | if use_gn: 48 | fc = nn.Linear(dim_in, hidden_dim, bias=False) 49 | nn.init.kaiming_uniform_(fc.weight, a=1) 50 | return nn.Sequential(fc, group_norm(hidden_dim)) 51 | fc = nn.Linear(dim_in, hidden_dim) 52 | nn.init.kaiming_uniform_(fc.weight, a=1) 53 | nn.init.constant_(fc.bias, 0) 54 | return fc 55 | 56 | 57 | -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/imp/model_imp.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/rowanz/neural-motifs 2 | import numpy as np 3 | import torch 4 | from torch import nn 5 | from torch.nn.utils.rnn import PackedSequence 6 | from torch.nn import functional as F 7 | from .make_layers import make_fc 8 | from ..motif.utils_motifs import obj_edge_vectors, center_x, sort_by_score, to_onehot, get_dropout_mask, encode_box_info, cat 9 | 10 | 11 | class IMPContext(nn.Module): 12 | def __init__(self, config, num_obj, num_rel, in_channels, hidden_dim=512, num_iter=3): 13 | super(IMPContext, self).__init__() 14 | self.cfg = config 15 | self.num_obj = num_obj + 1 16 | self.num_rel = num_rel + 1 17 | self.pooling_dim = config.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_POOLING_DIM 18 | self.hidden_dim = hidden_dim 19 | self.num_iter = num_iter 20 | # mode 21 | if self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX: 22 | if self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL: 23 | self.mode = 'predcls' 24 | else: 25 | self.mode = 'sgcls' 26 | else: 27 | self.mode = 'sgdet' 28 | 29 | # self.rel_fc = make_fc(hidden_dim, self.num_rel) 30 | # self.obj_fc = make_fc(hidden_dim, self.num_obj) 31 | 32 | # self.obj_unary = make_fc(in_channels, hidden_dim) 33 | # self.edge_unary = make_fc(self.pooling_dim, hidden_dim) 34 | 35 | self.rel_fc = nn.Linear(hidden_dim, self.num_rel) 36 | self.obj_fc = nn.Linear(hidden_dim, self.num_obj) 37 | 38 | self.obj_unary = nn.Linear(in_channels, hidden_dim) 39 | self.edge_unary = nn.Linear(self.pooling_dim, hidden_dim) 40 | 41 | self.edge_gru = nn.GRUCell(input_size=hidden_dim, hidden_size=hidden_dim) 42 | self.node_gru = nn.GRUCell(input_size=hidden_dim, hidden_size=hidden_dim) 43 | 44 | self.sub_vert_w_fc = nn.Sequential(nn.Linear(hidden_dim*2, 1), nn.Sigmoid()) 45 | self.obj_vert_w_fc = nn.Sequential(nn.Linear(hidden_dim*2, 1), nn.Sigmoid()) 46 | self.out_edge_w_fc = nn.Sequential(nn.Linear(hidden_dim*2, 1), nn.Sigmoid()) 47 | self.in_edge_w_fc = nn.Sequential(nn.Linear(hidden_dim*2, 1), nn.Sigmoid()) 48 | 49 | 50 | def forward(self, x, proposals, union_features, rel_pair_idxs, logger=None): 51 | num_objs = [len(b) for b in proposals] 52 | 53 | obj_rep = self.obj_unary(x) 54 | rel_rep = F.relu(self.edge_unary(union_features)) 55 | 56 | obj_count = obj_rep.shape[0] 57 | rel_count = rel_rep.shape[0] 58 | 59 | # generate sub-rel-obj mapping 60 | sub2rel = torch.zeros(obj_count, rel_count).to(obj_rep.device).float() 61 | obj2rel = torch.zeros(obj_count, rel_count).to(obj_rep.device).float() 62 | obj_offset = 0 63 | rel_offset = 0 64 | sub_global_inds = [] 65 | obj_global_inds = [] 66 | for pair_idx, num_obj in zip(rel_pair_idxs, num_objs): 67 | num_rel = pair_idx.shape[0] 68 | sub_idx = pair_idx[:,0].contiguous().long().view(-1) + obj_offset 69 | obj_idx = pair_idx[:,1].contiguous().long().view(-1) + obj_offset 70 | rel_idx = torch.arange(num_rel).to(obj_rep.device).long().view(-1) + rel_offset 71 | 72 | sub_global_inds.append(sub_idx) 73 | obj_global_inds.append(obj_idx) 74 | 75 | sub2rel[sub_idx, rel_idx] = 1.0 76 | obj2rel[obj_idx, rel_idx] = 1.0 77 | 78 | obj_offset += num_obj 79 | rel_offset += num_rel 80 | 81 | sub_global_inds = torch.cat(sub_global_inds, dim=0) 82 | obj_global_inds = torch.cat(obj_global_inds, dim=0) 83 | 84 | # iterative message passing 85 | hx_obj = torch.zeros(obj_count, self.hidden_dim, requires_grad=False).to(obj_rep.device).float() 86 | hx_rel = torch.zeros(rel_count, self.hidden_dim, requires_grad=False).to(obj_rep.device).float() 87 | 88 | vert_factor = [self.node_gru(obj_rep, hx_obj)] 89 | edge_factor = [self.edge_gru(rel_rep, hx_rel)] 90 | 91 | for i in range(self.num_iter): 92 | # compute edge context 93 | sub_vert = vert_factor[i][sub_global_inds] 94 | obj_vert = vert_factor[i][obj_global_inds] 95 | weighted_sub = self.sub_vert_w_fc( 96 | torch.cat((sub_vert, edge_factor[i]), 1)) * sub_vert 97 | weighted_obj = self.obj_vert_w_fc( 98 | torch.cat((obj_vert, edge_factor[i]), 1)) * obj_vert 99 | 100 | edge_factor.append(self.edge_gru(weighted_sub + weighted_obj, edge_factor[i])) 101 | 102 | # Compute vertex context 103 | pre_out = self.out_edge_w_fc(torch.cat((sub_vert, edge_factor[i]), 1)) * edge_factor[i] 104 | pre_in = self.in_edge_w_fc(torch.cat((obj_vert, edge_factor[i]), 1)) * edge_factor[i] 105 | vert_ctx = sub2rel @ pre_out + obj2rel @ pre_in 106 | vert_factor.append(self.node_gru(vert_ctx, vert_factor[i])) 107 | 108 | if self.mode == 'predcls': 109 | obj_labels = cat([proposal.pred_classes for proposal in proposals], dim=0) 110 | obj_dists = to_onehot(obj_labels, self.num_obj) 111 | else: 112 | obj_dists = self.obj_fc(vert_factor[-1]) 113 | 114 | rel_dists = self.rel_fc(edge_factor[-1]) 115 | 116 | return obj_dists, rel_dists 117 | 118 | 119 | 120 | class IMPSegmentationContext(IMPContext): 121 | def __init__(self, config, num_obj, num_rel, in_channels, hidden_dim=512, num_iter=3, mask_obj_classes=None): 122 | super(IMPSegmentationContext, self).__init__(config, num_obj, num_rel, in_channels, hidden_dim=hidden_dim, num_iter=num_iter) 123 | 124 | def forward(self, x, proposals, union_features, rel_pair_idxs, logger=None, mask_box_features=None, masks=None, segmentation_step=False, return_masks=False): 125 | num_objs = [len(b) for b in proposals] 126 | import ipdb; ipdb.set_trace() 127 | obj_rep = self.obj_unary(x) 128 | obj_count = obj_rep.shape[0] 129 | if not segmentation_step: 130 | rel_rep = F.relu(self.edge_unary(union_features)) 131 | rel_count = rel_rep.shape[0] 132 | 133 | # generate sub-rel-obj mapping 134 | sub2rel = torch.zeros(obj_count, rel_count).to(obj_rep.device).float() 135 | obj2rel = torch.zeros(obj_count, rel_count).to(obj_rep.device).float() 136 | obj_offset = 0 137 | rel_offset = 0 138 | sub_global_inds = [] 139 | obj_global_inds = [] 140 | for pair_idx, num_obj in zip(rel_pair_idxs, num_objs): 141 | num_rel = pair_idx.shape[0] 142 | sub_idx = pair_idx[:,0].contiguous().long().view(-1) + obj_offset 143 | obj_idx = pair_idx[:,1].contiguous().long().view(-1) + obj_offset 144 | rel_idx = torch.arange(num_rel).to(obj_rep.device).long().view(-1) + rel_offset 145 | 146 | sub_global_inds.append(sub_idx) 147 | obj_global_inds.append(obj_idx) 148 | 149 | sub2rel[sub_idx, rel_idx] = 1.0 150 | obj2rel[obj_idx, rel_idx] = 1.0 151 | 152 | obj_offset += num_obj 153 | rel_offset += num_rel 154 | 155 | sub_global_inds = torch.cat(sub_global_inds, dim=0) 156 | obj_global_inds = torch.cat(obj_global_inds, dim=0) 157 | 158 | # iterative message passing 159 | hx_obj = torch.zeros(obj_count, self.hidden_dim, requires_grad=False).to(obj_rep.device).float() 160 | vert_factor = [self.node_gru(obj_rep, hx_obj)] 161 | if not segmentation_step: 162 | hx_rel = torch.zeros(rel_count, self.hidden_dim, requires_grad=False).to(obj_rep.device).float() 163 | edge_factor = [self.edge_gru(rel_rep, hx_rel)] 164 | 165 | for i in range(self.num_iter): 166 | # compute edge context 167 | sub_vert = vert_factor[i][sub_global_inds] 168 | obj_vert = vert_factor[i][obj_global_inds] 169 | weighted_sub = self.sub_vert_w_fc( 170 | torch.cat((sub_vert, edge_factor[i]), 1)) * sub_vert 171 | weighted_obj = self.obj_vert_w_fc( 172 | torch.cat((obj_vert, edge_factor[i]), 1)) * obj_vert 173 | 174 | edge_factor.append(self.edge_gru(weighted_sub + weighted_obj, edge_factor[i])) 175 | 176 | # Compute vertex context 177 | pre_out = self.out_edge_w_fc(torch.cat((sub_vert, edge_factor[i]), 1)) * edge_factor[i] 178 | pre_in = self.in_edge_w_fc(torch.cat((obj_vert, edge_factor[i]), 1)) * edge_factor[i] 179 | vert_ctx = sub2rel @ pre_out + obj2rel @ pre_in 180 | vert_factor.append(self.node_gru(vert_ctx, vert_factor[i])) 181 | 182 | if self.mode == 'predcls': 183 | obj_labels = cat([proposal.pred_classes for proposal in proposals], dim=0) 184 | obj_dists = to_onehot(obj_labels, self.num_obj) 185 | else: 186 | obj_dists = self.obj_fc(vert_factor[-1]) 187 | 188 | rel_dists = self.rel_fc(edge_factor[-1]) 189 | 190 | return obj_dists, rel_dists -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from detectron2.structures.instances import Instances 6 | from detectron2.structures.boxes import Boxes 7 | from .utils import obj_prediction_nms 8 | 9 | class PostProcessor(nn.Module): 10 | """ 11 | From a set of classification scores, box regression and proposals, 12 | computes the post-processed boxes, and applies NMS to obtain the 13 | final results 14 | """ 15 | 16 | def __init__( 17 | self, 18 | use_gt_box=False, 19 | later_nms_pred_thres=0.3, 20 | ): 21 | """ 22 | Arguments: 23 | """ 24 | super(PostProcessor, self).__init__() 25 | self.use_gt_box = use_gt_box 26 | self.later_nms_pred_thres = later_nms_pred_thres 27 | 28 | def forward(self, x, rel_pair_idxs, boxes, img_sizes, segmentation_vis=False): 29 | """ 30 | Arguments: 31 | x (tuple[tensor, tensor]): x contains the relation logits 32 | and finetuned object logits from the relation model. 33 | rel_pair_idxs (list[tensor]): subject and object indice of each relation, 34 | the size of tensor is (num_rel, 2) 35 | boxes (list[BoxList]): bounding boxes that are used as 36 | reference, one for ech image 37 | Returns: 38 | results (list[BoxList]): one BoxList for each image, containing 39 | the extra fields labels and scores 40 | """ 41 | relation_logits, refine_logits = x 42 | finetune_obj_logits = refine_logits 43 | 44 | results = [] 45 | for i, (rel_logit, obj_logit, rel_pair_idx, box, img_size) in enumerate(zip( 46 | relation_logits, finetune_obj_logits, rel_pair_idxs, boxes, img_sizes 47 | )): 48 | 49 | obj_class_prob = F.softmax(obj_logit, -1) 50 | obj_class_prob[:, -1] = 0 # set background score to 0 51 | num_obj_bbox = obj_class_prob.shape[0] 52 | num_obj_class = obj_class_prob.shape[1] 53 | 54 | if self.use_gt_box: 55 | obj_scores, obj_pred = obj_class_prob[:, :-1].max(dim=1) 56 | else: 57 | # apply late nms for object prediction 58 | obj_pred = obj_prediction_nms(box.boxes_per_cls, obj_logit, self.later_nms_pred_thres) 59 | obj_score_ind = torch.arange(num_obj_bbox, device=obj_logit.device) * num_obj_class + obj_pred 60 | obj_scores = obj_class_prob.view(-1)[obj_score_ind] 61 | 62 | assert obj_scores.shape[0] == num_obj_bbox 63 | obj_class = obj_pred 64 | 65 | result = Instances(img_size) 66 | 67 | if self.use_gt_box: 68 | result.pred_boxes = box 69 | else: 70 | # mode==sgdet 71 | # apply regression based on finetuned object class 72 | #FIXME 73 | device = obj_class.device 74 | batch_size = obj_class.shape[0] 75 | regressed_box_idxs = obj_class 76 | result.pred_boxes = Boxes(box.boxes_per_cls[torch.arange(batch_size, device=device), regressed_box_idxs]) 77 | 78 | result.pred_classes = obj_class 79 | result.scores = obj_scores 80 | 81 | # sorting triples according to score production 82 | obj_scores0 = obj_scores[rel_pair_idx[:, 0]] 83 | obj_scores1 = obj_scores[rel_pair_idx[:, 1]] 84 | rel_class_prob = F.softmax(rel_logit, -1) 85 | rel_scores, rel_class = rel_class_prob[:, :-1].max(dim=1) 86 | 87 | triple_scores = rel_scores * obj_scores0 * obj_scores1 88 | _, sorting_idx = torch.sort(triple_scores.view(-1), dim=0, descending=True) 89 | rel_pair_idx = rel_pair_idx[sorting_idx] 90 | rel_class_prob = rel_class_prob[sorting_idx] 91 | rel_labels = rel_class[sorting_idx] 92 | 93 | result._rel_pair_idxs = rel_pair_idx # (#rel, 2) 94 | result._pred_rel_scores = rel_class_prob # (#rel, #rel_class) 95 | result._pred_rel_labels = rel_labels # (#rel, ) 96 | if segmentation_vis: 97 | result._sorting_idx = sorting_idx 98 | # should have fields : rel_pair_idxs, pred_rel_class_prob, pred_rel_labels, pred_labels, pred_scores 99 | results.append(result) 100 | return results 101 | 102 | 103 | def build_roi_scenegraph_post_processor(cfg): 104 | 105 | use_gt_box = cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX 106 | later_nms_pred_thres = cfg.TEST.RELATION.LATER_NMS_PREDICTION_THRES 107 | 108 | postprocessor = PostProcessor( 109 | use_gt_box, 110 | later_nms_pred_thres, 111 | ) 112 | return postprocessor -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | import numpy.random as npr 6 | 7 | from fvcore.nn import smooth_l1_loss 8 | from .motif import cat 9 | 10 | class RelationLossComputation(object): 11 | """ 12 | Computes the loss for relation triplet. 13 | Also supports FPN 14 | """ 15 | 16 | def __init__( 17 | self, 18 | use_label_smoothing, 19 | predicate_proportion, 20 | ): 21 | """ 22 | Arguments: 23 | bbox_proposal_matcher (Matcher) 24 | rel_fg_bg_sampler (RelationPositiveNegativeSampler) 25 | """ 26 | 27 | self.use_label_smoothing = use_label_smoothing 28 | self.pred_weight = (1.0 / torch.FloatTensor([0.5,] + predicate_proportion)).cuda() 29 | 30 | if self.use_label_smoothing: 31 | self.criterion_loss = Label_Smoothing_Regression(e=0.01) 32 | else: 33 | self.criterion_loss = nn.CrossEntropyLoss() 34 | 35 | 36 | def __call__(self, proposals, rel_labels, relation_logits, refine_logits): 37 | """ 38 | Computes the loss for relation triplet. 39 | This requires that the subsample method has been called beforehand. 40 | Arguments: 41 | relation_logits (list[Tensor]) 42 | refine_obj_logits (list[Tensor]) 43 | Returns: 44 | predicate_loss (Tensor) 45 | finetune_obj_loss (Tensor) 46 | """ 47 | refine_obj_logits = refine_logits 48 | relation_logits = cat(relation_logits, dim=0) 49 | refine_obj_logits = cat(refine_obj_logits, dim=0) 50 | 51 | fg_labels = cat([proposal.gt_classes for proposal in proposals], dim=0) 52 | rel_labels = cat(rel_labels, dim=0) 53 | 54 | loss_relation = self.criterion_loss(relation_logits, rel_labels.long()) 55 | loss_refine_obj = self.criterion_loss(refine_obj_logits, fg_labels.long()) 56 | 57 | 58 | return loss_relation, loss_refine_obj 59 | 60 | 61 | class FocalLoss(nn.Module): 62 | def __init__(self, gamma=0, alpha=None, size_average=True): 63 | super(FocalLoss, self).__init__() 64 | self.gamma = gamma 65 | self.alpha = alpha 66 | self.size_average = size_average 67 | 68 | def forward(self, input, target): 69 | target = target.view(-1) 70 | 71 | logpt = F.log_softmax(input) 72 | logpt = logpt.index_select(-1, target).diag() 73 | logpt = logpt.view(-1) 74 | pt = logpt.exp() 75 | 76 | logpt = logpt * self.alpha * (target > 0).float() + logpt * (1 - self.alpha) * (target <= 0).float() 77 | 78 | loss = -1 * (1-pt)**self.gamma * logpt 79 | if self.size_average: return loss.mean() 80 | else: return loss.sum() 81 | 82 | import torch 83 | import torch.nn as nn 84 | 85 | 86 | class Label_Smoothing_Regression(nn.Module): 87 | 88 | def __init__(self, e=0.01, reduction='mean'): 89 | super().__init__() 90 | 91 | self.log_softmax = nn.LogSoftmax(dim=1) 92 | self.e = e 93 | self.reduction = reduction 94 | 95 | def _one_hot(self, labels, classes, value=1): 96 | """ 97 | Convert labels to one hot vectors 98 | 99 | Args: 100 | labels: torch tensor in format [label1, label2, label3, ...] 101 | classes: int, number of classes 102 | value: label value in one hot vector, default to 1 103 | 104 | Returns: 105 | return one hot format labels in shape [batchsize, classes] 106 | """ 107 | 108 | one_hot = torch.zeros(labels.size(0), classes) 109 | 110 | #labels and value_added size must match 111 | labels = labels.view(labels.size(0), -1) 112 | value_added = torch.Tensor(labels.size(0), 1).fill_(value) 113 | 114 | value_added = value_added.to(labels.device) 115 | one_hot = one_hot.to(labels.device) 116 | 117 | one_hot.scatter_add_(1, labels, value_added) 118 | 119 | return one_hot 120 | 121 | def _smooth_label(self, target, length, smooth_factor): 122 | """convert targets to one-hot format, and smooth 123 | them. 124 | Args: 125 | target: target in form with [label1, label2, label_batchsize] 126 | length: length of one-hot format(number of classes) 127 | smooth_factor: smooth factor for label smooth 128 | 129 | Returns: 130 | smoothed labels in one hot format 131 | """ 132 | one_hot = self._one_hot(target, length, value=1 - smooth_factor) 133 | one_hot += smooth_factor / length 134 | 135 | return one_hot.to(target.device) 136 | 137 | def forward(self, x, target): 138 | 139 | if x.size(0) != target.size(0): 140 | raise ValueError('Expected input batchsize ({}) to match target batch_size({})' 141 | .format(x.size(0), target.size(0))) 142 | 143 | if x.dim() < 2: 144 | raise ValueError('Expected input tensor to have least 2 dimensions(got {})' 145 | .format(x.size(0))) 146 | 147 | if x.dim() != 2: 148 | raise ValueError('Only 2 dimension tensor are implemented, (got {})' 149 | .format(x.size())) 150 | 151 | 152 | smoothed_target = self._smooth_label(target, x.size(1), self.e) 153 | x = self.log_softmax(x) 154 | loss = torch.sum(- x * smoothed_target, dim=1) 155 | 156 | if self.reduction == 'none': 157 | return loss 158 | 159 | elif self.reduction == 'sum': 160 | return torch.sum(loss) 161 | 162 | elif self.reduction == 'mean': 163 | return torch.mean(loss) 164 | 165 | else: 166 | raise ValueError('unrecognized option, expect reduction to be one of none, mean, sum') 167 | 168 | def build_roi_scenegraph_loss_evaluator(cfg): 169 | 170 | loss_evaluator = RelationLossComputation( 171 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.LABEL_SMOOTHING_LOSS, 172 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.REL_PROP, 173 | ) 174 | 175 | return loss_evaluator -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/motif/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_motifs import LSTMContext, FrequencyBias, LSTMContextSegmentation, LSTMContextSegmentationC, LSTMContextSegmentationClsLoss, LSTMContextSegmentationNoLSTMEnc, LSTMContextSegmentationNoLSTMSeg 2 | from .utils_motifs import cat -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/motif/utils_motifs.py: -------------------------------------------------------------------------------- 1 | import array 2 | import os 3 | import zipfile 4 | import itertools 5 | import six 6 | import torch 7 | import numpy as np 8 | from six.moves.urllib.request import urlretrieve 9 | from tqdm import tqdm 10 | import sys 11 | import logging 12 | 13 | def cat(tensors, dim=0): 14 | """ 15 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 16 | """ 17 | assert isinstance(tensors, (list, tuple)) 18 | if len(tensors) == 1: 19 | return tensors[0] 20 | return torch.cat(tensors, dim) 21 | 22 | def normalize_sigmoid_logits(orig_logits): 23 | orig_logits = torch.sigmoid(orig_logits) 24 | orig_logits = orig_logits / (orig_logits.sum(1).unsqueeze(-1) + 1e-12) 25 | return orig_logits 26 | 27 | def generate_attributes_target(attributes, device, max_num_attri, num_attri_cat): 28 | """ 29 | from list of attribute indexs to [1,0,1,0,0,1] form 30 | """ 31 | assert max_num_attri == attributes.shape[1] 32 | num_obj = attributes.shape[0] 33 | 34 | with_attri_idx = (attributes.sum(-1) > 0).long() 35 | attribute_targets = torch.zeros((num_obj, num_attri_cat), device=device).float() 36 | 37 | for idx in torch.nonzero(with_attri_idx, as_tuple=False).squeeze(1).tolist(): 38 | for k in range(max_num_attri): 39 | att_id = int(attributes[idx, k]) 40 | if att_id == 0: 41 | break 42 | else: 43 | attribute_targets[idx, att_id] = 1 44 | return attribute_targets, with_attri_idx 45 | 46 | def transpose_packed_sequence_inds(lengths): 47 | """ 48 | Get a TxB indices from sorted lengths. 49 | Fetch new_inds, split by new_lens, padding to max(new_lens), and stack. 50 | Returns: 51 | new_inds (np.array) [sum(lengths), ] 52 | new_lens (list(np.array)): number of elements of each time step, descending 53 | """ 54 | new_inds = [] 55 | new_lens = [] 56 | cum_add = np.cumsum([0] + lengths) 57 | max_len = lengths[0] 58 | length_pointer = len(lengths) - 1 59 | for i in range(max_len): 60 | while length_pointer > 0 and lengths[length_pointer] <= i: 61 | length_pointer -= 1 62 | new_inds.append(cum_add[:(length_pointer+1)].copy()) 63 | cum_add[:(length_pointer+1)] += 1 64 | new_lens.append(length_pointer+1) 65 | new_inds = np.concatenate(new_inds, 0) 66 | return new_inds, new_lens 67 | 68 | 69 | def sort_by_score(boxes, scores): 70 | """ 71 | We'll sort everything scorewise from Hi->low, BUT we need to keep images together 72 | and sort LSTM from l 73 | :param im_inds: Which im we're on 74 | :param scores: Goodness ranging between [0, 1]. Higher numbers come FIRST 75 | :return: Permutation to put everything in the right order for the LSTM 76 | Inverse permutation 77 | Lengths for the TxB packed sequence. 78 | """ 79 | num_rois = [len(b) for b in boxes] 80 | num_im = len(num_rois) 81 | 82 | scores = scores.split(num_rois, dim=0) 83 | ordered_scores = [] 84 | for i, (score, num_roi) in enumerate(zip(scores, num_rois)): 85 | ordered_scores.append( score + 2.0 * float(num_roi * 2 * num_im + i) ) 86 | ordered_scores = cat(ordered_scores, dim=0) 87 | _, perm = torch.sort(ordered_scores, 0, descending=True) 88 | 89 | num_rois = sorted(num_rois, reverse=True) 90 | inds, ls_transposed = transpose_packed_sequence_inds(num_rois) # move it to TxB form 91 | inds = torch.LongTensor(inds).to(scores[0].device) 92 | ls_transposed = torch.LongTensor(ls_transposed) 93 | 94 | perm = perm[inds] # (batch_num_box, ) 95 | _, inv_perm = torch.sort(perm) 96 | 97 | return perm, inv_perm, ls_transposed 98 | 99 | 100 | def to_onehot(vec, num_classes, fill=1000): 101 | """ 102 | Creates a [size, num_classes] torch FloatTensor where 103 | one_hot[i, vec[i]] = fill 104 | 105 | :param vec: 1d torch tensor 106 | :param num_classes: int 107 | :param fill: value that we want + and - things to be. 108 | :return: 109 | """ 110 | onehot_result = vec.new(vec.size(0), num_classes).float().fill_(-fill) 111 | arange_inds = vec.new(vec.size(0)).long() 112 | torch.arange(0, vec.size(0), out=arange_inds) 113 | 114 | onehot_result.view(-1)[vec.long() + num_classes*arange_inds] = fill 115 | return onehot_result 116 | 117 | def get_dropout_mask(dropout_probability, tensor_shape, device): 118 | """ 119 | once get, it is fixed all the time 120 | """ 121 | binary_mask = (torch.rand(tensor_shape) > dropout_probability) 122 | # Scale mask by 1/keep_prob to preserve output statistics. 123 | dropout_mask = binary_mask.float().to(device).div(1.0 - dropout_probability) 124 | return dropout_mask 125 | 126 | def center_x(boxes): 127 | 128 | boxes = cat([p.tensor for p in boxes], dim=0) 129 | c_x = 0.5 * (boxes[:, 0] + boxes[:, 2]) 130 | return c_x.view(-1) 131 | 132 | def encode_box_info(boxes_list, img_sizes): 133 | """ 134 | encode proposed box information (x1, y1, x2, y2) to 135 | (cx/wid, cy/hei, w/wid, h/hei, x1/wid, y1/hei, x2/wid, y2/hei, wh/wid*hei) 136 | """ 137 | 138 | boxes_info = [] 139 | for (boxes, img_size) in zip(boxes_list, img_sizes): 140 | 141 | wid = img_size[0] 142 | hei = img_size[1] 143 | wh = boxes.tensor[:, 2:] - boxes.tensor[:, :2] + 1.0 144 | xy = boxes.tensor[:, :2] + 0.5 * wh 145 | w, h = wh.split([1,1], dim=-1) 146 | x, y = xy.split([1,1], dim=-1) 147 | x1, y1, x2, y2 = boxes.tensor.split([1,1,1,1], dim=-1) 148 | assert wid * hei != 0 149 | info = torch.cat([w/wid, h/hei, x/wid, y/hei, x1/wid, y1/hei, x2/wid, y2/hei, 150 | w*h/(wid*hei)], dim=-1).view(-1, 9) 151 | boxes_info.append(info) 152 | 153 | return torch.cat(boxes_info, dim=0) 154 | 155 | 156 | def obj_edge_vectors(names, wv_dir, wv_type='glove.6B', wv_dim=300): 157 | wv_dict, wv_arr, wv_size = load_word_vectors(wv_dir, wv_type, wv_dim) 158 | vectors = torch.Tensor(len(names), wv_dim) 159 | vectors.normal_(0,1) 160 | logger = logging.getLogger(__name__) 161 | for i, token in enumerate(names): 162 | wv_index = wv_dict.get(token, None) 163 | if wv_index is not None: 164 | vectors[i] = wv_arr[wv_index] 165 | else: 166 | # Try the longest word 167 | lw_token = sorted(token.split(' '), key=lambda x: len(x), reverse=True)[0] 168 | logger.info("{} -> {} ".format(token, lw_token)) 169 | wv_index = wv_dict.get(lw_token, None) 170 | if wv_index is not None: 171 | vectors[i] = wv_arr[wv_index] 172 | else: 173 | logger.warning("fail on {}".format(token)) 174 | 175 | return vectors 176 | 177 | def obj_edge_vectors_segmentation(names, wv_dir, wv_type='glove.6B', wv_dim=300): 178 | wv_dict, wv_arr, wv_size = load_word_vectors(wv_dir, wv_type, wv_dim) 179 | 180 | vectors = torch.Tensor(len(names), wv_dim) 181 | vectors.normal_(0,1) 182 | logger = logging.getLogger(__name__) 183 | for i, token in enumerate(names): 184 | wv_index = wv_dict.get(token, None) 185 | if wv_index is not None: 186 | vectors[i] = wv_arr[wv_index] 187 | else: 188 | # Try the longest word 189 | lw_token = sorted(token.split(' '), key=lambda x: len(x), reverse=True)[0] 190 | logger.info("{} -> {} ".format(token, lw_token)) 191 | wv_index = wv_dict.get(lw_token, None) 192 | if wv_index is not None: 193 | vectors[i] = wv_arr[wv_index] 194 | else: 195 | logger.warning("fail on {}".format(token)) 196 | 197 | return vectors 198 | 199 | def load_word_vectors(root, wv_type, dim): 200 | """Load word vectors from a path, trying .pt, .txt, and .zip extensions.""" 201 | URL = { 202 | 'glove.42B': 'http://nlp.stanford.edu/data/glove.42B.300d.zip', 203 | 'glove.840B': 'http://nlp.stanford.edu/data/glove.840B.300d.zip', 204 | 'glove.twitter.27B': 'http://nlp.stanford.edu/data/glove.twitter.27B.zip', 205 | 'glove.6B': 'http://nlp.stanford.edu/data/glove.6B.zip', 206 | } 207 | if isinstance(dim, int): 208 | dim = str(dim) + 'd' 209 | fname = os.path.join(root, wv_type + '.' + dim) 210 | 211 | if os.path.isfile(fname + '.pt'): 212 | fname_pt = fname + '.pt' 213 | print('loading word vectors from', fname_pt) 214 | try: 215 | return torch.load(fname_pt, map_location=torch.device("cpu")) 216 | except Exception as e: 217 | print("Error loading the model from {}{}".format(fname_pt, str(e))) 218 | sys.exit(-1) 219 | if os.path.isfile(fname + '.txt'): 220 | fname_txt = fname + '.txt' 221 | cm = open(fname_txt, 'rb') 222 | cm = [line for line in cm] 223 | elif os.path.basename(wv_type) in URL: 224 | url = URL[wv_type] 225 | print('downloading word vectors from {}'.format(url)) 226 | filename = os.path.basename(fname) 227 | if not os.path.exists(root): 228 | os.makedirs(root) 229 | with tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t: 230 | fname, _ = urlretrieve(url, fname, reporthook=reporthook(t)) 231 | with zipfile.ZipFile(fname, "r") as zf: 232 | print('extracting word vectors into {}'.format(root)) 233 | zf.extractall(root) 234 | if not os.path.isfile(fname + '.txt'): 235 | raise RuntimeError('no word vectors of requested dimension found') 236 | return load_word_vectors(root, wv_type, dim) 237 | else: 238 | raise RuntimeError('unable to load word vectors') 239 | 240 | wv_tokens, wv_arr, wv_size = [], array.array('d'), None 241 | if cm is not None: 242 | for line in tqdm(range(len(cm)), desc="loading word vectors from {}".format(fname_txt)): 243 | entries = cm[line].strip().split(b' ') 244 | word, entries = entries[0], entries[1:] 245 | if wv_size is None: 246 | wv_size = len(entries) 247 | try: 248 | if isinstance(word, six.binary_type): 249 | word = word.decode('utf-8') 250 | except: 251 | print('non-UTF8 token', repr(word), 'ignored') 252 | continue 253 | wv_arr.extend(float(x) for x in entries) 254 | wv_tokens.append(word) 255 | 256 | wv_dict = {word: i for i, word in enumerate(wv_tokens)} 257 | wv_arr = torch.Tensor(wv_arr).view(-1, wv_size) 258 | ret = (wv_dict, wv_arr, wv_size) 259 | torch.save(ret, fname + '.pt') 260 | return ret 261 | 262 | def reporthook(t): 263 | """https://github.com/tqdm/tqdm""" 264 | last_b = [0] 265 | def inner(b=1, bsize=1, tsize=None): 266 | if tsize is not None: 267 | t.total = tsize 268 | t.update((b - last_b[0]) * bsize) 269 | last_b[0] = b 270 | return inner -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/sampling.py: -------------------------------------------------------------------------------- 1 | #Modified from https://github.com/KaihuaTang/Scene-Graph-Benchmark.pytorch/blob/master/maskrcnn_benchmark/modeling/roi_heads/relation_head/sampling.py 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | import numpy.random as npr 7 | 8 | from detectron2.structures.boxes import pairwise_iou 9 | 10 | class RelationSampling(object): 11 | #sample relation pair proposals from given sets of bounding boxes 12 | def __init__( 13 | self, 14 | fg_thres, 15 | require_overlap, 16 | num_sample_per_gt_rel, 17 | batch_size_per_image, 18 | positive_fraction, 19 | use_gt_box, 20 | num_rel_classes, 21 | test_overlap, 22 | ): 23 | 24 | self.fg_thres = fg_thres 25 | self.require_overlap = require_overlap 26 | self.num_sample_per_gt_rel = num_sample_per_gt_rel 27 | self.batch_size_per_image = batch_size_per_image 28 | self.positive_fraction = positive_fraction 29 | self.use_gt_box = use_gt_box 30 | self.num_rel_classes = num_rel_classes 31 | self.test_overlap = test_overlap 32 | 33 | 34 | def prepare_test_pairs(self, device, proposals): 35 | # prepare object pairs for relation prediction 36 | rel_pair_idxs = [] 37 | for p in proposals: 38 | n = len(p) 39 | cand_matrix = torch.ones((n, n), device=device) - torch.eye(n, device=device) 40 | # mode==sgdet and require_overlap 41 | if (not self.use_gt_box) and self.test_overlap: 42 | cand_matrix = cand_matrix.byte() & pairwise_iou(p.pred_boxes, p.pred_boxes).gt(0).byte() 43 | 44 | idxs = torch.nonzero(cand_matrix, as_tuple=False).view(-1,2) 45 | del cand_matrix 46 | if len(idxs) > 0: 47 | rel_pair_idxs.append(idxs) 48 | else: 49 | # if there is no candidate pairs, give a placeholder of [[0, 0]] 50 | rel_pair_idxs.append(torch.zeros((1, 2), dtype=torch.int64, device=device)) 51 | return rel_pair_idxs 52 | 53 | def gtbox_relsample(self, boxes, targets, relations): 54 | assert self.use_gt_box 55 | num_pos_per_img = int(self.batch_size_per_image * self.positive_fraction) 56 | rel_idx_pairs = [] 57 | rel_labels = [] 58 | rel_sym_binarys = [] 59 | 60 | for img_id, (box, target, relation) in enumerate(zip(boxes, targets, relations)): 61 | device = box.device 62 | num_prp = box.tensor.shape[0] 63 | 64 | assert box.tensor.shape[0] == target.gt_boxes.tensor.shape[0] 65 | tgt_relations = relation # [tgt, tgt] 66 | # tgt_pair_idxs = torch.nonzero(tgt_rel_matrix > 0) 67 | tgt_pair_idxs = tgt_relations[:, :2] 68 | assert tgt_pair_idxs.shape[1] == 2 69 | tgt_head_idxs = tgt_pair_idxs[:, 0].contiguous().view(-1) 70 | tgt_tail_idxs = tgt_pair_idxs[:, 1].contiguous().view(-1) 71 | tgt_rel_labs = tgt_relations[:,2].contiguous().view(-1) 72 | 73 | # sym_binary_rels 74 | binary_rel = torch.zeros((num_prp, num_prp), device=device).long() 75 | binary_rel[tgt_head_idxs, tgt_tail_idxs] = 1 76 | binary_rel[tgt_tail_idxs, tgt_head_idxs] = 1 77 | rel_sym_binarys.append(binary_rel) 78 | 79 | rel_possibility = torch.ones((num_prp, num_prp), device=device).long() - torch.eye(num_prp, device=device).long() 80 | rel_possibility[tgt_head_idxs, tgt_tail_idxs] = 0 81 | tgt_bg_idxs = torch.nonzero(rel_possibility > 0, as_tuple=False) 82 | 83 | # generate fg bg rel_pairs 84 | if tgt_pair_idxs.shape[0] > num_pos_per_img: 85 | perm = torch.randperm(tgt_pair_idxs.shape[0], device=device)[:num_pos_per_img] 86 | tgt_pair_idxs = tgt_pair_idxs[perm] 87 | tgt_rel_labs = tgt_rel_labs[perm] 88 | num_fg = min(tgt_pair_idxs.shape[0], num_pos_per_img) 89 | 90 | num_bg = self.batch_size_per_image - num_fg 91 | perm = torch.randperm(tgt_bg_idxs.shape[0], device=device)[:num_bg] 92 | tgt_bg_idxs = tgt_bg_idxs[perm] 93 | 94 | img_rel_idxs = torch.cat((tgt_pair_idxs, tgt_bg_idxs), dim=0) 95 | img_rel_labels = torch.cat((tgt_rel_labs.long(), torch.full((tgt_bg_idxs.shape[0],), fill_value=self.num_rel_classes, device=device, dtype=torch.long)), dim=0).contiguous().view(-1) 96 | 97 | rel_idx_pairs.append(img_rel_idxs) 98 | rel_labels.append(img_rel_labels) 99 | 100 | return boxes, rel_labels, rel_idx_pairs, rel_sym_binarys 101 | 102 | def detect_relsample(self, proposals, targets, relations): 103 | # corresponding to rel_assignments function in neural-motifs 104 | """ 105 | The input proposals are already processed by subsample function of box_head, 106 | in this function, we should only care about fg box, and sample corresponding fg/bg relations 107 | Note: this function keeps a state. 108 | Arguments: 109 | boxes (list[BoxList]) contain fields: labels, predict_logits 110 | targets (list[BoxList]) contain fields: labels 111 | """ 112 | self.num_pos_per_img = int(self.batch_size_per_image * self.positive_fraction) 113 | rel_idx_pairs = [] 114 | rel_labels = [] 115 | rel_sym_binarys = [] 116 | for img_id, (proposal, target, relation) in enumerate(zip(proposals, targets, relations)): 117 | device = proposal.pred_boxes.device 118 | prp_box = proposal.pred_boxes 119 | prp_lab = proposal.pred_classes.long() 120 | tgt_box = target.gt_boxes 121 | tgt_lab = target.gt_classes.long() 122 | tgt_rel_matrix = torch.zeros(tgt_lab.shape[0],tgt_lab.shape[0]).long().to(device) # [tgt, tgt] 123 | tgt_rel_matrix[relation[:,0], relation[:,1]] = relation[:,2] 124 | # IoU matching 125 | ious = pairwise_iou(tgt_box, prp_box) # [tgt, prp] 126 | 127 | is_match = (tgt_lab[:,None] == prp_lab[None]) & (ious > self.fg_thres) # [tgt, prp] 128 | # Proposal self IoU to filter non-overlap 129 | prp_self_iou = pairwise_iou(prp_box, prp_box) # [prp, prp] 130 | if self.require_overlap and (not self.use_gt_box): 131 | rel_possibility = (prp_self_iou > 0) & (prp_self_iou < 1) # not self & intersect 132 | else: 133 | num_prp = prp_box.shape[0] 134 | rel_possibility = torch.ones((num_prp, num_prp), device=device).long() - torch.eye(num_prp, device=device).long() 135 | # only select relations between fg proposals 136 | #Fix for background class 137 | rel_possibility[prp_lab == self.num_rel_classes] = 0 138 | rel_possibility[:, prp_lab == self.num_rel_classes] = 0 139 | 140 | img_rel_triplets, binary_rel = self.motif_rel_fg_bg_sampling(device, tgt_rel_matrix, ious, is_match, rel_possibility) 141 | rel_idx_pairs.append(img_rel_triplets[:, :2]) # (num_rel, 2), (sub_idx, obj_idx) 142 | rel_labels.append(img_rel_triplets[:, 2]) # (num_rel, ) 143 | rel_sym_binarys.append(binary_rel) 144 | 145 | return proposals, rel_labels, rel_idx_pairs, rel_sym_binarys 146 | 147 | def motif_rel_fg_bg_sampling(self, device, tgt_rel_matrix, ious, is_match, rel_possibility): 148 | """ 149 | prepare to sample fg relation triplet and bg relation triplet 150 | tgt_rel_matrix: # [number_target, number_target] 151 | ious: # [number_target, num_proposal] 152 | is_match: # [number_target, num_proposal] 153 | rel_possibility:# [num_proposal, num_proposal] 154 | """ 155 | tgt_pair_idxs = torch.nonzero(tgt_rel_matrix > 0) 156 | assert tgt_pair_idxs.shape[1] == 2 157 | tgt_head_idxs = tgt_pair_idxs[:, 0].contiguous().view(-1) 158 | tgt_tail_idxs = tgt_pair_idxs[:, 1].contiguous().view(-1) 159 | tgt_rel_labs = tgt_rel_matrix[tgt_head_idxs, tgt_tail_idxs].contiguous().view(-1) 160 | 161 | num_tgt_rels = tgt_rel_labs.shape[0] 162 | # generate binary prp mask 163 | num_prp = is_match.shape[-1] 164 | binary_prp_head = is_match[tgt_head_idxs] # num_tgt_rel, num_prp (matched prp head) 165 | binary_prp_tail = is_match[tgt_tail_idxs] # num_tgt_rel, num_prp (matched prp head) 166 | binary_rel = torch.zeros((num_prp, num_prp), device=device).long() 167 | 168 | fg_rel_triplets = [] 169 | for i in range(num_tgt_rels): 170 | # generate binary prp mask 171 | bi_match_head = torch.nonzero(binary_prp_head[i] > 0) 172 | bi_match_tail = torch.nonzero(binary_prp_tail[i] > 0) 173 | 174 | num_bi_head = bi_match_head.shape[0] 175 | num_bi_tail = bi_match_tail.shape[0] 176 | if num_bi_head > 0 and num_bi_tail > 0: 177 | bi_match_head = bi_match_head.view(1, num_bi_head).expand(num_bi_tail, num_bi_head).contiguous() 178 | bi_match_tail = bi_match_tail.view(num_bi_tail, 1).expand(num_bi_tail, num_bi_head).contiguous() 179 | # binary rel only consider related or not, so its symmetric 180 | binary_rel[bi_match_head.view(-1), bi_match_tail.view(-1)] = 1 181 | binary_rel[bi_match_tail.view(-1), bi_match_head.view(-1)] = 1 182 | 183 | tgt_head_idx = int(tgt_head_idxs[i]) 184 | tgt_tail_idx = int(tgt_tail_idxs[i]) 185 | tgt_rel_lab = int(tgt_rel_labs[i]) 186 | # find matching pair in proposals (might be more than one) 187 | prp_head_idxs = torch.nonzero(is_match[tgt_head_idx]).squeeze(1) 188 | prp_tail_idxs = torch.nonzero(is_match[tgt_tail_idx]).squeeze(1) 189 | num_match_head = prp_head_idxs.shape[0] 190 | num_match_tail = prp_tail_idxs.shape[0] 191 | if num_match_head <= 0 or num_match_tail <= 0: 192 | continue 193 | # all combination pairs 194 | prp_head_idxs = prp_head_idxs.view(-1,1).expand(num_match_head,num_match_tail).contiguous().view(-1) 195 | prp_tail_idxs = prp_tail_idxs.view(1,-1).expand(num_match_head,num_match_tail).contiguous().view(-1) 196 | valid_pair = prp_head_idxs != prp_tail_idxs 197 | if valid_pair.sum().item() <= 0: 198 | continue 199 | # remove self-pair 200 | # remove selected pair from rel_possibility 201 | prp_head_idxs = prp_head_idxs[valid_pair] 202 | prp_tail_idxs = prp_tail_idxs[valid_pair] 203 | rel_possibility[prp_head_idxs, prp_tail_idxs] = 0 204 | # construct corresponding proposal triplets corresponding to i_th gt relation 205 | fg_labels = torch.tensor([tgt_rel_lab]*prp_tail_idxs.shape[0], dtype=torch.int64, device=device).view(-1,1) 206 | fg_rel_i = cat((prp_head_idxs.view(-1,1), prp_tail_idxs.view(-1,1), fg_labels), dim=-1).to(torch.int64) 207 | # select if too many corresponding proposal pairs to one pair of gt relationship triplet 208 | # NOTE that in original motif, the selection is based on a ious_score score 209 | if fg_rel_i.shape[0] > self.num_sample_per_gt_rel: 210 | ious_score = (ious[tgt_head_idx, prp_head_idxs] * ious[tgt_tail_idx, prp_tail_idxs]).view(-1).detach().cpu().numpy() 211 | ious_score = ious_score / ious_score.sum() 212 | perm = npr.choice(ious_score.shape[0], p=ious_score, size=self.num_sample_per_gt_rel, replace=False) 213 | fg_rel_i = fg_rel_i[perm] 214 | if fg_rel_i.shape[0] > 0: 215 | fg_rel_triplets.append(fg_rel_i) 216 | 217 | # select fg relations 218 | if len(fg_rel_triplets) == 0: 219 | fg_rel_triplets = torch.zeros((0, 3), dtype=torch.int64, device=device) 220 | else: 221 | fg_rel_triplets = cat(fg_rel_triplets, dim=0).to(torch.int64) 222 | if fg_rel_triplets.shape[0] > self.num_pos_per_img: 223 | perm = torch.randperm(fg_rel_triplets.shape[0], device=device)[:self.num_pos_per_img] 224 | fg_rel_triplets = fg_rel_triplets[perm] 225 | 226 | # select bg relations 227 | bg_rel_inds = torch.nonzero(rel_possibility>0).view(-1,2) 228 | bg_rel_labs = torch.full((bg_rel_inds.shape[0],), fill_value=self.num_rel_classes, dtype=torch.int64, device=device) 229 | bg_rel_triplets = cat((bg_rel_inds, bg_rel_labs.view(-1,1)), dim=-1).to(torch.int64) 230 | 231 | num_neg_per_img = min(self.batch_size_per_image - fg_rel_triplets.shape[0], bg_rel_triplets.shape[0]) 232 | if bg_rel_triplets.shape[0] > 0: 233 | perm = torch.randperm(bg_rel_triplets.shape[0], device=device)[:num_neg_per_img] 234 | bg_rel_triplets = bg_rel_triplets[perm] 235 | else: 236 | bg_rel_triplets = torch.zeros((0, 3), dtype=torch.int64, device=device) 237 | 238 | # if both fg and bg is none 239 | if fg_rel_triplets.shape[0] == 0 and bg_rel_triplets.shape[0] == 0: 240 | bg_rel_triplets = torch.zeros((1, 3), dtype=torch.int64, device=device) 241 | 242 | return cat((fg_rel_triplets, bg_rel_triplets), dim=0), binary_rel 243 | 244 | def cat(tensors, dim=0): 245 | """ 246 | Efficient version of torch.cat that avoids a copy if there is only a single element in a list 247 | """ 248 | assert isinstance(tensors, (list, tuple)) 249 | if len(tensors) == 1: 250 | return tensors[0] 251 | return torch.cat(tensors, dim) 252 | 253 | def build_roi_scenegraph_samp_processor(cfg): 254 | 255 | samp_processor = RelationSampling( 256 | cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, 257 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.REQUIRE_BOX_OVERLAP, 258 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.NUM_SAMPLE_PER_GT_REL, 259 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.BATCH_SIZE_PER_IMAGE, 260 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.POSITIVE_FRACTION, 261 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX, 262 | cfg.MODEL.ROI_SCENEGRAPH_HEAD.NUM_CLASSES, 263 | cfg.TEST.RELATION.REQUIRE_OVERLAP, 264 | ) 265 | 266 | return samp_processor 267 | -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/transformer/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_transformer import TransformerContext -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/transformer/model_transformer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Based on the implementation of https://github.com/jadore801120/attention-is-all-you-need-pytorch 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from ..motif.utils_motifs import cat 9 | from ..motif.utils_motifs import obj_edge_vectors, to_onehot, encode_box_info 10 | from ..utils import nms_overlaps 11 | 12 | class ScaledDotProductAttention(nn.Module): 13 | ''' Scaled Dot-Product Attention ''' 14 | def __init__(self, temperature, attn_dropout=0.1): 15 | super().__init__() 16 | self.temperature = temperature 17 | self.dropout = nn.Dropout(attn_dropout) 18 | self.softmax = nn.Softmax(dim=2) 19 | 20 | def forward(self, q, k, v, mask=None): 21 | """ 22 | Args: 23 | q (bsz, len_q, dim_q) 24 | k (bsz, len_k, dim_k) 25 | v (bsz, len_v, dim_v) 26 | Note: len_k==len_v, and dim_q==dim_k 27 | Returns: 28 | output (bsz, len_q, dim_v) 29 | attn (bsz, len_q, len_k) 30 | """ 31 | attn = torch.bmm(q, k.transpose(1, 2)) 32 | attn = attn / self.temperature 33 | 34 | if mask is not None: 35 | attn = attn.masked_fill(mask, -np.inf) 36 | 37 | attn = self.softmax(attn) 38 | attn = self.dropout(attn) 39 | output = torch.bmm(attn, v) 40 | 41 | return output, attn 42 | 43 | 44 | class MultiHeadAttention(nn.Module): 45 | ''' Multi-Head Attention module ''' 46 | def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1): 47 | super().__init__() 48 | self.n_head = n_head 49 | self.d_k = d_k 50 | self.d_v = d_v 51 | 52 | self.w_qs = nn.Linear(d_model, n_head * d_k) 53 | self.w_ks = nn.Linear(d_model, n_head * d_k) 54 | self.w_vs = nn.Linear(d_model, n_head * d_v) 55 | nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 56 | nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_k))) 57 | nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model + d_v))) 58 | 59 | self.attention = ScaledDotProductAttention(temperature=np.power(d_k, 0.5)) 60 | self.layer_norm = nn.LayerNorm(d_model) 61 | 62 | self.fc = nn.Linear(n_head * d_v, d_model) 63 | nn.init.xavier_normal_(self.fc.weight) 64 | 65 | self.dropout = nn.Dropout(dropout) 66 | 67 | 68 | def forward(self, q, k, v, mask=None): 69 | """ 70 | Args: 71 | q (bsz, len_q, dim_q) 72 | k (bsz, len_k, dim_k) 73 | v (bsz, len_v, dim_v) 74 | Note: len_k==len_v, and dim_q==dim_k 75 | Returns: 76 | output (bsz, len_q, d_model) 77 | attn (bsz, len_q, len_k) 78 | """ 79 | d_k, d_v, n_head = self.d_k, self.d_v, self.n_head 80 | 81 | sz_b, len_q, _ = q.size() 82 | sz_b, len_k, _ = k.size() 83 | sz_b, len_v, _ = v.size() # len_k==len_v 84 | 85 | residual = q 86 | 87 | q = self.w_qs(q).view(sz_b, len_q, n_head, d_k) 88 | k = self.w_ks(k).view(sz_b, len_k, n_head, d_k) 89 | v = self.w_vs(v).view(sz_b, len_v, n_head, d_v) 90 | 91 | q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk 92 | k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk 93 | v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv 94 | 95 | mask = mask.repeat(n_head, 1, 1) # (n*b) x .. x .. 96 | output, attn = self.attention(q, k, v, mask=mask) 97 | 98 | output = output.view(n_head, sz_b, len_q, d_v) 99 | output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_q, -1) # b x lq x (n*dv) 100 | 101 | output = self.dropout(self.fc(output)) 102 | output = self.layer_norm(output + residual) 103 | 104 | return output, attn 105 | 106 | 107 | class PositionwiseFeedForward(nn.Module): 108 | ''' A two-feed-forward-layer module ''' 109 | def __init__(self, d_in, d_hid, dropout=0.1): 110 | super().__init__() 111 | self.w_1 = nn.Conv1d(d_in, d_hid, 1) # position-wise 112 | self.w_2 = nn.Conv1d(d_hid, d_in, 1) # position-wise 113 | self.layer_norm = nn.LayerNorm(d_in) 114 | self.dropout = nn.Dropout(dropout) 115 | 116 | def forward(self, x): 117 | """ 118 | Merge adjacent information. Equal to linear layer if kernel size is 1 119 | Args: 120 | x (bsz, len, dim) 121 | Returns: 122 | output (bsz, len, dim) 123 | """ 124 | residual = x 125 | output = x.transpose(1, 2) 126 | output = self.w_2(F.relu(self.w_1(output))) 127 | output = output.transpose(1, 2) 128 | output = self.dropout(output) 129 | output = self.layer_norm(output + residual) 130 | return output 131 | 132 | 133 | class EncoderLayer(nn.Module): 134 | ''' Compose with two layers ''' 135 | def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1): 136 | super(EncoderLayer, self).__init__() 137 | self.slf_attn = MultiHeadAttention( 138 | n_head, d_model, d_k, d_v, dropout=dropout) 139 | self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout) 140 | 141 | def forward(self, enc_input, non_pad_mask=None, slf_attn_mask=None): 142 | enc_output, enc_slf_attn = self.slf_attn( 143 | enc_input, enc_input, enc_input, mask=slf_attn_mask) 144 | enc_output *= non_pad_mask.float() 145 | 146 | enc_output = self.pos_ffn(enc_output) 147 | enc_output *= non_pad_mask.float() 148 | 149 | return enc_output, enc_slf_attn 150 | 151 | 152 | class TransformerEncoder(nn.Module): 153 | """ 154 | A encoder model with self attention mechanism. 155 | """ 156 | def __init__(self, n_layers, n_head, d_k, d_v, d_model, d_inner, dropout=0.1): 157 | super().__init__() 158 | self.layer_stack = nn.ModuleList([ 159 | EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout) 160 | for _ in range(n_layers)]) 161 | 162 | def forward(self, input_feats, num_objs): 163 | """ 164 | Args: 165 | input_feats [Tensor] (#total_box, d_model) : bounding box features of a batch 166 | num_objs [list of int] (bsz, ) : number of bounding box of each image 167 | Returns: 168 | enc_output [Tensor] (#total_box, d_model) 169 | """ 170 | original_input_feats = input_feats 171 | input_feats = input_feats.split(num_objs, dim=0) 172 | input_feats = nn.utils.rnn.pad_sequence(input_feats, batch_first=True) 173 | 174 | # -- Prepare masks 175 | bsz = len(num_objs) 176 | device = input_feats.device 177 | pad_len = max(num_objs) 178 | num_objs_ = torch.LongTensor(num_objs).to(device).unsqueeze(1).expand(-1, pad_len) 179 | slf_attn_mask = torch.arange(pad_len, device=device).view(1, -1).expand(bsz, -1).ge(num_objs_).unsqueeze(1).expand(-1, pad_len, -1) # (bsz, pad_len, pad_len) 180 | non_pad_mask = torch.arange(pad_len, device=device).to(device).view(1, -1).expand(bsz, -1).lt(num_objs_).unsqueeze(-1) # (bsz, pad_len, 1) 181 | 182 | # -- Forward 183 | enc_output = input_feats 184 | for enc_layer in self.layer_stack: 185 | enc_output, enc_slf_attn = enc_layer( 186 | enc_output, 187 | non_pad_mask=non_pad_mask, 188 | slf_attn_mask=slf_attn_mask) 189 | 190 | enc_output = enc_output[non_pad_mask.squeeze(-1)] 191 | return enc_output 192 | 193 | 194 | class TransformerContext(nn.Module): 195 | def __init__(self, config, obj_classes, rel_classes, in_channels): 196 | super().__init__() 197 | self.cfg = config 198 | # setting parameters 199 | if self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_BOX: 200 | self.mode = 'predcls' if self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL else 'sgcls' 201 | else: 202 | self.mode = 'sgdet' 203 | self.obj_classes = obj_classes 204 | self.rel_classes = rel_classes 205 | self.num_obj_cls = len(obj_classes) 206 | self.num_rel_cls = len(rel_classes) 207 | self.in_channels = in_channels 208 | self.obj_dim = in_channels 209 | self.embed_dim = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.EMBED_DIM 210 | self.hidden_dim = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.CONTEXT_HIDDEN_DIM 211 | self.nms_thresh = self.cfg.TEST.RELATION.LATER_NMS_PREDICTION_THRES 212 | 213 | self.dropout_rate = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.DROPOUT_RATE 214 | self.obj_layer = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.OBJ_LAYER 215 | self.edge_layer = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.REL_LAYER 216 | self.num_head = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.NUM_HEAD 217 | self.inner_dim = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.INNER_DIM 218 | self.k_dim = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.KEY_DIM 219 | self.v_dim = self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.TRANSFORMER.VAL_DIM 220 | 221 | 222 | # the following word embedding layer should be initalize by glove.6B before using 223 | embed_vecs = obj_edge_vectors(self.obj_classes, wv_dir=self.cfg.GLOVE_DIR, wv_dim=self.embed_dim) 224 | self.obj_embed1 = nn.Embedding(self.num_obj_cls, self.embed_dim) 225 | self.obj_embed2 = nn.Embedding(self.num_obj_cls, self.embed_dim) 226 | with torch.no_grad(): 227 | self.obj_embed1.weight.copy_(embed_vecs, non_blocking=True) 228 | self.obj_embed2.weight.copy_(embed_vecs, non_blocking=True) 229 | 230 | # position embedding 231 | self.bbox_embed = nn.Sequential(*[ 232 | nn.Linear(9, 32), nn.ReLU(inplace=True), nn.Dropout(0.1), 233 | nn.Linear(32, 128), nn.ReLU(inplace=True), nn.Dropout(0.1), 234 | ]) 235 | self.lin_obj = nn.Linear(self.in_channels + self.embed_dim + 128, self.hidden_dim) 236 | self.lin_edge = nn.Linear(self.embed_dim + self.hidden_dim + self.in_channels, self.hidden_dim) 237 | self.out_obj = nn.Linear(self.hidden_dim, self.num_obj_cls) 238 | self.context_obj = TransformerEncoder(self.obj_layer, self.num_head, self.k_dim, 239 | self.v_dim, self.hidden_dim, self.inner_dim, self.dropout_rate) 240 | self.context_edge = TransformerEncoder(self.edge_layer, self.num_head, self.k_dim, 241 | self.v_dim, self.hidden_dim, self.inner_dim, self.dropout_rate) 242 | 243 | 244 | def forward(self, roi_features, proposals, boxes, logger=None): 245 | # labels will be used in DecoderRNN during training 246 | use_gt_label = self.training or self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL 247 | obj_labels = cat([proposal.pred_classes for proposal in proposals], dim=0) if use_gt_label else None 248 | 249 | # label/logits embedding will be used as input 250 | if self.cfg.MODEL.ROI_SCENEGRAPH_HEAD.USE_GT_OBJECT_LABEL: 251 | obj_embed = self.obj_embed1(obj_labels) 252 | else: 253 | obj_logits = cat([proposal.pred_scores for proposal in proposals], dim=0).detach() 254 | obj_embed = F.softmax(obj_logits, dim=1) @ self.obj_embed1.weight 255 | 256 | # bbox embedding will be used as input 257 | # assert proposals[0].mode == 'xyxy' 258 | img_sizes = [proposal.image_size for proposal in proposals] 259 | pos_embed = self.bbox_embed(encode_box_info(boxes, img_sizes)) 260 | 261 | # encode objects with transformer 262 | obj_pre_rep = cat((roi_features, obj_embed, pos_embed), -1) 263 | num_objs = [len(p) for p in proposals] 264 | obj_pre_rep = self.lin_obj(obj_pre_rep) 265 | obj_feats = self.context_obj(obj_pre_rep, num_objs) 266 | 267 | # predict obj_dists and obj_preds 268 | if self.mode == 'predcls': 269 | obj_preds = obj_labels 270 | obj_dists = to_onehot(obj_preds, self.num_obj_cls) 271 | edge_pre_rep = cat((roi_features, obj_feats, self.obj_embed2(obj_labels)), dim=-1) 272 | else: 273 | obj_dists = self.out_obj(obj_feats) 274 | use_decoder_nms = self.mode == 'sgdet' and not self.training 275 | if use_decoder_nms: 276 | boxes_per_cls = [proposal.boxes_per_cls for proposal in proposals] 277 | obj_preds = self.nms_per_cls(obj_dists, boxes_per_cls, num_objs) 278 | else: 279 | obj_preds = obj_dists[:, 1:].max(1)[1] 280 | edge_pre_rep = cat((roi_features, obj_feats, self.obj_embed2(obj_preds)), dim=-1) 281 | 282 | # edge context 283 | edge_pre_rep = self.lin_edge(edge_pre_rep) 284 | edge_ctx = self.context_edge(edge_pre_rep, num_objs) 285 | 286 | return obj_dists, obj_preds, edge_ctx 287 | 288 | def nms_per_cls(self, obj_dists, boxes_per_cls, num_objs): 289 | obj_dists = obj_dists.split(num_objs, dim=0) 290 | obj_preds = [] 291 | for i in range(len(num_objs)): 292 | is_overlap = nms_overlaps(boxes_per_cls[i]).cpu().numpy() >= self.nms_thresh # (#box, #box, #class) 293 | 294 | out_dists_sampled = F.softmax(obj_dists[i], -1).cpu().numpy() 295 | #Last index for background 296 | out_dists_sampled[:, -1] = -1 297 | 298 | out_label = obj_dists[i].new(num_objs[i]).fill_(0) 299 | 300 | for i in range(num_objs[i]): 301 | box_ind, cls_ind = np.unravel_index(out_dists_sampled.argmax(), out_dists_sampled.shape) 302 | out_label[int(box_ind)] = int(cls_ind) 303 | out_dists_sampled[is_overlap[box_ind,:,cls_ind], cls_ind] = 0.0 304 | out_dists_sampled[box_ind] = -1.0 # This way we won't re-sample 305 | 306 | obj_preds.append(out_label.long()) 307 | obj_preds = torch.cat(obj_preds, dim=0) 308 | return obj_preds 309 | -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | def nms_overlaps(boxes): 7 | """ get overlaps for each channel""" 8 | assert boxes.dim() == 3 9 | N = boxes.size(0) 10 | nc = boxes.size(1) 11 | max_xy = torch.min(boxes[:, None, :, 2:].expand(N, N, nc, 2), 12 | boxes[None, :, :, 2:].expand(N, N, nc, 2)) 13 | 14 | min_xy = torch.max(boxes[:, None, :, :2].expand(N, N, nc, 2), 15 | boxes[None, :, :, :2].expand(N, N, nc, 2)) 16 | 17 | inter = torch.clamp((max_xy - min_xy + 1.0), min=0) 18 | 19 | # n, n, 151 20 | inters = inter[:,:,:,0]*inter[:,:,:,1] 21 | boxes_flat = boxes.view(-1, 4) 22 | areas_flat = (boxes_flat[:,2]- boxes_flat[:,0]+1.0)*( 23 | boxes_flat[:,3]- boxes_flat[:,1]+1.0) 24 | areas = areas_flat.view(boxes.size(0), boxes.size(1)) 25 | union = -inters + areas[None] + areas[:, None] 26 | return inters / union 27 | 28 | def layer_init(layer, init_para=0.1, normal=False, xavier=True): 29 | xavier = False if normal == True else True 30 | if normal: 31 | torch.nn.init.normal_(layer.weight, mean=0, std=init_para) 32 | torch.nn.init.constant_(layer.bias, 0) 33 | return 34 | elif xavier: 35 | torch.nn.init.xavier_normal_(layer.weight, gain=1.0) 36 | torch.nn.init.constant_(layer.bias, 0) 37 | return 38 | 39 | def block_orthogonal(tensor, split_sizes, gain=1.0): 40 | sizes = list(tensor.size()) 41 | if any([a % b != 0 for a, b in zip(sizes, split_sizes)]): 42 | raise ValueError("tensor dimensions must be divisible by their respective " 43 | "split_sizes. Found size: {} and split_sizes: {}".format(sizes, split_sizes)) 44 | indexes = [list(range(0, max_size, split)) 45 | for max_size, split in zip(sizes, split_sizes)] 46 | # Iterate over all possible blocks within the tensor. 47 | for block_start_indices in itertools.product(*indexes): 48 | # A list of tuples containing the index to start at for this block 49 | # and the appropriate step size (i.e split_size[i] for dimension i). 50 | index_and_step_tuples = zip(block_start_indices, split_sizes) 51 | # This is a tuple of slices corresponding to: 52 | # tensor[index: index + step_size, ...]. This is 53 | # required because we could have an arbitrary number 54 | # of dimensions. The actual slices we need are the 55 | # start_index: start_index + step for each dimension in the tensor. 56 | block_slice = tuple([slice(start_index, start_index + step) 57 | for start_index, step in index_and_step_tuples]) 58 | 59 | # let's not initialize empty things to 0s because THAT SOUNDS REALLY BAD 60 | assert len(block_slice) == 2 61 | sizes = [x.stop - x.start for x in block_slice] 62 | tensor_copy = tensor.new(max(sizes), max(sizes)) 63 | torch.nn.init.orthogonal_(tensor_copy, gain=gain) 64 | tensor[block_slice] = tensor_copy[0:sizes[0], 0:sizes[1]] 65 | 66 | def obj_prediction_nms(boxes_per_cls, pred_logits, nms_thresh=0.3): 67 | """ 68 | boxes_per_cls: [num_obj, num_cls, 4] 69 | pred_logits: [num_obj, num_category] 70 | """ 71 | num_obj = pred_logits.shape[0] 72 | assert num_obj == boxes_per_cls.shape[0] 73 | is_overlap = nms_overlaps(boxes_per_cls).view(boxes_per_cls.size(0), boxes_per_cls.size(0), 74 | boxes_per_cls.size(1)).cpu().numpy() >= nms_thresh 75 | 76 | prob_sampled = F.softmax(pred_logits, 1).cpu().numpy() 77 | prob_sampled[:, -1] = 0 # set bg to 0 78 | 79 | pred_label = torch.zeros(num_obj, device=pred_logits.device, dtype=torch.int64) 80 | 81 | for i in range(num_obj): 82 | box_ind, cls_ind = np.unravel_index(prob_sampled.argmax(), prob_sampled.shape) 83 | if float(pred_label[int(box_ind)]) > 0: 84 | pass 85 | else: 86 | pred_label[int(box_ind)] = int(cls_ind) 87 | prob_sampled[is_overlap[box_ind,:,cls_ind], cls_ind] = 0.0 88 | prob_sampled[box_ind] = -1.0 # This way we won't re-sample 89 | 90 | return pred_label -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/vctree/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_vctree import VCTreeLSTMContext, VCTreeLSTMSegmentationContext, VCTreeLSTMSegmentationContextC, VCTreeLSTMContextC -------------------------------------------------------------------------------- /segmentationsg/modeling/roi_heads/scenegraph_head/vctree/utils_vctree.py: -------------------------------------------------------------------------------- 1 | import array 2 | import os 3 | import zipfile 4 | import itertools 5 | import six 6 | from six.moves.urllib.request import urlretrieve 7 | from tqdm import tqdm 8 | import sys 9 | from detectron2.layers import cat 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import numpy as np 15 | 16 | 17 | def generate_forest(pair_scores, proposals, mode): 18 | """ 19 | generate a list of trees that covers all the objects in a batch 20 | proposal.bbox: [obj_num, (x1, y1, x2, y2)] 21 | pair_scores: [obj_num, obj_num] 22 | output: list of trees, each present a chunk of overlaping objects 23 | """ 24 | output_forest = [] # the list of trees, each one is a chunk of overlapping objects 25 | 26 | for pair_score, proposal in zip(pair_scores, proposals): 27 | num_obj = pair_score.shape[0] 28 | if mode == 'predcls': 29 | obj_label = proposal.pred_classes 30 | else: 31 | obj_label = proposal.pred_scores.max(-1)[1] 32 | 33 | 34 | assert pair_score.shape[0] == len(proposal) 35 | assert pair_score.shape[0] == pair_score.shape[1] 36 | node_scores = pair_score.mean(1).view(-1) 37 | root_idx = int(node_scores.max(-1)[1]) 38 | 39 | root = ArbitraryTree(root_idx, float(node_scores[root_idx]), int(obj_label[root_idx]), proposal.pred_boxes[root_idx].tensor, is_root=True) 40 | 41 | node_container = [] 42 | remain_index = [] 43 | # put all nodes into node container 44 | for idx in list(range(num_obj)): 45 | if idx == root_idx: 46 | continue 47 | new_node = ArbitraryTree(idx, float(node_scores[idx]), int(obj_label[idx]), proposal.pred_boxes[idx].tensor) 48 | node_container.append(new_node) 49 | remain_index.append(idx) 50 | 51 | # iteratively generate tree 52 | gen_tree(node_container, pair_score, node_scores, root, remain_index, mode) 53 | output_forest.append(root) 54 | 55 | return output_forest 56 | 57 | def gen_tree(node_container, pair_score, node_scores, root, remain_index, mode): 58 | """ 59 | Step 1: Devide all nodes into left child container and right child container 60 | Step 2: From left child container and right child container, select their respective sub roots 61 | pair_scores: [obj_num, obj_num] 62 | node_scores: [obj_num] 63 | """ 64 | num_nodes = len(node_container) 65 | device = pair_score.device 66 | # Step 0 67 | if num_nodes == 0: 68 | return 69 | # Step 1 70 | select_node = [] 71 | select_index = [] 72 | select_node.append(root) 73 | select_index.append(root.index) 74 | 75 | while len(node_container) > 0: 76 | wid = len(remain_index) 77 | select_indexs = torch.tensor(select_index, device=device, dtype=torch.int64) 78 | remain_indexs = torch.tensor(remain_index, device=device, dtype=torch.int64) 79 | select_score_map = pair_score[select_indexs][:, remain_indexs].view(-1) 80 | best_id = select_score_map.max(0)[1] 81 | 82 | depend_id = int(best_id) // wid 83 | insert_id = int(best_id) % wid 84 | best_depend_node = select_node[depend_id] 85 | best_insert_node = node_container[insert_id] 86 | best_depend_node.add_child(best_insert_node) 87 | 88 | select_node.append(best_insert_node) 89 | select_index.append(best_insert_node.index) 90 | node_container.remove(best_insert_node) 91 | remain_index.remove(best_insert_node.index) 92 | 93 | 94 | 95 | 96 | def arbForest_to_biForest(forest): 97 | """ 98 | forest: a set of arbitrary Tree 99 | output: a set of corresponding binary Tree 100 | """ 101 | output = [] 102 | for i in range(len(forest)): 103 | result_tree = arTree_to_biTree(forest[i]) 104 | output.append(result_tree) 105 | 106 | return output 107 | 108 | 109 | def arTree_to_biTree(arTree): 110 | root_node = arTree.generate_bi_tree() 111 | arNode_to_biNode(arTree, root_node) 112 | 113 | return root_node 114 | 115 | def arNode_to_biNode(arNode, biNode): 116 | if arNode.get_child_num() >= 1: 117 | new_bi_node = arNode.children[0].generate_bi_tree() 118 | biNode.add_left_child(new_bi_node) 119 | arNode_to_biNode(arNode.children[0], biNode.left_child) 120 | 121 | if arNode.get_child_num() > 1: 122 | current_bi_node = biNode.left_child 123 | for i in range(arNode.get_child_num() - 1): 124 | new_bi_node = arNode.children[i+1].generate_bi_tree() 125 | current_bi_node.add_right_child(new_bi_node) 126 | current_bi_node = current_bi_node.right_child 127 | arNode_to_biNode(arNode.children[i+1], current_bi_node) 128 | 129 | def find_best_node(node_container): 130 | max_node_score = -1 131 | best_node = None 132 | for i in range(len(node_container)): 133 | if node_container[i].score > max_node_score: 134 | max_node_score = node_container[i].score 135 | best_node = node_container[i] 136 | return best_node 137 | 138 | 139 | 140 | 141 | 142 | class BasicBiTree(object): 143 | def __init__(self, idx, is_root=False): 144 | self.index = int(idx) 145 | self.is_root = is_root 146 | self.left_child = None 147 | self.right_child = None 148 | self.parent = None 149 | self.num_child = 0 150 | 151 | def add_left_child(self, child): 152 | if self.left_child is not None: 153 | print('Left child already exist') 154 | return 155 | child.parent = self 156 | self.num_child += 1 157 | self.left_child = child 158 | 159 | def add_right_child(self, child): 160 | if self.right_child is not None: 161 | print('Right child already exist') 162 | return 163 | child.parent = self 164 | self.num_child += 1 165 | self.right_child = child 166 | 167 | def get_total_child(self): 168 | sum = 0 169 | sum += self.num_child 170 | if self.left_child is not None: 171 | sum += self.left_child.get_total_child() 172 | if self.right_child is not None: 173 | sum += self.right_child.get_total_child() 174 | return sum 175 | 176 | def depth(self): 177 | if hasattr(self, '_depth'): 178 | return self._depth 179 | if self.parent is None: 180 | count = 1 181 | else: 182 | count = self.parent.depth() + 1 183 | self._depth = count 184 | return self._depth 185 | 186 | def max_depth(self): 187 | if hasattr(self, '_max_depth'): 188 | return self._max_depth 189 | count = 0 190 | if self.left_child is not None: 191 | left_depth = self.left_child.max_depth() 192 | if left_depth > count: 193 | count = left_depth 194 | if self.right_child is not None: 195 | right_depth = self.right_child.max_depth() 196 | if right_depth > count: 197 | count = right_depth 198 | count += 1 199 | self._max_depth = count 200 | return self._max_depth 201 | 202 | # by index 203 | def is_descendant(self, idx): 204 | left_flag = False 205 | right_flag = False 206 | # node is left child 207 | if self.left_child is not None: 208 | if self.left_child.index is idx: 209 | return True 210 | else: 211 | left_flag = self.left_child.is_descendant(idx) 212 | # node is right child 213 | if self.right_child is not None: 214 | if self.right_child.index is idx: 215 | return True 216 | else: 217 | right_flag = self.right_child.is_descendant(idx) 218 | # node is descendant 219 | if left_flag or right_flag: 220 | return True 221 | else: 222 | return False 223 | 224 | # whether input node is under left sub tree 225 | def is_left_descendant(self, idx): 226 | if self.left_child is not None: 227 | if self.left_child.index is idx: 228 | return True 229 | else: 230 | return self.left_child.is_descendant(idx) 231 | else: 232 | return False 233 | 234 | # whether input node is under right sub tree 235 | def is_right_descendant(self, idx): 236 | if self.right_child is not None: 237 | if self.right_child.index is idx: 238 | return True 239 | else: 240 | return self.right_child.is_descendant(idx) 241 | else: 242 | return False 243 | 244 | 245 | class ArbitraryTree(object): 246 | def __init__(self, idx, score, label=-1, box=None, is_root=False): 247 | self.index = int(idx) 248 | self.is_root = is_root 249 | self.score = float(score) 250 | self.children = [] 251 | self.label = label 252 | self.embeded_label = None 253 | self.box = box.view(-1) if box is not None else None #[x1,y1,x2,y2] 254 | self.parent = None 255 | self.node_order = -1 # the n_th node added to the tree 256 | 257 | def generate_bi_tree(self): 258 | # generate a BiTree node, parent/child relationship are not inherited 259 | return BiTree(self.index, self.score, self.label, self.box, self.is_root) 260 | 261 | def add_child(self, child): 262 | child.parent = self 263 | self.children.append(child) 264 | 265 | def print(self): 266 | print('index: ', self.index) 267 | print('node_order: ', self.node_order) 268 | print('num of child: ', len(self.children)) 269 | for node in self.children: 270 | node.print() 271 | 272 | def find_node_by_order(self, order, result_node): 273 | if self.node_order == order: 274 | result_node = self 275 | elif len(self.children) > 0: 276 | for i in range(len(self.children)): 277 | result_node = self.children[i].find_node_by_order(order, result_node) 278 | 279 | return result_node 280 | 281 | def find_node_by_index(self, index, result_node): 282 | if self.index == index: 283 | result_node = self 284 | elif len(self.children) > 0: 285 | for i in range(len(self.children)): 286 | result_node = self.children[i].find_node_by_index(index, result_node) 287 | 288 | return result_node 289 | 290 | def search_best_insert(self, score_map, best_score, insert_node, best_depend_node, best_insert_node, ignore_root = True): 291 | if self.is_root and ignore_root: 292 | pass 293 | elif float(score_map[self.index, insert_node.index]) > float(best_score): 294 | best_score = score_map[self.index, insert_node.index] 295 | best_depend_node = self 296 | best_insert_node = insert_node 297 | 298 | # iteratively search child 299 | for i in range(self.get_child_num()): 300 | best_score, best_depend_node, best_insert_node = \ 301 | self.children[i].search_best_insert(score_map, best_score, insert_node, best_depend_node, best_insert_node) 302 | 303 | return best_score, best_depend_node, best_insert_node 304 | 305 | def get_child_num(self): 306 | return len(self.children) 307 | 308 | def get_total_child(self): 309 | sum = 0 310 | num_current_child = self.get_child_num() 311 | sum += num_current_child 312 | for i in range(num_current_child): 313 | sum += self.children[i].get_total_child() 314 | return sum 315 | 316 | # only support binary tree 317 | class BiTree(BasicBiTree): 318 | def __init__(self, idx, node_score, label, box, is_root=False): 319 | super(BiTree, self).__init__(idx, is_root) 320 | self.state_c = None 321 | self.state_h = None 322 | self.state_c_backward = None 323 | self.state_h_backward = None 324 | # used to select node 325 | self.node_score = float(node_score) 326 | self.label = label 327 | self.embeded_label = None 328 | self.box = box.view(-1) #[x1,y1,x2,y2] 329 | 330 | 331 | 332 | def bbox_intersection(box_a, box_b): 333 | A = box_a.size(0) 334 | B = box_b.size(0) 335 | max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), 336 | box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) 337 | min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), 338 | box_b[:, :2].unsqueeze(0).expand(A, B, 2)) 339 | inter = torch.clamp((max_xy - min_xy + 1.0), min=0) 340 | return inter[:, :, 0] * inter[:, :, 1] 341 | 342 | 343 | def bbox_overlap(box_a, box_b): 344 | inter = bbox_intersection(box_a, box_b) 345 | area_a = ((box_a[:, 2] - box_a[:, 0] + 1.0) * 346 | (box_a[:, 3] - box_a[:, 1] + 1.0)).unsqueeze(1).expand_as(inter) # [A,B] 347 | area_b = ((box_b[:, 2] - box_b[:, 0] + 1.0) * 348 | (box_b[:, 3] - box_b[:, 1] + 1.0)).unsqueeze(0).expand_as(inter) # [A,B] 349 | union = area_a + area_b - inter 350 | return inter / (union + 1e-9) 351 | 352 | 353 | def bbox_area(bbox): 354 | area = (bbox[:,2] - bbox[:,0]) * (bbox[:,3] - bbox[:,1]) 355 | return area.view(-1, 1) 356 | 357 | 358 | def get_overlap_info(boxes_list): 359 | IM_SCALE = 1024 360 | overlap_info = [] 361 | for bbox in boxes_list : 362 | boxes = bbox.tensor 363 | intersection = bbox_intersection(boxes, boxes).float() # num, num 364 | overlap = bbox_overlap(boxes, boxes).float() # num, num 365 | area = bbox_area(boxes).float() # num, 1 366 | 367 | info1 = (intersection > 0.0).float().sum(1).view(-1, 1) 368 | info2 = intersection.sum(1).view(-1, 1) / float(IM_SCALE * IM_SCALE) 369 | info3 = overlap.sum(1).view(-1, 1) 370 | info4 = info2 / (info1 + 1e-9) 371 | info5 = info3 / (info1 + 1e-9) 372 | info6 = area / float(IM_SCALE * IM_SCALE) 373 | 374 | info = torch.cat([info1, info2, info3, info4, info5, info6], dim=1) 375 | overlap_info.append(info) 376 | 377 | return torch.cat(overlap_info, dim=0) 378 | -------------------------------------------------------------------------------- /segmentationsg/structures/__init__.py: -------------------------------------------------------------------------------- 1 | from .boxes_ops import boxes_union 2 | from .masks_ops import masks_union -------------------------------------------------------------------------------- /segmentationsg/structures/boxes_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from detectron2.structures.boxes import Boxes 4 | 5 | def boxes_union(boxes1, boxes2): 6 | """ 7 | Compute the union region of two set of boxes 8 | Arguments: 9 | box1: (Boxes) bounding boxes, sized [N,4]. 10 | box2: (Boxes) bounding boxes, sized [N,4]. 11 | Returns: 12 | (Boxes) union, sized [N,4]. 13 | """ 14 | assert len(boxes1) == len(boxes2) 15 | 16 | union_box = torch.cat(( 17 | torch.min(boxes1.tensor[:,:2], boxes2.tensor[:,:2]), 18 | torch.max(boxes1.tensor[:,2:], boxes2.tensor[:,2:]) 19 | ),dim=1) 20 | return Boxes(union_box) 21 | -------------------------------------------------------------------------------- /segmentationsg/structures/masks_ops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from detectron2.structures.boxes import Boxes 4 | 5 | def masks_union(masks1, masks2): 6 | assert len(masks1) == len(masks2) 7 | masks_union = (masks1 + masks2)/2.0 8 | return masks_union 9 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import setup 4 | import torch 5 | 6 | torch_ver = [int(x) for x in torch.__version__.split(".")[:2]] 7 | assert torch_ver >= [1, 6], "Requires PyTorch >= 1.6" 8 | 9 | 10 | setup(name='segmentationsg', 11 | version='0.1', 12 | description='segmentationsg', 13 | packages=['segmentationsg'], 14 | zip_safe=False) 15 | 16 | --------------------------------------------------------------------------------