├── defrcn ├── __init__.py ├── evaluation │ ├── archs │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── resnet.cpython-37.pyc │ │ └── resnet.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── evaluator.cpython-37.pyc │ │ ├── testing.cpython-37.pyc │ │ ├── coco_evaluation.cpython-37.pyc │ │ ├── calibration_layer.cpython-37.pyc │ │ └── pascal_voc_evaluation.cpython-37.pyc │ ├── __init__.py │ ├── testing.py │ ├── evaluator.py │ ├── calibration_layer.py │ ├── coco_evaluation.py │ └── pascal_voc_evaluation.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── builtin.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ ├── meta_coco.cpython-37.pyc │ │ ├── meta_voc.cpython-37.pyc │ │ ├── pcb_common.cpython-37.pyc │ │ └── builtin_meta.cpython-37.pyc │ ├── pcb_common.py │ ├── meta_coco.py │ ├── builtin.py │ ├── meta_voc.py │ └── builtin_meta.py ├── modeling │ ├── meta_arch │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── gdl.cpython-37.pyc │ │ │ ├── build.cpython-37.pyc │ │ │ ├── rcnn.cpython-37.pyc │ │ │ └── __init__.cpython-37.pyc │ │ ├── build.py │ │ ├── gdl.py │ │ └── rcnn.py │ ├── __pycache__ │ │ └── __init__.cpython-37.pyc │ ├── roi_heads │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ ├── box_head.cpython-37.pyc │ │ │ ├── fast_rcnn.cpython-37.pyc │ │ │ └── roi_heads.cpython-37.pyc │ │ ├── __init__.py │ │ └── box_head.py │ └── __init__.py ├── __pycache__ │ └── __init__.cpython-37.pyc ├── checkpoint │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── detection_checkpoint.cpython-37.pyc │ └── detection_checkpoint.py ├── engine │ ├── __pycache__ │ │ ├── hooks.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── defaults.cpython-37.pyc │ ├── __init__.py │ └── hooks.py ├── solver │ ├── __pycache__ │ │ ├── build.cpython-37.pyc │ │ ├── __init__.cpython-37.pyc │ │ └── lr_scheduler.cpython-37.pyc │ ├── __init__.py │ ├── lr_scheduler.py │ └── build.py ├── utils │ ├── __pycache__ │ │ └── kdloss.cpython-37.pyc │ └── kdloss.py ├── config │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ ├── compat.cpython-37.pyc │ │ ├── config.cpython-37.pyc │ │ └── defaults.cpython-37.pyc │ ├── __init__.py │ ├── defaults.py │ ├── config.py │ └── compat.py └── dataloader │ ├── __pycache__ │ ├── build.cpython-37.pyc │ ├── __init__.cpython-37.pyc │ └── dataset_mapper.cpython-37.pyc │ ├── __init__.py │ ├── dataset_mapper.py │ └── build.py ├── requirements.txt ├── .DS_Store ├── assets ├── arch.png └── header.png ├── configs ├── Base-RCNN.yaml ├── coco │ ├── defrcn_det_r101_base.yaml │ ├── defrcn_gfsod_r101_novel_1shot_seedx.yaml │ ├── defrcn_gfsod_r101_novel_2shot_seedx.yaml │ ├── defrcn_gfsod_r101_novel_3shot_seedx.yaml │ ├── defrcn_gfsod_r101_novel_5shot_seedx.yaml │ ├── defrcn_gfsod_r101_novel_10shot_seedx.yaml │ └── defrcn_gfsod_r101_novel_30shot_seedx.yaml └── voc │ ├── defrcn_det_r101_base1.yaml │ ├── defrcn_gfsod_r101_novelx_10shot_seedx.yaml │ ├── defrcn_gfsod_r101_novelx_1shot_seedx.yaml │ ├── defrcn_gfsod_r101_novelx_2shot_seedx.yaml │ ├── defrcn_gfsod_r101_novelx_3shot_seedx.yaml │ └── defrcn_gfsod_r101_novelx_5shot_seedx.yaml ├── LICENSE ├── run_coco.sh ├── run_voc.sh ├── tools ├── extract_results.py ├── create_config.py └── model_surgery.py ├── main.py └── README.md /defrcn/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | sklearn 3 | -------------------------------------------------------------------------------- /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/.DS_Store -------------------------------------------------------------------------------- /defrcn/evaluation/archs/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet import resnet101 2 | -------------------------------------------------------------------------------- /assets/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/assets/arch.png -------------------------------------------------------------------------------- /assets/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/assets/header.png -------------------------------------------------------------------------------- /defrcn/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .builtin import register_all_voc, register_all_coco 2 | -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import META_ARCH_REGISTRY, build_model 2 | from .rcnn import GeneralizedRCNN 3 | -------------------------------------------------------------------------------- /defrcn/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/checkpoint/__init__.py: -------------------------------------------------------------------------------- 1 | from .detection_checkpoint import DetectionCheckpointer 2 | 3 | __all__ = ["DetectionCheckpointer"] 4 | -------------------------------------------------------------------------------- /defrcn/data/__pycache__/builtin.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/data/__pycache__/builtin.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/engine/__pycache__/hooks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/engine/__pycache__/hooks.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/solver/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/solver/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/utils/__pycache__/kdloss.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/utils/__pycache__/kdloss.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/config/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/config/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/config/__pycache__/compat.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/config/__pycache__/compat.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/config/__pycache__/config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/config/__pycache__/config.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/config/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/config/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/data/__pycache__/meta_coco.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/data/__pycache__/meta_coco.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/data/__pycache__/meta_voc.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/data/__pycache__/meta_voc.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/data/__pycache__/pcb_common.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/data/__pycache__/pcb_common.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/engine/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/engine/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/engine/__pycache__/defaults.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/engine/__pycache__/defaults.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/solver/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/solver/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/data/__pycache__/builtin_meta.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/data/__pycache__/builtin_meta.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/dataloader/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/dataloader/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .defaults import DefaultPredictor, DefaultTrainer, default_argument_parser, default_setup 2 | from .hooks import * 3 | -------------------------------------------------------------------------------- /defrcn/modeling/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/checkpoint/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/checkpoint/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/dataloader/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/dataloader/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/__pycache__/evaluator.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/__pycache__/evaluator.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/__pycache__/testing.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/__pycache__/testing.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/solver/__pycache__/lr_scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/solver/__pycache__/lr_scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/__pycache__/gdl.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/meta_arch/__pycache__/gdl.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/dataloader/__pycache__/dataset_mapper.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/dataloader/__pycache__/dataset_mapper.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/archs/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/archs/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/archs/__pycache__/resnet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/archs/__pycache__/resnet.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/__pycache__/build.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/meta_arch/__pycache__/build.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/__pycache__/rcnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/meta_arch/__pycache__/rcnn.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/__pycache__/coco_evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/__pycache__/coco_evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/meta_arch/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/roi_heads/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/roi_heads/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/roi_heads/__pycache__/box_head.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/roi_heads/__pycache__/box_head.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/__pycache__/calibration_layer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/__pycache__/calibration_layer.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/roi_heads/__pycache__/fast_rcnn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/roi_heads/__pycache__/fast_rcnn.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/modeling/roi_heads/__pycache__/roi_heads.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/modeling/roi_heads/__pycache__/roi_heads.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/checkpoint/__pycache__/detection_checkpoint.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/checkpoint/__pycache__/detection_checkpoint.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/evaluation/__pycache__/pascal_voc_evaluation.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZYN-1101/DandR/HEAD/defrcn/evaluation/__pycache__/pascal_voc_evaluation.cpython-37.pyc -------------------------------------------------------------------------------- /defrcn/solver/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import build_lr_scheduler, build_optimizer 2 | from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR 3 | 4 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 5 | -------------------------------------------------------------------------------- /defrcn/modeling/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .box_head import ROI_BOX_HEAD_REGISTRY, build_box_head 2 | from .roi_heads import ( 3 | ROI_HEADS_REGISTRY, ROIHeads, StandardROIHeads, build_roi_heads, select_foreground_proposals) 4 | -------------------------------------------------------------------------------- /defrcn/modeling/__init__.py: -------------------------------------------------------------------------------- 1 | from .meta_arch import META_ARCH_REGISTRY, GeneralizedRCNN, build_model 2 | from .roi_heads import ( 3 | ROI_BOX_HEAD_REGISTRY, ROI_HEADS_REGISTRY, ROIHeads, StandardROIHeads, build_box_head, 4 | build_roi_heads) 5 | -------------------------------------------------------------------------------- /defrcn/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .compat import downgrade_config, upgrade_config 2 | from .config import CfgNode, get_cfg, global_cfg, set_global_cfg 3 | 4 | __all__ = [ 5 | "CfgNode", 6 | "get_cfg", 7 | "global_cfg", 8 | "set_global_cfg", 9 | "downgrade_config", 10 | "upgrade_config", 11 | ] -------------------------------------------------------------------------------- /defrcn/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .coco_evaluation import COCOEvaluator 2 | from .pascal_voc_evaluation import PascalVOCDetectionEvaluator 3 | from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset 4 | from .testing import print_csv_format, verify_results 5 | 6 | __all__ = [k for k in globals().keys() if not k.startswith("_")] 7 | -------------------------------------------------------------------------------- /configs/Base-RCNN.yaml: -------------------------------------------------------------------------------- 1 | VERSION: 2 2 | MODEL: 3 | META_ARCHITECTURE: "GeneralizedRCNN" 4 | RPN: 5 | PRE_NMS_TOPK_TEST: 6000 6 | POST_NMS_TOPK_TEST: 1000 7 | ROI_HEADS: 8 | NAME: "Res5ROIHeads" 9 | DATASETS: 10 | TRAIN: ("coco_2017_train",) 11 | TEST: ("coco_2017_val",) 12 | SOLVER: 13 | IMS_PER_BATCH: 16 14 | BASE_LR: 0.02 15 | STEPS: (60000, 80000) 16 | MAX_ITER: 90000 17 | INPUT: 18 | MIN_SIZE_TRAIN: (640, 672, 704, 736, 768, 800) 19 | -------------------------------------------------------------------------------- /defrcn/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from detectron2.data import transforms 2 | from detectron2.data.catalog import DatasetCatalog, MetadataCatalog, Metadata 3 | from detectron2.data.common import DatasetFromList, MapDataset 4 | from .build import ( 5 | build_batch_data_loader, 6 | build_detection_test_loader, 7 | build_detection_train_loader, 8 | get_detection_dataset_dicts, 9 | load_proposals_into_dataset, 10 | print_instances_class_histogram, 11 | ) 12 | from .dataset_mapper import DatasetMapper 13 | 14 | __all__ = [k for k in globals().keys() if not k.startswith("_")] -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/build.py: -------------------------------------------------------------------------------- 1 | from detectron2.utils.registry import Registry 2 | 3 | META_ARCH_REGISTRY = Registry("META_ARCH") 4 | META_ARCH_REGISTRY.__doc__ = """ 5 | Registry for meta-architectures, i.e. the whole model. 6 | 7 | The registered object will be called with `obj(cfg)` 8 | and expected to return a `nn.Module` object. 9 | """ 10 | 11 | 12 | def build_model(cfg): 13 | """ 14 | Built the whole model, defined by `cfg.MODEL.META_ARCHITECTURE`. 15 | """ 16 | meta_arch = cfg.MODEL.META_ARCHITECTURE 17 | return META_ARCH_REGISTRY.get(meta_arch)(cfg) 18 | -------------------------------------------------------------------------------- /configs/coco/defrcn_det_r101_base.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/ImageNet/Pretrain/Weight" 4 | MASK_ON: False 5 | RESNETS: 6 | DEPTH: 101 7 | RPN: 8 | ENABLE_DECOUPLE: True 9 | BACKWARD_SCALE: 0.0 10 | ROI_HEADS: 11 | NAME: "AuxRes5ROIHeads" 12 | NUM_CLASSES: 60 13 | ENABLE_DECOUPLE: True 14 | BACKWARD_SCALE: 0.75 15 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 16 | AUX_MODEL: 17 | SEMANTIC_DIM: 512 18 | INFERENCE_WITH_AUX: False 19 | DATASETS: 20 | TRAIN: ('coco14_trainval_base',) 21 | TEST: ('coco14_test_base',) 22 | SOLVER: 23 | IMS_PER_BATCH: 8 24 | BASE_LR: 0.01 25 | STEPS: (170000, 200000) 26 | MAX_ITER: 220000 27 | CHECKPOINT_PERIOD: 200000 28 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/voc/defrcn_det_r101_base1.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/ImageNet/Pretrain/Weight" 4 | MASK_ON: False 5 | RESNETS: 6 | DEPTH: 101 7 | RPN: 8 | ENABLE_DECOUPLE: True 9 | BACKWARD_SCALE: 0.0 10 | ROI_HEADS: 11 | NAME: "AuxRes5ROIHeads" 12 | NUM_CLASSES: 15 13 | ENABLE_DECOUPLE: True 14 | BACKWARD_SCALE: 0.75 15 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 16 | AUX_MODEL: 17 | SEMANTIC_DIM: 512 18 | INFERENCE_WITH_AUX: False 19 | INPUT: 20 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 21 | MIN_SIZE_TEST: 800 22 | DATASETS: 23 | TRAIN: ('voc_2007_trainval_base1', 'voc_2012_trainval_base1') 24 | TEST: ('voc_2007_test_base1',) 25 | SOLVER: 26 | IMS_PER_BATCH: 8 27 | BASE_LR: 0.01 28 | STEPS: (20000, 26600) 29 | MAX_ITER: 30000 30 | CHECKPOINT_PERIOD: 200000 31 | WARMUP_ITERS: 100 32 | OUTPUT_DIR: "/Path/to/Output/Dir" 33 | -------------------------------------------------------------------------------- /configs/coco/defrcn_gfsod_r101_novel_1shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | NUM_CLASSES: 80 16 | FREEZE_FEAT: True 17 | CLS_DROPOUT: True 18 | ENABLE_DECOUPLE: True 19 | BACKWARD_SCALE: 0.01 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 5.0 25 | WEIGHT_KD: 5.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.5 28 | KD_BETA: 4.0 29 | DATASETS: 30 | TRAIN: ('coco14_trainval_all_1shot_seedx',) 31 | TEST: ('coco14_test_all',) 32 | SOLVER: 33 | IMS_PER_BATCH: 8 34 | BASE_LR: 0.005 35 | STEPS: (5120,) 36 | MAX_ITER: 6400 37 | CHECKPOINT_PERIOD: 100000 38 | WARMUP_ITERS: 0 39 | TEST: 40 | PCB_ENABLE: True 41 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 42 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/coco/defrcn_gfsod_r101_novel_2shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | NUM_CLASSES: 80 16 | FREEZE_FEAT: True 17 | CLS_DROPOUT: True 18 | ENABLE_DECOUPLE: True 19 | BACKWARD_SCALE: 0.01 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 5.0 25 | WEIGHT_KD: 5.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.5 28 | KD_BETA: 4.0 29 | DATASETS: 30 | TRAIN: ('coco14_trainval_all_2shot_seedx',) 31 | TEST: ('coco14_test_all',) 32 | SOLVER: 33 | IMS_PER_BATCH: 8 34 | BASE_LR: 0.005 35 | STEPS: (6400,) 36 | MAX_ITER: 8000 37 | CHECKPOINT_PERIOD: 100000 38 | WARMUP_ITERS: 0 39 | TEST: 40 | PCB_ENABLE: True 41 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 42 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/coco/defrcn_gfsod_r101_novel_3shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | NUM_CLASSES: 80 16 | FREEZE_FEAT: True 17 | CLS_DROPOUT: True 18 | ENABLE_DECOUPLE: True 19 | BACKWARD_SCALE: 0.01 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 5.0 25 | WEIGHT_KD: 5.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.5 28 | KD_BETA: 4.0 29 | DATASETS: 30 | TRAIN: ('coco14_trainval_all_3shot_seedx',) 31 | TEST: ('coco14_test_all',) 32 | SOLVER: 33 | IMS_PER_BATCH: 8 34 | BASE_LR: 0.005 35 | STEPS: (7680,) 36 | MAX_ITER: 9600 37 | CHECKPOINT_PERIOD: 100000 38 | WARMUP_ITERS: 0 39 | TEST: 40 | PCB_ENABLE: True 41 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 42 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/coco/defrcn_gfsod_r101_novel_5shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | NUM_CLASSES: 80 16 | FREEZE_FEAT: True 17 | CLS_DROPOUT: True 18 | ENABLE_DECOUPLE: True 19 | BACKWARD_SCALE: 0.01 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 5.0 25 | WEIGHT_KD: 5.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.5 28 | KD_BETA: 4.0 29 | DATASETS: 30 | TRAIN: ('coco14_trainval_all_5shot_seedx',) 31 | TEST: ('coco14_test_all',) 32 | SOLVER: 33 | IMS_PER_BATCH: 8 34 | BASE_LR: 0.005 35 | STEPS: (9600,) 36 | MAX_ITER: 12000 37 | CHECKPOINT_PERIOD: 100000 38 | WARMUP_ITERS: 0 39 | TEST: 40 | PCB_ENABLE: True 41 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 42 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/coco/defrcn_gfsod_r101_novel_10shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | NUM_CLASSES: 80 16 | FREEZE_FEAT: True 17 | CLS_DROPOUT: True 18 | ENABLE_DECOUPLE: True 19 | BACKWARD_SCALE: 0.01 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 5.0 25 | WEIGHT_KD: 5.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.5 28 | KD_BETA: 4.0 29 | DATASETS: 30 | TRAIN: ('coco14_trainval_all_10shot_seedx',) 31 | TEST: ('coco14_test_all',) 32 | SOLVER: 33 | IMS_PER_BATCH: 8 34 | BASE_LR: 0.005 35 | STEPS: (9600,) 36 | MAX_ITER: 12000 37 | CHECKPOINT_PERIOD: 100000 38 | WARMUP_ITERS: 0 39 | TEST: 40 | PCB_ENABLE: True 41 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 42 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/coco/defrcn_gfsod_r101_novel_30shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | NUM_CLASSES: 80 16 | FREEZE_FEAT: True 17 | CLS_DROPOUT: True 18 | ENABLE_DECOUPLE: True 19 | BACKWARD_SCALE: 0.01 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 5.0 25 | WEIGHT_KD: 5.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.5 28 | KD_BETA: 4.0 29 | DATASETS: 30 | TRAIN: ('coco14_trainval_all_30shot_seedx',) 31 | TEST: ('coco14_test_all',) 32 | SOLVER: 33 | IMS_PER_BATCH: 8 34 | BASE_LR: 0.005 35 | STEPS: (19200,) 36 | MAX_ITER: 24000 37 | CHECKPOINT_PERIOD: 100000 38 | WARMUP_ITERS: 0 39 | TEST: 40 | PCB_ENABLE: True 41 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 42 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 er-muyue 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/voc/defrcn_gfsod_r101_novelx_10shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | ENABLE_DECOUPLE: True 16 | BACKWARD_SCALE: 0.001 17 | NUM_CLASSES: 20 18 | FREEZE_FEAT: True 19 | CLS_DROPOUT: True 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 10.0 25 | WEIGHT_KD: 1.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.2 28 | KD_BETA: 10.0 29 | INPUT: 30 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 31 | MIN_SIZE_TEST: 800 32 | DATASETS: 33 | TRAIN: ("voc_2007_trainval_allx_10shot_seedx", ) 34 | TEST: ('voc_2007_test_allx',) 35 | SOLVER: 36 | IMS_PER_BATCH: 8 37 | BASE_LR: 0.005 38 | STEPS: (6400,) 39 | MAX_ITER: 8000 40 | CHECKPOINT_PERIOD: 100000 41 | WARMUP_ITERS: 0 42 | TEST: 43 | PCB_ENABLE: True 44 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 45 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/voc/defrcn_gfsod_r101_novelx_1shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | ENABLE_DECOUPLE: True 16 | BACKWARD_SCALE: 0.001 17 | NUM_CLASSES: 20 18 | FREEZE_FEAT: True 19 | CLS_DROPOUT: True 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 10.0 25 | WEIGHT_KD: 1.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.2 28 | KD_BETA: 10.0 29 | INPUT: 30 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 31 | MIN_SIZE_TEST: 800 32 | DATASETS: 33 | TRAIN: ("voc_2007_trainval_allx_1shot_seedx", ) 34 | TEST: ('voc_2007_test_allx',) 35 | SOLVER: 36 | IMS_PER_BATCH: 8 37 | BASE_LR: 0.005 38 | STEPS: (1600, ) 39 | MAX_ITER: 2000 40 | CHECKPOINT_PERIOD: 100000 41 | WARMUP_ITERS: 0 42 | TEST: 43 | PCB_ENABLE: True 44 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 45 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/voc/defrcn_gfsod_r101_novelx_2shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | ENABLE_DECOUPLE: True 16 | BACKWARD_SCALE: 0.001 17 | NUM_CLASSES: 20 18 | FREEZE_FEAT: True 19 | CLS_DROPOUT: True 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 10.0 25 | WEIGHT_KD: 1.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.2 28 | KD_BETA: 10.0 29 | INPUT: 30 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 31 | MIN_SIZE_TEST: 800 32 | DATASETS: 33 | TRAIN: ("voc_2007_trainval_allx_2shot_seedx", ) 34 | TEST: ('voc_2007_test_allx',) 35 | SOLVER: 36 | IMS_PER_BATCH: 8 37 | BASE_LR: 0.005 38 | STEPS: (2400,) 39 | MAX_ITER: 3000 40 | CHECKPOINT_PERIOD: 100000 41 | WARMUP_ITERS: 0 42 | TEST: 43 | PCB_ENABLE: True 44 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 45 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/voc/defrcn_gfsod_r101_novelx_3shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | ENABLE_DECOUPLE: True 16 | BACKWARD_SCALE: 0.001 17 | NUM_CLASSES: 20 18 | FREEZE_FEAT: True 19 | CLS_DROPOUT: True 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 10.0 25 | WEIGHT_KD: 1.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.2 28 | KD_BETA: 10.0 29 | INPUT: 30 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 31 | MIN_SIZE_TEST: 800 32 | DATASETS: 33 | TRAIN: ("voc_2007_trainval_allx_3shot_seedx", ) 34 | TEST: ('voc_2007_test_allx',) 35 | SOLVER: 36 | IMS_PER_BATCH: 8 37 | BASE_LR: 0.005 38 | STEPS: (3200,) 39 | MAX_ITER: 4000 40 | CHECKPOINT_PERIOD: 100000 41 | WARMUP_ITERS: 0 42 | TEST: 43 | PCB_ENABLE: True 44 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 45 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /configs/voc/defrcn_gfsod_r101_novelx_5shot_seedx.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../Base-RCNN.yaml" 2 | MODEL: 3 | WEIGHTS: "/Path/to/Base/Pretrain/Weight" 4 | MASK_ON: False 5 | BACKBONE: 6 | FREEZE: False 7 | RESNETS: 8 | DEPTH: 101 9 | RPN: 10 | ENABLE_DECOUPLE: True 11 | BACKWARD_SCALE: 0.0 12 | FREEZE: False 13 | ROI_HEADS: 14 | NAME: "AuxRes5ROIHeads" 15 | ENABLE_DECOUPLE: True 16 | BACKWARD_SCALE: 0.001 17 | NUM_CLASSES: 20 18 | FREEZE_FEAT: True 19 | CLS_DROPOUT: True 20 | OUTPUT_LAYER: "AuxFastRCNNOutputLayers" 21 | AUX_MODEL: 22 | SEMANTIC_DIM: 512 23 | INFERENCE_WITH_AUX: False 24 | KD_TEMPERATURE: 10.0 25 | WEIGHT_KD: 1.0 26 | WEIGHT_CE: 1.0 27 | KD_ALPHA: 0.2 28 | KD_BETA: 10.0 29 | INPUT: 30 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 31 | MIN_SIZE_TEST: 800 32 | DATASETS: 33 | TRAIN: ("voc_2007_trainval_allx_5shot_seedx", ) 34 | TEST: ('voc_2007_test_allx',) 35 | SOLVER: 36 | IMS_PER_BATCH: 8 37 | BASE_LR: 0.005 38 | STEPS: (4000,) 39 | MAX_ITER: 5000 40 | CHECKPOINT_PERIOD: 100000 41 | WARMUP_ITERS: 0 42 | TEST: 43 | PCB_ENABLE: True 44 | PCB_MODELPATH: "/Path/to/ImageNet/Pre-Train/Weight" 45 | OUTPUT_DIR: "/Path/to/Output/Dir" -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/gdl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Function 4 | 5 | 6 | class GradientDecoupleLayer(Function): 7 | 8 | @staticmethod 9 | def forward(ctx, x, _lambda): 10 | ctx._lambda = _lambda 11 | return x 12 | 13 | @staticmethod 14 | def backward(ctx, grad_output): 15 | grad_output = grad_output * ctx._lambda 16 | return grad_output, None 17 | 18 | 19 | class AffineLayer(nn.Module): 20 | def __init__(self, num_channels, bias=False): 21 | super(AffineLayer, self).__init__() 22 | weight = torch.FloatTensor(1, num_channels, 1, 1).fill_(1) 23 | self.weight = nn.Parameter(weight, requires_grad=True) 24 | 25 | self.bias = None 26 | if bias: 27 | bias = torch.FloatTensor(1, num_channels, 1, 1).fill_(0) 28 | self.bias = nn.Parameter(bias, requires_grad=True) 29 | 30 | def forward(self, X): 31 | out = X * self.weight.expand_as(X) 32 | if self.bias is not None: 33 | out = out + self.bias.expand_as(X) 34 | return out 35 | 36 | 37 | def decouple_layer(x, _lambda): 38 | return GradientDecoupleLayer.apply(x, _lambda) 39 | -------------------------------------------------------------------------------- /defrcn/config/defaults.py: -------------------------------------------------------------------------------- 1 | from detectron2.config.defaults import _C 2 | from detectron2.config import CfgNode as CN 3 | 4 | _CC = _C 5 | 6 | # ----------- Backbone ----------- # 7 | _CC.MODEL.BACKBONE.FREEZE = False 8 | _CC.MODEL.BACKBONE.FREEZE_AT = 3 9 | 10 | # ------------- RPN -------------- # 11 | _CC.MODEL.RPN.FREEZE = False 12 | _CC.MODEL.RPN.ENABLE_DECOUPLE = False 13 | _CC.MODEL.RPN.BACKWARD_SCALE = 1.0 14 | 15 | # ------------- ROI -------------- # 16 | _CC.MODEL.ROI_HEADS.NAME = "Res5ROIHeads" 17 | _CC.MODEL.ROI_HEADS.FREEZE_FEAT = False 18 | _CC.MODEL.ROI_HEADS.ENABLE_DECOUPLE = False 19 | _CC.MODEL.ROI_HEADS.BACKWARD_SCALE = 1.0 20 | _CC.MODEL.ROI_HEADS.OUTPUT_LAYER = "FastRCNNOutputLayers" 21 | _CC.MODEL.ROI_HEADS.CLS_DROPOUT = False 22 | _CC.MODEL.ROI_HEADS.DROPOUT_RATIO = 0.8 23 | _CC.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 7 # for faster 24 | 25 | _CC.AUX_MODEL = CN() 26 | _CC.AUX_MODEL.SEMANTIC_DIM = 512 27 | _CC.AUX_MODEL.INFERENCE_WITH_AUX = False 28 | _CC.AUX_MODEL.KD_TEMPERATURE = 5.0 29 | _CC.AUX_MODEL.WEIGHT_CE = 1.0 30 | _CC.AUX_MODEL.WEIGHT_KD = 1.0 31 | _CC.AUX_MODEL.KD_ALPHA = 0.5 32 | _CC.AUX_MODEL.KD_BETA = 4.0 33 | 34 | # ------------- TEST ------------- # 35 | _CC.TEST.PCB_ENABLE = False 36 | _CC.TEST.PCB_MODELTYPE = 'resnet' # res-like 37 | _CC.TEST.PCB_MODELPATH = "" 38 | _CC.TEST.PCB_ALPHA = 0.50 39 | _CC.TEST.PCB_UPPER = 1.0 40 | _CC.TEST.PCB_LOWER = 0.05 41 | 42 | # ------------ Other ------------- # 43 | _CC.SOLVER.WEIGHT_DECAY = 5e-5 44 | _CC.MUTE_HEADER = True 45 | -------------------------------------------------------------------------------- /defrcn/data/pcb_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import copy 3 | import itertools 4 | import logging 5 | import numpy as np 6 | import pickle 7 | import random 8 | import torch.utils.data as data 9 | from torch.utils.data.sampler import Sampler 10 | 11 | from detectron2.utils.serialize import PicklableWrapper 12 | from detectron2.data.common import AspectRatioGroupedDataset 13 | 14 | __all__ = ["PCBAspectRatioGroupedDataset"] 15 | 16 | 17 | class PCBAspectRatioGroupedDataset(AspectRatioGroupedDataset): 18 | """ 19 | Batch data that have similar aspect ratio together. 20 | In this implementation, images whose aspect ratio < (or >) 1 will 21 | be batched together. 22 | This improves training speed because the images then need less padding 23 | to form a batch. 24 | 25 | It assumes the underlying dataset produces dicts with "width" and "height" keys. 26 | It will then produce a list of original dicts with length = batch_size, 27 | all with similar aspect ratios. 28 | """ 29 | 30 | def __init__(self, dataset, batch_size, proto_dataset): 31 | """ 32 | Args: 33 | dataset: an iterable. Each element must be a dict with keys 34 | "width" and "height", which will be used to batch data. 35 | batch_size (int): 36 | """ 37 | self.dataset = dataset 38 | self.batch_size = batch_size 39 | self._buckets = [[] for _ in range(2)] 40 | self.proto_dataset = proto_dataset 41 | -------------------------------------------------------------------------------- /run_coco.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | SAVEDIR=checkpoints/coco/ 4 | IMAGENET_PRETRAIN=ImageNetPretrained/MSRA/R-101.pkl # <-- change it to you path 5 | IMAGENET_PRETRAIN_TORCH=ImageNetPretrained/torchvision/resnet101-5d3b4d8f.pth # <-- change it to you path 6 | 7 | python3 main.py --num-gpus 1 --config-file configs/coco/defrcn_det_r101_base.yaml \ 8 | --opts MODEL.WEIGHTS ${IMAGENET_PRETRAIN} \ 9 | OUTPUT_DIR ${SAVEDIR}/defrcn_det_r101_base 10 | 11 | python3 tools/model_surgery.py --dataset coco --method randinit \ 12 | --src-path ${SAVEDIR}/defrcn_det_r101_base/model_final.pth \ 13 | --save-dir ${SAVEDIR}/defrcn_det_r101_base 14 | BASE_WEIGHT=checkpoints/coco/defrcn_det_r101_base/model_reset_surgery.pth 15 | 16 | # ------------------------------ Novel Fine-tuning ------------------------------- # 17 | for seed in 0 1 2 3 4 5 6 7 8 9 18 | do 19 | for shot in 1 2 3 5 10 30 20 | do 21 | python3 tools/create_config.py --dataset coco14 --config_root configs/coco/${EXPNAME} \ 22 | --shot ${shot} --seed ${seed} --setting 'gfsod' 23 | CONFIG_PATH=configs/coco/${EXPNAME}/defrcn_gfsod_r101_novel_${shot}shot_seed${seed}.yaml 24 | OUTPUT_DIR=${SAVEDIR}/defrcn_gfsod_r101_novel/tfa-like/${shot}shot_seed${seed} 25 | python3 main.py --num-gpus 1 --config-file ${CONFIG_PATH} \ 26 | --opts MODEL.WEIGHTS ${BASE_WEIGHT} OUTPUT_DIR ${OUTPUT_DIR} \ 27 | TEST.PCB_MODELPATH ${IMAGENET_PRETRAIN_TORCH} 28 | rm ${CONFIG_PATH} 29 | done 30 | done 31 | -------------------------------------------------------------------------------- /run_voc.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | SPLIT_ID=$1 3 | SAVE_DIR=checkpoints/voc/ 4 | IMAGENET_PRETRAIN=ImageNetPretrained/MSRA/R-101.pkl # <-- change it to you path 5 | IMAGENET_PRETRAIN_TORCH=ImageNetPretrained/torchvision/resnet101-5d3b4d8f.pth # <-- change it to you path 6 | SEED=7882548 7 | 8 | python3 main.py --num-gpus 1 --config-file configs/voc/defrcn_det_r101_base${SPLIT_ID}.yaml \ 9 | --opts MODEL.WEIGHTS ${IMAGENET_PRETRAIN} \ 10 | OUTPUT_DIR ${SAVE_DIR}/defrcn_det_r101_base${SPLIT_ID} 11 | 12 | # ----------------------------- Model Preparation --------------------------------- # 13 | python3 tools/model_surgery.py --dataset voc --method randinit \ 14 | --src-path ${SAVE_DIR}/defrcn_det_r101_base${SPLIT_ID}/model_final.pth \ 15 | --save-dir ${SAVE_DIR}/defrcn_det_r101_base${SPLIT_ID} 16 | 17 | # ------------------------------ Novel Fine-tuning ------------------------------- # 18 | BASE_WEIGHT=checkpoints/voc/defrcn_det_r101_base${SPLIT_ID}/model_reset_surgery.pth 19 | 20 | for seed in 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 21 | do 22 | for shot in 1 2 3 5 10 23 | do 24 | python3 tools/create_config.py --dataset voc --config_root configs/voc/${EXP_NAME} \ 25 | --shot ${shot} --seed ${seed} --setting 'gfsod' --split ${SPLIT_ID} 26 | CONFIG_PATH=configs/voc/defrcn_gfsod_r101_novel${SPLIT_ID}_${shot}shot_seed${seed}.yaml 27 | OUTPUT_DIR=${SAVE_DIR}/defrcn_gfsod_r101_novel${SPLIT_ID}/tfa-like/${shot}shot_seed${seed} 28 | python3 main.py --num-gpus 1 --config-file ${CONFIG_PATH} \ 29 | --opts MODEL.WEIGHTS ${BASE_WEIGHT} OUTPUT_DIR ${OUTPUT_DIR} \ 30 | TEST.PCB_MODELPATH ${IMAGENET_PRETRAIN_TORCH} SEED ${SEED} 31 | rm ${CONFIG_PATH} 32 | done 33 | done 34 | 35 | python3 tools/extract_results.py --res-dir ${SAVE_DIR}/defrcn_gfsod_r101_novel${SPLIT_ID}/tfa-like --shot-list 1 2 3 5 10 # surmarize all results 36 | -------------------------------------------------------------------------------- /tools/extract_results.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import numpy as np 5 | from tabulate import tabulate 6 | 7 | 8 | def main(): 9 | parser = argparse.ArgumentParser() 10 | parser.add_argument('--res-dir', type=str, default='', help='Path to the results') 11 | parser.add_argument('--shot-list', type=int, nargs='+', default=[10], help='') 12 | args = parser.parse_args() 13 | 14 | wf = open(os.path.join(args.res_dir, 'results.txt'), 'w') 15 | 16 | for shot in args.shot_list: 17 | 18 | file_paths = [] 19 | for fid, fname in enumerate(os.listdir(args.res_dir)): 20 | if fname.split('_')[0] != '{}shot'.format(shot): 21 | continue 22 | _dir = os.path.join(args.res_dir, fname) 23 | if not os.path.isdir(_dir): 24 | continue 25 | file_paths.append(os.path.join(_dir, 'log.txt')) 26 | 27 | header, results = [], [] 28 | for fid, fpath in enumerate(sorted(file_paths)): 29 | lineinfos = open(fpath).readlines() 30 | if fid == 0: 31 | res_info = lineinfos[-2].strip() 32 | header = res_info.split(':')[-1].split(',') 33 | res_info = lineinfos[-1].strip() 34 | results.append([fid] + [float(x) for x in res_info.split(':')[-1].split(',')]) 35 | 36 | results_np = np.array(results) 37 | avg = np.mean(results_np, axis=0).tolist() 38 | cid = [1.96 * s / math.sqrt(results_np.shape[0]) for s in np.std(results_np, axis=0)] 39 | results.append(['μ'] + avg[1:]) 40 | results.append(['c'] + cid[1:]) 41 | 42 | table = tabulate( 43 | results, 44 | tablefmt="pipe", 45 | floatfmt=".2f", 46 | headers=[''] + header, 47 | numalign="left", 48 | ) 49 | 50 | wf.write('--> {}-shot\n'.format(shot)) 51 | wf.write('{}\n\n'.format(table)) 52 | wf.flush() 53 | wf.close() 54 | 55 | print('Reformat all results -> {}'.format(os.path.join(args.res_dir, 'results.txt'))) 56 | 57 | 58 | if __name__ == '__main__': 59 | main() 60 | -------------------------------------------------------------------------------- /defrcn/evaluation/testing.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import numpy as np 3 | import pprint 4 | import sys 5 | from collections import Mapping, OrderedDict 6 | 7 | 8 | def print_csv_format(results): 9 | """ 10 | Print main metrics in a format similar to Detectron, 11 | so that they are easy to copypaste into a spreadsheet. 12 | 13 | Args: 14 | results (OrderedDict[dict]): task_name -> {metric -> score} 15 | """ 16 | assert isinstance(results, OrderedDict), results # unordered results cannot be properly printed 17 | logger = logging.getLogger(__name__) 18 | for task, res in results.items(): 19 | # Don't print "AP-category" metrics since they are usually not tracked. 20 | important_res = [(k, v) for k, v in res.items() if "-" not in k] 21 | logger.info("copypaste: Task: {}".format(task)) 22 | logger.info("copypaste: " + ",".join([k[0] for k in important_res])) 23 | logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res])) 24 | 25 | 26 | def verify_results(cfg, results): 27 | """ 28 | Args: 29 | results (OrderedDict[dict]): task_name -> {metric -> score} 30 | 31 | Returns: 32 | bool: whether the verification succeeds or not 33 | """ 34 | expected_results = cfg.TEST.EXPECTED_RESULTS 35 | if not len(expected_results): 36 | return True 37 | 38 | ok = True 39 | for task, metric, expected, tolerance in expected_results: 40 | actual = results[task][metric] 41 | if not np.isfinite(actual): 42 | ok = False 43 | diff = abs(actual - expected) 44 | if diff > tolerance: 45 | ok = False 46 | 47 | logger = logging.getLogger(__name__) 48 | if not ok: 49 | logger.error("Result verification failed!") 50 | logger.error("Expected Results: " + str(expected_results)) 51 | logger.error("Actual Results: " + pprint.pformat(results)) 52 | 53 | sys.exit(1) 54 | else: 55 | logger.info("Results verification passed.") 56 | return ok 57 | 58 | 59 | def flatten_results_dict(results): 60 | """ 61 | Expand a hierarchical dict of scalars into a flat dict of scalars. 62 | If results[k1][k2][k3] = v, the returned dict will have the entry 63 | {"k1/k2/k3": v}. 64 | 65 | Args: 66 | results (dict): 67 | """ 68 | r = {} 69 | for k, v in results.items(): 70 | if isinstance(v, Mapping): 71 | v = flatten_results_dict(v) 72 | for kk, vv in v.items(): 73 | r[k + "/" + kk] = vv 74 | else: 75 | r[k] = v 76 | return r 77 | -------------------------------------------------------------------------------- /defrcn/checkpoint/detection_checkpoint.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import detectron2.utils.comm as comm 3 | from fvcore.common.file_io import PathManager 4 | from fvcore.common.checkpoint import Checkpointer 5 | from detectron2.checkpoint.c2_model_loading import align_and_update_state_dicts 6 | 7 | 8 | class DetectionCheckpointer(Checkpointer): 9 | """ 10 | Same as :class:`Checkpointer`, but is able to handle models in detectron & detectron2 11 | model zoo, and apply conversions for legacy models. 12 | """ 13 | 14 | def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): 15 | is_main_process = comm.is_main_process() 16 | super().__init__( 17 | model, 18 | save_dir, 19 | save_to_disk=is_main_process if save_to_disk is None else save_to_disk, 20 | **checkpointables, 21 | ) 22 | 23 | def _load_file(self, filename): 24 | if filename.endswith(".pkl"): 25 | with PathManager.open(filename, "rb") as f: 26 | data = pickle.load(f, encoding="latin1") 27 | if "model" in data and "__author__" in data: 28 | # file is in Detectron2 model zoo format 29 | self.logger.info("Reading a file from '{}'".format(data["__author__"])) 30 | return data 31 | else: 32 | # assume file is from Caffe2 / Detectron1 model zoo 33 | if "blobs" in data: 34 | # Detection models have "blobs", but ImageNet models don't 35 | data = data["blobs"] 36 | data = {k: v for k, v in data.items() if not k.endswith("_momentum")} 37 | return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} 38 | 39 | loaded = super()._load_file(filename) # load native pth checkpoint 40 | if "model" not in loaded: 41 | loaded = {"model": loaded} 42 | return loaded 43 | 44 | def _load_model(self, checkpoint): 45 | if checkpoint.get("matching_heuristics", False): 46 | self._convert_ndarray_to_tensor(checkpoint["model"]) 47 | # convert weights by name-matching heuristics 48 | model_state_dict = self.model.state_dict() 49 | align_and_update_state_dicts( 50 | model_state_dict, 51 | checkpoint["model"], 52 | c2_conversion=checkpoint.get("__author__", None) == "Caffe2", 53 | ) 54 | checkpoint["model"] = model_state_dict 55 | # for non-caffe2 models, use standard ways to load it 56 | super()._load_model(checkpoint) 57 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from detectron2.utils import comm 3 | from detectron2.engine import launch 4 | from detectron2.data import MetadataCatalog 5 | from detectron2.checkpoint import DetectionCheckpointer 6 | from defrcn.config import get_cfg, set_global_cfg 7 | from defrcn.evaluation import DatasetEvaluators, verify_results 8 | from defrcn.engine import DefaultTrainer, default_argument_parser, default_setup 9 | 10 | 11 | class Trainer(DefaultTrainer): 12 | 13 | @classmethod 14 | def build_evaluator(cls, cfg, dataset_name, output_folder=None): 15 | if output_folder is None: 16 | output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") 17 | evaluator_list = [] 18 | evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type 19 | if evaluator_type == "coco": 20 | from defrcn.evaluation import COCOEvaluator 21 | evaluator_list.append(COCOEvaluator(dataset_name, True, output_folder)) 22 | if evaluator_type == "pascal_voc": 23 | from defrcn.evaluation import PascalVOCDetectionEvaluator 24 | return PascalVOCDetectionEvaluator(dataset_name) 25 | if len(evaluator_list) == 0: 26 | raise NotImplementedError( 27 | "no Evaluator for the dataset {} with the type {}".format( 28 | dataset_name, evaluator_type 29 | ) 30 | ) 31 | if len(evaluator_list) == 1: 32 | return evaluator_list[0] 33 | return DatasetEvaluators(evaluator_list) 34 | 35 | 36 | def setup(args): 37 | cfg = get_cfg() 38 | cfg.merge_from_file(args.config_file) 39 | if args.opts: 40 | cfg.merge_from_list(args.opts) 41 | cfg.freeze() 42 | set_global_cfg(cfg) 43 | default_setup(cfg, args) 44 | return cfg 45 | 46 | 47 | def main(args): 48 | cfg = setup(args) 49 | 50 | if args.eval_only: 51 | model = Trainer.build_model(cfg) 52 | DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( 53 | cfg.MODEL.WEIGHTS, resume=args.resume 54 | ) 55 | res = Trainer.test(cfg, model) 56 | if comm.is_main_process(): 57 | verify_results(cfg, res) 58 | return res 59 | 60 | trainer = Trainer(cfg) 61 | trainer.resume_or_load(resume=args.resume) 62 | return trainer.train() 63 | 64 | 65 | if __name__ == "__main__": 66 | args = default_argument_parser().parse_args() 67 | launch( 68 | main, 69 | args.num_gpus, 70 | num_machines=args.num_machines, 71 | machine_rank=args.machine_rank, 72 | dist_url=args.dist_url, 73 | args=(args,), 74 | ) 75 | -------------------------------------------------------------------------------- /tools/create_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument('--dataset', type=str, default='coco14', help='', choices=['coco14', 'voc']) 8 | parser.add_argument('--config_root', type=str, default='', help='the path to config dir') 9 | parser.add_argument('--shot', type=int, default=1, help='shot to run experiments over') 10 | parser.add_argument('--seed', type=int, default=0, help='seed to run experiments over') 11 | parser.add_argument('--setting', type=str, default='fsod', choices=['fsod', 'gfsod']) 12 | parser.add_argument('--split', type=int, default=1, help='only for voc') 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def load_config_file(yaml_path): 18 | fpath = os.path.join(yaml_path) 19 | yaml_info = open(fpath).readlines() 20 | return yaml_info 21 | 22 | 23 | def save_config_file(yaml_info, yaml_path): 24 | wf = open(yaml_path, 'w') 25 | for line in yaml_info: 26 | wf.write('{}'.format(line)) 27 | wf.close() 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | suffix = 'novel' if args.setting == 'fsod' else 'all' 33 | 34 | if args.dataset in ['voc']: 35 | name_template = 'defrcn_{}_r101_novelx_{}shot_seedx.yaml' 36 | yaml_path = os.path.join(args.config_root, name_template.format(args.setting, args.shot)) 37 | yaml_info = load_config_file(yaml_path) 38 | for i, lineinfo in enumerate(yaml_info): 39 | if ' TRAIN: ' in lineinfo: 40 | _str_ = ' TRAIN: ("voc_2007_trainval_{}{}_{}shot_seed{}", )\n' 41 | yaml_info[i] = _str_.format(suffix, args.split, args.shot, args.seed) 42 | if ' TEST: ' in lineinfo: 43 | _str_ = ' TEST: ("voc_2007_test_{}{}",)\n' 44 | yaml_info[i] = _str_.format(suffix, args.split) 45 | yaml_path = yaml_path.replace('novelx', 'novel{}'.format(args.split)) 46 | elif args.dataset in ['coco14']: 47 | name_template = 'defrcn_{}_r101_novel_{}shot_seedx.yaml' 48 | yaml_path = os.path.join(args.config_root, name_template.format(args.setting, args.shot)) 49 | yaml_info = load_config_file(yaml_path) 50 | for i, lineinfo in enumerate(yaml_info): 51 | if ' TRAIN: ' in lineinfo: 52 | _str_ = ' TRAIN: ("coco14_trainval_{}_{}shot_seed{}", )\n' 53 | yaml_info[i] = _str_.format(suffix, args.shot, args.seed) 54 | else: 55 | raise NotImplementedError 56 | 57 | yaml_path = yaml_path.replace('seedx', 'seed{}'.format(args.seed)) 58 | save_config_file(yaml_info, yaml_path) 59 | 60 | 61 | if __name__ == '__main__': 62 | main() 63 | -------------------------------------------------------------------------------- /defrcn/utils/kdloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def dandr_loss(logits_student, logits_teacher, target, alpha, beta, temperature, detach_target=True): 6 | if detach_target: 7 | logits_teacher = logits_teacher.detach() 8 | index_fg = (target != (logits_teacher.shape[1] - 1)) 9 | index_bg = (target == (logits_teacher.shape[1] - 1)) 10 | 11 | gt_mask = _get_target_mask(logits_student, target) 12 | other_mask = _get_other_mask(logits_student, target) 13 | 14 | pred_teacher = F.softmax(logits_teacher[index_fg] / temperature, dim=1) 15 | 16 | p_non_target_pos_teacher = (pred_teacher * other_mask[index_fg]).sum(1, keepdims=True)[:, 0] 17 | 18 | non_target_logits_teacher = logits_teacher - 1000.0 * gt_mask #.type(torch.float64) 19 | non_target_logits_student = logits_student - 1000.0 * gt_mask #.type(torch.float64) 20 | 21 | bg_mask = _get_bg_mask(non_target_logits_teacher) 22 | non_bg_mask = _get_non_bg_mask(non_target_logits_teacher) 23 | non_target_pred_student = F.softmax(non_target_logits_student / temperature, dim=1) 24 | non_target_pred_teacher = F.softmax(non_target_logits_teacher / temperature, dim=1) 25 | 26 | p_fbd_student = cat_mask(non_target_pred_student, bg_mask, non_bg_mask) 27 | p_fbd_teacher = cat_mask(non_target_pred_teacher, bg_mask, non_bg_mask) 28 | 29 | log_p_fbd_student = torch.log(p_fbd_student) 30 | loss_fbd =( 31 | F.kl_div(log_p_fbd_student, p_fbd_teacher, reduction='none').sum(1) 32 | * (temperature ** 2) 33 | ) 34 | 35 | p_fcd_teacher = F.softmax( 36 | non_target_logits_teacher / temperature - 1000 * bg_mask, dim=1 37 | ) 38 | log_p_fcd_student = F.log_softmax( 39 | non_target_logits_student / temperature - 1000 * bg_mask, dim=1 40 | ) 41 | loss_fcd = p_fbd_teacher[:, 1] * ( 42 | F.kl_div(log_p_fcd_student, p_fcd_teacher, reduction='none').sum(1) 43 | * (temperature**2) 44 | ) 45 | 46 | loss = alpha * torch.mean(p_non_target_pos_teacher * loss_fbd[index_fg]) \ 47 | + beta * torch.mean(loss_fbd[index_bg]) \ 48 | + torch.mean(p_non_target_pos_teacher * loss_fcd[index_fg])\ 49 | + torch.mean(loss_fcd[index_bg]) 50 | 51 | return loss 52 | 53 | 54 | def _get_target_mask(logits, target): 55 | target = target.reshape(-1) 56 | mask = torch.zeros_like(logits).scatter_(1, target.unsqueeze(1), 1).bool() 57 | mask[:, -1] = 0 58 | return mask 59 | 60 | def _get_other_mask(logits, target): 61 | target = target.reshape(-1) 62 | mask = torch.ones_like(logits).scatter_(1, target.unsqueeze(1), 0).bool() 63 | mask[:, -1] = 1 64 | return mask 65 | 66 | def _get_bg_mask(logits): 67 | mask = torch.zeros_like(logits) 68 | mask[:, -1] = 1 69 | mask = mask.bool() 70 | return mask 71 | 72 | def _get_non_bg_mask(logits): 73 | mask = torch.ones_like(logits) 74 | mask[:, -1] = 0 75 | mask = mask.bool() 76 | return mask 77 | 78 | def cat_mask(t, mask1, mask2): 79 | t1 = (t * mask1).sum(dim=1, keepdims=True) 80 | t2 = (t * mask2).sum(1, keepdims=True) 81 | rt = torch.cat([t1, t2], dim=1) 82 | return rt 83 | 84 | 85 | -------------------------------------------------------------------------------- /defrcn/config/config.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from fvcore.common.config import CfgNode as _CfgNode 3 | 4 | 5 | class CfgNode(_CfgNode): 6 | """ 7 | The same as `fvcore.common.config.CfgNode`, but different in: 8 | 1. Use unsafe yaml loading by default. 9 | Note that this may lead to arbitrary code execution: you must not 10 | load a config file from untrusted sources before manually inspecting 11 | the content of the file. 12 | 2. Support config versioning. 13 | When attempting to merge an old config, it will convert the old config automatically. 14 | """ 15 | 16 | # Note that the default value of allow_unsafe is changed to True 17 | def merge_from_file( 18 | self, cfg_filename: str, allow_unsafe: bool = True 19 | ) -> None: 20 | loaded_cfg = _CfgNode.load_yaml_with_base( 21 | cfg_filename, allow_unsafe=allow_unsafe 22 | ) 23 | loaded_cfg = type(self)(loaded_cfg) 24 | 25 | # defaults.py needs to import CfgNode 26 | from .defaults import _CC as _C 27 | 28 | latest_ver = _C.VERSION 29 | assert ( 30 | latest_ver == self.VERSION 31 | ), "CfgNode.merge_from_file is only allowed on a config of latest version!" 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | loaded_ver = loaded_cfg.get("VERSION", None) 36 | if loaded_ver is None: 37 | from .compat import guess_version 38 | 39 | loaded_ver = guess_version(loaded_cfg, cfg_filename) 40 | assert ( 41 | loaded_ver <= self.VERSION 42 | ), "Cannot merge a v{} config into a v{} config.".format( 43 | loaded_ver, self.VERSION 44 | ) 45 | 46 | if loaded_ver == self.VERSION: 47 | self.merge_from_other_cfg(loaded_cfg) 48 | else: 49 | # compat.py needs to import CfgNode 50 | from .compat import downgrade_config, upgrade_config 51 | 52 | logger.warning( 53 | "Loading an old v{} config file '{}' by automatically upgrading to v{}. " 54 | "See docs/CHANGELOG.md for instructions to update your files.".format( 55 | loaded_ver, cfg_filename, self.VERSION 56 | ) 57 | ) 58 | # To convert, first obtain a full config at an old version 59 | old_self = downgrade_config(self, to_version=loaded_ver) 60 | old_self.merge_from_other_cfg(loaded_cfg) 61 | new_config = upgrade_config(old_self) 62 | self.clear() 63 | self.update(new_config) 64 | 65 | 66 | global_cfg = CfgNode() 67 | 68 | 69 | def get_cfg() -> CfgNode: 70 | """ 71 | Get a copy of the default config. 72 | Returns: 73 | a fsdet CfgNode instance. 74 | """ 75 | from .defaults import _C 76 | 77 | return _C.clone() 78 | 79 | 80 | def set_global_cfg(cfg: CfgNode) -> None: 81 | """ 82 | Let the global config point to the given cfg. 83 | Assume that the given "cfg" has the key "KEY", after calling 84 | `set_global_cfg(cfg)`, the key can be accessed by: 85 | .. code-block:: python 86 | from fsdet.config import global_cfg 87 | print(global_cfg.KEY) 88 | By using a hacky global config, you can access these configs anywhere, 89 | without having to pass the config object or the values deep into the code. 90 | This is a hacky feature introduced for quick prototyping / research exploration. 91 | """ 92 | global global_cfg 93 | global_cfg.clear() 94 | global_cfg.update(cfg) 95 | -------------------------------------------------------------------------------- /defrcn/engine/hooks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | import torch 5 | import itertools 6 | import detectron2.utils.comm as comm 7 | from fvcore.common.file_io import PathManager 8 | from detectron2.config import global_cfg 9 | from detectron2.engine.train_loop import HookBase 10 | from detectron2.evaluation.testing import flatten_results_dict 11 | 12 | __all__ = ["EvalHookDeFRCN"] 13 | 14 | 15 | class EvalHookDeFRCN(HookBase): 16 | """ 17 | Run an evaluation function periodically, and at the end of training. 18 | It is executed every ``eval_period`` iterations and after the last iteration. 19 | """ 20 | 21 | def __init__(self, eval_period, eval_function, cfg): 22 | """ 23 | Args: 24 | eval_period (int): the period to run `eval_function`. Set to 0 to 25 | not evaluate periodically (but still after the last iteration). 26 | eval_function (callable): a function which takes no arguments, and 27 | returns a nested dict of evaluation metrics. 28 | cfg: config 29 | Note: 30 | This hook must be enabled in all or none workers. 31 | If you would like only certain workers to perform evaluation, 32 | give other workers a no-op function (`eval_function=lambda: None`). 33 | """ 34 | self._period = eval_period 35 | self._func = eval_function 36 | self.cfg = cfg 37 | 38 | def _do_eval(self): 39 | results = self._func() 40 | 41 | if results: 42 | assert isinstance( 43 | results, dict 44 | ), "Eval function must return a dict. Got {} instead.".format(results) 45 | 46 | flattened_results = flatten_results_dict(results) 47 | for k, v in flattened_results.items(): 48 | try: 49 | v = float(v) 50 | except Exception as e: 51 | raise ValueError( 52 | "[EvalHook] eval_function should return a nested dict of float. " 53 | "Got '{}: {}' instead.".format(k, v) 54 | ) from e 55 | self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) 56 | 57 | if comm.is_main_process() and results: 58 | # save evaluation results in json 59 | is_final = self.trainer.iter + 1 >= self.trainer.max_iter 60 | os.makedirs( 61 | os.path.join(self.cfg.OUTPUT_DIR, 'inference'), exist_ok=True) 62 | output_file = 'res_final.json' if is_final else \ 63 | 'iter_{:07d}.json'.format(self.trainer.iter) 64 | with PathManager.open(os.path.join(self.cfg.OUTPUT_DIR, 'inference', 65 | output_file), 'w') as fp: 66 | json.dump(results, fp) 67 | 68 | # Evaluation may take different time among workers. 69 | # A barrier make them start the next iteration together. 70 | comm.synchronize() 71 | 72 | def after_step(self): 73 | next_iter = self.trainer.iter + 1 74 | if self._period > 0 and next_iter % self._period == 0: 75 | self._do_eval() 76 | 77 | def after_train(self): 78 | # This condition is to prevent the eval from running after a failed training 79 | if self.trainer.iter + 1 >= self.trainer.max_iter: 80 | self._do_eval() 81 | # func is likely a closure that holds reference to the trainer 82 | # therefore we clean it to avoid circular reference in the end 83 | del self._func 84 | -------------------------------------------------------------------------------- /defrcn/modeling/roi_heads/box_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import fvcore.nn.weight_init as weight_init 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from detectron2.layers import Conv2d, ShapeSpec, get_norm 7 | from detectron2.utils.registry import Registry 8 | 9 | ROI_BOX_HEAD_REGISTRY = Registry("ROI_BOX_HEAD") 10 | ROI_BOX_HEAD_REGISTRY.__doc__ = """ 11 | Registry for box heads, which make box predictions from per-region features. 12 | 13 | The registered object will be called with `obj(cfg, input_shape)`. 14 | """ 15 | 16 | 17 | @ROI_BOX_HEAD_REGISTRY.register() 18 | class FastRCNNConvFCHead(nn.Module): 19 | """ 20 | A head with several 3x3 conv layers (each followed by norm & relu) and 21 | several fc layers (each followed by relu). 22 | """ 23 | 24 | def __init__(self, cfg, input_shape: ShapeSpec): 25 | """ 26 | The following attributes are parsed from config: 27 | num_conv, num_fc: the number of conv/fc layers 28 | conv_dim/fc_dim: the dimension of the conv/fc layers 29 | norm: normalization for the conv layers 30 | """ 31 | super().__init__() 32 | 33 | # fmt: off 34 | num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV 35 | conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM 36 | num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC 37 | fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM 38 | norm = cfg.MODEL.ROI_BOX_HEAD.NORM 39 | # fmt: on 40 | assert num_conv + num_fc > 0 41 | 42 | self._output_size = ( 43 | input_shape.channels, 44 | input_shape.height, 45 | input_shape.width, 46 | ) 47 | 48 | self.conv_norm_relus = [] 49 | for k in range(num_conv): 50 | conv = Conv2d( 51 | self._output_size[0], 52 | conv_dim, 53 | kernel_size=3, 54 | padding=1, 55 | bias=not norm, 56 | norm=get_norm(norm, conv_dim), 57 | activation=F.relu, 58 | ) 59 | self.add_module("conv{}".format(k + 1), conv) 60 | self.conv_norm_relus.append(conv) 61 | self._output_size = ( 62 | conv_dim, 63 | self._output_size[1], 64 | self._output_size[2], 65 | ) 66 | 67 | self.fcs = [] 68 | for k in range(num_fc): 69 | fc = nn.Linear(np.prod(self._output_size), fc_dim) 70 | self.add_module("fc{}".format(k + 1), fc) 71 | self.fcs.append(fc) 72 | self._output_size = fc_dim 73 | 74 | for layer in self.conv_norm_relus: 75 | weight_init.c2_msra_fill(layer) 76 | for layer in self.fcs: 77 | weight_init.c2_xavier_fill(layer) 78 | 79 | def forward(self, x): 80 | for layer in self.conv_norm_relus: 81 | x = layer(x) 82 | if len(self.fcs): 83 | if x.dim() > 2: 84 | x = torch.flatten(x, start_dim=1) 85 | for layer in self.fcs: 86 | x = F.relu(layer(x)) 87 | return x 88 | 89 | @property 90 | def output_size(self): 91 | return self._output_size 92 | 93 | 94 | def build_box_head(cfg, input_shape): 95 | """ 96 | Build a box head defined by `cfg.MODEL.ROI_BOX_HEAD.NAME`. 97 | """ 98 | name = cfg.MODEL.ROI_BOX_HEAD.NAME 99 | return ROI_BOX_HEAD_REGISTRY.get(name)(cfg, input_shape) 100 | -------------------------------------------------------------------------------- /defrcn/data/meta_coco.py: -------------------------------------------------------------------------------- 1 | import io 2 | import os 3 | import contextlib 4 | import numpy as np 5 | from pycocotools.coco import COCO 6 | from detectron2.structures import BoxMode 7 | from fvcore.common.file_io import PathManager 8 | from detectron2.data import DatasetCatalog, MetadataCatalog 9 | 10 | 11 | __all__ = ["register_meta_coco"] 12 | 13 | 14 | def load_coco_json(json_file, image_root, metadata, dataset_name): 15 | is_shots = "shot" in dataset_name # few-shot 16 | if is_shots: 17 | imgid2info = {} 18 | shot = dataset_name.split('_')[-2].split('shot')[0] 19 | seed = int(dataset_name.split('_seed')[-1]) 20 | split_dir = os.path.join('datasets', 'cocosplit', 'seed{}'.format(seed)) 21 | for idx, cls in enumerate(metadata["thing_classes"]): 22 | json_file = os.path.join(split_dir, "full_box_{}shot_{}_trainval.json".format(shot, cls)) 23 | json_file = PathManager.get_local_path(json_file) 24 | with contextlib.redirect_stdout(io.StringIO()): 25 | coco_api = COCO(json_file) 26 | img_ids = sorted(list(coco_api.imgs.keys())) 27 | for img_id in img_ids: 28 | if img_id not in imgid2info: 29 | imgid2info[img_id] = [coco_api.loadImgs([img_id])[0], coco_api.imgToAnns[img_id]] 30 | else: 31 | for item in coco_api.imgToAnns[img_id]: 32 | imgid2info[img_id][1].append(item) 33 | imgs, anns = [], [] 34 | for img_id in imgid2info: 35 | imgs.append(imgid2info[img_id][0]) 36 | anns.append(imgid2info[img_id][1]) 37 | else: 38 | json_file = PathManager.get_local_path(json_file) 39 | with contextlib.redirect_stdout(io.StringIO()): 40 | coco_api = COCO(json_file) 41 | # sort indices for reproducible results 42 | img_ids = sorted(list(coco_api.imgs.keys())) 43 | imgs = coco_api.loadImgs(img_ids) 44 | anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] 45 | 46 | imgs_anns = list(zip(imgs, anns)) 47 | id_map = metadata["thing_dataset_id_to_contiguous_id"] 48 | 49 | dataset_dicts = [] 50 | ann_keys = ["iscrowd", "bbox", "category_id"] 51 | 52 | for (img_dict, anno_dict_list) in imgs_anns: 53 | record = {} 54 | record["file_name"] = os.path.join( 55 | image_root, img_dict["file_name"] 56 | ) 57 | record["height"] = img_dict["height"] 58 | record["width"] = img_dict["width"] 59 | image_id = record["image_id"] = img_dict["id"] 60 | 61 | objs = [] 62 | for anno in anno_dict_list: 63 | assert anno["image_id"] == image_id 64 | assert anno.get("ignore", 0) == 0 65 | 66 | obj = {key: anno[key] for key in ann_keys if key in anno} 67 | 68 | obj["bbox_mode"] = BoxMode.XYWH_ABS 69 | if obj["category_id"] in id_map: 70 | obj["category_id"] = id_map[obj["category_id"]] 71 | objs.append(obj) 72 | record["annotations"] = objs 73 | dataset_dicts.append(record) 74 | 75 | return dataset_dicts 76 | 77 | 78 | def register_meta_coco(name, metadata, imgdir, annofile): 79 | DatasetCatalog.register( 80 | name, 81 | lambda: load_coco_json(annofile, imgdir, metadata, name), 82 | ) 83 | 84 | if "_base" in name or "_novel" in name: 85 | split = "base" if "_base" in name else "novel" 86 | metadata["thing_dataset_id_to_contiguous_id"] = metadata[ 87 | "{}_dataset_id_to_contiguous_id".format(split) 88 | ] 89 | metadata["thing_classes"] = metadata["{}_classes".format(split)] 90 | 91 | MetadataCatalog.get(name).set( 92 | json_file=annofile, 93 | image_root=imgdir, 94 | evaluator_type="coco", 95 | dirname="datasets/coco", 96 | **metadata, 97 | ) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | 3 | This repo contains the official PyTorch implementation of D\&R 4 | 5 | ## Quick Start 6 | 7 | **1. Check Requirements** 8 | * Linux with Python >= 3.6 9 | * [PyTorch](https://pytorch.org/get-started/locally/) >= 1.6 & [torchvision](https://github.com/pytorch/vision/) that matches the PyTorch version. 10 | * CUDA 10.1, 10.2 11 | * GCC >= 4.9 12 | 13 | **2. Build** 14 | 15 | * Create a virtual environment (optional) 16 | ``` 17 | conda create -n dandr python=3.7 18 | conda activate dandrzq 19 | ``` 20 | * Install PyTorch according to your CUDA version 21 | 22 | * Install Detectron2 (the version of Detectron2 must be 0.3) 23 | ```angular2html 24 | python3 -m pip install detectron2==0.3 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.6/index.html 25 | ``` 26 | * Install other requirements. 27 | ```angular2html 28 | python3 -m pip install -r requirements.txt 29 | ``` 30 | 31 | **3. Prepare Data and Weights** 32 | * Data Preparation (from DeFRCN) 33 | 34 | | Dataset | Size | GoogleDrive | BaiduYun | Note | 35 | |:---:|:---:|:---:|:---:|:---:| 36 | |VOC2007| 0.8G |[download](https://drive.google.com/file/d/1BcuJ9j9Mtymp56qGSOfYxlXN4uEVyxFm/view?usp=sharing)|[download](https://pan.baidu.com/s/1kjAmHY5JKDoG0L65T3dK9g)| - | 37 | |VOC2012| 3.5G |[download](https://drive.google.com/file/d/1NjztPltqm-Z-pG94a6PiPVP4BgD8Sz1H/view?usp=sharing)|[download](https://pan.baidu.com/s/1DUJT85AG_fqP9NRPhnwU2Q)| - | 38 | |vocsplit| <1M |[download](https://drive.google.com/file/d/1BpDDqJ0p-fQAFN_pthn2gqiK5nWGJ-1a/view?usp=sharing)|[download](https://pan.baidu.com/s/1518_egXZoJNhqH4KRDQvfw)| refer from [TFA](https://github.com/ucbdrive/few-shot-object-detection#models) | 39 | |COCO| ~19G | - | - | download from [offical](https://cocodataset.org/#download)| 40 | |cocosplit| 174M |[download](https://drive.google.com/file/d/1T_cYLxNqYlbnFNJt8IVvT7ZkWb5c0esj/view?usp=sharing)|[download](https://pan.baidu.com/s/1NELvshrbkpRS8BiuBIr5gA)| refer from [TFA](https://github.com/ucbdrive/few-shot-object-detection#models) | 41 | - Unzip the downloaded data-source to `datasets` and put it into your project directory: 42 | ```angular2html 43 | ... 44 | datasets 45 | | -- coco (trainval2014/*.jpg, val2014/*.jpg, annotations/*.json) 46 | | -- cocosplit 47 | | -- VOC2007 48 | | -- VOC2012 49 | | -- vocsplit 50 | defrcn 51 | tools 52 | ... 53 | ``` 54 | * Weights Preparation 55 | - DeFRCN use the imagenet pretrain weights to initialize the model. 56 | Download the same models from (given by DeFRCN): [GoogleDrive](https://drive.google.com/file/d/1rsE20_fSkYeIhFaNU04rBfEDkMENLibj/view?usp=sharing) [BaiduYun](https://pan.baidu.com/s/1IfxFq15LVUI3iIMGFT8slw) 57 | - Put the chekpoints into ImageNetPretrained/MSRA/R-101.pkl, ImageNetPretrained/torchvision, respectively 58 | - We provide the BASE_WEIGHT (refer to run_*.sh) we used. 59 | | Dataset | Split | Size | GoogleDrive | 60 | |:---:|:---:|:---:|:---:| 61 | |VOC2007| 1 | 203.8M | [download](https://drive.google.com/file/d/19LxiN9cj92YePs02k9E4-KyGY5ohTU9w/view?usp=share_link)| 62 | |VOC2007| 2 | 203.8M | [download](https://drive.google.com/file/d/1t1bbJ-YsXohIDUsQvUiF8pC6f7vxh0Z3/view?usp=share_link)| 63 | |VOC2007| 3 | 203.8M | [download](https://drive.google.com/file/d/1bWiS0fBrQDljTnpBFFldbZ8ZmZUmhB8w/view?usp=share_link)| 64 | | COCO | - | 206.2MB | [download](https://drive.google.com/file/d/1pH-7b_1B3qm_rJo-_nEfcrHmy-PCHHZ7/view?usp=share_link) | 65 | 66 | * Text Embeddings Preparation 67 | - Refer to the official implementation of [CLIP](https://github.com/openai/CLIP) for text embedding generation. 68 | - Put the generated text embeddings into 'dataset/clip' 69 | 70 | **4. Training and Evaluation** 71 | 72 | * To reproduce the results on VOC, 73 | ```angular2html 74 | sh run_voc.sh SPLIT_ID (1, 2 or 3) 75 | ``` 76 | * To reproduce the results on COCO 77 | ```angular2html 78 | sh run_coco.sh 79 | ``` 80 | * Please read the details of few-shot object detection pipeline in `run_*.sh`. 81 | 82 | ## Acknowledgement 83 | This repo is developed based on DeFRCN and [Detectron2](https://github.com/facebookresearch/detectron2). Please check them for more details and features. 84 | ``` 85 | -------------------------------------------------------------------------------- /tools/model_surgery.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | 5 | 6 | def surgery_loop(args, surgery): 7 | 8 | save_name = args.tar_name + '_' + ('remove' if args.method == 'remove' else 'surgery') + '.pth' 9 | save_path = os.path.join(args.save_dir, save_name) 10 | os.makedirs(args.save_dir, exist_ok=True) 11 | 12 | ckpt = torch.load(args.src_path) 13 | if 'scheduler' in ckpt: 14 | del ckpt['scheduler'] 15 | if 'optimizer' in ckpt: 16 | del ckpt['optimizer'] 17 | if 'iteration' in ckpt: 18 | ckpt['iteration'] = 0 19 | 20 | if args.method == 'remove': 21 | for param_name in args.param_name: 22 | del ckpt['model'][param_name + '.weight'] 23 | if param_name+'.bias' in ckpt['model']: 24 | del ckpt['model'][param_name+'.bias'] 25 | elif args.method == 'randinit': 26 | tar_sizes = [TAR_SIZE + 1, TAR_SIZE * 4] 27 | for idx, (param_name, tar_size) in enumerate(zip(args.param_name, tar_sizes)): 28 | surgery(param_name, True, tar_size, ckpt) 29 | surgery(param_name, False, tar_size, ckpt) 30 | else: 31 | raise NotImplementedError 32 | 33 | torch.save(ckpt, save_path) 34 | print('save changed ckpt to {}'.format(save_path)) 35 | 36 | 37 | def main(args): 38 | """ 39 | Either remove the final layer weights for fine-tuning on novel dataset or 40 | append randomly initialized weights for the novel classes. 41 | """ 42 | def surgery(param_name, is_weight, tar_size, ckpt): 43 | weight_name = param_name + ('.weight' if is_weight else '.bias') 44 | pretrained_weight = ckpt['model'][weight_name] 45 | prev_cls = pretrained_weight.size(0) 46 | if 'cls_score' in param_name: 47 | prev_cls -= 1 48 | if is_weight: 49 | feat_size = pretrained_weight.size(1) 50 | new_weight = torch.rand((tar_size, feat_size)) 51 | torch.nn.init.normal_(new_weight, 0, 0.01) 52 | else: 53 | new_weight = torch.zeros(tar_size) 54 | if args.dataset == 'coco': 55 | for idx, c in enumerate(BASE_CLASSES): 56 | # idx = i if args.dataset == 'coco' else c 57 | if 'cls_score' in param_name: 58 | new_weight[IDMAP[c]] = pretrained_weight[idx] 59 | else: 60 | new_weight[IDMAP[c]*4:(IDMAP[c]+1)*4] = \ 61 | pretrained_weight[idx*4:(idx+1)*4] 62 | else: 63 | new_weight[:prev_cls] = pretrained_weight[:prev_cls] 64 | if 'cls_score' in param_name: 65 | new_weight[-1] = pretrained_weight[-1] # bg class 66 | ckpt['model'][weight_name] = new_weight 67 | 68 | surgery_loop(args, surgery) 69 | 70 | 71 | if __name__ == '__main__': 72 | 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument('--dataset', type=str, default='coco', choices=['voc', 'coco']) 75 | parser.add_argument('--src-path', type=str, default='', help='Path to the main checkpoint') 76 | parser.add_argument('--save-dir', type=str, default='', required=True, help='Save directory') 77 | parser.add_argument('--method', choices=['remove', 'randinit'], required=True, 78 | help='remove = remove the final layer of the base detector. ' 79 | 'randinit = randomly initialize novel weights.') 80 | parser.add_argument('--param-name', type=str, nargs='+', help='Target parameter names', 81 | default=['roi_heads.box_predictor.cls_score', 'roi_heads.box_predictor.bbox_pred']) 82 | parser.add_argument('--tar-name', type=str, default='model_reset', help='Name of the new ckpt') 83 | args = parser.parse_args() 84 | 85 | if args.dataset == 'coco': 86 | NOVEL_CLASSES = [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 44, 62, 63, 64, 67, 72] 87 | BASE_CLASSES = [8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 88 | 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 89 | 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90] 90 | ALL_CLASSES = sorted(BASE_CLASSES + NOVEL_CLASSES) 91 | IDMAP = {v: i for i, v in enumerate(ALL_CLASSES)} 92 | TAR_SIZE = 80 93 | elif args.dataset == 'voc': 94 | TAR_SIZE = 20 95 | else: 96 | raise NotImplementedError 97 | 98 | main(args) 99 | -------------------------------------------------------------------------------- /defrcn/data/builtin.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .meta_voc import register_meta_voc 3 | from .meta_coco import register_meta_coco 4 | from .builtin_meta import _get_builtin_metadata 5 | from detectron2.data import DatasetCatalog, MetadataCatalog 6 | 7 | 8 | # -------- COCO -------- # 9 | def register_all_coco(root="datasets"): 10 | 11 | METASPLITS = [ 12 | ("coco14_trainval_all", "coco/trainval2014", "cocosplit/datasplit/trainvalno5k.json"), 13 | ("coco14_trainval_base", "coco/trainval2014", "cocosplit/datasplit/trainvalno5k.json"), 14 | ("coco14_test_all", "coco/val2014", "cocosplit/datasplit/5k.json"), 15 | ("coco14_test_base", "coco/val2014", "cocosplit/datasplit/5k.json"), 16 | ("coco14_test_novel", "coco/val2014", "cocosplit/datasplit/5k.json"), 17 | ] 18 | for prefix in ["all", "novel"]: 19 | for shot in [1, 2, 3, 5, 10, 30]: 20 | for seed in range(10): 21 | name = "coco14_trainval_{}_{}shot_seed{}".format(prefix, shot, seed) 22 | METASPLITS.append((name, "coco/trainval2014", "")) 23 | 24 | for name, imgdir, annofile in METASPLITS: 25 | register_meta_coco( 26 | name, 27 | _get_builtin_metadata("coco_fewshot"), 28 | os.path.join(root, imgdir), 29 | os.path.join(root, annofile), 30 | ) 31 | 32 | 33 | # -------- PASCAL VOC -------- # 34 | def register_all_voc(root="datasets"): 35 | 36 | METASPLITS = [ 37 | ("voc_2007_trainval_base1", "VOC2007", "trainval", "base1", 1), 38 | ("voc_2007_trainval_base2", "VOC2007", "trainval", "base2", 2), 39 | ("voc_2007_trainval_base3", "VOC2007", "trainval", "base3", 3), 40 | ("voc_2012_trainval_base1", "VOC2012", "trainval", "base1", 1), 41 | ("voc_2012_trainval_base2", "VOC2012", "trainval", "base2", 2), 42 | ("voc_2012_trainval_base3", "VOC2012", "trainval", "base3", 3), 43 | ("voc_2007_trainval_all1", "VOC2007", "trainval", "base_novel_1", 1), 44 | ("voc_2007_trainval_all2", "VOC2007", "trainval", "base_novel_2", 2), 45 | ("voc_2007_trainval_all3", "VOC2007", "trainval", "base_novel_3", 3), 46 | ("voc_2012_trainval_all1", "VOC2012", "trainval", "base_novel_1", 1), 47 | ("voc_2012_trainval_all2", "VOC2012", "trainval", "base_novel_2", 2), 48 | ("voc_2012_trainval_all3", "VOC2012", "trainval", "base_novel_3", 3), 49 | ("voc_2007_test_base1", "VOC2007", "test", "base1", 1), 50 | ("voc_2007_test_base2", "VOC2007", "test", "base2", 2), 51 | ("voc_2007_test_base3", "VOC2007", "test", "base3", 3), 52 | ("voc_2007_test_novel1", "VOC2007", "test", "novel1", 1), 53 | ("voc_2007_test_novel2", "VOC2007", "test", "novel2", 2), 54 | ("voc_2007_test_novel3", "VOC2007", "test", "novel3", 3), 55 | ("voc_2007_test_all1", "VOC2007", "test", "base_novel_1", 1), 56 | ("voc_2007_test_all2", "VOC2007", "test", "base_novel_2", 2), 57 | ("voc_2007_test_all3", "VOC2007", "test", "base_novel_3", 3), 58 | ] 59 | for prefix in ["all", "novel"]: 60 | for sid in range(1, 4): 61 | for shot in [1, 2, 3, 5, 10]: 62 | for year in [2007, 2012]: 63 | for seed in range(30): 64 | seed = "_seed{}".format(seed) 65 | name = "voc_{}_trainval_{}{}_{}shot{}".format( 66 | year, prefix, sid, shot, seed 67 | ) 68 | dirname = "VOC{}".format(year) 69 | img_file = "{}_{}shot_split_{}_trainval".format( 70 | prefix, shot, sid 71 | ) 72 | keepclasses = ( 73 | "base_novel_{}".format(sid) 74 | if prefix == "all" 75 | else "novel{}".format(sid) 76 | ) 77 | METASPLITS.append( 78 | (name, dirname, img_file, keepclasses, sid) 79 | ) 80 | 81 | for name, dirname, split, keepclasses, sid in METASPLITS: 82 | year = 2007 if "2007" in name else 2012 83 | register_meta_voc( 84 | name, 85 | _get_builtin_metadata("voc_fewshot"), 86 | os.path.join(root, dirname), 87 | split, 88 | year, 89 | keepclasses, 90 | sid, 91 | ) 92 | MetadataCatalog.get(name).evaluator_type = "pascal_voc" 93 | 94 | 95 | register_all_coco() 96 | register_all_voc() -------------------------------------------------------------------------------- /defrcn/solver/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | import math 3 | import torch 4 | from typing import List 5 | from bisect import bisect_right 6 | 7 | # NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes 8 | # only on epoch boundaries. We typically use iteration based schedules instead. 9 | # As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean 10 | # "iteration" instead. 11 | 12 | # FIXME: ideally this would be achieved with a CombinedLRScheduler, separating 13 | # MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it. 14 | 15 | 16 | class WarmupMultiStepLR(torch.optim.lr_scheduler._LRScheduler): 17 | def __init__( 18 | self, 19 | optimizer: torch.optim.Optimizer, 20 | milestones: List[int], 21 | gamma: float = 0.1, 22 | warmup_factor: float = 0.001, 23 | warmup_iters: int = 1000, 24 | warmup_method: str = "linear", 25 | last_epoch: int = -1, 26 | ): 27 | if not list(milestones) == sorted(milestones): 28 | raise ValueError( 29 | "Milestones should be a list of" " increasing integers. Got {}", milestones 30 | ) 31 | self.milestones = milestones 32 | self.gamma = gamma 33 | self.warmup_factor = warmup_factor 34 | self.warmup_iters = warmup_iters 35 | self.warmup_method = warmup_method 36 | super().__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self) -> List[float]: 39 | warmup_factor = _get_warmup_factor_at_iter( 40 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 41 | ) 42 | return [ 43 | base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) 44 | for base_lr in self.base_lrs 45 | ] 46 | 47 | def _compute_values(self) -> List[float]: 48 | # The new interface 49 | return self.get_lr() 50 | 51 | 52 | class WarmupCosineLR(torch.optim.lr_scheduler._LRScheduler): 53 | def __init__( 54 | self, 55 | optimizer: torch.optim.Optimizer, 56 | max_iters: int, 57 | warmup_factor: float = 0.001, 58 | warmup_iters: int = 1000, 59 | warmup_method: str = "linear", 60 | last_epoch: int = -1, 61 | ): 62 | self.max_iters = max_iters 63 | self.warmup_factor = warmup_factor 64 | self.warmup_iters = warmup_iters 65 | self.warmup_method = warmup_method 66 | super().__init__(optimizer, last_epoch) 67 | 68 | def get_lr(self) -> List[float]: 69 | warmup_factor = _get_warmup_factor_at_iter( 70 | self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor 71 | ) 72 | # Different definitions of half-cosine with warmup are possible. For 73 | # simplicity we multiply the standard half-cosine schedule by the warmup 74 | # factor. An alternative is to start the period of the cosine at warmup_iters 75 | # instead of at 0. In the case that warmup_iters << max_iters the two are 76 | # very close to each other. 77 | return [ 78 | base_lr 79 | * warmup_factor 80 | * 0.5 81 | * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) 82 | for base_lr in self.base_lrs 83 | ] 84 | 85 | def _compute_values(self) -> List[float]: 86 | # The new interface 87 | return self.get_lr() 88 | 89 | 90 | def _get_warmup_factor_at_iter( 91 | method: str, iter: int, warmup_iters: int, warmup_factor: float 92 | ) -> float: 93 | """ 94 | Return the learning rate warmup factor at a specific iteration. 95 | See :paper:`ImageNet in 1h` for more details. 96 | Args: 97 | method (str): warmup method; either "constant" or "linear". 98 | iter (int): iteration at which to calculate the warmup factor. 99 | warmup_iters (int): the number of warmup iterations. 100 | warmup_factor (float): the base warmup factor (the meaning changes according 101 | to the method used). 102 | Returns: 103 | float: the effective warmup factor at the given iteration. 104 | """ 105 | if iter >= warmup_iters: 106 | return 1.0 107 | 108 | if method == "constant": 109 | return warmup_factor 110 | elif method == "linear": 111 | alpha = iter / warmup_iters 112 | return warmup_factor * (1 - alpha) + alpha 113 | else: 114 | raise ValueError("Unknown warmup method: {}".format(method)) 115 | -------------------------------------------------------------------------------- /defrcn/modeling/meta_arch/rcnn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from torch import nn 4 | from detectron2.structures import ImageList 5 | from detectron2.utils.logger import log_first_n 6 | from detectron2.modeling.backbone import build_backbone 7 | from detectron2.modeling.postprocessing import detector_postprocess 8 | from detectron2.modeling.proposal_generator import build_proposal_generator 9 | from .build import META_ARCH_REGISTRY 10 | from .gdl import decouple_layer, AffineLayer 11 | from defrcn.modeling.roi_heads import build_roi_heads 12 | 13 | __all__ = ["GeneralizedRCNN"] 14 | 15 | 16 | @META_ARCH_REGISTRY.register() 17 | class GeneralizedRCNN(nn.Module): 18 | 19 | def __init__(self, cfg): 20 | super().__init__() 21 | 22 | self.cfg = cfg 23 | self.device = torch.device(cfg.MODEL.DEVICE) 24 | self.backbone = build_backbone(cfg) 25 | self._SHAPE_ = self.backbone.output_shape() 26 | self.proposal_generator = build_proposal_generator(cfg, self._SHAPE_) 27 | self.roi_heads = build_roi_heads(cfg, self._SHAPE_) 28 | self.normalizer = self.normalize_fn() 29 | self.affine_rpn = AffineLayer(num_channels=self._SHAPE_['res4'].channels, bias=True) 30 | self.affine_rcnn = AffineLayer(num_channels=self._SHAPE_['res4'].channels, bias=True) 31 | self.to(self.device) 32 | 33 | if cfg.MODEL.BACKBONE.FREEZE: 34 | for p in self.backbone.parameters(): 35 | p.requires_grad = False 36 | print("froze backbone parameters") 37 | 38 | if cfg.MODEL.RPN.FREEZE: 39 | for p in self.proposal_generator.parameters(): 40 | p.requires_grad = False 41 | print("froze proposal generator parameters") 42 | 43 | if cfg.MODEL.ROI_HEADS.FREEZE_FEAT: 44 | for p in self.roi_heads.res5.parameters(): 45 | p.requires_grad = False 46 | print("froze roi_box_head parameters") 47 | 48 | def forward(self, batched_inputs, inference_with_aux=False): 49 | if not self.training: 50 | return self.inference(batched_inputs, inference_with_aux=inference_with_aux) 51 | assert "instances" in batched_inputs[0] 52 | gt_instances = [x["instances"].to(self.device) for x in batched_inputs] 53 | proposal_losses, detector_losses, _, _ = self._forward_once_(batched_inputs, gt_instances) 54 | losses = {} 55 | losses.update(detector_losses) 56 | losses.update(proposal_losses) 57 | return losses 58 | 59 | def inference(self, batched_inputs, inference_with_aux): 60 | assert not self.training 61 | _, _, results, image_sizes = self._forward_once_(batched_inputs, None, inference_with_aux=inference_with_aux) 62 | processed_results = [] 63 | for r, input, image_size in zip(results, batched_inputs, image_sizes): 64 | height = input.get("height", image_size[0]) 65 | width = input.get("width", image_size[1]) 66 | r = detector_postprocess(r, height, width) 67 | processed_results.append({"instances": r}) 68 | return processed_results 69 | 70 | def _forward_once_(self, batched_inputs, gt_instances=None, inference_with_aux=False): 71 | 72 | images = self.preprocess_image(batched_inputs) 73 | features = self.backbone(images.tensor) 74 | 75 | features_de_rpn = features 76 | if self.cfg.MODEL.RPN.ENABLE_DECOUPLE: 77 | scale = self.cfg.MODEL.RPN.BACKWARD_SCALE 78 | features_de_rpn = {k: self.affine_rpn(decouple_layer(features[k], scale)) for k in features} 79 | proposals, proposal_losses = self.proposal_generator(images, features_de_rpn, gt_instances) 80 | 81 | features_de_rcnn = features 82 | if self.cfg.MODEL.ROI_HEADS.ENABLE_DECOUPLE: 83 | scale = self.cfg.MODEL.ROI_HEADS.BACKWARD_SCALE 84 | features_de_rcnn = {k: self.affine_rcnn(decouple_layer(features[k], scale)) for k in features} 85 | results, detector_losses = self.roi_heads(images, features_de_rcnn, proposals, gt_instances, 86 | inference_with_aux=inference_with_aux) 87 | 88 | return proposal_losses, detector_losses, results, images.image_sizes 89 | 90 | def preprocess_image(self, batched_inputs): 91 | images = [x["image"].to(self.device) for x in batched_inputs] 92 | images = [self.normalizer(x) for x in images] 93 | images = ImageList.from_tensors(images, self.backbone.size_divisibility) 94 | return images 95 | 96 | def normalize_fn(self): 97 | assert len(self.cfg.MODEL.PIXEL_MEAN) == len(self.cfg.MODEL.PIXEL_STD) 98 | num_channels = len(self.cfg.MODEL.PIXEL_MEAN) 99 | pixel_mean = (torch.Tensor( 100 | self.cfg.MODEL.PIXEL_MEAN).to(self.device).view(num_channels, 1, 1)) 101 | pixel_std = (torch.Tensor( 102 | self.cfg.MODEL.PIXEL_STD).to(self.device).view(num_channels, 1, 1)) 103 | return lambda x: (x - pixel_mean) / pixel_std 104 | -------------------------------------------------------------------------------- /defrcn/data/meta_voc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import xml.etree.ElementTree as ET 4 | from detectron2.structures import BoxMode 5 | from fvcore.common.file_io import PathManager 6 | from detectron2.data import DatasetCatalog, MetadataCatalog 7 | import logging 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | __all__ = ["register_meta_voc"] 12 | 13 | 14 | def load_filtered_voc_instances( 15 | name: str, dirname: str, split: str, classnames: str 16 | ): 17 | """ 18 | Load Pascal VOC detection annotations to Detectron2 format. 19 | Args: 20 | dirname: Contain "Annotations", "ImageSets", "JPEGImages" 21 | split (str): one of "train", "test", "val", "trainval" 22 | """ 23 | is_shots = "shot" in name 24 | if is_shots: 25 | fileids = {} 26 | split_dir = os.path.join("datasets", "vocsplit") 27 | shot = name.split("_")[-2].split("shot")[0] 28 | seed = int(name.split("_seed")[-1]) 29 | split_dir = os.path.join(split_dir, "seed{}".format(seed)) 30 | for cls in classnames: 31 | with PathManager.open( 32 | os.path.join( 33 | split_dir, "box_{}shot_{}_train.txt".format(shot, cls) 34 | ) 35 | ) as f: 36 | fileids_ = np.loadtxt(f, dtype=np.str).tolist() 37 | if isinstance(fileids_, str): 38 | fileids_ = [fileids_] 39 | fileids_ = [ 40 | fid.split("/")[-1].split(".jpg")[0] for fid in fileids_ 41 | ] 42 | fileids[cls] = fileids_ 43 | else: 44 | with PathManager.open( 45 | os.path.join(dirname, "ImageSets", "Main", split + ".txt") 46 | ) as f: 47 | fileids = np.loadtxt(f, dtype=np.str) 48 | 49 | dicts = [] 50 | if is_shots: 51 | for cls, fileids_ in fileids.items(): 52 | dicts_ = [] 53 | for fileid in fileids_: 54 | year = "2012" if "_" in fileid else "2007" 55 | dirname = os.path.join("datasets", "VOC{}".format(year)) 56 | anno_file = os.path.join( 57 | dirname, "Annotations", fileid + ".xml" 58 | ) 59 | jpeg_file = os.path.join( 60 | dirname, "JPEGImages", fileid + ".jpg" 61 | ) 62 | 63 | tree = ET.parse(anno_file) 64 | 65 | for obj in tree.findall("object"): 66 | r = { 67 | "file_name": jpeg_file, 68 | "image_id": fileid, 69 | "height": int(tree.findall("./size/height")[0].text), 70 | "width": int(tree.findall("./size/width")[0].text), 71 | } 72 | cls_ = obj.find("name").text 73 | if cls != cls_: 74 | continue 75 | bbox = obj.find("bndbox") 76 | bbox = [ 77 | float(bbox.find(x).text) 78 | for x in ["xmin", "ymin", "xmax", "ymax"] 79 | ] 80 | bbox[0] -= 1.0 81 | bbox[1] -= 1.0 82 | 83 | instances = [ 84 | { 85 | "category_id": classnames.index(cls), 86 | "bbox": bbox, 87 | "bbox_mode": BoxMode.XYXY_ABS, 88 | } 89 | ] 90 | r["annotations"] = instances 91 | dicts_.append(r) 92 | if len(dicts_) > int(shot): 93 | dicts_ = np.random.choice(dicts_, int(shot), replace=False) 94 | logger.info(str(cls) + ': ' + str(dicts_)) 95 | dicts.extend(dicts_) 96 | else: 97 | for fileid in fileids: 98 | anno_file = os.path.join(dirname, "Annotations", fileid + ".xml") 99 | jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg") 100 | 101 | tree = ET.parse(anno_file) 102 | 103 | r = { 104 | "file_name": jpeg_file, 105 | "image_id": fileid, 106 | "height": int(tree.findall("./size/height")[0].text), 107 | "width": int(tree.findall("./size/width")[0].text), 108 | } 109 | instances = [] 110 | 111 | for obj in tree.findall("object"): 112 | cls = obj.find("name").text 113 | if not (cls in classnames): 114 | continue 115 | bbox = obj.find("bndbox") 116 | bbox = [ 117 | float(bbox.find(x).text) 118 | for x in ["xmin", "ymin", "xmax", "ymax"] 119 | ] 120 | bbox[0] -= 1.0 121 | bbox[1] -= 1.0 122 | 123 | instances.append( 124 | { 125 | "category_id": classnames.index(cls), 126 | "bbox": bbox, 127 | "bbox_mode": BoxMode.XYXY_ABS, 128 | } 129 | ) 130 | r["annotations"] = instances 131 | dicts.append(r) 132 | return dicts 133 | 134 | 135 | def register_meta_voc( 136 | name, metadata, dirname, split, year, keepclasses, sid 137 | ): 138 | if keepclasses.startswith("base_novel"): 139 | thing_classes = metadata["thing_classes"][sid] 140 | elif keepclasses.startswith("base"): 141 | thing_classes = metadata["base_classes"][sid] 142 | elif keepclasses.startswith("novel"): 143 | thing_classes = metadata["novel_classes"][sid] 144 | 145 | DatasetCatalog.register( 146 | name, 147 | lambda: load_filtered_voc_instances( 148 | name, dirname, split, thing_classes 149 | ), 150 | ) 151 | 152 | MetadataCatalog.get(name).set( 153 | thing_classes=thing_classes, 154 | dirname=dirname, 155 | year=year, 156 | split=split, 157 | base_classes=metadata["base_classes"][sid], 158 | novel_classes=metadata["novel_classes"][sid], 159 | ) 160 | -------------------------------------------------------------------------------- /defrcn/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | import logging 4 | import datetime 5 | from collections import OrderedDict 6 | from contextlib import contextmanager 7 | from detectron2.utils.comm import is_main_process 8 | from .calibration_layer import PrototypicalCalibrationBlock 9 | 10 | 11 | class DatasetEvaluator: 12 | """ 13 | Base class for a dataset evaluator. 14 | 15 | The function :func:`inference_on_dataset` runs the model over 16 | all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs. 17 | 18 | This class will accumulate information of the inputs/outputs (by :meth:`process`), 19 | and produce evaluation results in the end (by :meth:`evaluate`). 20 | """ 21 | 22 | def reset(self): 23 | """ 24 | Preparation for a new round of evaluation. 25 | Should be called before starting a round of evaluation. 26 | """ 27 | pass 28 | 29 | def process(self, input, output): 30 | """ 31 | Process an input/output pair. 32 | 33 | Args: 34 | input: the input that's used to call the model. 35 | output: the return value of `model(output)` 36 | """ 37 | pass 38 | 39 | def evaluate(self): 40 | """ 41 | Evaluate/summarize the performance, after processing all input/output pairs. 42 | 43 | Returns: 44 | dict: 45 | A new evaluator class can return a dict of arbitrary format 46 | as long as the user can process the results. 47 | In our train_net.py, we expect the following format: 48 | 49 | * key: the name of the task (e.g., bbox) 50 | * value: a dict of {metric name: score}, e.g.: {"AP50": 80} 51 | """ 52 | pass 53 | 54 | 55 | class DatasetEvaluators(DatasetEvaluator): 56 | def __init__(self, evaluators): 57 | assert len(evaluators) 58 | super().__init__() 59 | self._evaluators = evaluators 60 | 61 | def reset(self): 62 | for evaluator in self._evaluators: 63 | evaluator.reset() 64 | 65 | def process(self, input, output): 66 | for evaluator in self._evaluators: 67 | evaluator.process(input, output) 68 | 69 | def evaluate(self): 70 | results = OrderedDict() 71 | for evaluator in self._evaluators: 72 | result = evaluator.evaluate() 73 | if is_main_process(): 74 | for k, v in result.items(): 75 | assert ( 76 | k not in results 77 | ), "Different evaluators produce results with the same key {}".format(k) 78 | results[k] = v 79 | return results 80 | 81 | 82 | def inference_on_dataset(model, data_loader, evaluator, cfg=None, inference_with_aux=False, proto_dataset=None): 83 | 84 | num_devices = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 85 | logger = logging.getLogger(__name__) 86 | 87 | pcb = None 88 | if cfg.TEST.PCB_ENABLE: 89 | logger.info("Start initializing PCB module, please wait a seconds...") 90 | pcb = PrototypicalCalibrationBlock(cfg, proto_dataset=proto_dataset) 91 | 92 | logger.info("Start inference on {} images".format(len(data_loader))) 93 | total = len(data_loader) # inference data loader must have a fixed length 94 | evaluator.reset() 95 | 96 | logging_interval = 50 97 | num_warmup = min(5, logging_interval - 1, total - 1) 98 | start_time = time.time() 99 | total_compute_time = 0 100 | with inference_context(model), torch.no_grad(): 101 | for idx, inputs in enumerate(data_loader): 102 | if idx == num_warmup: 103 | start_time = time.time() 104 | total_compute_time = 0 105 | 106 | start_compute_time = time.time() 107 | outputs = model(inputs, inference_with_aux) 108 | if cfg.TEST.PCB_ENABLE: 109 | outputs = pcb.execute_calibration(inputs, outputs) 110 | torch.cuda.synchronize() 111 | total_compute_time += time.time() - start_compute_time 112 | evaluator.process(inputs, outputs) 113 | 114 | if (idx + 1) % logging_interval == 0: 115 | duration = time.time() - start_time 116 | seconds_per_img = duration / (idx + 1 - num_warmup) 117 | eta = datetime.timedelta( 118 | seconds=int(seconds_per_img * (total - num_warmup) - duration) 119 | ) 120 | logger.info( 121 | "Inference done {}/{}. {:.4f} s / img. ETA={}".format( 122 | idx + 1, total, seconds_per_img, str(eta) 123 | ) 124 | ) 125 | 126 | # Measure the time only for this worker (before the synchronization barrier) 127 | total_time = int(time.time() - start_time) 128 | total_time_str = str(datetime.timedelta(seconds=total_time)) 129 | # NOTE this format is parsed by grep 130 | logger.info( 131 | "Total inference time: {} ({:.6f} s / img per device, on {} devices)".format( 132 | total_time_str, total_time / (total - num_warmup), num_devices 133 | ) 134 | ) 135 | total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) 136 | logger.info( 137 | "Total inference pure compute time: {} ({:.6f} s / img per device, on {} devices)".format( 138 | total_compute_time_str, total_compute_time / (total - num_warmup), num_devices 139 | ) 140 | ) 141 | 142 | results = evaluator.evaluate() 143 | # An evaluator may return None when not in main process. 144 | # Replace it by an empty dict instead to make it easier for downstream code to handle 145 | if results is None: 146 | results = {} 147 | return results 148 | 149 | 150 | @contextmanager 151 | def inference_context(model): 152 | """ 153 | A context where the model is temporarily changed to eval mode, 154 | and restored to previous mode afterwards. 155 | 156 | Args: 157 | model: a torch Module 158 | """ 159 | training_mode = model.training 160 | model.eval() 161 | yield 162 | model.train(training_mode) 163 | -------------------------------------------------------------------------------- /defrcn/solver/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from enum import Enum 3 | from typing import Any, Callable, Dict, Iterable, List, Set, Type, Union 4 | from detectron2.config import CfgNode 5 | from .lr_scheduler import WarmupCosineLR, WarmupMultiStepLR 6 | 7 | _GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] 8 | _GradientClipper = Callable[[_GradientClipperInput], None] 9 | 10 | 11 | class GradientClipType(Enum): 12 | VALUE = "value" 13 | NORM = "norm" 14 | 15 | 16 | def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: 17 | """ 18 | Creates gradient clipping closure to clip by value or by norm, 19 | according to the provided config. 20 | """ 21 | cfg = cfg.clone() 22 | 23 | def clip_grad_norm(p: _GradientClipperInput): 24 | torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) 25 | 26 | def clip_grad_value(p: _GradientClipperInput): 27 | torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) 28 | 29 | _GRADIENT_CLIP_TYPE_TO_CLIPPER = { 30 | GradientClipType.VALUE: clip_grad_value, 31 | GradientClipType.NORM: clip_grad_norm, 32 | } 33 | return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] 34 | 35 | 36 | def _generate_optimizer_class_with_gradient_clipping( 37 | optimizer_type: Type[torch.optim.Optimizer], gradient_clipper: _GradientClipper 38 | ) -> Type[torch.optim.Optimizer]: 39 | """ 40 | Dynamically creates a new type that inherits the type of a given instance 41 | and overrides the `step` method to add gradient clipping 42 | """ 43 | 44 | def optimizer_wgc_step(self, closure=None): 45 | for group in self.param_groups: 46 | for p in group["params"]: 47 | gradient_clipper(p) 48 | super(type(self), self).step(closure) 49 | 50 | OptimizerWithGradientClip = type( 51 | optimizer_type.__name__ + "WithGradientClip", 52 | (optimizer_type,), 53 | {"step": optimizer_wgc_step}, 54 | ) 55 | return OptimizerWithGradientClip 56 | 57 | 58 | def maybe_add_gradient_clipping( 59 | cfg: CfgNode, optimizer: torch.optim.Optimizer 60 | ) -> torch.optim.Optimizer: 61 | """ 62 | If gradient clipping is enabled through config options, wraps the existing 63 | optimizer instance of some type OptimizerType to become an instance 64 | of the new dynamically created class OptimizerTypeWithGradientClip 65 | that inherits OptimizerType and overrides the `step` method to 66 | include gradient clipping. 67 | 68 | Args: 69 | cfg: CfgNode 70 | configuration options 71 | optimizer: torch.optim.Optimizer 72 | existing optimizer instance 73 | 74 | Return: 75 | optimizer: torch.optim.Optimizer 76 | either the unmodified optimizer instance (if gradient clipping is 77 | disabled), or the same instance with adjusted __class__ to override 78 | the `step` method and include gradient clipping 79 | """ 80 | if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: 81 | return optimizer 82 | grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) 83 | OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( 84 | type(optimizer), grad_clipper 85 | ) 86 | optimizer.__class__ = OptimizerWithGradientClip 87 | return optimizer 88 | 89 | 90 | def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: 91 | """ 92 | Build an optimizer from config. 93 | """ 94 | norm_module_types = ( 95 | torch.nn.BatchNorm1d, 96 | torch.nn.BatchNorm2d, 97 | torch.nn.BatchNorm3d, 98 | torch.nn.SyncBatchNorm, 99 | # NaiveSyncBatchNorm inherits from BatchNorm2d 100 | torch.nn.GroupNorm, 101 | torch.nn.InstanceNorm1d, 102 | torch.nn.InstanceNorm2d, 103 | torch.nn.InstanceNorm3d, 104 | torch.nn.LayerNorm, 105 | torch.nn.LocalResponseNorm, 106 | ) 107 | params: List[Dict[str, Any]] = [] 108 | memo: Set[torch.nn.parameter.Parameter] = set() 109 | for module_name, module in model.named_modules(): 110 | for key, value in module.named_parameters(recurse=False): 111 | if not value.requires_grad: 112 | continue 113 | # Avoid duplicating parameters 114 | if value in memo: 115 | continue 116 | memo.add(value) 117 | lr = cfg.SOLVER.BASE_LR 118 | weight_decay = cfg.SOLVER.WEIGHT_DECAY 119 | if isinstance(module, norm_module_types): 120 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_NORM 121 | elif key == "bias": 122 | # NOTE: unlike Detectron v1, we now default BIAS_LR_FACTOR to 1.0 123 | # and WEIGHT_DECAY_BIAS to WEIGHT_DECAY so that bias optimizer 124 | # hyperparameters are by default exactly the same as for regular 125 | # weights. 126 | lr = cfg.SOLVER.BASE_LR * cfg.SOLVER.BIAS_LR_FACTOR 127 | weight_decay = cfg.SOLVER.WEIGHT_DECAY_BIAS 128 | 129 | params += [{"params": [value], "lr": lr, "weight_decay": weight_decay}] 130 | 131 | optimizer = torch.optim.SGD( 132 | params, cfg.SOLVER.BASE_LR, momentum=cfg.SOLVER.MOMENTUM, nesterov=cfg.SOLVER.NESTEROV 133 | ) 134 | optimizer = maybe_add_gradient_clipping(cfg, optimizer) 135 | return optimizer 136 | 137 | 138 | def build_lr_scheduler( 139 | cfg: CfgNode, optimizer: torch.optim.Optimizer 140 | ) -> torch.optim.lr_scheduler._LRScheduler: 141 | """ 142 | Build a LR scheduler from config. 143 | """ 144 | name = cfg.SOLVER.LR_SCHEDULER_NAME 145 | if name == "WarmupMultiStepLR": 146 | return WarmupMultiStepLR( 147 | optimizer, 148 | cfg.SOLVER.STEPS, 149 | cfg.SOLVER.GAMMA, 150 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 151 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 152 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 153 | ) 154 | elif name == "WarmupCosineLR": 155 | return WarmupCosineLR( 156 | optimizer, 157 | cfg.SOLVER.MAX_ITER, 158 | warmup_factor=cfg.SOLVER.WARMUP_FACTOR, 159 | warmup_iters=cfg.SOLVER.WARMUP_ITERS, 160 | warmup_method=cfg.SOLVER.WARMUP_METHOD, 161 | ) 162 | else: 163 | raise ValueError("Unknown LR scheduler: {}".format(name)) 164 | -------------------------------------------------------------------------------- /defrcn/evaluation/calibration_layer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import torch 5 | import logging 6 | import detectron2 7 | import numpy as np 8 | from detectron2.structures import ImageList 9 | from detectron2.modeling.poolers import ROIPooler 10 | from sklearn.metrics.pairwise import cosine_similarity 11 | from defrcn.dataloader import build_detection_test_loader 12 | from defrcn.evaluation.archs import resnet101 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class PrototypicalCalibrationBlock: 18 | 19 | def __init__(self, cfg, proto_dataset=None): 20 | super().__init__() 21 | self.cfg = cfg 22 | self.device = torch.device(cfg.MODEL.DEVICE) 23 | self.alpha = self.cfg.TEST.PCB_ALPHA 24 | 25 | self.imagenet_model = self.build_model() 26 | 27 | self.proto_dataset = proto_dataset 28 | if self.proto_dataset is None: 29 | self.dataloader = build_detection_test_loader(self.cfg, self.cfg.DATASETS.TRAIN[0]) 30 | self.proto_dataset = self.dataloader.dataset 31 | 32 | self.roi_pooler = ROIPooler(output_size=(1, 1), scales=(1 / 32,), sampling_ratio=(0), pooler_type="ROIAlignV2") 33 | self.prototypes = self.build_prototypes() 34 | 35 | self.exclude_cls = self.clsid_filter() 36 | 37 | def build_model(self): 38 | logger.info("Loading ImageNet Pre-train Model from {}".format(self.cfg.TEST.PCB_MODELPATH)) 39 | if self.cfg.TEST.PCB_MODELTYPE == 'resnet': 40 | imagenet_model = resnet101() 41 | else: 42 | raise NotImplementedError 43 | state_dict = torch.load(self.cfg.TEST.PCB_MODELPATH) 44 | imagenet_model.load_state_dict(state_dict) 45 | imagenet_model = imagenet_model.to(self.device) 46 | imagenet_model.eval() 47 | return imagenet_model 48 | 49 | def build_prototypes(self): 50 | 51 | all_features, all_labels = [], [] 52 | for index in range(len(self.proto_dataset)): 53 | inputs = [self.proto_dataset[index]] 54 | assert len(inputs) == 1 55 | # load support images and gt-boxes 56 | img = cv2.imread(inputs[0]['file_name']) # BGR 57 | img_h, img_w = img.shape[0], img.shape[1] 58 | ratio = img_h / inputs[0]['instances'].image_size[0] 59 | inputs[0]['instances'].gt_boxes.tensor = inputs[0]['instances'].gt_boxes.tensor * ratio 60 | boxes = [x["instances"].gt_boxes.to(self.device) for x in inputs] 61 | 62 | # extract roi features 63 | features = self.extract_roi_features(img, boxes) 64 | all_features.append(features.cpu().data) 65 | 66 | gt_classes = [x['instances'].gt_classes for x in inputs] 67 | all_labels.append(gt_classes[0].cpu().data) 68 | 69 | # concat 70 | all_features = torch.cat(all_features, dim=0) 71 | all_labels = torch.cat(all_labels, dim=0) 72 | assert all_features.shape[0] == all_labels.shape[0] 73 | 74 | # calculate prototype 75 | features_dict = {} 76 | for i, label in enumerate(all_labels): 77 | label = int(label) 78 | if label not in features_dict: 79 | features_dict[label] = [] 80 | features_dict[label].append(all_features[i].unsqueeze(0)) 81 | 82 | prototypes_dict = {} 83 | for label in features_dict: 84 | features = torch.cat(features_dict[label], dim=0) 85 | prototypes_dict[label] = torch.mean(features, dim=0, keepdim=True) 86 | 87 | return prototypes_dict 88 | 89 | def extract_roi_features(self, img, boxes): 90 | """ 91 | :param img: 92 | :param boxes: 93 | :return: 94 | """ 95 | 96 | mean = torch.tensor([0.406, 0.456, 0.485]).reshape((3, 1, 1)).to(self.device) 97 | std = torch.tensor([[0.225, 0.224, 0.229]]).reshape((3, 1, 1)).to(self.device) 98 | 99 | img = img.transpose((2, 0, 1)) 100 | img = torch.from_numpy(img).to(self.device) 101 | images = [(img / 255. - mean) / std] 102 | images = ImageList.from_tensors(images, 0) 103 | conv_feature = self.imagenet_model(images.tensor[:, [2, 1, 0]])[1] # size: BxCxHxW 104 | 105 | box_features = self.roi_pooler([conv_feature], boxes).squeeze(2).squeeze(2) 106 | 107 | activation_vectors = self.imagenet_model.fc(box_features) 108 | 109 | return activation_vectors 110 | 111 | def execute_calibration(self, inputs, dts): 112 | 113 | img = cv2.imread(inputs[0]['file_name']) 114 | 115 | ileft = (dts[0]['instances'].scores > self.cfg.TEST.PCB_UPPER).sum() 116 | iright = (dts[0]['instances'].scores > self.cfg.TEST.PCB_LOWER).sum() 117 | assert ileft <= iright 118 | boxes = [dts[0]['instances'].pred_boxes[ileft:iright]] 119 | 120 | features = self.extract_roi_features(img, boxes) 121 | 122 | for i in range(ileft, iright): 123 | tmp_class = int(dts[0]['instances'].pred_classes[i]) 124 | if tmp_class in self.exclude_cls: 125 | continue 126 | tmp_cos = cosine_similarity(features[i - ileft].cpu().data.numpy().reshape((1, -1)), 127 | self.prototypes[tmp_class].cpu().data.numpy())[0][0] 128 | dts[0]['instances'].scores[i] = dts[0]['instances'].scores[i] * self.alpha + tmp_cos * (1 - self.alpha) 129 | return dts 130 | 131 | def clsid_filter(self): 132 | dsname = self.cfg.DATASETS.TEST[0] 133 | exclude_ids = [] 134 | if 'test_all' in dsname: 135 | if 'coco' in dsname: 136 | exclude_ids = [7, 9, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 137 | 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 41, 42, 43, 44, 45, 138 | 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 61, 63, 64, 65, 139 | 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] 140 | elif 'voc' in dsname: 141 | exclude_ids = list(range(0, 15)) 142 | else: 143 | raise NotImplementedError 144 | return exclude_ids 145 | 146 | 147 | @torch.no_grad() 148 | def concat_all_gather(tensor): 149 | """ 150 | Performs all_gather operation on the provided tensors. 151 | *** Warning ***: torch.distributed.all_gather has no gradient. 152 | """ 153 | tensors_gather = [torch.ones_like(tensor) for _ in range(torch.distributed.get_world_size())] 154 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False) 155 | output = torch.cat(tensors_gather, dim=0) 156 | return output 157 | -------------------------------------------------------------------------------- /defrcn/dataloader/dataset_mapper.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import logging 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | from typing import List, Optional, Union 7 | from detectron2.config import configurable 8 | from detectron2.data import detection_utils as utils 9 | from detectron2.data import transforms as T 10 | 11 | """ 12 | This file contains the default mapping that's applied to "dataset dicts". 13 | """ 14 | 15 | __all__ = ["DatasetMapper"] 16 | 17 | 18 | class DatasetMapper: 19 | """ 20 | A callable which takes a dataset dict in Detectron2 Dataset format, 21 | and map it into a format used by the model. 22 | This is the default callable to be used to map your dataset dict into training data. 23 | You may need to follow it to implement your own one for customized logic, 24 | such as a different way to read or transform images. 25 | See :doc:`/tutorials/data_loading` for details. 26 | The callable currently does the following: 27 | 1. Read the image from "file_name" 28 | 2. Applies cropping/geometric transforms to the image and annotations 29 | 3. Prepare data and annotations to Tensor and :class:`Instances` 30 | """ 31 | 32 | @configurable 33 | def __init__( 34 | self, 35 | is_train: bool, 36 | *, 37 | augmentations: List[Union[T.Augmentation, T.Transform]], 38 | image_format: str, 39 | use_instance_mask: bool = False, 40 | use_keypoint: bool = False, 41 | instance_mask_format: str = "polygon", 42 | keypoint_hflip_indices: Optional[np.ndarray] = None, 43 | precomputed_proposal_topk: Optional[int] = None, 44 | recompute_boxes: bool = False 45 | ): 46 | """ 47 | NOTE: this interface is experimental. 48 | Args: 49 | is_train: whether it's used in training or inference 50 | augmentations: a list of augmentations or deterministic transforms to apply 51 | image_format: an image format supported by :func:`detection_utils.read_image`. 52 | use_instance_mask: whether to process instance segmentation annotations, if available 53 | use_keypoint: whether to process keypoint annotations if available 54 | instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation 55 | masks into this format. 56 | keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` 57 | precomputed_proposal_topk: if given, will load pre-computed 58 | proposals from dataset_dict and keep the top k proposals for each image. 59 | recompute_boxes: whether to overwrite bounding box annotations 60 | by computing tight bounding boxes from instance mask annotations. 61 | """ 62 | if recompute_boxes: 63 | assert use_instance_mask, "recompute_boxes requires instance masks" 64 | # fmt: off 65 | self.is_train = is_train 66 | self.augmentations = T.AugmentationList(augmentations) 67 | self.image_format = image_format 68 | self.use_instance_mask = use_instance_mask 69 | self.instance_mask_format = instance_mask_format 70 | self.use_keypoint = use_keypoint 71 | self.keypoint_hflip_indices = keypoint_hflip_indices 72 | self.proposal_topk = precomputed_proposal_topk 73 | self.recompute_boxes = recompute_boxes 74 | # fmt: on 75 | logger = logging.getLogger(__name__) 76 | mode = "training" if is_train else "inference" 77 | logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") 78 | 79 | @classmethod 80 | def from_config(cls, cfg, is_train: bool = True): 81 | augs = utils.build_augmentation(cfg, is_train) 82 | if cfg.INPUT.CROP.ENABLED and is_train: 83 | augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) 84 | recompute_boxes = cfg.MODEL.MASK_ON 85 | else: 86 | recompute_boxes = False 87 | 88 | ret = { 89 | "is_train": is_train, 90 | "augmentations": augs, 91 | "image_format": cfg.INPUT.FORMAT, 92 | "use_instance_mask": cfg.MODEL.MASK_ON, 93 | "instance_mask_format": cfg.INPUT.MASK_FORMAT, 94 | "use_keypoint": cfg.MODEL.KEYPOINT_ON, 95 | "recompute_boxes": recompute_boxes, 96 | } 97 | 98 | if cfg.MODEL.KEYPOINT_ON: 99 | ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) 100 | 101 | if cfg.MODEL.LOAD_PROPOSALS: 102 | ret["precomputed_proposal_topk"] = ( 103 | cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN 104 | if is_train 105 | else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST 106 | ) 107 | return ret 108 | 109 | def __call__(self, dataset_dict): 110 | """ 111 | Args: 112 | dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. 113 | Returns: 114 | dict: a format that builtin models in detectron2 accept 115 | """ 116 | dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below 117 | # USER: Write your own image loading if it's not from a file 118 | image = utils.read_image(dataset_dict["file_name"], format=self.image_format) 119 | utils.check_image_size(dataset_dict, image) 120 | 121 | # USER: Remove if you don't do semantic/panoptic segmentation. 122 | if "sem_seg_file_name" in dataset_dict: 123 | sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) 124 | else: 125 | sem_seg_gt = None 126 | 127 | aug_input = T.AugInput(image, sem_seg=sem_seg_gt) 128 | transforms = self.augmentations(aug_input) 129 | image, sem_seg_gt = aug_input.image, aug_input.sem_seg 130 | 131 | image_shape = image.shape[:2] # h, w 132 | # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, 133 | # but not efficient on large generic data structures due to the use of pickle & mp.Queue. 134 | # Therefore it's important to use torch.Tensor. 135 | dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) 136 | if sem_seg_gt is not None: 137 | dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) 138 | 139 | # USER: Remove if you don't use pre-computed proposals. 140 | # Most users would not need this feature. 141 | if self.proposal_topk is not None: 142 | utils.transform_proposals( 143 | dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk 144 | ) 145 | 146 | # if not self.is_train: 147 | # USER: Modify this if you want to keep them for some reason. 148 | # dataset_dict.pop("annotations", None) 149 | dataset_dict.pop("sem_seg_file_name", None) 150 | # return dataset_dict 151 | 152 | if "annotations" in dataset_dict: 153 | # USER: Modify this if you want to keep them for some reason. 154 | for anno in dataset_dict["annotations"]: 155 | if not self.use_instance_mask: 156 | anno.pop("segmentation", None) 157 | if not self.use_keypoint: 158 | anno.pop("keypoints", None) 159 | 160 | # USER: Implement additional transformations if you have other types of data 161 | annos = [ 162 | utils.transform_instance_annotations( 163 | obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices 164 | ) 165 | for obj in dataset_dict.pop("annotations") 166 | if obj.get("iscrowd", 0) == 0 167 | ] 168 | instances = utils.annotations_to_instances( 169 | annos, image_shape, mask_format=self.instance_mask_format 170 | ) 171 | 172 | # After transforms such as cropping are applied, the bounding box may no longer 173 | # tightly bound the object. As an example, imagine a triangle object 174 | # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight 175 | # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to 176 | # the intersection of original bounding box and the cropping box. 177 | if self.recompute_boxes: 178 | instances.gt_boxes = instances.gt_masks.get_bounding_boxes() 179 | dataset_dict["instances"] = utils.filter_empty_instances(instances) 180 | return dataset_dict 181 | -------------------------------------------------------------------------------- /defrcn/config/compat.py: -------------------------------------------------------------------------------- 1 | """ 2 | Backward compatibility of configs. 3 | Instructions to bump version: 4 | + It's not needed to bump version if new keys are added. 5 | It's only needed when backward-incompatible changes happen 6 | (i.e., some existing keys disappear, or the meaning of a key changes) 7 | + To bump version, do the following: 8 | 1. Increment _C.VERSION in defaults.py 9 | 2. Add a converter in this file. 10 | Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X, 11 | and a function "downgrade" which in-place downgrades config from X to X-1 12 | In each function, VERSION is left unchanged. 13 | Each converter assumes that its input has the relevant keys 14 | (i.e., the input is not a partial config). 15 | 3. Run the tests (test_config.py) to make sure the upgrade & downgrade 16 | functions are consistent. 17 | """ 18 | 19 | import logging 20 | from typing import List, Optional, Tuple 21 | 22 | from .config import CfgNode as CN 23 | from .defaults import _CC as _C 24 | 25 | __all__ = ["upgrade_config", "downgrade_config"] 26 | 27 | 28 | def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN: 29 | """ 30 | Upgrade a config from its current version to a newer version. 31 | Args: 32 | cfg (CfgNode): 33 | to_version (int): defaults to the latest version. 34 | """ 35 | cfg = cfg.clone() 36 | if to_version is None: 37 | to_version = _C.VERSION 38 | 39 | assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format( 40 | cfg.VERSION, to_version 41 | ) 42 | for k in range(cfg.VERSION, to_version): 43 | converter = globals()["ConverterV" + str(k + 1)] 44 | converter.upgrade(cfg) 45 | cfg.VERSION = k + 1 46 | return cfg 47 | 48 | 49 | def downgrade_config(cfg: CN, to_version: int) -> CN: 50 | """ 51 | Downgrade a config from its current version to an older version. 52 | Args: 53 | cfg (CfgNode): 54 | to_version (int): 55 | Note: 56 | A general downgrade of arbitrary configs is not always possible due to the 57 | different functionalities in different versions. 58 | The purpose of downgrade is only to recover the defaults in old versions, 59 | allowing it to load an old partial yaml config. 60 | Therefore, the implementation only needs to fill in the default values 61 | in the old version when a general downgrade is not possible. 62 | """ 63 | cfg = cfg.clone() 64 | assert ( 65 | cfg.VERSION >= to_version 66 | ), "Cannot downgrade from v{} to v{}!".format(cfg.VERSION, to_version) 67 | for k in range(cfg.VERSION, to_version, -1): 68 | converter = globals()["ConverterV" + str(k)] 69 | converter.downgrade(cfg) 70 | cfg.VERSION = k - 1 71 | return cfg 72 | 73 | 74 | def guess_version(cfg: CN, filename: str) -> int: 75 | """ 76 | Guess the version of a partial config where the VERSION field is not specified. 77 | Returns the version, or the latest if cannot make a guess. 78 | This makes it easier for users to migrate. 79 | """ 80 | logger = logging.getLogger(__name__) 81 | 82 | def _has(name: str) -> bool: 83 | cur = cfg 84 | for n in name.split("."): 85 | if n not in cur: 86 | return False 87 | cur = cur[n] 88 | return True 89 | 90 | # Most users' partial configs have "MODEL.WEIGHT", so guess on it 91 | ret = None 92 | if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"): 93 | ret = 1 94 | 95 | if ret is not None: 96 | logger.warning( 97 | "Config '{}' has no VERSION. Assuming it to be v{}.".format( 98 | filename, ret 99 | ) 100 | ) 101 | else: 102 | ret = _C.VERSION 103 | logger.warning( 104 | "Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format( 105 | filename, ret 106 | ) 107 | ) 108 | return ret 109 | 110 | 111 | def _rename(cfg: CN, old: str, new: str) -> None: 112 | old_keys = old.split(".") 113 | new_keys = new.split(".") 114 | 115 | def _set(key_seq: List[str], val: str) -> None: 116 | cur = cfg 117 | for k in key_seq[:-1]: 118 | if k not in cur: 119 | cur[k] = CN() 120 | cur = cur[k] 121 | cur[key_seq[-1]] = val 122 | 123 | def _get(key_seq: List[str]) -> CN: 124 | cur = cfg 125 | for k in key_seq: 126 | cur = cur[k] 127 | return cur 128 | 129 | def _del(key_seq: List[str]) -> None: 130 | cur = cfg 131 | for k in key_seq[:-1]: 132 | cur = cur[k] 133 | del cur[key_seq[-1]] 134 | if len(cur) == 0 and len(key_seq) > 1: 135 | _del(key_seq[:-1]) 136 | 137 | _set(new_keys, _get(old_keys)) 138 | _del(old_keys) 139 | 140 | 141 | class _RenameConverter: 142 | """ 143 | A converter that handles simple rename. 144 | """ 145 | 146 | RENAME: List[ 147 | Tuple[str, str] 148 | ] = [] # list of tuples of (old name, new name) 149 | 150 | @classmethod 151 | def upgrade(cls, cfg: CN) -> None: 152 | for old, new in cls.RENAME: 153 | _rename(cfg, old, new) 154 | 155 | @classmethod 156 | def downgrade(cls, cfg: CN) -> None: 157 | for old, new in cls.RENAME[::-1]: 158 | _rename(cfg, new, old) 159 | 160 | 161 | class ConverterV1(_RenameConverter): 162 | RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")] 163 | 164 | 165 | class ConverterV2(_RenameConverter): 166 | """ 167 | A large bulk of rename, before public release. 168 | """ 169 | 170 | RENAME = [ 171 | ("MODEL.WEIGHT", "MODEL.WEIGHTS"), 172 | ( 173 | "MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", 174 | "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT", 175 | ), 176 | ("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"), 177 | ( 178 | "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", 179 | "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT", 180 | ), 181 | ( 182 | "MODEL.PANOPTIC_FPN.COMBINE_ON", 183 | "MODEL.PANOPTIC_FPN.COMBINE.ENABLED", 184 | ), 185 | ( 186 | "MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD", 187 | "MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH", 188 | ), 189 | ( 190 | "MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT", 191 | "MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT", 192 | ), 193 | ( 194 | "MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD", 195 | "MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH", 196 | ), 197 | ("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"), 198 | ("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"), 199 | ( 200 | "MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", 201 | "MODEL.RETINANET.SCORE_THRESH_TEST", 202 | ), 203 | ( 204 | "MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", 205 | "MODEL.RETINANET.TOPK_CANDIDATES_TEST", 206 | ), 207 | ( 208 | "MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", 209 | "MODEL.RETINANET.NMS_THRESH_TEST", 210 | ), 211 | ("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"), 212 | ("TEST.AUG_ON", "TEST.AUG.ENABLED"), 213 | ("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"), 214 | ("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"), 215 | ("TEST.AUG_FLIP", "TEST.AUG.FLIP"), 216 | ] 217 | 218 | @classmethod 219 | def upgrade(cls, cfg: CN) -> None: 220 | super().upgrade(cfg) 221 | 222 | if cfg.MODEL.META_ARCHITECTURE == "RetinaNet": 223 | _rename( 224 | cfg, 225 | "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", 226 | "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", 227 | ) 228 | _rename( 229 | cfg, 230 | "MODEL.RETINANET.ANCHOR_SIZES", 231 | "MODEL.ANCHOR_GENERATOR.SIZES", 232 | ) 233 | del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"] 234 | del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"] 235 | else: 236 | _rename( 237 | cfg, 238 | "MODEL.RPN.ANCHOR_ASPECT_RATIOS", 239 | "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", 240 | ) 241 | _rename( 242 | cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES" 243 | ) 244 | del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"] 245 | del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"] 246 | del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"] 247 | 248 | @classmethod 249 | def downgrade(cls, cfg: CN) -> None: 250 | super().downgrade(cfg) 251 | 252 | _rename( 253 | cfg, 254 | "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", 255 | "MODEL.RPN.ANCHOR_ASPECT_RATIOS", 256 | ) 257 | _rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES") 258 | cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = ( 259 | cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS 260 | ) 261 | cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES 262 | cfg.MODEL.RETINANET.ANCHOR_STRIDES = ( 263 | [] 264 | ) # this is not used anywhere in any version 265 | -------------------------------------------------------------------------------- /defrcn/evaluation/coco_evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import json 4 | import copy 5 | import torch 6 | import logging 7 | import itertools 8 | import contextlib 9 | import numpy as np 10 | from tabulate import tabulate 11 | from collections import OrderedDict 12 | from fvcore.common.file_io import PathManager 13 | from pycocotools.coco import COCO 14 | from pycocotools.cocoeval import COCOeval 15 | from detectron2.structures import BoxMode 16 | from detectron2.utils import comm as comm 17 | from detectron2.data import MetadataCatalog 18 | from detectron2.utils.logger import create_small_table 19 | from detectron2.data.datasets.coco import convert_to_coco_json 20 | from defrcn.evaluation.evaluator import DatasetEvaluator 21 | 22 | 23 | class COCOEvaluator(DatasetEvaluator): 24 | 25 | def __init__(self, dataset_name, distributed, output_dir=None): 26 | 27 | self._distributed = distributed 28 | self._output_dir = output_dir 29 | self._dataset_name = dataset_name 30 | self._cpu_device = torch.device("cpu") 31 | self._logger = logging.getLogger(__name__) 32 | 33 | self._metadata = MetadataCatalog.get(dataset_name) 34 | if not hasattr(self._metadata, "json_file"): 35 | self._logger.warning( 36 | f"json_file was not found in MetaDataCatalog for '{dataset_name}'") 37 | cache_path = convert_to_coco_json(dataset_name, output_dir) 38 | self._metadata.json_file = cache_path 39 | self._is_splits = "all" in dataset_name or "base" in dataset_name \ 40 | or "novel" in dataset_name 41 | self._base_classes = [ 42 | 8, 10, 11, 13, 14, 15, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 43 | 36, 37, 38, 39, 40, 41, 42, 43, 46, 47, 48, 49, 50, 51, 52, 53, 54, 44 | 55, 56, 57, 58, 59, 60, 61, 65, 70, 73, 74, 75, 76, 77, 78, 79, 80, 45 | 81, 82, 84, 85, 86, 87, 88, 89, 90, 46 | ] 47 | self._novel_classes = [1, 2, 3, 4, 5, 6, 7, 9, 16, 17, 18, 19, 20, 21, 48 | 44, 62, 63, 64, 67, 72] 49 | 50 | json_file = PathManager.get_local_path(self._metadata.json_file) 51 | with contextlib.redirect_stdout(io.StringIO()): 52 | self._coco_api = COCO(json_file) 53 | self._do_evaluation = "annotations" in self._coco_api.dataset 54 | 55 | def reset(self): 56 | self._predictions = [] 57 | self._coco_results = [] 58 | 59 | def process(self, inputs, outputs): 60 | """ 61 | Args: 62 | inputs: the inputs to a COCO model (e.g., GeneralizedRCNN). 63 | It is a list of dict. Each dict corresponds to an image and 64 | contains keys like "height", "width", "file_name", "image_id". 65 | outputs: the outputs of a COCO model. It is a list of dicts with key 66 | "instances" that contains :class:`Instances`. 67 | """ 68 | for input, output in zip(inputs, outputs): 69 | prediction = {"image_id": input["image_id"]} 70 | # TODO this is ugly 71 | if "instances" in output: 72 | instances = output["instances"].to(self._cpu_device) 73 | prediction["instances"] = instances_to_coco_json( 74 | instances, input["image_id"]) 75 | self._predictions.append(prediction) 76 | 77 | def evaluate(self): 78 | if self._distributed: 79 | comm.synchronize() 80 | self._predictions = comm.gather(self._predictions, dst=0) 81 | self._predictions = list(itertools.chain(*self._predictions)) 82 | if not comm.is_main_process(): 83 | return {} 84 | 85 | if len(self._predictions) == 0: 86 | self._logger.warning( 87 | "[COCOEvaluator] Did not receive valid predictions.") 88 | return {} 89 | 90 | if self._output_dir: 91 | PathManager.mkdirs(self._output_dir) 92 | file_path = os.path.join( 93 | self._output_dir, "instances_predictions.pth") 94 | with PathManager.open(file_path, "wb") as f: 95 | torch.save(self._predictions, f) 96 | 97 | self._results = OrderedDict() 98 | if "instances" in self._predictions[0]: 99 | self._eval_predictions() 100 | # Copy so the caller can do whatever with results 101 | return copy.deepcopy(self._results) 102 | 103 | def _eval_predictions(self): 104 | """ 105 | Evaluate self._predictions on the instance detection task. 106 | Fill self._results with the metrics of the instance detection task. 107 | """ 108 | self._logger.info("Preparing results for COCO format ...") 109 | self._coco_results = list( 110 | itertools.chain(*[x["instances"] for x in self._predictions])) 111 | 112 | # unmap the category ids for COCO 113 | if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): 114 | reverse_id_mapping = { 115 | v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() 116 | } 117 | for result in self._coco_results: 118 | result["category_id"] = reverse_id_mapping[result["category_id"]] 119 | 120 | if self._output_dir: 121 | file_path = os.path.join(self._output_dir, "coco_instances_results.json") 122 | self._logger.info("Saving results to {}".format(file_path)) 123 | with PathManager.open(file_path, "w") as f: 124 | f.write(json.dumps(self._coco_results)) 125 | f.flush() 126 | 127 | if not self._do_evaluation: 128 | self._logger.info("Annotations are not available for evaluation.") 129 | return 130 | 131 | self._logger.info("Evaluating predictions ...") 132 | if self._is_splits: 133 | self._results["bbox"] = {} 134 | for split, classes, names in [ 135 | ("all", None, self._metadata.get("thing_classes")), 136 | ("base", self._base_classes, self._metadata.get("base_classes")), 137 | ("novel", self._novel_classes, self._metadata.get("novel_classes"))]: 138 | if "all" not in self._dataset_name and \ 139 | split not in self._dataset_name: 140 | continue 141 | coco_eval = ( 142 | _evaluate_predictions_on_coco( 143 | self._coco_api, self._coco_results, "bbox", classes, 144 | ) 145 | if len(self._coco_results) > 0 146 | else None # cocoapi does not handle empty results very well 147 | ) 148 | res_ = self._derive_coco_results(coco_eval, "bbox", class_names=names) 149 | res = {} 150 | for metric in res_.keys(): 151 | if len(metric) <= 4: 152 | if split == "all": 153 | res[metric] = res_[metric] 154 | elif split == "base": 155 | res["b"+metric] = res_[metric] 156 | elif split == "novel": 157 | res["n"+metric] = res_[metric] 158 | self._results["bbox"].update(res) 159 | 160 | # add "AP" if not already in 161 | if "AP" not in self._results["bbox"]: 162 | if "nAP" in self._results["bbox"]: 163 | self._results["bbox"]["AP"] = self._results["bbox"]["nAP"] 164 | else: 165 | self._results["bbox"]["AP"] = self._results["bbox"]["bAP"] 166 | else: 167 | coco_eval = ( 168 | _evaluate_predictions_on_coco( 169 | self._coco_api, self._coco_results, "bbox", 170 | ) 171 | if len(self._coco_results) > 0 172 | else None # cocoapi does not handle empty results very well 173 | ) 174 | res = self._derive_coco_results( 175 | coco_eval, "bbox", 176 | class_names=self._metadata.get("thing_classes") 177 | ) 178 | self._results["bbox"] = res 179 | 180 | def _derive_coco_results(self, coco_eval, iou_type, class_names=None): 181 | """ 182 | Derive the desired score numbers from summarized COCOeval. 183 | 184 | Args: 185 | coco_eval (None or COCOEval): None represents no predictions from model. 186 | iou_type (str): 187 | class_names (None or list[str]): if provided, will use it to predict 188 | per-category AP. 189 | 190 | Returns: 191 | a dict of {metric name: score} 192 | """ 193 | 194 | metrics = ["AP", "AP50", "AP75", "APs", "APm", "APl"] 195 | 196 | if coco_eval is None: 197 | self._logger.warn("No predictions from the model! Set scores to -1") 198 | return {metric: -1 for metric in metrics} 199 | 200 | # the standard metrics 201 | results = { 202 | metric: float(coco_eval.stats[idx] * 100) \ 203 | for idx, metric in enumerate(metrics) 204 | } 205 | self._logger.info( 206 | "Evaluation results for {}: \n".format(iou_type) + \ 207 | create_small_table(results) 208 | ) 209 | 210 | if class_names is None or len(class_names) <= 1: 211 | return results 212 | # Compute per-category AP 213 | precisions = coco_eval.eval["precision"] 214 | # precision has dims (iou, recall, cls, area range, max dets) 215 | assert len(class_names) == precisions.shape[2] 216 | 217 | results_per_category = [] 218 | for idx, name in enumerate(class_names): 219 | # area range index 0: all area ranges 220 | # max dets index -1: typically 100 per image 221 | precision = precisions[:, :, idx, 0, -1] 222 | precision = precision[precision > -1] 223 | ap = np.mean(precision) if precision.size else float("nan") 224 | results_per_category.append(("{}".format(name), float(ap * 100))) 225 | 226 | # tabulate it 227 | N_COLS = min(6, len(results_per_category) * 2) 228 | results_flatten = list(itertools.chain(*results_per_category)) 229 | results_2d = itertools.zip_longest( 230 | *[results_flatten[i::N_COLS] for i in range(N_COLS)]) 231 | table = tabulate( 232 | results_2d, 233 | tablefmt="pipe", 234 | floatfmt=".3f", 235 | headers=["category", "AP"] * (N_COLS // 2), 236 | numalign="left", 237 | ) 238 | self._logger.info("Per-category {} AP: \n".format(iou_type) + table) 239 | 240 | results.update({"AP-" + name: ap for name, ap in results_per_category}) 241 | return results 242 | 243 | 244 | def instances_to_coco_json(instances, img_id): 245 | """ 246 | Dump an "Instances" object to a COCO-format json that's used for evaluation. 247 | 248 | Args: 249 | instances (Instances): 250 | img_id (int): the image id 251 | 252 | Returns: 253 | list[dict]: list of json annotations in COCO format. 254 | """ 255 | num_instance = len(instances) 256 | if num_instance == 0: 257 | return [] 258 | 259 | boxes = instances.pred_boxes.tensor.numpy() 260 | boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) 261 | boxes = boxes.tolist() 262 | scores = instances.scores.tolist() 263 | classes = instances.pred_classes.tolist() 264 | 265 | results = [] 266 | for k in range(num_instance): 267 | result = { 268 | "image_id": img_id, 269 | "category_id": classes[k], 270 | "bbox": boxes[k], 271 | "score": scores[k], 272 | } 273 | results.append(result) 274 | return results 275 | 276 | 277 | def _evaluate_predictions_on_coco(coco_gt, coco_results, iou_type, catIds=None): 278 | """ 279 | Evaluate the coco results using COCOEval API. 280 | """ 281 | assert len(coco_results) > 0 282 | 283 | coco_dt = coco_gt.loadRes(coco_results) 284 | coco_eval = COCOeval(coco_gt, coco_dt, iou_type) 285 | if catIds is not None: 286 | coco_eval.params.catIds = catIds 287 | coco_eval.evaluate() 288 | coco_eval.accumulate() 289 | coco_eval.summarize() 290 | 291 | return coco_eval 292 | -------------------------------------------------------------------------------- /defrcn/evaluation/pascal_voc_evaluation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import torch 4 | import logging 5 | import tempfile 6 | import numpy as np 7 | from functools import lru_cache 8 | from xml.etree import ElementTree as ET 9 | from collections import OrderedDict, defaultdict 10 | from detectron2.utils import comm 11 | from detectron2.data import MetadataCatalog 12 | from detectron2.utils.logger import create_small_table 13 | from defrcn.evaluation.evaluator import DatasetEvaluator 14 | 15 | 16 | class PascalVOCDetectionEvaluator(DatasetEvaluator): 17 | """ 18 | Evaluate Pascal VOC AP. 19 | It contains a synchronization, therefore has to be called from all ranks. 20 | 21 | Note that this is a rewrite of the official Matlab API. 22 | The results should be similar, but not identical to the one produced by 23 | the official API. 24 | """ 25 | 26 | def __init__(self, dataset_name): 27 | """ 28 | Args: 29 | dataset_name (str): name of the dataset, e.g., "voc_2007_test" 30 | """ 31 | self._dataset_name = dataset_name 32 | meta = MetadataCatalog.get(dataset_name) 33 | self._anno_file_template = os.path.join(meta.dirname, "Annotations", "{}.xml") 34 | self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt") 35 | self._class_names = meta.thing_classes 36 | # add this two terms for calculating the mAP of different subset 37 | self._base_classes = meta.base_classes 38 | self._novel_classes = meta.novel_classes 39 | assert meta.year in [2007, 2012], meta.year 40 | self._is_2007 = meta.year == 2007 41 | self._cpu_device = torch.device("cpu") 42 | self._logger = logging.getLogger(__name__) 43 | 44 | def reset(self): 45 | self._predictions = defaultdict(list) # class name -> list of prediction strings 46 | 47 | def process(self, inputs, outputs): 48 | for input, output in zip(inputs, outputs): 49 | image_id = input["image_id"] 50 | instances = output["instances"].to(self._cpu_device) 51 | boxes = instances.pred_boxes.tensor.numpy() 52 | scores = instances.scores.tolist() 53 | classes = instances.pred_classes.tolist() 54 | for box, score, cls in zip(boxes, scores, classes): 55 | xmin, ymin, xmax, ymax = box 56 | # The inverse of data loading logic in `datasets/pascal_voc.py` 57 | xmin += 1 58 | ymin += 1 59 | self._predictions[cls].append( 60 | f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}" 61 | ) 62 | 63 | def evaluate(self): 64 | """ 65 | Returns: 66 | dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75". 67 | """ 68 | all_predictions = comm.gather(self._predictions, dst=0) 69 | if not comm.is_main_process(): 70 | return 71 | predictions = defaultdict(list) 72 | for predictions_per_rank in all_predictions: 73 | for clsid, lines in predictions_per_rank.items(): 74 | predictions[clsid].extend(lines) 75 | del all_predictions 76 | 77 | self._logger.info( 78 | "Evaluating {} using {} metric. " 79 | "Note that results do not use the official Matlab API.".format( 80 | self._dataset_name, 2007 if self._is_2007 else 2012 81 | ) 82 | ) 83 | 84 | with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: 85 | res_file_template = os.path.join(dirname, "{}.txt") 86 | 87 | aps = defaultdict(list) # iou -> ap per class 88 | aps_base = defaultdict(list) 89 | aps_novel = defaultdict(list) 90 | exist_base, exist_novel = False, False 91 | for cls_id, cls_name in enumerate(self._class_names): 92 | lines = predictions.get(cls_id, [""]) 93 | 94 | with open(res_file_template.format(cls_name), "w") as f: 95 | f.write("\n".join(lines)) 96 | 97 | for thresh in range(50, 100, 5): 98 | rec, prec, ap = voc_eval( 99 | res_file_template, 100 | self._anno_file_template, 101 | self._image_set_path, 102 | cls_name, 103 | ovthresh=thresh / 100.0, 104 | use_07_metric=self._is_2007, 105 | ) 106 | aps[thresh].append(ap * 100) 107 | 108 | if self._base_classes is not None and cls_name in self._base_classes: 109 | aps_base[thresh].append(ap * 100) 110 | exist_base = True 111 | 112 | if self._novel_classes is not None and cls_name in self._novel_classes: 113 | aps_novel[thresh].append(ap * 100) 114 | exist_novel = True 115 | 116 | ret = OrderedDict() 117 | mAP = {iou: np.mean(x) for iou, x in aps.items()} 118 | ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} 119 | 120 | # adding evaluation of the base and novel classes 121 | if exist_base: 122 | mAP_base = {iou: np.mean(x) for iou, x in aps_base.items()} 123 | ret["bbox"].update( 124 | {"bAP": np.mean(list(mAP_base.values())), "bAP50": mAP_base[50], 125 | "bAP75": mAP_base[75]} 126 | ) 127 | 128 | if exist_novel: 129 | mAP_novel = {iou: np.mean(x) for iou, x in aps_novel.items()} 130 | ret["bbox"].update({ 131 | "nAP": np.mean(list(mAP_novel.values())), "nAP50": mAP_novel[50], 132 | "nAP75": mAP_novel[75] 133 | }) 134 | 135 | # write per class AP to logger 136 | per_class_res = {self._class_names[idx]: ap for idx, ap in enumerate(aps[50])} 137 | 138 | self._logger.info("Evaluate per-class mAP50:\n"+create_small_table(per_class_res)) 139 | self._logger.info("Evaluate overall bbox:\n"+create_small_table(ret["bbox"])) 140 | return ret 141 | 142 | 143 | ############################################################################## 144 | # 145 | # Below code is modified from 146 | # https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py 147 | # -------------------------------------------------------- 148 | # Fast/er R-CNN 149 | # Licensed under The MIT License [see LICENSE for details] 150 | # Written by Bharath Hariharan 151 | # -------------------------------------------------------- 152 | 153 | """Python implementation of the PASCAL VOC devkit's AP evaluation code.""" 154 | 155 | 156 | @lru_cache(maxsize=None) 157 | def parse_rec(filename): 158 | """Parse a PASCAL VOC xml file.""" 159 | tree = ET.parse(filename) 160 | objects = [] 161 | for obj in tree.findall("object"): 162 | obj_struct = {} 163 | obj_struct["name"] = obj.find("name").text 164 | obj_struct["pose"] = obj.find("pose").text 165 | obj_struct["truncated"] = int(obj.find("truncated").text) 166 | obj_struct["difficult"] = int(obj.find("difficult").text) 167 | bbox = obj.find("bndbox") 168 | obj_struct["bbox"] = [ 169 | int(bbox.find("xmin").text), 170 | int(bbox.find("ymin").text), 171 | int(bbox.find("xmax").text), 172 | int(bbox.find("ymax").text), 173 | ] 174 | objects.append(obj_struct) 175 | 176 | return objects 177 | 178 | 179 | def voc_ap(rec, prec, use_07_metric=False): 180 | """Compute VOC AP given precision and recall. If use_07_metric is true, uses 181 | the VOC 07 11-point method (default:False). 182 | """ 183 | if use_07_metric: 184 | # 11 point metric 185 | ap = 0.0 186 | for t in np.arange(0.0, 1.1, 0.1): 187 | if np.sum(rec >= t) == 0: 188 | p = 0 189 | else: 190 | p = np.max(prec[rec >= t]) 191 | ap = ap + p / 11.0 192 | else: 193 | # correct AP calculation 194 | # first append sentinel values at the end 195 | mrec = np.concatenate(([0.0], rec, [1.0])) 196 | mpre = np.concatenate(([0.0], prec, [0.0])) 197 | 198 | # compute the precision envelope 199 | for i in range(mpre.size - 1, 0, -1): 200 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 201 | 202 | # to calculate area under PR curve, look for points 203 | # where X axis (recall) changes value 204 | i = np.where(mrec[1:] != mrec[:-1])[0] 205 | 206 | # and sum (\Delta recall) * prec 207 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 208 | return ap 209 | 210 | 211 | def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False): 212 | """rec, prec, ap = voc_eval(detpath, 213 | annopath, 214 | imagesetfile, 215 | classname, 216 | [ovthresh], 217 | [use_07_metric]) 218 | 219 | Top level function that does the PASCAL VOC evaluation. 220 | 221 | detpath: Path to detections 222 | detpath.format(classname) should produce the detection results file. 223 | annopath: Path to annotations 224 | annopath.format(imagename) should be the xml annotations file. 225 | imagesetfile: Text file containing the list of images, one image per line. 226 | classname: Category name (duh) 227 | [ovthresh]: Overlap threshold (default = 0.5) 228 | [use_07_metric]: Whether to use VOC07's 11 point AP computation 229 | (default False) 230 | """ 231 | # assumes detections are in detpath.format(classname) 232 | # assumes annotations are in annopath.format(imagename) 233 | # assumes imagesetfile is a text file with each line an image name 234 | 235 | # first load gt 236 | # read list of images 237 | with open(imagesetfile, "r") as f: 238 | lines = f.readlines() 239 | imagenames = [x.strip() for x in lines] 240 | 241 | # load annots 242 | recs = {} 243 | for imagename in imagenames: 244 | recs[imagename] = parse_rec(annopath.format(imagename)) 245 | 246 | # extract gt objects for this class 247 | class_recs = {} 248 | npos = 0 249 | for imagename in imagenames: 250 | R = [obj for obj in recs[imagename] if obj["name"] == classname] 251 | bbox = np.array([x["bbox"] for x in R]) 252 | difficult = np.array([x["difficult"] for x in R]).astype(np.bool) 253 | # difficult = np.array([False for x in R]).astype(np.bool) # treat all "difficult" as GT 254 | det = [False] * len(R) 255 | npos = npos + sum(~difficult) 256 | class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det} 257 | 258 | # read dets 259 | detfile = detpath.format(classname) 260 | with open(detfile, "r") as f: 261 | lines = f.readlines() 262 | 263 | splitlines = [x.strip().split(" ") for x in lines] 264 | image_ids = [x[0] for x in splitlines] 265 | confidence = np.array([float(x[1]) for x in splitlines]) 266 | BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4) 267 | 268 | # sort by confidence 269 | sorted_ind = np.argsort(-confidence) 270 | BB = BB[sorted_ind, :] 271 | image_ids = [image_ids[x] for x in sorted_ind] 272 | 273 | # go down dets and mark TPs and FPs 274 | nd = len(image_ids) 275 | tp = np.zeros(nd) 276 | fp = np.zeros(nd) 277 | for d in range(nd): 278 | R = class_recs[image_ids[d]] 279 | bb = BB[d, :].astype(float) 280 | ovmax = -np.inf 281 | BBGT = R["bbox"].astype(float) 282 | 283 | if BBGT.size > 0: 284 | # compute overlaps 285 | # intersection 286 | ixmin = np.maximum(BBGT[:, 0], bb[0]) 287 | iymin = np.maximum(BBGT[:, 1], bb[1]) 288 | ixmax = np.minimum(BBGT[:, 2], bb[2]) 289 | iymax = np.minimum(BBGT[:, 3], bb[3]) 290 | iw = np.maximum(ixmax - ixmin + 1.0, 0.0) 291 | ih = np.maximum(iymax - iymin + 1.0, 0.0) 292 | inters = iw * ih 293 | 294 | # union 295 | uni = ( 296 | (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0) 297 | + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0) 298 | - inters 299 | ) 300 | 301 | overlaps = inters / uni 302 | ovmax = np.max(overlaps) 303 | jmax = np.argmax(overlaps) 304 | 305 | if ovmax > ovthresh: 306 | if not R["difficult"][jmax]: 307 | if not R["det"][jmax]: 308 | tp[d] = 1.0 309 | R["det"][jmax] = 1 310 | else: 311 | fp[d] = 1.0 312 | else: 313 | fp[d] = 1.0 314 | 315 | # compute precision recall 316 | fp = np.cumsum(fp) 317 | tp = np.cumsum(tp) 318 | rec = tp / float(npos) 319 | # avoid divide by zero in case the first detection matches a difficult 320 | # ground truth 321 | prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) 322 | ap = voc_ap(rec, prec, use_07_metric) 323 | 324 | return rec, prec, ap 325 | -------------------------------------------------------------------------------- /defrcn/evaluation/archs/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.hub import load_state_dict_from_url 4 | 5 | 6 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 7 | 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 8 | 'wide_resnet50_2', 'wide_resnet101_2'] 9 | 10 | 11 | model_urls = { 12 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 13 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 14 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 15 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 16 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 17 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 18 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 19 | 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', 20 | 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=dilation, groups=groups, bias=False, dilation=dilation) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 39 | base_width=64, dilation=1, norm_layer=None): 40 | super(BasicBlock, self).__init__() 41 | if norm_layer is None: 42 | norm_layer = nn.BatchNorm2d 43 | if groups != 1 or base_width != 64: 44 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 45 | if dilation > 1: 46 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 47 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 48 | self.conv1 = conv3x3(inplanes, planes, stride) 49 | self.bn1 = norm_layer(planes) 50 | self.relu = nn.ReLU(inplace=True) 51 | self.conv2 = conv3x3(planes, planes) 52 | self.bn2 = norm_layer(planes) 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | identity = x 58 | 59 | out = self.conv1(x) 60 | out = self.bn1(out) 61 | out = self.relu(out) 62 | 63 | out = self.conv2(out) 64 | out = self.bn2(out) 65 | 66 | if self.downsample is not None: 67 | identity = self.downsample(x) 68 | 69 | out += identity 70 | out = self.relu(out) 71 | 72 | return out 73 | 74 | 75 | class Bottleneck(nn.Module): 76 | expansion = 4 77 | 78 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 79 | base_width=64, dilation=1, norm_layer=None): 80 | super(Bottleneck, self).__init__() 81 | if norm_layer is None: 82 | norm_layer = nn.BatchNorm2d 83 | width = int(planes * (base_width / 64.)) * groups 84 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 85 | self.conv1 = conv1x1(inplanes, width) 86 | self.bn1 = norm_layer(width) 87 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 88 | self.bn2 = norm_layer(width) 89 | self.conv3 = conv1x1(width, planes * self.expansion) 90 | self.bn3 = norm_layer(planes * self.expansion) 91 | self.relu = nn.ReLU(inplace=True) 92 | self.downsample = downsample 93 | self.stride = stride 94 | 95 | def forward(self, x): 96 | identity = x 97 | 98 | out = self.conv1(x) 99 | out = self.bn1(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv2(out) 103 | out = self.bn2(out) 104 | out = self.relu(out) 105 | 106 | out = self.conv3(out) 107 | out = self.bn3(out) 108 | 109 | if self.downsample is not None: 110 | identity = self.downsample(x) 111 | 112 | out += identity 113 | out = self.relu(out) 114 | 115 | return out 116 | 117 | 118 | class ResNet(nn.Module): 119 | 120 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 121 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 122 | norm_layer=None): 123 | super(ResNet, self).__init__() 124 | if norm_layer is None: 125 | norm_layer = nn.BatchNorm2d 126 | self._norm_layer = norm_layer 127 | 128 | self.inplanes = 64 129 | self.dilation = 1 130 | if replace_stride_with_dilation is None: 131 | # each element in the tuple indicates if we should replace 132 | # the 2x2 stride with a dilated convolution instead 133 | replace_stride_with_dilation = [False, False, False] 134 | if len(replace_stride_with_dilation) != 3: 135 | raise ValueError("replace_stride_with_dilation should be None " 136 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 137 | self.groups = groups 138 | self.base_width = width_per_group 139 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 140 | bias=False) 141 | self.bn1 = norm_layer(self.inplanes) 142 | self.relu = nn.ReLU(inplace=True) 143 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 144 | self.layer1 = self._make_layer(block, 64, layers[0]) 145 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 146 | dilate=replace_stride_with_dilation[0]) 147 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 148 | dilate=replace_stride_with_dilation[1]) 149 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 150 | dilate=replace_stride_with_dilation[2]) 151 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 152 | self.fc = nn.Linear(512 * block.expansion, num_classes) 153 | 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 157 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | 161 | # Zero-initialize the last BN in each residual branch, 162 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 163 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 164 | if zero_init_residual: 165 | for m in self.modules(): 166 | if isinstance(m, Bottleneck): 167 | nn.init.constant_(m.bn3.weight, 0) 168 | elif isinstance(m, BasicBlock): 169 | nn.init.constant_(m.bn2.weight, 0) 170 | 171 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 172 | norm_layer = self._norm_layer 173 | downsample = None 174 | previous_dilation = self.dilation 175 | if dilate: 176 | self.dilation *= stride 177 | stride = 1 178 | if stride != 1 or self.inplanes != planes * block.expansion: 179 | downsample = nn.Sequential( 180 | conv1x1(self.inplanes, planes * block.expansion, stride), 181 | norm_layer(planes * block.expansion), 182 | ) 183 | 184 | layers = [] 185 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 186 | self.base_width, previous_dilation, norm_layer)) 187 | self.inplanes = planes * block.expansion 188 | for _ in range(1, blocks): 189 | layers.append(block(self.inplanes, planes, groups=self.groups, 190 | base_width=self.base_width, dilation=self.dilation, 191 | norm_layer=norm_layer)) 192 | 193 | return nn.Sequential(*layers) 194 | 195 | def forward(self, x): 196 | x = self.conv1(x) 197 | x = self.bn1(x) 198 | x = self.relu(x) 199 | x = self.maxpool(x) 200 | 201 | x = self.layer1(x) 202 | x = self.layer2(x) 203 | x = self.layer3(x) 204 | feature = self.layer4(x) 205 | 206 | x = self.avgpool(feature) 207 | x = torch.flatten(x, 1) 208 | x = self.fc(x) 209 | 210 | return x, feature 211 | 212 | 213 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 214 | model = ResNet(block, layers, **kwargs) 215 | if pretrained: 216 | state_dict = load_state_dict_from_url(model_urls[arch], 217 | progress=progress) 218 | model.load_state_dict(state_dict) 219 | return model 220 | 221 | 222 | def resnet18(pretrained=False, progress=True, **kwargs): 223 | r"""ResNet-18 model from 224 | `"Deep Residual Learning for Image Recognition" `_ 225 | 226 | Args: 227 | pretrained (bool): If True, returns a model pre-trained on ImageNet 228 | progress (bool): If True, displays a progress bar of the download to stderr 229 | """ 230 | return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, 231 | **kwargs) 232 | 233 | 234 | def resnet34(pretrained=False, progress=True, **kwargs): 235 | r"""ResNet-34 model from 236 | `"Deep Residual Learning for Image Recognition" `_ 237 | 238 | Args: 239 | pretrained (bool): If True, returns a model pre-trained on ImageNet 240 | progress (bool): If True, displays a progress bar of the download to stderr 241 | """ 242 | return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, 243 | **kwargs) 244 | 245 | 246 | def resnet50(pretrained=False, progress=True, **kwargs): 247 | r"""ResNet-50 model from 248 | `"Deep Residual Learning for Image Recognition" `_ 249 | 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, 255 | **kwargs) 256 | 257 | 258 | def resnet101(pretrained=False, progress=True, **kwargs): 259 | r"""ResNet-101 model from 260 | `"Deep Residual Learning for Image Recognition" `_ 261 | 262 | Args: 263 | pretrained (bool): If True, returns a model pre-trained on ImageNet 264 | progress (bool): If True, displays a progress bar of the download to stderr 265 | """ 266 | return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, 267 | **kwargs) 268 | 269 | 270 | def resnet152(pretrained=False, progress=True, **kwargs): 271 | r"""ResNet-152 model from 272 | `"Deep Residual Learning for Image Recognition" `_ 273 | 274 | Args: 275 | pretrained (bool): If True, returns a model pre-trained on ImageNet 276 | progress (bool): If True, displays a progress bar of the download to stderr 277 | """ 278 | return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, 279 | **kwargs) 280 | 281 | 282 | def resnext50_32x4d(pretrained=False, progress=True, **kwargs): 283 | r"""ResNeXt-50 32x4d model from 284 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 285 | 286 | Args: 287 | pretrained (bool): If True, returns a model pre-trained on ImageNet 288 | progress (bool): If True, displays a progress bar of the download to stderr 289 | """ 290 | kwargs['groups'] = 32 291 | kwargs['width_per_group'] = 4 292 | return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3], 293 | pretrained, progress, **kwargs) 294 | 295 | 296 | def resnext101_32x8d(pretrained=False, progress=True, **kwargs): 297 | r"""ResNeXt-101 32x8d model from 298 | `"Aggregated Residual Transformation for Deep Neural Networks" `_ 299 | 300 | Args: 301 | pretrained (bool): If True, returns a model pre-trained on ImageNet 302 | progress (bool): If True, displays a progress bar of the download to stderr 303 | """ 304 | kwargs['groups'] = 32 305 | kwargs['width_per_group'] = 8 306 | return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], 307 | pretrained, progress, **kwargs) 308 | 309 | 310 | def wide_resnet50_2(pretrained=False, progress=True, **kwargs): 311 | r"""Wide ResNet-50-2 model from 312 | `"Wide Residual Networks" `_ 313 | 314 | The model is the same as ResNet except for the bottleneck number of channels 315 | which is twice larger in every block. The number of channels in outer 1x1 316 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 317 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 318 | 319 | Args: 320 | pretrained (bool): If True, returns a model pre-trained on ImageNet 321 | progress (bool): If True, displays a progress bar of the download to stderr 322 | """ 323 | kwargs['width_per_group'] = 64 * 2 324 | return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3], 325 | pretrained, progress, **kwargs) 326 | 327 | 328 | def wide_resnet101_2(pretrained=False, progress=True, **kwargs): 329 | r"""Wide ResNet-101-2 model from 330 | `"Wide Residual Networks" `_ 331 | 332 | The model is the same as ResNet except for the bottleneck number of channels 333 | which is twice larger in every block. The number of channels in outer 1x1 334 | convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 335 | channels, and in Wide ResNet-50-2 has 2048-1024-2048. 336 | 337 | Args: 338 | pretrained (bool): If True, returns a model pre-trained on ImageNet 339 | progress (bool): If True, displays a progress bar of the download to stderr 340 | """ 341 | kwargs['width_per_group'] = 64 * 2 342 | return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3], 343 | pretrained, progress, **kwargs) 344 | -------------------------------------------------------------------------------- /defrcn/data/builtin_meta.py: -------------------------------------------------------------------------------- 1 | # It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json 2 | COCO_CATEGORIES = [ 3 | {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, 4 | {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, 5 | {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, 6 | {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, 7 | {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, 8 | {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, 9 | {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, 10 | {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, 11 | {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, 12 | {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, 13 | {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, 14 | {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, 15 | {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter",}, 16 | {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, 17 | {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, 18 | {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, 19 | {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, 20 | {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, 21 | {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, 22 | {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, 23 | {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, 24 | {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, 25 | {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, 26 | {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, 27 | {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, 28 | {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, 29 | {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, 30 | {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, 31 | {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, 32 | {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, 33 | {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, 34 | {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, 35 | {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, 36 | {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, 37 | {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, 38 | {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, 39 | {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, 40 | {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, 41 | {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket",}, 42 | {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, 43 | {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, 44 | {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, 45 | {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, 46 | {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, 47 | {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, 48 | {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, 49 | {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, 50 | {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, 51 | {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, 52 | {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, 53 | {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, 54 | {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, 55 | {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, 56 | {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, 57 | {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, 58 | {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, 59 | {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, 60 | {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, 61 | {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, 62 | {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, 63 | {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, 64 | {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, 65 | {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, 66 | {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, 67 | {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, 68 | {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, 69 | {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, 70 | {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, 71 | {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, 72 | {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, 73 | {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, 74 | {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, 75 | {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, 76 | {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, 77 | {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, 78 | {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, 79 | {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, 80 | {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, 81 | {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, 82 | {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, 83 | {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, 84 | {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, 85 | {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, 86 | {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, 87 | {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, 88 | {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, 89 | {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, 90 | {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, 91 | {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, 92 | {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, 93 | {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, 94 | {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, 95 | {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, 96 | {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, 97 | {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, 98 | {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, 99 | {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, 100 | {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, 101 | {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, 102 | {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, 103 | {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, 104 | {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, 105 | {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, 106 | {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, 107 | {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, 108 | {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, 109 | {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, 110 | {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, 111 | {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, 112 | {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, 113 | {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, 114 | {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, 115 | {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, 116 | {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, 117 | {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, 118 | {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, 119 | {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, 120 | {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, 121 | {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, 122 | {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, 123 | {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, 124 | {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, 125 | {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, 126 | {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, 127 | {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, 128 | {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, 129 | {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, 130 | {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, 131 | {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, 132 | {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, 133 | {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, 134 | {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, 135 | {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}, 136 | ] 137 | 138 | # Novel COCO categories 139 | COCO_NOVEL_CATEGORIES = [ 140 | {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, 141 | {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, 142 | {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, 143 | {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, 144 | {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, 145 | {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, 146 | {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, 147 | {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, 148 | {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, 149 | {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, 150 | {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, 151 | {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, 152 | {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, 153 | {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, 154 | {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, 155 | {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, 156 | {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, 157 | {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, 158 | {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, 159 | {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, 160 | ] 161 | 162 | # PASCAL VOC categories 163 | PASCAL_VOC_ALL_CATEGORIES = { 164 | 1: ["aeroplane", "bicycle", "boat", "bottle", "car", 165 | "cat", "chair", "diningtable", "dog", "horse", 166 | "person", "pottedplant", "sheep", "train", "tvmonitor", 167 | "bird", "bus", "cow", "motorbike", "sofa", 168 | ], 169 | 2: ["bicycle", "bird", "boat", "bus", "car", 170 | "cat", "chair", "diningtable", "dog", "motorbike", 171 | "person", "pottedplant", "sheep", "train", "tvmonitor", 172 | "aeroplane", "bottle", "cow", "horse", "sofa", 173 | ], 174 | 3: ["aeroplane", "bicycle", "bird", "bottle", "bus", 175 | "car", "chair", "cow", "diningtable", "dog", 176 | "horse", "person", "pottedplant", "train", "tvmonitor", 177 | "boat", "cat", "motorbike", "sheep", "sofa", 178 | ], 179 | } 180 | 181 | PASCAL_VOC_NOVEL_CATEGORIES = { 182 | 1: ["bird", "bus", "cow", "motorbike", "sofa"], 183 | 2: ["aeroplane", "bottle", "cow", "horse", "sofa"], 184 | 3: ["boat", "cat", "motorbike", "sheep", "sofa"], 185 | } 186 | 187 | PASCAL_VOC_BASE_CATEGORIES = { 188 | 1: ["aeroplane", "bicycle", "boat", "bottle", "car", 189 | "cat", "chair", "diningtable", "dog", "horse", 190 | "person", "pottedplant", "sheep", "train", "tvmonitor", 191 | ], 192 | 2: ["bicycle", "bird", "boat", "bus", "car", 193 | "cat", "chair", "diningtable", "dog", "motorbike", 194 | "person", "pottedplant", "sheep", "train", "tvmonitor", 195 | ], 196 | 3: ["aeroplane", "bicycle", "bird", "bottle", "bus", 197 | "car", "chair", "cow", "diningtable", "dog", 198 | "horse", "person", "pottedplant", "train", "tvmonitor", 199 | ], 200 | } 201 | 202 | 203 | def _get_coco_instances_meta(): 204 | thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1] 205 | thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1] 206 | assert len(thing_ids) == 80, len(thing_ids) 207 | # Mapping from the incontiguous COCO category id to an id in [0, 79] 208 | thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} 209 | thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1] 210 | ret = { 211 | "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, 212 | "thing_classes": thing_classes, 213 | "thing_colors": thing_colors, 214 | } 215 | return ret 216 | 217 | 218 | def _get_coco_fewshot_instances_meta(): 219 | ret = _get_coco_instances_meta() 220 | novel_ids = [k["id"] for k in COCO_NOVEL_CATEGORIES if k["isthing"] == 1] 221 | novel_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(novel_ids)} 222 | novel_classes = [ 223 | k["name"] for k in COCO_NOVEL_CATEGORIES if k["isthing"] == 1 224 | ] 225 | base_categories = [ 226 | k for k in COCO_CATEGORIES if k["isthing"] == 1 and k["name"] not in novel_classes 227 | ] 228 | base_ids = [k["id"] for k in base_categories] 229 | base_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(base_ids)} 230 | base_classes = [k["name"] for k in base_categories] 231 | ret["novel_dataset_id_to_contiguous_id"] = novel_dataset_id_to_contiguous_id 232 | ret["novel_classes"] = novel_classes 233 | ret["base_dataset_id_to_contiguous_id"] = base_dataset_id_to_contiguous_id 234 | ret["base_classes"] = base_classes 235 | return ret 236 | 237 | 238 | def _get_voc_fewshot_instances_meta(): 239 | ret = { 240 | "thing_classes": PASCAL_VOC_ALL_CATEGORIES, 241 | "novel_classes": PASCAL_VOC_NOVEL_CATEGORIES, 242 | "base_classes": PASCAL_VOC_BASE_CATEGORIES, 243 | } 244 | return ret 245 | 246 | 247 | def _get_builtin_metadata(dataset_name): 248 | if dataset_name == "coco": 249 | return _get_coco_instances_meta() 250 | elif dataset_name == "coco_fewshot": 251 | return _get_coco_fewshot_instances_meta() 252 | elif dataset_name == "voc_fewshot": 253 | return _get_voc_fewshot_instances_meta() 254 | raise KeyError("No built-in metadata for dataset {}".format(dataset_name)) 255 | -------------------------------------------------------------------------------- /defrcn/dataloader/build.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import logging 3 | import operator 4 | import itertools 5 | import numpy as np 6 | import torch.utils.data 7 | from tabulate import tabulate 8 | from termcolor import colored 9 | from detectron2.structures import BoxMode 10 | from detectron2.config import configurable 11 | from detectron2.utils.env import seed_all_rng 12 | from detectron2.utils.logger import log_first_n 13 | from detectron2.utils.comm import get_world_size 14 | from detectron2.utils.file_io import PathManager 15 | from detectron2.data.catalog import DatasetCatalog, MetadataCatalog 16 | from detectron2.data.common import AspectRatioGroupedDataset, DatasetFromList, MapDataset 17 | from defrcn.data.pcb_common import PCBAspectRatioGroupedDataset 18 | from detectron2.data.detection_utils import check_metadata_consistency 19 | from detectron2.data.samplers import InferenceSampler, RepeatFactorTrainingSampler, TrainingSampler 20 | from .dataset_mapper import DatasetMapper 21 | 22 | 23 | __all__ = [ 24 | "build_batch_data_loader", 25 | "build_detection_train_loader", 26 | "build_detection_test_loader", 27 | "get_detection_dataset_dicts", 28 | "load_proposals_into_dataset", 29 | "print_instances_class_histogram", 30 | ] 31 | 32 | 33 | def filter_images_with_only_crowd_annotations(dataset_dicts): 34 | """ 35 | Filter out images with none annotations or only crowd annotations 36 | (i.e., images without non-crowd annotations). 37 | A common training-time preprocessing on COCO dataset. 38 | Args: 39 | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. 40 | Returns: 41 | list[dict]: the same format, but filtered. 42 | """ 43 | num_before = len(dataset_dicts) 44 | 45 | def valid(anns): 46 | for ann in anns: 47 | if ann.get("iscrowd", 0) == 0: 48 | return True 49 | return False 50 | 51 | dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] 52 | num_after = len(dataset_dicts) 53 | logger = logging.getLogger(__name__) 54 | logger.info( 55 | "Removed {} images with no usable annotations. {} images left.".format( 56 | num_before - num_after, num_after 57 | ) 58 | ) 59 | return dataset_dicts 60 | 61 | 62 | def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): 63 | """ 64 | Filter out images with too few number of keypoints. 65 | Args: 66 | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. 67 | Returns: 68 | list[dict]: the same format as dataset_dicts, but filtered. 69 | """ 70 | num_before = len(dataset_dicts) 71 | 72 | def visible_keypoints_in_image(dic): 73 | # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility 74 | annotations = dic["annotations"] 75 | return sum( 76 | (np.array(ann["keypoints"][2::3]) > 0).sum() 77 | for ann in annotations 78 | if "keypoints" in ann 79 | ) 80 | 81 | dataset_dicts = [ 82 | x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image 83 | ] 84 | num_after = len(dataset_dicts) 85 | logger = logging.getLogger(__name__) 86 | logger.info( 87 | "Removed {} images with fewer than {} keypoints.".format( 88 | num_before - num_after, min_keypoints_per_image 89 | ) 90 | ) 91 | return dataset_dicts 92 | 93 | 94 | def load_proposals_into_dataset(dataset_dicts, proposal_file): 95 | """ 96 | Load precomputed object proposals into the dataset. 97 | The proposal file should be a pickled dict with the following keys: 98 | - "ids": list[int] or list[str], the image ids 99 | - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id 100 | - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores 101 | corresponding to the boxes. 102 | - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. 103 | Args: 104 | dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. 105 | proposal_file (str): file path of pre-computed proposals, in pkl format. 106 | Returns: 107 | list[dict]: the same format as dataset_dicts, but added proposal field. 108 | """ 109 | logger = logging.getLogger(__name__) 110 | logger.info("Loading proposals from: {}".format(proposal_file)) 111 | 112 | with PathManager.open(proposal_file, "rb") as f: 113 | proposals = pickle.load(f, encoding="latin1") 114 | 115 | # Rename the key names in D1 proposal files 116 | rename_keys = {"indexes": "ids", "scores": "objectness_logits"} 117 | for key in rename_keys: 118 | if key in proposals: 119 | proposals[rename_keys[key]] = proposals.pop(key) 120 | 121 | # Fetch the indexes of all proposals that are in the dataset 122 | # Convert image_id to str since they could be int. 123 | img_ids = set({str(record["image_id"]) for record in dataset_dicts}) 124 | id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} 125 | 126 | # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS' 127 | bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS 128 | 129 | for record in dataset_dicts: 130 | # Get the index of the proposal 131 | i = id_to_index[str(record["image_id"])] 132 | 133 | boxes = proposals["boxes"][i] 134 | objectness_logits = proposals["objectness_logits"][i] 135 | # Sort the proposals in descending order of the scores 136 | inds = objectness_logits.argsort()[::-1] 137 | record["proposal_boxes"] = boxes[inds] 138 | record["proposal_objectness_logits"] = objectness_logits[inds] 139 | record["proposal_bbox_mode"] = bbox_mode 140 | 141 | return dataset_dicts 142 | 143 | 144 | def print_instances_class_histogram(dataset_dicts, class_names): 145 | """ 146 | Args: 147 | dataset_dicts (list[dict]): list of dataset dicts. 148 | class_names (list[str]): list of class names (zero-indexed). 149 | """ 150 | num_classes = len(class_names) 151 | hist_bins = np.arange(num_classes + 1) 152 | histogram = np.zeros((num_classes,), dtype=np.int) 153 | for entry in dataset_dicts: 154 | annos = entry["annotations"] 155 | classes = [x["category_id"] for x in annos if not x.get("iscrowd", 0)] 156 | histogram += np.histogram(classes, bins=hist_bins)[0] 157 | 158 | N_COLS = min(6, len(class_names) * 2) 159 | 160 | def short_name(x): 161 | # make long class names shorter. useful for lvis 162 | if len(x) > 13: 163 | return x[:11] + ".." 164 | return x 165 | 166 | data = list( 167 | itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) 168 | ) 169 | total_num_instances = sum(data[1::2]) 170 | data.extend([None] * (N_COLS - (len(data) % N_COLS))) 171 | if num_classes > 1: 172 | data.extend(["total", total_num_instances]) 173 | data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) 174 | table = tabulate( 175 | data, 176 | headers=["category", "#instances"] * (N_COLS // 2), 177 | tablefmt="pipe", 178 | numalign="left", 179 | stralign="center", 180 | ) 181 | log_first_n( 182 | logging.INFO, 183 | "Distribution of instances among all {} categories:\n".format(num_classes) 184 | + colored(table, "cyan"), 185 | key="message", 186 | ) 187 | 188 | 189 | def get_detection_dataset_dicts( 190 | dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None 191 | ): 192 | """ 193 | Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. 194 | Args: 195 | dataset_names (list[str]): a list of dataset names 196 | filter_empty (bool): whether to filter out images without instance annotations 197 | min_keypoints (int): filter out images with fewer keypoints than 198 | `min_keypoints`. Set to 0 to do nothing. 199 | proposal_files (list[str]): if given, a list of object proposal files 200 | that match each dataset in `dataset_names`. 201 | Returns: 202 | list[dict]: a list of dicts following the standard dataset dict format. 203 | """ 204 | assert len(dataset_names) 205 | dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] 206 | for dataset_name, dicts in zip(dataset_names, dataset_dicts): 207 | assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) 208 | 209 | if proposal_files is not None: 210 | assert len(dataset_names) == len(proposal_files) 211 | # load precomputed proposals from proposal files 212 | dataset_dicts = [ 213 | load_proposals_into_dataset(dataset_i_dicts, proposal_file) 214 | for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) 215 | ] 216 | 217 | dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) 218 | 219 | has_instances = "annotations" in dataset_dicts[0] 220 | if filter_empty and has_instances: 221 | dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) 222 | if min_keypoints > 0 and has_instances: 223 | dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) 224 | 225 | if has_instances: 226 | try: 227 | class_names = MetadataCatalog.get(dataset_names[0]).thing_classes 228 | check_metadata_consistency("thing_classes", dataset_names) 229 | print_instances_class_histogram(dataset_dicts, class_names) 230 | except AttributeError: # class names are not available for this dataset 231 | pass 232 | 233 | assert len(dataset_dicts), "No valid data found in {}.".format(",".join(dataset_names)) 234 | return dataset_dicts 235 | 236 | 237 | def build_batch_data_loader( 238 | dataset, proto_dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0 239 | ): 240 | """ 241 | Build a batched dataloader for training. 242 | Args: 243 | dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed. 244 | sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices 245 | total_batch_size, aspect_ratio_grouping, num_workers): see 246 | :func:`build_detection_train_loader`. 247 | Returns: 248 | iterable[list]. Length of each list is the batch size of the current 249 | GPU. Each element in the list comes from the dataset. 250 | """ 251 | world_size = get_world_size() 252 | assert ( 253 | total_batch_size > 0 and total_batch_size % world_size == 0 254 | ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( 255 | total_batch_size, world_size 256 | ) 257 | 258 | batch_size = total_batch_size // world_size 259 | if aspect_ratio_grouping: 260 | data_loader = torch.utils.data.DataLoader( 261 | dataset, 262 | sampler=sampler, 263 | num_workers=num_workers, 264 | batch_sampler=None, 265 | collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements 266 | worker_init_fn=worker_init_reset_seed, 267 | ) # yield individual mapped dict 268 | return PCBAspectRatioGroupedDataset(data_loader, batch_size, proto_dataset) 269 | else: 270 | batch_sampler = torch.utils.data.sampler.BatchSampler( 271 | sampler, batch_size, drop_last=True 272 | ) # drop_last so the batch always have the same size 273 | return torch.utils.data.DataLoader( 274 | dataset, 275 | num_workers=num_workers, 276 | batch_sampler=batch_sampler, 277 | collate_fn=trivial_batch_collator, 278 | worker_init_fn=worker_init_reset_seed, 279 | ) 280 | 281 | 282 | def _train_loader_from_config(cfg, *, mapper=None, dataset=None, sampler=None): 283 | if dataset is None: 284 | dataset = get_detection_dataset_dicts( 285 | cfg.DATASETS.TRAIN, 286 | filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, 287 | min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE 288 | if cfg.MODEL.KEYPOINT_ON 289 | else 0, 290 | proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, 291 | ) 292 | 293 | if cfg.TEST.PCB_ENABLE: 294 | import copy 295 | proto_dataset = copy.deepcopy(dataset) 296 | proto_mapper = DatasetMapper(cfg, False) 297 | proto_dataset = MapDataset(proto_dataset, proto_mapper) 298 | else: 299 | proto_dataset = None 300 | 301 | if mapper is None: 302 | mapper = DatasetMapper(cfg, True) 303 | 304 | if sampler is None: 305 | sampler_name = cfg.DATALOADER.SAMPLER_TRAIN 306 | logger = logging.getLogger(__name__) 307 | logger.info("Using training sampler {}".format(sampler_name)) 308 | if sampler_name == "TrainingSampler": 309 | sampler = TrainingSampler(len(dataset)) 310 | elif sampler_name == "RepeatFactorTrainingSampler": 311 | repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( 312 | dataset, cfg.DATALOADER.REPEAT_THRESHOLD 313 | ) 314 | sampler = RepeatFactorTrainingSampler(repeat_factors) 315 | else: 316 | raise ValueError("Unknown training sampler: {}".format(sampler_name)) 317 | 318 | return { 319 | "dataset": dataset, 320 | "proto_dataset": proto_dataset, 321 | "sampler": sampler, 322 | "mapper": mapper, 323 | "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, 324 | "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, 325 | "num_workers": cfg.DATALOADER.NUM_WORKERS, 326 | } 327 | 328 | 329 | # TODO can allow dataset as an iterable or IterableDataset to make this function more general 330 | @configurable(from_config=_train_loader_from_config) 331 | def build_detection_train_loader( 332 | dataset, *, proto_dataset, mapper, sampler=None, total_batch_size, aspect_ratio_grouping=True, num_workers=0 333 | ): 334 | """ 335 | Build a dataloader for object detection with some default features. 336 | This interface is experimental. 337 | Args: 338 | dataset (list or torch.utils.data.Dataset): a list of dataset dicts, 339 | or a map-style pytorch dataset. They can be obtained by using 340 | :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. 341 | mapper (callable): a callable which takes a sample (dict) from dataset and 342 | returns the format to be consumed by the model. 343 | When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. 344 | sampler (torch.utils.data.sampler.Sampler or None): a sampler that 345 | produces indices to be applied on ``dataset``. 346 | Default to :class:`TrainingSampler`, which coordinates a random shuffle 347 | sequence across all workers. 348 | total_batch_size (int): total batch size across all workers. Batching 349 | simply puts data into a list. 350 | aspect_ratio_grouping (bool): whether to group images with similar 351 | aspect ratio for efficiency. When enabled, it requires each 352 | element in dataset be a dict with keys "width" and "height". 353 | num_workers (int): number of parallel data loading workers 354 | Returns: 355 | torch.utils.data.DataLoader: a dataloader. Each output from it is a 356 | ``list[mapped_element]`` of length ``total_batch_size / num_workers``, 357 | where ``mapped_element`` is produced by the ``mapper``. 358 | """ 359 | if isinstance(dataset, list): 360 | dataset = DatasetFromList(dataset, copy=False) 361 | if mapper is not None: 362 | dataset = MapDataset(dataset, mapper) 363 | if sampler is None: 364 | sampler = TrainingSampler(len(dataset)) 365 | assert isinstance(sampler, torch.utils.data.sampler.Sampler) 366 | return build_batch_data_loader( 367 | dataset, 368 | proto_dataset, 369 | sampler, 370 | total_batch_size, 371 | aspect_ratio_grouping=aspect_ratio_grouping, 372 | num_workers=num_workers, 373 | ) 374 | 375 | 376 | def _test_loader_from_config(cfg, dataset_name, mapper=None): 377 | """ 378 | Uses the given `dataset_name` argument (instead of the names in cfg), because the 379 | standard practice is to evaluate each test set individually (not combining them). 380 | """ 381 | dataset = get_detection_dataset_dicts( 382 | [dataset_name], 383 | filter_empty=False, 384 | proposal_files=[ 385 | cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] 386 | ] 387 | if cfg.MODEL.LOAD_PROPOSALS 388 | else None, 389 | ) 390 | if mapper is None: 391 | mapper = DatasetMapper(cfg, False) 392 | return {"dataset": dataset, "mapper": mapper, "num_worker": cfg.DATALOADER.NUM_WORKERS} 393 | 394 | 395 | @configurable(from_config=_test_loader_from_config) 396 | def build_detection_test_loader(dataset, *, mapper, num_worker=0): 397 | """ 398 | Similar to `build_detection_train_loader`, but uses a batch size of 1. 399 | This interface is experimental. 400 | Args: 401 | dataset (list or torch.utils.data.Dataset): a list of dataset dicts, 402 | or a map-style pytorch dataset. They can be obtained by using 403 | :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. 404 | mapper (callable): a callable which takes a sample (dict) from dataset 405 | and returns the format to be consumed by the model. 406 | When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. 407 | num_workers (int): number of parallel data loading workers 408 | Returns: 409 | DataLoader: a torch DataLoader, that loads the given detection 410 | dataset, with test-time transformation and batching. 411 | Examples: 412 | :: 413 | data_loader = build_detection_test_loader( 414 | DatasetRegistry.get("my_test"), 415 | mapper=DatasetMapper(...)) 416 | # or, instantiate with a CfgNode: 417 | data_loader = build_detection_test_loader(cfg, "my_test") 418 | """ 419 | if isinstance(dataset, list): 420 | dataset = DatasetFromList(dataset, copy=False) 421 | if mapper is not None: 422 | dataset = MapDataset(dataset, mapper) 423 | sampler = InferenceSampler(len(dataset)) 424 | # Always use 1 image per worker during inference since this is the 425 | # standard when reporting inference time in papers. 426 | batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False) 427 | data_loader = torch.utils.data.DataLoader( 428 | dataset, 429 | num_workers=num_worker, 430 | batch_sampler=batch_sampler, 431 | collate_fn=trivial_batch_collator, 432 | ) 433 | return data_loader 434 | 435 | 436 | def trivial_batch_collator(batch): 437 | """ 438 | A batch collator that does nothing. 439 | """ 440 | return batch 441 | 442 | 443 | def worker_init_reset_seed(worker_id): 444 | seed_all_rng(np.random.randint(2 ** 31) + worker_id) 445 | --------------------------------------------------------------------------------