├── configs ├── R-50-grid.yaml ├── R-50-updn.yaml ├── X-152-grid.yaml ├── X-101-grid.yaml ├── X-152-challenge.yaml ├── Base-RCNN-updn.yaml └── Base-RCNN-grid.yaml ├── grid_feats ├── __init__.py ├── config.py ├── build_loader.py ├── visual_genome.py ├── dataset_mapper.py └── roi_heads.py ├── CONTRIBUTING.md ├── CODE_OF_CONDUCT.md ├── extract_grid_feature.py ├── train_net.py ├── README.md └── LICENSE /configs/R-50-grid.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-grid.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | RESNETS: 5 | DEPTH: 50 -------------------------------------------------------------------------------- /configs/R-50-updn.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-updn.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/MSRA/R-50.pkl" 4 | RESNETS: 5 | DEPTH: 50 -------------------------------------------------------------------------------- /configs/X-152-grid.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-grid.yaml" 2 | MODEL: 3 | WEIGHTS: "catalog://ImageNetPretrained/FAIR/X-152-32x8d-IN5k" 4 | RESNETS: 5 | STRIDE_IN_1X1: False # this is a C2 model 6 | NUM_GROUPS: 32 7 | WIDTH_PER_GROUP: 8 8 | DEPTH: 152 -------------------------------------------------------------------------------- /configs/X-101-grid.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-grid.yaml" 2 | MODEL: 3 | WEIGHTS: "detectron2://ImageNetPretrained/FAIR/X-101-32x8d.pkl" 4 | PIXEL_STD: [57.375, 57.120, 58.395] 5 | RESNETS: 6 | STRIDE_IN_1X1: False # this is a C2 model 7 | NUM_GROUPS: 32 8 | WIDTH_PER_GROUP: 8 9 | DEPTH: 101 -------------------------------------------------------------------------------- /grid_feats/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | from .config import add_attribute_config 3 | from .build_loader import ( 4 | build_detection_train_loader_with_attributes, 5 | build_detection_test_loader_with_attributes, 6 | ) 7 | from .roi_heads import AttributeRes5ROIHeads, AttributeStandardROIHeads 8 | from . import visual_genome -------------------------------------------------------------------------------- /configs/X-152-challenge.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "Base-RCNN-grid.yaml" 2 | MODEL: 3 | WEIGHTS: "catalog://ImageNetPretrained/FAIR/X-152-32x8d-IN5k" 4 | RESNETS: 5 | STRIDE_IN_1X1: False # this is a C2 model 6 | NUM_GROUPS: 32 7 | WIDTH_PER_GROUP: 8 8 | DEPTH: 152 9 | DEFORM_ON_PER_STAGE: [False,False,True,True] 10 | RPN: 11 | BBOX_LOSS_WEIGHT: .5 12 | ROI_BOX_HEAD: 13 | BBOX_LOSS_WEIGHT: 0. 14 | POOLER_SAMPLING_RATIO: 0 15 | INPUT: 16 | MIN_SIZE_TRAIN: (360,600) 17 | MIN_SIZE_TRAIN_SAMPLING: range 18 | SOLVER: 19 | MAX_ITER: 60000 20 | LR_SCHEDULER_NAME: WarmupCosineLR -------------------------------------------------------------------------------- /configs/Base-RCNN-updn.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | ATTRIBUTE_ON: True 4 | RPN: 5 | PRE_NMS_TOPK_TEST: 6000 6 | POST_NMS_TOPK_TEST: 1000 7 | SMOOTH_L1_BETA: 0.1111 8 | BOUNDARY_THRESH: 0 9 | ROI_HEADS: 10 | NAME: "AttributeRes5ROIHeads" 11 | NUM_CLASSES: 1600 12 | ROI_BOX_HEAD: 13 | POOLER_SAMPLING_RATIO: 2 14 | SMOOTH_L1_BETA: 1. 15 | DATASETS: 16 | TRAIN: ("visual_genome_train", "visual_genome_val") 17 | TEST: ("visual_genome_test",) 18 | SOLVER: 19 | IMS_PER_BATCH: 16 20 | BASE_LR: 0.02 21 | STEPS: (60000, 80000) 22 | MAX_ITER: 90000 23 | INPUT: 24 | MIN_SIZE_TRAIN: (600,) 25 | MAX_SIZE_TRAIN: 1000 26 | MIN_SIZE_TEST: 600 27 | MAX_SIZE_TEST: 1000 28 | VERSION: 2 29 | -------------------------------------------------------------------------------- /configs/Base-RCNN-grid.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | META_ARCHITECTURE: "GeneralizedRCNN" 3 | ATTRIBUTE_ON: True 4 | RESNETS: 5 | OUT_FEATURES: ["res5"] 6 | RES5_DILATION: 2 7 | RPN: 8 | IN_FEATURES: ["res5"] 9 | PRE_NMS_TOPK_TEST: 6000 10 | POST_NMS_TOPK_TEST: 1000 11 | SMOOTH_L1_BETA: 0.1111 12 | BOUNDARY_THRESH: 0 13 | ROI_HEADS: 14 | NAME: "AttributeStandardROIHeads" 15 | IN_FEATURES: ["res5"] 16 | NUM_CLASSES: 1600 17 | ROI_BOX_HEAD: 18 | NAME: "FastRCNNConvFCHead" 19 | NUM_FC: 2 20 | POOLER_RESOLUTION: 1 21 | POOLER_SAMPLING_RATIO: 2 22 | SMOOTH_L1_BETA: 1. 23 | DATASETS: 24 | TRAIN: ("visual_genome_train", "visual_genome_val") 25 | TEST: ("visual_genome_test",) 26 | SOLVER: 27 | IMS_PER_BATCH: 16 28 | BASE_LR: 0.02 29 | STEPS: (60000, 80000) 30 | MAX_ITER: 90000 31 | INPUT: 32 | MIN_SIZE_TRAIN: (600,) 33 | MAX_SIZE_TRAIN: 1000 34 | MIN_SIZE_TEST: 600 35 | MAX_SIZE_TEST: 1000 36 | VERSION: 2 37 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to grid-feats-vqa 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `master`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## Coding Style 30 | * 4 spaces for indentation rather than tabs 31 | 32 | ## License 33 | By contributing to grid-feats-vqa, you agree that your contributions will be licensed 34 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /grid_feats/config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | from detectron2.config import CfgNode as CN 5 | 6 | 7 | def add_attribute_config(cfg): 8 | """ 9 | Add config for attribute prediction. 10 | """ 11 | # Whether to have attribute prediction 12 | cfg.MODEL.ATTRIBUTE_ON = False 13 | # Maximum number of attributes per foreground instance 14 | cfg.INPUT.MAX_ATTR_PER_INS = 16 15 | # ------------------------------------------------------------------------ # 16 | # Attribute Head 17 | # ----------------------------------------------------------------------- # 18 | cfg.MODEL.ROI_ATTRIBUTE_HEAD = CN() 19 | # Dimension for object class embedding, used in conjunction with 20 | # visual features to predict attributes 21 | cfg.MODEL.ROI_ATTRIBUTE_HEAD.OBJ_EMBED_DIM = 256 22 | # Dimension of the hidden fc layer of the input visual features 23 | cfg.MODEL.ROI_ATTRIBUTE_HEAD.FC_DIM = 512 24 | # Loss weight for attribute prediction, 0.2 is best per analysis 25 | cfg.MODEL.ROI_ATTRIBUTE_HEAD.LOSS_WEIGHT = 0.2 26 | # Number of classes for attributes 27 | cfg.MODEL.ROI_ATTRIBUTE_HEAD.NUM_CLASSES = 400 28 | 29 | """ 30 | Add config for box regression loss adjustment. 31 | """ 32 | # Loss weights for RPN box regression 33 | cfg.MODEL.RPN.BBOX_LOSS_WEIGHT = 1.0 34 | # Loss weights for R-CNN box regression 35 | cfg.MODEL.ROI_BOX_HEAD.BBOX_LOSS_WEIGHT = 1.0 -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq 77 | -------------------------------------------------------------------------------- /extract_grid_feature.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | """ 5 | Grid features extraction script. 6 | """ 7 | import argparse 8 | import os 9 | import torch 10 | import tqdm 11 | from fvcore.common.file_io import PathManager 12 | 13 | from detectron2.checkpoint import DetectionCheckpointer 14 | from detectron2.config import get_cfg 15 | from detectron2.engine import default_setup 16 | from detectron2.evaluation import inference_context 17 | from detectron2.modeling import build_model 18 | 19 | from grid_feats import ( 20 | add_attribute_config, 21 | build_detection_test_loader_with_attributes, 22 | ) 23 | 24 | # A simple mapper from object detection dataset to VQA dataset names 25 | dataset_to_folder_mapper = {} 26 | dataset_to_folder_mapper['coco_2014_train'] = 'train2014' 27 | dataset_to_folder_mapper['coco_2014_val'] = 'val2014' 28 | # One may need to change the Detectron2 code to support coco_2015_test 29 | # insert "coco_2015_test": ("coco/test2015", "coco/annotations/image_info_test2015.json"), 30 | # at: https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/builtin.py#L36 31 | dataset_to_folder_mapper['coco_2015_test'] = 'test2015' 32 | 33 | def extract_grid_feature_argument_parser(): 34 | parser = argparse.ArgumentParser(description="Grid feature extraction") 35 | parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") 36 | parser.add_argument("--dataset", help="name of the dataset", default="coco_2014_train", 37 | choices=['coco_2014_train', 'coco_2014_val', 'coco_2015_test']) 38 | parser.add_argument( 39 | "opts", 40 | help="Modify config options using the command-line", 41 | default=None, 42 | nargs=argparse.REMAINDER, 43 | ) 44 | return parser 45 | 46 | def extract_grid_feature_on_dataset(model, data_loader, dump_folder): 47 | for idx, inputs in enumerate(tqdm.tqdm(data_loader)): 48 | with torch.no_grad(): 49 | image_id = inputs[0]['image_id'] 50 | file_name = '%d.pth' % image_id 51 | # compute features 52 | images = model.preprocess_image(inputs) 53 | features = model.backbone(images.tensor) 54 | outputs = model.roi_heads.get_conv5_features(features) 55 | with PathManager.open(os.path.join(dump_folder, file_name), "wb") as f: 56 | # save as CPU tensors 57 | torch.save(outputs.cpu(), f) 58 | 59 | def do_feature_extraction(cfg, model, dataset_name): 60 | with inference_context(model): 61 | dump_folder = os.path.join(cfg.OUTPUT_DIR, "features", dataset_to_folder_mapper[dataset_name]) 62 | PathManager.mkdirs(dump_folder) 63 | data_loader = build_detection_test_loader_with_attributes(cfg, dataset_name) 64 | extract_grid_feature_on_dataset(model, data_loader, dump_folder) 65 | 66 | def setup(args): 67 | """ 68 | Create configs and perform basic setups. 69 | """ 70 | cfg = get_cfg() 71 | add_attribute_config(cfg) 72 | cfg.merge_from_file(args.config_file) 73 | cfg.merge_from_list(args.opts) 74 | # force the final residual block to have dilations 1 75 | cfg.MODEL.RESNETS.RES5_DILATION = 1 76 | cfg.freeze() 77 | default_setup(cfg, args) 78 | return cfg 79 | 80 | 81 | def main(args): 82 | cfg = setup(args) 83 | model = build_model(cfg) 84 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 85 | cfg.MODEL.WEIGHTS, resume=True 86 | ) 87 | do_feature_extraction(cfg, model, args.dataset) 88 | 89 | 90 | if __name__ == "__main__": 91 | args = extract_grid_feature_argument_parser().parse_args() 92 | print("Command Line Args:", args) 93 | main(args) 94 | -------------------------------------------------------------------------------- /grid_feats/build_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import logging 3 | import operator 4 | import torch.utils.data 5 | 6 | from detectron2.utils.comm import get_world_size 7 | from detectron2.data import samplers 8 | from detectron2.data.build import get_detection_dataset_dicts, worker_init_reset_seed, trivial_batch_collator 9 | from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, MapDataset 10 | 11 | from .dataset_mapper import AttributeDatasetMapper 12 | 13 | 14 | def build_detection_train_loader_with_attributes(cfg, mapper=None): 15 | num_workers = get_world_size() 16 | images_per_batch = cfg.SOLVER.IMS_PER_BATCH 17 | assert ( 18 | images_per_batch % num_workers == 0 19 | ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( 20 | images_per_batch, num_workers 21 | ) 22 | assert ( 23 | images_per_batch >= num_workers 24 | ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( 25 | images_per_batch, num_workers 26 | ) 27 | images_per_worker = images_per_batch // num_workers 28 | 29 | dataset_dicts = get_detection_dataset_dicts( 30 | cfg.DATASETS.TRAIN, 31 | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, 32 | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE 33 | if cfg.MODEL.KEYPOINT_ON 34 | else 0, 35 | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, 36 | ) 37 | dataset = DatasetFromList(dataset_dicts, copy=False) 38 | 39 | if mapper is None: 40 | mapper = AttributeDatasetMapper(cfg, True) 41 | dataset = MapDataset(dataset, mapper) 42 | 43 | sampler_name = cfg.DATALOADER.SAMPLER_TRAIN 44 | logger = logging.getLogger(__name__) 45 | logger.info("Using training sampler {}".format(sampler_name)) 46 | if sampler_name == "TrainingSampler": 47 | sampler = samplers.TrainingSampler(len(dataset)) 48 | elif sampler_name == "RepeatFactorTrainingSampler": 49 | sampler = samplers.RepeatFactorTrainingSampler( 50 | dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD 51 | ) 52 | else: 53 | raise ValueError("Unknown training sampler: {}".format(sampler_name)) 54 | 55 | if cfg.DATALOADER.ASPECT_RATIO_GROUPING: 56 | data_loader = torch.utils.data.DataLoader( 57 | dataset, 58 | sampler=sampler, 59 | num_workers=cfg.DATALOADER.NUM_WORKERS, 60 | batch_sampler=None, 61 | collate_fn=operator.itemgetter(0), 62 | worker_init_fn=worker_init_reset_seed, 63 | ) 64 | data_loader = AspectRatioGroupedDataset(data_loader, images_per_worker) 65 | else: 66 | batch_sampler = torch.utils.data.sampler.BatchSampler( 67 | sampler, images_per_worker, drop_last=True 68 | ) 69 | data_loader = torch.utils.data.DataLoader( 70 | dataset, 71 | num_workers=cfg.DATALOADER.NUM_WORKERS, 72 | batch_sampler=batch_sampler, 73 | collate_fn=trivial_batch_collator, 74 | worker_init_fn=worker_init_reset_seed, 75 | ) 76 | 77 | return data_loader 78 | 79 | 80 | def build_detection_test_loader_with_attributes(cfg, dataset_name, mapper=None): 81 | dataset_dicts = get_detection_dataset_dicts( 82 | [dataset_name], 83 | filter_empty=False, 84 | proposal_files=[ 85 | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] 86 | ] 87 | if cfg.MODEL.LOAD_PROPOSALS 88 | else None, 89 | ) 90 | 91 | dataset = DatasetFromList(dataset_dicts) 92 | if mapper is None: 93 | mapper = AttributeDatasetMapper(cfg, False) 94 | dataset = MapDataset(dataset, mapper) 95 | 96 | sampler = samplers.InferenceSampler(len(dataset)) 97 | batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False) 98 | 99 | data_loader = torch.utils.data.DataLoader( 100 | dataset, 101 | num_workers=cfg.DATALOADER.NUM_WORKERS, 102 | batch_sampler=batch_sampler, 103 | collate_fn=trivial_batch_collator, 104 | ) 105 | return data_loader -------------------------------------------------------------------------------- /train_net.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | 4 | """ 5 | Grid features pre-training script. 6 | 7 | This script is a simplified version of the training script in detectron2/tools. 8 | """ 9 | 10 | import os 11 | import time 12 | import torch 13 | 14 | import detectron2.utils.comm as comm 15 | from detectron2.checkpoint import DetectionCheckpointer 16 | from detectron2.config import get_cfg 17 | from detectron2.data import MetadataCatalog 18 | from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch 19 | from detectron2.evaluation import COCOEvaluator, DatasetEvaluators, verify_results 20 | 21 | from grid_feats import ( 22 | add_attribute_config, 23 | build_detection_train_loader_with_attributes, 24 | build_detection_test_loader_with_attributes, 25 | ) 26 | 27 | 28 | class Trainer(DefaultTrainer): 29 | """ 30 | A trainer for visual genome dataset. 31 | """ 32 | def __init__(self, cfg): 33 | super().__init__(cfg) 34 | self.rpn_box_lw = cfg.MODEL.RPN.BBOX_LOSS_WEIGHT 35 | self.rcnn_box_lw = cfg.MODEL.ROI_BOX_HEAD.BBOX_LOSS_WEIGHT 36 | 37 | @classmethod 38 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 39 | if output_folder is None: 40 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 41 | evaluator_list = [] 42 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 43 | if evaluator_type == "coco": 44 | return COCOEvaluator(dataset_name, cfg, True, output_folder) 45 | if len(evaluator_list) == 0: 46 | raise NotImplementedError( 47 | "no Evaluator for the dataset {} with the type {}".format( 48 | dataset_name, evaluator_type 49 | ) 50 | ) 51 | if len(evaluator_list) == 1: 52 | return evaluator_list[0] 53 | return DatasetEvaluators(evaluator_list) 54 | 55 | @classmethod 56 | def build_train_loader(cls, cfg): 57 | return build_detection_train_loader_with_attributes(cfg) 58 | 59 | @classmethod 60 | def build_test_loader(cls, cfg, dataset_name): 61 | return build_detection_test_loader_with_attributes(cfg, dataset_name) 62 | 63 | def run_step(self): 64 | """ 65 | !!Hack!! for the run_step method in SimpleTrainer to adjust the loss 66 | """ 67 | assert self.model.training, "[Trainer] model was changed to eval mode!" 68 | start = time.perf_counter() 69 | data = next(self._data_loader_iter) 70 | data_time = time.perf_counter() - start 71 | loss_dict = self.model(data) 72 | # RPN box loss: 73 | loss_dict["loss_rpn_loc"] *= self.rpn_box_lw 74 | # R-CNN box loss: 75 | loss_dict["loss_box_reg"] *= self.rcnn_box_lw 76 | losses = sum(loss_dict.values()) 77 | self._detect_anomaly(losses, loss_dict) 78 | 79 | metrics_dict = loss_dict 80 | metrics_dict["data_time"] = data_time 81 | self._write_metrics(metrics_dict) 82 | self.optimizer.zero_grad() 83 | losses.backward() 84 | self.optimizer.step() 85 | 86 | 87 | def setup(args): 88 | """ 89 | Create configs and perform basic setups. 90 | """ 91 | cfg = get_cfg() 92 | add_attribute_config(cfg) 93 | cfg.merge_from_file(args.config_file) 94 | cfg.merge_from_list(args.opts) 95 | cfg.freeze() 96 | default_setup(cfg, args) 97 | return cfg 98 | 99 | 100 | def main(args): 101 | cfg = setup(args) 102 | 103 | if args.eval_only: 104 | model = Trainer.build_model(cfg) 105 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 106 | cfg.MODEL.WEIGHTS, resume=args.resume 107 | ) 108 | res = Trainer.test(cfg, model) 109 | if comm.is_main_process(): 110 | verify_results(cfg, res) 111 | return res 112 | 113 | trainer = Trainer(cfg) 114 | trainer.resume_or_load(resume=args.resume) 115 | return trainer.train() 116 | 117 | 118 | if __name__ == "__main__": 119 | args = default_argument_parser().parse_args() 120 | print("Command Line Args:", args) 121 | launch( 122 | main, 123 | args.num_gpus, 124 | num_machines=args.num_machines, 125 | machine_rank=args.machine_rank, 126 | dist_url=args.dist_url, 127 | args=(args,), 128 | ) 129 | -------------------------------------------------------------------------------- /grid_feats/visual_genome.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 3 | import contextlib 4 | import io 5 | import logging 6 | import os 7 | from fvcore.common.file_io import PathManager 8 | from fvcore.common.timer import Timer 9 | 10 | from detectron2.data import DatasetCatalog, MetadataCatalog 11 | from detectron2.structures import BoxMode 12 | 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | def load_coco_with_attributes_json(json_file, 17 | image_root, 18 | dataset_name=None, 19 | extra_annotation_keys=None): 20 | """ 21 | Extend load_coco_json() with additional support for attributes 22 | """ 23 | from pycocotools.coco import COCO 24 | 25 | timer = Timer() 26 | json_file = PathManager.get_local_path(json_file) 27 | with contextlib.redirect_stdout(io.StringIO()): 28 | coco_api = COCO(json_file) 29 | if timer.seconds() > 1: 30 | logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) 31 | 32 | id_map = None 33 | if dataset_name is not None: 34 | meta = MetadataCatalog.get(dataset_name) 35 | cat_ids = sorted(coco_api.getCatIds()) 36 | cats = coco_api.loadCats(cat_ids) 37 | thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])] 38 | meta.thing_classes = thing_classes 39 | if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)): 40 | if "coco" not in dataset_name: 41 | logger.warning( 42 | """ 43 | Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you. 44 | """ 45 | ) 46 | id_map = {v: i for i, v in enumerate(cat_ids)} 47 | meta.thing_dataset_id_to_contiguous_id = id_map 48 | 49 | img_ids = sorted(coco_api.imgs.keys()) 50 | imgs = coco_api.loadImgs(img_ids) 51 | anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] 52 | 53 | if "minival" not in json_file: 54 | ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] 55 | assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format( 56 | json_file 57 | ) 58 | 59 | imgs_anns = list(zip(imgs, anns)) 60 | 61 | logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file)) 62 | 63 | dataset_dicts = [] 64 | 65 | ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or []) 66 | 67 | num_instances_without_valid_segmentation = 0 68 | 69 | for (img_dict, anno_dict_list) in imgs_anns: 70 | record = {} 71 | record["file_name"] = os.path.join(image_root, img_dict["file_name"]) 72 | record["height"] = img_dict["height"] 73 | record["width"] = img_dict["width"] 74 | image_id = record["image_id"] = img_dict["id"] 75 | 76 | objs = [] 77 | for anno in anno_dict_list: 78 | assert anno["image_id"] == image_id 79 | 80 | assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.' 81 | 82 | obj = {key: anno[key] for key in ann_keys if key in anno} 83 | 84 | segm = anno.get("segmentation", None) 85 | if segm: 86 | if not isinstance(segm, dict): 87 | segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] 88 | if len(segm) == 0: 89 | num_instances_without_valid_segmentation += 1 90 | continue 91 | obj["segmentation"] = segm 92 | 93 | keypts = anno.get("keypoints", None) 94 | if keypts: 95 | for idx, v in enumerate(keypts): 96 | if idx % 3 != 2: 97 | keypts[idx] = v + 0.5 98 | obj["keypoints"] = keypts 99 | 100 | attrs = anno.get("attribute_ids", None) 101 | if attrs: # list[int] 102 | obj["attribute_ids"] = attrs 103 | 104 | obj["bbox_mode"] = BoxMode.XYWH_ABS 105 | if id_map: 106 | obj["category_id"] = id_map[obj["category_id"]] 107 | objs.append(obj) 108 | record["annotations"] = objs 109 | dataset_dicts.append(record) 110 | 111 | if num_instances_without_valid_segmentation > 0: 112 | logger.warning( 113 | "Filtered out {} instances without valid segmentation. " 114 | "There might be issues in your dataset generation process.".format( 115 | num_instances_without_valid_segmentation 116 | ) 117 | ) 118 | return dataset_dicts 119 | 120 | def register_coco_instances_with_attributes(name, metadata, json_file, image_root): 121 | DatasetCatalog.register(name, lambda: load_coco_with_attributes_json(json_file, 122 | image_root, 123 | name)) 124 | MetadataCatalog.get(name).set( 125 | json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata 126 | ) 127 | 128 | # ==== Predefined splits for visual genome images =========== 129 | _PREDEFINED_SPLITS_VG = { 130 | "visual_genome_train": ("visual_genome/images", 131 | "visual_genome/annotations/visual_genome_train.json"), 132 | "visual_genome_val": ("visual_genome/images", 133 | "visual_genome/annotations/visual_genome_val.json"), 134 | "visual_genome_test": ("visual_genome/images", 135 | "visual_genome/annotations/visual_genome_test.json"), 136 | } 137 | 138 | def register_all_vg(root): 139 | for key, (image_root, json_file) in _PREDEFINED_SPLITS_VG.items(): 140 | register_coco_instances_with_attributes( 141 | key, 142 | {}, # no meta data 143 | os.path.join(root, json_file), 144 | os.path.join(root, image_root), 145 | ) 146 | 147 | # Register them all under "./datasets" 148 | _root = os.getenv("DETECTRON2_DATASETS", "datasets") 149 | register_all_vg(_root) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # In Defense of Grid Features for Visual Question Answering 2 | **Grid Feature Pre-Training Code** 3 | 4 |

5 | 6 |

7 | 8 | This is a feature pre-training code release of the [paper](https://arxiv.org/abs/2001.03615): 9 | ``` 10 | @InProceedings{jiang2020defense, 11 | title={In Defense of Grid Features for Visual Question Answering}, 12 | author={Jiang, Huaizu and Misra, Ishan and Rohrbach, Marcus and Learned-Miller, Erik and Chen, Xinlei}, 13 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 14 | year={2020} 15 | } 16 | ``` 17 | For more sustained maintenance, we release code using [Detectron2](https://github.com/facebookresearch/detectron2) instead of [mask-rcnn-benchmark](https://github.com/facebookresearch/maskrcnn-benchmark) which the original code is based on. The current repository should reproduce the results reported in the paper, *e.g.*, reporting **~72.5** single model VQA score for a X-101 backbone paired with [MCAN](https://github.com/MILVLG/mcan-vqa)-large. 18 | 19 | ## Installation 20 | Install Detectron 2 following [INSTALL.md](https://github.com/facebookresearch/detectron2/blob/master/INSTALL.md). Since Detectron 2 is also being actively updated which can result in breaking behaviors, it is **highly recommended** to install via the following command: 21 | ```bash 22 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git@ffff8ac' 23 | ``` 24 | Commits before or after `ffff8ac` might also work, but it could be risky. 25 | Then clone this repository: 26 | ```bash 27 | git clone git@github.com:facebookresearch/grid-feats-vqa.git 28 | cd grid-feats-vqa 29 | ``` 30 | 31 | ## Data 32 | [Visual Genome](http://visualgenome.org/) `train+val` splits released from the bottom-up-attention [code](https://github.com/peteanderson80/bottom-up-attention) are used for pre-training, and `test` split is used for evaluating detection performance. All of them are prepared in [COCO](http://cocodataset.org/) format but include an additional field for `attribute` prediction. We provide the `.json` files [here](https://dl.fbaipublicfiles.com/grid-feats-vqa/json/visual_genome.tgz) which can be directly loaded by Detectron2. Same as in Detectron2, the expected dataset structure under the `DETECTRON2_DATASETS` (default is `./datasets` relative to your current working directory) folder should be: 33 | ``` 34 | visual_genome/ 35 | annotations/ 36 | visual_genome_{train,val,test}.json 37 | images/ 38 | # visual genome images (~108K) 39 | ``` 40 | 41 | ## Training 42 | Once the dataset is setup, to train a model, run (by default we use 8 GPUs): 43 | ```bash 44 | python train_net.py --num-gpus 8 --config-file 45 | ``` 46 | For example, to launch grid-feature pre-training with ResNet-50 backbone on 8 GPUs, one should execute: 47 | ```bash 48 | python train_net.py --num-gpus 8 --config-file configs/R-50-grid.yaml 49 | ``` 50 | The final model by default should be saved under `./output` of your current working directory once it is done training. We also provide the region-feature pre-training configuration `configs/R-50-updn.yaml` for reference. Note that we use `0.2` attribute loss (`MODEL.ROI_ATTRIBUTE_HEAD.LOSS_WEIGHT = 0.2`), which is better for down-stream tasks like VQA per our analysis. 51 | 52 | We also release the configuration (`configs/R-50-updn.yaml`) for training the region features described in **bottom-up-attention** paper, which is a faithful re-implementation of the original [one](https://github.com/peteanderson80/bottom-up-attention) in Detectron2. 53 | 54 | ## Feature Extraction 55 | Grid feature extraction can be done by simply running once the model is trained (or you can directly download our pre-trained models, see below): 56 | ```bash 57 | python extract_grid_feature.py -config-file configs/R-50-grid.yaml --dataset 58 | ``` 59 | and the code will load the final model from `cfg.OUTPUT_DIR` (which one can override in command line) and start extracting features for ``, we provide three options for the dataset: `coco_2014_train`, `coco_2014_val` and `coco_2015_test`, they correspond to `train`, `val` and `test` splits of the VQA dataset. The extracted features can be conveniently loaded in [Pythia](https://github.com/facebookresearch/pythia). 60 | 61 | To extract features on your customized dataset, you may want to dump the image information into [COCO](http://cocodataset.org/) `.json` format, and add the dataset information to use `extract_grid_feature.py`, or you can hack `extract_grid_feature.py` and directly loop over images. 62 | 63 | ## Pre-Trained Models and Features 64 | We release several pre-trained models for grid features: one with R-50 backbone, one with X-101, one with X-152, and one with additional improvements used for the 2020 VQA Challenge (see `X-152-challenge.yaml`). The models can be used directly to extract features. For your convenience, we also release the pre-extracted features for direct download. 65 | 66 | | Backbone | AP50:95 | Download | 67 | | -------- | ---- | -------- | 68 | | R-50 | 3.1 | model \|  metrics \|  features | 69 | | X-101 | 4.3 | model \|  metrics \|  features | 70 | | X-152 | 4.7 | model \|  metrics \|  features | 71 | | X-152++ | 3.7 | model \|  metrics \|  features | 72 | 73 | ## License 74 | 75 | The code is released under the [Apache 2.0 license](LICENSE). -------------------------------------------------------------------------------- /grid_feats/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import copy 3 | import logging 4 | import numpy as np 5 | import torch 6 | from fvcore.common.file_io import PathManager 7 | from PIL import Image 8 | 9 | from detectron2.data import detection_utils as utils 10 | from detectron2.data import transforms as T 11 | from detectron2.data import DatasetMapper 12 | from detectron2.structures import ( 13 | BitMasks, 14 | Boxes, 15 | BoxMode, 16 | Instances, 17 | Keypoints, 18 | PolygonMasks, 19 | polygons_to_bitmask, 20 | ) 21 | 22 | 23 | def annotations_to_instances_with_attributes(annos, 24 | image_size, 25 | mask_format="polygon", 26 | load_attributes=False, 27 | max_attr_per_ins=16): 28 | """ 29 | Extend the function annotations_to_instances() to support attributes 30 | """ 31 | boxes = [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos] 32 | target = Instances(image_size) 33 | boxes = target.gt_boxes = Boxes(boxes) 34 | boxes.clip(image_size) 35 | 36 | classes = [obj["category_id"] for obj in annos] 37 | classes = torch.tensor(classes, dtype=torch.int64) 38 | target.gt_classes = classes 39 | 40 | if len(annos) and "segmentation" in annos[0]: 41 | segms = [obj["segmentation"] for obj in annos] 42 | if mask_format == "polygon": 43 | masks = PolygonMasks(segms) 44 | else: 45 | assert mask_format == "bitmask", mask_format 46 | masks = [] 47 | for segm in segms: 48 | if isinstance(segm, list): 49 | # polygon 50 | masks.append(polygons_to_bitmask(segm, *image_size)) 51 | elif isinstance(segm, dict): 52 | # COCO RLE 53 | masks.append(mask_util.decode(segm)) 54 | elif isinstance(segm, np.ndarray): 55 | assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format( 56 | segm.ndim 57 | ) 58 | # mask array 59 | masks.append(segm) 60 | else: 61 | raise ValueError( 62 | "Cannot convert segmentation of type '{}' to BitMasks!" 63 | "Supported types are: polygons as list[list[float] or ndarray]," 64 | " COCO-style RLE as a dict, or a full-image segmentation mask " 65 | "as a 2D ndarray.".format(type(segm)) 66 | ) 67 | masks = BitMasks( 68 | torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks]) 69 | ) 70 | target.gt_masks = masks 71 | 72 | if len(annos) and "keypoints" in annos[0]: 73 | kpts = [obj.get("keypoints", []) for obj in annos] 74 | target.gt_keypoints = Keypoints(kpts) 75 | 76 | if len(annos) and load_attributes: 77 | attributes = -torch.ones((len(annos), max_attr_per_ins), dtype=torch.int64) 78 | for idx, anno in enumerate(annos): 79 | if "attribute_ids" in anno: 80 | for jdx, attr_id in enumerate(anno["attribute_ids"]): 81 | attributes[idx, jdx] = attr_id 82 | target.gt_attributes = attributes 83 | 84 | return target 85 | 86 | 87 | class AttributeDatasetMapper(DatasetMapper): 88 | """ 89 | Extend DatasetMapper to support attributes. 90 | """ 91 | def __init__(self, cfg, is_train=True): 92 | super().__init__(cfg, is_train) 93 | 94 | # fmt: off 95 | self.attribute_on = cfg.MODEL.ATTRIBUTE_ON 96 | self.max_attr_per_ins = cfg.INPUT.MAX_ATTR_PER_INS 97 | # fmt: on 98 | 99 | def __call__(self, dataset_dict): 100 | dataset_dict = copy.deepcopy(dataset_dict) 101 | image = utils.read_image(dataset_dict["file_name"], format=self.img_format) 102 | utils.check_image_size(dataset_dict, image) 103 | 104 | if "annotations" not in dataset_dict: 105 | image, transforms = T.apply_transform_gens( 106 | ([self.crop_gen] if self.crop_gen else []) + self.tfm_gens, image 107 | ) 108 | else: 109 | if self.crop_gen: 110 | crop_tfm = utils.gen_crop_transform_with_instance( 111 | self.crop_gen.get_crop_size(image.shape[:2]), 112 | image.shape[:2], 113 | np.random.choice(dataset_dict["annotations"]), 114 | ) 115 | image = crop_tfm.apply_image(image) 116 | image, transforms = T.apply_transform_gens(self.tfm_gens, image) 117 | if self.crop_gen: 118 | transforms = crop_tfm + transforms 119 | 120 | image_shape = image.shape[:2] 121 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 122 | 123 | if self.load_proposals: 124 | utils.transform_proposals( 125 | dataset_dict, image_shape, transforms, self.min_box_side_len, self.proposal_topk 126 | ) 127 | 128 | if not self.is_train: 129 | dataset_dict.pop("annotations", None) 130 | dataset_dict.pop("sem_seg_file_name", None) 131 | return dataset_dict 132 | 133 | if "annotations" in dataset_dict: 134 | for anno in dataset_dict["annotations"]: 135 | if not self.mask_on: 136 | anno.pop("segmentation", None) 137 | if not self.keypoint_on: 138 | anno.pop("keypoints", None) 139 | if not self.attribute_on: 140 | anno.pop("attribute_ids") 141 | 142 | annos = [ 143 | utils.transform_instance_annotations( 144 | obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices 145 | ) 146 | for obj in dataset_dict.pop("annotations") 147 | if obj.get("iscrowd", 0) == 0 148 | ] 149 | instances = annotations_to_instances_with_attributes( 150 | annos, image_shape, mask_format=self.mask_format, 151 | load_attributes=self.attribute_on, max_attr_per_ins=self.max_attr_per_ins 152 | ) 153 | if self.crop_gen and instances.has("gt_masks"): 154 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 155 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 156 | 157 | if "sem_seg_file_name" in dataset_dict: 158 | with PathManager.open(dataset_dict.pop("sem_seg_file_name"), "rb") as f: 159 | sem_seg_gt = Image.open(f) 160 | sem_seg_gt = np.asarray(sem_seg_gt, dtype="uint8") 161 | sem_seg_gt = transforms.apply_segmentation(sem_seg_gt) 162 | sem_seg_gt = torch.as_tensor(sem_seg_gt.astype("long")) 163 | dataset_dict["sem_seg"] = sem_seg_gt 164 | return dataset_dict 165 | -------------------------------------------------------------------------------- /grid_feats/roi_heads.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | from detectron2.layers import ShapeSpec 7 | from detectron2.modeling.roi_heads import ( 8 | build_box_head, 9 | build_mask_head, 10 | select_foreground_proposals, 11 | ROI_HEADS_REGISTRY, 12 | ROIHeads, 13 | Res5ROIHeads, 14 | StandardROIHeads, 15 | ) 16 | from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers 17 | from detectron2.modeling.poolers import ROIPooler 18 | 19 | 20 | class AttributePredictor(nn.Module): 21 | """ 22 | Head for attribute prediction, including feature/score computation and 23 | loss computation. 24 | 25 | """ 26 | def __init__(self, cfg, input_dim): 27 | super().__init__() 28 | 29 | # fmt: off 30 | self.num_objs = cfg.MODEL.ROI_HEADS.NUM_CLASSES 31 | self.obj_embed_dim = cfg.MODEL.ROI_ATTRIBUTE_HEAD.OBJ_EMBED_DIM 32 | self.fc_dim = cfg.MODEL.ROI_ATTRIBUTE_HEAD.FC_DIM 33 | self.num_attributes = cfg.MODEL.ROI_ATTRIBUTE_HEAD.NUM_CLASSES 34 | self.max_attr_per_ins = cfg.INPUT.MAX_ATTR_PER_INS 35 | self.loss_weight = cfg.MODEL.ROI_ATTRIBUTE_HEAD.LOSS_WEIGHT 36 | # fmt: on 37 | 38 | # object class embedding, including the background class 39 | self.obj_embed = nn.Embedding(self.num_objs + 1, self.obj_embed_dim) 40 | input_dim += self.obj_embed_dim 41 | self.fc = nn.Sequential( 42 | nn.Linear(input_dim, self.fc_dim), 43 | nn.ReLU() 44 | ) 45 | self.attr_score = nn.Linear(self.fc_dim, self.num_attributes) 46 | nn.init.normal_(self.attr_score.weight, std=0.01) 47 | nn.init.constant_(self.attr_score.bias, 0) 48 | 49 | def forward(self, x, obj_labels): 50 | attr_feat = torch.cat((x, self.obj_embed(obj_labels)), dim=1) 51 | return self.attr_score(self.fc(attr_feat)) 52 | 53 | def loss(self, score, label): 54 | n = score.shape[0] 55 | score = score.unsqueeze(1) 56 | score = score.expand(n, self.max_attr_per_ins, self.num_attributes).contiguous() 57 | score = score.view(-1, self.num_attributes) 58 | inv_weights = ( 59 | (label >= 0).sum(dim=1).repeat(self.max_attr_per_ins, 1).transpose(0, 1).flatten() 60 | ) 61 | weights = inv_weights.float().reciprocal() 62 | weights[weights > 1] = 0. 63 | n_valid = len((label >= 0).sum(dim=1).nonzero()) 64 | label = label.view(-1) 65 | attr_loss = F.cross_entropy(score, label, reduction="none", ignore_index=-1) 66 | attr_loss = (attr_loss * weights).view(n, -1).sum(dim=1) 67 | 68 | if n_valid > 0: 69 | attr_loss = attr_loss.sum() * self.loss_weight / n_valid 70 | else: 71 | attr_loss = attr_loss.sum() * 0. 72 | return {"loss_attr": attr_loss} 73 | 74 | 75 | class AttributeROIHeads(ROIHeads): 76 | """ 77 | An extension of ROIHeads to include attribute prediction. 78 | """ 79 | def forward_attribute_loss(self, proposals, box_features): 80 | proposals, fg_selection_attributes = select_foreground_proposals( 81 | proposals, self.num_classes 82 | ) 83 | attribute_features = box_features[torch.cat(fg_selection_attributes, dim=0)] 84 | obj_labels = torch.cat([p.gt_classes for p in proposals]) 85 | attribute_labels = torch.cat([p.gt_attributes for p in proposals], dim=0) 86 | attribute_scores = self.attribute_predictor(attribute_features, obj_labels) 87 | return self.attribute_predictor.loss(attribute_scores, attribute_labels) 88 | 89 | 90 | @ROI_HEADS_REGISTRY.register() 91 | class AttributeRes5ROIHeads(AttributeROIHeads, Res5ROIHeads): 92 | """ 93 | An extension of Res5ROIHeads to include attribute prediction. 94 | """ 95 | def __init__(self, cfg, input_shape): 96 | super(Res5ROIHeads, self).__init__(cfg, input_shape) 97 | 98 | assert len(self.in_features) == 1 99 | 100 | # fmt: off 101 | pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION 102 | pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE 103 | pooler_scales = (1.0 / input_shape[self.in_features[0]].stride, ) 104 | sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO 105 | self.mask_on = cfg.MODEL.MASK_ON 106 | self.attribute_on = cfg.MODEL.ATTRIBUTE_ON 107 | # fmt: on 108 | assert not cfg.MODEL.KEYPOINT_ON 109 | 110 | self.pooler = ROIPooler( 111 | output_size=pooler_resolution, 112 | scales=pooler_scales, 113 | sampling_ratio=sampling_ratio, 114 | pooler_type=pooler_type, 115 | ) 116 | 117 | self.res5, out_channels = self._build_res5_block(cfg) 118 | self.box_predictor = FastRCNNOutputLayers( 119 | cfg, ShapeSpec(channels=out_channels, height=1, width=1) 120 | ) 121 | 122 | if self.mask_on: 123 | self.mask_head = build_mask_head( 124 | cfg, 125 | ShapeSpec(channels=out_channels, width=pooler_resolution, height=pooler_resolution), 126 | ) 127 | 128 | if self.attribute_on: 129 | self.attribute_predictor = AttributePredictor(cfg, out_channels) 130 | 131 | def forward(self, images, features, proposals, targets=None): 132 | del images 133 | 134 | if self.training: 135 | assert targets 136 | proposals = self.label_and_sample_proposals(proposals, targets) 137 | del targets 138 | 139 | proposal_boxes = [x.proposal_boxes for x in proposals] 140 | box_features = self._shared_roi_transform( 141 | [features[f] for f in self.in_features], proposal_boxes 142 | ) 143 | feature_pooled = box_features.mean(dim=[2, 3]) 144 | predictions = self.box_predictor(feature_pooled) 145 | 146 | if self.training: 147 | del features 148 | losses = self.box_predictor.losses(predictions, proposals) 149 | if self.mask_on: 150 | proposals, fg_selection_masks = select_foreground_proposals( 151 | proposals, self.num_classes 152 | ) 153 | mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] 154 | del box_features 155 | losses.update(self.mask_head(mask_features, proposals)) 156 | if self.attribute_on: 157 | losses.update(self.forward_attribute_loss(proposals, feature_pooled)) 158 | return [], losses 159 | else: 160 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 161 | pred_instances = self.forward_with_given_boxes(features, pred_instances) 162 | return pred_instances, {} 163 | 164 | def get_conv5_features(self, features): 165 | features = [features[f] for f in self.in_features] 166 | return self.res5(features[0]) 167 | 168 | 169 | @ROI_HEADS_REGISTRY.register() 170 | class AttributeStandardROIHeads(AttributeROIHeads, StandardROIHeads): 171 | """ 172 | An extension of StandardROIHeads to include attribute prediction. 173 | """ 174 | def __init__(self, cfg, input_shape): 175 | super(StandardROIHeads, self).__init__(cfg, input_shape) 176 | self._init_box_head(cfg, input_shape) 177 | self._init_mask_head(cfg, input_shape) 178 | self._init_keypoint_head(cfg, input_shape) 179 | 180 | def _init_box_head(self, cfg, input_shape): 181 | # fmt: off 182 | pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION 183 | pooler_scales = tuple(1.0 / input_shape[k].stride for k in self.in_features) 184 | sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO 185 | pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE 186 | self.train_on_pred_boxes = cfg.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES 187 | self.attribute_on = cfg.MODEL.ATTRIBUTE_ON 188 | # fmt: on 189 | 190 | in_channels = [input_shape[f].channels for f in self.in_features] 191 | assert len(set(in_channels)) == 1, in_channels 192 | in_channels = in_channels[0] 193 | 194 | self.box_pooler = ROIPooler( 195 | output_size=pooler_resolution, 196 | scales=pooler_scales, 197 | sampling_ratio=sampling_ratio, 198 | pooler_type=pooler_type, 199 | ) 200 | self.box_head = build_box_head( 201 | cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution) 202 | ) 203 | self.box_predictor = FastRCNNOutputLayers(cfg, self.box_head.output_shape) 204 | 205 | if self.attribute_on: 206 | self.attribute_predictor = AttributePredictor(cfg, self.box_head.output_shape.channels) 207 | 208 | def _forward_box(self, features, proposals): 209 | features = [features[f] for f in self.in_features] 210 | box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) 211 | box_features = self.box_head(box_features) 212 | predictions = self.box_predictor(box_features) 213 | 214 | if self.training: 215 | if self.train_on_pred_boxes: 216 | with torch.no_grad(): 217 | pred_boxes = self.box_predictor.predict_boxes_for_gt_classes( 218 | predictions, proposals 219 | ) 220 | for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes): 221 | proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image) 222 | losses = self.box_predictor.losses(predictions, proposals) 223 | if self.attribute_on: 224 | losses.update(self.forward_attribute_loss(proposals, box_features)) 225 | del box_features 226 | 227 | return losses 228 | else: 229 | pred_instances, _ = self.box_predictor.inference(predictions, proposals) 230 | return pred_instances 231 | 232 | def get_conv5_features(self, features): 233 | assert len(self.in_features) == 1 234 | 235 | features = [features[f] for f in self.in_features] 236 | return features[0] 237 | 238 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. --------------------------------------------------------------------------------