├── README.md ├── architecture.png ├── configs ├── _base_ │ ├── datasets │ │ └── nway_kshot │ │ │ ├── base_coco_ms.py │ │ │ ├── base_voc_ms.py │ │ │ ├── few_shot_coco_ms.py │ │ │ └── few_shot_voc_ms.py │ ├── default_runtime.py │ ├── models │ │ └── faster_rcnn_r50_caffe_c4.py │ └── schedules │ │ └── schedule.py └── fpd │ ├── coco │ ├── fpd_r101_c4_2xb4_coco_10shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_coco_30shot-fine-tuning.py │ └── fpd_r101_c4_2xb4_coco_base-training.py │ ├── fpd_r101_c4.py │ ├── meta-rcnn_r50_c4.py │ └── voc │ ├── split1 │ ├── fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split1_1shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split1_2shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split1_3shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split1_5shot-fine-tuning.py │ └── fpd_r101_c4_2xb4_voc-split1_base-training.py │ ├── split2 │ ├── fpd_r101_c4_2xb4_voc-split2_10shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split2_1shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split2_2shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split2_3shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split2_5shot-fine-tuning.py │ └── fpd_r101_c4_2xb4_voc-split2_base-training.py │ └── split3 │ ├── fpd_r101_c4_2xb4_voc-split3_10shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split3_1shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split3_2shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split3_3shot-fine-tuning.py │ ├── fpd_r101_c4_2xb4_voc-split3_5shot-fine-tuning.py │ └── fpd_r101_c4_2xb4_voc-split3_base-training.py ├── dist_test.sh ├── dist_train.sh ├── fpd ├── __init__.py ├── ffa.py ├── fpd_detector.py ├── fpd_roi_head.py ├── query_support.py ├── transforms.py └── utils.py ├── requirements.txt ├── setup.py ├── test.py └── train.py /README.md: -------------------------------------------------------------------------------- 1 | ## Introduction 2 | >[**Fine-Grained Prototypes Distillation for Few-Shot Object Detection (AAAI2024)**](https://arxiv.org/pdf/2401.07629.pdf) 3 | > 4 | ![fpd_architecture](architecture.png) 5 | 6 | This repo is based on [MMFewShot](https://github.com/open-mmlab/mmfewshot). 7 | 8 | ## Quick Start 9 | ```bash 10 | # creat a conda environment 11 | conda create -n fpd python=3.8 12 | conda activate fpd 13 | conda install pytorch==1.12.1 torchvision==0.13.1 torchaudio cudatoolkit=11.3 -c pytorch -c conda-forge 14 | 15 | # dependencies 16 | pip install openmim 17 | mim install mmcv-full==1.6.0 18 | mim install mmcls==0.25.0 19 | mim install mmdet==2.24.0 20 | pip install -r requirements.txt 21 | 22 | # install mmfewshot 23 | pip install git+https://github.com/open-mmlab/mmfewshot.git 24 | # or manually download the code, then 25 | # cd mmfewshot 26 | # pip install . 27 | 28 | # install FPD 29 | python setup.py develop 30 | ``` 31 | 32 | ## Prepare Datasets 33 | Please refer to [mmfewshot/data](https://github.com/open-mmlab/mmfewshot/blob/main/tools/data/README.md) 34 | for the data preparation steps. 35 | 36 | ## Results on VOC Dataset 37 | * Base Training 38 | 39 | | Config | Split | Base AP50 | ckpt | 40 | |:---:|:---:|:---:|:---:| 41 | |[config](configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_base-training.py)|1|79.8|[ckpt](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split1_base-training_iter_20000.pth)| 42 | |[config](configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_base-training.py)|2|80.3|[ckpt](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split2_base-training_iter_20000.pth)| 43 | |[config](configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_base-training.py)|3|80.2|[ckpt](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split3_base-training_iter_20000.pth)| 44 | 45 | * Few Shot Fine-tuning 46 | 47 | | Config | Split | Shot | Novel AP50 | ckpt | log | 48 | |:---:|:---:|:---:|:---:|:---:|:---:| 49 | |[config](configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning.py)|1|10|68.4|[ckpt](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning_iter_2000.pth)|[log](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning.log)| 50 | |[config](configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_10shot-fine-tuning.py)|2|10|53.9|[ckpt](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split2_10shot-fine-tuning_iter_2000.pth)|[log](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split2_10shot-fine-tuning.log)| 51 | |[config](configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_10shot-fine-tuning.py)|3|10|62.9|[ckpt](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split3_10shot-fine-tuning_iter_3200.pth)|[log](https://github.com/wangchen1801/FPD/releases/download/ckpts_fpd/fpd_r101_c4_2xb4_voc-split3_10shot-fine-tuning.log)| 52 | 53 | ## Results on COCO Dataset 54 | * Base Training 55 | 56 | | Config | Base mAP | ckpt | 57 | |:---:|:---:|:---:| 58 | |[config](configs/fpd/coco/fpd_r101_c4_2xb4_coco_base-training.py)|36.0|[ckpt](https://github.com/wangchen1801/FPD/releases/download/ckpts/fpd_r101_c4_2xb4_coco_base-training_iter_110000.pth)| 59 | 60 | * Few Shot Fine-tuning 61 | 62 | | Config | Shot | Novel mAP (nAP) | ckpt | log | 63 | |:---:|:---:|:---:|:---:|:---:| 64 | |[config](configs/fpd/coco/fpd_r101_c4_2xb4_coco_30shot-fine-tuning.py)|30|20.1|[ckpt]()|[log](https://github.com/wangchen1801/FPD/releases/download/ckpts/fpd_r101_c4_2xb4_coco_30shot-fine-tuning.log)| 65 | 66 | ## Evaluation 67 | 68 | ```bash 69 | # single-gpu test 70 | python test.py ${CONFIG} ${CHECKPOINT} --eval mAP|bbox 71 | 72 | # multi-gpus test 73 | bash dist_test.sh ${CONFIG} ${CHECKPOINT} ${NUM_GPU} --eval mAP|bbox 74 | ``` 75 | 76 | * For example, test pretrained weights on VOC Split1 10-shot with 2 gpus: 77 | 78 | ```bash 79 | bash dist_test.sh \ 80 | configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning.py \ 81 | ./work_dirs/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning_iter_2000.pth 2 --eval mAP 82 | ``` 83 | 84 | * Test pretrained weights on COCO 30-shot with 2 gpus: 85 | ```bash 86 | bash dist_test.sh \ 87 | configs/fpd/coco/fpd_r101_c4_2xb4_coco_30shot-fine-tuning.py \ 88 | ./work_dirs/fpd_r101_c4_2xb4_coco_30shot-fine-tuning/fpd_r101_c4_2xb4_coco_30shot-fine-tuning_iter_18000.pth 2 --eval bbox 89 | ``` 90 | 91 | ## Training 92 | ```bash 93 | # single-gpu training 94 | python train.py ${CONFIG} 95 | 96 | # multi-gpus training 97 | bash dist_train.sh ${CONFIG} ${NUM_GPU} 98 | ``` 99 | * Training FPD on VOC dataset with 2 gpus: 100 | ```bash 101 | # base training 102 | bash dist_train.sh \ 103 | configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_base-training.py 2 104 | 105 | # few-shot fine-tuning 106 | bash dist_train.sh \ 107 | configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning.py 2 108 | ``` 109 | * Training FPD on COCO dataset with 2 gpus: 110 | ```bash 111 | # base training 112 | bash dist_train.sh \ 113 | configs/fpd/coco/fpd_r101_c4_2xb4_coco_base-training.py 2 114 | 115 | # few-shot fine-tuning 116 | bash dist_train.sh \ 117 | configs/fpd/coco/fpd_r101_c4_2xb4_coco_30shot-fine-tuning.py 2 118 | ``` 119 | 120 | ## Citation 121 | If you would like to cite this paper, please use the following BibTeX entries: 122 | ```BibTeX 123 | @InProceedings{wang2024fpd, 124 | title={Fine-Grained Prototypes Distillation for Few-Shot Object Detection}, 125 | author={Wang, Zichen and Yang, Bo and Yue, Haonan and Ma, Zhenghao}, 126 | booktitle = {Proceedings of the 38th AAAI Conference on Artificial Intelligence (AAAI-24)}, 127 | year={2024} 128 | } 129 | ``` 130 | -------------------------------------------------------------------------------- /architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wangchen1801/FPD/19882e25938e3990bf59be060d73f99f598e0eee/architecture.png -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/base_coco_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | # dict(type='Resize', img_scale=(1000, 600), keep_ratio=True), 9 | dict( 10 | type='Resize', 11 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), 12 | (1333, 768), (1333, 800)], 13 | keep_ratio=True, 14 | multiscale_mode='value'), 15 | dict(type='RandomFlip', flip_ratio=0.5), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='Pad', size_divisor=32), 18 | dict(type='DefaultFormatBundle'), 19 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 20 | ], 21 | support=[ 22 | dict(type='LoadImageFromFile'), 23 | dict(type='LoadAnnotations', with_bbox=True), 24 | dict( 25 | type='CropResizeInstanceByRatio', 26 | num_context_pixels=16, 27 | context_ratio=0.07, 28 | target_size=(224, 224)), 29 | # dict( 30 | # type='CropResizeInstance', 31 | # num_context_pixels=16, 32 | # target_size=(224, 224)), 33 | dict(type='Normalize', **img_norm_cfg), 34 | dict(type='GenerateMask', target_size=(224, 224)), 35 | dict(type='RandomFlip', flip_ratio=0.0), 36 | dict(type='DefaultFormatBundle'), 37 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 38 | ]) 39 | test_pipeline = [ 40 | dict(type='LoadImageFromFile'), 41 | dict( 42 | type='MultiScaleFlipAug', 43 | # img_scale=(1000, 600), 44 | img_scale=(1333, 800), 45 | flip=False, 46 | transforms=[ 47 | dict(type='Resize', keep_ratio=True), 48 | dict(type='RandomFlip'), 49 | dict(type='Normalize', **img_norm_cfg), 50 | dict(type='Pad', size_divisor=32), 51 | dict(type='ImageToTensor', keys=['img']), 52 | dict(type='Collect', keys=['img']) 53 | ]) 54 | ] 55 | # classes splits are predefined in FewShotCocoDataset 56 | data_root = 'data/coco/' 57 | data = dict( 58 | samples_per_gpu=4, 59 | workers_per_gpu=2, 60 | train=dict( 61 | type='NWayKShotDataset', 62 | num_support_ways=60, 63 | num_support_shots=1, 64 | one_support_shot_per_image=True, 65 | num_used_support_shots=200, 66 | save_dataset=False, 67 | dataset=dict( 68 | type='FewShotCocoDataset', 69 | ann_cfg=[ 70 | dict( 71 | type='ann_file', 72 | ann_file='data/few_shot_ann/coco/annotations/train.json') 73 | ], 74 | img_prefix=data_root, 75 | multi_pipelines=train_multi_pipelines, 76 | classes='BASE_CLASSES', 77 | instance_wise=False, 78 | dataset_name='query_support_dataset'), 79 | ), 80 | val=dict( 81 | type='FewShotCocoDataset', 82 | ann_cfg=[ 83 | dict( 84 | type='ann_file', 85 | ann_file='data/few_shot_ann/coco/annotations/val.json') 86 | ], 87 | img_prefix=data_root, 88 | pipeline=test_pipeline, 89 | classes='BASE_CLASSES'), 90 | test=dict( 91 | type='FewShotCocoDataset', 92 | ann_cfg=[ 93 | dict( 94 | type='ann_file', 95 | ann_file='data/few_shot_ann/coco/annotations/val.json') 96 | ], 97 | img_prefix=data_root, 98 | pipeline=test_pipeline, 99 | test_mode=True, 100 | classes='BASE_CLASSES'), 101 | model_init=dict( 102 | copy_from_train_dataset=True, 103 | samples_per_gpu=16, 104 | workers_per_gpu=1, 105 | type='FewShotCocoDataset', 106 | ann_cfg=None, 107 | img_prefix=data_root, 108 | pipeline=train_multi_pipelines['support'], 109 | instance_wise=True, 110 | classes='BASE_CLASSES', 111 | dataset_name='model_init_dataset')) 112 | evaluation = dict(interval=20000, metric='bbox', classwise=True) 113 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/base_voc_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | # dict(type='Resize', img_scale=(1000, 600), keep_ratio=True), 9 | dict( 10 | type='Resize', 11 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), 12 | (1333, 768), (1333, 800)], 13 | keep_ratio=True, 14 | multiscale_mode='value'), 15 | dict(type='RandomFlip', flip_ratio=0.5), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 19 | ], 20 | support=[ 21 | dict(type='LoadImageFromFile'), 22 | dict(type='LoadAnnotations', with_bbox=True), 23 | dict( 24 | type='CropResizeInstance', 25 | num_context_pixels=16, 26 | target_size=(224, 224)), 27 | dict(type='Normalize', **img_norm_cfg), 28 | dict(type='GenerateMask', target_size=(224, 224)), 29 | dict(type='RandomFlip', flip_ratio=0.0), 30 | dict(type='DefaultFormatBundle'), 31 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 32 | ]) 33 | test_pipeline = [ 34 | dict(type='LoadImageFromFile'), 35 | dict( 36 | type='MultiScaleFlipAug', 37 | # img_scale=(1000, 600), 38 | img_scale=(1333, 800), 39 | flip=False, 40 | transforms=[ 41 | dict(type='Resize', keep_ratio=True), 42 | dict(type='RandomFlip'), 43 | dict(type='Normalize', **img_norm_cfg), 44 | dict(type='ImageToTensor', keys=['img']), 45 | dict(type='Collect', keys=['img']) 46 | ]) 47 | ] 48 | # classes splits are predefined in FewShotVOCDataset 49 | data_root = 'data/VOCdevkit/' 50 | data = dict( 51 | samples_per_gpu=4, 52 | workers_per_gpu=2, 53 | train=dict( 54 | type='NWayKShotDataset', 55 | num_support_ways=15, 56 | num_support_shots=1, 57 | one_support_shot_per_image=True, 58 | num_used_support_shots=200, 59 | save_dataset=False, 60 | dataset=dict( 61 | type='FewShotVOCDataset', 62 | ann_cfg=[ 63 | dict( 64 | type='ann_file', 65 | ann_file=data_root + 66 | 'VOC2007/ImageSets/Main/trainval.txt'), 67 | dict( 68 | type='ann_file', 69 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt'), 70 | ], 71 | img_prefix=data_root, 72 | multi_pipelines=train_multi_pipelines, 73 | classes=None, 74 | use_difficult=True, 75 | instance_wise=False, 76 | dataset_name='query_dataset'), 77 | support_dataset=dict( 78 | type='FewShotVOCDataset', 79 | ann_cfg=[ 80 | dict( 81 | type='ann_file', 82 | ann_file=data_root + 83 | 'VOC2007/ImageSets/Main/trainval.txt'), 84 | dict( 85 | type='ann_file', 86 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt'), 87 | ], 88 | img_prefix=data_root, 89 | multi_pipelines=train_multi_pipelines, 90 | classes=None, 91 | use_difficult=False, 92 | instance_wise=False, 93 | dataset_name='support_dataset')), 94 | val=dict( 95 | type='FewShotVOCDataset', 96 | ann_cfg=[ 97 | dict( 98 | type='ann_file', 99 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt'), 100 | ], 101 | img_prefix=data_root, 102 | pipeline=test_pipeline, 103 | classes=None), 104 | test=dict( 105 | type='FewShotVOCDataset', 106 | ann_cfg=[ 107 | dict( 108 | type='ann_file', 109 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt'), 110 | ], 111 | img_prefix=data_root, 112 | pipeline=test_pipeline, 113 | test_mode=True, 114 | classes=None), 115 | model_init=dict( 116 | copy_from_train_dataset=True, 117 | samples_per_gpu=16, 118 | workers_per_gpu=1, 119 | type='FewShotVOCDataset', 120 | ann_cfg=None, 121 | img_prefix=data_root, 122 | pipeline=train_multi_pipelines['support'], 123 | use_difficult=False, 124 | instance_wise=True, 125 | classes=None, 126 | dataset_name='model_init_dataset')) 127 | evaluation = dict(interval=5000, metric='mAP') 128 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/few_shot_coco_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | # dict(type='Resize', img_scale=(1000, 600), keep_ratio=True), 9 | dict( 10 | type='Resize', 11 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), 12 | (1333, 768), (1333, 800)], 13 | keep_ratio=True, 14 | multiscale_mode='value'), 15 | dict(type='RandomFlip', flip_ratio=0.5), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='Pad', size_divisor=32), 18 | dict(type='DefaultFormatBundle'), 19 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 20 | ], 21 | support=[ 22 | dict(type='LoadImageFromFile'), 23 | dict(type='LoadAnnotations', with_bbox=True), 24 | dict( 25 | type='CropResizeInstanceByRatio', 26 | num_context_pixels=16, 27 | context_ratio=0.07, 28 | target_size=(224, 224)), 29 | # dict( 30 | # type='CropResizeInstance', 31 | # num_context_pixels=16, 32 | # target_size=(224, 224)), 33 | dict(type='Normalize', **img_norm_cfg), 34 | dict(type='GenerateMask', target_size=(224, 224)), 35 | dict(type='RandomFlip', flip_ratio=0.0), 36 | dict(type='DefaultFormatBundle'), 37 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 38 | ]) 39 | test_pipeline = [ 40 | dict(type='LoadImageFromFile'), 41 | dict( 42 | type='MultiScaleFlipAug', 43 | # img_scale=(1000, 600), 44 | img_scale=(1333, 800), 45 | flip=False, 46 | transforms=[ 47 | dict(type='Resize', keep_ratio=True), 48 | dict(type='RandomFlip'), 49 | dict(type='Normalize', **img_norm_cfg), 50 | dict(type='Pad', size_divisor=32), 51 | dict(type='ImageToTensor', keys=['img']), 52 | dict(type='Collect', keys=['img']) 53 | ]) 54 | ] 55 | # classes splits are predefined in FewShotCocoDataset 56 | data_root = 'data/coco/' 57 | data = dict( 58 | samples_per_gpu=4, 59 | workers_per_gpu=2, 60 | train=dict( 61 | type='NWayKShotDataset', 62 | num_support_ways=80, 63 | num_support_shots=1, 64 | one_support_shot_per_image=False, 65 | num_used_support_shots=30, 66 | save_dataset=True, 67 | dataset=dict( 68 | type='FewShotCocoDataset', 69 | ann_cfg=[ 70 | dict( 71 | type='ann_file', 72 | ann_file='data/few_shot_ann/coco/annotations/train.json') 73 | ], 74 | img_prefix=data_root, 75 | multi_pipelines=train_multi_pipelines, 76 | classes='ALL_CLASSES', 77 | instance_wise=False, 78 | dataset_name='query_support_dataset'), 79 | ), 80 | 81 | val=dict( 82 | type='FewShotCocoDataset', 83 | ann_cfg=[ 84 | dict( 85 | type='ann_file', 86 | ann_file='data/few_shot_ann/coco/annotations/val.json') 87 | ], 88 | img_prefix=data_root, 89 | pipeline=test_pipeline, 90 | classes='ALL_CLASSES'), 91 | test=dict( 92 | type='FewShotCocoDataset', 93 | ann_cfg=[ 94 | dict( 95 | type='ann_file', 96 | ann_file='data/few_shot_ann/coco/annotations/val.json') 97 | ], 98 | img_prefix=data_root, 99 | pipeline=test_pipeline, 100 | test_mode=True, 101 | classes='ALL_CLASSES'), 102 | model_init=dict( 103 | copy_from_train_dataset=True, 104 | samples_per_gpu=16, 105 | workers_per_gpu=1, 106 | type='FewShotCocoDataset', 107 | ann_cfg=None, 108 | img_prefix=data_root, 109 | pipeline=train_multi_pipelines['support'], 110 | instance_wise=True, 111 | classes='ALL_CLASSES', 112 | dataset_name='model_init_dataset')) 113 | evaluation = dict( 114 | interval=3000, 115 | metric='bbox', 116 | classwise=True, 117 | class_splits=['BASE_CLASSES', 'NOVEL_CLASSES']) 118 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/few_shot_voc_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | # dict(type='Resize', img_scale=(1000, 600), keep_ratio=True), 9 | dict( 10 | type='Resize', 11 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), 12 | (1333, 768), (1333, 800)], 13 | keep_ratio=True, 14 | multiscale_mode='value'), 15 | dict(type='RandomFlip', flip_ratio=0.5), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='Pad', size_divisor=32), 18 | dict(type='DefaultFormatBundle'), 19 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 20 | ], 21 | support=[ 22 | dict(type='LoadImageFromFile'), 23 | dict(type='LoadAnnotations', with_bbox=True), 24 | dict( 25 | type='CropResizeInstance', 26 | num_context_pixels=16, 27 | target_size=(224, 224)), 28 | dict(type='Normalize', **img_norm_cfg), 29 | dict(type='GenerateMask', target_size=(224, 224)), 30 | dict(type='RandomFlip', flip_ratio=0.0), 31 | dict(type='DefaultFormatBundle'), 32 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 33 | ]) 34 | test_pipeline = [ 35 | dict(type='LoadImageFromFile'), 36 | dict( 37 | type='MultiScaleFlipAug', 38 | # img_scale=(1000, 600), 39 | img_scale=(1333, 800), 40 | flip=False, 41 | transforms=[ 42 | dict(type='Resize', keep_ratio=True), 43 | dict(type='RandomFlip'), 44 | dict(type='Normalize', **img_norm_cfg), 45 | dict(type='ImageToTensor', keys=['img']), 46 | dict(type='Collect', keys=['img']) 47 | ]) 48 | ] 49 | # classes splits are predefined in FewShotVOCDataset 50 | data_root = 'data/VOCdevkit/' 51 | data = dict( 52 | # samples_per_gpu=4, 53 | samples_per_gpu=6, 54 | workers_per_gpu=2, 55 | train=dict( 56 | type='NWayKShotDataset', 57 | num_support_ways=20, 58 | num_support_shots=1, 59 | one_support_shot_per_image=False, 60 | num_used_support_shots=30, 61 | save_dataset=True, 62 | dataset=dict( 63 | type='FewShotVOCDataset', 64 | ann_cfg=[ 65 | dict( 66 | type='ann_file', 67 | ann_file=data_root + 68 | 'VOC2007/ImageSets/Main/trainval.txt'), 69 | dict( 70 | type='ann_file', 71 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt'), 72 | ], 73 | img_prefix=data_root, 74 | multi_pipelines=train_multi_pipelines, 75 | classes=None, 76 | use_difficult=False, 77 | instance_wise=False, 78 | dataset_name='query_support_dataset')), 79 | val=dict( 80 | type='FewShotVOCDataset', 81 | ann_cfg=[ 82 | dict( 83 | type='ann_file', 84 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt'), 85 | ], 86 | img_prefix=data_root, 87 | pipeline=test_pipeline, 88 | classes=None), 89 | test=dict( 90 | type='FewShotVOCDataset', 91 | ann_cfg=[ 92 | dict( 93 | type='ann_file', 94 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt'), 95 | ], 96 | img_prefix=data_root, 97 | pipeline=test_pipeline, 98 | test_mode=True, 99 | classes=None), 100 | model_init=dict( 101 | copy_from_train_dataset=True, 102 | samples_per_gpu=16, 103 | workers_per_gpu=1, 104 | type='FewShotVOCDataset', 105 | ann_cfg=None, 106 | img_prefix=data_root, 107 | pipeline=train_multi_pipelines['support'], 108 | use_difficult=False, 109 | instance_wise=True, 110 | num_novel_shots=None, 111 | classes=None, 112 | dataset_name='model_init_dataset')) 113 | evaluation = dict(interval=3000, metric='mAP', class_splits=None) 114 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=5000) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | custom_hooks = [dict(type='NumClassCheckHook')] 11 | 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | use_infinite_sampler = True 18 | # a magical seed works well in most cases for this repo!!! 19 | # using different seeds might raise some issues about reproducibility 20 | seed = 42 21 | -------------------------------------------------------------------------------- /configs/_base_/models/faster_rcnn_r50_caffe_c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='BN', requires_grad=False) 3 | pretrained = 'open-mmlab://detectron2/resnet50_caffe' 4 | model = dict( 5 | type='FasterRCNN', 6 | pretrained=pretrained, 7 | backbone=dict( 8 | type='ResNet', 9 | depth=50, 10 | num_stages=3, 11 | strides=(1, 2, 2), 12 | dilations=(1, 1, 1), 13 | out_indices=(2, ), 14 | frozen_stages=1, 15 | norm_cfg=norm_cfg, 16 | norm_eval=True, 17 | style='caffe'), 18 | rpn_head=dict( 19 | type='RPNHead', 20 | in_channels=1024, 21 | feat_channels=1024, 22 | anchor_generator=dict( 23 | type='AnchorGenerator', 24 | scales=[2, 4, 8, 16, 32], 25 | ratios=[0.5, 1.0, 2.0], 26 | scale_major=False, 27 | strides=[16]), 28 | bbox_coder=dict( 29 | type='DeltaXYWHBBoxCoder', 30 | target_means=[.0, .0, .0, .0], 31 | target_stds=[1.0, 1.0, 1.0, 1.0]), 32 | loss_cls=dict( 33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 35 | roi_head=dict( 36 | type='StandardRoIHead', 37 | shared_head=dict( 38 | type='ResLayer', 39 | pretrained=pretrained, 40 | depth=50, 41 | stage=3, 42 | stride=2, 43 | dilation=1, 44 | style='caffe', 45 | norm_cfg=norm_cfg, 46 | norm_eval=True), 47 | bbox_roi_extractor=dict( 48 | type='SingleRoIExtractor', 49 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 50 | out_channels=1024, 51 | featmap_strides=[16]), 52 | bbox_head=dict( 53 | type='BBoxHead', 54 | with_avg_pool=True, 55 | roi_feat_size=7, 56 | in_channels=2048, 57 | num_classes=80, 58 | bbox_coder=dict( 59 | type='DeltaXYWHBBoxCoder', 60 | target_means=[0., 0., 0., 0.], 61 | target_stds=[0.1, 0.1, 0.2, 0.2]), 62 | reg_class_agnostic=False, 63 | loss_cls=dict( 64 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 65 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 66 | # model training and testing settings 67 | train_cfg=dict( 68 | rpn=dict( 69 | assigner=dict( 70 | type='MaxIoUAssigner', 71 | pos_iou_thr=0.7, 72 | neg_iou_thr=0.3, 73 | min_pos_iou=0.3, 74 | match_low_quality=True, 75 | ignore_iof_thr=-1), 76 | sampler=dict( 77 | type='RandomSampler', 78 | num=256, 79 | pos_fraction=0.5, 80 | neg_pos_ub=-1, 81 | add_gt_as_proposals=False), 82 | allowed_border=0, 83 | pos_weight=-1, 84 | debug=False), 85 | rpn_proposal=dict( 86 | nms_pre=12000, 87 | max_per_img=2000, 88 | nms=dict(type='nms', iou_threshold=0.7), 89 | min_bbox_size=0), 90 | rcnn=dict( 91 | assigner=dict( 92 | type='MaxIoUAssigner', 93 | pos_iou_thr=0.5, 94 | neg_iou_thr=0.5, 95 | min_pos_iou=0.5, 96 | match_low_quality=False, 97 | ignore_iof_thr=-1), 98 | sampler=dict( 99 | type='RandomSampler', 100 | num=512, 101 | pos_fraction=0.25, 102 | neg_pos_ub=-1, 103 | add_gt_as_proposals=True), 104 | pos_weight=-1, 105 | debug=False)), 106 | test_cfg=dict( 107 | rpn=dict( 108 | nms_pre=6000, 109 | max_per_img=1000, 110 | nms=dict(type='nms', iou_threshold=0.7), 111 | min_bbox_size=0), 112 | rcnn=dict( 113 | score_thr=0.05, 114 | nms=dict(type='nms', iou_threshold=0.5), 115 | max_per_img=100))) 116 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[60000, 80000]) 11 | runner = dict(type='IterBasedRunner', max_iters=90000) 12 | -------------------------------------------------------------------------------- /configs/fpd/coco/fpd_r101_c4_2xb4_coco_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/datasets/nway_kshot/few_shot_coco_ms.py', 3 | '../../_base_/schedules/schedule.py', '../fpd_r101_c4.py', 4 | '../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotCocoDataset 7 | # FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | num_used_support_shots=10, 12 | dataset=dict( 13 | type='FewShotCocoDefaultDataset', 14 | ann_cfg=[dict(method='MetaRCNN', setting='10SHOT')], 15 | num_novel_shots=10, 16 | num_base_shots=10, 17 | )), 18 | model_init=dict(num_novel_shots=10, num_base_shots=10)) 19 | 20 | evaluation = dict(interval=1000) 21 | checkpoint_config = dict(interval=1000) 22 | optimizer = dict(lr=0.001) 23 | lr_config = dict(warmup_iters=200) 24 | runner = dict(max_iters=10000) 25 | 26 | # load_from = 'path of base training model' 27 | load_from = \ 28 | 'work_dirs/fpd_r101_c4_2xb4_coco_base-training/latest.pth' 29 | 30 | # model settings 31 | model = dict( 32 | with_refine=True, 33 | frozen_parameters=['backbone', 'shared_head'], 34 | roi_head=dict( 35 | bbox_head=dict(num_classes=80, num_meta_classes=80), 36 | novel_class=(0, 1, 2, 3, 4, 5, 6, 8, 14, 15, 16, 17, 18, 19, 39, 56, 57, 58, 60, 62), 37 | num_novel=20, 38 | meta_cls_ratio=1.0, 39 | prototypes_distillation=dict(num_base_cls=60, num_novel=20)), 40 | ) 41 | -------------------------------------------------------------------------------- /configs/fpd/coco/fpd_r101_c4_2xb4_coco_30shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/datasets/nway_kshot/few_shot_coco_ms.py', 3 | '../../_base_/schedules/schedule.py', '../fpd_r101_c4.py', 4 | '../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotCocoDataset 7 | # FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | num_used_support_shots=30, 12 | dataset=dict( 13 | type='FewShotCocoDefaultDataset', 14 | ann_cfg=[dict(method='MetaRCNN', setting='30SHOT')], 15 | num_novel_shots=30, 16 | num_base_shots=30, 17 | )), 18 | model_init=dict(num_novel_shots=30, num_base_shots=30)) 19 | 20 | evaluation = dict(interval=8000) 21 | checkpoint_config = dict(interval=8000) 22 | optimizer = dict(lr=0.001) 23 | lr_config = dict(warmup=None) 24 | runner = dict(max_iters=18000) 25 | 26 | # load_from = 'path of base training model' 27 | load_from = \ 28 | 'work_dirs/fpd_r101_c4_2xb4_coco_base-training/latest.pth' 29 | 30 | model = dict( 31 | with_refine=True, 32 | frozen_parameters=['backbone', 'shared_head'], 33 | roi_head=dict( 34 | bbox_head=dict(num_classes=80, num_meta_classes=80), 35 | novel_class=(0, 1, 2, 3, 4, 5, 6, 8, 14, 15, 16, 17, 18, 19, 39, 56, 57, 58, 60, 62), 36 | num_novel=20, 37 | meta_cls_ratio=1.0, 38 | prototypes_distillation=dict(num_base_cls=60, num_novel=20)), 39 | ) 40 | 41 | -------------------------------------------------------------------------------- /configs/fpd/coco/fpd_r101_c4_2xb4_coco_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/datasets/nway_kshot/base_coco_ms.py', 3 | '../../_base_/schedules/schedule.py', '../fpd_r101_c4.py', 4 | '../../_base_/default_runtime.py' 5 | ] 6 | 7 | lr_config = dict(warmup_iters=1000, step=[92000]) 8 | evaluation = dict(interval=110000) 9 | checkpoint_config = dict(interval=55000) 10 | runner = dict(max_iters=110000) 11 | optimizer = dict(lr=0.004) 12 | 13 | # model settings 14 | model = dict( 15 | roi_head=dict(bbox_head=dict(num_classes=60, num_meta_classes=60), 16 | prototypes_distillation=dict(num_base_cls=60), 17 | num_novel=0, 18 | meta_cls_ratio=1.0), 19 | ) 20 | -------------------------------------------------------------------------------- /configs/fpd/fpd_r101_c4.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | './meta-rcnn_r50_c4.py', 3 | ] 4 | pretrained = 'open-mmlab://detectron2/resnet101_caffe' 5 | # model settings 6 | model = dict( 7 | type='FPD', 8 | post_rpn=True, 9 | pretrained=pretrained, 10 | backbone=dict(depth=101), 11 | roi_head=dict( 12 | type='FPDRoIHead', 13 | shared_head=dict(pretrained=pretrained), 14 | bbox_head=dict(num_classes=20, num_meta_classes=20), 15 | novel_class=(15, 16, 17, 18, 19), 16 | prototypes_distillation=dict( 17 | type='PrototypesDistillation', 18 | num_queries=5, dim=1024, num_base_cls=15), 19 | prototypes_assignment=dict( 20 | type='PrototypesAssignment', 21 | dim=1024, num_bg=5), 22 | ) 23 | 24 | ) 25 | -------------------------------------------------------------------------------- /configs/fpd/meta-rcnn_r50_c4.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/faster_rcnn_r50_caffe_c4.py', 3 | ] 4 | # model settings 5 | model = dict( 6 | type='MetaRCNN', 7 | backbone=dict(type='ResNetWithMetaConv', frozen_stages=2), 8 | rpn_head=dict( 9 | feat_channels=512, loss_cls=dict(use_sigmoid=False, loss_weight=1.0)), 10 | roi_head=dict( 11 | type='MetaRCNNRoIHead', 12 | shared_head=dict(type='MetaRCNNResLayer'), 13 | bbox_head=dict( 14 | type='MetaBBoxHead', 15 | with_avg_pool=False, 16 | in_channels=2048, 17 | roi_feat_size=1, 18 | num_classes=80, 19 | num_meta_classes=80, 20 | meta_cls_in_channels=2048, 21 | with_meta_cls_loss=True, 22 | loss_meta=dict( 23 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 24 | loss_bbox=dict(type='SmoothL1Loss', loss_weight=1.0)), 25 | aggregation_layer=dict( 26 | type='AggregationLayer', 27 | aggregator_cfgs=[ 28 | dict( 29 | type='DotProductAggregator', 30 | in_channels=2048, 31 | with_fc=False) 32 | ])), 33 | train_cfg=dict( 34 | rpn=dict( 35 | assigner=dict( 36 | type='MaxIoUAssigner', 37 | pos_iou_thr=0.7, 38 | neg_iou_thr=0.3, 39 | min_pos_iou=0.3, 40 | match_low_quality=True, 41 | ignore_iof_thr=-1), 42 | sampler=dict( 43 | type='RandomSampler', 44 | num=256, 45 | pos_fraction=0.5, 46 | neg_pos_ub=-1, 47 | add_gt_as_proposals=False), 48 | allowed_border=0, 49 | pos_weight=-1, 50 | debug=False), 51 | rpn_proposal=dict( 52 | nms_pre=12000, 53 | max_per_img=2000, 54 | nms=dict(type='nms', iou_threshold=0.7), 55 | min_bbox_size=0), 56 | rcnn=dict( 57 | assigner=dict( 58 | type='MaxIoUAssigner', 59 | pos_iou_thr=0.5, 60 | neg_iou_thr=0.5, 61 | min_pos_iou=0.5, 62 | match_low_quality=False, 63 | ignore_iof_thr=-1), 64 | sampler=dict( 65 | type='RandomSampler', 66 | num=128, 67 | pos_fraction=0.25, 68 | neg_pos_ub=-1, 69 | add_gt_as_proposals=True), 70 | pos_weight=-1, 71 | debug=False)), 72 | test_cfg=dict( 73 | rpn=dict( 74 | nms_pre=6000, 75 | max_per_img=300, 76 | nms=dict(type='nms', iou_threshold=0.7), 77 | min_bbox_size=0), 78 | rcnn=dict( 79 | score_thr=0.05, 80 | nms=dict(type='nms', iou_threshold=0.3), 81 | max_per_img=100))) 82 | -------------------------------------------------------------------------------- /configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_10SHOT')], 14 | num_novel_shots=10, 15 | num_base_shots=10, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | 22 | evaluation = dict( 23 | interval=600, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 24 | checkpoint_config = dict(interval=600) 25 | optimizer = dict(lr=0.001) 26 | lr_config = dict(warmup=None) 27 | runner = dict(max_iters=3000) 28 | 29 | # load_from = 'path of base training model' 30 | load_from = \ 31 | 'work_dirs/fpd_r101_c4_2xb4_voc-split1_base-training/latest.pth' 32 | # model settings 33 | model = dict( 34 | frozen_parameters=['backbone', 'shared_head'], 35 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 36 | prototypes_distillation=dict(num_novel=5)) 37 | ) 38 | 39 | -------------------------------------------------------------------------------- /configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_1shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_1SHOT')], 14 | num_novel_shots=1, 15 | num_base_shots=1, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=200, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=200) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=800) 27 | 28 | # load_from = 'path of base training model' 29 | load_from = \ 30 | 'work_dirs/fpd_r101_c4_2xb4_voc-split1_base-training/latest.pth' 31 | 32 | # model settings 33 | model = dict( 34 | frozen_parameters=['backbone', 'shared_head', 'rpn_head.rpn_conv', 'roi_head.rpn_head.rpn_conv'], 35 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 36 | prototypes_distillation=dict(num_novel=5)) 37 | ) 38 | -------------------------------------------------------------------------------- /configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_2shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_2SHOT')], 14 | num_novel_shots=2, 15 | num_base_shots=2, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=300, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=300) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1200) 27 | 28 | # load_from = 'path of base training model' 29 | load_from = \ 30 | 'work_dirs/fpd_r101_c4_2xb4_voc-split1_base-training/latest.pth' 31 | 32 | # model settings 33 | model = dict( 34 | frozen_parameters=['backbone', 'shared_head', 'rpn_head.rpn_conv', 'roi_head.rpn_head.rpn_conv'], 35 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 36 | prototypes_distillation=dict(num_novel=5)) 37 | ) 38 | 39 | -------------------------------------------------------------------------------- /configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_3shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_3SHOT')], 14 | num_novel_shots=3, 15 | num_base_shots=3, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=400, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=400) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1600) 27 | 28 | # load_from = 'path of base training model' 29 | load_from = \ 30 | 'work_dirs/fpd_r101_c4_2xb4_voc-split1_base-training/latest.pth' 31 | 32 | # model settings 33 | model = dict( 34 | frozen_parameters=['backbone', 'shared_head'], 35 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 36 | prototypes_distillation=dict(num_novel=5)) 37 | ) -------------------------------------------------------------------------------- /configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_5shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_5SHOT')], 14 | num_novel_shots=5, 15 | num_base_shots=5, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | 22 | evaluation = dict( 23 | interval=500, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 24 | checkpoint_config = dict(interval=500) 25 | optimizer = dict(lr=0.001) 26 | lr_config = dict(warmup=None) 27 | runner = dict(max_iters=2000) 28 | # load_from = 'path of base training model' 29 | load_from = \ 30 | 'work_dirs/fpd_r101_c4_2xb4_voc-split1_base-training/latest.pth' 31 | 32 | # model settings 33 | model = dict( 34 | frozen_parameters=['backbone', 'shared_head'], 35 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 36 | prototypes_distillation=dict(num_novel=5)) 37 | ) -------------------------------------------------------------------------------- /configs/fpd/voc/split1/fpd_r101_c4_2xb4_voc-split1_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=False, 11 | dataset=dict(classes='BASE_CLASSES_SPLIT1'), 12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT1')), 13 | val=dict(classes='BASE_CLASSES_SPLIT1'), 14 | test=dict(classes='BASE_CLASSES_SPLIT1'), 15 | model_init=dict(classes='BASE_CLASSES_SPLIT1')) 16 | 17 | lr_config = dict(warmup_iters=500, step=[17000]) 18 | evaluation = dict(interval=20000) 19 | checkpoint_config = dict(interval=20000) 20 | runner = dict(max_iters=20000) 21 | optimizer = dict(lr=0.005) 22 | 23 | model = dict( 24 | roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15), 25 | num_novel=0, meta_cls_ratio=1.0)) 26 | -------------------------------------------------------------------------------- /configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_10SHOT')], 14 | num_novel_shots=10, 15 | num_base_shots=10, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | 22 | evaluation = dict( 23 | interval=600, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 24 | checkpoint_config = dict(interval=600) 25 | optimizer = dict(lr=0.001) 26 | lr_config = dict(warmup=None) 27 | runner = dict(max_iters=3000) 28 | 29 | # load_from = 'path of base training model' 30 | load_from = \ 31 | 'work_dirs/fpd_r101_c4_2xb4_voc-split2_base-training/latest.pth' 32 | 33 | model = dict( 34 | frozen_parameters=['backbone', 'shared_head'], 35 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 36 | prototypes_distillation=dict(num_novel=5)) 37 | ) 38 | -------------------------------------------------------------------------------- /configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_1shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_1SHOT')], 14 | num_novel_shots=1, 15 | num_base_shots=1, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=200, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=200) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=800) 27 | # load_from = 'path of base training model' 28 | load_from = \ 29 | 'work_dirs/fpd_r101_c4_2xb4_voc-split2_base-training/latest.pth' 30 | # model settings 31 | model = dict( 32 | frozen_parameters=['backbone', 'shared_head', 'rpn_head.rpn_conv', 'roi_head.rpn_head.rpn_conv'], 33 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 34 | prototypes_distillation=dict(num_novel=5)) 35 | ) 36 | 37 | -------------------------------------------------------------------------------- /configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_2shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_2SHOT')], 14 | num_novel_shots=2, 15 | num_base_shots=2, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=300, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=300) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1200) 27 | # load_from = 'path of base training model' 28 | load_from = \ 29 | 'work_dirs/fpd_r101_c4_2xb4_voc-split2_base-training/latest.pth' 30 | # model settings 31 | model = dict( 32 | frozen_parameters=['backbone', 'shared_head', 'rpn_head.rpn_conv', 'roi_head.rpn_head.rpn_conv'], 33 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 34 | prototypes_distillation=dict(num_novel=5)) 35 | ) 36 | -------------------------------------------------------------------------------- /configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_3shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_3SHOT')], 14 | num_novel_shots=3, 15 | num_base_shots=3, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=400, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=400) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1600) 27 | 28 | # load_from = 'path of base training model' 29 | load_from = \ 30 | 'work_dirs/fpd_r101_c4_2xb4_voc-split2_base-training/latest.pth' 31 | 32 | # model settings 33 | model = dict( 34 | frozen_parameters=['backbone', 'shared_head'], 35 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 36 | prototypes_distillation=dict(num_novel=5)) 37 | ) 38 | 39 | -------------------------------------------------------------------------------- /configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_5shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_5SHOT')], 14 | num_novel_shots=5, 15 | num_base_shots=5, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | 22 | evaluation = dict( 23 | interval=500, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 24 | checkpoint_config = dict(interval=500) 25 | optimizer = dict(lr=0.001) 26 | lr_config = dict(warmup=None) 27 | runner = dict(max_iters=2000) 28 | 29 | # load_from = 'path of base training model' 30 | load_from = \ 31 | 'work_dirs/fpd_r101_c4_2xb4_voc-split2_base-training/latest.pth' 32 | 33 | # model settings 34 | model = dict( 35 | frozen_parameters=['backbone', 'shared_head'], 36 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 37 | prototypes_distillation=dict(num_novel=5)) 38 | ) 39 | -------------------------------------------------------------------------------- /configs/fpd/voc/split2/fpd_r101_c4_2xb4_voc-split2_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=False, 11 | dataset=dict(classes='BASE_CLASSES_SPLIT2'), 12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT2')), 13 | val=dict(classes='BASE_CLASSES_SPLIT2'), 14 | test=dict(classes='BASE_CLASSES_SPLIT2'), 15 | model_init=dict(classes='BASE_CLASSES_SPLIT2')) 16 | lr_config = dict(warmup_iters=500, step=[17000]) 17 | evaluation = dict(interval=20000) 18 | checkpoint_config = dict(interval=20000) 19 | runner = dict(max_iters=20000) 20 | optimizer = dict(lr=0.005) 21 | 22 | model = dict( 23 | roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15), 24 | num_novel=0, meta_cls_ratio=1.0)) -------------------------------------------------------------------------------- /configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_10SHOT')], 14 | num_novel_shots=10, 15 | num_base_shots=10, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | 22 | evaluation = dict( 23 | interval=600, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 24 | checkpoint_config = dict(interval=600) 25 | optimizer = dict(lr=0.001) 26 | lr_config = dict(warmup=None) 27 | runner = dict(max_iters=3000) 28 | # load_from = 'path of base training model' 29 | load_from = \ 30 | 'work_dirs/fpd_r101_c4_2xb4_voc-split3_base-training/latest.pth' 31 | # model settings 32 | model = dict( 33 | frozen_parameters=['backbone', 'shared_head'], 34 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 35 | prototypes_distillation=dict(num_novel=5)) 36 | ) 37 | -------------------------------------------------------------------------------- /configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_1shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_1SHOT')], 14 | num_novel_shots=1, 15 | num_base_shots=1, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=200, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=200) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=800) 27 | # load_from = 'path of base training model' 28 | load_from = \ 29 | 'work_dirs/fpd_r101_c4_2xb4_voc-split3_base-training/latest.pth' 30 | # model settings 31 | model = dict( 32 | frozen_parameters=['backbone', 'shared_head', 'rpn_head.rpn_conv', 'roi_head.rpn_head.rpn_conv'], 33 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 34 | prototypes_distillation=dict(num_novel=5)) 35 | ) 36 | -------------------------------------------------------------------------------- /configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_2shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_2SHOT')], 14 | num_novel_shots=2, 15 | num_base_shots=2, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=300, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=300) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1200) 27 | # load_from = 'path of base training model' 28 | load_from = \ 29 | 'work_dirs/fpd_r101_c4_2xb4_voc-split3_base-training/latest.pth' 30 | # model settings 31 | model = dict( 32 | frozen_parameters=['backbone', 'shared_head', 'rpn_head.rpn_conv', 'roi_head.rpn_head.rpn_conv'], 33 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 34 | prototypes_distillation=dict(num_novel=5)) 35 | ) 36 | -------------------------------------------------------------------------------- /configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_3shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_3SHOT')], 14 | num_novel_shots=3, 15 | num_base_shots=3, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=400, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=400) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1600) 27 | # load_from = 'path of base training model' 28 | load_from = \ 29 | 'work_dirs/fpd_r101_c4_2xb4_voc-split3_base-training/latest.pth' 30 | # model settings 31 | model = dict( 32 | frozen_parameters=['backbone', 'shared_head'], 33 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 34 | prototypes_distillation=dict(num_novel=5)) 35 | ) 36 | -------------------------------------------------------------------------------- /configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_5shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_5SHOT')], 14 | num_novel_shots=5, 15 | num_base_shots=5, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=500, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=500) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=2000) 27 | # load_from = 'path of base training model' 28 | load_from = \ 29 | 'work_dirs/fpd_r101_c4_2xb4_voc-split3_base-training/latest.pth' 30 | # model settings 31 | model = dict( 32 | frozen_parameters=['backbone', 'shared_head'], 33 | roi_head=dict(num_novel=5, meta_cls_ratio=1.0, 34 | prototypes_distillation=dict(num_novel=5)) 35 | ) 36 | -------------------------------------------------------------------------------- /configs/fpd/voc/split3/fpd_r101_c4_2xb4_voc-split3_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../fpd_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=False, 11 | dataset=dict(classes='BASE_CLASSES_SPLIT3'), 12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT3')), 13 | val=dict(classes='BASE_CLASSES_SPLIT3'), 14 | test=dict(classes='BASE_CLASSES_SPLIT3'), 15 | model_init=dict(classes='BASE_CLASSES_SPLIT3')) 16 | 17 | lr_config = dict(warmup_iters=500, step=[17000]) 18 | evaluation = dict(interval=20000) 19 | checkpoint_config = dict(interval=20000) 20 | runner = dict(max_iters=20000) 21 | optimizer = dict(lr=0.005) 22 | 23 | # model settings 24 | model = dict( 25 | roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15), 26 | num_novel=0, meta_cls_ratio=1.0)) -------------------------------------------------------------------------------- /dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | CONFIG=$1 3 | CHECKPOINT=$2 4 | GPUS=$3 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \ 11 | python -m torch.distributed.launch \ 12 | --nnodes=$NNODES \ 13 | --node_rank=$NODE_RANK \ 14 | --master_addr=$MASTER_ADDR \ 15 | --nproc_per_node=$GPUS \ 16 | --master_port=$PORT \ 17 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 18 | -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-29500} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \ 11 | python -m torch.distributed.launch \ 12 | --nnodes=$NNODES \ 13 | --node_rank=$NODE_RANK \ 14 | --master_addr=$MASTER_ADDR \ 15 | --nproc_per_node=$GPUS \ 16 | --master_port=$PORT \ 17 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 18 | -------------------------------------------------------------------------------- /fpd/__init__.py: -------------------------------------------------------------------------------- 1 | from .ffa import PrototypesDistillation, PrototypesAssignment 2 | from .fpd_roi_head import FPDRoIHead 3 | from .fpd_detector import FPD 4 | from .transforms import CropResizeInstanceByRatio 5 | 6 | __all__ = ['FPD', 'FPDRoIHead', 'PrototypesDistillation', 'PrototypesAssignment'] 7 | -------------------------------------------------------------------------------- /fpd/ffa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from torch import nn 3 | from mmfewshot.detection.models.utils.aggregation_layer import AGGREGATORS 4 | from mmcv.runner import BaseModule 5 | 6 | import math 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | 11 | @AGGREGATORS.register_module() 12 | class PrototypesDistillation(BaseModule): 13 | def __init__(self, num_queries, dim, num_base_cls=15, num_novel=0): 14 | super().__init__() 15 | self.num_queries = num_queries 16 | self.num_base_cls = num_base_cls 17 | self.num_novel = num_novel 18 | self.dim = dim 19 | 20 | k_dim = dim // 4 21 | self.k_dim = k_dim 22 | 23 | self.query_embed = nn.Embedding(num_queries*num_base_cls, k_dim) 24 | self.duplicated = False 25 | if self.num_novel > 0: 26 | self.query_embed_novel = nn.Embedding(num_queries * num_novel, k_dim) 27 | self.w_qk = nn.Linear(dim, k_dim, bias=False) 28 | 29 | def forward(self, support_feats, support_gt_labels=None, forward_novel=False, forward_novel_test=False): 30 | """ 31 | Args: 32 | support_feats: Tensor with shape (B, C, H, W). 33 | support_gt_labels: Support gt labels. 34 | forward_novel (bool): Novel classes. 35 | forward_novel_test (bool): Test time. 36 | Returns: 37 | tensor with shape (15, 1024) 38 | """ 39 | 40 | # at the fine-tuning stage, duplicate the most compatible feature queries for the novel classes 41 | # ************************************************************ 42 | if not self.duplicated and self.num_novel > 0 and not forward_novel_test and forward_novel: 43 | with torch.no_grad(): 44 | support_feats_mp = F.max_pool2d(support_feats, kernel_size=2, stride=2) 45 | B, C, H, W = support_feats_mp.shape 46 | k = support_feats_mp.reshape(B, C, H*W).permute(0, 2, 1) # (B, 196, 1024) 47 | k = self.w_qk(k) # (B, 196, 1024) 48 | query_emb = self.query_embed.weight 49 | q = query_emb.unsqueeze(0).repeat(B, 1, 1) 50 | B, Nt, E = q.shape 51 | attn = torch.bmm(q / math.sqrt(E), k.transpose(-2, -1)) 52 | weight = torch.topk(attn, 20, dim=-1)[0].mean(-1) 53 | 54 | drop = 5 55 | top_indices = torch.topk(weight, self.num_queries + drop, dim=-1)[1][:, -self.num_queries:] 56 | top_emb = torch.gather(self.query_embed.weight.unsqueeze(0).expand(B, -1, -1), 1, top_indices.unsqueeze(-1).expand(-1, -1, self.k_dim)) 57 | top_emb = top_emb[torch.sort(support_gt_labels, dim=0)[1]].reshape(self.num_novel*self.num_queries, self.k_dim) 58 | self.query_embed_novel.weight.copy_(top_emb) 59 | self.duplicated = True # set as True once duplicated 60 | # ************************************************************ 61 | 62 | support_feats_mp = F.max_pool2d(support_feats, kernel_size=2, stride=2) 63 | B, C, H, W = support_feats_mp.shape 64 | k = v = support_feats_mp.reshape(B, C, H*W).permute(0, 2, 1) # (B, 196, 1024) 65 | k = self.w_qk(k) # (B, 196, 1024) 66 | 67 | # scaled dot-product attention 68 | if forward_novel: 69 | query_emb = self.query_embed_novel.weight 70 | q = query_emb.reshape(self.num_novel, self.num_queries, query_emb.size(-1)) 71 | else: 72 | query_emb = self.query_embed.weight 73 | q = query_emb.reshape(self.num_base_cls, self.num_queries, query_emb.size(-1)) 74 | q = q[support_gt_labels, ...] # align with support_gt_labels 75 | B, Nt, E = q.shape 76 | attn = torch.bmm(q / math.sqrt(E), k.transpose(-2, -1)) 77 | weight = torch.topk(attn, 20, dim=-1)[0].mean(-1) 78 | prototypes = torch.matmul(attn.softmax(-1), v) # (B, 5, 1024) 79 | 80 | return weight, prototypes 81 | 82 | 83 | @AGGREGATORS.register_module() 84 | class PrototypesAssignment(BaseModule): 85 | def __init__(self, dim, num_bg=5): 86 | super().__init__() 87 | k_dim = dim // 4 88 | self.w_qk = nn.Linear(dim, k_dim, bias=False) 89 | 90 | self.num_bg = num_bg 91 | if self.num_bg > 0: 92 | self.dummy = nn.Parameter(torch.Tensor(self.num_bg, dim)) 93 | nn.init.normal_(self.dummy) 94 | self.linear = nn.Linear(dim, k_dim) 95 | self.gamma = nn.Parameter(torch.tensor(0.)) 96 | 97 | def forward(self, query_feature, prototypes, query_img_metas=None): 98 | """ 99 | Args: 100 | query_feature: Tensor with shape (B, C, H, W) 101 | prototypes: Tensor with shape (num_supp, num_queries, C), 102 | query_img_metas: Visualization. 103 | Returns: 104 | class-specific query feature: tensor(B, C, H, W) 105 | """ 106 | 107 | B, C, H, W = query_feature.shape 108 | num_supp, num_queries, _ = prototypes.shape 109 | q = query_feature.reshape(B, C, H*W).permute(0, 2, 1) # (B, H*W, 1024) 110 | k = v = prototypes.reshape(num_supp * num_queries, C) 111 | 112 | q = self.w_qk(q) 113 | k = self.w_qk(k) 114 | 115 | if self.num_bg > 0: 116 | dummy_v = torch.zeros((self.num_bg, C), device='cuda') 117 | k = torch.cat([k, self.linear(self.dummy)], dim=0) 118 | v = torch.cat([v, dummy_v], dim=0) 119 | 120 | k = k.unsqueeze(0) 121 | B, Nt, E = q.shape 122 | attn = torch.bmm(q / math.sqrt(E), k.expand(B, -1, -1).transpose(-2, -1)) 123 | attn.div_(0.5) 124 | 125 | out = torch.matmul(attn.softmax(-1), v) # (B, 2850, 1024) 126 | out = out.permute(0, 2, 1).contiguous().view(B, C, H, W) 127 | out = query_feature + self.gamma * out 128 | return out 129 | 130 | -------------------------------------------------------------------------------- /fpd/fpd_detector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import sys 4 | from abc import ABC 5 | from typing import Dict, List, Optional 6 | 7 | import torch 8 | from mmcv.runner import auto_fp16 9 | from mmcv.utils import ConfigDict 10 | from mmdet.models.builder import DETECTORS 11 | from torch import Tensor 12 | 13 | from .query_support import QuerySupportDetectorFPD 14 | from .utils import TestMixins 15 | 16 | @DETECTORS.register_module() 17 | class FPD(QuerySupportDetectorFPD, TestMixins): 18 | """Implementation of `FPD. `_. 19 | Args: 20 | backbone (dict): Config of the backbone for query data. 21 | neck (dict | None): Config of the neck for query data and 22 | probably for support data. Default: None. 23 | support_backbone (dict | None): Config of the backbone for 24 | support data only. If None, support and query data will 25 | share same backbone. Default: None. 26 | support_neck (dict | None): Config of the neck for support 27 | data only. Default: None. 28 | rpn_head (dict | None): Config of rpn_head. Default: None. 29 | roi_head (dict | None): Config of roi_head. Default: None. 30 | train_cfg (dict | None): Training config. Useless in CenterNet, 31 | but we keep this variable for SingleStageDetector. Default: None. 32 | test_cfg (dict | None): Testing config of CenterNet. Default: None. 33 | pretrained (str | None): model pretrained path. Default: None. 34 | init_cfg (dict | list[dict] | None): Initialization config dict. 35 | Default: None 36 | """ 37 | 38 | def __init__(self, 39 | backbone: ConfigDict, 40 | neck: Optional[ConfigDict] = None, 41 | support_backbone: Optional[ConfigDict] = None, 42 | support_neck: Optional[ConfigDict] = None, 43 | rpn_head: Optional[ConfigDict] = None, 44 | roi_head: Optional[ConfigDict] = None, 45 | train_cfg: Optional[ConfigDict] = None, 46 | test_cfg: Optional[ConfigDict] = None, 47 | pretrained: Optional[ConfigDict] = None, 48 | init_cfg: Optional[ConfigDict] = None, 49 | post_rpn=True, 50 | with_refine=False, 51 | ) -> None: 52 | super().__init__( 53 | backbone=backbone, 54 | neck=neck, 55 | support_backbone=support_backbone, 56 | support_neck=support_neck, 57 | rpn_head=rpn_head, 58 | roi_head=roi_head, 59 | train_cfg=train_cfg, 60 | test_cfg=test_cfg, 61 | pretrained=pretrained, 62 | init_cfg=init_cfg, 63 | post_rpn=post_rpn) 64 | 65 | self.is_model_init = False 66 | # save support template features for model initialization, 67 | # `_forward_saved_support_dict` used in :func:`forward_model_init`. 68 | self._forward_saved_support_dict = { 69 | 'gt_labels': [], 70 | 'roi_feats': [], 71 | } 72 | # save processed support template features for inference, 73 | # the processed support template features are generated 74 | # in :func:`model_init` 75 | self.inference_support_dict = {} 76 | 77 | self._new_forward_saved_support_dict = { 78 | 'gt_labels': [], 79 | 'weight': [], 80 | 'prototypes': [], # fine-grained prototypes 81 | 'prototypes_novel': [], 82 | } 83 | self.new_inference_support_dict = { 84 | 'weight': {}, 85 | 'prototypes': {}, 86 | 'prototypes_novel': {}, 87 | } 88 | 89 | # refine results for COCO. We do not use it for VOC. 90 | self.with_refine = with_refine 91 | 92 | @auto_fp16(apply_to=('img',)) 93 | def extract_support_feat(self, img): 94 | """Extracting features from support data. 95 | Args: 96 | img (Tensor): Input images of shape (N, C, H, W). 97 | Typically these should be mean centered and std scaled. 98 | Returns: 99 | list[Tensor]: Features of input image, each item with shape 100 | (N, C, H, W). 101 | """ 102 | feats = self.backbone(img, use_meta_conv=True) 103 | if self.support_neck is not None: 104 | feats = self.support_neck(feats) 105 | return feats 106 | 107 | def forward_model_init(self, 108 | img: Tensor, 109 | img_metas: List[Dict], 110 | gt_bboxes: List[Tensor] = None, 111 | gt_labels: List[Tensor] = None, 112 | **kwargs): 113 | """extract and save support features for model initialization. 114 | 115 | Args: 116 | img (Tensor): Input images of shape (N, C, H, W). 117 | Typically these should be mean centered and std scaled. 118 | img_metas (list[dict]): list of image info dict where each dict 119 | has: `img_shape`, `scale_factor`, `flip`, and may also contain 120 | `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. 121 | For details on the values of these keys see 122 | :class:`mmdet.datasets.pipelines.Collect`. 123 | gt_bboxes (list[Tensor]): Ground truth bboxes for each image with 124 | shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format. 125 | gt_labels (list[Tensor]): class indices corresponding to each box. 126 | 127 | Returns: 128 | dict: A dict contains following keys: 129 | 130 | - `gt_labels` (Tensor): class indices corresponding to each 131 | feature. 132 | - `res5_rois` (list[Tensor]): roi features of res5 layer. 133 | """ 134 | # `is_model_init` flag will be reset while forwarding new data. 135 | self.is_model_init = False 136 | assert len(gt_labels) == img.size(0), \ 137 | 'Support instance have more than two labels' 138 | 139 | feats = self.extract_support_feat(img) 140 | self._forward_saved_support_dict['gt_labels'].extend(gt_labels) 141 | roi_feat = self.roi_head.extract_support_feats(feats) # (16,2048) 142 | self._forward_saved_support_dict['roi_feats'].extend(roi_feat) 143 | 144 | gt_labels_ = torch.cat(gt_labels) 145 | if getattr(self.roi_head, 'prototypes_distillation', False): 146 | if self.roi_head.num_novel != 0: 147 | num_classes = self.roi_head.bbox_head.num_classes 148 | nc = torch.tensor(self.roi_head.novel_class, device='cuda') 149 | n_ids = torch.isin(gt_labels_, nc) 150 | b_ids = torch.logical_not(n_ids) 151 | b_gts = gt_labels_[b_ids] 152 | n_gts = gt_labels_[n_ids] 153 | bc = torch.tensor(list(set(list(range(num_classes))) - set(self.roi_head.novel_class)), device='cuda') 154 | r_b_gts = torch.cat([torch.argwhere(bc == item)[0] for item in b_gts], dim=0) \ 155 | if len(b_gts) else b_gts # relative gt_labels 156 | r_n_gts = torch.cat([torch.argwhere(nc == item)[0] for item in n_gts], dim=0) \ 157 | if len(n_gts) else n_gts # relative gt_labels 158 | base_support_feats = feats[0][b_ids] 159 | novel_support_feats = feats[0][n_ids] 160 | 161 | # prototypes distillation 162 | weight_base, prototypes_base = self.roi_head.prototypes_distillation( 163 | base_support_feats, support_gt_labels=r_b_gts) 164 | weight_novel, prototypes_novel = self.roi_head.prototypes_distillation( 165 | novel_support_feats, support_gt_labels=r_n_gts, forward_novel=True, forward_novel_test=True) 166 | prototypes = torch.cat([prototypes_base, prototypes_novel], dim=0) 167 | prototypes[b_ids] = prototypes_base 168 | prototypes[n_ids] = prototypes_novel 169 | weight = torch.cat([weight_base, weight_novel], dim=0) 170 | weight[b_ids] = weight_base 171 | weight[n_ids] = weight_novel 172 | else: 173 | weight, prototypes = self.roi_head.prototypes_distillation(feats[0], support_gt_labels=gt_labels_) 174 | self._new_forward_saved_support_dict['gt_labels'].extend(gt_labels) 175 | self._new_forward_saved_support_dict['weight'].extend([weight]) 176 | self._new_forward_saved_support_dict['prototypes'].extend([prototypes]) 177 | 178 | return {'gt_labels': gt_labels, 'roi_feat': roi_feat} 179 | 180 | def model_init(self): 181 | pass 182 | 183 | def fpd_model_init(self): 184 | """process the saved support features for model initialization.""" 185 | gt_labels = torch.cat(self._forward_saved_support_dict['gt_labels']) 186 | class_ids = set(gt_labels.data.tolist()) 187 | 188 | roi_feats = torch.cat(self._forward_saved_support_dict['roi_feats']) 189 | self.inference_support_dict.clear() 190 | for class_id in class_ids: 191 | self.inference_support_dict[class_id] = roi_feats[ 192 | gt_labels == class_id].mean([0], True) 193 | self.is_model_init = True 194 | # reset support features buff 195 | for k in self._forward_saved_support_dict.keys(): 196 | self._forward_saved_support_dict[k].clear() 197 | 198 | if getattr(self.roi_head, 'prototypes_distillation', False): 199 | weight = torch.cat(self._new_forward_saved_support_dict['weight']) 200 | prototypes = torch.cat(self._new_forward_saved_support_dict['prototypes']) 201 | for k in self.new_inference_support_dict.keys(): 202 | self.new_inference_support_dict[k].clear() 203 | for class_id in class_ids: 204 | self.new_inference_support_dict['weight'][class_id] = weight[ 205 | gt_labels == class_id].mean([0], True) 206 | self.new_inference_support_dict['prototypes'][class_id] = prototypes[ 207 | gt_labels == class_id].mean([0], True) 208 | 209 | ws = 0.5 210 | weighted_sum = True # test time natural integration 211 | if weighted_sum: 212 | prototypes_c = prototypes[gt_labels == class_id] 213 | weight_c = weight[gt_labels == class_id][:, :prototypes_c.size(1)].div(ws).softmax(0).unsqueeze(-1) 214 | prototypes_c = prototypes_c.mul(weight_c).sum(0, True) 215 | self.new_inference_support_dict['prototypes'][class_id] = prototypes_c 216 | 217 | if self.roi_head.num_novel != 0: 218 | nc = torch.tensor(self.roi_head.novel_class, device='cuda') 219 | n_ids = torch.isin(torch.arange(len(class_ids), device='cuda'), nc) 220 | b_ids = torch.logical_not(n_ids) 221 | 222 | # weight = torch.cat( 223 | # [self.new_inference_support_dict['weight'][class_id] for class_id in class_ids], dim=0) 224 | prototypes = torch.cat( 225 | [self.new_inference_support_dict['prototypes'][class_id] for class_id in class_ids], dim=0) 226 | prototypes_base = prototypes[b_ids][:, :self.roi_head.prototypes_distillation.num_queries, :] 227 | prototypes_novel = prototypes[n_ids] # (5, 5, 1024) 228 | # weight_novel = weight[n_ids] # (5, 5) 229 | self.new_inference_support_dict['prototypes'][0] = prototypes_base 230 | self.new_inference_support_dict['prototypes_novel'][0] = prototypes_novel 231 | # prototypes = torch.cat([prototypes_base, prototypes_novel], 0) 232 | else: 233 | prototypes = torch.cat( 234 | [self.new_inference_support_dict['prototypes'][class_id] for class_id in class_ids], dim=0) 235 | self.new_inference_support_dict['prototypes'][0] = prototypes 236 | keys = list(self.new_inference_support_dict['prototypes'].keys()) 237 | for k in keys: 238 | if k != 0: 239 | self.new_inference_support_dict['prototypes'].pop(k) 240 | 241 | for k in self._new_forward_saved_support_dict.keys(): 242 | self._new_forward_saved_support_dict[k].clear() 243 | 244 | def simple_test(self, 245 | img: Tensor, 246 | img_metas: List[Dict], 247 | proposals: Optional[List[Tensor]] = None, 248 | rescale: bool = False): 249 | """Test without augmentation. 250 | Args: 251 | img (Tensor): Input images of shape (N, C, H, W). 252 | Typically these should be mean centered and std scaled. 253 | img_metas (list[dict]): list of image info dict where each dict 254 | has: `img_shape`, `scale_factor`, `flip`, and may also contain 255 | `filename`, `ori_shape`, `pad_shape`, and `img_norm_cfg`. 256 | For details on the values of these keys see 257 | :class:`mmdet.datasets.pipelines.Collect`. 258 | proposals (list[Tensor] | None): override rpn proposals with 259 | custom proposals. Use when `with_rpn` is False. Default: None. 260 | rescale (bool): If True, return boxes in original image space. 261 | Returns: 262 | list[list[np.ndarray]]: BBox results of each image and classes. 263 | The outer list corresponds to each image. The inner list 264 | corresponds to each class. 265 | """ 266 | assert self.with_bbox, 'Bbox head must be implemented.' 267 | assert len(img_metas) == 1, 'Only support single image inference.' 268 | if not self.is_model_init: 269 | # process the saved support features 270 | self.fpd_model_init() 271 | 272 | query_feats = self.extract_feat(img) 273 | 274 | proposal_list = None 275 | if proposals is None: 276 | if not self.post_rpn: 277 | proposal_list = self.rpn_head.simple_test(query_feats, img_metas) 278 | else: 279 | proposal_list = proposals 280 | 281 | bbox_results = self.roi_head.simple_test( 282 | query_feats, 283 | copy.deepcopy(self.inference_support_dict), 284 | copy.deepcopy(self.new_inference_support_dict), 285 | proposal_list, 286 | img_metas, 287 | rescale=rescale) 288 | if self.with_refine: 289 | return self.refine_test(bbox_results, img_metas) 290 | else: 291 | return bbox_results 292 | -------------------------------------------------------------------------------- /fpd/fpd_roi_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import sys 4 | from typing import Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from mmcv.utils import ConfigDict 9 | from mmdet.core import bbox2result, bbox2roi 10 | from mmdet.models.builder import HEADS 11 | from mmdet.models.roi_heads import StandardRoIHead 12 | from torch import Tensor 13 | 14 | from mmfewshot.detection.models.utils.aggregation_layer import build_aggregator 15 | from mmdet.models.builder import build_head 16 | import torch.nn as nn 17 | 18 | 19 | @HEADS.register_module() 20 | class FPDRoIHead(StandardRoIHead): 21 | """Roi head for `FPD `. 22 | Args: 23 | aggregation_layer (ConfigDict): Config of `aggregation_layer`. 24 | """ 25 | 26 | def __init__(self, 27 | aggregation_layer: Optional[ConfigDict] = None, 28 | prototypes_distillation: Optional[ConfigDict] = None, 29 | prototypes_assignment: Optional[ConfigDict] = None, 30 | num_novel=0, 31 | novel_class=None, 32 | rpn_head_: Optional[ConfigDict] = None, 33 | meta_cls_ratio=1.0, 34 | **kwargs) -> None: 35 | super().__init__(**kwargs) 36 | 37 | assert prototypes_distillation is not None, "missing config of `prototypes_distillation`" 38 | self.prototypes_distillation = build_aggregator(copy.deepcopy(prototypes_distillation)) 39 | assert prototypes_assignment is not None, "missing config of `prototypes_assignment`" 40 | self.prototypes_assignment = build_aggregator(copy.deepcopy(prototypes_assignment)) 41 | self.num_novel = num_novel 42 | self.novel_class = novel_class 43 | self.meta_cls_ratio = meta_cls_ratio 44 | 45 | # RPN after FFA 46 | self.with_rpn = False 47 | if rpn_head_ is not None: 48 | self.with_rpn = True 49 | self.rpn_with_support = False 50 | self.rpn_head = build_head(rpn_head_) 51 | self.rpn_head_ = rpn_head_ 52 | 53 | # Non-Linear Fusion (NLF) 54 | d_model = 2048 55 | self.linear1 = nn.Sequential(nn.Linear(d_model, d_model // 2), nn.ReLU(inplace=True)) 56 | self.linear2 = nn.Sequential(nn.Linear(d_model, d_model // 2), nn.ReLU(inplace=True)) 57 | self.linear4 = nn.Sequential(nn.Linear(d_model * 2, d_model // 2), nn.ReLU(inplace=True)) 58 | self.linear3 = nn.Linear(int(d_model * 2.5), d_model) 59 | 60 | def forward_train(self, 61 | query_feats: List[Tensor], 62 | support_feats: List[Tensor], 63 | proposals: List[Tensor], 64 | query_img_metas: List[Dict], 65 | query_gt_bboxes: List[Tensor], 66 | query_gt_labels: List[Tensor], 67 | support_gt_labels: List[Tensor], 68 | query_gt_bboxes_ignore: Optional[List[Tensor]] = None, 69 | **kwargs) -> Dict: 70 | """Forward function for training. 71 | Args: 72 | query_feats (list[Tensor]): List of query features, each item 73 | with shape (N, C, H, W). 74 | support_feats (list[Tensor]): List of support features, each item 75 | with shape (N, C, H, W). 76 | proposals (list[Tensor]): List of region proposals with positive 77 | and negative pairs. 78 | query_img_metas (list[dict]): List of query image info dict where 79 | each dict has: 'img_shape', 'scale_factor', 'flip', and may 80 | also contain 'filename', 'ori_shape', 'pad_shape', and 81 | 'img_norm_cfg'. For details on the values of these keys see 82 | `mmdet/datasets/pipelines/formatting.py:Collect`. 83 | query_gt_bboxes (list[Tensor]): Ground truth bboxes for each 84 | query image, each item with shape (num_gts, 4) 85 | in [tl_x, tl_y, br_x, br_y] format. 86 | query_gt_labels (list[Tensor]): Class indices corresponding to 87 | each box of query images, each item with shape (num_gts). 88 | support_gt_labels (list[Tensor]): Class indices corresponding to 89 | each box of support images, each item with shape (1). 90 | query_gt_bboxes_ignore (list[Tensor] | None): Specify which 91 | bounding boxes can be ignored when computing the loss. 92 | Default: None. 93 | Returns: 94 | dict[str, Tensor]: A dictionary of loss components 95 | """ 96 | 97 | # assign gts and sample proposals 98 | sampling_results = [] 99 | if not self.with_rpn: 100 | if self.with_bbox: 101 | num_imgs = len(query_img_metas) 102 | if query_gt_bboxes_ignore is None: 103 | query_gt_bboxes_ignore = [None for _ in range(num_imgs)] 104 | for i in range(num_imgs): # dense detector, bbox assign task 105 | assign_result = self.bbox_assigner.assign( 106 | proposals[i], query_gt_bboxes[i], 107 | query_gt_bboxes_ignore[i], query_gt_labels[i]) 108 | 109 | sampling_result = self.bbox_sampler.sample( 110 | assign_result, 111 | proposals[i], 112 | query_gt_bboxes[i], 113 | query_gt_labels[i], 114 | feats=[lvl_feat[i][None] for lvl_feat in 115 | query_feats]) 116 | sampling_results.append(sampling_result) 117 | 118 | losses = dict() 119 | if self.with_bbox: 120 | bbox_results = self._optimized_bbox_forward_train( 121 | query_feats, support_feats, sampling_results, query_img_metas, 122 | query_gt_bboxes, query_gt_labels, support_gt_labels, query_gt_bboxes_ignore) 123 | if bbox_results is not None: 124 | losses.update(bbox_results['loss_bbox']) 125 | 126 | return losses 127 | 128 | def _optimized_bbox_forward_train(self, query_feats: List[Tensor], 129 | support_feats: List[Tensor], 130 | sampling_results: object, 131 | query_img_metas: List[Dict], 132 | query_gt_bboxes: List[Tensor], 133 | query_gt_labels: List[Tensor], 134 | support_gt_labels: List[Tensor], 135 | query_gt_bboxes_ignore: Optional[List[Tensor]] = None, ) -> Dict: 136 | """Forward function and calculate loss for box head in training. 137 | Args: 138 | query_feats (list[Tensor]): List of query features, each item 139 | with shape (N, C, H, W). 140 | support_feats (list[Tensor]): List of support features, each item 141 | with shape (N, C, H, W). 142 | sampling_results (obj:`SamplingResult`): Sampling results. 143 | query_img_metas (list[dict]): List of query image info dict where 144 | each dict has: 'img_shape', 'scale_factor', 'flip', and may 145 | also contain 'filename', 'ori_shape', 'pad_shape', and 146 | 'img_norm_cfg'. For details on the values of these keys see 147 | `mmdet/datasets/pipelines/formatting.py:Collect`. 148 | query_gt_bboxes (list[Tensor]): Ground truth bboxes for each query 149 | image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] 150 | format. 151 | query_gt_labels (list[Tensor]): Class indices corresponding to 152 | each box of query images. 153 | support_gt_labels (list[Tensor]): Class indices corresponding to 154 | each box of support images. 155 | Returns: 156 | dict: Predicted results and losses. 157 | """ 158 | if not self.with_rpn: 159 | query_rois = bbox2roi( 160 | [res.bboxes for res in sampling_results]) 161 | len_query_rois = [res.bboxes.size(0) for res in sampling_results] 162 | 163 | bbox_targets = self.bbox_head.get_targets(sampling_results, 164 | query_gt_bboxes, 165 | query_gt_labels, 166 | self.train_cfg) 167 | (labels, label_weights, bbox_targets, bbox_weights) = bbox_targets 168 | 169 | support_gt_labels_ = torch.cat(support_gt_labels) 170 | unique = set(support_gt_labels_.cpu().numpy()) 171 | num_classes = len(unique) 172 | num_supp_shots = support_feats[0].size(0) // num_classes 173 | support_feat = self.extract_support_feats(support_feats)[0] 174 | 175 | # base & novel classes processed separately 176 | if self.num_novel != 0: 177 | nc = torch.tensor(self.novel_class, device='cuda') 178 | n_ids = torch.isin(support_gt_labels_, nc) 179 | b_ids = torch.logical_not(n_ids) 180 | b_gts = support_gt_labels_[b_ids] 181 | n_gts = support_gt_labels_[n_ids] 182 | # bc = torch.tensor(list(set(list(range(num_classes))) - set(nc)), device='cuda') # nc is a tensor! 183 | bc = torch.tensor(list(set(list(range(num_classes))) - set(self.novel_class)), device='cuda') 184 | r_b_gts = torch.cat([torch.argwhere(bc == item)[0] for item in b_gts], dim=0) # relative gt_labels 185 | r_n_gts = torch.cat([torch.argwhere(nc == item)[0] for item in n_gts], dim=0) # relative gt_labels 186 | base_support_feats = support_feats[0][b_ids] 187 | novel_support_feats = support_feats[0][n_ids] 188 | 189 | # prototypes distillation 190 | weight_base, prototypes_base = self.prototypes_distillation(base_support_feats, support_gt_labels=r_b_gts) 191 | weight_novel, prototypes_novel = self.prototypes_distillation(novel_support_feats, forward_novel=True, 192 | support_gt_labels=r_n_gts) 193 | prototypes = torch.cat([prototypes_base, prototypes_novel], 0) 194 | weight = torch.cat([weight_base, weight_novel], 0) 195 | else: 196 | weight, prototypes = self.prototypes_distillation(support_feats[0], support_gt_labels=support_gt_labels_) 197 | 198 | loss_bbox = {'loss_cls': [], 'loss_bbox': [], 'acc': []} 199 | batch_size = len(query_img_metas) 200 | bbox_results = None 201 | 202 | # sampling positive & negative samples (B-CAS) 203 | supp_ids = [] 204 | num_supp_per_im = num_supp_shots + 1 205 | for img_id in range(batch_size): 206 | random_index = np.random.choice( 207 | range(query_gt_labels[img_id].size(0))) 208 | random_query_label = query_gt_labels[img_id][random_index] 209 | supp_id = [] 210 | for i in range(support_feats[0].size(0)): 211 | if support_gt_labels[i] == random_query_label: 212 | supp_id.append(i) 213 | while len(supp_id) < num_supp_per_im: 214 | supp_id.append(np.random.choice(range(support_feats[0].size(0)))) 215 | supp_ids.append(supp_id) 216 | 217 | # prototypes assignment 218 | supp_order = [] 219 | for k in range(num_supp_per_im): 220 | supp_order += [supp_ids[img_id][k] for img_id in range(batch_size)] 221 | fused_feats = self.prototypes_assignment(query_feats[0], prototypes) 222 | 223 | # ************ POST RPN ***************** 224 | if self.with_rpn: 225 | if self.with_rpn: 226 | proposal_cfg = self.train_cfg.get('rpn_proposal', self.rpn_head_.test_cfg) 227 | if self.rpn_with_support: # False 228 | raise NotImplementedError 229 | else: 230 | rpn_losses, proposal_list = self.rpn_head.forward_train( 231 | [fused_feats], 232 | copy.deepcopy(query_img_metas), 233 | copy.deepcopy(query_gt_bboxes), 234 | gt_labels=None, 235 | gt_bboxes_ignore=query_gt_bboxes_ignore, 236 | proposal_cfg=proposal_cfg) 237 | proposals = proposal_list 238 | if self.with_bbox: 239 | num_imgs = len(query_img_metas) 240 | if query_gt_bboxes_ignore is None: 241 | query_gt_bboxes_ignore = [None for _ in range(num_imgs)] 242 | for i in range(num_imgs): 243 | assign_result = self.bbox_assigner.assign( 244 | proposals[i], query_gt_bboxes[i], 245 | query_gt_bboxes_ignore[i], query_gt_labels[i]) 246 | sampling_result = self.bbox_sampler.sample( 247 | assign_result, 248 | proposals[i], 249 | query_gt_bboxes[i], 250 | query_gt_labels[i], 251 | feats=[lvl_feat[i][None] for lvl_feat in 252 | query_feats]) 253 | sampling_results.append(sampling_result) 254 | query_rois = bbox2roi( 255 | [res.bboxes for res in sampling_results]) 256 | len_query_rois = [res.bboxes.size(0) for res in sampling_results] 257 | bbox_targets = self.bbox_head.get_targets(sampling_results, 258 | query_gt_bboxes, 259 | query_gt_labels, 260 | self.train_cfg) 261 | (labels, label_weights, bbox_targets, bbox_weights) = bbox_targets 262 | # ************************************** 263 | 264 | query_roi_feats = self.extract_query_roi_feat([fused_feats], query_rois) 265 | rpt_prototype = support_feat[supp_order] 266 | 267 | # roi_align, reslayer4, bbox_head, loss_func 268 | for k in range(num_supp_per_im): 269 | start, end = k * batch_size, (k + 1) * batch_size 270 | prototype = [item.unsqueeze(0).expand(len_query_roi, -1) for item, len_query_roi in 271 | zip(rpt_prototype[start:end], len_query_rois)] 272 | prototype = torch.concat(prototype) # (512, 2048) 273 | 274 | # Non-Linear Fusion (NLF) 275 | agg1 = self.linear1(query_roi_feats * prototype) 276 | agg2 = self.linear2(query_roi_feats - prototype) 277 | agg3 = self.linear4(torch.cat([query_roi_feats, prototype], dim=-1)) 278 | agg = self.linear3( 279 | torch.cat([agg1, agg2, agg3, query_roi_feats], dim=-1) 280 | ) 281 | bbox_results = self._bbox_forward_without_agg(agg) 282 | 283 | single_loss_bbox = self.bbox_head.loss( 284 | bbox_results['cls_score'], bbox_results['bbox_pred'], 285 | query_rois, labels, 286 | label_weights, bbox_targets, 287 | bbox_weights) 288 | for key in single_loss_bbox.keys(): 289 | loss_bbox[key].append(single_loss_bbox[key]) 290 | if bbox_results is not None: 291 | for key in loss_bbox.keys(): 292 | if key == 'acc': 293 | loss_bbox[key] = torch.cat(loss_bbox['acc']).mean() 294 | else: 295 | loss_bbox[key] = torch.stack( 296 | loss_bbox[key]).sum() / (num_supp_per_im / 2) # / batch_size 297 | 298 | # meta classification loss 299 | if self.bbox_head.with_meta_cls_loss: 300 | meta_cls_labels = torch.cat(support_gt_labels) 301 | meta_cls_score = self.bbox_head.forward_meta_cls(support_feat) 302 | loss_meta_cls = self.bbox_head.loss_meta( 303 | meta_cls_score, meta_cls_labels, 304 | torch.ones_like(meta_cls_labels)) 305 | loss_meta_cls['loss_meta_cls'] = loss_meta_cls['loss_meta_cls'] * self.meta_cls_ratio 306 | loss_bbox.update(loss_meta_cls) 307 | 308 | bbox_results.update(loss_bbox=loss_bbox) 309 | if self.with_rpn: 310 | bbox_results['loss_bbox'].update(rpn_losses) 311 | return bbox_results 312 | 313 | def extract_query_roi_feat(self, feats: List[Tensor], 314 | rois: Tensor) -> Tensor: 315 | """Extracting query BBOX features, which is used in both training and 316 | testing. 317 | Args: 318 | feats (list[Tensor]): List of query features, each item 319 | with shape (N, C, H, W). 320 | rois (Tensor): shape with (bs*128, 5). 321 | Returns: 322 | Tensor: RoI features with shape (N, C). 323 | """ 324 | roi_feats = self.bbox_roi_extractor( 325 | feats[:self.bbox_roi_extractor.num_inputs], rois) 326 | if self.with_shared_head: 327 | roi_feats = self.shared_head(roi_feats) 328 | return roi_feats 329 | 330 | def extract_support_feats(self, feats: List[Tensor]) -> List[Tensor]: 331 | """Forward support features through shared layers. 332 | Args: 333 | feats (list[Tensor]): List of support features, each item 334 | with shape (N, C, H, W). 335 | Returns: 336 | list[Tensor]: List of support features, each item 337 | with shape (N, C). 338 | """ 339 | out = [] 340 | if self.with_shared_head: 341 | for lvl in range(len(feats)): 342 | out.append(self.shared_head.forward_support(feats[lvl])) 343 | else: 344 | out = feats 345 | return out 346 | 347 | def _bbox_forward(self, query_roi_feats: Tensor, 348 | support_roi_feats: Tensor) -> Dict: 349 | """Box head forward function used in both training and testing. 350 | 351 | Args: 352 | query_roi_feats (Tensor): Query roi features with shape (N, C). 353 | support_roi_feats (Tensor): Support features with shape (1, C). 354 | 355 | Returns: 356 | dict: A dictionary of predicted results. 357 | """ 358 | # feature aggregation 359 | roi_feats = self.aggregation_layer( 360 | query_feat=query_roi_feats.unsqueeze(-1).unsqueeze(-1), 361 | support_feat=support_roi_feats.view(1, -1, 1, 1))[0] 362 | 363 | cls_score, bbox_pred = self.bbox_head( 364 | roi_feats.squeeze(-1).squeeze(-1)) 365 | bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) 366 | return bbox_results 367 | 368 | def _bbox_forward_without_agg(self, query_roi_feats: Tensor) -> Dict: 369 | """Box head forward function used in both training and testing. 370 | Args: 371 | query_roi_feats (Tensor): Query roi features with shape (N, C). 372 | Returns: 373 | dict: A dictionary of predicted results. 374 | """ 375 | cls_score, bbox_pred = self.bbox_head(query_roi_feats) 376 | bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) 377 | return bbox_results 378 | 379 | def simple_test(self, 380 | query_feats: List[Tensor], 381 | support_feats_dict: Dict, 382 | new_support_feats_dict: Dict, 383 | proposal_list: List[Tensor], 384 | query_img_metas: List[Dict], 385 | rescale: bool = False) -> List[List[np.ndarray]]: 386 | """Test without augmentation. 387 | Args: 388 | query_feats (list[Tensor]): Features of query image, 389 | each item with shape (N, C, H, W). 390 | support_feats_dict (dict[int, Tensor]) Dict of support features 391 | used for inference only, each key is the class id and value is 392 | the support template features with shape (1, C). 393 | new_support_feats_dict: {'prototype': {cls_id: }, ..}. 394 | proposal_list (list[Tensors]): list of region proposals. 395 | query_img_metas (list[dict]): list of image info dict where each 396 | dict has: `img_shape`, `scale_factor`, `flip`, and may also 397 | contain `filename`, `ori_shape`, `pad_shape`, and 398 | `img_norm_cfg`. For details on the values of these keys see 399 | :class:`mmdet.datasets.pipelines.Collect`. 400 | rescale (bool): Whether to rescale the results. Default: False. 401 | Returns: 402 | list[list[np.ndarray]]: BBox results of each image and classes. 403 | The outer list corresponds to each image. The inner list 404 | corresponds to each class. 405 | """ 406 | assert self.with_bbox, 'Bbox head must be implemented.' 407 | det_bboxes, det_labels = self.simple_test_bboxes( 408 | query_feats, 409 | support_feats_dict, 410 | new_support_feats_dict, 411 | query_img_metas, 412 | proposal_list, 413 | self.test_cfg, 414 | rescale=rescale) 415 | bbox_results = [ 416 | bbox2result(det_bboxes[i], det_labels[i], 417 | self.bbox_head.num_classes) 418 | for i in range(len(det_bboxes)) 419 | ] 420 | 421 | return bbox_results 422 | 423 | def simple_test_bboxes( 424 | self, 425 | query_feats: List[Tensor], 426 | support_feats_dict: Dict, 427 | new_support_feats_dict: Dict, 428 | query_img_metas: List[Dict], 429 | proposals: List[Tensor], 430 | rcnn_test_cfg: ConfigDict, 431 | rescale: bool = False) -> Tuple[List[Tensor], List[Tensor]]: 432 | """Test only det bboxes without augmentation. 433 | Args: 434 | query_feats (list[Tensor]): Features of query image, 435 | each item with shape (N, C, H, W). 436 | support_feats_dict (dict[int, Tensor]) Dict of support features 437 | used for inference only, each key is the class id and value is 438 | the support template features with shape (1, C). 439 | new_support_feats_dict: 440 | query_img_metas (list[dict]): list of image info dict where each 441 | dict has: `img_shape`, `scale_factor`, `flip`, and may also 442 | contain `filename`, `ori_shape`, `pad_shape`, and 443 | `img_norm_cfg`. For details on the values of these keys see 444 | :class:`mmdet.datasets.pipelines.Collect`. 445 | proposals (list[Tensor]): Region proposals. 446 | rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. 447 | rescale (bool): If True, return boxes in original image space. 448 | Default: False. 449 | Returns: 450 | tuple[list[Tensor], list[Tensor]]: Each tensor in first list 451 | with shape (num_boxes, 4) and with shape (num_boxes, ) 452 | in second list. The length of both lists should be equal 453 | to batch_size. 454 | """ 455 | img_shapes = tuple(meta['img_shape'] for meta in query_img_metas) 456 | scale_factors = tuple(meta['scale_factor'] for meta in query_img_metas) 457 | 458 | if not self.with_rpn: 459 | rois = bbox2roi(proposals) 460 | num_rois = rois.size(0) 461 | 462 | cls_scores_dict, bbox_preds_dict = {}, {} 463 | num_classes = len(support_feats_dict) 464 | support_feat = torch.cat([support_feats_dict[i] for i in range(num_classes)]) 465 | prototypes = new_support_feats_dict['prototypes'][0] 466 | if self.num_novel != 0: 467 | prototypes_novel = new_support_feats_dict['prototypes_novel'][0] # (5, 5, 1024) 468 | prototypes = torch.cat([prototypes, prototypes_novel], 0) 469 | 470 | # prototypes assignment 471 | fused_feats = self.prototypes_assignment(query_feats[0], prototypes, query_img_metas) 472 | 473 | # **** POST RPN **** 474 | if self.with_rpn: 475 | proposals = self.rpn_head.simple_test([fused_feats], query_img_metas) 476 | rois = bbox2roi(proposals) 477 | num_rois = rois.size(0) 478 | # ************* 479 | 480 | query_roi_feats = self.extract_query_roi_feat([fused_feats], rois) 481 | query_roi_feats = query_roi_feats.repeat(support_feat.size(0), 1) 482 | rpt_support_feat = torch.cat([item.unsqueeze(0).expand(num_rois, -1) for item in support_feat]) 483 | 484 | # Non-Linear Fusion (NLF) 485 | agg1 = self.linear1(query_roi_feats * rpt_support_feat) 486 | agg2 = self.linear2(query_roi_feats - rpt_support_feat) 487 | agg3 = self.linear4(torch.cat([query_roi_feats, rpt_support_feat], dim=-1)) 488 | agg = self.linear3( 489 | torch.cat([agg1, agg2, agg3, query_roi_feats], dim=-1) 490 | ) 491 | bbox_results = self._bbox_forward_without_agg(agg) 492 | 493 | for class_id in support_feats_dict.keys(): 494 | cls_scores_dict[class_id] = \ 495 | bbox_results['cls_score'][class_id * num_rois:(class_id + 1) * num_rois, class_id:class_id + 1] 496 | bbox_preds_dict[class_id] = \ 497 | bbox_results['bbox_pred'][class_id * num_rois:(class_id + 1) * num_rois, 498 | class_id * 4:(class_id + 1) * 4] 499 | # the official code use the first class background score as final 500 | # background score, while this code use average of all classes' 501 | # background scores instead. 502 | if cls_scores_dict.get(num_classes, None) is None: 503 | cls_scores_dict[num_classes] = \ 504 | bbox_results['cls_score'][class_id * num_rois:(class_id + 1) * num_rois, -1:] 505 | else: 506 | cls_scores_dict[num_classes] += \ 507 | bbox_results['cls_score'][class_id * num_rois:(class_id + 1) * num_rois, -1:] 508 | cls_scores_dict[num_classes] /= len(support_feats_dict.keys()) # 509 | 510 | cls_scores = [ 511 | cls_scores_dict[i] if i in cls_scores_dict.keys() else 512 | torch.zeros_like(cls_scores_dict[list(cls_scores_dict.keys())[0]]) 513 | for i in range(num_classes + 1) 514 | ] 515 | bbox_preds = [ 516 | bbox_preds_dict[i] if i in bbox_preds_dict.keys() else 517 | torch.zeros_like(bbox_preds_dict[list(bbox_preds_dict.keys())[0]]) 518 | for i in range(num_classes) 519 | ] 520 | cls_score = torch.cat(cls_scores, dim=1) # tensor(141,21) 521 | bbox_pred = torch.cat(bbox_preds, dim=1) # tensor(141,80) 522 | 523 | # split batch bbox prediction back to each image 524 | num_proposals_per_img = tuple( 525 | len(p) for p in proposals) 526 | rois = rois.split(num_proposals_per_img, 0) 527 | cls_score = cls_score.split(num_proposals_per_img, 0) 528 | bbox_pred = bbox_pred.split(num_proposals_per_img, 0) 529 | 530 | # apply bbox post-processing to each image individually 531 | det_bboxes = [] 532 | det_labels = [] 533 | for i in range(len(proposals)): 534 | det_bbox, det_label = self.bbox_head.get_bboxes( 535 | rois[i], 536 | cls_score[i], 537 | bbox_pred[i], 538 | img_shapes[i], 539 | scale_factors[i], 540 | rescale=rescale, 541 | cfg=rcnn_test_cfg) 542 | det_bboxes.append(det_bbox) 543 | det_labels.append(det_label) 544 | return det_bboxes, det_labels 545 | -------------------------------------------------------------------------------- /fpd/query_support.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import sys 4 | from abc import abstractmethod 5 | from typing import Dict, List, Optional, Union 6 | 7 | from mmcv.runner import auto_fp16 8 | from mmcv.utils import ConfigDict 9 | from mmdet.models.builder import (DETECTORS, build_backbone, build_head, 10 | build_neck) 11 | from mmdet.models.detectors import BaseDetector 12 | from torch import Tensor 13 | from typing_extensions import Literal 14 | 15 | 16 | @DETECTORS.register_module() 17 | class QuerySupportDetectorFPD(BaseDetector): 18 | """Base class for two-stage detectors in query-support fashion. 19 | 20 | Query-support detectors typically consisting of a region 21 | proposal network and a task-specific regression head. There are 22 | two pipelines for query and support data respectively. 23 | 24 | Args: 25 | backbone (dict): Config of the backbone for query data. 26 | neck (dict | None): Config of the neck for query data and 27 | probably for support data. Default: None. 28 | support_backbone (dict | None): Config of the backbone for 29 | support data only. If None, support and query data will 30 | share same backbone. Default: None. 31 | support_neck (dict | None): Config of the neck for support 32 | data only. Default: None. 33 | rpn_head (dict | None): Config of rpn_head. Default: None. 34 | roi_head (dict | None): Config of roi_head. Default: None. 35 | train_cfg (dict | None): Training config. Useless in CenterNet, 36 | but we keep this variable for SingleStageDetector. Default: None. 37 | test_cfg (dict | None): Testing config of CenterNet. Default: None. 38 | pretrained (str | None): model pretrained path. Default: None. 39 | init_cfg (dict | list[dict] | None): Initialization config dict. 40 | Default: None 41 | """ 42 | 43 | def __init__(self, 44 | backbone: ConfigDict, 45 | neck: Optional[ConfigDict] = None, 46 | support_backbone: Optional[ConfigDict] = None, 47 | support_neck: Optional[ConfigDict] = None, 48 | rpn_head: Optional[ConfigDict] = None, 49 | roi_head: Optional[ConfigDict] = None, 50 | train_cfg: Optional[ConfigDict] = None, 51 | test_cfg: Optional[ConfigDict] = None, 52 | pretrained: Optional[ConfigDict] = None, 53 | init_cfg: Optional[ConfigDict] = None, 54 | post_rpn=False) -> None: 55 | super().__init__(init_cfg) 56 | backbone.pretrained = pretrained 57 | self.backbone = build_backbone(backbone) 58 | self.neck = build_neck(neck) if neck is not None else None 59 | # if `support_backbone` is None, then support and query pipeline will 60 | # share same backbone. 61 | self.support_backbone = build_backbone( 62 | support_backbone 63 | ) if support_backbone is not None else self.backbone 64 | # support neck only forward support data. 65 | self.support_neck = build_neck( 66 | support_neck) if support_neck is not None else None 67 | assert roi_head is not None, 'missing config of roi_head' 68 | # when rpn with aggregation neck, the input of rpn will consist of 69 | # query and support data. otherwise the input of rpn only 70 | # has query data. 71 | self.with_rpn = False 72 | self.rpn_with_support = False 73 | self.post_rpn = post_rpn 74 | if rpn_head is not None: 75 | # self.with_rpn = True 76 | if rpn_head.get('aggregation_layer', None) is not None: 77 | self.rpn_with_support = True 78 | rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None 79 | rpn_head_ = copy.deepcopy(rpn_head) 80 | rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn) 81 | 82 | if not self.post_rpn: 83 | self.with_rpn = True 84 | self.rpn_head = build_head( 85 | rpn_head_) 86 | 87 | if roi_head is not None: 88 | # update train and test cfg here for now 89 | rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None 90 | roi_head.update(train_cfg=rcnn_train_cfg) 91 | roi_head.update(test_cfg=test_cfg.rcnn) 92 | roi_head.pretrained = pretrained 93 | 94 | if self.post_rpn: 95 | roi_head.update(rpn_head_=rpn_head_) 96 | self.roi_head = build_head(roi_head) 97 | 98 | self.train_cfg = train_cfg 99 | self.test_cfg = test_cfg 100 | 101 | @auto_fp16(apply_to=('img',)) 102 | def extract_query_feat(self, img: Tensor) -> List[Tensor]: 103 | """Extract features of query data. 104 | 105 | Args: 106 | img (Tensor): Input images of shape (N, C, H, W). 107 | Typically these should be mean centered and std scaled. 108 | 109 | Returns: 110 | list[Tensor]: Features of support images, each item with shape 111 | (N, C, H, W). 112 | """ 113 | feats = self.backbone(img) 114 | if self.with_neck: 115 | feats = self.neck(feats) 116 | return feats 117 | 118 | def extract_feat(self, img: Tensor) -> List[Tensor]: 119 | """Extract features of query data. 120 | 121 | Args: 122 | img (Tensor): Input images of shape (N, C, H, W). 123 | Typically these should be mean centered and std scaled. 124 | 125 | Returns: 126 | list[Tensor]: Features of query images. 127 | """ 128 | return self.extract_query_feat(img) 129 | 130 | @abstractmethod 131 | def extract_support_feat(self, img: Tensor): 132 | """Extract features of support data.""" 133 | raise NotImplementedError 134 | 135 | @auto_fp16(apply_to=('img',)) 136 | def forward(self, 137 | query_data: Optional[Dict] = None, 138 | support_data: Optional[Dict] = None, 139 | img: Optional[List[Tensor]] = None, 140 | img_metas: Optional[List[Dict]] = None, 141 | mode: Literal['train', 'model_init', 'test'] = 'train', 142 | **kwargs) -> Dict: 143 | """Calls one of (:func:`forward_train`, :func:`forward_test` and 144 | :func:`forward_model_init`) according to the `mode`. The inputs 145 | of forward function would change with the `mode`. 146 | 147 | - When `mode` is 'train', the input will be query and support data 148 | for training. 149 | 150 | - When `mode` is 'model_init', the input will be support template 151 | data at least including (img, img_metas). 152 | 153 | - When `mode` is 'test', the input will be test data at least 154 | including (img, img_metas). 155 | 156 | Args: 157 | query_data (dict): Used for :func:`forward_train`. Dict of 158 | query data and data info where each dict has: `img`, 159 | `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. 160 | Default: None. 161 | support_data (dict): Used for :func:`forward_train`. Dict of 162 | support data and data info dict where each dict has: `img`, 163 | `img_metas`, `gt_bboxes`, `gt_labels`, `gt_bboxes_ignore`. 164 | Default: None. 165 | img (list[Tensor]): Used for func:`forward_test` or 166 | :func:`forward_model_init`. List of tensors of shape 167 | (1, C, H, W). Typically these should be mean centered 168 | and std scaled. Default: None. 169 | img_metas (list[dict]): Used for func:`forward_test` or 170 | :func:`forward_model_init`. List of image info dict 171 | where each dict has: `img_shape`, `scale_factor`, `flip`, 172 | and may also contain `filename`, `ori_shape`, `pad_shape`, 173 | and `img_norm_cfg`. For details on the values of these keys, 174 | see :class:`mmdet.datasets.pipelines.Collect`. Default: None. 175 | mode (str): Indicate which function to call. Options are 'train', 176 | 'model_init' and 'test'. Default: 'train'. 177 | """ 178 | if mode == 'train': 179 | return self.forward_train(query_data, support_data, **kwargs) 180 | elif mode == 'model_init': 181 | return self.forward_model_init(img, img_metas, **kwargs) 182 | elif mode == 'test': 183 | return self.forward_test(img, img_metas, **kwargs) 184 | else: 185 | raise ValueError( 186 | f'invalid forward mode {mode}, ' 187 | f'only support `train`, `model_init` and `test` now') 188 | 189 | def train_step(self, data: Dict, optimizer: Union[object, Dict]) -> Dict: 190 | """The iteration step during training. 191 | 192 | This method defines an iteration step during training, except for the 193 | back propagation and optimizer updating, which are done in an optimizer 194 | hook. Note that in some complicated cases or models, the whole process 195 | including back propagation and optimizer updating is also defined in 196 | this method, such as GAN. For most of query-support detectors, the 197 | batch size denote the batch size of query data. 198 | 199 | Args: 200 | data (dict): The output of dataloader. 201 | optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of 202 | runner is passed to ``train_step()``. This argument is unused 203 | and reserved. 204 | 205 | Returns: 206 | dict: It should contain at least 3 keys: ``loss``, ``log_vars``, 207 | ``num_samples``. 208 | 209 | - ``loss`` is a tensor for back propagation, which can be a 210 | weighted sum of multiple losses. 211 | - ``log_vars`` contains all the variables to be sent to the 212 | logger. 213 | - ``num_samples`` indicates the batch size (when the model is 214 | DDP, it means the batch size on each GPU), which is used for 215 | averaging the logs. 216 | """ 217 | losses = self(**data) 218 | loss, log_vars = self._parse_losses(losses) 219 | 220 | # For most of query-support detectors, the batch size denote the 221 | # batch size of query data. 222 | outputs = dict( 223 | loss=loss, 224 | log_vars=log_vars, 225 | num_samples=len(data['query_data']['img_metas'])) 226 | 227 | return outputs 228 | 229 | def val_step(self, 230 | data: Dict, 231 | optimizer: Optional[Union[object, Dict]] = None) -> Dict: 232 | """The iteration step during validation. 233 | 234 | This method shares the same signature as :func:`train_step`, but used 235 | during val epochs. Note that the evaluation after training epochs is 236 | not implemented with this method, but an evaluation hook. 237 | """ 238 | losses = self(**data) 239 | loss, log_vars = self._parse_losses(losses) 240 | 241 | # For most of query-support detectors, the batch size denote the 242 | # batch size of query data. 243 | outputs = dict( 244 | loss=loss, 245 | log_vars=log_vars, 246 | num_samples=len(data['query_data']['img_metas'])) 247 | 248 | return outputs 249 | 250 | def forward_train(self, 251 | query_data: Dict, 252 | support_data: Dict, 253 | proposals: Optional[List] = None, 254 | **kwargs) -> Dict: 255 | """Forward function for training. 256 | Args: 257 | query_data (dict): In most cases, dict of query data contains: 258 | `img`, `img_metas`, `gt_bboxes`, `gt_labels`, 259 | `gt_bboxes_ignore`. 260 | support_data (dict): In most cases, dict of support data contains: 261 | `img`, `img_metas`, `gt_bboxes`, `gt_labels`, 262 | `gt_bboxes_ignore`. 263 | proposals (list): Override rpn proposals with custom proposals. 264 | Use when `with_rpn` is False. Default: None. 265 | 266 | Returns: 267 | dict[str, Tensor]: a dictionary of loss components 268 | """ 269 | query_img = query_data['img'] 270 | support_img = support_data['img'] 271 | query_feats = self.extract_query_feat(query_img) 272 | support_feats = self.extract_support_feat(support_img) 273 | 274 | losses = dict() 275 | if self.post_rpn: 276 | proposal_list = None 277 | if not self.post_rpn: 278 | # RPN forward and loss 279 | if self.with_rpn: 280 | proposal_cfg = self.train_cfg.get('rpn_proposal', 281 | self.test_cfg.rpn) 282 | if self.rpn_with_support: 283 | rpn_losses, proposal_list = self.rpn_head.forward_train( 284 | query_feats, 285 | support_feats, 286 | query_img_metas=query_data['img_metas'], 287 | query_gt_bboxes=query_data['gt_bboxes'], 288 | query_gt_labels=None, 289 | query_gt_bboxes_ignore=query_data.get( 290 | 'gt_bboxes_ignore', None), 291 | support_img_metas=support_data['img_metas'], 292 | support_gt_bboxes=support_data['gt_bboxes'], 293 | support_gt_labels=support_data['gt_labels'], 294 | support_gt_bboxes_ignore=support_data.get( 295 | 'gt_bboxes_ignore', None), 296 | proposal_cfg=proposal_cfg) 297 | else: 298 | rpn_losses, proposal_list = self.rpn_head.forward_train( 299 | query_feats, 300 | copy.deepcopy(query_data['img_metas']), 301 | copy.deepcopy(query_data['gt_bboxes']), 302 | gt_labels=None, 303 | gt_bboxes_ignore=copy.deepcopy( 304 | query_data.get('gt_bboxes_ignore', None)), 305 | proposal_cfg=proposal_cfg) 306 | losses.update(rpn_losses) 307 | else: 308 | proposal_list = proposals 309 | 310 | roi_losses = self.roi_head.forward_train( 311 | query_feats, 312 | support_feats, 313 | proposals=proposal_list, 314 | query_img_metas=query_data['img_metas'], 315 | query_gt_bboxes=query_data['gt_bboxes'], 316 | query_gt_labels=query_data['gt_labels'], 317 | query_gt_bboxes_ignore=query_data.get('gt_bboxes_ignore', None), 318 | support_img_metas=support_data['img_metas'], 319 | support_gt_bboxes=support_data['gt_bboxes'], 320 | support_gt_labels=support_data['gt_labels'], 321 | support_gt_bboxes_ignore=support_data.get('gt_bboxes_ignore', 322 | None), 323 | **kwargs) 324 | losses.update(roi_losses) 325 | return losses 326 | 327 | def forward_test(self, imgs, img_metas, **kwargs): 328 | """ 329 | Args: 330 | imgs (List[Tensor]): the outer list indicates test-time 331 | augmentations and inner Tensor should have a shape NxCxHxW, 332 | which contains all images in the batch. 333 | img_metas (List[List[dict]]): the outer list indicates test-time 334 | augs (multiscale, flip, etc.) and the inner list indicates 335 | images in a batch. 336 | """ 337 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: 338 | if not isinstance(var, list): 339 | raise TypeError(f'{name} must be a list, but got {type(var)}') 340 | 341 | num_augs = len(imgs) 342 | if num_augs != len(img_metas): 343 | raise ValueError(f'num of augmentations ({len(imgs)}) ' 344 | f'!= num of image meta ({len(img_metas)})') 345 | 346 | # NOTE the batched image size information may be useful, e.g. 347 | # in DETR, this is needed for the construction of masks, which is 348 | # then used for the transformer_head. 349 | for img, img_meta in zip(imgs, img_metas): 350 | batch_size = len(img_meta) 351 | for img_id in range(batch_size): 352 | img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:]) 353 | 354 | if num_augs == 1: 355 | # proposals (List[List[Tensor]]): the outer list indicates 356 | # test-time augs (multiscale, flip, etc.) and the inner list 357 | # indicates images in a batch. 358 | # The Tensor should have a shape Px4, where P is the number of 359 | # proposals. 360 | if 'proposals' in kwargs: 361 | kwargs['proposals'] = kwargs['proposals'][0] 362 | return self.simple_test(imgs[0], img_metas[0], **kwargs) 363 | else: 364 | assert imgs[0].size(0) == 1, 'aug test does not support ' \ 365 | 'inference with batch size ' \ 366 | f'{imgs[0].size(0)}' 367 | assert 'proposals' not in kwargs 368 | return self.aug_test(imgs, img_metas, **kwargs) 369 | 370 | def simple_test(self, 371 | img: Tensor, 372 | img_metas: List[Dict], 373 | proposals: Optional[List[Tensor]] = None, 374 | rescale: bool = False): 375 | """Test without augmentation.""" 376 | raise NotImplementedError 377 | 378 | def aug_test(self, **kwargs): 379 | """Test with augmentation.""" 380 | raise NotImplementedError 381 | 382 | def forward_model_init(self, 383 | img: Tensor, 384 | img_metas: List[Dict], 385 | gt_bboxes: List[Tensor] = None, 386 | gt_labels: List[Tensor] = None, 387 | **kwargs): 388 | """extract and save support features for model initialization.""" 389 | raise NotImplementedError 390 | 391 | @abstractmethod 392 | def model_init(self, **kwargs): 393 | """process the saved support features for model initialization.""" 394 | raise NotImplementedError 395 | -------------------------------------------------------------------------------- /fpd/transforms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | import math 4 | from typing import Dict, List, Tuple 5 | 6 | import mmcv 7 | import numpy as np 8 | from mmdet.datasets import PIPELINES 9 | # from mmdet.datasets.pipelines import (Normalize, Pad, RandomCrop, RandomFlip, 10 | # Resize) 11 | 12 | 13 | 14 | # new 15 | @PIPELINES.register_module() 16 | class CropResizeInstanceByRatio: 17 | """Crop and resize instance according to bbox form image. 18 | 19 | Args: 20 | num_context_pixels (int): Padding pixel around instance. Default: 16. 21 | target_size (tuple[int, int]): Resize cropped instance to target size. 22 | Default: (320, 320). 23 | """ 24 | 25 | def __init__( 26 | self, 27 | num_context_pixels: int = 16, 28 | context_ratio: float = None, 29 | target_size: Tuple[int] = (320, 320) 30 | ) -> None: 31 | assert isinstance(num_context_pixels, int) 32 | assert len(target_size) == 2, 'target_size' 33 | self.num_context_pixels = num_context_pixels 34 | self.target_size = target_size 35 | 36 | self.context_ratio = context_ratio 37 | 38 | def __call__(self, results: Dict) -> Dict: 39 | """Call function to flip bounding boxes, masks, semantic segmentation 40 | maps. 41 | 42 | Args: 43 | results (dict): Result dict from loading pipeline. 44 | 45 | Returns: 46 | dict: Cropped and resized instance results. 47 | """ 48 | img = results['img'] 49 | gt_bbox = results['gt_bboxes'] 50 | img_h, img_w = img.shape[:2] # h, w 51 | x1, y1, x2, y2 = list(map(int, gt_bbox.tolist()[0])) 52 | 53 | # new 54 | if self.context_ratio is not None: 55 | gt_w, gt_h = x2 - x1, y2 - y1 56 | delta_w = max(0, self.target_size[1] - gt_w) 57 | delta_h = max(0, self.target_size[0] - gt_h) 58 | self.num_context_pixels = (gt_w + gt_h) * self.context_ratio + 0.04 * (delta_h + delta_w) 59 | self.num_context_pixels = int(self.num_context_pixels) 60 | 61 | bbox_w = x2 - x1 62 | bbox_h = y2 - y1 63 | t_x1, t_y1, t_x2, t_y2 = 0, 0, bbox_w, bbox_h 64 | 65 | if bbox_w >= bbox_h: 66 | crop_x1 = x1 - self.num_context_pixels 67 | crop_x2 = x2 + self.num_context_pixels 68 | # t_x1 and t_x2 will change when crop context or overflow 69 | t_x1 = t_x1 + self.num_context_pixels 70 | t_x2 = t_x1 + bbox_w 71 | if crop_x1 < 0: 72 | t_x1 = t_x1 + crop_x1 73 | t_x2 = t_x1 + bbox_w 74 | crop_x1 = 0 75 | if crop_x2 > img_w: 76 | crop_x2 = img_w 77 | 78 | short_size = bbox_h 79 | long_size = crop_x2 - crop_x1 80 | y_center = int((y2 + y1) / 2) # math.ceil((y2 + y1) / 2) 81 | crop_y1 = int( 82 | y_center - 83 | (long_size / 2)) # int(y_center - math.ceil(long_size / 2)) 84 | crop_y2 = int( 85 | y_center + 86 | (long_size / 2)) # int(y_center + math.floor(long_size / 2)) 87 | 88 | # t_y1 and t_y2 will change when crop context or overflow 89 | t_y1 = t_y1 + math.ceil((long_size - short_size) / 2) 90 | t_y2 = t_y1 + bbox_h 91 | 92 | if crop_y1 < 0: 93 | t_y1 = t_y1 + crop_y1 94 | t_y2 = t_y1 + bbox_h 95 | crop_y1 = 0 96 | if crop_y2 > img_h: 97 | crop_y2 = img_h 98 | 99 | crop_short_size = crop_y2 - crop_y1 100 | crop_long_size = crop_x2 - crop_x1 101 | 102 | square = np.zeros((crop_long_size, crop_long_size, 3), 103 | dtype=np.uint8) 104 | delta = int( 105 | (crop_long_size - crop_short_size) / 106 | 2) # int(math.ceil((crop_long_size - crop_short_size) / 2)) 107 | square_y1 = delta 108 | square_y2 = delta + crop_short_size 109 | 110 | t_y1 = t_y1 + delta 111 | t_y2 = t_y2 + delta 112 | 113 | crop_box = img[crop_y1:crop_y2, crop_x1:crop_x2, :] 114 | square[square_y1:square_y2, :, :] = crop_box 115 | else: 116 | crop_y1 = y1 - self.num_context_pixels 117 | crop_y2 = y2 + self.num_context_pixels 118 | 119 | # t_y1 and t_y2 will change when crop context or overflow 120 | t_y1 = t_y1 + self.num_context_pixels 121 | t_y2 = t_y1 + bbox_h 122 | if crop_y1 < 0: 123 | t_y1 = t_y1 + crop_y1 124 | t_y2 = t_y1 + bbox_h 125 | crop_y1 = 0 126 | if crop_y2 > img_h: 127 | crop_y2 = img_h 128 | 129 | short_size = bbox_w 130 | long_size = crop_y2 - crop_y1 131 | x_center = int((x2 + x1) / 2) # math.ceil((x2 + x1) / 2) 132 | crop_x1 = int( 133 | x_center - 134 | (long_size / 2)) # int(x_center - math.ceil(long_size / 2)) 135 | crop_x2 = int( 136 | x_center + 137 | (long_size / 2)) # int(x_center + math.floor(long_size / 2)) 138 | 139 | # t_x1 and t_x2 will change when crop context or overflow 140 | t_x1 = t_x1 + math.ceil((long_size - short_size) / 2) 141 | t_x2 = t_x1 + bbox_w 142 | if crop_x1 < 0: 143 | t_x1 = t_x1 + crop_x1 144 | t_x2 = t_x1 + bbox_w 145 | crop_x1 = 0 146 | if crop_x2 > img_w: 147 | crop_x2 = img_w 148 | 149 | crop_short_size = crop_x2 - crop_x1 150 | crop_long_size = crop_y2 - crop_y1 151 | square = np.zeros((crop_long_size, crop_long_size, 3), 152 | dtype=np.uint8) 153 | delta = int( 154 | (crop_long_size - crop_short_size) / 155 | 2) # int(math.ceil((crop_long_size - crop_short_size) / 2)) 156 | square_x1 = delta 157 | square_x2 = delta + crop_short_size 158 | 159 | t_x1 = t_x1 + delta 160 | t_x2 = t_x2 + delta 161 | crop_box = img[crop_y1:crop_y2, crop_x1:crop_x2, :] 162 | square[:, square_x1:square_x2, :] = crop_box # 这才是padding的正确方法 163 | 164 | square = square.astype(np.float32, copy=False) 165 | square, square_scale = mmcv.imrescale( 166 | square, self.target_size, return_scale=True, backend='cv2') 167 | 168 | square = square.astype(np.uint8) 169 | 170 | t_x1 = int(t_x1 * square_scale) 171 | t_y1 = int(t_y1 * square_scale) 172 | t_x2 = int(t_x2 * square_scale) 173 | t_y2 = int(t_y2 * square_scale) 174 | results['img'] = square 175 | results['img_shape'] = square.shape 176 | results['gt_bboxes'] = np.array([[t_x1, t_y1, t_x2, 177 | t_y2]]).astype(np.float32) 178 | 179 | return results 180 | 181 | def __repr__(self) -> str: 182 | return self.__class__.__name__ + \ 183 | f'(num_context_pixels={self.num_context_pixels},' \ 184 | f' target_size={self.target_size})' 185 | -------------------------------------------------------------------------------- /fpd/utils.py: -------------------------------------------------------------------------------- 1 | """ Copy paste from VFA(https://github.com/csuhan/VFA/). """ 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | from PIL import Image 6 | from torchvision import transforms as trans 7 | 8 | from mmfewshot.detection.datasets.coco import COCO_SPLIT 9 | 10 | 11 | class PCB: 12 | def __init__(self, class_names, model="RN101", templates="a photo of a {}"): 13 | super().__init__() 14 | self.device = torch.cuda.current_device() 15 | 16 | # image transforms 17 | self.expand_ratio = 0.1 18 | self.trans = trans.Compose([ 19 | trans.Resize([224, 224], interpolation=3), 20 | trans.ToTensor(), 21 | trans.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 22 | 23 | # CLIP configs 24 | # import clip 25 | from CLIP import clip 26 | self.class_names = class_names 27 | self.clip, _ = clip.load(model, device=self.device) 28 | self.prompts = clip.tokenize([ 29 | templates.format(cls_name) 30 | for cls_name in self.class_names 31 | ]).to(self.device) 32 | with torch.no_grad(): 33 | text_features = self.clip.encode_text(self.prompts) 34 | self.text_features = F.normalize(text_features, dim=-1, p=2) 35 | 36 | def load_image_by_box(self, img_path, boxes): 37 | image = Image.open(img_path).convert("RGB") 38 | image_list = [] 39 | for box in boxes: 40 | x1, y1, x2, y2 = box 41 | h, w = y2 - y1, x2 - x1 42 | x1 = max(0, x1 - w * self.expand_ratio) 43 | y1 = max(0, y1 - h * self.expand_ratio) 44 | x2 = x2 + w * self.expand_ratio 45 | y2 = y2 + h * self.expand_ratio 46 | sub_image = image.crop((int(x1), int(y1), int(x2), int(y2))) 47 | sub_image = self.trans(sub_image).to(self.device) 48 | image_list.append(sub_image) 49 | return torch.stack(image_list) 50 | 51 | @torch.no_grad() 52 | def __call__(self, img_path, boxes): 53 | images = self.load_image_by_box(img_path, boxes) 54 | 55 | image_features = self.clip.encode_image(images) 56 | image_features = F.normalize(image_features, dim=-1, p=2) 57 | logit_scale = self.clip.logit_scale.exp() 58 | logits_per_image = logit_scale * image_features @ self.text_features.t() 59 | return logits_per_image.softmax(dim=-1) 60 | 61 | 62 | class TestMixins: 63 | def __init__(self): 64 | self.pcb = None 65 | 66 | def refine_test(self, results, img_metas): 67 | if not hasattr(self, 'pcb'): 68 | self.pcb = PCB(COCO_SPLIT['ALL_CLASSES'], model='ViT-B/32') 69 | # exclue ids for COCO 70 | self.exclude_ids = [7, 9, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 71 | 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 41, 42, 43, 44, 45, 72 | 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 61, 63, 64, 65, 73 | 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] 74 | 75 | boxes_list, scores_list, labels_list = [], [], [] 76 | for cls_id, result in enumerate(results[0]): 77 | if len(result) == 0: 78 | continue 79 | boxes_list.append(result[:, :4]) 80 | scores_list.append(result[:, 4]) 81 | labels_list.append([cls_id] * len(result)) 82 | 83 | if len(boxes_list) == 0: 84 | return results 85 | 86 | boxes_list = np.concatenate(boxes_list, axis=0) 87 | scores_list = np.concatenate(scores_list, axis=0) 88 | labels_list = np.concatenate(labels_list, axis=0) 89 | 90 | logits = self.pcb(img_metas[0]['filename'], boxes_list) 91 | 92 | for i, prob in enumerate(logits): 93 | if labels_list[i] not in self.exclude_ids: 94 | # print('average') 95 | # print(scores_list[i], logits[i, labels_list[i]]) # single value 96 | scores_list[i] = scores_list[i] * 0.5 + logits[i, labels_list[i]] * 0.5 97 | 98 | j = 0 99 | for i in range(len(results[0])): 100 | num_dets = len(results[0][i]) 101 | if num_dets == 0: 102 | continue 103 | for k in range(num_dets): 104 | results[0][i][k, 4] = scores_list[j] 105 | j += 1 106 | 107 | return results -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | yapf==0.40.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import find_packages, setup 3 | 4 | setup( 5 | name='fpd', 6 | version='1.0', 7 | author='wangzc', 8 | url="https://github.com/wangchen1801/FPD", 9 | description="Code for 'Fine-Grained Prototypes Distillation for Few-Shot Object Detection(FPD).'", 10 | packages=find_packages(exclude=('configs', 'data', 'work_dirs')), 11 | # install_requires=['clip@git+ssh://git@github.com/openai/CLIP.git'], 12 | ) 13 | 14 | 15 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import warnings 5 | 6 | import mmcv 7 | import torch 8 | from mmcv import Config, DictAction 9 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 10 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 11 | wrap_fp16_model) 12 | 13 | from mmfewshot.detection.datasets import (build_dataloader, build_dataset, 14 | get_copy_dataset_type) 15 | from mmfewshot.detection.models import build_detector 16 | from mmfewshot.utils import compat_cfg 17 | 18 | import fpd 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser( 22 | description='MMFewShot test (and eval) a model') 23 | parser.add_argument('config', help='test config file path') 24 | parser.add_argument('checkpoint', help='checkpoint file') 25 | parser.add_argument('--out', help='output result file in pickle format') 26 | parser.add_argument( 27 | '--eval', 28 | type=str, 29 | nargs='+', 30 | default='mAP', 31 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 32 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') 33 | parser.add_argument( 34 | '--gpu-ids', 35 | type=int, 36 | nargs='+', 37 | help='(Deprecated, please use --gpu-id) ids of gpus to use ' 38 | '(only applicable to non-distributed training)') 39 | parser.add_argument( 40 | '--gpu-id', 41 | type=int, 42 | default=0, 43 | help='id of gpu to use ' 44 | '(only applicable to non-distributed testing)') 45 | parser.add_argument('--show', action='store_true', help='show results') 46 | parser.add_argument( 47 | '--show-dir', help='directory where painted images will be saved') 48 | parser.add_argument( 49 | '--show-score-thr', 50 | type=float, 51 | default=0.3, 52 | help='score threshold (default: 0.3)') 53 | parser.add_argument( 54 | '--gpu-collect', 55 | action='store_true', 56 | help='whether to use gpu to collect results.') 57 | parser.add_argument( 58 | '--tmpdir', 59 | help='tmp directory used for collecting results from multiple ' 60 | 'workers, available when gpu-collect is not specified') 61 | parser.add_argument( 62 | '--cfg-options', 63 | nargs='+', 64 | action=DictAction, 65 | help='override some settings in the used config, the key-value pair ' 66 | 'in xxx=yyy format will be merged into config file. If the value to ' 67 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 68 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 69 | 'Note that the quotation marks are necessary and that no white space ' 70 | 'is allowed.') 71 | parser.add_argument( 72 | '--options', 73 | nargs='+', 74 | action=DictAction, 75 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 76 | 'format will be kwargs for dataset.evaluate() function (deprecate), ' 77 | 'change to --eval-options instead.') 78 | parser.add_argument( 79 | '--eval-options', 80 | nargs='+', 81 | action=DictAction, 82 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 83 | 'format will be kwargs for dataset.evaluate() function') 84 | parser.add_argument( 85 | '--launcher', 86 | choices=['none', 'pytorch', 'slurm', 'mpi'], 87 | default='none', 88 | help='job launcher') 89 | parser.add_argument('--local_rank', type=int, default=0) 90 | args = parser.parse_args() 91 | if 'LOCAL_RANK' not in os.environ: 92 | os.environ['LOCAL_RANK'] = str(args.local_rank) 93 | 94 | if args.options and args.eval_options: 95 | raise ValueError( 96 | '--options and --eval-options cannot be both ' 97 | 'specified, --options is deprecated in favor of --eval-options') 98 | if args.options: 99 | warnings.warn('--options is deprecated in favor of --eval-options') 100 | args.eval_options = args.options 101 | args.cfg_options = args.options 102 | return args 103 | 104 | 105 | def main(): 106 | args = parse_args() 107 | 108 | assert args.out or args.eval or args.show \ 109 | or args.show_dir, ( 110 | 'Please specify at least one operation (save/eval/show the ' 111 | 'results / save the results) with the argument "--out", "--eval"', 112 | '"--show" or "--show-dir"') 113 | 114 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 115 | raise ValueError('The output file must be a pkl file.') 116 | 117 | cfg = Config.fromfile(args.config) 118 | cfg = compat_cfg(cfg) 119 | 120 | if args.cfg_options is not None: 121 | cfg.merge_from_dict(args.cfg_options) 122 | 123 | # import modules from string list. 124 | if cfg.get('custom_imports', None): 125 | from mmcv.utils import import_modules_from_strings 126 | import_modules_from_strings(**cfg['custom_imports']) 127 | # set cudnn_benchmark 128 | if cfg.get('cudnn_benchmark', False): 129 | torch.backends.cudnn.benchmark = True 130 | cfg.model.pretrained = None 131 | 132 | if args.gpu_ids is not None: 133 | cfg.gpu_ids = args.gpu_ids[0:1] 134 | warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' 135 | 'Because we only support single GPU mode in ' 136 | 'non-distributed testing. Use the first GPU ' 137 | 'in `gpu_ids` now.') 138 | else: 139 | cfg.gpu_ids = [args.gpu_id] 140 | # init distributed env first, since logger depends on the dist info. 141 | if args.launcher == 'none': 142 | distributed = False 143 | else: 144 | distributed = True 145 | init_dist(args.launcher, **cfg.dist_params) 146 | 147 | # build the dataloader 148 | dataset = build_dataset(cfg.data.test) 149 | 150 | test_dataloader_default_args = dict( 151 | samples_per_gpu=1, workers_per_gpu=2, dist=distributed, shuffle=False) 152 | # update overall dataloader(for train, val and test) setting 153 | test_loader_cfg = { 154 | **test_dataloader_default_args, 155 | **cfg.data.get('test_dataloader', {}) 156 | } 157 | 158 | # currently only support single images testing 159 | assert test_loader_cfg['samples_per_gpu'] == 1, \ 160 | 'currently only support single images testing' 161 | data_loader = build_dataloader(dataset, **test_loader_cfg) 162 | 163 | # pop frozen_parameters 164 | cfg.model.pop('frozen_parameters', None) 165 | 166 | # build the model and load checkpoint 167 | cfg.model.train_cfg = None 168 | model = build_detector(cfg.model) 169 | 170 | fp16_cfg = cfg.get('fp16', None) 171 | if fp16_cfg is not None: 172 | wrap_fp16_model(model) 173 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 174 | # old versions did not save class info in checkpoints, this walkaround is 175 | # for backward compatibility 176 | if 'CLASSES' in checkpoint.get('meta', {}): 177 | model.CLASSES = checkpoint['meta']['CLASSES'] 178 | else: 179 | model.CLASSES = dataset.CLASSES 180 | 181 | # for meta-learning methods which require support template dataset 182 | # for model initialization. 183 | if cfg.data.get('model_init', None) is not None: 184 | cfg.data.model_init.pop('copy_from_train_dataset') 185 | model_init_samples_per_gpu = cfg.data.model_init.pop( 186 | 'samples_per_gpu', 1) 187 | model_init_workers_per_gpu = cfg.data.model_init.pop( 188 | 'workers_per_gpu', 1) 189 | if cfg.data.model_init.get('ann_cfg', None) is None: 190 | assert checkpoint['meta'].get('model_init_ann_cfg', 191 | None) is not None 192 | cfg.data.model_init.type = \ 193 | get_copy_dataset_type(cfg.data.model_init.type) 194 | cfg.data.model_init.ann_cfg = \ 195 | checkpoint['meta']['model_init_ann_cfg'] 196 | model_init_dataset = build_dataset(cfg.data.model_init) 197 | # disable dist to make all rank get same data 198 | model_init_dataloader = build_dataloader( 199 | model_init_dataset, 200 | samples_per_gpu=model_init_samples_per_gpu, 201 | workers_per_gpu=model_init_workers_per_gpu, 202 | dist=False, 203 | shuffle=False) 204 | 205 | if not distributed: 206 | # Please use MMCV >= 1.4.4 for CPU testing! 207 | model = MMDataParallel(model, device_ids=cfg.gpu_ids) 208 | show_kwargs = dict(show_score_thr=args.show_score_thr) 209 | if cfg.data.get('model_init', None) is not None: 210 | from mmfewshot.detection.apis import (single_gpu_model_init, 211 | single_gpu_test) 212 | single_gpu_model_init(model, model_init_dataloader) 213 | else: 214 | from mmdet.apis.test import single_gpu_test 215 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, 216 | **show_kwargs) 217 | else: 218 | model = MMDistributedDataParallel( 219 | model.cuda(), 220 | device_ids=[torch.cuda.current_device()], 221 | broadcast_buffers=False) 222 | if cfg.data.get('model_init', None) is not None: 223 | from mmfewshot.detection.apis import (multi_gpu_model_init, 224 | multi_gpu_test) 225 | multi_gpu_model_init(model, model_init_dataloader) 226 | else: 227 | from mmdet.apis.test import multi_gpu_test 228 | outputs = multi_gpu_test( 229 | model, 230 | data_loader, 231 | args.tmpdir, 232 | args.gpu_collect, 233 | ) 234 | 235 | rank, _ = get_dist_info() 236 | if rank == 0: 237 | if args.out: 238 | print(f'\nwriting results to {args.out}') 239 | mmcv.dump(outputs, args.out) 240 | kwargs = {} if args.eval_options is None else args.eval_options 241 | if args.eval: 242 | eval_kwargs = cfg.get('evaluation', {}).copy() 243 | # hard-code way to remove EvalHook args 244 | for key in [ 245 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 246 | 'rule' 247 | ]: 248 | eval_kwargs.pop(key, None) 249 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 250 | print(dataset.evaluate(outputs, **eval_kwargs)) 251 | 252 | 253 | if __name__ == '__main__': 254 | main() 255 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import sys 7 | import time 8 | import warnings 9 | 10 | import cv2 11 | import mmcv 12 | import torch 13 | from mmcv import Config, DictAction 14 | from mmcv.runner import get_dist_info, init_dist, set_random_seed 15 | from mmcv.utils import get_git_hash 16 | from mmdet.utils import collect_env 17 | 18 | import mmfewshot # noqa: F401, F403 19 | from mmfewshot import __version__ 20 | from mmfewshot.detection.apis import train_detector 21 | from mmfewshot.detection.datasets import build_dataset 22 | from mmfewshot.detection.models import build_detector 23 | from mmfewshot.utils import get_root_logger 24 | 25 | import fpd 26 | 27 | 28 | # disable multithreading to avoid system being overloaded 29 | cv2.setNumThreads(0) 30 | os.environ['OMP_NUM_THREADS'] = '1' 31 | os.environ['MKL_NUM_THREADS'] = '1' 32 | 33 | # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' # for debugging, time consuming 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser(description='Train a FewShot model') 37 | parser.add_argument('config', help='train config file path') 38 | parser.add_argument( 39 | '--work-dir', help='the directory to save logs and models') 40 | parser.add_argument( 41 | '--resume-from', help='the checkpoint file to resume from') 42 | parser.add_argument( 43 | '--no-validate', 44 | action='store_true', 45 | help='whether not to evaluate the checkpoint during training') 46 | group_gpus = parser.add_mutually_exclusive_group() 47 | group_gpus.add_argument( 48 | '--gpus', 49 | type=int, 50 | help='number of gpus to use ' 51 | '(only applicable to non-distributed training)') 52 | group_gpus.add_argument( 53 | '--gpu-ids', 54 | type=int, 55 | nargs='+', 56 | help='(Deprecated, please use --gpu-id) ids of gpus to use ' 57 | '(only applicable to non-distributed training)') 58 | parser.add_argument( 59 | '--gpu-id', 60 | type=int, 61 | default=0, 62 | help='id of gpu to use ' 63 | '(only applicable to non-distributed testing)') 64 | parser.add_argument('--seed', type=int, default=None, help='random seed') 65 | parser.add_argument( 66 | '--deterministic', 67 | action='store_true', 68 | help='whether to set deterministic options for CUDNN backend.') 69 | parser.add_argument( 70 | '--options', 71 | nargs='+', 72 | action=DictAction, 73 | help='override some settings in the used config, the key-value pair ' 74 | 'in xxx=yyy format will be merged into config file (deprecate), ' 75 | 'change to --cfg-options instead.') 76 | parser.add_argument( 77 | '--cfg-options', 78 | nargs='+', 79 | action=DictAction, 80 | help='override some settings in the used config, the key-value pair ' 81 | 'in xxx=yyy format will be merged into config file. If the value ' 82 | 'to be overwritten is a list, it should be like key="[a,b]" or ' 83 | 'key=a,b It also allows nested list/tuple values, e.g. ' 84 | 'key="[(a,b),(c,d)]" Note that the quotation marks are necessary ' 85 | 'and that no white space is allowed.') 86 | parser.add_argument( 87 | '--launcher', 88 | choices=['none', 'pytorch', 'slurm', 'mpi'], 89 | default='none', 90 | help='job launcher') 91 | parser.add_argument('--local_rank', type=int, default=0) 92 | args = parser.parse_args() 93 | if 'LOCAL_RANK' not in os.environ: 94 | os.environ['LOCAL_RANK'] = str(args.local_rank) 95 | 96 | if args.options and args.cfg_options: 97 | raise ValueError( 98 | '--options and --cfg-options cannot be both ' 99 | 'specified, --options is deprecated in favor of --cfg-options') 100 | if args.options: 101 | warnings.warn('--options is deprecated in favor of --cfg-options') 102 | args.cfg_options = args.options 103 | 104 | return args 105 | 106 | 107 | def main(): 108 | args = parse_args() 109 | 110 | cfg = Config.fromfile(args.config) 111 | if args.cfg_options is not None: 112 | cfg.merge_from_dict(args.cfg_options) 113 | 114 | # import modules from string list. 115 | if cfg.get('custom_imports', None): 116 | from mmcv.utils import import_modules_from_strings 117 | import_modules_from_strings(**cfg['custom_imports']) 118 | # set cudnn_benchmark 119 | if cfg.get('cudnn_benchmark', False): 120 | torch.backends.cudnn.benchmark = True 121 | 122 | # work_dir is determined in this priority: CLI > segment in file > filename 123 | if args.work_dir is not None: 124 | # update configs according to CLI args if args.work_dir is not None 125 | cfg.work_dir = args.work_dir 126 | elif cfg.get('work_dir', None) is None: 127 | # use config filename as default work_dir if cfg.work_dir is None 128 | cfg.work_dir = osp.join('./work_dirs', 129 | osp.splitext(osp.basename(args.config))[0]) 130 | if args.resume_from is not None: 131 | cfg.resume_from = args.resume_from 132 | if args.gpus is not None: 133 | cfg.gpu_ids = range(1) 134 | warnings.warn('`--gpus` is deprecated because we only support ' 135 | 'single GPU mode in non-distributed training. ' 136 | 'Use `gpus=1` now.') 137 | if args.gpu_ids is not None: 138 | cfg.gpu_ids = args.gpu_ids[0:1] 139 | warnings.warn('`--gpu-ids` is deprecated, please use `--gpu-id`. ' 140 | 'Because we only support single GPU mode in ' 141 | 'non-distributed training. Use the first GPU ' 142 | 'in `gpu_ids` now.') 143 | if args.gpus is None and args.gpu_ids is None: 144 | cfg.gpu_ids = [args.gpu_id] 145 | 146 | # init distributed env first, since logger depends on the dist info. 147 | if args.launcher == 'none': 148 | distributed = False 149 | rank = 0 150 | else: 151 | distributed = True 152 | init_dist(args.launcher, **cfg.dist_params) 153 | rank, world_size = get_dist_info() 154 | # re-set gpu_ids with distributed training mode 155 | cfg.gpu_ids = range(world_size) 156 | # create work_dir 157 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 158 | # dump config 159 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 160 | # init the logger before other steps 161 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 162 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 163 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 164 | 165 | # init the meta dict to record some important information such as 166 | # environment info and seed, which will be logged 167 | meta = dict() 168 | # log env info 169 | env_info_dict = collect_env() 170 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 171 | dash_line = '-' * 60 + '\n' 172 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 173 | dash_line) 174 | meta['env_info'] = env_info 175 | meta['config'] = cfg.pretty_text 176 | # log some basic info 177 | logger.info(f'Distributed training: {distributed}') 178 | logger.info(f'Config:\n{cfg.pretty_text}') 179 | 180 | # set random seeds 181 | if args.seed is not None: 182 | seed = args.seed 183 | elif cfg.seed is not None: 184 | seed = cfg.seed 185 | elif distributed: 186 | seed = 0 187 | Warning(f'When using DistributedDataParallel, each rank will ' 188 | f'initialize different random seed. It will cause different' 189 | f'random action for each rank. In few shot setting, novel ' 190 | f'shots may be generated by random sampling. If all rank do ' 191 | f'not use same seed, each rank will sample different data.' 192 | f'It will cause UNFAIR data usage. Therefore, seed is set ' 193 | f'to {seed} for default.') 194 | else: 195 | seed = None 196 | 197 | if seed is not None: 198 | logger.info(f'Set random seed to {seed}, ' 199 | f'deterministic: {args.deterministic}') 200 | set_random_seed(seed, deterministic=args.deterministic) 201 | meta['seed'] = seed 202 | meta['exp_name'] = osp.basename(args.config) 203 | 204 | # build_detector will do three things, including building model, 205 | # initializing weights and freezing parameters (optional). 206 | model = build_detector(cfg.model, logger=logger) 207 | 208 | # build_dataset will do two things, including building dataset 209 | # and saving dataset into json file (optional). 210 | datasets = [ 211 | build_dataset( 212 | cfg.data.train, 213 | rank=rank, 214 | work_dir=cfg.work_dir, 215 | timestamp=timestamp) 216 | ] 217 | 218 | if len(cfg.workflow) == 2: 219 | val_dataset = copy.deepcopy(cfg.data.val) 220 | val_dataset.pipeline = cfg.data.train.pipeline 221 | datasets.append(build_dataset(val_dataset)) 222 | if cfg.checkpoint_config is not None: 223 | # save mmfewshot version, config file content and class names in 224 | # checkpoints as meta data 225 | cfg.checkpoint_config.meta = dict( 226 | mmfewshot_version=__version__ + get_git_hash()[:7], 227 | CLASSES=datasets[0].CLASSES) 228 | # add an attribute for visualization convenience 229 | model.CLASSES = datasets[0].CLASSES 230 | 231 | train_detector( 232 | model, 233 | datasets, 234 | cfg, 235 | distributed=distributed, 236 | validate=(not args.no_validate), 237 | timestamp=timestamp, 238 | meta=meta) 239 | 240 | 241 | if __name__ == '__main__': 242 | main() 243 | --------------------------------------------------------------------------------