├── .gitignore
├── README.md
├── VFA.png
├── configs
├── _base_
│ ├── datasets
│ │ └── nway_kshot
│ │ │ ├── base_coco_ms.py
│ │ │ ├── base_voc_ms.py
│ │ │ ├── few_shot_coco_ms.py
│ │ │ └── few_shot_voc_ms.py
│ ├── default_runtime.py
│ ├── models
│ │ └── faster_rcnn_r50_caffe_c4.py
│ └── schedules
│ │ └── schedule.py
└── vfa
│ ├── coco
│ ├── vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_coco_30shot-fine-tuning.py
│ └── vfa_r101_c4_8xb4_coco_base-training.py
│ ├── meta-rcnn_r50_c4.py
│ ├── vfa_r101_c4.py
│ └── voc
│ ├── vfa_split1
│ ├── vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning.py
│ └── vfa_r101_c4_8xb4_voc-split1_base-training.py
│ ├── vfa_split2
│ ├── vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning.py
│ └── vfa_r101_c4_8xb4_voc-split2_base-training.py
│ └── vfa_split3
│ ├── vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning.py
│ ├── vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py
│ └── vfa_r101_c4_8xb4_voc-split3_base-training.py
├── dist_test.sh
├── dist_train.sh
├── setup.py
├── test.py
├── train.py
└── vfa
├── __init__.py
├── utils.py
├── vfa_bbox_head.py
├── vfa_detector.py
└── vfa_roi_head.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/en/_build/
68 | docs/zh_cn/_build/
69 |
70 | # PyBuilder
71 | target/
72 |
73 | # Jupyter Notebook
74 | .ipynb_checkpoints
75 |
76 | # pyenv
77 | .python-version
78 |
79 | # celery beat schedule file
80 | celerybeat-schedule
81 |
82 | # SageMath parsed files
83 | *.sage.py
84 |
85 | # Environments
86 | .env
87 | .venv
88 | env/
89 | venv/
90 | ENV/
91 | env.bak/
92 | venv.bak/
93 |
94 | # Spyder project settings
95 | .spyderproject
96 | .spyproject
97 |
98 | # Rope project settings
99 | .ropeproject
100 |
101 | # mkdocs documentation
102 | /site
103 |
104 | # mypy
105 | .mypy_cache/
106 |
107 | /data/
108 | /data
109 | .vscode
110 | .idea
111 | .DS_Store
112 |
113 | # custom
114 | *.pkl
115 | *.pkl.json
116 | *.log.json
117 | work_dirs/
118 | mmfewshot/.mim
119 |
120 | # Pytorch
121 | *.pth
122 | *.py~
123 | *.sh~
124 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## VFA
2 |
3 |
4 |
5 | > **Few-Shot Object Detection via Variational Feature Aggregation (AAAI2023)**
6 | > [Jiaming Han](https://csuhan.com), [Yuqiang Ren](https://github.com/Anymake), [Jian Ding](https://dingjiansw101.github.io), [Ke Yan](https://scholar.google.com.hk/citations?user=vWstgn0AAAAJ), [Gui-Song Xia](http://www.captain-whu.com/xia_En.html).
7 | > [arXiv preprint](https://arxiv.org/abs/2301.13411).
8 |
9 | Our code is based on [mmfewshot](https://github.com/open-mmlab/mmfewshot).
10 |
11 | ### Setup
12 |
13 | * **Installation**
14 |
15 | Here is a from-scratch setup script.
16 |
17 | ```bash
18 | conda create -n vfa python=3.8 -y
19 | conda activate vfa
20 |
21 | conda install pytorch==1.7.1 torchvision==0.8.2 cudatoolkit=11.0 -c pytorch
22 |
23 | pip install openmim
24 | mim install mmcv-full==1.3.12
25 |
26 | # install mmclassification mmdetection
27 | mim install mmcls==0.15.0
28 | mim install mmdet==2.16.0
29 |
30 | # install mmfewshot
31 | mim install mmfewshot==0.1.0
32 |
33 | # install VFA
34 | python setup.py develop
35 |
36 | ```
37 |
38 | * **Prepare Datasets**
39 |
40 | Please refer to mmfewshot's [detection data preparation](https://github.com/open-mmlab/mmfewshot/blob/main/tools/data/README.md).
41 |
42 |
43 | ### Model Zoo
44 |
45 | All pretrained models can be found at [github release](https://github.com/csuhan/VFA/releases/tag/v1.0.0).
46 |
47 | #### Results on PASCAL VOC dataset
48 |
49 | * **Base Training**
50 |
51 | | Split | Base AP50 | config | ckpt |
52 | |-------|-----------|-----------------------------------------------------------------------------------|------|
53 | | 1 | 78.6 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_base-training_iter_18000.pth) |
54 | | 2 | 79.5 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_base-training_iter_18000.pth) |
55 | | 3 | 79.8 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_base-training_iter_18000.pth) |
56 |
57 | * **Few Shot Fine-tuning**
58 |
59 | | Split | Shot | nAP50 | config | ckpt |
60 | |-------|------|-------|----------------------------------------------------------------------------------------|----------|
61 | | 1 | 1 | 57.5 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning_iter_400.pth) |
62 | | 1 | 2 | 65.0 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning_iter_800.pth) |
63 | | 1 | 3 | 64.3 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning_iter_1200.pth) |
64 | | 1 | 5 | 67.1 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning_iter_1600.pth) |
65 | | 1 | 10 | 67.4 | [config](configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning_iter_2000.pth) |
66 | | 2 | 1 | 40.8 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning_iter_400.pth) |
67 | | 2 | 2 | 45.9 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning_iter_800.pth) |
68 | | 2 | 3 | 51.1 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning_iter_1200.pth) |
69 | | 2 | 5 | 51.8 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning_iter_1600.pth) |
70 | | 2 | 10 | 51.8 | [config](configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning_iter_2000.pth) |
71 | | 3 | 1 | 49.0 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning_iter_400.pth) |
72 | | 3 | 2 | 54.9 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning_iter_800.pth) |
73 | | 3 | 3 | 56.6 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning_iter_1200.pth) |
74 | | 3 | 5 | 59.0 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning_iter_1600.pth) |
75 | | 3 | 10 | 58.5 | [config](configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning_iter_2000.pth) |
76 |
77 |
78 | #### Results on COCO dataset
79 |
80 | * **Base Training**
81 |
82 | | Base mAP | config | ckpt |
83 | |----------|-------------------------------------------------------------------|----------|
84 | | 36.0 | [config](configs/vfa/coco/vfa_r101_c4_8xb4_coco_base-training.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_coco_base-training_iter_110000.pth) |
85 |
86 | * **Few Shot Finetuning**
87 |
88 | | Shot | nAP | config | ckpt |
89 | |------|------|------------------------------------------------------------------------|----------|
90 | | 10 | 16.8 | [config](configs/vfa/coco/vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_coco_10shot-fine-tuning_iter_10000.pth) |
91 | | 30 | 19.5 | [config](configs/vfa/coco/vfa_r101_c4_8xb4_coco_30shot-fine-tuning.py) | [ckpt](https://github.com/csuhan/VFA/releases/download/v1.0.0/vfa_r101_c4_8xb4_coco_30shot-fine-tuning_iter_20000.pth) |
92 |
93 |
94 | ### Train and Test
95 |
96 | * **Testing**
97 |
98 | ```bash
99 | # single-gpu test
100 | python test.py ${CONFIG} ${CHECKPOINT} --eval mAP|bbox
101 |
102 | # multi-gpus test
103 | bash dist_test.sh ${CONFIG} ${CHECKPOINT} ${NUM_GPU} --eval mAP|bbox
104 | ```
105 |
106 | For example:
107 | * test VFA on VOC split1 1-shot with sinel-gpu, we should run:
108 | ```bash
109 | python test.py configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py \
110 | work_dirs/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning/iter_400.pth \
111 | --eval mAP
112 | ```
113 |
114 | * test VFA on COCO 10-shot with 8 gpus, we should run:
115 | ```bash
116 | bash dist_test.sh configs/vfa/coco/vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py \
117 | work_dirs/vfa_r101_c4_8xb4_coco_10shot-fine-tuning/iter_10000.pth \
118 | 8 --eval bbox
119 | ```
120 |
121 | * **Training**
122 |
123 | ```bash
124 | # single-gpu training
125 | python train.py ${CONFIG}
126 |
127 | # multi-gpus training
128 | bash dist_train.sh ${CONFIG} ${NUM_GPU}
129 | ```
130 |
131 | For example: train VFA on VOC.
132 | ```bash
133 | # Stage I: base training.
134 | bash dist_train.sh configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split{1,2,3}_base-training.py 8
135 |
136 | # Stage II: few-shot fine-tuning on all splits and shots.
137 | voc_config_dir=configs/vfa/voc/
138 | for split in 1 2 3; do
139 | for shot in 1 2 3 5 10; do
140 | config_path=${voc_config_dir}/vfa_split${split}/vfa_r101_c4_8xb4_voc-split${split}_${shot}shot-fine-tuning.py
141 | echo $config_path
142 | bash dist_train.sh $config_path 8
143 | done
144 | done
145 | ```
146 |
147 | **Note**: All our configs and models are trained with 8 gpus. You need to change the learning rate or batch size if you use fewer/more gpus.
148 |
149 |
150 | ### Citation
151 |
152 | If you find our work useful for your research, please consider citing:
153 |
154 | ```BibTeX
155 | @InProceedings{han2023vfa,
156 | title = {Few-Shot Object Detection via Variational Feature Aggregation},
157 | author = {Han, Jiaming and Ren, Yuqiang and Ding, Jian and Yan, Ke and Xia, Gui-Song},
158 | booktitle = {Proceedings of the 37th AAAI Conference on Artificial Intelligence (AAAI-23)},
159 | year = {2023}
160 | }
161 | ```
--------------------------------------------------------------------------------
/VFA.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csuhan/VFA/e35411eb22b4fc48b524debe58dc7c09be2bf9a6/VFA.png
--------------------------------------------------------------------------------
/configs/_base_/datasets/nway_kshot/base_coco_ms.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | img_norm_cfg = dict(
3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
4 | train_multi_pipelines = dict(
5 | query=[
6 | dict(type='LoadImageFromFile'),
7 | dict(type='LoadAnnotations', with_bbox=True),
8 | dict(
9 | type='Resize',
10 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
11 | (1333, 768), (1333, 800)],
12 | keep_ratio=True,
13 | multiscale_mode='value'),
14 | dict(type='RandomFlip', flip_ratio=0.5),
15 | dict(type='Normalize', **img_norm_cfg),
16 | dict(type='Pad', size_divisor=32),
17 | dict(type='DefaultFormatBundle'),
18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
19 | ],
20 | support=[
21 | dict(type='LoadImageFromFile'),
22 | dict(type='LoadAnnotations', with_bbox=True),
23 | dict(type='Normalize', **img_norm_cfg),
24 | dict(type='GenerateMask', target_size=(224, 224)),
25 | dict(type='RandomFlip', flip_ratio=0.0),
26 | dict(type='DefaultFormatBundle'),
27 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
28 | ])
29 | test_pipeline = [
30 | dict(type='LoadImageFromFile'),
31 | dict(
32 | type='MultiScaleFlipAug',
33 | img_scale=(1333, 800),
34 | flip=False,
35 | transforms=[
36 | dict(type='Resize', keep_ratio=True),
37 | dict(type='RandomFlip'),
38 | dict(type='Normalize', **img_norm_cfg),
39 | dict(type='Pad', size_divisor=32),
40 | dict(type='ImageToTensor', keys=['img']),
41 | dict(type='Collect', keys=['img'])
42 | ])
43 | ]
44 | # classes splits are predefined in FewShotCocoDataset
45 | data_root = 'data/coco/'
46 | data = dict(
47 | samples_per_gpu=2,
48 | workers_per_gpu=2,
49 | train=dict(
50 | type='NWayKShotDataset',
51 | num_support_ways=60,
52 | num_support_shots=1,
53 | one_support_shot_per_image=True,
54 | num_used_support_shots=200,
55 | save_dataset=False,
56 | dataset=dict(
57 | type='FewShotCocoDataset',
58 | ann_cfg=[
59 | dict(
60 | type='ann_file',
61 | ann_file='data/few_shot_ann/coco/annotations/train.json')
62 | ],
63 | img_prefix=data_root,
64 | multi_pipelines=train_multi_pipelines,
65 | classes='BASE_CLASSES',
66 | instance_wise=False,
67 | dataset_name='query_support_dataset')),
68 | val=dict(
69 | type='FewShotCocoDataset',
70 | ann_cfg=[
71 | dict(
72 | type='ann_file',
73 | ann_file='data/few_shot_ann/coco/annotations/val.json')
74 | ],
75 | img_prefix=data_root,
76 | pipeline=test_pipeline,
77 | classes='BASE_CLASSES'),
78 | test=dict(
79 | type='FewShotCocoDataset',
80 | ann_cfg=[
81 | dict(
82 | type='ann_file',
83 | ann_file='data/few_shot_ann/coco/annotations/val.json')
84 | ],
85 | img_prefix=data_root,
86 | pipeline=test_pipeline,
87 | test_mode=True,
88 | classes='BASE_CLASSES'),
89 | model_init=dict(
90 | copy_from_train_dataset=True,
91 | samples_per_gpu=16,
92 | workers_per_gpu=1,
93 | type='FewShotCocoDataset',
94 | ann_cfg=None,
95 | img_prefix=data_root,
96 | pipeline=train_multi_pipelines['support'],
97 | instance_wise=True,
98 | classes='BASE_CLASSES',
99 | dataset_name='model_init_dataset'))
100 | evaluation = dict(interval=20000, metric='bbox', classwise=True)
101 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/nway_kshot/base_voc_ms.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | img_norm_cfg = dict(
3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
4 | train_multi_pipelines = dict(
5 | query=[
6 | dict(type='LoadImageFromFile'),
7 | dict(type='LoadAnnotations', with_bbox=True),
8 | dict(
9 | type='Resize',
10 | img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576),
11 | (1333, 608), (1333, 640), (1333, 672), (1333, 704),
12 | (1333, 736), (1333, 768), (1333, 800)],
13 | keep_ratio=True,
14 | multiscale_mode='value'),
15 | dict(type='RandomFlip', flip_ratio=0.5),
16 | dict(type='Normalize', **img_norm_cfg),
17 | dict(type='Pad', size_divisor=32),
18 | dict(type='DefaultFormatBundle'),
19 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
20 | ],
21 | support=[
22 | dict(type='LoadImageFromFile'),
23 | dict(type='LoadAnnotations', with_bbox=True),
24 | dict(type='Normalize', **img_norm_cfg),
25 | dict(type='GenerateMask', target_size=(224, 224)),
26 | dict(type='RandomFlip', flip_ratio=0.0),
27 | dict(type='DefaultFormatBundle'),
28 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
29 | ])
30 | test_pipeline = [
31 | dict(type='LoadImageFromFile'),
32 | dict(
33 | type='MultiScaleFlipAug',
34 | img_scale=(1333, 800),
35 | flip=False,
36 | transforms=[
37 | dict(type='Resize', keep_ratio=True),
38 | dict(type='RandomFlip'),
39 | dict(type='Normalize', **img_norm_cfg),
40 | dict(type='Pad', size_divisor=32),
41 | dict(type='ImageToTensor', keys=['img']),
42 | dict(type='Collect', keys=['img'])
43 | ])
44 | ]
45 | # classes splits are predefined in FewShotVOCDataset
46 | data_root = 'data/VOCdevkit/'
47 | data = dict(
48 | samples_per_gpu=4,
49 | workers_per_gpu=4,
50 | train=dict(
51 | type='NWayKShotDataset',
52 | num_support_ways=15,
53 | num_support_shots=1,
54 | one_support_shot_per_image=True,
55 | num_used_support_shots=200,
56 | save_dataset=False,
57 | dataset=dict(
58 | type='FewShotVOCDataset',
59 | ann_cfg=[
60 | dict(
61 | type='ann_file',
62 | ann_file=data_root +
63 | 'VOC2007/ImageSets/Main/trainval.txt'),
64 | dict(
65 | type='ann_file',
66 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt')
67 | ],
68 | img_prefix=data_root,
69 | multi_pipelines=train_multi_pipelines,
70 | classes=None,
71 | use_difficult=True,
72 | instance_wise=False,
73 | dataset_name='query_dataset'),
74 | support_dataset=dict(
75 | type='FewShotVOCDataset',
76 | ann_cfg=[
77 | dict(
78 | type='ann_file',
79 | ann_file=data_root +
80 | 'VOC2007/ImageSets/Main/trainval.txt'),
81 | dict(
82 | type='ann_file',
83 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt')
84 | ],
85 | img_prefix=data_root,
86 | multi_pipelines=train_multi_pipelines,
87 | classes=None,
88 | use_difficult=False,
89 | instance_wise=False,
90 | dataset_name='support_dataset')),
91 | val=dict(
92 | type='FewShotVOCDataset',
93 | ann_cfg=[
94 | dict(
95 | type='ann_file',
96 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt')
97 | ],
98 | img_prefix=data_root,
99 | pipeline=test_pipeline,
100 | classes=None),
101 | test=dict(
102 | type='FewShotVOCDataset',
103 | ann_cfg=[
104 | dict(
105 | type='ann_file',
106 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt')
107 | ],
108 | img_prefix=data_root,
109 | pipeline=test_pipeline,
110 | test_mode=True,
111 | classes=None),
112 | model_init=dict(
113 | copy_from_train_dataset=True,
114 | samples_per_gpu=16,
115 | workers_per_gpu=1,
116 | type='FewShotVOCDataset',
117 | ann_cfg=None,
118 | img_prefix=data_root,
119 | pipeline=train_multi_pipelines['support'],
120 | use_difficult=False,
121 | instance_wise=True,
122 | classes=None,
123 | dataset_name='model_init_dataset'))
124 | evaluation = dict(interval=5000, metric='mAP')
125 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/nway_kshot/few_shot_coco_ms.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | img_norm_cfg = dict(
3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
4 | train_multi_pipelines = dict(
5 | query=[
6 | dict(type='LoadImageFromFile'),
7 | dict(type='LoadAnnotations', with_bbox=True),
8 | dict(
9 | type='Resize',
10 | img_scale=[(1333, 640), (1333, 672), (1333, 704), (1333, 736),
11 | (1333, 768), (1333, 800)],
12 | keep_ratio=True,
13 | multiscale_mode='value'),
14 | dict(type='RandomFlip', flip_ratio=0.5),
15 | dict(type='Normalize', **img_norm_cfg),
16 | dict(type='Pad', size_divisor=32),
17 | dict(type='DefaultFormatBundle'),
18 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
19 | ],
20 | support=[
21 | dict(type='LoadImageFromFile'),
22 | dict(type='LoadAnnotations', with_bbox=True),
23 | dict(type='Normalize', **img_norm_cfg),
24 | dict(type='GenerateMask', target_size=(224, 224)),
25 | dict(type='RandomFlip', flip_ratio=0.0),
26 | dict(type='DefaultFormatBundle'),
27 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
28 | ])
29 | test_pipeline = [
30 | dict(type='LoadImageFromFile'),
31 | dict(
32 | type='MultiScaleFlipAug',
33 | img_scale=(1333, 800),
34 | flip=False,
35 | transforms=[
36 | dict(type='Resize', keep_ratio=True),
37 | dict(type='RandomFlip'),
38 | dict(type='Normalize', **img_norm_cfg),
39 | dict(type='Pad', size_divisor=32),
40 | dict(type='ImageToTensor', keys=['img']),
41 | dict(type='Collect', keys=['img'])
42 | ])
43 | ]
44 | # classes splits are predefined in FewShotCocoDataset
45 | data_root = 'data/coco/'
46 | data = dict(
47 | samples_per_gpu=2,
48 | workers_per_gpu=2,
49 | train=dict(
50 | type='NWayKShotDataset',
51 | num_support_ways=80,
52 | num_support_shots=1,
53 | one_support_shot_per_image=False,
54 | num_used_support_shots=30,
55 | save_dataset=True,
56 | dataset=dict(
57 | type='FewShotCocoDataset',
58 | ann_cfg=[
59 | dict(
60 | type='ann_file',
61 | ann_file='data/few_shot_ann/coco/annotations/train.json')
62 | ],
63 | img_prefix=data_root,
64 | multi_pipelines=train_multi_pipelines,
65 | classes='ALL_CLASSES',
66 | instance_wise=False,
67 | dataset_name='query_support_dataset')),
68 | val=dict(
69 | type='FewShotCocoDataset',
70 | ann_cfg=[
71 | dict(
72 | type='ann_file',
73 | ann_file='data/few_shot_ann/coco/annotations/val.json')
74 | ],
75 | img_prefix=data_root,
76 | pipeline=test_pipeline,
77 | classes='ALL_CLASSES'),
78 | test=dict(
79 | type='FewShotCocoDataset',
80 | ann_cfg=[
81 | dict(
82 | type='ann_file',
83 | ann_file='data/few_shot_ann/coco/annotations/val.json')
84 | ],
85 | img_prefix=data_root,
86 | pipeline=test_pipeline,
87 | test_mode=True,
88 | classes='ALL_CLASSES'),
89 | model_init=dict(
90 | copy_from_train_dataset=True,
91 | samples_per_gpu=16,
92 | workers_per_gpu=1,
93 | type='FewShotCocoDataset',
94 | ann_cfg=None,
95 | img_prefix=data_root,
96 | pipeline=train_multi_pipelines['support'],
97 | instance_wise=True,
98 | classes='ALL_CLASSES',
99 | dataset_name='model_init_dataset'))
100 | evaluation = dict(
101 | interval=3000,
102 | metric='bbox',
103 | classwise=True,
104 | class_splits=['BASE_CLASSES', 'NOVEL_CLASSES'])
105 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/nway_kshot/few_shot_voc_ms.py:
--------------------------------------------------------------------------------
1 | # dataset settings
2 | img_norm_cfg = dict(
3 | mean=[103.530, 116.280, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
4 | train_multi_pipelines = dict(
5 | query=[
6 | dict(type='LoadImageFromFile'),
7 | dict(type='LoadAnnotations', with_bbox=True),
8 | dict(
9 | type='Resize',
10 | img_scale=[(1333, 480), (1333, 512), (1333, 544), (1333, 576),
11 | (1333, 608), (1333, 640), (1333, 672), (1333, 704),
12 | (1333, 736), (1333, 768), (1333, 800)],
13 | keep_ratio=True,
14 | multiscale_mode='value'),
15 | dict(type='RandomFlip', flip_ratio=0.5),
16 | dict(type='Normalize', **img_norm_cfg),
17 | dict(type='Pad', size_divisor=32),
18 | dict(type='DefaultFormatBundle'),
19 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
20 | ],
21 | support=[
22 | dict(type='LoadImageFromFile'),
23 | dict(type='LoadAnnotations', with_bbox=True),
24 | dict(type='Normalize', **img_norm_cfg),
25 | dict(type='GenerateMask', target_size=(224, 224)),
26 | dict(type='RandomFlip', flip_ratio=0.0),
27 | dict(type='DefaultFormatBundle'),
28 | dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
29 | ])
30 | test_pipeline = [
31 | dict(type='LoadImageFromFile'),
32 | dict(
33 | type='MultiScaleFlipAug',
34 | img_scale=(1333, 800),
35 | flip=False,
36 | transforms=[
37 | dict(type='Resize', keep_ratio=True),
38 | dict(type='RandomFlip'),
39 | dict(type='Normalize', **img_norm_cfg),
40 | dict(type='Pad', size_divisor=32),
41 | dict(type='ImageToTensor', keys=['img']),
42 | dict(type='Collect', keys=['img'])
43 | ])
44 | ]
45 | # classes splits are predefined in FewShotVOCDataset
46 | data_root = 'data/VOCdevkit/'
47 | data = dict(
48 | samples_per_gpu=4,
49 | workers_per_gpu=4,
50 | train=dict(
51 | type='NWayKShotDataset',
52 | num_support_ways=20,
53 | num_support_shots=1,
54 | one_support_shot_per_image=False,
55 | num_used_support_shots=30,
56 | save_dataset=True,
57 | dataset=dict(
58 | type='FewShotVOCDataset',
59 | ann_cfg=[
60 | dict(
61 | type='ann_file',
62 | ann_file=data_root +
63 | 'VOC2007/ImageSets/Main/trainval.txt'),
64 | dict(
65 | type='ann_file',
66 | ann_file=data_root + 'VOC2012/ImageSets/Main/trainval.txt')
67 | ],
68 | img_prefix=data_root,
69 | multi_pipelines=train_multi_pipelines,
70 | classes=None,
71 | use_difficult=False,
72 | instance_wise=False,
73 | dataset_name='query_support_dataset')),
74 | val=dict(
75 | type='FewShotVOCDataset',
76 | ann_cfg=[
77 | dict(
78 | type='ann_file',
79 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt')
80 | ],
81 | img_prefix=data_root,
82 | pipeline=test_pipeline,
83 | classes=None),
84 | test=dict(
85 | type='FewShotVOCDataset',
86 | ann_cfg=[
87 | dict(
88 | type='ann_file',
89 | ann_file=data_root + 'VOC2007/ImageSets/Main/test.txt')
90 | ],
91 | img_prefix=data_root,
92 | pipeline=test_pipeline,
93 | test_mode=True,
94 | classes=None),
95 | model_init=dict(
96 | copy_from_train_dataset=True,
97 | samples_per_gpu=16,
98 | workers_per_gpu=1,
99 | type='FewShotVOCDataset',
100 | ann_cfg=None,
101 | img_prefix=data_root,
102 | pipeline=train_multi_pipelines['support'],
103 | use_difficult=False,
104 | instance_wise=True,
105 | num_novel_shots=None,
106 | classes=None,
107 | dataset_name='model_init_dataset'))
108 | evaluation = dict(interval=3000, metric='mAP', class_splits=None)
109 |
--------------------------------------------------------------------------------
/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | checkpoint_config = dict(interval=5000)
2 | # yapf:disable
3 | log_config = dict(
4 | interval=50,
5 | hooks=[
6 | dict(type='TextLoggerHook'),
7 | # dict(type='TensorboardLoggerHook')
8 | ])
9 | # yapf:enable
10 | custom_hooks = [dict(type='NumClassCheckHook')]
11 |
12 | dist_params = dict(backend='nccl')
13 | log_level = 'INFO'
14 | load_from = None
15 | resume_from = None
16 | workflow = [('train', 1)]
17 | use_infinite_sampler = True
18 | # a magical seed works well in most cases for this repo!!!
19 | # using different seeds might raise some issues about reproducibility
20 | seed = 42
21 |
--------------------------------------------------------------------------------
/configs/_base_/models/faster_rcnn_r50_caffe_c4.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | norm_cfg = dict(type='BN', requires_grad=False)
3 | pretrained = 'open-mmlab://detectron2/resnet50_caffe'
4 | model = dict(
5 | type='FasterRCNN',
6 | pretrained=pretrained,
7 | backbone=dict(
8 | type='ResNet',
9 | depth=50,
10 | num_stages=3,
11 | strides=(1, 2, 2),
12 | dilations=(1, 1, 1),
13 | out_indices=(2, ),
14 | frozen_stages=1,
15 | norm_cfg=norm_cfg,
16 | norm_eval=True,
17 | style='caffe'),
18 | rpn_head=dict(
19 | type='RPNHead',
20 | in_channels=1024,
21 | feat_channels=1024,
22 | anchor_generator=dict(
23 | type='AnchorGenerator',
24 | scales=[2, 4, 8, 16, 32],
25 | ratios=[0.5, 1.0, 2.0],
26 | scale_major=False,
27 | strides=[16]),
28 | bbox_coder=dict(
29 | type='DeltaXYWHBBoxCoder',
30 | target_means=[.0, .0, .0, .0],
31 | target_stds=[1.0, 1.0, 1.0, 1.0]),
32 | loss_cls=dict(
33 | type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
34 | loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
35 | roi_head=dict(
36 | type='StandardRoIHead',
37 | shared_head=dict(
38 | type='ResLayer',
39 | pretrained=pretrained,
40 | depth=50,
41 | stage=3,
42 | stride=2,
43 | dilation=1,
44 | style='caffe',
45 | norm_cfg=norm_cfg,
46 | norm_eval=True),
47 | bbox_roi_extractor=dict(
48 | type='SingleRoIExtractor',
49 | roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
50 | out_channels=1024,
51 | featmap_strides=[16]),
52 | bbox_head=dict(
53 | type='BBoxHead',
54 | with_avg_pool=True,
55 | roi_feat_size=7,
56 | in_channels=2048,
57 | num_classes=80,
58 | bbox_coder=dict(
59 | type='DeltaXYWHBBoxCoder',
60 | target_means=[0., 0., 0., 0.],
61 | target_stds=[0.1, 0.1, 0.2, 0.2]),
62 | reg_class_agnostic=False,
63 | loss_cls=dict(
64 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
65 | loss_bbox=dict(type='L1Loss', loss_weight=1.0))),
66 | # model training and testing settings
67 | train_cfg=dict(
68 | rpn=dict(
69 | assigner=dict(
70 | type='MaxIoUAssigner',
71 | pos_iou_thr=0.7,
72 | neg_iou_thr=0.3,
73 | min_pos_iou=0.3,
74 | match_low_quality=True,
75 | ignore_iof_thr=-1),
76 | sampler=dict(
77 | type='RandomSampler',
78 | num=256,
79 | pos_fraction=0.5,
80 | neg_pos_ub=-1,
81 | add_gt_as_proposals=False),
82 | allowed_border=0,
83 | pos_weight=-1,
84 | debug=False),
85 | rpn_proposal=dict(
86 | nms_pre=12000,
87 | max_per_img=2000,
88 | nms=dict(type='nms', iou_threshold=0.7),
89 | min_bbox_size=0),
90 | rcnn=dict(
91 | assigner=dict(
92 | type='MaxIoUAssigner',
93 | pos_iou_thr=0.5,
94 | neg_iou_thr=0.5,
95 | min_pos_iou=0.5,
96 | match_low_quality=False,
97 | ignore_iof_thr=-1),
98 | sampler=dict(
99 | type='RandomSampler',
100 | num=512,
101 | pos_fraction=0.25,
102 | neg_pos_ub=-1,
103 | add_gt_as_proposals=True),
104 | pos_weight=-1,
105 | debug=False)),
106 | test_cfg=dict(
107 | rpn=dict(
108 | nms_pre=6000,
109 | max_per_img=1000,
110 | nms=dict(type='nms', iou_threshold=0.7),
111 | min_bbox_size=0),
112 | rcnn=dict(
113 | score_thr=0.05,
114 | nms=dict(type='nms', iou_threshold=0.5),
115 | max_per_img=100)))
116 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/schedule.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
3 | optimizer_config = dict(grad_clip=None)
4 | # learning policy
5 | lr_config = dict(
6 | policy='step',
7 | warmup='linear',
8 | warmup_iters=500,
9 | warmup_ratio=0.001,
10 | step=[60000, 80000])
11 | runner = dict(type='IterBasedRunner', max_iters=90000)
12 |
--------------------------------------------------------------------------------
/configs/vfa/coco/vfa_r101_c4_8xb4_coco_10shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/datasets/nway_kshot/few_shot_coco_ms.py',
3 | '../../_base_/schedules/schedule.py', '../vfa_r101_c4.py',
4 | '../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotCocoDataset
7 | # FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | num_used_support_shots=10,
12 | dataset=dict(
13 | type='FewShotCocoDefaultDataset',
14 | ann_cfg=[dict(method='MetaRCNN', setting='10SHOT')],
15 | num_novel_shots=10,
16 | num_base_shots=10,
17 | )),
18 | model_init=dict(num_novel_shots=10, num_base_shots=10))
19 | evaluation = dict(interval=10000)
20 | checkpoint_config = dict(interval=10000)
21 | optimizer = dict(lr=0.001)
22 | lr_config = dict(warmup=None, step=[10000])
23 | runner = dict(max_iters=10000)
24 | # load_from = 'path of base training model'
25 | load_from = 'work_dirs/vfa_r101_c4_8xb4_coco_base-training/latest.pth'
26 | # model settings
27 | model = dict(
28 | roi_head=dict(bbox_head=dict(num_classes=80, num_meta_classes=80)),
29 | with_refine=True,
30 | frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head.rpn_conv',
32 | ])
33 |
34 | # iter 10000
35 | # OrderedDict([('BASE_CLASSES bbox_mAP', 0.309), ('BASE_CLASSES bbox_mAP_50', 0.509), ('BASE_CLASSES bbox_mAP_75', 0.33), ('BASE_CLASSES bbox_mAP_s', 0.165), ('BASE_CLASSES bbox_mAP_m', 0.351), ('BASE_CLASSES bbox_mAP_l', 0.442), ('BASE_CLASSES bbox_mAP_copypaste', '0.309 0.509 0.330 0.165 0.351 0.442'), ('NOVEL_CLASSES bbox_mAP', 0.168), ('NOVEL_CLASSES bbox_mAP_50', 0.363), ('NOVEL_CLASSES bbox_mAP_75', 0.136), ('NOVEL_CLASSES bbox_mAP_s', 0.074), ('NOVEL_CLASSES bbox_mAP_m', 0.176), ('NOVEL_CLASSES bbox_mAP_l', 0.253), ('NOVEL_CLASSES bbox_mAP_copypaste', '0.168 0.363 0.136 0.074 0.176 0.253'), ('bbox_mAP', 0.273), ('bbox_mAP_50', 0.473), ('bbox_mAP_75', 0.282), ('bbox_mAP_s', 0.142), ('bbox_mAP_m', 0.308), ('bbox_mAP_l', 0.395), ('bbox_mAP_copypaste', '0.273 0.473 0.282 0.142 0.308 0.395')])
--------------------------------------------------------------------------------
/configs/vfa/coco/vfa_r101_c4_8xb4_coco_30shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/datasets/nway_kshot/few_shot_coco_ms.py',
3 | '../../_base_/schedules/schedule.py', '../vfa_r101_c4.py',
4 | '../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotCocoDataset
7 | # FewShotCocoDefaultDataset predefine ann_cfg for model reproducibility
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | num_used_support_shots=30,
12 | dataset=dict(
13 | type='FewShotCocoDefaultDataset',
14 | ann_cfg=[dict(method='MetaRCNN', setting='30SHOT')],
15 | num_novel_shots=30,
16 | num_base_shots=30,
17 | )),
18 | model_init=dict(num_novel_shots=30, num_base_shots=30))
19 | evaluation = dict(interval=20000)
20 | checkpoint_config = dict(interval=20000)
21 | optimizer = dict(lr=0.001)
22 | lr_config = dict(warmup=None, step=[20000])
23 | runner = dict(max_iters=20000)
24 | # load_from = 'path of base training model'
25 | load_from = 'work_dirs/vfa_r101_c4_8xb4_coco_base-training/latest.pth'
26 | # model settings
27 | model = dict(
28 | roi_head=dict(bbox_head=dict(num_classes=80, num_meta_classes=80)),
29 | with_refine=True,
30 | frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head.rpn_conv',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/coco/vfa_r101_c4_8xb4_coco_base-training.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../_base_/datasets/nway_kshot/base_coco_ms.py',
3 | '../../_base_/schedules/schedule.py', '../vfa_r101_c4.py',
4 | '../../_base_/default_runtime.py'
5 | ]
6 | lr_config = dict(warmup_iters=1000, step=[85000, 100000])
7 | evaluation = dict(interval=10000)
8 | checkpoint_config = dict(interval=10000)
9 | runner = dict(max_iters=110000)
10 | optimizer = dict(lr=0.005)
11 | # model settings
12 | model = dict(roi_head=dict(bbox_head=dict(num_classes=60, num_meta_classes=60)))
--------------------------------------------------------------------------------
/configs/vfa/meta-rcnn_r50_c4.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/faster_rcnn_r50_caffe_c4.py',
3 | ]
4 | # model settings
5 | model = dict(
6 | type='MetaRCNN',
7 | backbone=dict(type='ResNetWithMetaConv', frozen_stages=2),
8 | rpn_head=dict(
9 | feat_channels=512, loss_cls=dict(use_sigmoid=False, loss_weight=1.0)),
10 | roi_head=dict(
11 | type='MetaRCNNRoIHead',
12 | shared_head=dict(type='MetaRCNNResLayer'),
13 | bbox_head=dict(
14 | type='MetaBBoxHead',
15 | with_avg_pool=False,
16 | in_channels=2048,
17 | roi_feat_size=1,
18 | num_classes=80,
19 | num_meta_classes=80,
20 | meta_cls_in_channels=2048,
21 | with_meta_cls_loss=True,
22 | loss_meta=dict(
23 | type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
24 | loss_bbox=dict(type='SmoothL1Loss', loss_weight=1.0)),
25 | aggregation_layer=dict(
26 | type='AggregationLayer',
27 | aggregator_cfgs=[
28 | dict(
29 | type='DotProductAggregator',
30 | in_channels=2048,
31 | with_fc=False)
32 | ])),
33 | train_cfg=dict(
34 | rpn=dict(
35 | assigner=dict(
36 | type='MaxIoUAssigner',
37 | pos_iou_thr=0.7,
38 | neg_iou_thr=0.3,
39 | min_pos_iou=0.3,
40 | match_low_quality=True,
41 | ignore_iof_thr=-1),
42 | sampler=dict(
43 | type='RandomSampler',
44 | num=256,
45 | pos_fraction=0.5,
46 | neg_pos_ub=-1,
47 | add_gt_as_proposals=False),
48 | allowed_border=0,
49 | pos_weight=-1,
50 | debug=False),
51 | rpn_proposal=dict(
52 | nms_pre=12000,
53 | max_per_img=2000,
54 | nms=dict(type='nms', iou_threshold=0.7),
55 | min_bbox_size=0),
56 | rcnn=dict(
57 | assigner=dict(
58 | type='MaxIoUAssigner',
59 | pos_iou_thr=0.5,
60 | neg_iou_thr=0.5,
61 | min_pos_iou=0.5,
62 | match_low_quality=False,
63 | ignore_iof_thr=-1),
64 | sampler=dict(
65 | type='RandomSampler',
66 | num=128,
67 | pos_fraction=0.25,
68 | neg_pos_ub=-1,
69 | add_gt_as_proposals=True),
70 | pos_weight=-1,
71 | debug=False)),
72 | test_cfg=dict(
73 | rpn=dict(
74 | nms_pre=6000,
75 | max_per_img=300,
76 | nms=dict(type='nms', iou_threshold=0.7),
77 | min_bbox_size=0),
78 | rcnn=dict(
79 | score_thr=0.05,
80 | nms=dict(type='nms', iou_threshold=0.3),
81 | max_per_img=100)))
82 |
--------------------------------------------------------------------------------
/configs/vfa/vfa_r101_c4.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | './meta-rcnn_r50_c4.py',
3 | ]
4 |
5 | custom_imports = dict(
6 | imports=[
7 | 'vfa.vfa_detector',
8 | 'vfa.vfa_roi_head',
9 | 'vfa.vfa_bbox_head'],
10 | allow_failed_imports=False)
11 |
12 | pretrained = 'open-mmlab://detectron2/resnet101_caffe'
13 | # model settings
14 | model = dict(
15 | type='VFA',
16 | pretrained=pretrained,
17 | backbone=dict(depth=101),
18 | roi_head=dict(
19 | type='VFARoIHead',
20 | shared_head=dict(pretrained=pretrained),
21 | bbox_head=dict(
22 | type='VFABBoxHead', num_classes=20, num_meta_classes=20)))
23 |
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_10shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_10SHOT')],
14 | num_novel_shots=10,
15 | num_base_shots=10,
16 | classes='ALL_CLASSES_SPLIT1',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT1'),
19 | test=dict(classes='ALL_CLASSES_SPLIT1'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT1'))
21 | evaluation = dict(
22 | interval=2000, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
23 | checkpoint_config = dict(interval=2000)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=2000)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_1shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_1SHOT')],
14 | num_novel_shots=1,
15 | num_base_shots=1,
16 | classes='ALL_CLASSES_SPLIT1',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT1'),
19 | test=dict(classes='ALL_CLASSES_SPLIT1'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT1'))
21 | evaluation = dict(
22 | interval=400, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
23 | checkpoint_config = dict(interval=400)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=400)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_2shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_2SHOT')],
14 | num_novel_shots=2,
15 | num_base_shots=2,
16 | classes='ALL_CLASSES_SPLIT1',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT1'),
19 | test=dict(classes='ALL_CLASSES_SPLIT1'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT1'))
21 | evaluation = dict(
22 | interval=800, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
23 | checkpoint_config = dict(interval=800)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=800)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_3shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_3SHOT')],
14 | num_novel_shots=3,
15 | num_base_shots=3,
16 | classes='ALL_CLASSES_SPLIT1',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT1'),
19 | test=dict(classes='ALL_CLASSES_SPLIT1'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT1'))
21 | evaluation = dict(
22 | interval=1200, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
23 | checkpoint_config = dict(interval=1200)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=1200)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_5shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT1_5SHOT')],
14 | num_novel_shots=5,
15 | num_base_shots=5,
16 | classes='ALL_CLASSES_SPLIT1',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT1'),
19 | test=dict(classes='ALL_CLASSES_SPLIT1'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT1'))
21 | evaluation = dict(
22 | interval=1600, class_splits=['BASE_CLASSES_SPLIT1', 'NOVEL_CLASSES_SPLIT1'])
23 | checkpoint_config = dict(interval=1600)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=1600)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split1_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split1/vfa_r101_c4_8xb4_voc-split1_base-training.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=False,
11 | dataset=dict(classes='BASE_CLASSES_SPLIT1'),
12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT1')),
13 | val=dict(classes='BASE_CLASSES_SPLIT1'),
14 | test=dict(classes='BASE_CLASSES_SPLIT1'),
15 | model_init=dict(classes='BASE_CLASSES_SPLIT1'))
16 | lr_config = dict(warmup_iters=100, step=[12000, 16000])
17 | evaluation = dict(interval=3000)
18 | checkpoint_config = dict(interval=3000)
19 | runner = dict(max_iters=18000)
20 | optimizer = dict(lr=0.02)
21 | # model settings
22 | model = dict(roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15)))
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_10shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_10SHOT')],
14 | num_novel_shots=10,
15 | num_base_shots=10,
16 | classes='ALL_CLASSES_SPLIT2',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT2'),
19 | test=dict(classes='ALL_CLASSES_SPLIT2'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT2'))
21 | evaluation = dict(
22 | interval=2000, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
23 | checkpoint_config = dict(interval=2000)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=2000)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_1shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_1SHOT')],
14 | num_novel_shots=1,
15 | num_base_shots=1,
16 | classes='ALL_CLASSES_SPLIT2',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT2'),
19 | test=dict(classes='ALL_CLASSES_SPLIT2'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT2'))
21 | evaluation = dict(
22 | interval=400, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
23 | checkpoint_config = dict(interval=400)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=400)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_2shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_2SHOT')],
14 | num_novel_shots=2,
15 | num_base_shots=2,
16 | classes='ALL_CLASSES_SPLIT2',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT2'),
19 | test=dict(classes='ALL_CLASSES_SPLIT2'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT2'))
21 | evaluation = dict(
22 | interval=800, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
23 | checkpoint_config = dict(interval=800)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=800)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_3shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_3SHOT')],
14 | num_novel_shots=3,
15 | num_base_shots=3,
16 | classes='ALL_CLASSES_SPLIT2',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT2'),
19 | test=dict(classes='ALL_CLASSES_SPLIT2'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT2'))
21 | evaluation = dict(
22 | interval=1200, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
23 | checkpoint_config = dict(interval=1200)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=1200)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_5shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT2_5SHOT')],
14 | num_novel_shots=5,
15 | num_base_shots=5,
16 | classes='ALL_CLASSES_SPLIT2',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT2'),
19 | test=dict(classes='ALL_CLASSES_SPLIT2'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT2'))
21 | evaluation = dict(
22 | interval=1600, class_splits=['BASE_CLASSES_SPLIT2', 'NOVEL_CLASSES_SPLIT2'])
23 | checkpoint_config = dict(interval=1600)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=1600)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split2_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split2/vfa_r101_c4_8xb4_voc-split2_base-training.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=False,
11 | dataset=dict(classes='BASE_CLASSES_SPLIT2'),
12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT2')),
13 | val=dict(classes='BASE_CLASSES_SPLIT2'),
14 | test=dict(classes='BASE_CLASSES_SPLIT2'),
15 | model_init=dict(classes='BASE_CLASSES_SPLIT2'))
16 | lr_config = dict(warmup_iters=100, step=[12000, 16000])
17 | evaluation = dict(interval=3000)
18 | checkpoint_config = dict(interval=3000)
19 | runner = dict(max_iters=18000)
20 | optimizer = dict(lr=0.02)
21 | # model settings
22 | model = dict(roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15)))
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_10shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_10SHOT')],
14 | num_novel_shots=10,
15 | num_base_shots=10,
16 | classes='ALL_CLASSES_SPLIT3',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT3'),
19 | test=dict(classes='ALL_CLASSES_SPLIT3'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT3'))
21 | evaluation = dict(
22 | interval=2000, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
23 | checkpoint_config = dict(interval=2000)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=2000)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_1shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_1SHOT')],
14 | num_novel_shots=1,
15 | num_base_shots=1,
16 | classes='ALL_CLASSES_SPLIT3',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT3'),
19 | test=dict(classes='ALL_CLASSES_SPLIT3'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT3'))
21 | evaluation = dict(
22 | interval=400, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
23 | checkpoint_config = dict(interval=400)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=400)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_2shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_2SHOT')],
14 | num_novel_shots=2,
15 | num_base_shots=2,
16 | classes='ALL_CLASSES_SPLIT3',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT3'),
19 | test=dict(classes='ALL_CLASSES_SPLIT3'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT3'))
21 | evaluation = dict(
22 | interval=800, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
23 | checkpoint_config = dict(interval=800)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=800)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_3shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_3SHOT')],
14 | num_novel_shots=3,
15 | num_base_shots=3,
16 | classes='ALL_CLASSES_SPLIT3',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT3'),
19 | test=dict(classes='ALL_CLASSES_SPLIT3'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT3'))
21 | evaluation = dict(
22 | interval=1200, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
23 | checkpoint_config = dict(interval=1200)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=1200)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_5shot-fine-tuning.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/few_shot_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=True,
11 | dataset=dict(
12 | type='FewShotVOCDefaultDataset',
13 | ann_cfg=[dict(method='MetaRCNN', setting='SPLIT3_5SHOT')],
14 | num_novel_shots=5,
15 | num_base_shots=5,
16 | classes='ALL_CLASSES_SPLIT3',
17 | )),
18 | val=dict(classes='ALL_CLASSES_SPLIT3'),
19 | test=dict(classes='ALL_CLASSES_SPLIT3'),
20 | model_init=dict(classes='ALL_CLASSES_SPLIT3'))
21 | evaluation = dict(
22 | interval=1600, class_splits=['BASE_CLASSES_SPLIT3', 'NOVEL_CLASSES_SPLIT3'])
23 | checkpoint_config = dict(interval=1600)
24 | optimizer = dict(lr=0.001)
25 | lr_config = dict(warmup=None)
26 | runner = dict(max_iters=1600)
27 | load_from = 'work_dirs/vfa_r101_c4_8xb4_voc-split3_base-training/iter_18000.pth'
28 |
29 | # model settings
30 | model = dict(frozen_parameters=[
31 | 'backbone', 'shared_head', 'aggregation_layer', 'rpn_head',
32 | ])
--------------------------------------------------------------------------------
/configs/vfa/voc/vfa_split3/vfa_r101_c4_8xb4_voc-split3_base-training.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../../../_base_/datasets/nway_kshot/base_voc_ms.py',
3 | '../../../_base_/schedules/schedule.py', '../../vfa_r101_c4.py',
4 | '../../../_base_/default_runtime.py'
5 | ]
6 | # classes splits are predefined in FewShotVOCDataset
7 | # FewShotVOCDefaultDataset predefine ann_cfg for model reproducibility.
8 | data = dict(
9 | train=dict(
10 | save_dataset=False,
11 | dataset=dict(classes='BASE_CLASSES_SPLIT3'),
12 | support_dataset=dict(classes='BASE_CLASSES_SPLIT3')),
13 | val=dict(classes='BASE_CLASSES_SPLIT3'),
14 | test=dict(classes='BASE_CLASSES_SPLIT3'),
15 | model_init=dict(classes='BASE_CLASSES_SPLIT3'))
16 | lr_config = dict(warmup_iters=100, step=[12000, 16000])
17 | evaluation = dict(interval=3000)
18 | checkpoint_config = dict(interval=3000)
19 | runner = dict(max_iters=18000)
20 | optimizer = dict(lr=0.02)
21 | # model settings
22 | model = dict(roi_head=dict(bbox_head=dict(num_classes=15, num_meta_classes=15)))
--------------------------------------------------------------------------------
/dist_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | CHECKPOINT=$2
5 | GPUS=$3
6 | PORT=${PORT:-29500}
7 |
8 | PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
9 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
10 | $(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
11 |
--------------------------------------------------------------------------------
/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | PORT=${PORT:-29500}
6 |
7 | PYTHONPATH="$(dirname $0)/../../":$PYTHONPATH \
8 | python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
10 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | from setuptools import setup, find_packages
3 |
4 | setup(
5 | name="vfa",
6 | version=0.1,
7 | author="csuhan",
8 | url="https://github.com/csuhan/VFA",
9 | description="Codebase for few-shot object detection",
10 | python_requires=">=3.6",
11 | packages=find_packages(exclude=('configs', 'data', 'work_dirs')),
12 | install_requires=[
13 | 'clip@git+ssh://git@github.com/openai/CLIP.git'
14 | ],
15 | )
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import os
4 | import warnings
5 |
6 | import mmcv
7 | import torch
8 | from mmcv import Config, DictAction
9 | from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
10 | from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
11 | wrap_fp16_model)
12 |
13 | from mmfewshot.detection.datasets import (build_dataloader, build_dataset,
14 | get_copy_dataset_type)
15 | from mmfewshot.detection.models import build_detector
16 |
17 |
18 | def parse_args():
19 | parser = argparse.ArgumentParser(
20 | description='MMFewShot test (and eval) a model')
21 | parser.add_argument('config', help='test config file path')
22 | parser.add_argument('checkpoint', help='checkpoint file')
23 | parser.add_argument('--out', help='output result file in pickle format')
24 | parser.add_argument(
25 | '--eval',
26 | type=str,
27 | nargs='+',
28 | help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
29 | ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
30 | parser.add_argument('--show', action='store_true', help='show results')
31 | parser.add_argument(
32 | '--show-dir', help='directory where painted images will be saved')
33 | parser.add_argument(
34 | '--show-score-thr',
35 | type=float,
36 | default=0.3,
37 | help='score threshold (default: 0.3)')
38 | parser.add_argument(
39 | '--gpu-collect',
40 | action='store_true',
41 | help='whether to use gpu to collect results.')
42 | parser.add_argument(
43 | '--tmpdir',
44 | help='tmp directory used for collecting results from multiple '
45 | 'workers, available when gpu-collect is not specified')
46 | parser.add_argument(
47 | '--cfg-options',
48 | nargs='+',
49 | action=DictAction,
50 | help='override some settings in the used config, the key-value pair '
51 | 'in xxx=yyy format will be merged into config file. If the value to '
52 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
53 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
54 | 'Note that the quotation marks are necessary and that no white space '
55 | 'is allowed.')
56 | parser.add_argument(
57 | '--options',
58 | nargs='+',
59 | action=DictAction,
60 | help='custom options for evaluation, the key-value pair in xxx=yyy '
61 | 'format will be kwargs for dataset.evaluate() function (deprecate), '
62 | 'change to --eval-options instead.')
63 | parser.add_argument(
64 | '--eval-options',
65 | nargs='+',
66 | action=DictAction,
67 | help='custom options for evaluation, the key-value pair in xxx=yyy '
68 | 'format will be kwargs for dataset.evaluate() function')
69 | parser.add_argument(
70 | '--launcher',
71 | choices=['none', 'pytorch', 'slurm', 'mpi'],
72 | default='none',
73 | help='job launcher')
74 | parser.add_argument('--local_rank', type=int, default=0)
75 | args = parser.parse_args()
76 | if 'LOCAL_RANK' not in os.environ:
77 | os.environ['LOCAL_RANK'] = str(args.local_rank)
78 |
79 | if args.options and args.eval_options:
80 | raise ValueError(
81 | '--options and --eval-options cannot be both '
82 | 'specified, --options is deprecated in favor of --eval-options')
83 | if args.options:
84 | warnings.warn('--options is deprecated in favor of --eval-options')
85 | args.eval_options = args.options
86 | args.cfg_options = args.options
87 | return args
88 |
89 |
90 | def main():
91 | args = parse_args()
92 |
93 | assert args.out or args.eval or args.show \
94 | or args.show_dir, (
95 | 'Please specify at least one operation (save/eval/show the '
96 | 'results / save the results) with the argument "--out", "--eval"',
97 | '"--show" or "--show-dir"')
98 |
99 | if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
100 | raise ValueError('The output file must be a pkl file.')
101 |
102 | cfg = Config.fromfile(args.config)
103 |
104 | if args.cfg_options is not None:
105 | cfg.merge_from_dict(args.cfg_options)
106 |
107 | # import modules from string list.
108 | if cfg.get('custom_imports', None):
109 | from mmcv.utils import import_modules_from_strings
110 | import_modules_from_strings(**cfg['custom_imports'])
111 | # set cudnn_benchmark
112 | if cfg.get('cudnn_benchmark', False):
113 | torch.backends.cudnn.benchmark = True
114 | cfg.model.pretrained = None
115 |
116 | # currently only support single images testing
117 | samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
118 | assert samples_per_gpu == 1, 'currently only support single images testing'
119 |
120 | # init distributed env first, since logger depends on the dist info.
121 | if args.launcher == 'none':
122 | distributed = False
123 | else:
124 | distributed = True
125 | init_dist(args.launcher, **cfg.dist_params)
126 |
127 | # build the dataloader
128 | dataset = build_dataset(cfg.data.test)
129 | data_loader = build_dataloader(
130 | dataset,
131 | samples_per_gpu=samples_per_gpu,
132 | workers_per_gpu=cfg.data.workers_per_gpu,
133 | dist=distributed,
134 | shuffle=False)
135 |
136 | # pop frozen_parameters
137 | cfg.model.pop('frozen_parameters', None)
138 |
139 | # build the model and load checkpoint
140 | cfg.model.train_cfg = None
141 | model = build_detector(cfg.model)
142 |
143 | fp16_cfg = cfg.get('fp16', None)
144 | if fp16_cfg is not None:
145 | wrap_fp16_model(model)
146 | checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
147 | # old versions did not save class info in checkpoints, this walkaround is
148 | # for backward compatibility
149 | if 'CLASSES' in checkpoint.get('meta', {}):
150 | model.CLASSES = checkpoint['meta']['CLASSES']
151 | else:
152 | model.CLASSES = dataset.CLASSES
153 |
154 | # for meta-learning methods which require support template dataset
155 | # for model initialization.
156 | if cfg.data.get('model_init', None) is not None:
157 | cfg.data.model_init.pop('copy_from_train_dataset')
158 | model_init_samples_per_gpu = cfg.data.model_init.pop(
159 | 'samples_per_gpu', 1)
160 | model_init_workers_per_gpu = cfg.data.model_init.pop(
161 | 'workers_per_gpu', 1)
162 | if cfg.data.model_init.get('ann_cfg', None) is None:
163 | assert checkpoint['meta'].get('model_init_ann_cfg',
164 | None) is not None
165 | cfg.data.model_init.type = \
166 | get_copy_dataset_type(cfg.data.model_init.type)
167 | cfg.data.model_init.ann_cfg = \
168 | checkpoint['meta']['model_init_ann_cfg']
169 | model_init_dataset = build_dataset(cfg.data.model_init)
170 | # disable dist to make all rank get same data
171 | model_init_dataloader = build_dataloader(
172 | model_init_dataset,
173 | samples_per_gpu=model_init_samples_per_gpu,
174 | workers_per_gpu=model_init_workers_per_gpu,
175 | dist=False,
176 | shuffle=False)
177 |
178 | if cfg.data.get('support_init', None) is not None:
179 | model_init_dataset = build_dataset(cfg.data.support_init)
180 | # disable dist to make all rank get same data
181 | model_init_dataloader = build_dataloader(
182 | model_init_dataset,
183 | samples_per_gpu=1,
184 | workers_per_gpu=1,
185 | dist=False,
186 | shuffle=False)
187 |
188 | if not distributed:
189 | model = MMDataParallel(model, device_ids=[0])
190 | show_kwargs = dict(show_score_thr=args.show_score_thr)
191 | if (cfg.data.get('model_init', None) is not None) or (cfg.data.get('support_init', None) is not None):
192 | from mmfewshot.detection.apis import (single_gpu_model_init,
193 | single_gpu_test)
194 | single_gpu_model_init(model, model_init_dataloader)
195 | else:
196 | from mmdet.apis.test import single_gpu_test
197 | outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
198 | **show_kwargs)
199 | else:
200 | model = MMDistributedDataParallel(
201 | model.cuda(),
202 | device_ids=[torch.cuda.current_device()],
203 | broadcast_buffers=False)
204 | if (cfg.data.get('model_init', None) is not None) or (cfg.data.get('support_init', None) is not None):
205 | from mmfewshot.detection.apis import (multi_gpu_model_init,
206 | multi_gpu_test)
207 | multi_gpu_model_init(model, model_init_dataloader)
208 | else:
209 | from mmdet.apis.test import multi_gpu_test
210 | outputs = multi_gpu_test(
211 | model,
212 | data_loader,
213 | args.tmpdir,
214 | args.gpu_collect,
215 | )
216 |
217 | rank, _ = get_dist_info()
218 | if rank == 0:
219 | if args.out:
220 | print(f'\nwriting results to {args.out}')
221 | mmcv.dump(outputs, args.out)
222 | kwargs = {} if args.eval_options is None else args.eval_options
223 | if args.eval:
224 | eval_kwargs = cfg.get('evaluation', {}).copy()
225 | # hard-code way to remove EvalHook args
226 | for key in [
227 | 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
228 | 'rule'
229 | ]:
230 | eval_kwargs.pop(key, None)
231 | eval_kwargs.update(dict(metric=args.eval, **kwargs))
232 | print(dataset.evaluate(outputs, **eval_kwargs))
233 |
234 |
235 | if __name__ == '__main__':
236 | main()
237 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import argparse
3 | import copy
4 | import os
5 | import os.path as osp
6 | import multiprocessing as mp
7 | import platform
8 | import time
9 | import warnings
10 |
11 | import cv2
12 | import mmcv
13 | import torch
14 | from mmcv import Config, DictAction
15 | from mmcv.runner import get_dist_info, init_dist, set_random_seed
16 | from mmcv.utils import get_git_hash
17 | from mmdet.utils import collect_env
18 |
19 | import mmfewshot # noqa: F401, F403
20 | from mmfewshot import __version__
21 | # from mmfewshot.detection.apis import train_detector
22 | from mmfewshot.detection.apis.train import train_detector
23 | from mmfewshot.detection.datasets import build_dataset
24 | from mmfewshot.detection.models import build_detector
25 | from mmfewshot.utils import get_root_logger
26 |
27 | def parse_args():
28 | parser = argparse.ArgumentParser(description='Train a FewShot model')
29 | parser.add_argument('config', help='train config file path')
30 | parser.add_argument(
31 | '--work-dir', help='the directory to save logs and models')
32 | parser.add_argument(
33 | '--resume-from', help='the checkpoint file to resume from')
34 | parser.add_argument(
35 | '--no-validate',
36 | action='store_true',
37 | help='whether not to evaluate the checkpoint during training')
38 | group_gpus = parser.add_mutually_exclusive_group()
39 | group_gpus.add_argument(
40 | '--gpus',
41 | type=int,
42 | help='number of gpus to use '
43 | '(only applicable to non-distributed training)')
44 | group_gpus.add_argument(
45 | '--gpu-ids',
46 | type=int,
47 | nargs='+',
48 | help='ids of gpus to use '
49 | '(only applicable to non-distributed training)')
50 | parser.add_argument('--seed', type=int, default=None, help='random seed')
51 | parser.add_argument(
52 | '--deterministic',
53 | action='store_true',
54 | help='whether to set deterministic options for CUDNN backend.')
55 | parser.add_argument(
56 | '--options',
57 | nargs='+',
58 | action=DictAction,
59 | help='override some settings in the used config, the key-value pair '
60 | 'in xxx=yyy format will be merged into config file (deprecate), '
61 | 'change to --cfg-options instead.')
62 | parser.add_argument(
63 | '--cfg-options',
64 | nargs='+',
65 | action=DictAction,
66 | help='override some settings in the used config, the key-value pair '
67 | 'in xxx=yyy format will be merged into config file. If the value '
68 | 'to be overwritten is a list, it should be like key="[a,b]" or '
69 | 'key=a,b It also allows nested list/tuple values, e.g. '
70 | 'key="[(a,b),(c,d)]" Note that the quotation marks are necessary '
71 | 'and that no white space is allowed.')
72 | parser.add_argument(
73 | '--launcher',
74 | choices=['none', 'pytorch', 'slurm', 'mpi'],
75 | default='none',
76 | help='job launcher')
77 | parser.add_argument('--local_rank', type=int, default=0)
78 | args = parser.parse_args()
79 | if 'LOCAL_RANK' not in os.environ:
80 | os.environ['LOCAL_RANK'] = str(args.local_rank)
81 |
82 | if args.options and args.cfg_options:
83 | raise ValueError(
84 | '--options and --cfg-options cannot be both '
85 | 'specified, --options is deprecated in favor of --cfg-options')
86 | if args.options:
87 | warnings.warn('--options is deprecated in favor of --cfg-options')
88 | args.cfg_options = args.options
89 |
90 | return args
91 |
92 |
93 | def setup_multi_processes(cfg):
94 | # set multi-process start method as `fork` to speed up the training
95 | if platform.system() != 'Windows':
96 | mp_start_method = cfg.get('mp_start_method', 'fork')
97 | mp.set_start_method(mp_start_method)
98 |
99 | # disable opencv multithreading to avoid system being overloaded
100 | opencv_num_threads = cfg.get('opencv_num_threads', 0)
101 | cv2.setNumThreads(opencv_num_threads)
102 |
103 | # setup OMP threads
104 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
105 | if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
106 | omp_num_threads = 1
107 | warnings.warn(
108 | f'Setting OMP_NUM_THREADS environment variable for each process '
109 | f'to be {omp_num_threads} in default, to avoid your system being '
110 | f'overloaded, please further tune the variable for optimal '
111 | f'performance in your application as needed.')
112 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
113 |
114 | # setup MKL threads
115 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
116 | mkl_num_threads = 1
117 | warnings.warn(
118 | f'Setting MKL_NUM_THREADS environment variable for each process '
119 | f'to be {mkl_num_threads} in default, to avoid your system being '
120 | f'overloaded, please further tune the variable for optimal '
121 | f'performance in your application as needed.')
122 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
123 |
124 | def main():
125 | args = parse_args()
126 |
127 | cfg = Config.fromfile(args.config)
128 |
129 | if args.cfg_options is not None:
130 | cfg.merge_from_dict(args.cfg_options)
131 |
132 | # set multi-process settings
133 | setup_multi_processes(cfg)
134 |
135 | # import modules from string list.
136 | if cfg.get('custom_imports', None):
137 | from mmcv.utils import import_modules_from_strings
138 | import_modules_from_strings(**cfg['custom_imports'])
139 | # set cudnn_benchmark
140 | if cfg.get('cudnn_benchmark', False):
141 | torch.backends.cudnn.benchmark = True
142 |
143 | # work_dir is determined in this priority: CLI > segment in file > filename
144 | if args.work_dir is not None:
145 | # update configs according to CLI args if args.work_dir is not None
146 | cfg.work_dir = args.work_dir
147 | elif cfg.get('work_dir', None) is None:
148 | # use config filename as default work_dir if cfg.work_dir is None
149 | cfg.work_dir = osp.join('./work_dirs',
150 | osp.splitext(osp.basename(args.config))[0])
151 | if args.resume_from is not None:
152 | cfg.resume_from = args.resume_from
153 | if args.gpu_ids is not None:
154 | cfg.gpu_ids = args.gpu_ids
155 | else:
156 | cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
157 |
158 | # init distributed env first, since logger depends on the dist info.
159 | if args.launcher == 'none':
160 | distributed = False
161 | rank, world_size = get_dist_info()
162 | else:
163 | distributed = True
164 | init_dist(args.launcher, **cfg.dist_params)
165 | rank, world_size = get_dist_info()
166 | # re-set gpu_ids with distributed training mode
167 | cfg.gpu_ids = range(world_size)
168 | # create work_dir
169 | mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
170 | # dump config
171 | cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
172 | # init the logger before other steps
173 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
174 | log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
175 | logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
176 |
177 | # init the meta dict to record some important information such as
178 | # environment info and seed, which will be logged
179 | meta = dict()
180 | # log env info
181 | env_info_dict = collect_env()
182 | env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
183 | dash_line = '-' * 60 + '\n'
184 | logger.info('Environment info:\n' + dash_line + env_info + '\n' +
185 | dash_line)
186 | meta['env_info'] = env_info
187 | meta['config'] = cfg.pretty_text
188 | # log some basic info
189 | logger.info(f'Distributed training: {distributed}')
190 | logger.info(f'Config:\n{cfg.pretty_text}')
191 |
192 | # set random seeds
193 | if args.seed is not None:
194 | seed = args.seed
195 | elif cfg.seed is not None:
196 | seed = cfg.seed
197 | elif distributed:
198 | seed = 0
199 | Warning(f'When using DistributedDataParallel, each rank will '
200 | f'initialize different random seed. It will cause different'
201 | f'random action for each rank. In few shot setting, novel '
202 | f'shots may be generated by random sampling. If all rank do '
203 | f'not use same seed, each rank will sample different data.'
204 | f'It will cause UNFAIR data usage. Therefore, seed is set '
205 | f'to {seed} for default.')
206 | else:
207 | seed = None
208 |
209 | if seed is not None:
210 | logger.info(f'Set random seed to {seed}, '
211 | f'deterministic: {args.deterministic}')
212 | set_random_seed(seed, deterministic=args.deterministic)
213 | meta['seed'] = seed
214 | meta['exp_name'] = osp.basename(args.config)
215 |
216 | # build_detector will do three things, including building model,
217 | # initializing weights and freezing parameters (optional).
218 | model = build_detector(cfg.model, logger=logger)
219 | # build_dataset will do two things, including building dataset
220 | # and saving dataset into json file (optional).
221 | datasets = [
222 | build_dataset(
223 | cfg.data.train,
224 | rank=rank,
225 | work_dir=cfg.work_dir,
226 | timestamp=timestamp)
227 | ]
228 |
229 | if len(cfg.workflow) == 2:
230 | val_dataset = copy.deepcopy(cfg.data.val)
231 | val_dataset.pipeline = cfg.data.train.pipeline
232 | datasets.append(build_dataset(val_dataset))
233 | if cfg.checkpoint_config is not None:
234 | # save mmfewshot version, config file content and class names in
235 | # checkpoints as meta data
236 | cfg.checkpoint_config.meta = dict(
237 | mmfewshot_version=__version__ + get_git_hash()[:7],
238 | CLASSES=datasets[0].CLASSES)
239 | # add an attribute for visualization convenience
240 | model.CLASSES = datasets[0].CLASSES
241 | train_detector(
242 | model,
243 | datasets,
244 | cfg,
245 | distributed=distributed,
246 | validate=(not args.no_validate),
247 | timestamp=timestamp,
248 | meta=meta)
249 |
250 |
251 | if __name__ == '__main__':
252 | main()
253 |
--------------------------------------------------------------------------------
/vfa/__init__.py:
--------------------------------------------------------------------------------
1 | from .vfa_detector import VFA
2 | from .vfa_roi_head import VFARoIHead
3 | from .vfa_bbox_head import VFABBoxHead
4 |
5 | __all__ = ['VFA', 'VFARoIHead', 'VFABBoxHead']
6 |
--------------------------------------------------------------------------------
/vfa/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from PIL import Image
5 | from torchvision import transforms as trans
6 |
7 | from mmfewshot.detection.datasets.coco import COCO_SPLIT
8 |
9 |
10 | class PCB:
11 | def __init__(self, class_names, model="RN101", templates="a photo of a {}"):
12 | super().__init__()
13 | self.device = torch.cuda.current_device()
14 |
15 | # image transforms
16 | self.expand_ratio = 0.1
17 | self.trans = trans.Compose([
18 | trans.Resize([224, 224], interpolation=3),
19 | trans.ToTensor(),
20 | trans.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
21 |
22 | # CLIP configs
23 | import clip
24 | self.class_names = class_names
25 | self.clip, _ = clip.load(model, device=self.device)
26 | self.prompts = clip.tokenize([
27 | templates.format(cls_name)
28 | for cls_name in self.class_names
29 | ]).to(self.device)
30 | with torch.no_grad():
31 | text_features = self.clip.encode_text(self.prompts)
32 | self.text_features = F.normalize(text_features, dim=-1, p=2)
33 |
34 |
35 | def load_image_by_box(self, img_path, boxes):
36 | image = Image.open(img_path).convert("RGB")
37 | image_list = []
38 | for box in boxes:
39 | x1, y1, x2, y2 = box
40 | h, w = y2-y1, x2-x1
41 | x1 = max(0, x1 - w*self.expand_ratio)
42 | y1 = max(0, y1 - h*self.expand_ratio)
43 | x2 = x2 + w*self.expand_ratio
44 | y2 = y2 + h*self.expand_ratio
45 | sub_image = image.crop((int(x1), int(y1), int(x2), int(y2)))
46 | sub_image = self.trans(sub_image).to(self.device)
47 | image_list.append(sub_image)
48 | return torch.stack(image_list)
49 |
50 | @torch.no_grad()
51 | def __call__(self, img_path, boxes):
52 | images = self.load_image_by_box(img_path, boxes)
53 |
54 | image_features = self.clip.encode_image(images)
55 | image_features = F.normalize(image_features, dim=-1, p=2)
56 | logit_scale = self.clip.logit_scale.exp()
57 | logits_per_image = logit_scale * image_features @ self.text_features.t()
58 | return logits_per_image.softmax(dim=-1)
59 |
60 |
61 | class TestMixins:
62 | def __init__(self):
63 | self.pcb = None
64 |
65 | def refine_test(self, results, img_metas):
66 | if not hasattr(self, 'pcb'):
67 | self.pcb = PCB(COCO_SPLIT['ALL_CLASSES'], model='ViT-B/32')
68 | # exclue ids for COCO
69 | self.exclude_ids = [7, 9, 10, 11, 12, 13, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29,
70 | 30, 31, 32, 33, 34, 35, 36, 37, 38, 40, 41, 42, 43, 44, 45,
71 | 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 59, 61, 63, 64, 65,
72 | 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79]
73 |
74 | boxes_list, scores_list, labels_list = [], [], []
75 | for cls_id, result in enumerate(results[0]):
76 | if len(result) == 0:
77 | continue
78 | boxes_list.append(result[:, :4])
79 | scores_list.append(result[:, 4])
80 | labels_list.append([cls_id] * len(result))
81 |
82 | if len(boxes_list) == 0:
83 | return results
84 |
85 | boxes_list = np.concatenate(boxes_list, axis=0)
86 | scores_list = np.concatenate(scores_list, axis=0)
87 | labels_list = np.concatenate(labels_list, axis=0)
88 |
89 | logits = self.pcb(img_metas[0]['filename'], boxes_list)
90 |
91 | for i, prob in enumerate(logits):
92 | if labels_list[i] not in self.exclude_ids:
93 | scores_list[i] = scores_list[i] * 0.5 + logits[i, labels_list[i]] * 0.5
94 |
95 | j = 0
96 | for i in range(len(results[0])):
97 | num_dets = len(results[0][i])
98 | if num_dets == 0:
99 | continue
100 | for k in range(num_dets):
101 | results[0][i][k, 4] = scores_list[j]
102 | j += 1
103 |
104 | return results
--------------------------------------------------------------------------------
/vfa/vfa_bbox_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from mmcv.runner import auto_fp16
4 | from mmdet.models.builder import HEADS
5 |
6 | from mmfewshot.detection.models.roi_heads.bbox_heads.meta_bbox_head import MetaBBoxHead
7 |
8 |
9 | @HEADS.register_module()
10 | class VFABBoxHead(MetaBBoxHead):
11 |
12 | @auto_fp16()
13 | def forward(self, x_agg, x_query):
14 | if self.with_avg_pool:
15 | if x_agg.numel() > 0:
16 | x_agg = self.avg_pool(x_agg)
17 | x_agg = x_agg.view(x_agg.size(0), -1)
18 | else:
19 | # avg_pool does not support empty tensor,
20 | # so use torch.mean instead it
21 | x_agg = torch.mean(x_agg, dim=(-1, -2))
22 | if x_query.numel() > 0:
23 | x_query = self.avg_pool(x_query)
24 | x_query = x_query.view(x_query.size(0), -1)
25 | else:
26 | x_query = torch.mean(x_query, dim=(-1, -2))
27 | cls_score = self.fc_cls(x_agg) if self.with_cls else None
28 | bbox_pred = self.fc_reg(x_query) if self.with_reg else None
29 | return cls_score, bbox_pred
30 |
--------------------------------------------------------------------------------
/vfa/vfa_detector.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from typing import Dict, List, Optional
3 |
4 | from mmdet.models.builder import DETECTORS
5 | from mmfewshot.detection.models import MetaRCNN
6 |
7 | from .utils import TestMixins
8 |
9 |
10 | @DETECTORS.register_module()
11 | class VFA(MetaRCNN, TestMixins):
12 | def __init__(self, *args, with_refine=False, **kwargs):
13 | super().__init__(*args, **kwargs)
14 | # refine results for COCO. We do not use it for VOC.
15 | self.with_refine = with_refine
16 |
17 | def forward_train(self,
18 | query_data: Dict,
19 | support_data: Dict,
20 | proposals: Optional[List] = None,
21 | **kwargs) -> Dict:
22 | """Forward function for training.
23 |
24 | Args:
25 | query_data (dict): In most cases, dict of query data contains:
26 | `img`, `img_metas`, `gt_bboxes`, `gt_labels`,
27 | `gt_bboxes_ignore`.
28 | support_data (dict): In most cases, dict of support data contains:
29 | `img`, `img_metas`, `gt_bboxes`, `gt_labels`,
30 | `gt_bboxes_ignore`.
31 | proposals (list): Override rpn proposals with custom proposals.
32 | Use when `with_rpn` is False. Default: None.
33 |
34 | Returns:
35 | dict[str, Tensor]: a dictionary of loss components
36 | """
37 | query_img = query_data['img']
38 | support_img = support_data['img']
39 | query_feats = self.extract_query_feat(query_img)
40 |
41 | # stop gradient at RPN
42 | query_feats_rpn = [x.detach() for x in query_feats]
43 | query_feats_rcnn = query_feats
44 |
45 | support_feats = self.extract_support_feat(support_img)
46 |
47 | losses = dict()
48 |
49 | # RPN forward and loss
50 | if self.with_rpn:
51 | proposal_cfg = self.train_cfg.get('rpn_proposal',
52 | self.test_cfg.rpn)
53 | if self.rpn_with_support:
54 | rpn_losses, proposal_list = self.rpn_head.forward_train(
55 | query_feats_rpn,
56 | support_feats,
57 | query_img_metas=query_data['img_metas'],
58 | query_gt_bboxes=query_data['gt_bboxes'],
59 | query_gt_labels=None,
60 | query_gt_bboxes_ignore=query_data.get(
61 | 'gt_bboxes_ignore', None),
62 | support_img_metas=support_data['img_metas'],
63 | support_gt_bboxes=support_data['gt_bboxes'],
64 | support_gt_labels=support_data['gt_labels'],
65 | support_gt_bboxes_ignore=support_data.get(
66 | 'gt_bboxes_ignore', None),
67 | proposal_cfg=proposal_cfg)
68 | else:
69 | rpn_losses, proposal_list = self.rpn_head.forward_train(
70 | query_feats_rpn,
71 | copy.deepcopy(query_data['img_metas']),
72 | copy.deepcopy(query_data['gt_bboxes']),
73 | gt_labels=None,
74 | gt_bboxes_ignore=copy.deepcopy(
75 | query_data.get('gt_bboxes_ignore', None)),
76 | proposal_cfg=proposal_cfg)
77 | losses.update(rpn_losses)
78 | else:
79 | proposal_list = proposals
80 |
81 | roi_losses = self.roi_head.forward_train(
82 | query_feats_rcnn,
83 | support_feats,
84 | proposals=proposal_list,
85 | query_img_metas=query_data['img_metas'],
86 | query_gt_bboxes=query_data['gt_bboxes'],
87 | query_gt_labels=query_data['gt_labels'],
88 | query_gt_bboxes_ignore=query_data.get('gt_bboxes_ignore', None),
89 | support_img_metas=support_data['img_metas'],
90 | support_gt_bboxes=support_data['gt_bboxes'],
91 | support_gt_labels=support_data['gt_labels'],
92 | support_gt_bboxes_ignore=support_data.get('gt_bboxes_ignore',
93 | None),
94 | **kwargs)
95 | losses.update(roi_losses)
96 |
97 | return losses
98 |
99 | def simple_test(self, img, img_metas, proposals = None, rescale = False):
100 | bbox_results = super().simple_test(img, img_metas, proposals, rescale)
101 | if self.with_refine:
102 | return self.refine_test(bbox_results, img_metas)
103 | else:
104 | return bbox_results
105 |
106 |
107 |
--------------------------------------------------------------------------------
/vfa/vfa_roi_head.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Optional, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | from torch import Tensor
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from mmcv.utils import ConfigDict
9 | from mmdet.core import bbox2roi
10 | from mmdet.models.builder import HEADS
11 | from mmfewshot.detection.models.roi_heads.meta_rcnn_roi_head import MetaRCNNRoIHead
12 |
13 |
14 | class VAE(nn.Module):
15 |
16 | def __init__(self,
17 | in_channels: int,
18 | latent_dim: int,
19 | hidden_dim: int) -> None:
20 | super(VAE, self).__init__()
21 |
22 | self.latent_dim = latent_dim
23 |
24 | self.encoder = nn.Sequential(
25 | nn.Linear(in_channels, hidden_dim),
26 | nn.BatchNorm1d(hidden_dim),
27 | nn.LeakyReLU()
28 | )
29 | self.fc_mu = nn.Linear(hidden_dim, latent_dim)
30 | self.fc_var = nn.Linear(hidden_dim, latent_dim)
31 |
32 | self.decoder_input = nn.Linear(latent_dim, hidden_dim)
33 |
34 | self.decoder = nn.Sequential(
35 | nn.Linear(hidden_dim, in_channels),
36 | nn.BatchNorm1d(in_channels),
37 | nn.Sigmoid()
38 | )
39 |
40 | def encode(self, input: Tensor) -> List[Tensor]:
41 | result = self.encoder(input)
42 |
43 | mu = self.fc_mu(result)
44 | log_var = self.fc_var(result)
45 |
46 | return [mu, log_var]
47 |
48 | def decode(self, z: Tensor) -> Tensor:
49 |
50 | z = self.decoder_input(z)
51 | z_out = self.decoder(z)
52 | return z_out
53 |
54 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
55 | std = torch.exp(0.5 * logvar)
56 | eps = torch.randn_like(std)
57 | return eps * std + mu, std + mu
58 |
59 | def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
60 | mu, log_var = self.encode(input)
61 | z, z_inv = self.reparameterize(mu, log_var)
62 | z_out = self.decode(z)
63 |
64 | return [z_out, z_inv, input, mu, log_var]
65 |
66 | def loss_function(self, input, rec, mu, log_var, kld_weight=0.00025) -> dict:
67 | recons_loss = F.mse_loss(rec, input)
68 |
69 | kld_loss = torch.mean(-0.5 * torch.sum(1 +
70 | log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
71 |
72 | loss = recons_loss + kld_weight * kld_loss
73 |
74 | return {'loss_vae': loss}
75 |
76 |
77 | @HEADS.register_module()
78 | class VFARoIHead(MetaRCNNRoIHead):
79 |
80 | def __init__(self, vae_dim=2048, *args, **kargs) -> None:
81 | super().__init__(*args, **kargs)
82 |
83 | self.vae = VAE(vae_dim, vae_dim, vae_dim)
84 |
85 | def _bbox_forward_train(self, query_feats: List[Tensor],
86 | support_feats: List[Tensor],
87 | sampling_results: object,
88 | query_img_metas: List[Dict],
89 | query_gt_bboxes: List[Tensor],
90 | query_gt_labels: List[Tensor],
91 | support_gt_labels: List[Tensor]) -> Dict:
92 | """Forward function and calculate loss for box head in training.
93 |
94 | Args:
95 | query_feats (list[Tensor]): List of query features, each item
96 | with shape (N, C, H, W).
97 | support_feats (list[Tensor]): List of support features, each item
98 | with shape (N, C, H, W).
99 | sampling_results (obj:`SamplingResult`): Sampling results.
100 | query_img_metas (list[dict]): List of query image info dict where
101 | each dict has: 'img_shape', 'scale_factor', 'flip', and may
102 | also contain 'filename', 'ori_shape', 'pad_shape', and
103 | 'img_norm_cfg'. For details on the values of these keys see
104 | `mmdet/datasets/pipelines/formatting.py:Collect`.
105 | query_gt_bboxes (list[Tensor]): Ground truth bboxes for each query
106 | image with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y]
107 | format.
108 | query_gt_labels (list[Tensor]): Class indices corresponding to
109 | each box of query images.
110 | support_gt_labels (list[Tensor]): Class indices corresponding to
111 | each box of support images.
112 |
113 | Returns:
114 | dict: Predicted results and losses.
115 | """
116 | query_rois = bbox2roi([res.bboxes for res in sampling_results])
117 | query_roi_feats = self.extract_query_roi_feat(query_feats, query_rois)
118 | support_feat = self.extract_support_feats(support_feats)[0]
119 | support_feat_rec, support_feat_inv, _, mu, log_var = self.vae(
120 | support_feat)
121 |
122 | bbox_targets = self.bbox_head.get_targets(sampling_results,
123 | query_gt_bboxes,
124 | query_gt_labels,
125 | self.train_cfg)
126 | (labels, label_weights, bbox_targets, bbox_weights) = bbox_targets
127 | loss_bbox = {'loss_cls': [], 'loss_bbox': [], 'acc': []}
128 | batch_size = len(query_img_metas)
129 | num_sample_per_imge = query_roi_feats.size(0) // batch_size
130 | bbox_results = None
131 | for img_id in range(batch_size):
132 | start = img_id * num_sample_per_imge
133 | end = (img_id + 1) * num_sample_per_imge
134 | # class agnostic aggregation
135 | # random_index = np.random.choice(
136 | # range(query_gt_labels[img_id].size(0)))
137 | # random_query_label = query_gt_labels[img_id][random_index]
138 | random_index = np.random.choice(
139 | range(len(support_gt_labels)))
140 | random_query_label = support_gt_labels[random_index]
141 | for i in range(support_feat.size(0)):
142 | if support_gt_labels[i] == random_query_label:
143 | bbox_results = self._bbox_forward(
144 | query_roi_feats[start:end],
145 | support_feat_inv[i].sigmoid().unsqueeze(0))
146 | single_loss_bbox = self.bbox_head.loss(
147 | bbox_results['cls_score'], bbox_results['bbox_pred'],
148 | query_rois[start:end], labels[start:end],
149 | label_weights[start:end], bbox_targets[start:end],
150 | bbox_weights[start:end])
151 | for key in single_loss_bbox.keys():
152 | loss_bbox[key].append(single_loss_bbox[key])
153 | if bbox_results is not None:
154 | for key in loss_bbox.keys():
155 | if key == 'acc':
156 | loss_bbox[key] = torch.cat(loss_bbox['acc']).mean()
157 | else:
158 | loss_bbox[key] = torch.stack(
159 | loss_bbox[key]).sum() / batch_size
160 |
161 | # meta classification loss
162 | if self.bbox_head.with_meta_cls_loss:
163 | # input support feature classification
164 | meta_cls_score = self.bbox_head.forward_meta_cls(support_feat_rec)
165 | meta_cls_labels = torch.cat(support_gt_labels)
166 | loss_meta_cls = self.bbox_head.loss_meta(
167 | meta_cls_score, meta_cls_labels,
168 | torch.ones_like(meta_cls_labels))
169 | loss_bbox.update(loss_meta_cls)
170 |
171 | loss_vae = self.vae.loss_function(
172 | support_feat, support_feat_rec, mu, log_var)
173 | loss_bbox.update(loss_vae)
174 |
175 | bbox_results.update(loss_bbox=loss_bbox)
176 | return bbox_results
177 |
178 | def _bbox_forward(self, query_roi_feats: Tensor,
179 | support_roi_feats: Tensor) -> Dict:
180 | """Box head forward function used in both training and testing.
181 |
182 | Args:
183 | query_roi_feats (Tensor): Query roi features with shape (N, C).
184 | support_roi_feats (Tensor): Support features with shape (1, C).
185 |
186 | Returns:
187 | dict: A dictionary of predicted results.
188 | """
189 | # feature aggregation
190 | roi_feats = self.aggregation_layer(
191 | query_feat=query_roi_feats.unsqueeze(-1).unsqueeze(-1),
192 | support_feat=support_roi_feats.view(1, -1, 1, 1))[0]
193 | cls_score, bbox_pred = self.bbox_head(
194 | roi_feats.squeeze(-1).squeeze(-1), query_roi_feats)
195 | bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
196 | return bbox_results
197 |
198 | def simple_test_bboxes(
199 | self,
200 | query_feats: List[Tensor],
201 | support_feats_dict: Dict,
202 | query_img_metas: List[Dict],
203 | proposals: List[Tensor],
204 | rcnn_test_cfg: ConfigDict,
205 | rescale: bool = False) -> Tuple[List[Tensor], List[Tensor]]:
206 | """Test only det bboxes without augmentation.
207 |
208 | Args:
209 | query_feats (list[Tensor]): Features of query image,
210 | each item with shape (N, C, H, W).
211 | support_feats_dict (dict[int, Tensor]) Dict of support features
212 | used for inference only, each key is the class id and value is
213 | the support template features with shape (1, C).
214 | query_img_metas (list[dict]): list of image info dict where each
215 | dict has: `img_shape`, `scale_factor`, `flip`, and may also
216 | contain `filename`, `ori_shape`, `pad_shape`, and
217 | `img_norm_cfg`. For details on the values of these keys see
218 | :class:`mmdet.datasets.pipelines.Collect`.
219 | proposals (list[Tensor]): Region proposals.
220 | rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
221 | rescale (bool): If True, return boxes in original image space.
222 | Default: False.
223 |
224 | Returns:
225 | tuple[list[Tensor], list[Tensor]]: Each tensor in first list
226 | with shape (num_boxes, 4) and with shape (num_boxes, )
227 | in second list. The length of both lists should be equal
228 | to batch_size.
229 | """
230 | img_shapes = tuple(meta['img_shape'] for meta in query_img_metas)
231 | scale_factors = tuple(meta['scale_factor'] for meta in query_img_metas)
232 |
233 | rois = bbox2roi(proposals)
234 |
235 | query_roi_feats = self.extract_query_roi_feat(query_feats, rois)
236 | cls_scores_dict, bbox_preds_dict = {}, {}
237 | num_classes = self.bbox_head.num_classes
238 | for class_id in support_feats_dict.keys():
239 | support_feat = support_feats_dict[class_id]
240 | support_feat_rec, support_feat_inv, _, mu, log_var = self.vae(
241 | support_feat)
242 | bbox_results = self._bbox_forward(
243 | query_roi_feats, support_feat_inv.sigmoid())
244 | cls_scores_dict[class_id] = \
245 | bbox_results['cls_score'][:, class_id:class_id + 1]
246 | bbox_preds_dict[class_id] = \
247 | bbox_results['bbox_pred'][:, class_id * 4:(class_id + 1) * 4]
248 | # the official code use the first class background score as final
249 | # background score, while this code use average of all classes'
250 | # background scores instead.
251 | if cls_scores_dict.get(num_classes, None) is None:
252 | cls_scores_dict[num_classes] = \
253 | bbox_results['cls_score'][:, -1:]
254 | else:
255 | cls_scores_dict[num_classes] += \
256 | bbox_results['cls_score'][:, -1:]
257 | cls_scores_dict[num_classes] /= len(support_feats_dict.keys())
258 | cls_scores = [
259 | cls_scores_dict[i] if i in cls_scores_dict.keys() else
260 | torch.zeros_like(cls_scores_dict[list(cls_scores_dict.keys())[0]])
261 | for i in range(num_classes + 1)
262 | ]
263 | bbox_preds = [
264 | bbox_preds_dict[i] if i in bbox_preds_dict.keys() else
265 | torch.zeros_like(bbox_preds_dict[list(bbox_preds_dict.keys())[0]])
266 | for i in range(num_classes)
267 | ]
268 | cls_score = torch.cat(cls_scores, dim=1)
269 | bbox_pred = torch.cat(bbox_preds, dim=1)
270 |
271 | # split batch bbox prediction back to each image
272 | num_proposals_per_img = tuple(len(p) for p in proposals)
273 | rois = rois.split(num_proposals_per_img, 0)
274 | cls_score = cls_score.split(num_proposals_per_img, 0)
275 | bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
276 |
277 | # apply bbox post-processing to each image individually
278 | det_bboxes = []
279 | det_labels = []
280 | for i in range(len(proposals)):
281 | det_bbox, det_label = self.bbox_head.get_bboxes(
282 | rois[i],
283 | cls_score[i],
284 | bbox_pred[i],
285 | img_shapes[i],
286 | scale_factors[i],
287 | rescale=rescale,
288 | cfg=rcnn_test_cfg)
289 | det_bboxes.append(det_bbox)
290 | det_labels.append(det_label)
291 | return det_bboxes, det_labels
292 |
--------------------------------------------------------------------------------