├── .gitignore ├── README.md ├── VFA.png ├── configs ├── _base_ │ ├── datasets │ │ └── nway_kshot │ │ │ ├── base_coco_ms.py │ │ │ ├── base_voc_ms.py │ │ │ ├── few_shot_coco_ms.py │ │ │ └── few_shot_voc_ms.py │ ├── default_runtime.py │ ├── models │ │ └── faster_rcnn_r50_caffe_c4.py │ └── schedules │ │ └── schedule.py └── vfa │ ├── coco │ ├── vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_coco_30shot-fine-tuning.py │ └── vfa_r101_c4_8xb4_coco_base-training.py │ ├── meta-rcnn_r50_c4.py │ ├── vfa_r101_c4.py │ └── voc │ ├── vfa_split1 │ ├── vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning.py │ └── vfa_r101_c4_8xb4_voc-split1_base-training.py │ ├── vfa_split2 │ ├── vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning.py │ └── vfa_r101_c4_8xb4_voc-split2_base-training.py │ └── vfa_split3 │ ├── vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning.py │ ├── vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py │ └── vfa_r101_c4_8xb4_voc-split3_base-training.py ├── dist_test.sh ├── dist_train.sh ├── setup.py ├── test.py ├── train.py └── vfa ├── __init__.py ├── utils.py ├── vfa_bbox_head.py ├── vfa_detector.py └── vfa_roi_head.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/en/_build/ 68 | docs/zh_cn/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | /data/ 108 | /data 109 | .vscode 110 | .idea 111 | .DS_Store 112 | 113 | # custom 114 | *.pkl 115 | *.pkl.json 116 | *.log.json 117 | work_dirs/ 118 | mmfewshot/.mim 119 | 120 | # Pytorch 121 | *.pth 122 | *.py~ 123 | *.sh~ 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## VFA 2 | 3 | 4 | 5 | > **Few-Shot Object Detection via Variational Feature Aggregation (AAAI2023)**
6 | > [Jiaming Han](https://csuhan.com), [Yuqiang Ren](https://github.com/Anymake), [Jian Ding](https://dingjiansw101.github.io), [Ke Yan](https://scholar.google.com.hk/citations?user=vWstgn0AAAAJ), [Gui-Song Xia](http://www.captain-whu.com/xia_En.html).
7 | > [arXiv preprint](https://arxiv.org/abs/2301.13411). 8 | 9 | Our code is based on [mmfewshot](https://github.com/open-mmlab/mmfewshot). 10 | 11 | ### Setup 12 | 13 | * **Installation** 14 | 15 | Here is a from-scratch setup script. 16 | 17 | ```bash 18 | conda create -n vfa python=3.8 -y 19 | conda activate vfa 20 | 21 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=11.0 -c pytorch 22 | 23 | pip install openmim 24 | mim install mmcv-full==1.3.12 25 | 26 | # install mmclassification mmdetection 27 | mim install mmcls==0.15.0 28 | mim install mmdet==2.16.0 29 | 30 | # install mmfewshot 31 | mim install mmfewshot==0.1.0 32 | 33 | # install VFA 34 | python setup.py develop 35 | 36 | ``` 37 | 38 | * **Prepare Datasets** 39 | 40 | Please refer to mmfewshot's [detection data preparation](https://github.com/open-mmlab/mmfewshot/blob/main/tools/data/README.md). 41 | 42 | 43 | ### Model Zoo 44 | 45 | All pretrained models can be found at [github release](https://github.com/csuhan/VFA/releases/tag/v1.0.0). 46 | 47 | #### Results on PASCAL VOC dataset 48 | 49 | * **Base Training** 50 | 51 | | Split | Base AP50 | config | ckpt | 52 | |-------|-----------|-----------------------------------------------------------------------------------|------| 53 | | 1 | 78.6 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_base-training_iter_18000.pth) | 54 | | 2 | 79.5 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_base-training_iter_18000.pth) | 55 | | 3 | 79.8 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_base-training_iter_18000.pth) | 56 | 57 | * **Few Shot Fine-tuning** 58 | 59 | | Split | Shot | nAP50 | config | ckpt | 60 | |-------|------|-------|----------------------------------------------------------------------------------------|----------| 61 | | 1 | 1 | 57.5 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning_iter_400.pth) | 62 | | 1 | 2 | 65.0 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning_iter_800.pth) | 63 | | 1 | 3 | 64.3 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning_iter_1200.pth) | 64 | | 1 | 5 | 67.1 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning_iter_1600.pth) | 65 | | 1 | 10 | 67.4 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning_iter_2000.pth) | 66 | | 2 | 1 | 40.8 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning_iter_400.pth) | 67 | | 2 | 2 | 45.9 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning_iter_800.pth) | 68 | | 2 | 3 | 51.1 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning_iter_1200.pth) | 69 | | 2 | 5 | 51.8 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning_iter_1600.pth) | 70 | | 2 | 10 | 51.8 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning_iter_2000.pth) | 71 | | 3 | 1 | 49.0 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning_iter_400.pth) | 72 | | 3 | 2 | 54.9 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning_iter_800.pth) | 73 | | 3 | 3 | 56.6 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning_iter_1200.pth) | 74 | | 3 | 5 | 59.0 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning_iter_1600.pth) | 75 | | 3 | 10 | 58.5 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning_iter_2000.pth) | 76 | 77 | 78 | #### Results on COCO dataset 79 | 80 | * **Base Training** 81 | 82 | | Base mAP | config | ckpt | 83 | |----------|-------------------------------------------------------------------|----------| 84 | | 36.0 | [config](configs/vfa/coco/vfa_r101_c4_8xb4_coco_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_coco_base-training_iter_110000.pth) | 85 | 86 | * **Few Shot Finetuning** 87 | 88 | | Shot | nAP | config | ckpt | 89 | |------|------|------------------------------------------------------------------------|----------| 90 | | 10 | 16.8 | [config](configs/vfa/coco/vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_coco_10shot-fine-tuning_iter_10000.pth) | 91 | | 30 | 19.5 | [config](configs/vfa/coco/vfa_r101_c4_8xb4_coco_30shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_coco_30shot-fine-tuning_iter_20000.pth) | 92 | 93 | 94 | ### Train and Test 95 | 96 | * **Testing** 97 | 98 | ```bash 99 | # single-gpu test 100 | python test.py ${CONFIG} ${CHECKPOINT} --eval mAP|bbox 101 | 102 | # multi-gpus test 103 | bash dist_test.sh ${CONFIG} ${CHECKPOINT} ${NUM_GPU} --eval mAP|bbox 104 | ``` 105 | 106 | For example: 107 | * test VFA on VOC split1 1-shot with sinel-gpu, we should run: 108 | ```bash 109 | python test.py configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py \ 110 | work_dirs/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning/iter_400.pth \ 111 | --eval mAP 112 | ``` 113 | 114 | * test VFA on COCO 10-shot with 8 gpus, we should run: 115 | ```bash 116 | bash dist_test.sh configs/vfa/coco/vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py \ 117 | work_dirs/vfa_r101_c4_8xb4_coco_10shot-fine-tuning/iter_10000.pth \ 118 | 8 --eval bbox 119 | ``` 120 | 121 | * **Training** 122 | 123 | ```bash 124 | # single-gpu training 125 | python train.py ${CONFIG} 126 | 127 | # multi-gpus training 128 | bash dist_train.sh ${CONFIG} ${NUM_GPU} 129 | ``` 130 | 131 | For example: train VFA on VOC. 132 | ```bash 133 | # Stage I: base training. 134 | bash dist_train.sh configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split{1,2,3}_base-training.py 8 135 | 136 | # Stage II: few-shot fine-tuning on all splits and shots. 137 | voc_config_dir=configs/vfa/voc/ 138 | for split in 1 2 3; do 139 | for shot in 1 2 3 5 10; do 140 | config_path=${voc_config_dir}/vfa_split${split}/vfa_r101_c4_8xb4_voc-split${split}_${shot}shot-fine-tuning.py 141 | echo $config_path 142 | bash dist_train.sh $config_path 8 143 | done 144 | done 145 | ``` 146 | 147 | **Note**: All our configs and models are trained with 8 gpus. You need to change the learning rate or batch size if you use fewer/more gpus. 148 | 149 | 150 | ### Citation 151 | 152 | If you find our work useful for your research, please consider citing: 153 | 154 | ```BibTeX 155 | @InProceedings{han2023vfa, 156 | title = {Few-Shot Object Detection via Variational Feature Aggregation}, 157 | author = {Han, Jiaming and Ren, Yuqiang and Ding, Jian and Yan, Ke and Xia, Gui-Song}, 158 | booktitle = {Proceedings of the 37th AAAI Conference on Artificial Intelligence (AAAI-23)}, 159 | year = {2023} 160 | } 161 | ``` -------------------------------------------------------------------------------- /VFA.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csuhan/VFA/e35411eb22b4fc48b524debe58dc7c09be2bf9a6/VFA.png -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/base_coco_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='Resize', 10 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), 11 | (1333, 768), (1333, 800)], 12 | keep_ratio=True, 13 | multiscale_mode='value'), 14 | dict(type='RandomFlip', flip_ratio=0.5), 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='Pad', size_divisor=32), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 19 | ], 20 | support=[ 21 | dict(type='LoadImageFromFile'), 22 | dict(type='LoadAnnotations', with_bbox=True), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='GenerateMask', target_size=(224, 224)), 25 | dict(type='RandomFlip', flip_ratio=0.0), 26 | dict(type='DefaultFormatBundle'), 27 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 28 | ]) 29 | test_pipeline = [ 30 | dict(type='LoadImageFromFile'), 31 | dict( 32 | type='MultiScaleFlipAug', 33 | img_scale=(1333, 800), 34 | flip=False, 35 | transforms=[ 36 | dict(type='Resize', keep_ratio=True), 37 | dict(type='RandomFlip'), 38 | dict(type='Normalize', **img_norm_cfg), 39 | dict(type='Pad', size_divisor=32), 40 | dict(type='ImageToTensor', keys=['img']), 41 | dict(type='Collect', keys=['img']) 42 | ]) 43 | ] 44 | # classes splits are predefined in FewShotCocoDataset 45 | data_root = 'data/coco/' 46 | data = dict( 47 | samples_per_gpu=2, 48 | workers_per_gpu=2, 49 | train=dict( 50 | type='NWayKShotDataset', 51 | num_support_ways=60, 52 | num_support_shots=1, 53 | one_support_shot_per_image=True, 54 | num_used_support_shots=200, 55 | save_dataset=False, 56 | dataset=dict( 57 | type='FewShotCocoDataset', 58 | ann_cfg=[ 59 | dict( 60 | type='ann_file', 61 | ann_file='data/few_shot_ann/coco/annotations/train.json') 62 | ], 63 | img_prefix=data_root, 64 | multi_pipelines=train_multi_pipelines, 65 | classes='BASE_CLASSES', 66 | instance_wise=False, 67 | dataset_name='query_support_dataset')), 68 | val=dict( 69 | type='FewShotCocoDataset', 70 | ann_cfg=[ 71 | dict( 72 | type='ann_file', 73 | ann_file='data/few_shot_ann/coco/annotations/val.json') 74 | ], 75 | img_prefix=data_root, 76 | pipeline=test_pipeline, 77 | classes='BASE_CLASSES'), 78 | test=dict( 79 | type='FewShotCocoDataset', 80 | ann_cfg=[ 81 | dict( 82 | type='ann_file', 83 | ann_file='data/few_shot_ann/coco/annotations/val.json') 84 | ], 85 | img_prefix=data_root, 86 | pipeline=test_pipeline, 87 | test_mode=True, 88 | classes='BASE_CLASSES'), 89 | model_init=dict( 90 | copy_from_train_dataset=True, 91 | samples_per_gpu=16, 92 | workers_per_gpu=1, 93 | type='FewShotCocoDataset', 94 | ann_cfg=None, 95 | img_prefix=data_root, 96 | pipeline=train_multi_pipelines['support'], 97 | instance_wise=True, 98 | classes='BASE_CLASSES', 99 | dataset_name='model_init_dataset')) 100 | evaluation = dict(interval=20000, metric='bbox', classwise=True) 101 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/base_voc_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='Resize', 10 | img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), 11 | (1333, 608), (1333, 640), (1333, 672), (1333, 704), 12 | (1333, 736), (1333, 768), (1333, 800)], 13 | keep_ratio=True, 14 | multiscale_mode='value'), 15 | dict(type='RandomFlip', flip_ratio=0.5), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='Pad', size_divisor=32), 18 | dict(type='DefaultFormatBundle'), 19 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 20 | ], 21 | support=[ 22 | dict(type='LoadImageFromFile'), 23 | dict(type='LoadAnnotations', with_bbox=True), 24 | dict(type='Normalize', **img_norm_cfg), 25 | dict(type='GenerateMask', target_size=(224, 224)), 26 | dict(type='RandomFlip', flip_ratio=0.0), 27 | dict(type='DefaultFormatBundle'), 28 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 29 | ]) 30 | test_pipeline = [ 31 | dict(type='LoadImageFromFile'), 32 | dict( 33 | type='MultiScaleFlipAug', 34 | img_scale=(1333, 800), 35 | flip=False, 36 | transforms=[ 37 | dict(type='Resize', keep_ratio=True), 38 | dict(type='RandomFlip'), 39 | dict(type='Normalize', **img_norm_cfg), 40 | dict(type='Pad', size_divisor=32), 41 | dict(type='ImageToTensor', keys=['img']), 42 | dict(type='Collect', keys=['img']) 43 | ]) 44 | ] 45 | # classes splits are predefined in FewShotVOCDataset 46 | data_root = 'data/VOCdevkit/' 47 | data = dict( 48 | samples_per_gpu=4, 49 | workers_per_gpu=4, 50 | train=dict( 51 | type='NWayKShotDataset', 52 | num_support_ways=15, 53 | num_support_shots=1, 54 | one_support_shot_per_image=True, 55 | num_used_support_shots=200, 56 | save_dataset=False, 57 | dataset=dict( 58 | type='FewShotVOCDataset', 59 | ann_cfg=[ 60 | dict( 61 | type='ann_file', 62 | ann_file=data_root + 63 | 'VOC2007/ImageSets/Main/trainval.txt'), 64 | dict( 65 | type='ann_file', 66 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt') 67 | ], 68 | img_prefix=data_root, 69 | multi_pipelines=train_multi_pipelines, 70 | classes=None, 71 | use_difficult=True, 72 | instance_wise=False, 73 | dataset_name='query_dataset'), 74 | support_dataset=dict( 75 | type='FewShotVOCDataset', 76 | ann_cfg=[ 77 | dict( 78 | type='ann_file', 79 | ann_file=data_root + 80 | 'VOC2007/ImageSets/Main/trainval.txt'), 81 | dict( 82 | type='ann_file', 83 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt') 84 | ], 85 | img_prefix=data_root, 86 | multi_pipelines=train_multi_pipelines, 87 | classes=None, 88 | use_difficult=False, 89 | instance_wise=False, 90 | dataset_name='support_dataset')), 91 | val=dict( 92 | type='FewShotVOCDataset', 93 | ann_cfg=[ 94 | dict( 95 | type='ann_file', 96 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt') 97 | ], 98 | img_prefix=data_root, 99 | pipeline=test_pipeline, 100 | classes=None), 101 | test=dict( 102 | type='FewShotVOCDataset', 103 | ann_cfg=[ 104 | dict( 105 | type='ann_file', 106 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt') 107 | ], 108 | img_prefix=data_root, 109 | pipeline=test_pipeline, 110 | test_mode=True, 111 | classes=None), 112 | model_init=dict( 113 | copy_from_train_dataset=True, 114 | samples_per_gpu=16, 115 | workers_per_gpu=1, 116 | type='FewShotVOCDataset', 117 | ann_cfg=None, 118 | img_prefix=data_root, 119 | pipeline=train_multi_pipelines['support'], 120 | use_difficult=False, 121 | instance_wise=True, 122 | classes=None, 123 | dataset_name='model_init_dataset')) 124 | evaluation = dict(interval=5000, metric='mAP') 125 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/few_shot_coco_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='Resize', 10 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736), 11 | (1333, 768), (1333, 800)], 12 | keep_ratio=True, 13 | multiscale_mode='value'), 14 | dict(type='RandomFlip', flip_ratio=0.5), 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='Pad', size_divisor=32), 17 | dict(type='DefaultFormatBundle'), 18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 19 | ], 20 | support=[ 21 | dict(type='LoadImageFromFile'), 22 | dict(type='LoadAnnotations', with_bbox=True), 23 | dict(type='Normalize', **img_norm_cfg), 24 | dict(type='GenerateMask', target_size=(224, 224)), 25 | dict(type='RandomFlip', flip_ratio=0.0), 26 | dict(type='DefaultFormatBundle'), 27 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 28 | ]) 29 | test_pipeline = [ 30 | dict(type='LoadImageFromFile'), 31 | dict( 32 | type='MultiScaleFlipAug', 33 | img_scale=(1333, 800), 34 | flip=False, 35 | transforms=[ 36 | dict(type='Resize', keep_ratio=True), 37 | dict(type='RandomFlip'), 38 | dict(type='Normalize', **img_norm_cfg), 39 | dict(type='Pad', size_divisor=32), 40 | dict(type='ImageToTensor', keys=['img']), 41 | dict(type='Collect', keys=['img']) 42 | ]) 43 | ] 44 | # classes splits are predefined in FewShotCocoDataset 45 | data_root = 'data/coco/' 46 | data = dict( 47 | samples_per_gpu=2, 48 | workers_per_gpu=2, 49 | train=dict( 50 | type='NWayKShotDataset', 51 | num_support_ways=80, 52 | num_support_shots=1, 53 | one_support_shot_per_image=False, 54 | num_used_support_shots=30, 55 | save_dataset=True, 56 | dataset=dict( 57 | type='FewShotCocoDataset', 58 | ann_cfg=[ 59 | dict( 60 | type='ann_file', 61 | ann_file='data/few_shot_ann/coco/annotations/train.json') 62 | ], 63 | img_prefix=data_root, 64 | multi_pipelines=train_multi_pipelines, 65 | classes='ALL_CLASSES', 66 | instance_wise=False, 67 | dataset_name='query_support_dataset')), 68 | val=dict( 69 | type='FewShotCocoDataset', 70 | ann_cfg=[ 71 | dict( 72 | type='ann_file', 73 | ann_file='data/few_shot_ann/coco/annotations/val.json') 74 | ], 75 | img_prefix=data_root, 76 | pipeline=test_pipeline, 77 | classes='ALL_CLASSES'), 78 | test=dict( 79 | type='FewShotCocoDataset', 80 | ann_cfg=[ 81 | dict( 82 | type='ann_file', 83 | ann_file='data/few_shot_ann/coco/annotations/val.json') 84 | ], 85 | img_prefix=data_root, 86 | pipeline=test_pipeline, 87 | test_mode=True, 88 | classes='ALL_CLASSES'), 89 | model_init=dict( 90 | copy_from_train_dataset=True, 91 | samples_per_gpu=16, 92 | workers_per_gpu=1, 93 | type='FewShotCocoDataset', 94 | ann_cfg=None, 95 | img_prefix=data_root, 96 | pipeline=train_multi_pipelines['support'], 97 | instance_wise=True, 98 | classes='ALL_CLASSES', 99 | dataset_name='model_init_dataset')) 100 | evaluation = dict( 101 | interval=3000, 102 | metric='bbox', 103 | classwise=True, 104 | class_splits=['BASE_CLASSES', 'NOVEL_CLASSES']) 105 | -------------------------------------------------------------------------------- /configs/_base_/datasets/nway_kshot/few_shot_voc_ms.py: -------------------------------------------------------------------------------- 1 | # dataset settings 2 | img_norm_cfg = dict( 3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False) 4 | train_multi_pipelines = dict( 5 | query=[ 6 | dict(type='LoadImageFromFile'), 7 | dict(type='LoadAnnotations', with_bbox=True), 8 | dict( 9 | type='Resize', 10 | img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576), 11 | (1333, 608), (1333, 640), (1333, 672), (1333, 704), 12 | (1333, 736), (1333, 768), (1333, 800)], 13 | keep_ratio=True, 14 | multiscale_mode='value'), 15 | dict(type='RandomFlip', flip_ratio=0.5), 16 | dict(type='Normalize', **img_norm_cfg), 17 | dict(type='Pad', size_divisor=32), 18 | dict(type='DefaultFormatBundle'), 19 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 20 | ], 21 | support=[ 22 | dict(type='LoadImageFromFile'), 23 | dict(type='LoadAnnotations', with_bbox=True), 24 | dict(type='Normalize', **img_norm_cfg), 25 | dict(type='GenerateMask', target_size=(224, 224)), 26 | dict(type='RandomFlip', flip_ratio=0.0), 27 | dict(type='DefaultFormatBundle'), 28 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']) 29 | ]) 30 | test_pipeline = [ 31 | dict(type='LoadImageFromFile'), 32 | dict( 33 | type='MultiScaleFlipAug', 34 | img_scale=(1333, 800), 35 | flip=False, 36 | transforms=[ 37 | dict(type='Resize', keep_ratio=True), 38 | dict(type='RandomFlip'), 39 | dict(type='Normalize', **img_norm_cfg), 40 | dict(type='Pad', size_divisor=32), 41 | dict(type='ImageToTensor', keys=['img']), 42 | dict(type='Collect', keys=['img']) 43 | ]) 44 | ] 45 | # classes splits are predefined in FewShotVOCDataset 46 | data_root = 'data/VOCdevkit/' 47 | data = dict( 48 | samples_per_gpu=4, 49 | workers_per_gpu=4, 50 | train=dict( 51 | type='NWayKShotDataset', 52 | num_support_ways=20, 53 | num_support_shots=1, 54 | one_support_shot_per_image=False, 55 | num_used_support_shots=30, 56 | save_dataset=True, 57 | dataset=dict( 58 | type='FewShotVOCDataset', 59 | ann_cfg=[ 60 | dict( 61 | type='ann_file', 62 | ann_file=data_root + 63 | 'VOC2007/ImageSets/Main/trainval.txt'), 64 | dict( 65 | type='ann_file', 66 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt') 67 | ], 68 | img_prefix=data_root, 69 | multi_pipelines=train_multi_pipelines, 70 | classes=None, 71 | use_difficult=False, 72 | instance_wise=False, 73 | dataset_name='query_support_dataset')), 74 | val=dict( 75 | type='FewShotVOCDataset', 76 | ann_cfg=[ 77 | dict( 78 | type='ann_file', 79 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt') 80 | ], 81 | img_prefix=data_root, 82 | pipeline=test_pipeline, 83 | classes=None), 84 | test=dict( 85 | type='FewShotVOCDataset', 86 | ann_cfg=[ 87 | dict( 88 | type='ann_file', 89 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt') 90 | ], 91 | img_prefix=data_root, 92 | pipeline=test_pipeline, 93 | test_mode=True, 94 | classes=None), 95 | model_init=dict( 96 | copy_from_train_dataset=True, 97 | samples_per_gpu=16, 98 | workers_per_gpu=1, 99 | type='FewShotVOCDataset', 100 | ann_cfg=None, 101 | img_prefix=data_root, 102 | pipeline=train_multi_pipelines['support'], 103 | use_difficult=False, 104 | instance_wise=True, 105 | num_novel_shots=None, 106 | classes=None, 107 | dataset_name='model_init_dataset')) 108 | evaluation = dict(interval=3000, metric='mAP', class_splits=None) 109 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | checkpoint_config = dict(interval=5000) 2 | # yapf:disable 3 | log_config = dict( 4 | interval=50, 5 | hooks=[ 6 | dict(type='TextLoggerHook'), 7 | # dict(type='TensorboardLoggerHook') 8 | ]) 9 | # yapf:enable 10 | custom_hooks = [dict(type='NumClassCheckHook')] 11 | 12 | dist_params = dict(backend='nccl') 13 | log_level = 'INFO' 14 | load_from = None 15 | resume_from = None 16 | workflow = [('train', 1)] 17 | use_infinite_sampler = True 18 | # a magical seed works well in most cases for this repo!!! 19 | # using different seeds might raise some issues about reproducibility 20 | seed = 42 21 | -------------------------------------------------------------------------------- /configs/_base_/models/faster_rcnn_r50_caffe_c4.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | norm_cfg = dict(type='BN', requires_grad=False) 3 | pretrained = 'open-mmlab://detectron2/resnet50_caffe' 4 | model = dict( 5 | type='FasterRCNN', 6 | pretrained=pretrained, 7 | backbone=dict( 8 | type='ResNet', 9 | depth=50, 10 | num_stages=3, 11 | strides=(1, 2, 2), 12 | dilations=(1, 1, 1), 13 | out_indices=(2, ), 14 | frozen_stages=1, 15 | norm_cfg=norm_cfg, 16 | norm_eval=True, 17 | style='caffe'), 18 | rpn_head=dict( 19 | type='RPNHead', 20 | in_channels=1024, 21 | feat_channels=1024, 22 | anchor_generator=dict( 23 | type='AnchorGenerator', 24 | scales=[2, 4, 8, 16, 32], 25 | ratios=[0.5, 1.0, 2.0], 26 | scale_major=False, 27 | strides=[16]), 28 | bbox_coder=dict( 29 | type='DeltaXYWHBBoxCoder', 30 | target_means=[.0, .0, .0, .0], 31 | target_stds=[1.0, 1.0, 1.0, 1.0]), 32 | loss_cls=dict( 33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0), 34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)), 35 | roi_head=dict( 36 | type='StandardRoIHead', 37 | shared_head=dict( 38 | type='ResLayer', 39 | pretrained=pretrained, 40 | depth=50, 41 | stage=3, 42 | stride=2, 43 | dilation=1, 44 | style='caffe', 45 | norm_cfg=norm_cfg, 46 | norm_eval=True), 47 | bbox_roi_extractor=dict( 48 | type='SingleRoIExtractor', 49 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0), 50 | out_channels=1024, 51 | featmap_strides=[16]), 52 | bbox_head=dict( 53 | type='BBoxHead', 54 | with_avg_pool=True, 55 | roi_feat_size=7, 56 | in_channels=2048, 57 | num_classes=80, 58 | bbox_coder=dict( 59 | type='DeltaXYWHBBoxCoder', 60 | target_means=[0., 0., 0., 0.], 61 | target_stds=[0.1, 0.1, 0.2, 0.2]), 62 | reg_class_agnostic=False, 63 | loss_cls=dict( 64 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 65 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))), 66 | # model training and testing settings 67 | train_cfg=dict( 68 | rpn=dict( 69 | assigner=dict( 70 | type='MaxIoUAssigner', 71 | pos_iou_thr=0.7, 72 | neg_iou_thr=0.3, 73 | min_pos_iou=0.3, 74 | match_low_quality=True, 75 | ignore_iof_thr=-1), 76 | sampler=dict( 77 | type='RandomSampler', 78 | num=256, 79 | pos_fraction=0.5, 80 | neg_pos_ub=-1, 81 | add_gt_as_proposals=False), 82 | allowed_border=0, 83 | pos_weight=-1, 84 | debug=False), 85 | rpn_proposal=dict( 86 | nms_pre=12000, 87 | max_per_img=2000, 88 | nms=dict(type='nms', iou_threshold=0.7), 89 | min_bbox_size=0), 90 | rcnn=dict( 91 | assigner=dict( 92 | type='MaxIoUAssigner', 93 | pos_iou_thr=0.5, 94 | neg_iou_thr=0.5, 95 | min_pos_iou=0.5, 96 | match_low_quality=False, 97 | ignore_iof_thr=-1), 98 | sampler=dict( 99 | type='RandomSampler', 100 | num=512, 101 | pos_fraction=0.25, 102 | neg_pos_ub=-1, 103 | add_gt_as_proposals=True), 104 | pos_weight=-1, 105 | debug=False)), 106 | test_cfg=dict( 107 | rpn=dict( 108 | nms_pre=6000, 109 | max_per_img=1000, 110 | nms=dict(type='nms', iou_threshold=0.7), 111 | min_bbox_size=0), 112 | rcnn=dict( 113 | score_thr=0.05, 114 | nms=dict(type='nms', iou_threshold=0.5), 115 | max_per_img=100))) 116 | -------------------------------------------------------------------------------- /configs/_base_/schedules/schedule.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001) 3 | optimizer_config = dict(grad_clip=None) 4 | # learning policy 5 | lr_config = dict( 6 | policy='step', 7 | warmup='linear', 8 | warmup_iters=500, 9 | warmup_ratio=0.001, 10 | step=[60000, 80000]) 11 | runner = dict(type='IterBasedRunner', max_iters=90000) 12 | -------------------------------------------------------------------------------- /configs/vfa/coco/vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/datasets/nway_kshot/few_shot_coco_ms.py', 3 | '../../_base_/schedules/schedule.py', '../vfa_r101_c4.py', 4 | '../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotCocoDataset 7 | # FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | num_used_support_shots=10, 12 | dataset=dict( 13 | type='FewShotCocoDefaultDataset', 14 | ann_cfg=[dict(method='MetaRCNN', setting='10SHOT')], 15 | num_novel_shots=10, 16 | num_base_shots=10, 17 | )), 18 | model_init=dict(num_novel_shots=10, num_base_shots=10)) 19 | evaluation = dict(interval=10000) 20 | checkpoint_config = dict(interval=10000) 21 | optimizer = dict(lr=0.001) 22 | lr_config = dict(warmup=None, step=[10000]) 23 | runner = dict(max_iters=10000) 24 | # load_from = 'path of base training model' 25 | load_from = 'work_dirs/vfa_r101_c4_8xb4_coco_base-training/latest.pth' 26 | # model settings 27 | model = dict( 28 | roi_head=dict(bbox_head=dict(num_classes=80, num_meta_classes=80)), 29 | with_refine=True, 30 | frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head.rpn_conv', 32 | ]) 33 | 34 | # iter 10000 35 | # OrderedDict([('BASE_CLASSES bbox_mAP', 0.309), ('BASE_CLASSES bbox_mAP_50', 0.509), ('BASE_CLASSES bbox_mAP_75', 0.33), ('BASE_CLASSES bbox_mAP_s', 0.165), ('BASE_CLASSES bbox_mAP_m', 0.351), ('BASE_CLASSES bbox_mAP_l', 0.442), ('BASE_CLASSES bbox_mAP_copypaste', '0.309 0.509 0.330 0.165 0.351 0.442'), ('NOVEL_CLASSES bbox_mAP', 0.168), ('NOVEL_CLASSES bbox_mAP_50', 0.363), ('NOVEL_CLASSES bbox_mAP_75', 0.136), ('NOVEL_CLASSES bbox_mAP_s', 0.074), ('NOVEL_CLASSES bbox_mAP_m', 0.176), ('NOVEL_CLASSES bbox_mAP_l', 0.253), ('NOVEL_CLASSES bbox_mAP_copypaste', '0.168 0.363 0.136 0.074 0.176 0.253'), ('bbox_mAP', 0.273), ('bbox_mAP_50', 0.473), ('bbox_mAP_75', 0.282), ('bbox_mAP_s', 0.142), ('bbox_mAP_m', 0.308), ('bbox_mAP_l', 0.395), ('bbox_mAP_copypaste', '0.273 0.473 0.282 0.142 0.308 0.395')]) -------------------------------------------------------------------------------- /configs/vfa/coco/vfa_r101_c4_8xb4_coco_30shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/datasets/nway_kshot/few_shot_coco_ms.py', 3 | '../../_base_/schedules/schedule.py', '../vfa_r101_c4.py', 4 | '../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotCocoDataset 7 | # FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | num_used_support_shots=30, 12 | dataset=dict( 13 | type='FewShotCocoDefaultDataset', 14 | ann_cfg=[dict(method='MetaRCNN', setting='30SHOT')], 15 | num_novel_shots=30, 16 | num_base_shots=30, 17 | )), 18 | model_init=dict(num_novel_shots=30, num_base_shots=30)) 19 | evaluation = dict(interval=20000) 20 | checkpoint_config = dict(interval=20000) 21 | optimizer = dict(lr=0.001) 22 | lr_config = dict(warmup=None, step=[20000]) 23 | runner = dict(max_iters=20000) 24 | # load_from = 'path of base training model' 25 | load_from = 'work_dirs/vfa_r101_c4_8xb4_coco_base-training/latest.pth' 26 | # model settings 27 | model = dict( 28 | roi_head=dict(bbox_head=dict(num_classes=80, num_meta_classes=80)), 29 | with_refine=True, 30 | frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head.rpn_conv', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/coco/vfa_r101_c4_8xb4_coco_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../_base_/datasets/nway_kshot/base_coco_ms.py', 3 | '../../_base_/schedules/schedule.py', '../vfa_r101_c4.py', 4 | '../../_base_/default_runtime.py' 5 | ] 6 | lr_config = dict(warmup_iters=1000, step=[85000, 100000]) 7 | evaluation = dict(interval=10000) 8 | checkpoint_config = dict(interval=10000) 9 | runner = dict(max_iters=110000) 10 | optimizer = dict(lr=0.005) 11 | # model settings 12 | model = dict(roi_head=dict(bbox_head=dict(num_classes=60, num_meta_classes=60))) -------------------------------------------------------------------------------- /configs/vfa/meta-rcnn_r50_c4.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/faster_rcnn_r50_caffe_c4.py', 3 | ] 4 | # model settings 5 | model = dict( 6 | type='MetaRCNN', 7 | backbone=dict(type='ResNetWithMetaConv', frozen_stages=2), 8 | rpn_head=dict( 9 | feat_channels=512, loss_cls=dict(use_sigmoid=False, loss_weight=1.0)), 10 | roi_head=dict( 11 | type='MetaRCNNRoIHead', 12 | shared_head=dict(type='MetaRCNNResLayer'), 13 | bbox_head=dict( 14 | type='MetaBBoxHead', 15 | with_avg_pool=False, 16 | in_channels=2048, 17 | roi_feat_size=1, 18 | num_classes=80, 19 | num_meta_classes=80, 20 | meta_cls_in_channels=2048, 21 | with_meta_cls_loss=True, 22 | loss_meta=dict( 23 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), 24 | loss_bbox=dict(type='SmoothL1Loss', loss_weight=1.0)), 25 | aggregation_layer=dict( 26 | type='AggregationLayer', 27 | aggregator_cfgs=[ 28 | dict( 29 | type='DotProductAggregator', 30 | in_channels=2048, 31 | with_fc=False) 32 | ])), 33 | train_cfg=dict( 34 | rpn=dict( 35 | assigner=dict( 36 | type='MaxIoUAssigner', 37 | pos_iou_thr=0.7, 38 | neg_iou_thr=0.3, 39 | min_pos_iou=0.3, 40 | match_low_quality=True, 41 | ignore_iof_thr=-1), 42 | sampler=dict( 43 | type='RandomSampler', 44 | num=256, 45 | pos_fraction=0.5, 46 | neg_pos_ub=-1, 47 | add_gt_as_proposals=False), 48 | allowed_border=0, 49 | pos_weight=-1, 50 | debug=False), 51 | rpn_proposal=dict( 52 | nms_pre=12000, 53 | max_per_img=2000, 54 | nms=dict(type='nms', iou_threshold=0.7), 55 | min_bbox_size=0), 56 | rcnn=dict( 57 | assigner=dict( 58 | type='MaxIoUAssigner', 59 | pos_iou_thr=0.5, 60 | neg_iou_thr=0.5, 61 | min_pos_iou=0.5, 62 | match_low_quality=False, 63 | ignore_iof_thr=-1), 64 | sampler=dict( 65 | type='RandomSampler', 66 | num=128, 67 | pos_fraction=0.25, 68 | neg_pos_ub=-1, 69 | add_gt_as_proposals=True), 70 | pos_weight=-1, 71 | debug=False)), 72 | test_cfg=dict( 73 | rpn=dict( 74 | nms_pre=6000, 75 | max_per_img=300, 76 | nms=dict(type='nms', iou_threshold=0.7), 77 | min_bbox_size=0), 78 | rcnn=dict( 79 | score_thr=0.05, 80 | nms=dict(type='nms', iou_threshold=0.3), 81 | max_per_img=100))) 82 | -------------------------------------------------------------------------------- /configs/vfa/vfa_r101_c4.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | './meta-rcnn_r50_c4.py', 3 | ] 4 | 5 | custom_imports = dict( 6 | imports=[ 7 | 'vfa.vfa_detector', 8 | 'vfa.vfa_roi_head', 9 | 'vfa.vfa_bbox_head'], 10 | allow_failed_imports=False) 11 | 12 | pretrained = 'open-mmlab://detectron2/resnet101_caffe' 13 | # model settings 14 | model = dict( 15 | type='VFA', 16 | pretrained=pretrained, 17 | backbone=dict(depth=101), 18 | roi_head=dict( 19 | type='VFARoIHead', 20 | shared_head=dict(pretrained=pretrained), 21 | bbox_head=dict( 22 | type='VFABBoxHead', num_classes=20, num_meta_classes=20))) 23 | -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_10SHOT')], 14 | num_novel_shots=10, 15 | num_base_shots=10, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=2000, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=2000) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=2000) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_1SHOT')], 14 | num_novel_shots=1, 15 | num_base_shots=1, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=400, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=400) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=400) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_2SHOT')], 14 | num_novel_shots=2, 15 | num_base_shots=2, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=800, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=800) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=800) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_3SHOT')], 14 | num_novel_shots=3, 15 | num_base_shots=3, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=1200, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=1200) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1200) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_5SHOT')], 14 | num_novel_shots=5, 15 | num_base_shots=5, 16 | classes='ALL_CLASSES_SPLIT1', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT1'), 19 | test=dict(classes='ALL_CLASSES_SPLIT1'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT1')) 21 | evaluation = dict( 22 | interval=1600, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1']) 23 | checkpoint_config = dict(interval=1600) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1600) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=False, 11 | dataset=dict(classes='BASE_CLASSES_SPLIT1'), 12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT1')), 13 | val=dict(classes='BASE_CLASSES_SPLIT1'), 14 | test=dict(classes='BASE_CLASSES_SPLIT1'), 15 | model_init=dict(classes='BASE_CLASSES_SPLIT1')) 16 | lr_config = dict(warmup_iters=100, step=[12000, 16000]) 17 | evaluation = dict(interval=3000) 18 | checkpoint_config = dict(interval=3000) 19 | runner = dict(max_iters=18000) 20 | optimizer = dict(lr=0.02) 21 | # model settings 22 | model = dict(roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15))) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_10SHOT')], 14 | num_novel_shots=10, 15 | num_base_shots=10, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=2000, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=2000) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=2000) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_1SHOT')], 14 | num_novel_shots=1, 15 | num_base_shots=1, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=400, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=400) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=400) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_2SHOT')], 14 | num_novel_shots=2, 15 | num_base_shots=2, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=800, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=800) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=800) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_3SHOT')], 14 | num_novel_shots=3, 15 | num_base_shots=3, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=1200, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=1200) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1200) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_5SHOT')], 14 | num_novel_shots=5, 15 | num_base_shots=5, 16 | classes='ALL_CLASSES_SPLIT2', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT2'), 19 | test=dict(classes='ALL_CLASSES_SPLIT2'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT2')) 21 | evaluation = dict( 22 | interval=1600, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2']) 23 | checkpoint_config = dict(interval=1600) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1600) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=False, 11 | dataset=dict(classes='BASE_CLASSES_SPLIT2'), 12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT2')), 13 | val=dict(classes='BASE_CLASSES_SPLIT2'), 14 | test=dict(classes='BASE_CLASSES_SPLIT2'), 15 | model_init=dict(classes='BASE_CLASSES_SPLIT2')) 16 | lr_config = dict(warmup_iters=100, step=[12000, 16000]) 17 | evaluation = dict(interval=3000) 18 | checkpoint_config = dict(interval=3000) 19 | runner = dict(max_iters=18000) 20 | optimizer = dict(lr=0.02) 21 | # model settings 22 | model = dict(roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15))) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_10SHOT')], 14 | num_novel_shots=10, 15 | num_base_shots=10, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=2000, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=2000) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=2000) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_1SHOT')], 14 | num_novel_shots=1, 15 | num_base_shots=1, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=400, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=400) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=400) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_2SHOT')], 14 | num_novel_shots=2, 15 | num_base_shots=2, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=800, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=800) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=800) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_3SHOT')], 14 | num_novel_shots=3, 15 | num_base_shots=3, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=1200, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=1200) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1200) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=True, 11 | dataset=dict( 12 | type='FewShotVOCDefaultDataset', 13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_5SHOT')], 14 | num_novel_shots=5, 15 | num_base_shots=5, 16 | classes='ALL_CLASSES_SPLIT3', 17 | )), 18 | val=dict(classes='ALL_CLASSES_SPLIT3'), 19 | test=dict(classes='ALL_CLASSES_SPLIT3'), 20 | model_init=dict(classes='ALL_CLASSES_SPLIT3')) 21 | evaluation = dict( 22 | interval=1600, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3']) 23 | checkpoint_config = dict(interval=1600) 24 | optimizer = dict(lr=0.001) 25 | lr_config = dict(warmup=None) 26 | runner = dict(max_iters=1600) 27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth' 28 | 29 | # model settings 30 | model = dict(frozen_parameters=[ 31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head', 32 | ]) -------------------------------------------------------------------------------- /configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_base-training.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py', 3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py', 4 | '../../../_base_/default_runtime.py' 5 | ] 6 | # classes splits are predefined in FewShotVOCDataset 7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility. 8 | data = dict( 9 | train=dict( 10 | save_dataset=False, 11 | dataset=dict(classes='BASE_CLASSES_SPLIT3'), 12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT3')), 13 | val=dict(classes='BASE_CLASSES_SPLIT3'), 14 | test=dict(classes='BASE_CLASSES_SPLIT3'), 15 | model_init=dict(classes='BASE_CLASSES_SPLIT3')) 16 | lr_config = dict(warmup_iters=100, step=[12000, 16000]) 17 | evaluation = dict(interval=3000) 18 | checkpoint_config = dict(interval=3000) 19 | runner = dict(max_iters=18000) 20 | optimizer = dict(lr=0.02) 21 | # model settings 22 | model = dict(roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15))) -------------------------------------------------------------------------------- /dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | CHECKPOINT=$2 5 | GPUS=$3 6 | PORT=${PORT:-29500} 7 | 8 | PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \ 9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4} 11 | -------------------------------------------------------------------------------- /dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | PORT=${PORT:-29500} 6 | 7 | PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \ 8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | from setuptools import setup, find_packages 3 | 4 | setup( 5 | name="vfa", 6 | version=0.1, 7 | author="csuhan", 8 | url="https://github.com/csuhan/VFA", 9 | description="Codebase for few-shot object detection", 10 | python_requires=">=3.6", 11 | packages=find_packages(exclude=('configs', 'data', 'work_dirs')), 12 | install_requires=[ 13 | 'clip@git+ssh://git@github.com/openai/CLIP.git' 14 | ], 15 | ) -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import os 4 | import warnings 5 | 6 | import mmcv 7 | import torch 8 | from mmcv import Config, DictAction 9 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel 10 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint, 11 | wrap_fp16_model) 12 | 13 | from mmfewshot.detection.datasets import (build_dataloader, build_dataset, 14 | get_copy_dataset_type) 15 | from mmfewshot.detection.models import build_detector 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser( 20 | description='MMFewShot test (and eval) a model') 21 | parser.add_argument('config', help='test config file path') 22 | parser.add_argument('checkpoint', help='checkpoint file') 23 | parser.add_argument('--out', help='output result file in pickle format') 24 | parser.add_argument( 25 | '--eval', 26 | type=str, 27 | nargs='+', 28 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",' 29 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC') 30 | parser.add_argument('--show', action='store_true', help='show results') 31 | parser.add_argument( 32 | '--show-dir', help='directory where painted images will be saved') 33 | parser.add_argument( 34 | '--show-score-thr', 35 | type=float, 36 | default=0.3, 37 | help='score threshold (default: 0.3)') 38 | parser.add_argument( 39 | '--gpu-collect', 40 | action='store_true', 41 | help='whether to use gpu to collect results.') 42 | parser.add_argument( 43 | '--tmpdir', 44 | help='tmp directory used for collecting results from multiple ' 45 | 'workers, available when gpu-collect is not specified') 46 | parser.add_argument( 47 | '--cfg-options', 48 | nargs='+', 49 | action=DictAction, 50 | help='override some settings in the used config, the key-value pair ' 51 | 'in xxx=yyy format will be merged into config file. If the value to ' 52 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 53 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 54 | 'Note that the quotation marks are necessary and that no white space ' 55 | 'is allowed.') 56 | parser.add_argument( 57 | '--options', 58 | nargs='+', 59 | action=DictAction, 60 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 61 | 'format will be kwargs for dataset.evaluate() function (deprecate), ' 62 | 'change to --eval-options instead.') 63 | parser.add_argument( 64 | '--eval-options', 65 | nargs='+', 66 | action=DictAction, 67 | help='custom options for evaluation, the key-value pair in xxx=yyy ' 68 | 'format will be kwargs for dataset.evaluate() function') 69 | parser.add_argument( 70 | '--launcher', 71 | choices=['none', 'pytorch', 'slurm', 'mpi'], 72 | default='none', 73 | help='job launcher') 74 | parser.add_argument('--local_rank', type=int, default=0) 75 | args = parser.parse_args() 76 | if 'LOCAL_RANK' not in os.environ: 77 | os.environ['LOCAL_RANK'] = str(args.local_rank) 78 | 79 | if args.options and args.eval_options: 80 | raise ValueError( 81 | '--options and --eval-options cannot be both ' 82 | 'specified, --options is deprecated in favor of --eval-options') 83 | if args.options: 84 | warnings.warn('--options is deprecated in favor of --eval-options') 85 | args.eval_options = args.options 86 | args.cfg_options = args.options 87 | return args 88 | 89 | 90 | def main(): 91 | args = parse_args() 92 | 93 | assert args.out or args.eval or args.show \ 94 | or args.show_dir, ( 95 | 'Please specify at least one operation (save/eval/show the ' 96 | 'results / save the results) with the argument "--out", "--eval"', 97 | '"--show" or "--show-dir"') 98 | 99 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')): 100 | raise ValueError('The output file must be a pkl file.') 101 | 102 | cfg = Config.fromfile(args.config) 103 | 104 | if args.cfg_options is not None: 105 | cfg.merge_from_dict(args.cfg_options) 106 | 107 | # import modules from string list. 108 | if cfg.get('custom_imports', None): 109 | from mmcv.utils import import_modules_from_strings 110 | import_modules_from_strings(**cfg['custom_imports']) 111 | # set cudnn_benchmark 112 | if cfg.get('cudnn_benchmark', False): 113 | torch.backends.cudnn.benchmark = True 114 | cfg.model.pretrained = None 115 | 116 | # currently only support single images testing 117 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1) 118 | assert samples_per_gpu == 1, 'currently only support single images testing' 119 | 120 | # init distributed env first, since logger depends on the dist info. 121 | if args.launcher == 'none': 122 | distributed = False 123 | else: 124 | distributed = True 125 | init_dist(args.launcher, **cfg.dist_params) 126 | 127 | # build the dataloader 128 | dataset = build_dataset(cfg.data.test) 129 | data_loader = build_dataloader( 130 | dataset, 131 | samples_per_gpu=samples_per_gpu, 132 | workers_per_gpu=cfg.data.workers_per_gpu, 133 | dist=distributed, 134 | shuffle=False) 135 | 136 | # pop frozen_parameters 137 | cfg.model.pop('frozen_parameters', None) 138 | 139 | # build the model and load checkpoint 140 | cfg.model.train_cfg = None 141 | model = build_detector(cfg.model) 142 | 143 | fp16_cfg = cfg.get('fp16', None) 144 | if fp16_cfg is not None: 145 | wrap_fp16_model(model) 146 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu') 147 | # old versions did not save class info in checkpoints, this walkaround is 148 | # for backward compatibility 149 | if 'CLASSES' in checkpoint.get('meta', {}): 150 | model.CLASSES = checkpoint['meta']['CLASSES'] 151 | else: 152 | model.CLASSES = dataset.CLASSES 153 | 154 | # for meta-learning methods which require support template dataset 155 | # for model initialization. 156 | if cfg.data.get('model_init', None) is not None: 157 | cfg.data.model_init.pop('copy_from_train_dataset') 158 | model_init_samples_per_gpu = cfg.data.model_init.pop( 159 | 'samples_per_gpu', 1) 160 | model_init_workers_per_gpu = cfg.data.model_init.pop( 161 | 'workers_per_gpu', 1) 162 | if cfg.data.model_init.get('ann_cfg', None) is None: 163 | assert checkpoint['meta'].get('model_init_ann_cfg', 164 | None) is not None 165 | cfg.data.model_init.type = \ 166 | get_copy_dataset_type(cfg.data.model_init.type) 167 | cfg.data.model_init.ann_cfg = \ 168 | checkpoint['meta']['model_init_ann_cfg'] 169 | model_init_dataset = build_dataset(cfg.data.model_init) 170 | # disable dist to make all rank get same data 171 | model_init_dataloader = build_dataloader( 172 | model_init_dataset, 173 | samples_per_gpu=model_init_samples_per_gpu, 174 | workers_per_gpu=model_init_workers_per_gpu, 175 | dist=False, 176 | shuffle=False) 177 | 178 | if cfg.data.get('support_init', None) is not None: 179 | model_init_dataset = build_dataset(cfg.data.support_init) 180 | # disable dist to make all rank get same data 181 | model_init_dataloader = build_dataloader( 182 | model_init_dataset, 183 | samples_per_gpu=1, 184 | workers_per_gpu=1, 185 | dist=False, 186 | shuffle=False) 187 | 188 | if not distributed: 189 | model = MMDataParallel(model, device_ids=[0]) 190 | show_kwargs = dict(show_score_thr=args.show_score_thr) 191 | if (cfg.data.get('model_init', None) is not None) or (cfg.data.get('support_init', None) is not None): 192 | from mmfewshot.detection.apis import (single_gpu_model_init, 193 | single_gpu_test) 194 | single_gpu_model_init(model, model_init_dataloader) 195 | else: 196 | from mmdet.apis.test import single_gpu_test 197 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir, 198 | **show_kwargs) 199 | else: 200 | model = MMDistributedDataParallel( 201 | model.cuda(), 202 | device_ids=[torch.cuda.current_device()], 203 | broadcast_buffers=False) 204 | if (cfg.data.get('model_init', None) is not None) or (cfg.data.get('support_init', None) is not None): 205 | from mmfewshot.detection.apis import (multi_gpu_model_init, 206 | multi_gpu_test) 207 | multi_gpu_model_init(model, model_init_dataloader) 208 | else: 209 | from mmdet.apis.test import multi_gpu_test 210 | outputs = multi_gpu_test( 211 | model, 212 | data_loader, 213 | args.tmpdir, 214 | args.gpu_collect, 215 | ) 216 | 217 | rank, _ = get_dist_info() 218 | if rank == 0: 219 | if args.out: 220 | print(f'\nwriting results to {args.out}') 221 | mmcv.dump(outputs, args.out) 222 | kwargs = {} if args.eval_options is None else args.eval_options 223 | if args.eval: 224 | eval_kwargs = cfg.get('evaluation', {}).copy() 225 | # hard-code way to remove EvalHook args 226 | for key in [ 227 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best', 228 | 'rule' 229 | ]: 230 | eval_kwargs.pop(key, None) 231 | eval_kwargs.update(dict(metric=args.eval, **kwargs)) 232 | print(dataset.evaluate(outputs, **eval_kwargs)) 233 | 234 | 235 | if __name__ == '__main__': 236 | main() 237 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import argparse 3 | import copy 4 | import os 5 | import os.path as osp 6 | import multiprocessing as mp 7 | import platform 8 | import time 9 | import warnings 10 | 11 | import cv2 12 | import mmcv 13 | import torch 14 | from mmcv import Config, DictAction 15 | from mmcv.runner import get_dist_info, init_dist, set_random_seed 16 | from mmcv.utils import get_git_hash 17 | from mmdet.utils import collect_env 18 | 19 | import mmfewshot # noqa: F401, F403 20 | from mmfewshot import __version__ 21 | # from mmfewshot.detection.apis import train_detector 22 | from mmfewshot.detection.apis.train import train_detector 23 | from mmfewshot.detection.datasets import build_dataset 24 | from mmfewshot.detection.models import build_detector 25 | from mmfewshot.utils import get_root_logger 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser(description='Train a FewShot model') 29 | parser.add_argument('config', help='train config file path') 30 | parser.add_argument( 31 | '--work-dir', help='the directory to save logs and models') 32 | parser.add_argument( 33 | '--resume-from', help='the checkpoint file to resume from') 34 | parser.add_argument( 35 | '--no-validate', 36 | action='store_true', 37 | help='whether not to evaluate the checkpoint during training') 38 | group_gpus = parser.add_mutually_exclusive_group() 39 | group_gpus.add_argument( 40 | '--gpus', 41 | type=int, 42 | help='number of gpus to use ' 43 | '(only applicable to non-distributed training)') 44 | group_gpus.add_argument( 45 | '--gpu-ids', 46 | type=int, 47 | nargs='+', 48 | help='ids of gpus to use ' 49 | '(only applicable to non-distributed training)') 50 | parser.add_argument('--seed', type=int, default=None, help='random seed') 51 | parser.add_argument( 52 | '--deterministic', 53 | action='store_true', 54 | help='whether to set deterministic options for CUDNN backend.') 55 | parser.add_argument( 56 | '--options', 57 | nargs='+', 58 | action=DictAction, 59 | help='override some settings in the used config, the key-value pair ' 60 | 'in xxx=yyy format will be merged into config file (deprecate), ' 61 | 'change to --cfg-options instead.') 62 | parser.add_argument( 63 | '--cfg-options', 64 | nargs='+', 65 | action=DictAction, 66 | help='override some settings in the used config, the key-value pair ' 67 | 'in xxx=yyy format will be merged into config file. If the value ' 68 | 'to be overwritten is a list, it should be like key="[a,b]" or ' 69 | 'key=a,b It also allows nested list/tuple values, e.g. ' 70 | 'key="[(a,b),(c,d)]" Note that the quotation marks are necessary ' 71 | 'and that no white space is allowed.') 72 | parser.add_argument( 73 | '--launcher', 74 | choices=['none', 'pytorch', 'slurm', 'mpi'], 75 | default='none', 76 | help='job launcher') 77 | parser.add_argument('--local_rank', type=int, default=0) 78 | args = parser.parse_args() 79 | if 'LOCAL_RANK' not in os.environ: 80 | os.environ['LOCAL_RANK'] = str(args.local_rank) 81 | 82 | if args.options and args.cfg_options: 83 | raise ValueError( 84 | '--options and --cfg-options cannot be both ' 85 | 'specified, --options is deprecated in favor of --cfg-options') 86 | if args.options: 87 | warnings.warn('--options is deprecated in favor of --cfg-options') 88 | args.cfg_options = args.options 89 | 90 | return args 91 | 92 | 93 | def setup_multi_processes(cfg): 94 | # set multi-process start method as `fork` to speed up the training 95 | if platform.system() != 'Windows': 96 | mp_start_method = cfg.get('mp_start_method', 'fork') 97 | mp.set_start_method(mp_start_method) 98 | 99 | # disable opencv multithreading to avoid system being overloaded 100 | opencv_num_threads = cfg.get('opencv_num_threads', 0) 101 | cv2.setNumThreads(opencv_num_threads) 102 | 103 | # setup OMP threads 104 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 105 | if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1): 106 | omp_num_threads = 1 107 | warnings.warn( 108 | f'Setting OMP_NUM_THREADS environment variable for each process ' 109 | f'to be {omp_num_threads} in default, to avoid your system being ' 110 | f'overloaded, please further tune the variable for optimal ' 111 | f'performance in your application as needed.') 112 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 113 | 114 | # setup MKL threads 115 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 116 | mkl_num_threads = 1 117 | warnings.warn( 118 | f'Setting MKL_NUM_THREADS environment variable for each process ' 119 | f'to be {mkl_num_threads} in default, to avoid your system being ' 120 | f'overloaded, please further tune the variable for optimal ' 121 | f'performance in your application as needed.') 122 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 123 | 124 | def main(): 125 | args = parse_args() 126 | 127 | cfg = Config.fromfile(args.config) 128 | 129 | if args.cfg_options is not None: 130 | cfg.merge_from_dict(args.cfg_options) 131 | 132 | # set multi-process settings 133 | setup_multi_processes(cfg) 134 | 135 | # import modules from string list. 136 | if cfg.get('custom_imports', None): 137 | from mmcv.utils import import_modules_from_strings 138 | import_modules_from_strings(**cfg['custom_imports']) 139 | # set cudnn_benchmark 140 | if cfg.get('cudnn_benchmark', False): 141 | torch.backends.cudnn.benchmark = True 142 | 143 | # work_dir is determined in this priority: CLI > segment in file > filename 144 | if args.work_dir is not None: 145 | # update configs according to CLI args if args.work_dir is not None 146 | cfg.work_dir = args.work_dir 147 | elif cfg.get('work_dir', None) is None: 148 | # use config filename as default work_dir if cfg.work_dir is None 149 | cfg.work_dir = osp.join('./work_dirs', 150 | osp.splitext(osp.basename(args.config))[0]) 151 | if args.resume_from is not None: 152 | cfg.resume_from = args.resume_from 153 | if args.gpu_ids is not None: 154 | cfg.gpu_ids = args.gpu_ids 155 | else: 156 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus) 157 | 158 | # init distributed env first, since logger depends on the dist info. 159 | if args.launcher == 'none': 160 | distributed = False 161 | rank, world_size = get_dist_info() 162 | else: 163 | distributed = True 164 | init_dist(args.launcher, **cfg.dist_params) 165 | rank, world_size = get_dist_info() 166 | # re-set gpu_ids with distributed training mode 167 | cfg.gpu_ids = range(world_size) 168 | # create work_dir 169 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) 170 | # dump config 171 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config))) 172 | # init the logger before other steps 173 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 174 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log') 175 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level) 176 | 177 | # init the meta dict to record some important information such as 178 | # environment info and seed, which will be logged 179 | meta = dict() 180 | # log env info 181 | env_info_dict = collect_env() 182 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()]) 183 | dash_line = '-' * 60 + '\n' 184 | logger.info('Environment info:\n' + dash_line + env_info + '\n' + 185 | dash_line) 186 | meta['env_info'] = env_info 187 | meta['config'] = cfg.pretty_text 188 | # log some basic info 189 | logger.info(f'Distributed training: {distributed}') 190 | logger.info(f'Config:\n{cfg.pretty_text}') 191 | 192 | # set random seeds 193 | if args.seed is not None: 194 | seed = args.seed 195 | elif cfg.seed is not None: 196 | seed = cfg.seed 197 | elif distributed: 198 | seed = 0 199 | Warning(f'When using DistributedDataParallel, each rank will ' 200 | f'initialize different random seed. It will cause different' 201 | f'random action for each rank. In few shot setting, novel ' 202 | f'shots may be generated by random sampling. If all rank do ' 203 | f'not use same seed, each rank will sample different data.' 204 | f'It will cause UNFAIR data usage. Therefore, seed is set ' 205 | f'to {seed} for default.') 206 | else: 207 | seed = None 208 | 209 | if seed is not None: 210 | logger.info(f'Set random seed to {seed}, ' 211 | f'deterministic: {args.deterministic}') 212 | set_random_seed(seed, deterministic=args.deterministic) 213 | meta['seed'] = seed 214 | meta['exp_name'] = osp.basename(args.config) 215 | 216 | # build_detector will do three things, including building model, 217 | # initializing weights and freezing parameters (optional). 218 | model = build_detector(cfg.model, logger=logger) 219 | # build_dataset will do two things, including building dataset 220 | # and saving dataset into json file (optional). 221 | datasets = [ 222 | build_dataset( 223 | cfg.data.train, 224 | rank=rank, 225 | work_dir=cfg.work_dir, 226 | timestamp=timestamp) 227 | ] 228 | 229 | if len(cfg.workflow) == 2: 230 | val_dataset = copy.deepcopy(cfg.data.val) 231 | val_dataset.pipeline = cfg.data.train.pipeline 232 | datasets.append(build_dataset(val_dataset)) 233 | if cfg.checkpoint_config is not None: 234 | # save mmfewshot version, config file content and class names in 235 | # checkpoints as meta data 236 | cfg.checkpoint_config.meta = dict( 237 | mmfewshot_version=__version__ + get_git_hash()[:7], 238 | CLASSES=datasets[0].CLASSES) 239 | # add an attribute for visualization convenience 240 | model.CLASSES = datasets[0].CLASSES 241 | train_detector( 242 | model, 243 | datasets, 244 | cfg, 245 | distributed=distributed, 246 | validate=(not args.no_validate), 247 | timestamp=timestamp, 248 | meta=meta) 249 | 250 | 251 | if __name__ == '__main__': 252 | main() 253 | -------------------------------------------------------------------------------- /vfa/__init__.py: -------------------------------------------------------------------------------- 1 | from .vfa_detector import VFA 2 | from .vfa_roi_head import VFARoIHead 3 | from .vfa_bbox_head import VFABBoxHead 4 | 5 | __all__ = ['VFA', 'VFARoIHead', 'VFABBoxHead'] 6 | -------------------------------------------------------------------------------- /vfa/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from PIL import Image 5 | from torchvision import transforms as trans 6 | 7 | from mmfewshot.detection.datasets.coco import COCO_SPLIT 8 | 9 | 10 | class PCB: 11 | def __init__(self, class_names, model="RN101", templates="a photo of a {}"): 12 | super().__init__() 13 | self.device = torch.cuda.current_device() 14 | 15 | # image transforms 16 | self.expand_ratio = 0.1 17 | self.trans = trans.Compose([ 18 | trans.Resize([224, 224], interpolation=3), 19 | trans.ToTensor(), 20 | trans.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 21 | 22 | # CLIP configs 23 | import clip 24 | self.class_names = class_names 25 | self.clip, _ = clip.load(model, device=self.device) 26 | self.prompts = clip.tokenize([ 27 | templates.format(cls_name) 28 | for cls_name in self.class_names 29 | ]).to(self.device) 30 | with torch.no_grad(): 31 | text_features = self.clip.encode_text(self.prompts) 32 | self.text_features = F.normalize(text_features, dim=-1, p=2) 33 | 34 | 35 | def load_image_by_box(self, img_path, boxes): 36 | image = Image.open(img_path).convert("RGB") 37 | image_list = [] 38 | for box in boxes: 39 | x1, y1, x2, y2 = box 40 | h, w = y2-y1, x2-x1 41 | x1 = max(0, x1 - w*self.expand_ratio) 42 | y1 = max(0, y1 - h*self.expand_ratio) 43 | x2 = x2 + w*self.expand_ratio 44 | y2 = y2 + h*self.expand_ratio 45 | sub_image = image.crop((int(x1), int(y1), int(x2), int(y2))) 46 | sub_image = self.trans(sub_image).to(self.device) 47 | image_list.append(sub_image) 48 | return torch.stack(image_list) 49 | 50 | @torch.no_grad() 51 | def __call__(self, img_path, boxes): 52 | images = self.load_image_by_box(img_path, boxes) 53 | 54 | image_features = self.clip.encode_image(images) 55 | image_features = F.normalize(image_features, dim=-1, p=2) 56 | logit_scale = self.clip.logit_scale.exp() 57 | logits_per_image = logit_scale * image_features @ self.text_features.t() 58 | return logits_per_image.softmax(dim=-1) 59 | 60 | 61 | class TestMixins: 62 | def __init__(self): 63 | self.pcb = None 64 | 65 | def refine_test(self, results, img_metas): 66 | if not hasattr(self, 'pcb'): 67 | self.pcb = PCB(COCO_SPLIT['ALL_CLASSES'], model='ViT-B/32') 68 | # exclue ids for COCO 69 | self.exclude_ids = [7, 9, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 70 | 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 41, 42, 43, 44, 45, 71 | 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 61, 63, 64, 65, 72 | 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79] 73 | 74 | boxes_list, scores_list, labels_list = [], [], [] 75 | for cls_id, result in enumerate(results[0]): 76 | if len(result) == 0: 77 | continue 78 | boxes_list.append(result[:, :4]) 79 | scores_list.append(result[:, 4]) 80 | labels_list.append([cls_id] * len(result)) 81 | 82 | if len(boxes_list) == 0: 83 | return results 84 | 85 | boxes_list = np.concatenate(boxes_list, axis=0) 86 | scores_list = np.concatenate(scores_list, axis=0) 87 | labels_list = np.concatenate(labels_list, axis=0) 88 | 89 | logits = self.pcb(img_metas[0]['filename'], boxes_list) 90 | 91 | for i, prob in enumerate(logits): 92 | if labels_list[i] not in self.exclude_ids: 93 | scores_list[i] = scores_list[i] * 0.5 + logits[i, labels_list[i]] * 0.5 94 | 95 | j = 0 96 | for i in range(len(results[0])): 97 | num_dets = len(results[0][i]) 98 | if num_dets == 0: 99 | continue 100 | for k in range(num_dets): 101 | results[0][i][k, 4] = scores_list[j] 102 | j += 1 103 | 104 | return results -------------------------------------------------------------------------------- /vfa/vfa_bbox_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from mmcv.runner import auto_fp16 4 | from mmdet.models.builder import HEADS 5 | 6 | from mmfewshot.detection.models.roi_heads.bbox_heads.meta_bbox_head import MetaBBoxHead 7 | 8 | 9 | @HEADS.register_module() 10 | class VFABBoxHead(MetaBBoxHead): 11 | 12 | @auto_fp16() 13 | def forward(self, x_agg, x_query): 14 | if self.with_avg_pool: 15 | if x_agg.numel() > 0: 16 | x_agg = self.avg_pool(x_agg) 17 | x_agg = x_agg.view(x_agg.size(0), -1) 18 | else: 19 | # avg_pool does not support empty tensor, 20 | # so use torch.mean instead it 21 | x_agg = torch.mean(x_agg, dim=(-1, -2)) 22 | if x_query.numel() > 0: 23 | x_query = self.avg_pool(x_query) 24 | x_query = x_query.view(x_query.size(0), -1) 25 | else: 26 | x_query = torch.mean(x_query, dim=(-1, -2)) 27 | cls_score = self.fc_cls(x_agg) if self.with_cls else None 28 | bbox_pred = self.fc_reg(x_query) if self.with_reg else None 29 | return cls_score, bbox_pred 30 | -------------------------------------------------------------------------------- /vfa/vfa_detector.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from typing import Dict, List, Optional 3 | 4 | from mmdet.models.builder import DETECTORS 5 | from mmfewshot.detection.models import MetaRCNN 6 | 7 | from .utils import TestMixins 8 | 9 | 10 | @DETECTORS.register_module() 11 | class VFA(MetaRCNN, TestMixins): 12 | def __init__(self, *args, with_refine=False, **kwargs): 13 | super().__init__(*args, **kwargs) 14 | # refine results for COCO. We do not use it for VOC. 15 | self.with_refine = with_refine 16 | 17 | def forward_train(self, 18 | query_data: Dict, 19 | support_data: Dict, 20 | proposals: Optional[List] = None, 21 | **kwargs) -> Dict: 22 | """Forward function for training. 23 | 24 | Args: 25 | query_data (dict): In most cases, dict of query data contains: 26 | `img`, `img_metas`, `gt_bboxes`, `gt_labels`, 27 | `gt_bboxes_ignore`. 28 | support_data (dict): In most cases, dict of support data contains: 29 | `img`, `img_metas`, `gt_bboxes`, `gt_labels`, 30 | `gt_bboxes_ignore`. 31 | proposals (list): Override rpn proposals with custom proposals. 32 | Use when `with_rpn` is False. Default: None. 33 | 34 | Returns: 35 | dict[str, Tensor]: a dictionary of loss components 36 | """ 37 | query_img = query_data['img'] 38 | support_img = support_data['img'] 39 | query_feats = self.extract_query_feat(query_img) 40 | 41 | # stop gradient at RPN 42 | query_feats_rpn = [x.detach() for x in query_feats] 43 | query_feats_rcnn = query_feats 44 | 45 | support_feats = self.extract_support_feat(support_img) 46 | 47 | losses = dict() 48 | 49 | # RPN forward and loss 50 | if self.with_rpn: 51 | proposal_cfg = self.train_cfg.get('rpn_proposal', 52 | self.test_cfg.rpn) 53 | if self.rpn_with_support: 54 | rpn_losses, proposal_list = self.rpn_head.forward_train( 55 | query_feats_rpn, 56 | support_feats, 57 | query_img_metas=query_data['img_metas'], 58 | query_gt_bboxes=query_data['gt_bboxes'], 59 | query_gt_labels=None, 60 | query_gt_bboxes_ignore=query_data.get( 61 | 'gt_bboxes_ignore', None), 62 | support_img_metas=support_data['img_metas'], 63 | support_gt_bboxes=support_data['gt_bboxes'], 64 | support_gt_labels=support_data['gt_labels'], 65 | support_gt_bboxes_ignore=support_data.get( 66 | 'gt_bboxes_ignore', None), 67 | proposal_cfg=proposal_cfg) 68 | else: 69 | rpn_losses, proposal_list = self.rpn_head.forward_train( 70 | query_feats_rpn, 71 | copy.deepcopy(query_data['img_metas']), 72 | copy.deepcopy(query_data['gt_bboxes']), 73 | gt_labels=None, 74 | gt_bboxes_ignore=copy.deepcopy( 75 | query_data.get('gt_bboxes_ignore', None)), 76 | proposal_cfg=proposal_cfg) 77 | losses.update(rpn_losses) 78 | else: 79 | proposal_list = proposals 80 | 81 | roi_losses = self.roi_head.forward_train( 82 | query_feats_rcnn, 83 | support_feats, 84 | proposals=proposal_list, 85 | query_img_metas=query_data['img_metas'], 86 | query_gt_bboxes=query_data['gt_bboxes'], 87 | query_gt_labels=query_data['gt_labels'], 88 | query_gt_bboxes_ignore=query_data.get('gt_bboxes_ignore', None), 89 | support_img_metas=support_data['img_metas'], 90 | support_gt_bboxes=support_data['gt_bboxes'], 91 | support_gt_labels=support_data['gt_labels'], 92 | support_gt_bboxes_ignore=support_data.get('gt_bboxes_ignore', 93 | None), 94 | **kwargs) 95 | losses.update(roi_losses) 96 | 97 | return losses 98 | 99 | def simple_test(self, img, img_metas, proposals = None, rescale = False): 100 | bbox_results = super().simple_test(img, img_metas, proposals, rescale) 101 | if self.with_refine: 102 | return self.refine_test(bbox_results, img_metas) 103 | else: 104 | return bbox_results 105 | 106 | 107 | -------------------------------------------------------------------------------- /vfa/vfa_roi_head.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch import Tensor 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from mmcv.utils import ConfigDict 9 | from mmdet.core import bbox2roi 10 | from mmdet.models.builder import HEADS 11 | from mmfewshot.detection.models.roi_heads.meta_rcnn_roi_head import MetaRCNNRoIHead 12 | 13 | 14 | class VAE(nn.Module): 15 | 16 | def __init__(self, 17 | in_channels: int, 18 | latent_dim: int, 19 | hidden_dim: int) -> None: 20 | super(VAE, self).__init__() 21 | 22 | self.latent_dim = latent_dim 23 | 24 | self.encoder = nn.Sequential( 25 | nn.Linear(in_channels, hidden_dim), 26 | nn.BatchNorm1d(hidden_dim), 27 | nn.LeakyReLU() 28 | ) 29 | self.fc_mu = nn.Linear(hidden_dim, latent_dim) 30 | self.fc_var = nn.Linear(hidden_dim, latent_dim) 31 | 32 | self.decoder_input = nn.Linear(latent_dim, hidden_dim) 33 | 34 | self.decoder = nn.Sequential( 35 | nn.Linear(hidden_dim, in_channels), 36 | nn.BatchNorm1d(in_channels), 37 | nn.Sigmoid() 38 | ) 39 | 40 | def encode(self, input: Tensor) -> List[Tensor]: 41 | result = self.encoder(input) 42 | 43 | mu = self.fc_mu(result) 44 | log_var = self.fc_var(result) 45 | 46 | return [mu, log_var] 47 | 48 | def decode(self, z: Tensor) -> Tensor: 49 | 50 | z = self.decoder_input(z) 51 | z_out = self.decoder(z) 52 | return z_out 53 | 54 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 55 | std = torch.exp(0.5 * logvar) 56 | eps = torch.randn_like(std) 57 | return eps * std + mu, std + mu 58 | 59 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]: 60 | mu, log_var = self.encode(input) 61 | z, z_inv = self.reparameterize(mu, log_var) 62 | z_out = self.decode(z) 63 | 64 | return [z_out, z_inv, input, mu, log_var] 65 | 66 | def loss_function(self, input, rec, mu, log_var, kld_weight=0.00025) -> dict: 67 | recons_loss = F.mse_loss(rec, input) 68 | 69 | kld_loss = torch.mean(-0.5 * torch.sum(1 + 70 | log_var - mu ** 2 - log_var.exp(), dim=1), dim=0) 71 | 72 | loss = recons_loss + kld_weight * kld_loss 73 | 74 | return {'loss_vae': loss} 75 | 76 | 77 | @HEADS.register_module() 78 | class VFARoIHead(MetaRCNNRoIHead): 79 | 80 | def __init__(self, vae_dim=2048, *args, **kargs) -> None: 81 | super().__init__(*args, **kargs) 82 | 83 | self.vae = VAE(vae_dim, vae_dim, vae_dim) 84 | 85 | def _bbox_forward_train(self, query_feats: List[Tensor], 86 | support_feats: List[Tensor], 87 | sampling_results: object, 88 | query_img_metas: List[Dict], 89 | query_gt_bboxes: List[Tensor], 90 | query_gt_labels: List[Tensor], 91 | support_gt_labels: List[Tensor]) -> Dict: 92 | """Forward function and calculate loss for box head in training. 93 | 94 | Args: 95 | query_feats (list[Tensor]): List of query features, each item 96 | with shape (N, C, H, W). 97 | support_feats (list[Tensor]): List of support features, each item 98 | with shape (N, C, H, W). 99 | sampling_results (obj:`SamplingResult`): Sampling results. 100 | query_img_metas (list[dict]): List of query image info dict where 101 | each dict has: 'img_shape', 'scale_factor', 'flip', and may 102 | also contain 'filename', 'ori_shape', 'pad_shape', and 103 | 'img_norm_cfg'. For details on the values of these keys see 104 | `mmdet/datasets/pipelines/formatting.py:Collect`. 105 | query_gt_bboxes (list[Tensor]): Ground truth bboxes for each query 106 | image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] 107 | format. 108 | query_gt_labels (list[Tensor]): Class indices corresponding to 109 | each box of query images. 110 | support_gt_labels (list[Tensor]): Class indices corresponding to 111 | each box of support images. 112 | 113 | Returns: 114 | dict: Predicted results and losses. 115 | """ 116 | query_rois = bbox2roi([res.bboxes for res in sampling_results]) 117 | query_roi_feats = self.extract_query_roi_feat(query_feats, query_rois) 118 | support_feat = self.extract_support_feats(support_feats)[0] 119 | support_feat_rec, support_feat_inv, _, mu, log_var = self.vae( 120 | support_feat) 121 | 122 | bbox_targets = self.bbox_head.get_targets(sampling_results, 123 | query_gt_bboxes, 124 | query_gt_labels, 125 | self.train_cfg) 126 | (labels, label_weights, bbox_targets, bbox_weights) = bbox_targets 127 | loss_bbox = {'loss_cls': [], 'loss_bbox': [], 'acc': []} 128 | batch_size = len(query_img_metas) 129 | num_sample_per_imge = query_roi_feats.size(0) // batch_size 130 | bbox_results = None 131 | for img_id in range(batch_size): 132 | start = img_id * num_sample_per_imge 133 | end = (img_id + 1) * num_sample_per_imge 134 | # class agnostic aggregation 135 | # random_index = np.random.choice( 136 | # range(query_gt_labels[img_id].size(0))) 137 | # random_query_label = query_gt_labels[img_id][random_index] 138 | random_index = np.random.choice( 139 | range(len(support_gt_labels))) 140 | random_query_label = support_gt_labels[random_index] 141 | for i in range(support_feat.size(0)): 142 | if support_gt_labels[i] == random_query_label: 143 | bbox_results = self._bbox_forward( 144 | query_roi_feats[start:end], 145 | support_feat_inv[i].sigmoid().unsqueeze(0)) 146 | single_loss_bbox = self.bbox_head.loss( 147 | bbox_results['cls_score'], bbox_results['bbox_pred'], 148 | query_rois[start:end], labels[start:end], 149 | label_weights[start:end], bbox_targets[start:end], 150 | bbox_weights[start:end]) 151 | for key in single_loss_bbox.keys(): 152 | loss_bbox[key].append(single_loss_bbox[key]) 153 | if bbox_results is not None: 154 | for key in loss_bbox.keys(): 155 | if key == 'acc': 156 | loss_bbox[key] = torch.cat(loss_bbox['acc']).mean() 157 | else: 158 | loss_bbox[key] = torch.stack( 159 | loss_bbox[key]).sum() / batch_size 160 | 161 | # meta classification loss 162 | if self.bbox_head.with_meta_cls_loss: 163 | # input support feature classification 164 | meta_cls_score = self.bbox_head.forward_meta_cls(support_feat_rec) 165 | meta_cls_labels = torch.cat(support_gt_labels) 166 | loss_meta_cls = self.bbox_head.loss_meta( 167 | meta_cls_score, meta_cls_labels, 168 | torch.ones_like(meta_cls_labels)) 169 | loss_bbox.update(loss_meta_cls) 170 | 171 | loss_vae = self.vae.loss_function( 172 | support_feat, support_feat_rec, mu, log_var) 173 | loss_bbox.update(loss_vae) 174 | 175 | bbox_results.update(loss_bbox=loss_bbox) 176 | return bbox_results 177 | 178 | def _bbox_forward(self, query_roi_feats: Tensor, 179 | support_roi_feats: Tensor) -> Dict: 180 | """Box head forward function used in both training and testing. 181 | 182 | Args: 183 | query_roi_feats (Tensor): Query roi features with shape (N, C). 184 | support_roi_feats (Tensor): Support features with shape (1, C). 185 | 186 | Returns: 187 | dict: A dictionary of predicted results. 188 | """ 189 | # feature aggregation 190 | roi_feats = self.aggregation_layer( 191 | query_feat=query_roi_feats.unsqueeze(-1).unsqueeze(-1), 192 | support_feat=support_roi_feats.view(1, -1, 1, 1))[0] 193 | cls_score, bbox_pred = self.bbox_head( 194 | roi_feats.squeeze(-1).squeeze(-1), query_roi_feats) 195 | bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) 196 | return bbox_results 197 | 198 | def simple_test_bboxes( 199 | self, 200 | query_feats: List[Tensor], 201 | support_feats_dict: Dict, 202 | query_img_metas: List[Dict], 203 | proposals: List[Tensor], 204 | rcnn_test_cfg: ConfigDict, 205 | rescale: bool = False) -> Tuple[List[Tensor], List[Tensor]]: 206 | """Test only det bboxes without augmentation. 207 | 208 | Args: 209 | query_feats (list[Tensor]): Features of query image, 210 | each item with shape (N, C, H, W). 211 | support_feats_dict (dict[int, Tensor]) Dict of support features 212 | used for inference only, each key is the class id and value is 213 | the support template features with shape (1, C). 214 | query_img_metas (list[dict]): list of image info dict where each 215 | dict has: `img_shape`, `scale_factor`, `flip`, and may also 216 | contain `filename`, `ori_shape`, `pad_shape`, and 217 | `img_norm_cfg`. For details on the values of these keys see 218 | :class:`mmdet.datasets.pipelines.Collect`. 219 | proposals (list[Tensor]): Region proposals. 220 | rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN. 221 | rescale (bool): If True, return boxes in original image space. 222 | Default: False. 223 | 224 | Returns: 225 | tuple[list[Tensor], list[Tensor]]: Each tensor in first list 226 | with shape (num_boxes, 4) and with shape (num_boxes, ) 227 | in second list. The length of both lists should be equal 228 | to batch_size. 229 | """ 230 | img_shapes = tuple(meta['img_shape'] for meta in query_img_metas) 231 | scale_factors = tuple(meta['scale_factor'] for meta in query_img_metas) 232 | 233 | rois = bbox2roi(proposals) 234 | 235 | query_roi_feats = self.extract_query_roi_feat(query_feats, rois) 236 | cls_scores_dict, bbox_preds_dict = {}, {} 237 | num_classes = self.bbox_head.num_classes 238 | for class_id in support_feats_dict.keys(): 239 | support_feat = support_feats_dict[class_id] 240 | support_feat_rec, support_feat_inv, _, mu, log_var = self.vae( 241 | support_feat) 242 | bbox_results = self._bbox_forward( 243 | query_roi_feats, support_feat_inv.sigmoid()) 244 | cls_scores_dict[class_id] = \ 245 | bbox_results['cls_score'][:, class_id:class_id + 1] 246 | bbox_preds_dict[class_id] = \ 247 | bbox_results['bbox_pred'][:, class_id * 4:(class_id + 1) * 4] 248 | # the official code use the first class background score as final 249 | # background score, while this code use average of all classes' 250 | # background scores instead. 251 | if cls_scores_dict.get(num_classes, None) is None: 252 | cls_scores_dict[num_classes] = \ 253 | bbox_results['cls_score'][:, -1:] 254 | else: 255 | cls_scores_dict[num_classes] += \ 256 | bbox_results['cls_score'][:, -1:] 257 | cls_scores_dict[num_classes] /= len(support_feats_dict.keys()) 258 | cls_scores = [ 259 | cls_scores_dict[i] if i in cls_scores_dict.keys() else 260 | torch.zeros_like(cls_scores_dict[list(cls_scores_dict.keys())[0]]) 261 | for i in range(num_classes + 1) 262 | ] 263 | bbox_preds = [ 264 | bbox_preds_dict[i] if i in bbox_preds_dict.keys() else 265 | torch.zeros_like(bbox_preds_dict[list(bbox_preds_dict.keys())[0]]) 266 | for i in range(num_classes) 267 | ] 268 | cls_score = torch.cat(cls_scores, dim=1) 269 | bbox_pred = torch.cat(bbox_preds, dim=1) 270 | 271 | # split batch bbox prediction back to each image 272 | num_proposals_per_img = tuple(len(p) for p in proposals) 273 | rois = rois.split(num_proposals_per_img, 0) 274 | cls_score = cls_score.split(num_proposals_per_img, 0) 275 | bbox_pred = bbox_pred.split(num_proposals_per_img, 0) 276 | 277 | # apply bbox post-processing to each image individually 278 | det_bboxes = [] 279 | det_labels = [] 280 | for i in range(len(proposals)): 281 | det_bbox, det_label = self.bbox_head.get_bboxes( 282 | rois[i], 283 | cls_score[i], 284 | bbox_pred[i], 285 | img_shapes[i], 286 | scale_factors[i], 287 | rescale=rescale, 288 | cfg=rcnn_test_cfg) 289 | det_bboxes.append(det_bbox) 290 | det_labels.append(det_label) 291 | return det_bboxes, det_labels 292 | --------------------------------------------------------------------------------