├── .gitignore
├── README.md
├── configs
├── _base_
│ ├── datasets
│ │ ├── cifar_fscil.py
│ │ ├── cub_fscil.py
│ │ └── mini_imagenet_fscil.py
│ ├── default_runtime.py
│ ├── models
│ │ └── resnet_etf.py
│ └── schedules
│ │ ├── cifar_200e.py
│ │ ├── cub_80e.py
│ │ └── mini_imagenet_500e.py
├── cifar
│ ├── resnet12_etf_bs512_200e_cifar.py
│ └── resnet12_etf_bs512_200e_cifar_eval.py
├── cub
│ ├── resnet18_etf_bs512_80e_cub.py
│ └── resnet18_etf_bs512_80e_cub_eval.py
└── mini_imagenet
│ ├── resnet12_etf_bs512_500e_miniimagenet.py
│ └── resnet12_etf_bs512_500e_miniimagenet_eval.py
├── docker_env
└── Dockerfile
├── logs
├── cifar_base.log
├── cifar_inc.log
├── cub_base.log
├── cub_inc.log
├── min_base.log
└── min_inc.log
├── mmcls
├── __init__.py
├── apis
│ ├── __init__.py
│ ├── inference.py
│ ├── test.py
│ └── train.py
├── core
│ ├── __init__.py
│ ├── evaluation
│ │ ├── __init__.py
│ │ ├── eval_hooks.py
│ │ ├── eval_metrics.py
│ │ ├── mean_ap.py
│ │ └── multilabel_eval_metrics.py
│ ├── export
│ │ ├── __init__.py
│ │ └── test.py
│ ├── hook
│ │ ├── __init__.py
│ │ ├── class_num_check_hook.py
│ │ ├── lr_updater.py
│ │ ├── precise_bn_hook.py
│ │ └── wandblogger_hook.py
│ ├── optimizers
│ │ ├── __init__.py
│ │ └── lamb.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── dist_utils.py
│ │ └── misc.py
│ └── visualization
│ │ ├── __init__.py
│ │ └── image.py
├── datasets
│ ├── __init__.py
│ ├── base_dataset.py
│ ├── builder.py
│ ├── cifar.py
│ ├── cub.py
│ ├── custom.py
│ ├── dataset_wrappers.py
│ ├── imagenet.py
│ ├── imagenet21k.py
│ ├── mnist.py
│ ├── multi_label.py
│ ├── pipelines
│ │ ├── __init__.py
│ │ ├── auto_augment.py
│ │ ├── compose.py
│ │ ├── formatting.py
│ │ ├── loading.py
│ │ └── transforms.py
│ ├── samplers
│ │ ├── __init__.py
│ │ ├── distributed_sampler.py
│ │ └── repeat_aug.py
│ ├── utils.py
│ └── voc.py
├── models
│ ├── __init__.py
│ ├── backbones
│ │ ├── __init__.py
│ │ ├── alexnet.py
│ │ ├── base_backbone.py
│ │ ├── conformer.py
│ │ ├── convmixer.py
│ │ ├── convnext.py
│ │ ├── cspnet.py
│ │ ├── deit.py
│ │ ├── densenet.py
│ │ ├── efficientnet.py
│ │ ├── hrnet.py
│ │ ├── lenet.py
│ │ ├── mlp_mixer.py
│ │ ├── mobilenet_v2.py
│ │ ├── mobilenet_v3.py
│ │ ├── poolformer.py
│ │ ├── regnet.py
│ │ ├── repmlp.py
│ │ ├── repvgg.py
│ │ ├── res2net.py
│ │ ├── resnest.py
│ │ ├── resnet.py
│ │ ├── resnet_cifar.py
│ │ ├── resnext.py
│ │ ├── seresnet.py
│ │ ├── seresnext.py
│ │ ├── shufflenet_v1.py
│ │ ├── shufflenet_v2.py
│ │ ├── swin_transformer.py
│ │ ├── t2t_vit.py
│ │ ├── timm_backbone.py
│ │ ├── tnt.py
│ │ ├── twins.py
│ │ ├── van.py
│ │ ├── vgg.py
│ │ └── vision_transformer.py
│ ├── builder.py
│ ├── classifiers
│ │ ├── __init__.py
│ │ ├── base.py
│ │ └── image.py
│ ├── heads
│ │ ├── __init__.py
│ │ ├── base_head.py
│ │ ├── cls_head.py
│ │ ├── conformer_head.py
│ │ ├── deit_head.py
│ │ ├── linear_head.py
│ │ ├── multi_label_head.py
│ │ ├── multi_label_linear_head.py
│ │ ├── stacked_head.py
│ │ └── vision_transformer_head.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── accuracy.py
│ │ ├── asymmetric_loss.py
│ │ ├── cross_entropy_loss.py
│ │ ├── focal_loss.py
│ │ ├── label_smooth_loss.py
│ │ ├── seesaw_loss.py
│ │ └── utils.py
│ ├── necks
│ │ ├── __init__.py
│ │ ├── gap.py
│ │ ├── gem.py
│ │ └── hr_fuse.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── attention.py
│ │ ├── augment
│ │ ├── __init__.py
│ │ ├── augments.py
│ │ ├── builder.py
│ │ ├── cutmix.py
│ │ ├── identity.py
│ │ ├── mixup.py
│ │ ├── resizemix.py
│ │ └── utils.py
│ │ ├── channel_shuffle.py
│ │ ├── embed.py
│ │ ├── helpers.py
│ │ ├── inverted_residual.py
│ │ ├── make_divisible.py
│ │ ├── position_encoding.py
│ │ └── se_layer.py
├── utils
│ ├── __init__.py
│ ├── collect_env.py
│ ├── device.py
│ ├── distribution.py
│ ├── logger.py
│ └── setup_env.py
└── version.py
├── mmfewshot
├── __init__.py
├── classification
│ ├── __init__.py
│ ├── apis
│ │ ├── __init__.py
│ │ ├── inference.py
│ │ ├── test.py
│ │ └── train.py
│ ├── core
│ │ ├── __init__.py
│ │ └── evaluation
│ │ │ ├── __init__.py
│ │ │ └── eval_hooks.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── builder.py
│ │ ├── cub.py
│ │ ├── dataset_wrappers.py
│ │ ├── mini_imagenet.py
│ │ ├── pipelines
│ │ │ ├── __init__.py
│ │ │ └── loading.py
│ │ ├── tiered_imagenet.py
│ │ └── utils.py
│ ├── models
│ │ ├── __init__.py
│ │ ├── backbones
│ │ │ ├── __init__.py
│ │ │ ├── conv4.py
│ │ │ ├── resnet12.py
│ │ │ ├── utils.py
│ │ │ └── wrn.py
│ │ ├── classifiers
│ │ │ ├── __init__.py
│ │ │ ├── base.py
│ │ │ ├── base_finetune.py
│ │ │ ├── base_metric.py
│ │ │ ├── baseline.py
│ │ │ ├── baseline_plus.py
│ │ │ ├── maml.py
│ │ │ ├── matching_net.py
│ │ │ ├── meta_baseline.py
│ │ │ ├── neg_margin.py
│ │ │ ├── proto_net.py
│ │ │ └── relation_net.py
│ │ ├── heads
│ │ │ ├── __init__.py
│ │ │ ├── base_head.py
│ │ │ ├── cosine_distance_head.py
│ │ │ ├── linear_head.py
│ │ │ ├── matching_head.py
│ │ │ ├── meta_baseline_head.py
│ │ │ ├── neg_margin_head.py
│ │ │ ├── prototype_head.py
│ │ │ └── relation_head.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── mse_loss.py
│ │ │ └── nll_loss.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ └── maml_module.py
│ └── utils
│ │ ├── __init__.py
│ │ └── meta_test_parallel.py
├── detection
│ ├── __init__.py
│ ├── apis
│ │ ├── __init__.py
│ │ ├── inference.py
│ │ ├── test.py
│ │ └── train.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── evaluation
│ │ │ ├── __init__.py
│ │ │ ├── eval_hooks.py
│ │ │ └── mean_ap.py
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ └── custom_hook.py
│ ├── datasets
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── builder.py
│ │ ├── coco.py
│ │ ├── dataloader_wrappers.py
│ │ ├── dataset_wrappers.py
│ │ ├── pipelines
│ │ │ ├── __init__.py
│ │ │ ├── formatting.py
│ │ │ └── transforms.py
│ │ ├── utils.py
│ │ └── voc.py
│ └── models
│ │ ├── __init__.py
│ │ ├── backbones
│ │ ├── __init__.py
│ │ └── resnet_with_meta_conv.py
│ │ ├── builder.py
│ │ ├── dense_heads
│ │ ├── __init__.py
│ │ ├── attention_rpn_head.py
│ │ └── two_branch_rpn_head.py
│ │ ├── detectors
│ │ ├── __init__.py
│ │ ├── attention_rpn_detector.py
│ │ ├── fsce.py
│ │ ├── fsdetview.py
│ │ ├── meta_rcnn.py
│ │ ├── mpsr.py
│ │ ├── query_support_detector.py
│ │ └── tfa.py
│ │ ├── losses
│ │ ├── __init__.py
│ │ └── supervised_contrastive_loss.py
│ │ ├── roi_heads
│ │ ├── __init__.py
│ │ ├── bbox_heads
│ │ │ ├── __init__.py
│ │ │ ├── contrastive_bbox_head.py
│ │ │ ├── cosine_sim_bbox_head.py
│ │ │ ├── meta_bbox_head.py
│ │ │ ├── multi_relation_bbox_head.py
│ │ │ └── two_branch_bbox_head.py
│ │ ├── contrastive_roi_head.py
│ │ ├── fsdetview_roi_head.py
│ │ ├── meta_rcnn_roi_head.py
│ │ ├── multi_relation_roi_head.py
│ │ ├── shared_heads
│ │ │ ├── __init__.py
│ │ │ └── meta_rcnn_res_layer.py
│ │ └── two_branch_roi_head.py
│ │ └── utils
│ │ ├── __init__.py
│ │ └── aggregation_layer.py
├── utils
│ ├── __init__.py
│ ├── collate.py
│ ├── collect_env.py
│ ├── compat_config.py
│ ├── dist_utils.py
│ ├── infinite_sampler.py
│ ├── local_seed.py
│ ├── logger.py
│ └── runner.py
└── version.py
├── mmfscil
├── __init__.py
├── apis
│ ├── __init__.py
│ ├── eval_hook.py
│ ├── fscil.py
│ ├── test_fn.py
│ └── train.py
├── augments
│ ├── __init__.py
│ ├── cutmix.py
│ ├── idty.py
│ └── mixup.py
├── datasets
│ ├── __init__.py
│ ├── cifar100.py
│ ├── cub.py
│ ├── memory.py
│ └── mini_imagenet.py
└── models
│ ├── ETFHead.py
│ ├── __init__.py
│ ├── classifier.py
│ ├── mlp_ffn_neck.py
│ └── resnet18.py
└── tools
├── dist_train.sh
├── docker.sh
├── fscil.py
├── run_fscil.sh
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | # *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
131 | # data
132 | data/
133 |
134 | # Macos
135 | .DS_Store
136 |
137 | # Pycharm
138 | .idea/
139 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cifar_fscil.py:
--------------------------------------------------------------------------------
1 | img_size = 32
2 | _img_resize_size = 36
3 | img_norm_cfg = dict(mean=[129.304, 124.070, 112.434], std=[68.170, 65.392, 70.418], to_rgb=False)
4 | meta_keys = ('filename', 'ori_filename', 'ori_shape',
5 | 'img_shape', 'flip', 'flip_direction',
6 | 'img_norm_cfg', 'cls_id', 'img_id')
7 |
8 | train_pipeline = [
9 | dict(type='RandomResizedCrop', size=img_size, scale=(0.6, 1.), interpolation='bicubic'),
10 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
11 | dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
12 | dict(type='Normalize', **img_norm_cfg),
13 | dict(type='ImageToTensor', keys=['img']),
14 | dict(type='ToTensor', keys=['gt_label']),
15 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys)
16 | ]
17 |
18 | test_pipeline = [
19 | dict(type='Resize', size=(_img_resize_size, -1), interpolation='bicubic'),
20 | dict(type='CenterCrop', crop_size=img_size),
21 | dict(type='Normalize', **img_norm_cfg),
22 | dict(type='ImageToTensor', keys=['img']),
23 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys)
24 | ]
25 |
26 | data = dict(
27 | samples_per_gpu=64,
28 | workers_per_gpu=8,
29 | train_dataloader=dict(
30 | persistent_workers=True,
31 | ),
32 | val_dataloader=dict(
33 | persistent_workers=True,
34 | ),
35 | test_dataloader=dict(
36 | persistent_workers=True,
37 | ),
38 | train=dict(
39 | type='RepeatDataset',
40 | times=4,
41 | dataset=dict(
42 | type='CIFAR100FSCILDataset',
43 | data_prefix='/opt/data/cifar',
44 | pipeline=train_pipeline,
45 | num_cls=60,
46 | subset='train',
47 | )
48 | ),
49 | val=dict(
50 | type='CIFAR100FSCILDataset',
51 | data_prefix='/opt/data/cifar',
52 | pipeline=test_pipeline,
53 | num_cls=60,
54 | subset='test',
55 | ),
56 | test=dict(
57 | type='CIFAR100FSCILDataset',
58 | data_prefix='/opt/data/cifar',
59 | pipeline=test_pipeline,
60 | num_cls=100,
61 | subset='test',
62 | )
63 | )
64 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/cub_fscil.py:
--------------------------------------------------------------------------------
1 | img_size = 224
2 | _img_resize_size = 256
3 | img_norm_cfg = dict(
4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
5 | )
6 | meta_keys = ('filename', 'ori_filename', 'ori_shape',
7 | 'img_shape', 'flip', 'flip_direction',
8 | 'img_norm_cfg', 'cls_id', 'img_id')
9 |
10 | train_pipeline = [
11 | dict(type='LoadImageFromFile'),
12 | dict(type='RandomResizedCrop', size=img_size),
13 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
14 | dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
15 | dict(type='Normalize', **img_norm_cfg),
16 | dict(type='ImageToTensor', keys=['img']),
17 | dict(type='ToTensor', keys=['gt_label']),
18 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys)
19 | ]
20 |
21 | test_pipeline = [
22 | dict(type='LoadImageFromFile'),
23 | dict(type='Resize', size=(_img_resize_size, -1)),
24 | dict(type='CenterCrop', crop_size=img_size),
25 | dict(type='Normalize', **img_norm_cfg),
26 | dict(type='ImageToTensor', keys=['img']),
27 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys)
28 | ]
29 |
30 | data = dict(
31 | samples_per_gpu=64,
32 | workers_per_gpu=8,
33 | train_dataloader=dict(
34 | persistent_workers=True,
35 | ),
36 | val_dataloader=dict(
37 | persistent_workers=True,
38 | ),
39 | test_dataloader=dict(
40 | persistent_workers=True,
41 | ),
42 | train=dict(
43 | type='RepeatDataset',
44 | times=4,
45 | dataset=dict(
46 | type='CUBFSCILDataset',
47 | data_prefix='/opt/data/CUB_200_2011',
48 | pipeline=train_pipeline,
49 | num_cls=100,
50 | subset='train',
51 | )
52 | ),
53 | val=dict(
54 | type='CUBFSCILDataset',
55 | data_prefix='/opt/data/CUB_200_2011',
56 | pipeline=test_pipeline,
57 | num_cls=100,
58 | subset='test',
59 | ),
60 | test=dict(
61 | type='CUBFSCILDataset',
62 | data_prefix='/opt/data/CUB_200_2011',
63 | pipeline=test_pipeline,
64 | num_cls=200,
65 | subset='test',
66 | )
67 | )
68 |
--------------------------------------------------------------------------------
/configs/_base_/datasets/mini_imagenet_fscil.py:
--------------------------------------------------------------------------------
1 | img_size = 84
2 | img_norm_cfg = dict(
3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True
4 | )
5 | meta_keys = ('filename', 'ori_filename', 'ori_shape',
6 | 'img_shape', 'flip', 'flip_direction',
7 | 'img_norm_cfg', 'cls_id', 'img_id')
8 |
9 | train_pipeline = [
10 | dict(type='LoadImageFromFile'),
11 | dict(type='RandomResizedCrop', size=img_size),
12 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
13 | dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4),
14 | dict(type='Normalize', **img_norm_cfg),
15 | dict(type='ImageToTensor', keys=['img']),
16 | dict(type='ToTensor', keys=['gt_label']),
17 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys)
18 | ]
19 |
20 | test_pipeline = [
21 | dict(type='LoadImageFromFile'),
22 | dict(type='Resize', size=(int(img_size * 1.15), -1)),
23 | dict(type='CenterCrop', crop_size=img_size),
24 | dict(type='Normalize', **img_norm_cfg),
25 | dict(type='ImageToTensor', keys=['img']),
26 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys)
27 | ]
28 |
29 | data = dict(
30 | samples_per_gpu=64,
31 | workers_per_gpu=8,
32 | train_dataloader=dict(
33 | persistent_workers=True,
34 | ),
35 | val_dataloader=dict(
36 | persistent_workers=True,
37 | ),
38 | test_dataloader=dict(
39 | persistent_workers=True,
40 | ),
41 | train=dict(
42 | type='RepeatDataset',
43 | times=10,
44 | dataset=dict(
45 | type='MiniImageNetFSCILDataset',
46 | data_prefix='/opt/data/miniimagenet',
47 | pipeline=train_pipeline,
48 | num_cls=60,
49 | subset='train',
50 | )
51 | ),
52 | val=dict(
53 | type='MiniImageNetFSCILDataset',
54 | data_prefix='/opt/data/miniimagenet',
55 | pipeline=test_pipeline,
56 | num_cls=60,
57 | subset='test',
58 | ),
59 | test=dict(
60 | type='MiniImageNetFSCILDataset',
61 | data_prefix='/opt/data/miniimagenet',
62 | pipeline=test_pipeline,
63 | num_cls=100,
64 | subset='test',
65 | )
66 | )
67 |
--------------------------------------------------------------------------------
/configs/_base_/default_runtime.py:
--------------------------------------------------------------------------------
1 | # checkpoint saving
2 | checkpoint_config = dict(interval=1, max_keep_ckpts=2)
3 | evaluation = dict(interval=1, save_best='auto')
4 | log_config = dict(
5 | interval=10,
6 | hooks=[
7 | dict(type='TextLoggerHook'),
8 | ]
9 | )
10 |
11 | dist_params = dict(backend='nccl')
12 | log_level = 'INFO'
13 | workflow = [('train', 1)]
14 |
15 | load_from = None
16 | resume_from = None
17 |
18 | # Test configs
19 | mean_neck_feat = True
20 | mean_cur_feat = False
21 | feat_test = False
22 | grad_clip = None
23 | finetune_lr = 0.1
24 | inc_start = 60
25 | inc_end = 100
26 | inc_step = 5
27 |
28 | copy_list = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
29 | step_list = (50, 50, 50, 50, 50, 50, 50, 50, 50, 50)
30 |
--------------------------------------------------------------------------------
/configs/_base_/models/resnet_etf.py:
--------------------------------------------------------------------------------
1 | # model settings
2 | model = dict(
3 | type='ImageClassifierCIL',
4 | backbone=dict(
5 | type='ResNet12',
6 | with_avgpool=False,
7 | flatten=False
8 | ),
9 | neck=dict(type='MLPNeck', in_channels=640, out_channels=512),
10 | head=dict(
11 | type='ETFHead',
12 | num_classes=100,
13 | eval_classes=60,
14 | in_channels=512,
15 | loss=dict(type='DRLoss', loss_weight=10.),
16 | topk=(1, 5),
17 | cal_acc=True,
18 | )
19 | )
20 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/cifar_200e.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(
3 | type='SGD', lr=0.25, momentum=0.9, weight_decay=0.0005
4 | )
5 | optimizer_config = dict(grad_clip=None)
6 | lr_config = dict(
7 | policy='CosineAnnealingCooldown',
8 | min_lr=None,
9 | min_lr_ratio=1.e-2,
10 | cool_down_ratio=0.1,
11 | cool_down_time=10,
12 | by_epoch=False,
13 | # warmup
14 | warmup='linear',
15 | warmup_iters=100,
16 | warmup_ratio=0.1,
17 | warmup_by_epoch=False
18 | )
19 |
20 | runner = dict(type='EpochBasedRunner', max_epochs=50)
21 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/cub_80e.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(
3 | type='SGD', lr=0.025, momentum=0.9, weight_decay=0.0005
4 | )
5 | optimizer_config = dict(grad_clip=None)
6 | lr_config = dict(
7 | policy='CosineAnnealingCooldown',
8 | min_lr=None,
9 | min_lr_ratio=0.1,
10 | cool_down_ratio=0.1,
11 | cool_down_time=10,
12 | by_epoch=False,
13 | # warmup
14 | warmup='linear',
15 | warmup_iters=100,
16 | warmup_ratio=0.1,
17 | warmup_by_epoch=False
18 | )
19 |
20 | runner = dict(type='EpochBasedRunner', max_epochs=20)
21 |
--------------------------------------------------------------------------------
/configs/_base_/schedules/mini_imagenet_500e.py:
--------------------------------------------------------------------------------
1 | # optimizer
2 | optimizer = dict(
3 | type='SGD', lr=0.25, momentum=0.9, weight_decay=0.0005
4 | )
5 | optimizer_config = dict(grad_clip=None)
6 | # learning policy
7 | lr_config = dict(
8 | policy='step',
9 | gamma=0.25,
10 | warmup='linear',
11 | warmup_iters=3000,
12 | warmup_ratio=0.25,
13 | step=[20, 30, 35, 40, 45]
14 | )
15 | runner = dict(type='EpochBasedRunner', max_epochs=50)
16 |
--------------------------------------------------------------------------------
/configs/cifar/resnet12_etf_bs512_200e_cifar.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/resnet_etf.py',
3 | '../_base_/datasets/cifar_fscil.py',
4 | '../_base_/schedules/cifar_200e.py',
5 | '../_base_/default_runtime.py'
6 | ]
7 |
8 | # model settings
9 | model = dict(
10 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512),
11 | head=dict(type='ETFHead', in_channels=512, with_len=False),
12 | train_cfg=dict(augments=[
13 | dict(type='BatchMixupTwoLabel', alpha=0.8, num_classes=-1, prob=0.4),
14 | dict(type='BatchCutMixTwoLabel', alpha=1.0, num_classes=-1, prob=0.4),
15 | dict(type='IdentityTwoLabel', num_classes=-1, prob=0.2),
16 | ]),
17 | )
18 |
--------------------------------------------------------------------------------
/configs/cifar/resnet12_etf_bs512_200e_cifar_eval.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/resnet_etf.py',
3 | '../_base_/datasets/cifar_fscil.py',
4 | '../_base_/schedules/cifar_200e.py',
5 | '../_base_/default_runtime.py'
6 | ]
7 |
8 | # model settings
9 | model = dict(
10 | mixup=0.5,
11 | mixup_prob=0.75,
12 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512),
13 | head=dict(type='ETFHead', in_channels=512, with_len=True),
14 | )
15 |
16 | copy_list = (1, 1, 1, 1, 1, 1, 1, 1, None, None)
17 | step_list = (50, 75, 100, 120, 140, 160, 200, 200, None, None)
18 | finetune_lr = 0.25
19 |
--------------------------------------------------------------------------------
/configs/cub/resnet18_etf_bs512_80e_cub.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/resnet_etf.py',
3 | '../_base_/datasets/cub_fscil.py',
4 | '../_base_/schedules/cub_80e.py',
5 | '../_base_/default_runtime.py'
6 | ]
7 |
8 | # CUB requires different inc settings
9 | inc_start = 100
10 | inc_end = 200
11 | inc_step = 10
12 |
13 | # model settings
14 | model = dict(
15 | backbone=dict(
16 | _delete_=True,
17 | type='ResNet',
18 | depth=18,
19 | frozen_stages=1,
20 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
21 | norm_cfg=dict(type='SyncBN', requires_grad=True),
22 | ),
23 | neck=dict(type='MLPFFNNeck', in_channels=512, out_channels=512),
24 | head=dict(
25 | type='ETFHead',
26 | num_classes=200,
27 | eval_classes=100,
28 | with_len=False,
29 | )
30 | )
31 |
--------------------------------------------------------------------------------
/configs/cub/resnet18_etf_bs512_80e_cub_eval.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/resnet_etf.py',
3 | '../_base_/datasets/cub_fscil.py',
4 | '../_base_/schedules/cub_80e.py',
5 | '../_base_/default_runtime.py'
6 | ]
7 |
8 | # CUB requires different inc settings
9 | inc_start = 100
10 | inc_end = 200
11 | inc_step = 10
12 |
13 | # model settings
14 | model = dict(
15 | backbone=dict(
16 | _delete_=True,
17 | type='ResNet',
18 | depth=18,
19 | frozen_stages=1,
20 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'),
21 | norm_cfg=dict(type='BN', requires_grad=True),
22 | ),
23 | neck=dict(type='MLPFFNNeck', in_channels=512, out_channels=512),
24 | head=dict(
25 | type='ETFHead',
26 | num_classes=200,
27 | eval_classes=100,
28 | with_len=False,
29 | )
30 | )
31 |
32 | copy_list = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
33 | step_list = (105, 110, 115, 120, 125, 130, 135, 140, 145, 150)
34 | finetune_lr = 0.05
35 |
--------------------------------------------------------------------------------
/configs/mini_imagenet/resnet12_etf_bs512_500e_miniimagenet.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/resnet_etf.py',
3 | '../_base_/datasets/mini_imagenet_fscil.py',
4 | '../_base_/schedules/mini_imagenet_500e.py',
5 | '../_base_/default_runtime.py'
6 | ]
7 |
8 |
9 | # model settings
10 | model = dict(
11 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512),
12 | head=dict(type='ETFHead', in_channels=512, with_len=False),
13 | )
14 |
15 |
--------------------------------------------------------------------------------
/configs/mini_imagenet/resnet12_etf_bs512_500e_miniimagenet_eval.py:
--------------------------------------------------------------------------------
1 | _base_ = [
2 | '../_base_/models/resnet_etf.py',
3 | '../_base_/datasets/mini_imagenet_fscil.py',
4 | '../_base_/schedules/mini_imagenet_500e.py',
5 | '../_base_/default_runtime.py'
6 | ]
7 |
8 |
9 | # model settings
10 | model = dict(
11 | mixup=0.,
12 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512),
13 | head=dict(type='ETFHead', in_channels=512, with_len=False),
14 | )
15 |
16 | copy_list = (1, 2, 3, 4, 5, 6, 7, 8, None, None)
17 | step_list = (100, 110, 120, 130, 140, 150, 160, 170, None, None)
18 | finetune_lr = 0.025
19 |
--------------------------------------------------------------------------------
/docker_env/Dockerfile:
--------------------------------------------------------------------------------
1 | # Pytorch 1.13 CUDA 11.7
2 | FROM nvcr.io/nvidia/pytorch:22.06-py3
3 |
4 | # install ujson
5 | RUN pip install ujson
6 |
7 | # handle the timezone
8 | RUN apt-get update && DEBIAN_FRONTEND="noninteractive" TZ="PRC" apt-get install tzdata \
9 | && apt-get clean && rm -rf /var/lib/apt/lists/* \
10 | && unlink /etc/localtime && ln -s /usr/share/zoneinfo/PRC /etc/localtime
11 |
12 | # mmcv : 1.6.1
13 | RUN until MMCV_WITH_OPS=1 FORCE_CUDA=1 python -m pip install git+git://github.com/open-mmlab/mmcv.git@d409eedc816fccfb1c8d57e5eed5f03bd075f327; do sleep 0.1; done
14 |
15 | # git config
16 | RUN git config --global --add safe.directory /opt/project
17 |
18 | WORKDIR /opt/project
19 |
--------------------------------------------------------------------------------
/mmcls/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | import mmcv
5 | from packaging.version import parse
6 |
7 | from .version import __version__
8 |
9 |
10 | def digit_version(version_str: str, length: int = 4):
11 | """Convert a version string into a tuple of integers.
12 |
13 | This method is usually used for comparing two versions. For pre-release
14 | versions: alpha < beta < rc.
15 |
16 | Args:
17 | version_str (str): The version string.
18 | length (int): The maximum number of version levels. Default: 4.
19 |
20 | Returns:
21 | tuple[int]: The version info in digits (integers).
22 | """
23 | version = parse(version_str)
24 | assert version.release, f'failed to parse version {version_str}'
25 | release = list(version.release)
26 | release = release[:length]
27 | if len(release) < length:
28 | release = release + [0] * (length - len(release))
29 | if version.is_prerelease:
30 | mapping = {'a': -3, 'b': -2, 'rc': -1}
31 | val = -4
32 | # version.pre can be None
33 | if version.pre:
34 | if version.pre[0] not in mapping:
35 | warnings.warn(f'unknown prerelease version {version.pre[0]}, '
36 | 'version checking may go wrong')
37 | else:
38 | val = mapping[version.pre[0]]
39 | release.extend([val, version.pre[-1]])
40 | else:
41 | release.extend([val, 0])
42 |
43 | elif version.is_postrelease:
44 | release.extend([1, version.post])
45 | else:
46 | release.extend([0, 0])
47 | return tuple(release)
48 |
49 |
50 | mmcv_minimum_version = '1.4.2'
51 | mmcv_maximum_version = '1.7.0'
52 | mmcv_version = digit_version(mmcv.__version__)
53 |
54 |
55 | assert (mmcv_version >= digit_version(mmcv_minimum_version)
56 | and mmcv_version <= digit_version(mmcv_maximum_version)), \
57 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
58 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
59 |
60 | __all__ = ['__version__', 'digit_version']
61 |
--------------------------------------------------------------------------------
/mmcls/apis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .inference import inference_model, init_model, show_result_pyplot
3 | from .test import multi_gpu_test, single_gpu_test
4 | from .train import init_random_seed, set_random_seed, train_model
5 |
6 | __all__ = [
7 | 'set_random_seed', 'train_model', 'init_model', 'inference_model',
8 | 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot',
9 | 'init_random_seed'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmcls/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .evaluation import * # noqa: F401, F403
3 | from .hook import * # noqa: F401, F403
4 | from .optimizers import * # noqa: F401, F403
5 | from .utils import * # noqa: F401, F403
6 |
--------------------------------------------------------------------------------
/mmcls/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .eval_hooks import DistEvalHook, EvalHook
3 | from .eval_metrics import (calculate_confusion_matrix, f1_score, precision,
4 | precision_recall_f1, recall, support)
5 | from .mean_ap import average_precision, mAP
6 | from .multilabel_eval_metrics import average_performance
7 |
8 | __all__ = [
9 | 'precision', 'recall', 'f1_score', 'support', 'average_precision', 'mAP',
10 | 'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1',
11 | 'EvalHook', 'DistEvalHook'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmcls/core/evaluation/eval_hooks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | import torch.distributed as dist
5 | from mmcv.runner import DistEvalHook as BaseDistEvalHook
6 | from mmcv.runner import EvalHook as BaseEvalHook
7 | from torch.nn.modules.batchnorm import _BatchNorm
8 |
9 |
10 | class EvalHook(BaseEvalHook):
11 | """Non-Distributed evaluation hook.
12 |
13 | Comparing with the ``EvalHook`` in MMCV, this hook will save the latest
14 | evaluation results as an attribute for other hooks to use (like
15 | `MMClsWandbHook`).
16 | """
17 |
18 | def __init__(self, dataloader, **kwargs):
19 | super(EvalHook, self).__init__(dataloader, **kwargs)
20 | self.latest_results = None
21 |
22 | def _do_evaluate(self, runner):
23 | """perform evaluation and save ckpt."""
24 | results = self.test_fn(runner.model, self.dataloader)
25 | self.latest_results = results
26 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
27 | key_score = self.evaluate(runner, results)
28 | # the key_score may be `None` so it needs to skip the action to save
29 | # the best checkpoint
30 | if self.save_best and key_score:
31 | self._save_ckpt(runner, key_score)
32 |
33 |
34 | class DistEvalHook(BaseDistEvalHook):
35 | """Non-Distributed evaluation hook.
36 |
37 | Comparing with the ``EvalHook`` in MMCV, this hook will save the latest
38 | evaluation results as an attribute for other hooks to use (like
39 | `MMClsWandbHook`).
40 | """
41 |
42 | def __init__(self, dataloader, **kwargs):
43 | super(DistEvalHook, self).__init__(dataloader, **kwargs)
44 | self.latest_results = None
45 |
46 | def _do_evaluate(self, runner):
47 | """perform evaluation and save ckpt."""
48 | # Synchronization of BatchNorm's buffer (running_mean
49 | # and running_var) is not supported in the DDP of pytorch,
50 | # which may cause the inconsistent performance of models in
51 | # different ranks, so we broadcast BatchNorm's buffers
52 | # of rank 0 to other ranks to avoid this.
53 | if self.broadcast_bn_buffer:
54 | model = runner.model
55 | for name, module in model.named_modules():
56 | if isinstance(module,
57 | _BatchNorm) and module.track_running_stats:
58 | dist.broadcast(module.running_var, 0)
59 | dist.broadcast(module.running_mean, 0)
60 |
61 | tmpdir = self.tmpdir
62 | if tmpdir is None:
63 | tmpdir = osp.join(runner.work_dir, '.eval_hook')
64 |
65 | results = self.test_fn(
66 | runner.model,
67 | self.dataloader,
68 | tmpdir=tmpdir,
69 | gpu_collect=self.gpu_collect)
70 | self.latest_results = results
71 | if runner.rank == 0:
72 | print('\n')
73 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader)
74 | key_score = self.evaluate(runner, results)
75 | # the key_score may be `None` so it needs to skip the action to
76 | # save the best checkpoint
77 | if self.save_best and key_score:
78 | self._save_ckpt(runner, key_score)
79 |
--------------------------------------------------------------------------------
/mmcls/core/evaluation/mean_ap.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 |
5 |
6 | def average_precision(pred, target):
7 | r"""Calculate the average precision for a single class.
8 |
9 | AP summarizes a precision-recall curve as the weighted mean of maximum
10 | precisions obtained for any r'>r, where r is the recall:
11 |
12 | .. math::
13 | \text{AP} = \sum_n (R_n - R_{n-1}) P_n
14 |
15 | Note that no approximation is involved since the curve is piecewise
16 | constant.
17 |
18 | Args:
19 | pred (np.ndarray): The model prediction with shape (N, ).
20 | target (np.ndarray): The target of each prediction with shape (N, ).
21 |
22 | Returns:
23 | float: a single float as average precision value.
24 | """
25 | eps = np.finfo(np.float32).eps
26 |
27 | # sort examples
28 | sort_inds = np.argsort(-pred)
29 | sort_target = target[sort_inds]
30 |
31 | # count true positive examples
32 | pos_inds = sort_target == 1
33 | tp = np.cumsum(pos_inds)
34 | total_pos = tp[-1]
35 |
36 | # count not difficult examples
37 | pn_inds = sort_target != -1
38 | pn = np.cumsum(pn_inds)
39 |
40 | tp[np.logical_not(pos_inds)] = 0
41 | precision = tp / np.maximum(pn, eps)
42 | ap = np.sum(precision) / np.maximum(total_pos, eps)
43 | return ap
44 |
45 |
46 | def mAP(pred, target):
47 | """Calculate the mean average precision with respect of classes.
48 |
49 | Args:
50 | pred (torch.Tensor | np.ndarray): The model prediction with shape
51 | (N, C), where C is the number of classes.
52 | target (torch.Tensor | np.ndarray): The target of each prediction with
53 | shape (N, C), where C is the number of classes. 1 stands for
54 | positive examples, 0 stands for negative examples and -1 stands for
55 | difficult examples.
56 |
57 | Returns:
58 | float: A single float as mAP value.
59 | """
60 | if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
61 | pred = pred.detach().cpu().numpy()
62 | target = target.detach().cpu().numpy()
63 | elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)):
64 | raise TypeError('pred and target should both be torch.Tensor or'
65 | 'np.ndarray')
66 |
67 | assert pred.shape == \
68 | target.shape, 'pred and target should be in the same shape.'
69 | num_classes = pred.shape[1]
70 | ap = np.zeros(num_classes)
71 | for k in range(num_classes):
72 | ap[k] = average_precision(pred[:, k], target[:, k])
73 | mean_ap = ap.mean() * 100.0
74 | return mean_ap
75 |
--------------------------------------------------------------------------------
/mmcls/core/evaluation/multilabel_eval_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | import numpy as np
5 | import torch
6 |
7 |
8 | def average_performance(pred, target, thr=None, k=None):
9 | """Calculate CP, CR, CF1, OP, OR, OF1, where C stands for per-class
10 | average, O stands for overall average, P stands for precision, R stands for
11 | recall and F1 stands for F1-score.
12 |
13 | Args:
14 | pred (torch.Tensor | np.ndarray): The model prediction with shape
15 | (N, C), where C is the number of classes.
16 | target (torch.Tensor | np.ndarray): The target of each prediction with
17 | shape (N, C), where C is the number of classes. 1 stands for
18 | positive examples, 0 stands for negative examples and -1 stands for
19 | difficult examples.
20 | thr (float): The confidence threshold. Defaults to None.
21 | k (int): Top-k performance. Note that if thr and k are both given, k
22 | will be ignored. Defaults to None.
23 |
24 | Returns:
25 | tuple: (CP, CR, CF1, OP, OR, OF1)
26 | """
27 | if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
28 | pred = pred.detach().cpu().numpy()
29 | target = target.detach().cpu().numpy()
30 | elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)):
31 | raise TypeError('pred and target should both be torch.Tensor or'
32 | 'np.ndarray')
33 | if thr is None and k is None:
34 | thr = 0.5
35 | warnings.warn('Neither thr nor k is given, set thr as 0.5 by '
36 | 'default.')
37 | elif thr is not None and k is not None:
38 | warnings.warn('Both thr and k are given, use threshold in favor of '
39 | 'top-k.')
40 |
41 | assert pred.shape == \
42 | target.shape, 'pred and target should be in the same shape.'
43 |
44 | eps = np.finfo(np.float32).eps
45 | target[target == -1] = 0
46 | if thr is not None:
47 | # a label is predicted positive if the confidence is no lower than thr
48 | pos_inds = pred >= thr
49 |
50 | else:
51 | # top-k labels will be predicted positive for any example
52 | sort_inds = np.argsort(-pred, axis=1)
53 | sort_inds_ = sort_inds[:, :k]
54 | inds = np.indices(sort_inds_.shape)
55 | pos_inds = np.zeros_like(pred)
56 | pos_inds[inds[0], sort_inds_] = 1
57 |
58 | tp = (pos_inds * target) == 1
59 | fp = (pos_inds * (1 - target)) == 1
60 | fn = ((1 - pos_inds) * target) == 1
61 |
62 | precision_class = tp.sum(axis=0) / np.maximum(
63 | tp.sum(axis=0) + fp.sum(axis=0), eps)
64 | recall_class = tp.sum(axis=0) / np.maximum(
65 | tp.sum(axis=0) + fn.sum(axis=0), eps)
66 | CP = precision_class.mean() * 100.0
67 | CR = recall_class.mean() * 100.0
68 | CF1 = 2 * CP * CR / np.maximum(CP + CR, eps)
69 | OP = tp.sum() / np.maximum(tp.sum() + fp.sum(), eps) * 100.0
70 | OR = tp.sum() / np.maximum(tp.sum() + fn.sum(), eps) * 100.0
71 | OF1 = 2 * OP * OR / np.maximum(OP + OR, eps)
72 | return CP, CR, CF1, OP, OR, OF1
73 |
--------------------------------------------------------------------------------
/mmcls/core/export/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .test import ONNXRuntimeClassifier, TensorRTClassifier
3 |
4 | __all__ = ['ONNXRuntimeClassifier', 'TensorRTClassifier']
5 |
--------------------------------------------------------------------------------
/mmcls/core/export/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import warnings
3 |
4 | import numpy as np
5 | import onnxruntime as ort
6 | import torch
7 |
8 | from mmcls.models.classifiers import BaseClassifier
9 |
10 |
11 | class ONNXRuntimeClassifier(BaseClassifier):
12 | """Wrapper for classifier's inference with ONNXRuntime."""
13 |
14 | def __init__(self, onnx_file, class_names, device_id):
15 | super(ONNXRuntimeClassifier, self).__init__()
16 | sess = ort.InferenceSession(onnx_file)
17 |
18 | providers = ['CPUExecutionProvider']
19 | options = [{}]
20 | is_cuda_available = ort.get_device() == 'GPU'
21 | if is_cuda_available:
22 | providers.insert(0, 'CUDAExecutionProvider')
23 | options.insert(0, {'device_id': device_id})
24 | sess.set_providers(providers, options)
25 |
26 | self.sess = sess
27 | self.CLASSES = class_names
28 | self.device_id = device_id
29 | self.io_binding = sess.io_binding()
30 | self.output_names = [_.name for _ in sess.get_outputs()]
31 | self.is_cuda_available = is_cuda_available
32 |
33 | def simple_test(self, img, img_metas, **kwargs):
34 | raise NotImplementedError('This method is not implemented.')
35 |
36 | def extract_feat(self, imgs):
37 | raise NotImplementedError('This method is not implemented.')
38 |
39 | def forward_train(self, imgs, **kwargs):
40 | raise NotImplementedError('This method is not implemented.')
41 |
42 | def forward_test(self, imgs, img_metas, **kwargs):
43 | input_data = imgs
44 | # set io binding for inputs/outputs
45 | device_type = 'cuda' if self.is_cuda_available else 'cpu'
46 | if not self.is_cuda_available:
47 | input_data = input_data.cpu()
48 | self.io_binding.bind_input(
49 | name='input',
50 | device_type=device_type,
51 | device_id=self.device_id,
52 | element_type=np.float32,
53 | shape=input_data.shape,
54 | buffer_ptr=input_data.data_ptr())
55 |
56 | for name in self.output_names:
57 | self.io_binding.bind_output(name)
58 | # run session to get outputs
59 | self.sess.run_with_iobinding(self.io_binding)
60 | results = self.io_binding.copy_outputs_to_cpu()[0]
61 | return list(results)
62 |
63 |
64 | class TensorRTClassifier(BaseClassifier):
65 |
66 | def __init__(self, trt_file, class_names, device_id):
67 | super(TensorRTClassifier, self).__init__()
68 | from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin
69 | try:
70 | load_tensorrt_plugin()
71 | except (ImportError, ModuleNotFoundError):
72 | warnings.warn('If input model has custom op from mmcv, \
73 | you may have to build mmcv with TensorRT from source.')
74 | model = TRTWraper(
75 | trt_file, input_names=['input'], output_names=['probs'])
76 |
77 | self.model = model
78 | self.device_id = device_id
79 | self.CLASSES = class_names
80 |
81 | def simple_test(self, img, img_metas, **kwargs):
82 | raise NotImplementedError('This method is not implemented.')
83 |
84 | def extract_feat(self, imgs):
85 | raise NotImplementedError('This method is not implemented.')
86 |
87 | def forward_train(self, imgs, **kwargs):
88 | raise NotImplementedError('This method is not implemented.')
89 |
90 | def forward_test(self, imgs, img_metas, **kwargs):
91 | input_data = imgs
92 | with torch.cuda.device(self.device_id), torch.no_grad():
93 | results = self.model({'input': input_data})['probs']
94 | results = results.detach().cpu().numpy()
95 |
96 | return list(results)
97 |
--------------------------------------------------------------------------------
/mmcls/core/hook/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .class_num_check_hook import ClassNumCheckHook
3 | from .lr_updater import CosineAnnealingCooldownLrUpdaterHook
4 | from .precise_bn_hook import PreciseBNHook
5 | from .wandblogger_hook import MMClsWandbHook
6 |
7 | __all__ = [
8 | 'ClassNumCheckHook', 'PreciseBNHook',
9 | 'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmcls/core/hook/class_num_check_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved
2 | from mmcv.runner import IterBasedRunner
3 | from mmcv.runner.hooks import HOOKS, Hook
4 | from mmcv.utils import is_seq_of
5 |
6 |
7 | @HOOKS.register_module()
8 | class ClassNumCheckHook(Hook):
9 |
10 | def _check_head(self, runner, dataset):
11 | """Check whether the `num_classes` in head matches the length of
12 | `CLASSES` in `dataset`.
13 |
14 | Args:
15 | runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object.
16 | dataset (obj: `BaseDataset`): the dataset to check.
17 | """
18 | model = runner.model
19 | if dataset.CLASSES is None:
20 | runner.logger.warning(
21 | f'Please set `CLASSES` '
22 | f'in the {dataset.__class__.__name__} and'
23 | f'check if it is consistent with the `num_classes` '
24 | f'of head')
25 | else:
26 | assert is_seq_of(dataset.CLASSES, str), \
27 | (f'`CLASSES` in {dataset.__class__.__name__}'
28 | f'should be a tuple of str.')
29 | for name, module in model.named_modules():
30 | if hasattr(module, 'num_classes'):
31 | assert module.num_classes == len(dataset.CLASSES), \
32 | (f'The `num_classes` ({module.num_classes}) in '
33 | f'{module.__class__.__name__} of '
34 | f'{model.__class__.__name__} does not matches '
35 | f'the length of `CLASSES` '
36 | f'{len(dataset.CLASSES)}) in '
37 | f'{dataset.__class__.__name__}')
38 |
39 | def before_train_iter(self, runner):
40 | """Check whether the training dataset is compatible with head.
41 |
42 | Args:
43 | runner (obj: `IterBasedRunner`): Iter based Runner.
44 | """
45 | if not isinstance(runner, IterBasedRunner):
46 | return
47 | self._check_head(runner, runner.data_loader._dataloader.dataset)
48 |
49 | def before_val_iter(self, runner):
50 | """Check whether the eval dataset is compatible with head.
51 |
52 | Args:
53 | runner (obj:`IterBasedRunner`): Iter based Runner.
54 | """
55 | if not isinstance(runner, IterBasedRunner):
56 | return
57 | self._check_head(runner, runner.data_loader._dataloader.dataset)
58 |
59 | def before_train_epoch(self, runner):
60 | """Check whether the training dataset is compatible with head.
61 |
62 | Args:
63 | runner (obj:`EpochBasedRunner`): Epoch based Runner.
64 | """
65 | self._check_head(runner, runner.data_loader.dataset)
66 |
67 | def before_val_epoch(self, runner):
68 | """Check whether the eval dataset is compatible with head.
69 |
70 | Args:
71 | runner (obj:`EpochBasedRunner`): Epoch based Runner.
72 | """
73 | self._check_head(runner, runner.data_loader.dataset)
74 |
--------------------------------------------------------------------------------
/mmcls/core/hook/lr_updater.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from math import cos, pi
3 |
4 | from mmcv.runner.hooks import HOOKS, LrUpdaterHook
5 |
6 |
7 | @HOOKS.register_module()
8 | class CosineAnnealingCooldownLrUpdaterHook(LrUpdaterHook):
9 | """Cosine annealing learning rate scheduler with cooldown.
10 |
11 | Args:
12 | min_lr (float, optional): The minimum learning rate after annealing.
13 | Defaults to None.
14 | min_lr_ratio (float, optional): The minimum learning ratio after
15 | nnealing. Defaults to None.
16 | cool_down_ratio (float): The cooldown ratio. Defaults to 0.1.
17 | cool_down_time (int): The cooldown time. Defaults to 10.
18 | by_epoch (bool): If True, the learning rate changes epoch by epoch. If
19 | False, the learning rate changes iter by iter. Defaults to True.
20 | warmup (string, optional): Type of warmup used. It can be None (use no
21 | warmup), 'constant', 'linear' or 'exp'. Defaults to None.
22 | warmup_iters (int): The number of iterations or epochs that warmup
23 | lasts. Defaults to 0.
24 | warmup_ratio (float): LR used at the beginning of warmup equals to
25 | ``warmup_ratio * initial_lr``. Defaults to 0.1.
26 | warmup_by_epoch (bool): If True, the ``warmup_iters``
27 | means the number of epochs that warmup lasts, otherwise means the
28 | number of iteration that warmup lasts. Defaults to False.
29 |
30 | Note:
31 | You need to set one and only one of ``min_lr`` and ``min_lr_ratio``.
32 | """
33 |
34 | def __init__(self,
35 | min_lr=None,
36 | min_lr_ratio=None,
37 | cool_down_ratio=0.1,
38 | cool_down_time=10,
39 | **kwargs):
40 | assert (min_lr is None) ^ (min_lr_ratio is None)
41 | self.min_lr = min_lr
42 | self.min_lr_ratio = min_lr_ratio
43 | self.cool_down_time = cool_down_time
44 | self.cool_down_ratio = cool_down_ratio
45 | super(CosineAnnealingCooldownLrUpdaterHook, self).__init__(**kwargs)
46 |
47 | def get_lr(self, runner, base_lr):
48 | if self.by_epoch:
49 | progress = runner.epoch
50 | max_progress = runner.max_epochs
51 | else:
52 | progress = runner.iter
53 | max_progress = runner.max_iters
54 |
55 | if self.min_lr_ratio is not None:
56 | target_lr = base_lr * self.min_lr_ratio
57 | else:
58 | target_lr = self.min_lr
59 |
60 | if progress > max_progress - self.cool_down_time:
61 | return target_lr * self.cool_down_ratio
62 | else:
63 | max_progress = max_progress - self.cool_down_time
64 |
65 | return annealing_cos(base_lr, target_lr, progress / max_progress)
66 |
67 |
68 | def annealing_cos(start, end, factor, weight=1):
69 | """Calculate annealing cos learning rate.
70 |
71 | Cosine anneal from `weight * start + (1 - weight) * end` to `end` as
72 | percentage goes from 0.0 to 1.0.
73 |
74 | Args:
75 | start (float): The starting learning rate of the cosine annealing.
76 | end (float): The ending learing rate of the cosine annealing.
77 | factor (float): The coefficient of `pi` when calculating the current
78 | percentage. Range from 0.0 to 1.0.
79 | weight (float, optional): The combination factor of `start` and `end`
80 | when calculating the actual starting learning rate. Default to 1.
81 | """
82 | cos_out = cos(pi * factor) + 1
83 | return end + 0.5 * weight * (start - end) * cos_out
84 |
--------------------------------------------------------------------------------
/mmcls/core/optimizers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .lamb import Lamb
3 |
4 | __all__ = [
5 | 'Lamb',
6 | ]
7 |
--------------------------------------------------------------------------------
/mmcls/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .dist_utils import DistOptimizerHook, allreduce_grads, sync_random_seed
3 | from .misc import multi_apply
4 |
5 | __all__ = [
6 | 'allreduce_grads', 'DistOptimizerHook', 'multi_apply', 'sync_random_seed'
7 | ]
8 |
--------------------------------------------------------------------------------
/mmcls/core/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from collections import OrderedDict
3 |
4 | import numpy as np
5 | import torch
6 | import torch.distributed as dist
7 | from mmcv.runner import OptimizerHook, get_dist_info
8 | from torch._utils import (_flatten_dense_tensors, _take_tensors,
9 | _unflatten_dense_tensors)
10 |
11 |
12 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
13 | if bucket_size_mb > 0:
14 | bucket_size_bytes = bucket_size_mb * 1024 * 1024
15 | buckets = _take_tensors(tensors, bucket_size_bytes)
16 | else:
17 | buckets = OrderedDict()
18 | for tensor in tensors:
19 | tp = tensor.type()
20 | if tp not in buckets:
21 | buckets[tp] = []
22 | buckets[tp].append(tensor)
23 | buckets = buckets.values()
24 |
25 | for bucket in buckets:
26 | flat_tensors = _flatten_dense_tensors(bucket)
27 | dist.all_reduce(flat_tensors)
28 | flat_tensors.div_(world_size)
29 | for tensor, synced in zip(
30 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
31 | tensor.copy_(synced)
32 |
33 |
34 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
35 | grads = [
36 | param.grad.data for param in params
37 | if param.requires_grad and param.grad is not None
38 | ]
39 | world_size = dist.get_world_size()
40 | if coalesce:
41 | _allreduce_coalesced(grads, world_size, bucket_size_mb)
42 | else:
43 | for tensor in grads:
44 | dist.all_reduce(tensor.div_(world_size))
45 |
46 |
47 | class DistOptimizerHook(OptimizerHook):
48 |
49 | def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
50 | self.grad_clip = grad_clip
51 | self.coalesce = coalesce
52 | self.bucket_size_mb = bucket_size_mb
53 |
54 | def after_train_iter(self, runner):
55 | runner.optimizer.zero_grad()
56 | runner.outputs['loss'].backward()
57 | if self.grad_clip is not None:
58 | self.clip_grads(runner.model.parameters())
59 | runner.optimizer.step()
60 |
61 |
62 | def sync_random_seed(seed=None, device='cuda'):
63 | """Make sure different ranks share the same seed.
64 |
65 | All workers must call this function, otherwise it will deadlock.
66 | This method is generally used in `DistributedSampler`,
67 | because the seed should be identical across all processes
68 | in the distributed group.
69 |
70 | In distributed sampling, different ranks should sample non-overlapped
71 | data in the dataset. Therefore, this function is used to make sure that
72 | each rank shuffles the data indices in the same order based
73 | on the same seed. Then different ranks could use different indices
74 | to select non-overlapped data from the same data list.
75 |
76 | Args:
77 | seed (int, Optional): The seed. Default to None.
78 | device (str): The device where the seed will be put on.
79 | Default to 'cuda'.
80 |
81 | Returns:
82 | int: Seed to be used.
83 | """
84 | if seed is None:
85 | seed = np.random.randint(2**31)
86 | assert isinstance(seed, int)
87 |
88 | rank, world_size = get_dist_info()
89 |
90 | if world_size == 1:
91 | return seed
92 |
93 | if rank == 0:
94 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
95 | else:
96 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
97 | dist.broadcast(random_num, src=0)
98 | return random_num.item()
99 |
--------------------------------------------------------------------------------
/mmcls/core/utils/misc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from functools import partial
3 |
4 |
5 | def multi_apply(func, *args, **kwargs):
6 | pfunc = partial(func, **kwargs) if kwargs else func
7 | map_results = map(pfunc, *args)
8 | return tuple(map(list, zip(*map_results)))
9 |
--------------------------------------------------------------------------------
/mmcls/core/visualization/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .image import (BaseFigureContextManager, ImshowInfosContextManager,
3 | color_val_matplotlib, imshow_infos)
4 |
5 | __all__ = [
6 | 'BaseFigureContextManager', 'ImshowInfosContextManager', 'imshow_infos',
7 | 'color_val_matplotlib'
8 | ]
9 |
--------------------------------------------------------------------------------
/mmcls/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base_dataset import BaseDataset
3 | from .builder import (DATASETS, PIPELINES, SAMPLERS, build_dataloader,
4 | build_dataset, build_sampler)
5 | from .cifar import CIFAR10, CIFAR100
6 | from .cub import CUB
7 | from .custom import CustomDataset
8 | from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
9 | KFoldDataset, RepeatDataset)
10 | from .imagenet import ImageNet
11 | from .imagenet21k import ImageNet21k
12 | from .mnist import MNIST, FashionMNIST
13 | from .multi_label import MultiLabelDataset
14 | from .samplers import DistributedSampler, RepeatAugSampler
15 | from .voc import VOC
16 |
17 | __all__ = [
18 | 'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST',
19 | 'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
20 | 'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
21 | 'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
22 | 'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset'
23 | ]
24 |
--------------------------------------------------------------------------------
/mmcls/datasets/multi_label.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import List
3 |
4 | import numpy as np
5 |
6 | from mmcls.core import average_performance, mAP
7 | from .base_dataset import BaseDataset
8 |
9 |
10 | class MultiLabelDataset(BaseDataset):
11 | """Multi-label Dataset."""
12 |
13 | def get_cat_ids(self, idx: int) -> List[int]:
14 | """Get category ids by index.
15 |
16 | Args:
17 | idx (int): Index of data.
18 |
19 | Returns:
20 | cat_ids (List[int]): Image categories of specified index.
21 | """
22 | gt_labels = self.data_infos[idx]['gt_label']
23 | cat_ids = np.where(gt_labels == 1)[0].tolist()
24 | return cat_ids
25 |
26 | def evaluate(self,
27 | results,
28 | metric='mAP',
29 | metric_options=None,
30 | indices=None,
31 | logger=None):
32 | """Evaluate the dataset.
33 |
34 | Args:
35 | results (list): Testing results of the dataset.
36 | metric (str | list[str]): Metrics to be evaluated.
37 | Default value is 'mAP'. Options are 'mAP', 'CP', 'CR', 'CF1',
38 | 'OP', 'OR' and 'OF1'.
39 | metric_options (dict, optional): Options for calculating metrics.
40 | Allowed keys are 'k' and 'thr'. Defaults to None
41 | logger (logging.Logger | str, optional): Logger used for printing
42 | related information during evaluation. Defaults to None.
43 |
44 | Returns:
45 | dict: evaluation results
46 | """
47 | if metric_options is None or metric_options == {}:
48 | metric_options = {'thr': 0.5}
49 |
50 | if isinstance(metric, str):
51 | metrics = [metric]
52 | else:
53 | metrics = metric
54 | allowed_metrics = ['mAP', 'CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
55 | eval_results = {}
56 | results = np.vstack(results)
57 | gt_labels = self.get_gt_labels()
58 | if indices is not None:
59 | gt_labels = gt_labels[indices]
60 | num_imgs = len(results)
61 | assert len(gt_labels) == num_imgs, 'dataset testing results should '\
62 | 'be of the same length as gt_labels.'
63 |
64 | invalid_metrics = set(metrics) - set(allowed_metrics)
65 | if len(invalid_metrics) != 0:
66 | raise ValueError(f'metric {invalid_metrics} is not supported.')
67 |
68 | if 'mAP' in metrics:
69 | mAP_value = mAP(results, gt_labels)
70 | eval_results['mAP'] = mAP_value
71 | if len(set(metrics) - {'mAP'}) != 0:
72 | performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1']
73 | performance_values = average_performance(results, gt_labels,
74 | **metric_options)
75 | for k, v in zip(performance_keys, performance_values):
76 | if k in metrics:
77 | eval_results[k] = v
78 |
79 | return eval_results
80 |
--------------------------------------------------------------------------------
/mmcls/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .auto_augment import (AutoAugment, AutoContrast, Brightness,
3 | ColorTransform, Contrast, Cutout, Equalize, Invert,
4 | Posterize, RandAugment, Rotate, Sharpness, Shear,
5 | Solarize, SolarizeAdd, Translate)
6 | from .compose import Compose
7 | from .formatting import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor,
8 | Transpose, to_tensor)
9 | from .loading import LoadImageFromFile
10 | from .transforms import (CenterCrop, ColorJitter, Lighting, Normalize, Pad,
11 | RandomCrop, RandomErasing, RandomFlip,
12 | RandomGrayscale, RandomResizedCrop, Resize)
13 |
14 | __all__ = [
15 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy',
16 | 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop',
17 | 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop',
18 | 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert',
19 | 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize',
20 | 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd',
21 | 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad'
22 | ]
23 |
--------------------------------------------------------------------------------
/mmcls/datasets/pipelines/compose.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from collections.abc import Sequence
3 |
4 | from mmcv.utils import build_from_cfg
5 |
6 | from ..builder import PIPELINES
7 |
8 |
9 | @PIPELINES.register_module()
10 | class Compose(object):
11 | """Compose a data pipeline with a sequence of transforms.
12 |
13 | Args:
14 | transforms (list[dict | callable]):
15 | Either config dicts of transforms or transform objects.
16 | """
17 |
18 | def __init__(self, transforms):
19 | assert isinstance(transforms, Sequence)
20 | self.transforms = []
21 | for transform in transforms:
22 | if isinstance(transform, dict):
23 | transform = build_from_cfg(transform, PIPELINES)
24 | self.transforms.append(transform)
25 | elif callable(transform):
26 | self.transforms.append(transform)
27 | else:
28 | raise TypeError('transform must be callable or a dict, but got'
29 | f' {type(transform)}')
30 |
31 | def __call__(self, data):
32 | for t in self.transforms:
33 | data = t(data)
34 | if data is None:
35 | return None
36 | return data
37 |
38 | def __repr__(self):
39 | format_string = self.__class__.__name__ + '('
40 | for t in self.transforms:
41 | format_string += f'\n {t}'
42 | format_string += '\n)'
43 | return format_string
44 |
--------------------------------------------------------------------------------
/mmcls/datasets/pipelines/loading.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 |
4 | import mmcv
5 | import numpy as np
6 |
7 | from ..builder import PIPELINES
8 |
9 |
10 | @PIPELINES.register_module()
11 | class LoadImageFromFile(object):
12 | """Load an image from file.
13 |
14 | Required keys are "img_prefix" and "img_info" (a dict that must contain the
15 | key "filename"). Added or updated keys are "filename", "img", "img_shape",
16 | "ori_shape" (same as `img_shape`) and "img_norm_cfg" (means=0 and stds=1).
17 |
18 | Args:
19 | to_float32 (bool): Whether to convert the loaded image to a float32
20 | numpy array. If set to False, the loaded image is an uint8 array.
21 | Defaults to False.
22 | color_type (str): The flag argument for :func:`mmcv.imfrombytes()`.
23 | Defaults to 'color'.
24 | file_client_args (dict): Arguments to instantiate a FileClient.
25 | See :class:`mmcv.fileio.FileClient` for details.
26 | Defaults to ``dict(backend='disk')``.
27 | """
28 |
29 | def __init__(self,
30 | to_float32=False,
31 | color_type='color',
32 | file_client_args=dict(backend='disk')):
33 | self.to_float32 = to_float32
34 | self.color_type = color_type
35 | self.file_client_args = file_client_args.copy()
36 | self.file_client = None
37 |
38 | def __call__(self, results):
39 | if self.file_client is None:
40 | self.file_client = mmcv.FileClient(**self.file_client_args)
41 |
42 | if results['img_prefix'] is not None:
43 | filename = osp.join(results['img_prefix'],
44 | results['img_info']['filename'])
45 | else:
46 | filename = results['img_info']['filename']
47 |
48 | img_bytes = self.file_client.get(filename)
49 | img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
50 | if self.to_float32:
51 | img = img.astype(np.float32)
52 |
53 | results['filename'] = filename
54 | results['ori_filename'] = results['img_info']['filename']
55 | results['img'] = img
56 | results['img_shape'] = img.shape
57 | results['ori_shape'] = img.shape
58 | num_channels = 1 if len(img.shape) < 3 else img.shape[2]
59 | results['img_norm_cfg'] = dict(
60 | mean=np.zeros(num_channels, dtype=np.float32),
61 | std=np.ones(num_channels, dtype=np.float32),
62 | to_rgb=False)
63 | return results
64 |
65 | def __repr__(self):
66 | repr_str = (f'{self.__class__.__name__}('
67 | f'to_float32={self.to_float32}, '
68 | f"color_type='{self.color_type}', "
69 | f'file_client_args={self.file_client_args})')
70 | return repr_str
71 |
--------------------------------------------------------------------------------
/mmcls/datasets/samplers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .distributed_sampler import DistributedSampler
3 | from .repeat_aug import RepeatAugSampler
4 |
5 | __all__ = ('DistributedSampler', 'RepeatAugSampler')
6 |
--------------------------------------------------------------------------------
/mmcls/datasets/samplers/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from torch.utils.data import DistributedSampler as _DistributedSampler
4 |
5 | from mmcls.core.utils import sync_random_seed
6 | from mmcls.datasets import SAMPLERS
7 |
8 |
9 | @SAMPLERS.register_module()
10 | class DistributedSampler(_DistributedSampler):
11 |
12 | def __init__(self,
13 | dataset,
14 | num_replicas=None,
15 | rank=None,
16 | shuffle=True,
17 | round_up=True,
18 | seed=0):
19 | super().__init__(dataset, num_replicas=num_replicas, rank=rank)
20 | self.shuffle = shuffle
21 | self.round_up = round_up
22 | if self.round_up:
23 | self.total_size = self.num_samples * self.num_replicas
24 | else:
25 | self.total_size = len(self.dataset)
26 |
27 | # In distributed sampling, different ranks should sample
28 | # non-overlapped data in the dataset. Therefore, this function
29 | # is used to make sure that each rank shuffles the data indices
30 | # in the same order based on the same seed. Then different ranks
31 | # could use different indices to select non-overlapped data from the
32 | # same data list.
33 | self.seed = sync_random_seed(seed)
34 |
35 | def __iter__(self):
36 | # deterministically shuffle based on epoch
37 | if self.shuffle:
38 | g = torch.Generator()
39 | # When :attr:`shuffle=True`, this ensures all replicas
40 | # use a different random ordering for each epoch.
41 | # Otherwise, the next iteration of this sampler will
42 | # yield the same ordering.
43 | g.manual_seed(self.epoch + self.seed)
44 | indices = torch.randperm(len(self.dataset), generator=g).tolist()
45 | else:
46 | indices = torch.arange(len(self.dataset)).tolist()
47 |
48 | # add extra samples to make it evenly divisible
49 | if self.round_up:
50 | indices = (
51 | indices *
52 | int(self.total_size / len(indices) + 1))[:self.total_size]
53 | assert len(indices) == self.total_size
54 |
55 | # subsample
56 | indices = indices[self.rank:self.total_size:self.num_replicas]
57 | if self.round_up:
58 | assert len(indices) == self.num_samples
59 |
60 | return iter(indices)
61 |
--------------------------------------------------------------------------------
/mmcls/datasets/voc.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 | import xml.etree.ElementTree as ET
4 |
5 | import mmcv
6 | import numpy as np
7 |
8 | from .builder import DATASETS
9 | from .multi_label import MultiLabelDataset
10 |
11 |
12 | @DATASETS.register_module()
13 | class VOC(MultiLabelDataset):
14 | """`Pascal VOC `_ Dataset."""
15 |
16 | CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
17 | 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
18 | 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
19 | 'tvmonitor')
20 |
21 | def __init__(self, **kwargs):
22 | super(VOC, self).__init__(**kwargs)
23 | if 'VOC2007' in self.data_prefix:
24 | self.year = 2007
25 | else:
26 | raise ValueError('Cannot infer dataset year from img_prefix.')
27 |
28 | def load_annotations(self):
29 | """Load annotations.
30 |
31 | Returns:
32 | list[dict]: Annotation info from XML file.
33 | """
34 | data_infos = []
35 | img_ids = mmcv.list_from_file(self.ann_file)
36 | for img_id in img_ids:
37 | filename = f'JPEGImages/{img_id}.jpg'
38 | xml_path = osp.join(self.data_prefix, 'Annotations',
39 | f'{img_id}.xml')
40 | tree = ET.parse(xml_path)
41 | root = tree.getroot()
42 | labels = []
43 | labels_difficult = []
44 | for obj in root.findall('object'):
45 | label_name = obj.find('name').text
46 | # in case customized dataset has wrong labels
47 | # or CLASSES has been override.
48 | if label_name not in self.CLASSES:
49 | continue
50 | label = self.class_to_idx[label_name]
51 | difficult = int(obj.find('difficult').text)
52 | if difficult:
53 | labels_difficult.append(label)
54 | else:
55 | labels.append(label)
56 |
57 | gt_label = np.zeros(len(self.CLASSES))
58 | # The order cannot be swapped for the case where multiple objects
59 | # of the same kind exist and some are difficult.
60 | gt_label[labels_difficult] = -1
61 | gt_label[labels] = 1
62 |
63 | info = dict(
64 | img_prefix=self.data_prefix,
65 | img_info=dict(filename=filename),
66 | gt_label=gt_label.astype(np.int8))
67 | data_infos.append(info)
68 |
69 | return data_infos
70 |
--------------------------------------------------------------------------------
/mmcls/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .backbones import * # noqa: F401,F403
3 | from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS,
4 | build_backbone, build_classifier, build_head, build_loss,
5 | build_neck)
6 | from .classifiers import * # noqa: F401,F403
7 | from .heads import * # noqa: F401,F403
8 | from .losses import * # noqa: F401,F403
9 | from .necks import * # noqa: F401,F403
10 |
11 | __all__ = [
12 | 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone',
13 | 'build_head', 'build_neck', 'build_loss', 'build_classifier'
14 | ]
15 |
--------------------------------------------------------------------------------
/mmcls/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .alexnet import AlexNet
3 | from .conformer import Conformer
4 | from .convmixer import ConvMixer
5 | from .convnext import ConvNeXt
6 | from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
7 | from .deit import DistilledVisionTransformer
8 | from .densenet import DenseNet
9 | from .efficientnet import EfficientNet
10 | from .hrnet import HRNet
11 | from .lenet import LeNet5
12 | from .mlp_mixer import MlpMixer
13 | from .mobilenet_v2 import MobileNetV2
14 | from .mobilenet_v3 import MobileNetV3
15 | from .poolformer import PoolFormer
16 | from .regnet import RegNet
17 | from .repmlp import RepMLPNet
18 | from .repvgg import RepVGG
19 | from .res2net import Res2Net
20 | from .resnest import ResNeSt
21 | from .resnet import ResNet, ResNetV1c, ResNetV1d
22 | from .resnet_cifar import ResNet_CIFAR
23 | from .resnext import ResNeXt
24 | from .seresnet import SEResNet
25 | from .seresnext import SEResNeXt
26 | from .shufflenet_v1 import ShuffleNetV1
27 | from .shufflenet_v2 import ShuffleNetV2
28 | from .swin_transformer import SwinTransformer
29 | from .t2t_vit import T2T_ViT
30 | from .timm_backbone import TIMMBackbone
31 | from .tnt import TNT
32 | from .twins import PCPVT, SVT
33 | from .van import VAN
34 | from .vgg import VGG
35 | from .vision_transformer import VisionTransformer
36 |
37 | __all__ = [
38 | 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
39 | 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
40 | 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
41 | 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
42 | 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT',
43 | 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer',
44 | 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet',
45 | 'PoolFormer', 'DenseNet', 'VAN'
46 | ]
47 |
--------------------------------------------------------------------------------
/mmcls/models/backbones/alexnet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 |
4 | from ..builder import BACKBONES
5 | from .base_backbone import BaseBackbone
6 |
7 |
8 | @BACKBONES.register_module()
9 | class AlexNet(BaseBackbone):
10 | """`AlexNet `_ backbone.
11 |
12 | The input for AlexNet is a 224x224 RGB image.
13 |
14 | Args:
15 | num_classes (int): number of classes for classification.
16 | The default value is -1, which uses the backbone as
17 | a feature extractor without the top classifier.
18 | """
19 |
20 | def __init__(self, num_classes=-1):
21 | super(AlexNet, self).__init__()
22 | self.num_classes = num_classes
23 | self.features = nn.Sequential(
24 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
25 | nn.ReLU(inplace=True),
26 | nn.MaxPool2d(kernel_size=3, stride=2),
27 | nn.Conv2d(64, 192, kernel_size=5, padding=2),
28 | nn.ReLU(inplace=True),
29 | nn.MaxPool2d(kernel_size=3, stride=2),
30 | nn.Conv2d(192, 384, kernel_size=3, padding=1),
31 | nn.ReLU(inplace=True),
32 | nn.Conv2d(384, 256, kernel_size=3, padding=1),
33 | nn.ReLU(inplace=True),
34 | nn.Conv2d(256, 256, kernel_size=3, padding=1),
35 | nn.ReLU(inplace=True),
36 | nn.MaxPool2d(kernel_size=3, stride=2),
37 | )
38 | if self.num_classes > 0:
39 | self.classifier = nn.Sequential(
40 | nn.Dropout(),
41 | nn.Linear(256 * 6 * 6, 4096),
42 | nn.ReLU(inplace=True),
43 | nn.Dropout(),
44 | nn.Linear(4096, 4096),
45 | nn.ReLU(inplace=True),
46 | nn.Linear(4096, num_classes),
47 | )
48 |
49 | def forward(self, x):
50 |
51 | x = self.features(x)
52 | if self.num_classes > 0:
53 | x = x.view(x.size(0), 256 * 6 * 6)
54 | x = self.classifier(x)
55 |
56 | return (x, )
57 |
--------------------------------------------------------------------------------
/mmcls/models/backbones/base_backbone.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 |
4 | from mmcv.runner import BaseModule
5 |
6 |
7 | class BaseBackbone(BaseModule, metaclass=ABCMeta):
8 | """Base backbone.
9 |
10 | This class defines the basic functions of a backbone. Any backbone that
11 | inherits this class should at least define its own `forward` function.
12 | """
13 |
14 | def __init__(self, init_cfg=None):
15 | super(BaseBackbone, self).__init__(init_cfg)
16 |
17 | @abstractmethod
18 | def forward(self, x):
19 | """Forward computation.
20 |
21 | Args:
22 | x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of
23 | Torch.tensor, containing input data for forward computation.
24 | """
25 | pass
26 |
27 | def train(self, mode=True):
28 | """Set module status before forward computation.
29 |
30 | Args:
31 | mode (bool): Whether it is train_mode or test_mode
32 | """
33 | super(BaseBackbone, self).train(mode)
34 |
--------------------------------------------------------------------------------
/mmcls/models/backbones/lenet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 |
4 | from ..builder import BACKBONES
5 | from .base_backbone import BaseBackbone
6 |
7 |
8 | @BACKBONES.register_module()
9 | class LeNet5(BaseBackbone):
10 | """`LeNet5 `_ backbone.
11 |
12 | The input for LeNet-5 is a 32×32 grayscale image.
13 |
14 | Args:
15 | num_classes (int): number of classes for classification.
16 | The default value is -1, which uses the backbone as
17 | a feature extractor without the top classifier.
18 | """
19 |
20 | def __init__(self, num_classes=-1):
21 | super(LeNet5, self).__init__()
22 | self.num_classes = num_classes
23 | self.features = nn.Sequential(
24 | nn.Conv2d(1, 6, kernel_size=5, stride=1), nn.Tanh(),
25 | nn.AvgPool2d(kernel_size=2),
26 | nn.Conv2d(6, 16, kernel_size=5, stride=1), nn.Tanh(),
27 | nn.AvgPool2d(kernel_size=2),
28 | nn.Conv2d(16, 120, kernel_size=5, stride=1), nn.Tanh())
29 | if self.num_classes > 0:
30 | self.classifier = nn.Sequential(
31 | nn.Linear(120, 84),
32 | nn.Tanh(),
33 | nn.Linear(84, num_classes),
34 | )
35 |
36 | def forward(self, x):
37 |
38 | x = self.features(x)
39 | if self.num_classes > 0:
40 | x = self.classifier(x.squeeze())
41 |
42 | return (x, )
43 |
--------------------------------------------------------------------------------
/mmcls/models/backbones/resnet_cifar.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmcv.cnn import build_conv_layer, build_norm_layer
4 |
5 | from ..builder import BACKBONES
6 | from .resnet import ResNet
7 |
8 |
9 | @BACKBONES.register_module()
10 | class ResNet_CIFAR(ResNet):
11 | """ResNet backbone for CIFAR.
12 |
13 | Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in
14 | conv1, and does not apply MaxPoolinng after stem. It has been proven to
15 | be more efficient than standard ResNet in other public codebase, e.g.,
16 | `https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`.
17 |
18 | Args:
19 | depth (int): Network depth, from {18, 34, 50, 101, 152}.
20 | in_channels (int): Number of input image channels. Default: 3.
21 | stem_channels (int): Output channels of the stem layer. Default: 64.
22 | base_channels (int): Middle channels of the first stage. Default: 64.
23 | num_stages (int): Stages of the network. Default: 4.
24 | strides (Sequence[int]): Strides of the first block of each stage.
25 | Default: ``(1, 2, 2, 2)``.
26 | dilations (Sequence[int]): Dilation of each stage.
27 | Default: ``(1, 1, 1, 1)``.
28 | out_indices (Sequence[int]): Output from which stages. If only one
29 | stage is specified, a single tensor (feature map) is returned,
30 | otherwise multiple stages are specified, a tuple of tensors will
31 | be returned. Default: ``(3, )``.
32 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
33 | layer is the 3x3 conv layer, otherwise the stride-two layer is
34 | the first 1x1 conv layer.
35 | deep_stem (bool): This network has specific designed stem, thus it is
36 | asserted to be False.
37 | avg_down (bool): Use AvgPool instead of stride conv when
38 | downsampling in the bottleneck. Default: False.
39 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
40 | -1 means not freezing any parameters. Default: -1.
41 | conv_cfg (dict | None): The config dict for conv layers. Default: None.
42 | norm_cfg (dict): The config dict for norm layers.
43 | norm_eval (bool): Whether to set norm layers to eval mode, namely,
44 | freeze running stats (mean and var). Note: Effect on Batch Norm
45 | and its variants only. Default: False.
46 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some
47 | memory while slowing down the training speed. Default: False.
48 | zero_init_residual (bool): Whether to use zero init for last norm layer
49 | in resblocks to let them behave as identity. Default: True.
50 | """
51 |
52 | def __init__(self, depth, deep_stem=False, **kwargs):
53 | super(ResNet_CIFAR, self).__init__(
54 | depth, deep_stem=deep_stem, **kwargs)
55 | assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem'
56 |
57 | def _make_stem_layer(self, in_channels, base_channels):
58 | self.conv1 = build_conv_layer(
59 | self.conv_cfg,
60 | in_channels,
61 | base_channels,
62 | kernel_size=3,
63 | stride=1,
64 | padding=1,
65 | bias=False)
66 | self.norm1_name, norm1 = build_norm_layer(
67 | self.norm_cfg, base_channels, postfix=1)
68 | self.add_module(self.norm1_name, norm1)
69 | self.relu = nn.ReLU(inplace=True)
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = self.norm1(x)
74 | x = self.relu(x)
75 | outs = []
76 | for i, layer_name in enumerate(self.res_layers):
77 | res_layer = getattr(self, layer_name)
78 | x = res_layer(x)
79 | if i in self.out_indices:
80 | outs.append(x)
81 | return tuple(outs)
82 |
--------------------------------------------------------------------------------
/mmcls/models/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.cnn import MODELS as MMCV_MODELS
3 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION
4 | from mmcv.utils import Registry
5 |
6 | MODELS = Registry('models', parent=MMCV_MODELS)
7 |
8 | BACKBONES = MODELS
9 | NECKS = MODELS
10 | HEADS = MODELS
11 | LOSSES = MODELS
12 | CLASSIFIERS = MODELS
13 |
14 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION)
15 |
16 |
17 | def build_backbone(cfg):
18 | """Build backbone."""
19 | return BACKBONES.build(cfg)
20 |
21 |
22 | def build_neck(cfg):
23 | """Build neck."""
24 | return NECKS.build(cfg)
25 |
26 |
27 | def build_head(cfg):
28 | """Build head."""
29 | return HEADS.build(cfg)
30 |
31 |
32 | def build_loss(cfg):
33 | """Build loss."""
34 | return LOSSES.build(cfg)
35 |
36 |
37 | def build_classifier(cfg):
38 | return CLASSIFIERS.build(cfg)
39 |
--------------------------------------------------------------------------------
/mmcls/models/classifiers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base import BaseClassifier
3 | from .image import ImageClassifier
4 |
5 | __all__ = ['BaseClassifier', 'ImageClassifier']
6 |
--------------------------------------------------------------------------------
/mmcls/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .cls_head import ClsHead
3 | from .conformer_head import ConformerHead
4 | from .deit_head import DeiTClsHead
5 | from .linear_head import LinearClsHead
6 | from .multi_label_head import MultiLabelClsHead
7 | from .multi_label_linear_head import MultiLabelLinearClsHead
8 | from .stacked_head import StackedLinearClsHead
9 | from .vision_transformer_head import VisionTransformerClsHead
10 |
11 | __all__ = [
12 | 'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
13 | 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
14 | 'ConformerHead'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmcls/models/heads/base_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 |
4 | from mmcv.runner import BaseModule
5 |
6 |
7 | class BaseHead(BaseModule, metaclass=ABCMeta):
8 | """Base head."""
9 |
10 | def __init__(self, init_cfg=None):
11 | super(BaseHead, self).__init__(init_cfg)
12 |
13 | @abstractmethod
14 | def forward_train(self, x, gt_label, **kwargs):
15 | pass
16 |
--------------------------------------------------------------------------------
/mmcls/models/heads/linear_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from ..builder import HEADS
6 | from .cls_head import ClsHead
7 |
8 |
9 | @HEADS.register_module()
10 | class LinearClsHead(ClsHead):
11 | """Linear classifier head.
12 |
13 | Args:
14 | num_classes (int): Number of categories excluding the background
15 | category.
16 | in_channels (int): Number of channels in the input feature map.
17 | init_cfg (dict | optional): The extra init config of layers.
18 | Defaults to use dict(type='Normal', layer='Linear', std=0.01).
19 | """
20 |
21 | def __init__(self,
22 | num_classes,
23 | in_channels,
24 | init_cfg=dict(type='Normal', layer='Linear', std=0.01),
25 | *args,
26 | **kwargs):
27 | super(LinearClsHead, self).__init__(init_cfg=init_cfg, *args, **kwargs)
28 |
29 | self.in_channels = in_channels
30 | self.num_classes = num_classes
31 |
32 | if self.num_classes <= 0:
33 | raise ValueError(
34 | f'num_classes={num_classes} must be a positive integer')
35 |
36 | self.fc = nn.Linear(self.in_channels, self.num_classes)
37 |
38 | def pre_logits(self, x):
39 | if isinstance(x, tuple):
40 | x = x[-1]
41 | return x
42 |
43 | def simple_test(self, x, softmax=True, post_process=True):
44 | """Inference without augmentation.
45 |
46 | Args:
47 | x (tuple[Tensor]): The input features.
48 | Multi-stage inputs are acceptable but only the last stage will
49 | be used to classify. The shape of every item should be
50 | ``(num_samples, in_channels)``.
51 | softmax (bool): Whether to softmax the classification score.
52 | post_process (bool): Whether to do post processing the
53 | inference results. It will convert the output to a list.
54 |
55 | Returns:
56 | Tensor | list: The inference results.
57 |
58 | - If no post processing, the output is a tensor with shape
59 | ``(num_samples, num_classes)``.
60 | - If post processing, the output is a multi-dimentional list of
61 | float and the dimensions are ``(num_samples, num_classes)``.
62 | """
63 | x = self.pre_logits(x)
64 | cls_score = self.fc(x)
65 |
66 | if softmax:
67 | pred = (
68 | F.softmax(cls_score, dim=1) if cls_score is not None else None)
69 | else:
70 | pred = cls_score
71 |
72 | if post_process:
73 | return self.post_process(pred)
74 | else:
75 | return pred
76 |
77 | def forward_train(self, x, gt_label, **kwargs):
78 | x = self.pre_logits(x)
79 | cls_score = self.fc(x)
80 | losses = self.loss(cls_score, gt_label, **kwargs)
81 | return losses
82 |
--------------------------------------------------------------------------------
/mmcls/models/heads/multi_label_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 | from ..builder import HEADS, build_loss
5 | from ..utils import is_tracing
6 | from .base_head import BaseHead
7 |
8 |
9 | @HEADS.register_module()
10 | class MultiLabelClsHead(BaseHead):
11 | """Classification head for multilabel task.
12 |
13 | Args:
14 | loss (dict): Config of classification loss.
15 | """
16 |
17 | def __init__(self,
18 | loss=dict(
19 | type='CrossEntropyLoss',
20 | use_sigmoid=True,
21 | reduction='mean',
22 | loss_weight=1.0),
23 | init_cfg=None):
24 | super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg)
25 |
26 | assert isinstance(loss, dict)
27 |
28 | self.compute_loss = build_loss(loss)
29 |
30 | def loss(self, cls_score, gt_label):
31 | gt_label = gt_label.type_as(cls_score)
32 | num_samples = len(cls_score)
33 | losses = dict()
34 |
35 | # map difficult examples to positive ones
36 | _gt_label = torch.abs(gt_label)
37 | # compute loss
38 | loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples)
39 | losses['loss'] = loss
40 | return losses
41 |
42 | def forward_train(self, cls_score, gt_label, **kwargs):
43 | if isinstance(cls_score, tuple):
44 | cls_score = cls_score[-1]
45 | gt_label = gt_label.type_as(cls_score)
46 | losses = self.loss(cls_score, gt_label, **kwargs)
47 | return losses
48 |
49 | def pre_logits(self, x):
50 | if isinstance(x, tuple):
51 | x = x[-1]
52 |
53 | from mmcls.utils import get_root_logger
54 | logger = get_root_logger()
55 | logger.warning(
56 | 'The input of MultiLabelClsHead should be already logits. '
57 | 'Please modify the backbone if you want to get pre-logits feature.'
58 | )
59 | return x
60 |
61 | def simple_test(self, x, sigmoid=True, post_process=True):
62 | """Inference without augmentation.
63 |
64 | Args:
65 | cls_score (tuple[Tensor]): The input classification score logits.
66 | Multi-stage inputs are acceptable but only the last stage will
67 | be used to classify. The shape of every item should be
68 | ``(num_samples, num_classes)``.
69 | sigmoid (bool): Whether to sigmoid the classification score.
70 | post_process (bool): Whether to do post processing the
71 | inference results. It will convert the output to a list.
72 |
73 | Returns:
74 | Tensor | list: The inference results.
75 |
76 | - If no post processing, the output is a tensor with shape
77 | ``(num_samples, num_classes)``.
78 | - If post processing, the output is a multi-dimentional list of
79 | float and the dimensions are ``(num_samples, num_classes)``.
80 | """
81 | if isinstance(x, tuple):
82 | x = x[-1]
83 |
84 | if sigmoid:
85 | pred = torch.sigmoid(x) if x is not None else None
86 | else:
87 | pred = x
88 |
89 | if post_process:
90 | return self.post_process(pred)
91 | else:
92 | return pred
93 |
94 | def post_process(self, pred):
95 | on_trace = is_tracing()
96 | if torch.onnx.is_in_onnx_export() or on_trace:
97 | return pred
98 | pred = list(pred.detach().cpu().numpy())
99 | return pred
100 |
--------------------------------------------------------------------------------
/mmcls/models/heads/multi_label_linear_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 |
5 | from ..builder import HEADS
6 | from .multi_label_head import MultiLabelClsHead
7 |
8 |
9 | @HEADS.register_module()
10 | class MultiLabelLinearClsHead(MultiLabelClsHead):
11 | """Linear classification head for multilabel task.
12 |
13 | Args:
14 | num_classes (int): Number of categories.
15 | in_channels (int): Number of channels in the input feature map.
16 | loss (dict): Config of classification loss.
17 | init_cfg (dict | optional): The extra init config of layers.
18 | Defaults to use dict(type='Normal', layer='Linear', std=0.01).
19 | """
20 |
21 | def __init__(self,
22 | num_classes,
23 | in_channels,
24 | loss=dict(
25 | type='CrossEntropyLoss',
26 | use_sigmoid=True,
27 | reduction='mean',
28 | loss_weight=1.0),
29 | init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
30 | super(MultiLabelLinearClsHead, self).__init__(
31 | loss=loss, init_cfg=init_cfg)
32 |
33 | if num_classes <= 0:
34 | raise ValueError(
35 | f'num_classes={num_classes} must be a positive integer')
36 |
37 | self.in_channels = in_channels
38 | self.num_classes = num_classes
39 |
40 | self.fc = nn.Linear(self.in_channels, self.num_classes)
41 |
42 | def pre_logits(self, x):
43 | if isinstance(x, tuple):
44 | x = x[-1]
45 | return x
46 |
47 | def forward_train(self, x, gt_label, **kwargs):
48 | x = self.pre_logits(x)
49 | gt_label = gt_label.type_as(x)
50 | cls_score = self.fc(x)
51 | losses = self.loss(cls_score, gt_label, **kwargs)
52 | return losses
53 |
54 | def simple_test(self, x, sigmoid=True, post_process=True):
55 | """Inference without augmentation.
56 |
57 | Args:
58 | x (tuple[Tensor]): The input features.
59 | Multi-stage inputs are acceptable but only the last stage will
60 | be used to classify. The shape of every item should be
61 | ``(num_samples, in_channels)``.
62 | sigmoid (bool): Whether to sigmoid the classification score.
63 | post_process (bool): Whether to do post processing the
64 | inference results. It will convert the output to a list.
65 |
66 | Returns:
67 | Tensor | list: The inference results.
68 |
69 | - If no post processing, the output is a tensor with shape
70 | ``(num_samples, num_classes)``.
71 | - If post processing, the output is a multi-dimentional list of
72 | float and the dimensions are ``(num_samples, num_classes)``.
73 | """
74 | x = self.pre_logits(x)
75 | cls_score = self.fc(x)
76 |
77 | if sigmoid:
78 | pred = torch.sigmoid(cls_score) if cls_score is not None else None
79 | else:
80 | pred = cls_score
81 |
82 | if post_process:
83 | return self.post_process(pred)
84 | else:
85 | return pred
86 |
--------------------------------------------------------------------------------
/mmcls/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .accuracy import Accuracy, accuracy
3 | from .asymmetric_loss import AsymmetricLoss, asymmetric_loss
4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
5 | cross_entropy)
6 | from .focal_loss import FocalLoss, sigmoid_focal_loss
7 | from .label_smooth_loss import LabelSmoothLoss
8 | from .seesaw_loss import SeesawLoss
9 | from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss,
10 | weighted_loss)
11 |
12 | __all__ = [
13 | 'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss',
14 | 'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
15 | 'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss',
16 | 'sigmoid_focal_loss', 'convert_to_one_hot', 'SeesawLoss'
17 | ]
18 |
--------------------------------------------------------------------------------
/mmcls/models/necks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .gap import GlobalAveragePooling
3 | from .gem import GeneralizedMeanPooling
4 | from .hr_fuse import HRFuseScales
5 |
6 | __all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales']
7 |
--------------------------------------------------------------------------------
/mmcls/models/necks/gap.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 |
5 | from ..builder import NECKS
6 |
7 |
8 | @NECKS.register_module()
9 | class GlobalAveragePooling(nn.Module):
10 | """Global Average Pooling neck.
11 |
12 | Note that we use `view` to remove extra channel after pooling. We do not
13 | use `squeeze` as it will also remove the batch dimension when the tensor
14 | has a batch dimension of size 1, which can lead to unexpected errors.
15 |
16 | Args:
17 | dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}.
18 | Default: 2
19 | """
20 |
21 | def __init__(self, dim=2):
22 | super(GlobalAveragePooling, self).__init__()
23 | assert dim in [1, 2, 3], 'GlobalAveragePooling dim only support ' \
24 | f'{1, 2, 3}, get {dim} instead.'
25 | if dim == 1:
26 | self.gap = nn.AdaptiveAvgPool1d(1)
27 | elif dim == 2:
28 | self.gap = nn.AdaptiveAvgPool2d((1, 1))
29 | else:
30 | self.gap = nn.AdaptiveAvgPool3d((1, 1, 1))
31 |
32 | def init_weights(self):
33 | pass
34 |
35 | def forward(self, inputs):
36 | if isinstance(inputs, tuple):
37 | outs = tuple([self.gap(x) for x in inputs])
38 | outs = tuple(
39 | [out.view(x.size(0), -1) for out, x in zip(outs, inputs)])
40 | elif isinstance(inputs, torch.Tensor):
41 | outs = self.gap(inputs)
42 | outs = outs.view(inputs.size(0), -1)
43 | else:
44 | raise TypeError('neck inputs should be tuple or torch.tensor')
45 | return outs
46 |
--------------------------------------------------------------------------------
/mmcls/models/necks/gem.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from torch import Tensor, nn
4 | from torch.nn import functional as F
5 | from torch.nn.parameter import Parameter
6 |
7 | from ..builder import NECKS
8 |
9 |
10 | def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor:
11 | if clamp:
12 | x = x.clamp(min=eps)
13 | return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p)
14 |
15 |
16 | @NECKS.register_module()
17 | class GeneralizedMeanPooling(nn.Module):
18 | """Generalized Mean Pooling neck.
19 |
20 | Note that we use `view` to remove extra channel after pooling. We do not
21 | use `squeeze` as it will also remove the batch dimension when the tensor
22 | has a batch dimension of size 1, which can lead to unexpected errors.
23 |
24 | Args:
25 | p (float): Parameter value.
26 | Default: 3.
27 | eps (float): epsilon.
28 | Default: 1e-6
29 | clamp (bool): Use clamp before pooling.
30 | Default: True
31 | """
32 |
33 | def __init__(self, p=3., eps=1e-6, clamp=True):
34 | assert p >= 1, "'p' must be a value greater then 1"
35 | super(GeneralizedMeanPooling, self).__init__()
36 | self.p = Parameter(torch.ones(1) * p)
37 | self.eps = eps
38 | self.clamp = clamp
39 |
40 | def forward(self, inputs):
41 | if isinstance(inputs, tuple):
42 | outs = tuple([
43 | gem(x, p=self.p, eps=self.eps, clamp=self.clamp)
44 | for x in inputs
45 | ])
46 | outs = tuple(
47 | [out.view(x.size(0), -1) for out, x in zip(outs, inputs)])
48 | elif isinstance(inputs, torch.Tensor):
49 | outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp)
50 | outs = outs.view(inputs.size(0), -1)
51 | else:
52 | raise TypeError('neck inputs should be tuple or torch.tensor')
53 | return outs
54 |
--------------------------------------------------------------------------------
/mmcls/models/necks/hr_fuse.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmcv.cnn.bricks import ConvModule
4 | from mmcv.runner import BaseModule
5 |
6 | from ..backbones.resnet import Bottleneck, ResLayer
7 | from ..builder import NECKS
8 |
9 |
10 | @NECKS.register_module()
11 | class HRFuseScales(BaseModule):
12 | """Fuse feature map of multiple scales in HRNet.
13 |
14 | Args:
15 | in_channels (list[int]): The input channels of all scales.
16 | out_channels (int): The channels of fused feature map.
17 | Defaults to 2048.
18 | norm_cfg (dict): dictionary to construct norm layers.
19 | Defaults to ``dict(type='BN', momentum=0.1)``.
20 | init_cfg (dict | list[dict], optional): Initialization config dict.
21 | Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``.
22 | """
23 |
24 | def __init__(self,
25 | in_channels,
26 | out_channels=2048,
27 | norm_cfg=dict(type='BN', momentum=0.1),
28 | init_cfg=dict(type='Normal', layer='Linear', std=0.01)):
29 | super(HRFuseScales, self).__init__(init_cfg=init_cfg)
30 | self.in_channels = in_channels
31 | self.out_channels = out_channels
32 | self.norm_cfg = norm_cfg
33 |
34 | block_type = Bottleneck
35 | out_channels = [128, 256, 512, 1024]
36 |
37 | # Increase the channels on each resolution
38 | # from C, 2C, 4C, 8C to 128, 256, 512, 1024
39 | increase_layers = []
40 | for i in range(len(in_channels)):
41 | increase_layers.append(
42 | ResLayer(
43 | block_type,
44 | in_channels=in_channels[i],
45 | out_channels=out_channels[i],
46 | num_blocks=1,
47 | stride=1,
48 | ))
49 | self.increase_layers = nn.ModuleList(increase_layers)
50 |
51 | # Downsample feature maps in each scale.
52 | downsample_layers = []
53 | for i in range(len(in_channels) - 1):
54 | downsample_layers.append(
55 | ConvModule(
56 | in_channels=out_channels[i],
57 | out_channels=out_channels[i + 1],
58 | kernel_size=3,
59 | stride=2,
60 | padding=1,
61 | norm_cfg=self.norm_cfg,
62 | bias=False,
63 | ))
64 | self.downsample_layers = nn.ModuleList(downsample_layers)
65 |
66 | # The final conv block before final classifier linear layer.
67 | self.final_layer = ConvModule(
68 | in_channels=out_channels[3],
69 | out_channels=self.out_channels,
70 | kernel_size=1,
71 | norm_cfg=self.norm_cfg,
72 | bias=False,
73 | )
74 |
75 | def forward(self, x):
76 | assert isinstance(x, tuple) and len(x) == len(self.in_channels)
77 |
78 | feat = self.increase_layers[0](x[0])
79 | for i in range(len(self.downsample_layers)):
80 | feat = self.downsample_layers[i](feat) + \
81 | self.increase_layers[i + 1](x[i + 1])
82 |
83 | return (self.final_layer(feat), )
84 |
--------------------------------------------------------------------------------
/mmcls/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .attention import MultiheadAttention, ShiftWindowMSA
3 | from .augment.augments import Augments
4 | from .channel_shuffle import channel_shuffle
5 | from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed,
6 | resize_relative_position_bias_table)
7 | from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple
8 | from .inverted_residual import InvertedResidual
9 | from .make_divisible import make_divisible
10 | from .position_encoding import ConditionalPositionEncoding
11 | from .se_layer import SELayer
12 |
13 | __all__ = [
14 | 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer',
15 | 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed',
16 | 'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing',
17 | 'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed',
18 | 'resize_relative_position_bias_table'
19 | ]
20 |
--------------------------------------------------------------------------------
/mmcls/models/utils/augment/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .augments import Augments
3 | from .cutmix import BatchCutMixLayer
4 | from .identity import Identity
5 | from .mixup import BatchMixupLayer
6 | from .resizemix import BatchResizeMixLayer
7 |
8 | __all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer',
9 | 'BatchResizeMixLayer')
10 |
--------------------------------------------------------------------------------
/mmcls/models/utils/augment/augments.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import random
3 |
4 | import numpy as np
5 |
6 | from .builder import build_augment
7 |
8 |
9 | class Augments(object):
10 | """Data augments.
11 |
12 | We implement some data augmentation methods, such as mixup, cutmix.
13 |
14 | Args:
15 | augments_cfg (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`):
16 | Config dict of augments
17 |
18 | Example:
19 | >>> augments_cfg = [
20 | dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5),
21 | dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3)
22 | ]
23 | >>> augments = Augments(augments_cfg)
24 | >>> imgs = torch.randn(16, 3, 32, 32)
25 | >>> label = torch.randint(0, 10, (16, ))
26 | >>> imgs, label = augments(imgs, label)
27 |
28 | To decide which augmentation within Augments block is used
29 | the following rule is applied.
30 | We pick augmentation based on the probabilities. In the example above,
31 | we decide if we should use BatchCutMix with probability 0.5,
32 | BatchMixup 0.3. As Identity is not in augments_cfg, we use Identity with
33 | probability 1 - 0.5 - 0.3 = 0.2.
34 | """
35 |
36 | def __init__(self, augments_cfg):
37 | super(Augments, self).__init__()
38 |
39 | if isinstance(augments_cfg, dict):
40 | augments_cfg = [augments_cfg]
41 |
42 | assert len(augments_cfg) > 0, \
43 | 'The length of augments_cfg should be positive.'
44 | self.augments = [build_augment(cfg) for cfg in augments_cfg]
45 | self.augment_probs = [aug.prob for aug in self.augments]
46 |
47 | has_identity = any([cfg['type'] == 'Identity' for cfg in augments_cfg])
48 | if has_identity:
49 | assert sum(self.augment_probs) == 1.0,\
50 | 'The sum of augmentation probabilities should equal to 1,' \
51 | ' but got {:.2f}'.format(sum(self.augment_probs))
52 | else:
53 | assert sum(self.augment_probs) <= 1.0,\
54 | 'The sum of augmentation probabilities should less than or ' \
55 | 'equal to 1, but got {:.2f}'.format(sum(self.augment_probs))
56 | identity_prob = 1 - sum(self.augment_probs)
57 | if identity_prob > 0:
58 | num_classes = self.augments[0].num_classes
59 | self.augments += [
60 | build_augment(
61 | dict(
62 | type='Identity',
63 | num_classes=num_classes,
64 | prob=identity_prob))
65 | ]
66 | self.augment_probs += [identity_prob]
67 |
68 | def __call__(self, img, gt_label):
69 | if self.augments:
70 | random_state = np.random.RandomState(random.randint(0, 2**32 - 1))
71 | aug = random_state.choice(self.augments, p=self.augment_probs)
72 | return aug(img, gt_label)
73 | return img, gt_label
74 |
--------------------------------------------------------------------------------
/mmcls/models/utils/augment/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.utils import Registry, build_from_cfg
3 |
4 | AUGMENT = Registry('augment')
5 |
6 |
7 | def build_augment(cfg, default_args=None):
8 | return build_from_cfg(cfg, AUGMENT, default_args)
9 |
--------------------------------------------------------------------------------
/mmcls/models/utils/augment/identity.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .builder import AUGMENT
3 | from .utils import one_hot_encoding
4 |
5 |
6 | @AUGMENT.register_module(name='Identity')
7 | class Identity(object):
8 | """Change gt_label to one_hot encoding and keep img as the same.
9 |
10 | Args:
11 | num_classes (int): The number of classes.
12 | prob (float): MixUp probability. It should be in range [0, 1].
13 | Default to 1.0
14 | """
15 |
16 | def __init__(self, num_classes, prob=1.0):
17 | super(Identity, self).__init__()
18 |
19 | assert isinstance(num_classes, int)
20 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0
21 |
22 | self.num_classes = num_classes
23 | self.prob = prob
24 |
25 | def one_hot(self, gt_label):
26 | return one_hot_encoding(gt_label, self.num_classes)
27 |
28 | def __call__(self, img, gt_label):
29 | return img, self.one_hot(gt_label)
30 |
--------------------------------------------------------------------------------
/mmcls/models/utils/augment/mixup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 |
4 | import numpy as np
5 | import torch
6 |
7 | from .builder import AUGMENT
8 | from .utils import one_hot_encoding
9 |
10 |
11 | class BaseMixupLayer(object, metaclass=ABCMeta):
12 | """Base class for MixupLayer.
13 |
14 | Args:
15 | alpha (float): Parameters for Beta distribution to generate the
16 | mixing ratio. It should be a positive number.
17 | num_classes (int): The number of classes.
18 | prob (float): MixUp probability. It should be in range [0, 1].
19 | Default to 1.0
20 | """
21 |
22 | def __init__(self, alpha, num_classes, prob=1.0):
23 | super(BaseMixupLayer, self).__init__()
24 |
25 | assert isinstance(alpha, float) and alpha > 0
26 | assert isinstance(num_classes, int)
27 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0
28 |
29 | self.alpha = alpha
30 | self.num_classes = num_classes
31 | self.prob = prob
32 |
33 | @abstractmethod
34 | def mixup(self, imgs, gt_label):
35 | pass
36 |
37 |
38 | @AUGMENT.register_module(name='BatchMixup')
39 | class BatchMixupLayer(BaseMixupLayer):
40 | r"""Mixup layer for a batch of data.
41 |
42 | Mixup is a method to reduces the memorization of corrupt labels and
43 | increases the robustness to adversarial examples. It's
44 | proposed in `mixup: Beyond Empirical Risk Minimization
45 | `
46 |
47 | This method simply linearly mix pairs of data and their labels.
48 |
49 | Args:
50 | alpha (float): Parameters for Beta distribution to generate the
51 | mixing ratio. It should be a positive number. More details
52 | are in the note.
53 | num_classes (int): The number of classes.
54 | prob (float): The probability to execute mixup. It should be in
55 | range [0, 1]. Default sto 1.0.
56 |
57 | Note:
58 | The :math:`\alpha` (``alpha``) determines a random distribution
59 | :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
60 | a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
61 | distribution.
62 | """
63 |
64 | def __init__(self, *args, **kwargs):
65 | super(BatchMixupLayer, self).__init__(*args, **kwargs)
66 |
67 | def mixup(self, img, gt_label):
68 | one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
69 | lam = np.random.beta(self.alpha, self.alpha)
70 | batch_size = img.size(0)
71 | index = torch.randperm(batch_size)
72 |
73 | mixed_img = lam * img + (1 - lam) * img[index, :]
74 | mixed_gt_label = lam * one_hot_gt_label + (
75 | 1 - lam) * one_hot_gt_label[index, :]
76 |
77 | return mixed_img, mixed_gt_label
78 |
79 | def __call__(self, img, gt_label):
80 | return self.mixup(img, gt_label)
81 |
--------------------------------------------------------------------------------
/mmcls/models/utils/augment/resizemix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 | import torch.nn.functional as F
5 |
6 | from mmcls.models.utils.augment.builder import AUGMENT
7 | from .cutmix import BatchCutMixLayer
8 | from .utils import one_hot_encoding
9 |
10 |
11 | @AUGMENT.register_module(name='BatchResizeMix')
12 | class BatchResizeMixLayer(BatchCutMixLayer):
13 | r"""ResizeMix Random Paste layer for a batch of data.
14 |
15 | The ResizeMix will resize an image to a small patch and paste it on another
16 | image. It's proposed in `ResizeMix: Mixing Data with Preserved Object
17 | Information and True Labels `_
18 |
19 | Args:
20 | alpha (float): Parameters for Beta distribution to generate the
21 | mixing ratio. It should be a positive number. More details
22 | can be found in :class:`BatchMixupLayer`.
23 | num_classes (int): The number of classes.
24 | lam_min(float): The minimum value of lam. Defaults to 0.1.
25 | lam_max(float): The maximum value of lam. Defaults to 0.8.
26 | interpolation (str): algorithm used for upsampling:
27 | 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' |
28 | 'area'. Default to 'bilinear'.
29 | prob (float): The probability to execute resizemix. It should be in
30 | range [0, 1]. Defaults to 1.0.
31 | cutmix_minmax (List[float], optional): The min/max area ratio of the
32 | patches. If not None, the bounding-box of patches is uniform
33 | sampled within this ratio range, and the ``alpha`` will be ignored.
34 | Otherwise, the bounding-box is generated according to the
35 | ``alpha``. Defaults to None.
36 | correct_lam (bool): Whether to apply lambda correction when cutmix bbox
37 | clipped by image borders. Defaults to True
38 | **kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`.
39 |
40 | Note:
41 | The :math:`\lambda` (``lam``) is the mixing ratio. It's a random
42 | variable which follows :math:`Beta(\alpha, \alpha)` and is mapped
43 | to the range [``lam_min``, ``lam_max``].
44 |
45 | .. math::
46 | \lambda = \frac{Beta(\alpha, \alpha)}
47 | {\lambda_{max} - \lambda_{min}} + \lambda_{min}
48 |
49 | And the resize ratio of source images is calculated by :math:`\lambda`:
50 |
51 | .. math::
52 | \text{ratio} = \sqrt{1-\lambda}
53 | """
54 |
55 | def __init__(self,
56 | alpha,
57 | num_classes,
58 | lam_min: float = 0.1,
59 | lam_max: float = 0.8,
60 | interpolation='bilinear',
61 | prob=1.0,
62 | cutmix_minmax=None,
63 | correct_lam=True,
64 | **kwargs):
65 | super(BatchResizeMixLayer, self).__init__(
66 | alpha=alpha,
67 | num_classes=num_classes,
68 | prob=prob,
69 | cutmix_minmax=cutmix_minmax,
70 | correct_lam=correct_lam,
71 | **kwargs)
72 | self.lam_min = lam_min
73 | self.lam_max = lam_max
74 | self.interpolation = interpolation
75 |
76 | def cutmix(self, img, gt_label):
77 | one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes)
78 |
79 | lam = np.random.beta(self.alpha, self.alpha)
80 | lam = lam * (self.lam_max - self.lam_min) + self.lam_min
81 | batch_size = img.size(0)
82 | index = torch.randperm(batch_size)
83 |
84 | (bby1, bby2, bbx1,
85 | bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
86 |
87 | img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate(
88 | img[index],
89 | size=(bby2 - bby1, bbx2 - bbx1),
90 | mode=self.interpolation)
91 | mixed_gt_label = lam * one_hot_gt_label + (
92 | 1 - lam) * one_hot_gt_label[index, :]
93 | return img, mixed_gt_label
94 |
--------------------------------------------------------------------------------
/mmcls/models/utils/augment/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn.functional as F
3 |
4 |
5 | def one_hot_encoding(gt, num_classes):
6 | """Change gt_label to one_hot encoding.
7 |
8 | If the shape has 2 or more
9 | dimensions, return it without encoding.
10 | Args:
11 | gt (Tensor): The gt label with shape (N,) or shape (N, */).
12 | num_classes (int): The number of classes.
13 | Return:
14 | Tensor: One hot gt label.
15 | """
16 | if gt.ndim == 1:
17 | # multi-class classification
18 | return F.one_hot(gt, num_classes=num_classes)
19 | else:
20 | # binary classification
21 | # example. [[0], [1], [1]]
22 | # multi-label classification
23 | # example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]]
24 | return gt
25 |
--------------------------------------------------------------------------------
/mmcls/models/utils/channel_shuffle.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 |
4 |
5 | def channel_shuffle(x, groups):
6 | """Channel Shuffle operation.
7 |
8 | This function enables cross-group information flow for multiple groups
9 | convolution layers.
10 |
11 | Args:
12 | x (Tensor): The input tensor.
13 | groups (int): The number of groups to divide the input tensor
14 | in the channel dimension.
15 |
16 | Returns:
17 | Tensor: The output tensor after channel shuffle operation.
18 | """
19 |
20 | batch_size, num_channels, height, width = x.size()
21 | assert (num_channels % groups == 0), ('num_channels should be '
22 | 'divisible by groups')
23 | channels_per_group = num_channels // groups
24 |
25 | x = x.view(batch_size, groups, channels_per_group, height, width)
26 | x = torch.transpose(x, 1, 2).contiguous()
27 | x = x.view(batch_size, -1, height, width)
28 |
29 | return x
30 |
--------------------------------------------------------------------------------
/mmcls/models/utils/helpers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import collections.abc
3 | import warnings
4 | from itertools import repeat
5 |
6 | import torch
7 | from mmcv.utils import digit_version
8 |
9 |
10 | def is_tracing() -> bool:
11 | """Determine whether the model is called during the tracing of code with
12 | ``torch.jit.trace``."""
13 | if digit_version(torch.__version__) >= digit_version('1.6.0'):
14 | on_trace = torch.jit.is_tracing()
15 | # In PyTorch 1.6, torch.jit.is_tracing has a bug.
16 | # Refers to https://github.com/pytorch/pytorch/issues/42448
17 | if isinstance(on_trace, bool):
18 | return on_trace
19 | else:
20 | return torch._C._is_tracing()
21 | else:
22 | warnings.warn(
23 | 'torch.jit.is_tracing is only supported after v1.6.0. '
24 | 'Therefore is_tracing returns False automatically. Please '
25 | 'set on_trace manually if you are using trace.', UserWarning)
26 | return False
27 |
28 |
29 | # From PyTorch internals
30 | def _ntuple(n):
31 | """A `to_tuple` function generator.
32 |
33 | It returns a function, this function will repeat the input to a tuple of
34 | length ``n`` if the input is not an Iterable object, otherwise, return the
35 | input directly.
36 |
37 | Args:
38 | n (int): The number of the target length.
39 | """
40 |
41 | def parse(x):
42 | if isinstance(x, collections.abc.Iterable):
43 | return x
44 | return tuple(repeat(x, n))
45 |
46 | return parse
47 |
48 |
49 | to_1tuple = _ntuple(1)
50 | to_2tuple = _ntuple(2)
51 | to_3tuple = _ntuple(3)
52 | to_4tuple = _ntuple(4)
53 | to_ntuple = _ntuple
54 |
--------------------------------------------------------------------------------
/mmcls/models/utils/make_divisible.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
3 | """Make divisible function.
4 |
5 | This function rounds the channel number down to the nearest value that can
6 | be divisible by the divisor.
7 |
8 | Args:
9 | value (int): The original channel number.
10 | divisor (int): The divisor to fully divide the channel number.
11 | min_value (int, optional): The minimum value of the output channel.
12 | Default: None, means that the minimum value equal to the divisor.
13 | min_ratio (float): The minimum ratio of the rounded channel
14 | number to the original channel number. Default: 0.9.
15 | Returns:
16 | int: The modified output channel number
17 | """
18 |
19 | if min_value is None:
20 | min_value = divisor
21 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
22 | # Make sure that round down does not go down by more than (1-min_ratio).
23 | if new_value < min_ratio * value:
24 | new_value += divisor
25 | return new_value
26 |
--------------------------------------------------------------------------------
/mmcls/models/utils/position_encoding.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmcv.runner.base_module import BaseModule
4 |
5 |
6 | class ConditionalPositionEncoding(BaseModule):
7 | """The Conditional Position Encoding (CPE) module.
8 |
9 | The CPE is the implementation of 'Conditional Positional Encodings
10 | for Vision Transformers '_.
11 |
12 | Args:
13 | in_channels (int): Number of input channels.
14 | embed_dims (int): The feature dimension. Default: 768.
15 | stride (int): Stride of conv layer. Default: 1.
16 | """
17 |
18 | def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None):
19 | super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg)
20 | self.proj = nn.Conv2d(
21 | in_channels,
22 | embed_dims,
23 | kernel_size=3,
24 | stride=stride,
25 | padding=1,
26 | bias=True,
27 | groups=embed_dims)
28 | self.stride = stride
29 |
30 | def forward(self, x, hw_shape):
31 | B, N, C = x.shape
32 | H, W = hw_shape
33 | feat_token = x
34 | # convert (B, N, C) to (B, C, H, W)
35 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous()
36 | if self.stride == 1:
37 | x = self.proj(cnn_feat) + cnn_feat
38 | else:
39 | x = self.proj(cnn_feat)
40 | x = x.flatten(2).transpose(1, 2)
41 | return x
42 |
--------------------------------------------------------------------------------
/mmcls/models/utils/se_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcv
3 | import torch.nn as nn
4 | from mmcv.cnn import ConvModule
5 | from mmcv.runner import BaseModule
6 |
7 | from .make_divisible import make_divisible
8 |
9 |
10 | class SELayer(BaseModule):
11 | """Squeeze-and-Excitation Module.
12 |
13 | Args:
14 | channels (int): The input (and output) channels of the SE layer.
15 | squeeze_channels (None or int): The intermediate channel number of
16 | SElayer. Default: None, means the value of ``squeeze_channels``
17 | is ``make_divisible(channels // ratio, divisor)``.
18 | ratio (int): Squeeze ratio in SELayer, the intermediate channel will
19 | be ``make_divisible(channels // ratio, divisor)``. Only used when
20 | ``squeeze_channels`` is None. Default: 16.
21 | divisor(int): The divisor to true divide the channel number. Only
22 | used when ``squeeze_channels`` is None. Default: 8.
23 | conv_cfg (None or dict): Config dict for convolution layer. Default:
24 | None, which means using conv2d.
25 | return_weight(bool): Whether to return the weight. Default: False.
26 | act_cfg (dict or Sequence[dict]): Config dict for activation layer.
27 | If act_cfg is a dict, two activation layers will be configurated
28 | by this dict. If act_cfg is a sequence of dicts, the first
29 | activation layer will be configurated by the first dict and the
30 | second activation layer will be configurated by the second dict.
31 | Default: (dict(type='ReLU'), dict(type='Sigmoid'))
32 | """
33 |
34 | def __init__(self,
35 | channels,
36 | squeeze_channels=None,
37 | ratio=16,
38 | divisor=8,
39 | bias='auto',
40 | conv_cfg=None,
41 | act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')),
42 | return_weight=False,
43 | init_cfg=None):
44 | super(SELayer, self).__init__(init_cfg)
45 | if isinstance(act_cfg, dict):
46 | act_cfg = (act_cfg, act_cfg)
47 | assert len(act_cfg) == 2
48 | assert mmcv.is_tuple_of(act_cfg, dict)
49 | self.global_avgpool = nn.AdaptiveAvgPool2d(1)
50 | if squeeze_channels is None:
51 | squeeze_channels = make_divisible(channels // ratio, divisor)
52 | assert isinstance(squeeze_channels, int) and squeeze_channels > 0, \
53 | '"squeeze_channels" should be a positive integer, but get ' + \
54 | f'{squeeze_channels} instead.'
55 | self.return_weight = return_weight
56 | self.conv1 = ConvModule(
57 | in_channels=channels,
58 | out_channels=squeeze_channels,
59 | kernel_size=1,
60 | stride=1,
61 | bias=bias,
62 | conv_cfg=conv_cfg,
63 | act_cfg=act_cfg[0])
64 | self.conv2 = ConvModule(
65 | in_channels=squeeze_channels,
66 | out_channels=channels,
67 | kernel_size=1,
68 | stride=1,
69 | bias=bias,
70 | conv_cfg=conv_cfg,
71 | act_cfg=act_cfg[1])
72 |
73 | def forward(self, x):
74 | out = self.global_avgpool(x)
75 | out = self.conv1(out)
76 | out = self.conv2(out)
77 | if self.return_weight:
78 | return out
79 | else:
80 | return x * out
81 |
--------------------------------------------------------------------------------
/mmcls/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .collect_env import collect_env
3 | from .device import auto_select_device
4 | from .distribution import wrap_distributed_model, wrap_non_distributed_model
5 | from .logger import get_root_logger, load_json_log
6 | from .setup_env import setup_multi_processes
7 |
8 | __all__ = [
9 | 'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes',
10 | 'wrap_non_distributed_model', 'wrap_distributed_model',
11 | 'auto_select_device'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmcls/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.utils import collect_env as collect_base_env
3 | from mmcv.utils import get_git_hash
4 |
5 | import mmcls
6 |
7 |
8 | def collect_env():
9 | """Collect the information of the running environments."""
10 | env_info = collect_base_env()
11 | env_info['MMClassification'] = mmcls.__version__ + '+' + get_git_hash()[:7]
12 | return env_info
13 |
14 |
15 | if __name__ == '__main__':
16 | for name, val in collect_env().items():
17 | print(f'{name}: {val}')
18 |
--------------------------------------------------------------------------------
/mmcls/utils/device.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcv
3 | import torch
4 | from mmcv.utils import digit_version
5 |
6 |
7 | def auto_select_device() -> str:
8 | mmcv_version = digit_version(mmcv.__version__)
9 | if mmcv_version >= digit_version('1.6.0'):
10 | from mmcv.device import get_device
11 | return get_device()
12 | elif torch.cuda.is_available():
13 | return 'cuda'
14 | else:
15 | return 'cpu'
16 |
--------------------------------------------------------------------------------
/mmcls/utils/distribution.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 |
3 |
4 | def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs):
5 | """Wrap module in non-distributed environment by device type.
6 |
7 | - For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`.
8 | - For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`.
9 | - For CPU & IPU, not wrap the model.
10 |
11 | Args:
12 | model(:class:`nn.Module`): model to be parallelized.
13 | device(str): device type, cuda, cpu or mlu. Defaults to cuda.
14 | dim(int): Dimension used to scatter the data. Defaults to 0.
15 |
16 | Returns:
17 | model(nn.Module): the model to be parallelized.
18 | """
19 | if device == 'cuda':
20 | from mmcv.parallel import MMDataParallel
21 | model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs)
22 | elif device == 'cpu':
23 | model = model.cpu()
24 | elif device == 'ipu':
25 | model = model.cpu()
26 | elif device == 'mps':
27 | from mmcv.device import mps
28 | model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs)
29 | else:
30 | raise RuntimeError(f'Unavailable device "{device}"')
31 |
32 | return model
33 |
34 |
35 | def wrap_distributed_model(model, device='cuda', *args, **kwargs):
36 | """Build DistributedDataParallel module by device type.
37 |
38 | - For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`.
39 | - Other device types are not supported by now.
40 |
41 | Args:
42 | model(:class:`nn.Module`): module to be parallelized.
43 | device(str): device type, mlu or cuda.
44 |
45 | Returns:
46 | model(:class:`nn.Module`): the module to be parallelized
47 |
48 | References:
49 | .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
50 | DistributedDataParallel.html
51 | """
52 | if device == 'cuda':
53 | from mmcv.parallel import MMDistributedDataParallel
54 | model = MMDistributedDataParallel(model.cuda(), *args, **kwargs)
55 | else:
56 | raise RuntimeError(f'Unavailable device "{device}"')
57 |
58 | return model
59 |
--------------------------------------------------------------------------------
/mmcls/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import json
3 | import logging
4 | from collections import defaultdict
5 |
6 | from mmcv.utils import get_logger
7 |
8 |
9 | def get_root_logger(log_file=None, log_level=logging.INFO):
10 | """Get root logger.
11 |
12 | Args:
13 | log_file (str, optional): File path of log. Defaults to None.
14 | log_level (int, optional): The level of logger.
15 | Defaults to :obj:`logging.INFO`.
16 |
17 | Returns:
18 | :obj:`logging.Logger`: The obtained logger
19 | """
20 | return get_logger('mmcls', log_file, log_level)
21 |
22 |
23 | def load_json_log(json_log):
24 | """load and convert json_logs to log_dicts.
25 |
26 | Args:
27 | json_log (str): The path of the json log file.
28 |
29 | Returns:
30 | dict[int, dict[str, list]]:
31 | Key is the epoch, value is a sub dict. The keys in each sub dict
32 | are different metrics, e.g. memory, bbox_mAP, and the value is a
33 | list of corresponding values in all iterations in this epoch.
34 |
35 | .. code-block:: python
36 |
37 | # An example output
38 | {
39 | 1: {'iter': [100, 200, 300], 'loss': [6.94, 6.73, 6.53]},
40 | 2: {'iter': [100, 200, 300], 'loss': [6.33, 6.20, 6.07]},
41 | ...
42 | }
43 | """
44 | log_dict = dict()
45 | with open(json_log, 'r') as log_file:
46 | for line in log_file:
47 | log = json.loads(line.strip())
48 | # skip lines without `epoch` field
49 | if 'epoch' not in log:
50 | continue
51 | epoch = log.pop('epoch')
52 | if epoch not in log_dict:
53 | log_dict[epoch] = defaultdict(list)
54 | for k, v in log.items():
55 | log_dict[epoch][k].append(v)
56 | return log_dict
57 |
--------------------------------------------------------------------------------
/mmcls/utils/setup_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os
3 | import platform
4 | import warnings
5 |
6 | import cv2
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def setup_multi_processes(cfg):
11 | """Setup multi-processing environment variables."""
12 | # set multi-process start method as `fork` to speed up the training
13 | if platform.system() != 'Windows':
14 | mp_start_method = cfg.get('mp_start_method', 'fork')
15 | current_method = mp.get_start_method(allow_none=True)
16 | if current_method is not None and current_method != mp_start_method:
17 | warnings.warn(
18 | f'Multi-processing start method `{mp_start_method}` is '
19 | f'different from the previous setting `{current_method}`.'
20 | f'It will be force set to `{mp_start_method}`. You can change '
21 | f'this behavior by changing `mp_start_method` in your config.')
22 | mp.set_start_method(mp_start_method, force=True)
23 |
24 | # disable opencv multithreading to avoid system being overloaded
25 | opencv_num_threads = cfg.get('opencv_num_threads', 0)
26 | cv2.setNumThreads(opencv_num_threads)
27 |
28 | # setup OMP threads
29 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
30 | if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
31 | omp_num_threads = 1
32 | warnings.warn(
33 | f'Setting OMP_NUM_THREADS environment variable for each process '
34 | f'to be {omp_num_threads} in default, to avoid your system being '
35 | f'overloaded, please further tune the variable for optimal '
36 | f'performance in your application as needed.')
37 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
38 |
39 | # setup MKL threads
40 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
41 | mkl_num_threads = 1
42 | warnings.warn(
43 | f'Setting MKL_NUM_THREADS environment variable for each process '
44 | f'to be {mkl_num_threads} in default, to avoid your system being '
45 | f'overloaded, please further tune the variable for optimal '
46 | f'performance in your application as needed.')
47 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
48 |
--------------------------------------------------------------------------------
/mmcls/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved
2 |
3 | __version__ = '0.23.2_71ef7ba'
4 |
5 |
6 | def parse_version_info(version_str):
7 | """Parse a version string into a tuple.
8 |
9 | Args:
10 | version_str (str): The version string.
11 | Returns:
12 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into
13 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1').
14 | """
15 | version_info = []
16 | for x in version_str.split('.'):
17 | if x.isdigit():
18 | version_info.append(int(x))
19 | elif x.find('rc') != -1:
20 | patch_version = x.split('rc')
21 | version_info.append(int(patch_version[0]))
22 | version_info.append(f'rc{patch_version[1]}')
23 | return tuple(version_info)
24 |
25 |
26 | version_info = parse_version_info(__version__)
27 |
28 | __all__ = ['__version__', 'version_info', 'parse_version_info']
29 |
--------------------------------------------------------------------------------
/mmfewshot/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import mmcls # noqa: F401, F403
3 | import mmcv
4 | # import mmdet # noqa: F401, F403
5 |
6 | from .classification import * # noqa: F401, F403
7 | # from .detection import * # noqa: F401, F403
8 | from .utils import * # noqa: F401, F403
9 | from .version import __version__, short_version
10 |
11 |
12 | def digit_version(version_str):
13 | digit_version_ = []
14 | for x in version_str.split('.'):
15 | if x.isdigit():
16 | digit_version_.append(int(x))
17 | elif x.find('rc') != -1:
18 | patch_version = x.split('rc')
19 | digit_version_.append(int(patch_version[0]) - 1)
20 | digit_version_.append(int(patch_version[1]))
21 | return digit_version_
22 |
23 |
24 | mmcv_minimum_version = '1.3.12'
25 | # By harbor : 1.6.1 seems ok
26 | mmcv_maximum_version = '1.6.1'
27 | mmcv_version = digit_version(mmcv.__version__)
28 |
29 |
30 | assert (digit_version(mmcv_minimum_version) <= mmcv_version
31 | <= digit_version(mmcv_maximum_version)), \
32 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \
33 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
34 |
35 | # By harbor : No need for mmdet
36 | # mmdet_minimum_version = '2.16.0'
37 | # mmdet_maximum_version = '2.25.0'
38 | # mmdet_version = digit_version(mmdet.__version__)
39 | #
40 | #
41 | # assert (digit_version(mmdet_minimum_version) <= mmdet_version
42 | # <= digit_version(mmdet_maximum_version)), \
43 | # f'MMDET=={mmdet.__version__} is used but incompatible. ' \
44 | # f'Please install mmdet>={mmdet_minimum_version},\
45 | # <={mmdet_maximum_version}.'
46 |
47 | mmcls_minimum_version = '0.15.0'
48 | mmcls_maximum_version = '0.25.0'
49 | mmcls_version = digit_version(mmcls.__version__)
50 |
51 |
52 | assert (digit_version(mmcls_minimum_version) <= mmcls_version
53 | <= digit_version(mmcls_maximum_version)), \
54 | f'MMCLS=={mmcls.__version__} is used but incompatible. ' \
55 | f'Please install mmcls>={mmcls_minimum_version},\
56 | <={mmcls_maximum_version}.'
57 |
58 | __all__ = ['__version__', 'short_version']
59 |
--------------------------------------------------------------------------------
/mmfewshot/classification/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .apis import * # noqa: F401,F403
3 | from .core import * # noqa: F401,F403
4 | from .datasets import * # noqa: F401,F403
5 | from .models import * # noqa: F401,F403
6 | from .utils import * # noqa: F401, F403
7 |
--------------------------------------------------------------------------------
/mmfewshot/classification/apis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .inference import (inference_classifier, init_classifier,
3 | process_support_images, show_result_pyplot)
4 | from .test import (Z_SCORE, multi_gpu_meta_test, single_gpu_meta_test,
5 | test_single_task)
6 | from .train import train_model
7 |
8 | __all__ = [
9 | 'train_model', 'test_single_task', 'Z_SCORE', 'single_gpu_meta_test',
10 | 'multi_gpu_meta_test', 'init_classifier', 'process_support_images',
11 | 'inference_classifier', 'show_result_pyplot'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmfewshot/classification/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .evaluation import * # noqa: F401, F403
3 |
--------------------------------------------------------------------------------
/mmfewshot/classification/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .eval_hooks import DistMetaTestEvalHook, MetaTestEvalHook
3 |
4 | __all__ = ['MetaTestEvalHook', 'DistMetaTestEvalHook']
5 |
--------------------------------------------------------------------------------
/mmfewshot/classification/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcls.datasets.builder import DATASETS, PIPELINES
3 |
4 | from .base import BaseFewShotDataset
5 | from .builder import (build_dataloader, build_dataset,
6 | build_meta_test_dataloader)
7 | from .cub import CUBDataset
8 | from .dataset_wrappers import EpisodicDataset, MetaTestDataset
9 | from .mini_imagenet import MiniImageNetDataset
10 | from .pipelines import LoadImageFromBytes
11 | from .tiered_imagenet import TieredImageNetDataset
12 | from .utils import label_wrapper
13 |
14 | __all__ = [
15 | 'build_dataloader', 'build_dataset', 'DATASETS', 'PIPELINES', 'CUBDataset',
16 | 'LoadImageFromBytes', 'build_meta_test_dataloader', 'MiniImageNetDataset',
17 | 'TieredImageNetDataset', 'label_wrapper', 'BaseFewShotDataset',
18 | 'EpisodicDataset', 'MetaTestDataset'
19 | ]
20 |
--------------------------------------------------------------------------------
/mmfewshot/classification/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .loading import LoadImageFromBytes
3 |
4 | __all__ = [
5 | 'LoadImageFromBytes',
6 | ]
7 |
--------------------------------------------------------------------------------
/mmfewshot/classification/datasets/pipelines/loading.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import os.path as osp
3 | from typing import Dict
4 |
5 | import mmcv
6 | import numpy as np
7 | from mmcls.datasets.builder import PIPELINES
8 | from mmcls.datasets.pipelines import LoadImageFromFile
9 |
10 |
11 | @PIPELINES.register_module()
12 | class LoadImageFromBytes(LoadImageFromFile):
13 | """Load an image from bytes."""
14 |
15 | def __call__(self, results: Dict) -> Dict:
16 | if self.file_client is None:
17 | self.file_client = mmcv.FileClient(**self.file_client_args)
18 | if results['img_prefix'] is not None:
19 | filename = osp.join(results['img_prefix'],
20 | results['img_info']['filename'])
21 | else:
22 | filename = results['img_info']['filename']
23 | if results.get('img_bytes', None) is None:
24 | img_bytes = self.file_client.get(filename)
25 | else:
26 | img_bytes = results.pop('img_bytes')
27 | img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
28 | if self.to_float32:
29 | img = img.astype(np.float32)
30 |
31 | results['filename'] = filename
32 | results['ori_filename'] = results['img_info']['filename']
33 | results['img'] = img
34 | results['img_shape'] = img.shape
35 | results['ori_shape'] = img.shape
36 | num_channels = 1 if len(img.shape) < 3 else img.shape[2]
37 | results['img_norm_cfg'] = dict(
38 | mean=np.zeros(num_channels, dtype=np.float32),
39 | std=np.ones(num_channels, dtype=np.float32),
40 | to_rgb=False)
41 | return results
42 |
--------------------------------------------------------------------------------
/mmfewshot/classification/datasets/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import List, Union
3 |
4 | import numpy as np
5 | import torch
6 | from torch import Tensor
7 |
8 |
9 | def label_wrapper(labels: Union[Tensor, np.ndarray, List],
10 | class_ids: List[int]) -> Union[Tensor, np.ndarray, list]:
11 | """Map input labels into range of 0 to numbers of classes-1.
12 |
13 | It is usually used in the meta testing phase, in which the class ids are
14 | random sampled and discontinuous.
15 |
16 | Args:
17 | labels (Tensor | np.ndarray | list): The labels to be wrapped.
18 | class_ids (list[int]): All class ids of labels.
19 |
20 | Returns:
21 | (Tensor | np.ndarray | list): Same type as the input labels.
22 | """
23 | class_id_map = {class_id: i for i, class_id in enumerate(class_ids)}
24 | if isinstance(labels, torch.Tensor):
25 | wrapped_labels = torch.tensor(
26 | [class_id_map[label.item()] for label in labels])
27 | wrapped_labels = wrapped_labels.type_as(labels).to(labels.device)
28 | elif isinstance(labels, np.ndarray):
29 | wrapped_labels = np.array([class_id_map[label] for label in labels])
30 | wrapped_labels = wrapped_labels.astype(labels.dtype)
31 | elif isinstance(labels, (tuple, list)):
32 | wrapped_labels = [class_id_map[label] for label in labels]
33 | else:
34 | raise TypeError('only support torch.Tensor, np.ndarray and list')
35 | return wrapped_labels
36 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcls.models.builder import * # noqa: F401,F403
3 |
4 | from .backbones import * # noqa: F401,F403
5 | from .classifiers import * # noqa: F401,F403
6 | from .heads import * # noqa: F401,F403
7 | from .losses import * # noqa: F401,F403
8 | from .utils import * # noqa: F401,F403
9 |
10 | __all__ = []
11 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcls.models.builder import BACKBONES
3 |
4 | from .conv4 import Conv4, ConvNet
5 | from .resnet12 import ResNet12
6 | from .wrn import WideResNet, WRN28x10
7 |
8 | __all__ = [
9 | 'BACKBONES', 'ResNet12', 'Conv4', 'ConvNet', 'WRN28x10', 'WideResNet'
10 | ]
11 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/backbones/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch import Tensor
6 | from torch.distributions import Bernoulli
7 |
8 | # This part of code is modified from https://github.com/kjunelee/MetaOptNet
9 |
10 |
11 | class DropBlock(nn.Module):
12 |
13 | def __init__(self, block_size: int) -> None:
14 | super().__init__()
15 | self.block_size = block_size
16 |
17 | def forward(self, x: Tensor, gamma: float) -> Tensor:
18 | # Randomly zeroes 2D spatial blocks of the input tensor.
19 | if self.training:
20 | batch_size, channels, height, width = x.shape
21 | bernoulli = Bernoulli(gamma)
22 | mask = bernoulli.sample(
23 | (batch_size, channels, height - (self.block_size - 1),
24 | width - (self.block_size - 1)))
25 | mask = mask.to(x.device)
26 | block_mask = self._compute_block_mask(mask)
27 | countM = block_mask.size()[0] * block_mask.size(
28 | )[1] * block_mask.size()[2] * block_mask.size()[3]
29 | count_ones = block_mask.sum()
30 |
31 | return block_mask * x * (countM / count_ones)
32 | else:
33 | return x
34 |
35 | def _compute_block_mask(self, mask: Tensor) -> Tensor:
36 | left_padding = int((self.block_size - 1) / 2)
37 | right_padding = int(self.block_size / 2)
38 |
39 | non_zero_idxes = mask.nonzero()
40 | nr_blocks = non_zero_idxes.shape[0]
41 |
42 | offsets = torch.stack([
43 | torch.arange(self.block_size).view(-1, 1).expand(
44 | self.block_size, self.block_size).reshape(-1),
45 | torch.arange(self.block_size).repeat(self.block_size),
46 | ]).t()
47 | offsets = torch.cat(
48 | (torch.zeros(self.block_size**2, 2).long(), offsets.long()), 1)
49 | offsets = offsets.to(mask.device)
50 |
51 | if nr_blocks > 0:
52 | non_zero_idxes = non_zero_idxes.repeat(self.block_size**2, 1)
53 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4)
54 | offsets = offsets.long()
55 |
56 | block_idxes = non_zero_idxes + offsets
57 | padded_mask = F.pad(
58 | mask,
59 | (left_padding, right_padding, left_padding, right_padding))
60 | padded_mask[block_idxes[:, 0], block_idxes[:, 1],
61 | block_idxes[:, 2], block_idxes[:, 3]] = 1.
62 | else:
63 | padded_mask = F.pad(
64 | mask,
65 | (left_padding, right_padding, left_padding, right_padding))
66 |
67 | block_mask = 1 - padded_mask
68 | return block_mask
69 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcls.models.builder import CLASSIFIERS
3 |
4 | from .base_finetune import BaseFinetuneClassifier
5 | from .base_metric import BaseMetricClassifier
6 | from .baseline import Baseline
7 | from .baseline_plus import BaselinePlus
8 | from .maml import MAML
9 | from .matching_net import MatchingNet
10 | from .meta_baseline import MetaBaseline
11 | from .neg_margin import NegMargin
12 | from .proto_net import ProtoNet
13 | from .relation_net import RelationNet
14 |
15 | __all__ = [
16 | 'CLASSIFIERS', 'BaseFinetuneClassifier', 'BaseMetricClassifier',
17 | 'Baseline', 'BaselinePlus', 'ProtoNet', 'MatchingNet', 'RelationNet',
18 | 'NegMargin', 'MetaBaseline', 'MAML'
19 | ]
20 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/baseline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict
3 |
4 | from mmcls.models.builder import CLASSIFIERS
5 |
6 | from .base_finetune import BaseFinetuneClassifier
7 |
8 |
9 | @CLASSIFIERS.register_module()
10 | class Baseline(BaseFinetuneClassifier):
11 | """Implementation of `Baseline `_.
12 |
13 | Args:
14 | head (dict): Config of classification head for training.
15 | meta_test_head (dict): Config of classification head for meta testing.
16 | the `meta_test_head` only will be built and run in meta testing.
17 | """
18 |
19 | def __init__(self,
20 | head: Dict = dict(
21 | type='LinearHead', num_classes=100, in_channels=1600),
22 | meta_test_head: Dict = dict(
23 | type='LinearHead', num_classes=5, in_channels=1600),
24 | *args,
25 | **kwargs) -> None:
26 | super().__init__(
27 | head=head, meta_test_head=meta_test_head, *args, **kwargs)
28 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/baseline_plus.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict
3 |
4 | from mmcls.models.builder import CLASSIFIERS
5 |
6 | from .base_finetune import BaseFinetuneClassifier
7 |
8 |
9 | @CLASSIFIERS.register_module()
10 | class BaselinePlus(BaseFinetuneClassifier):
11 | """Implementation of `Baseline++ `_.
12 |
13 | Args:
14 | head (dict): Config of classification head for training.
15 | meta_test_head (dict): Config of classification head for meta testing.
16 | the `meta_test_head` only will be built and run in meta testing.
17 | """
18 |
19 | def __init__(self,
20 | head: Dict = dict(
21 | type='CosineDistanceHead',
22 | num_classes=100,
23 | in_channels=1600),
24 | meta_test_head: Dict = dict(
25 | type='CosineDistanceHead',
26 | num_classes=5,
27 | in_channels=1600),
28 | *args,
29 | **kwargs) -> None:
30 | super().__init__(
31 | head=head, meta_test_head=meta_test_head, *args, **kwargs)
32 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/matching_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 | from typing import Dict
4 |
5 | from mmcls.models.builder import CLASSIFIERS
6 |
7 | from .base_metric import BaseMetricClassifier
8 |
9 |
10 | @CLASSIFIERS.register_module()
11 | class MatchingNet(BaseMetricClassifier):
12 | """Implementation of `MatchingNet `_."""
13 |
14 | def __init__(self, head: Dict = dict(type='MatchingHead'), *args,
15 | **kwargs) -> None:
16 | self.head_cfg = copy.deepcopy(head)
17 | super().__init__(head=head, *args, **kwargs)
18 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/meta_baseline.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict
3 |
4 | from mmcls.models.builder import CLASSIFIERS
5 |
6 | from .base_metric import BaseMetricClassifier
7 |
8 |
9 | @CLASSIFIERS.register_module()
10 | class MetaBaseline(BaseMetricClassifier):
11 | """Implementation of `MetaBaseline `_.
12 |
13 | Args:
14 | head (dict): Config of classification head for training.
15 | """
16 |
17 | def __init__(self,
18 | head: Dict = dict(type='MetaBaselineHead'),
19 | *args,
20 | **kwargs) -> None:
21 | super().__init__(head=head, *args, **kwargs)
22 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/neg_margin.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict
3 |
4 | from mmcls.models.builder import CLASSIFIERS
5 |
6 | from .base_finetune import BaseFinetuneClassifier
7 |
8 |
9 | @CLASSIFIERS.register_module()
10 | class NegMargin(BaseFinetuneClassifier):
11 | """Implementation of `NegMargin `_."""
12 |
13 | def __init__(self,
14 | head: Dict = dict(
15 | type='NegMarginHead',
16 | metric_type='cosine',
17 | num_classes=100,
18 | in_channels=1600,
19 | margin=-0.02,
20 | temperature=30.0),
21 | meta_test_head: Dict = dict(
22 | type='NegMarginHead',
23 | metric_type='cosine',
24 | num_classes=5,
25 | in_channels=1600,
26 | margin=0.0,
27 | temperature=5.0),
28 | *args,
29 | **kwargs) -> None:
30 | super().__init__(
31 | head=head, meta_test_head=meta_test_head, *args, **kwargs)
32 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/proto_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 | from typing import Dict
4 |
5 | from mmcls.models.builder import CLASSIFIERS
6 |
7 | from .base_metric import BaseMetricClassifier
8 |
9 |
10 | @CLASSIFIERS.register_module()
11 | class ProtoNet(BaseMetricClassifier):
12 | """Implementation of `ProtoNet `_."""
13 |
14 | def __init__(self,
15 | head: Dict = dict(type='PrototypeHead'),
16 | *args,
17 | **kwargs) -> None:
18 | self.head_cfg = copy.deepcopy(head)
19 | super().__init__(head=head, *args, **kwargs)
20 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/classifiers/relation_net.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 | from typing import Dict
4 |
5 | from mmcls.models.builder import CLASSIFIERS
6 |
7 | from .base_metric import BaseMetricClassifier
8 |
9 |
10 | @CLASSIFIERS.register_module()
11 | class RelationNet(BaseMetricClassifier):
12 | """Implementation of `RelationNet `_."""
13 |
14 | def __init__(self,
15 | head: Dict = dict(
16 | type='RelationHead',
17 | in_channels=64,
18 | feature_size=(19, 19)),
19 | *args,
20 | **kwargs) -> None:
21 | self.head_cfg = copy.deepcopy(head)
22 | super().__init__(head=head, *args, **kwargs)
23 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcls.models.builder import HEADS
3 |
4 | from .cosine_distance_head import CosineDistanceHead
5 | from .linear_head import LinearHead
6 | from .matching_head import MatchingHead
7 | from .meta_baseline_head import MetaBaselineHead
8 | from .neg_margin_head import NegMarginHead
9 | from .prototype_head import PrototypeHead
10 | from .relation_head import RelationHead
11 |
12 | __all__ = [
13 | 'HEADS', 'MetaBaselineHead', 'MatchingHead', 'NegMarginHead', 'LinearHead',
14 | 'CosineDistanceHead', 'PrototypeHead', 'RelationHead'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/heads/base_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from abc import ABCMeta, abstractmethod
3 | from typing import Dict, Tuple
4 |
5 | from mmcls.models.builder import HEADS, build_loss
6 | from mmcls.models.losses import Accuracy
7 | from mmcv.runner import BaseModule
8 | from torch import Tensor
9 |
10 |
11 | @HEADS.register_module()
12 | class BaseFewShotHead(BaseModule, metaclass=ABCMeta):
13 | """Base head for few shot classifier.
14 |
15 | Args:
16 | loss (dict): Training loss.
17 | topk (tuple[int]): Topk metric for computing the accuracy.
18 | cal_acc (bool): Whether to compute the accuracy during training.
19 | Default: False.
20 | """
21 |
22 | def __init__(self,
23 | loss: Dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
24 | topk: Tuple[int] = (1, ),
25 | cal_acc: bool = False) -> None:
26 | super().__init__()
27 | assert isinstance(loss, dict)
28 | assert isinstance(topk, (int, tuple))
29 | if isinstance(topk, int):
30 | topk = (topk, )
31 | for _topk in topk:
32 | assert _topk > 0, 'Top-k should be larger than 0'
33 | self.topk = topk
34 |
35 | self.compute_loss = build_loss(loss)
36 | self.compute_accuracy = Accuracy(topk=self.topk)
37 | self.cal_acc = cal_acc
38 |
39 | def loss(self, cls_score: Tensor, gt_label: Tensor) -> Dict:
40 | """Calculate loss.
41 |
42 | Args:
43 | cls_score (Tensor): The prediction.
44 | gt_label (Tensor): The learning target of the prediction.
45 |
46 | Returns:
47 | Dict: The calculated loss.
48 | """
49 | num_samples = len(cls_score)
50 | losses = dict()
51 | # compute loss
52 | loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples)
53 | if self.cal_acc:
54 | # compute accuracy
55 | acc = self.compute_accuracy(cls_score, gt_label)
56 | assert len(acc) == len(self.topk)
57 | losses['accuracy'] = {
58 | f'top-{k}': a
59 | for k, a in zip(self.topk, acc)
60 | }
61 | losses['loss'] = loss
62 | return losses
63 |
64 | @abstractmethod
65 | def forward_train(self, **kwargs):
66 | """Forward training data."""
67 |
68 | @abstractmethod
69 | def forward_support(self, x, gt_label, **kwargs):
70 | """Forward support data in meta testing."""
71 |
72 | @abstractmethod
73 | def forward_query(self, x, **kwargs):
74 | """Forward query data in meta testing."""
75 |
76 | @abstractmethod
77 | def before_forward_support(self):
78 | """Used in meta testing.
79 |
80 | This function will be called before model forward support data during
81 | meta testing.
82 | """
83 |
84 | @abstractmethod
85 | def before_forward_query(self):
86 | """Used in meta testing.
87 |
88 | This function will be called before model forward query data during
89 | meta testing.
90 | """
91 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/heads/linear_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, List
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from mmcls.models.builder import HEADS
7 | from torch import Tensor
8 |
9 | from .base_head import BaseFewShotHead
10 |
11 |
12 | @HEADS.register_module()
13 | class LinearHead(BaseFewShotHead):
14 | """Classification head for Baseline.
15 |
16 | Args:
17 | num_classes (int): Number of categories.
18 | in_channels (int): Number of channels in the input feature map.
19 | """
20 |
21 | def __init__(self, num_classes: int, in_channels: int, *args,
22 | **kwargs) -> None:
23 | super().__init__(*args, **kwargs)
24 | assert num_classes > 0, f'num_classes={num_classes} ' \
25 | f'must be a positive integer'
26 |
27 | self.num_classes = num_classes
28 | self.in_channels = in_channels
29 |
30 | self.init_layers()
31 |
32 | def init_layers(self) -> None:
33 | self.fc = nn.Linear(self.in_channels, self.num_classes)
34 |
35 | def forward_train(self, x: Tensor, gt_label: Tensor, **kwargs) -> Dict:
36 | """Forward training data."""
37 | cls_score = self.fc(x)
38 | losses = self.loss(cls_score, gt_label)
39 | return losses
40 |
41 | def forward_support(self, x: Tensor, gt_label: Tensor, **kwargs) -> Dict:
42 | """Forward support data in meta testing."""
43 | return self.forward_train(x, gt_label, **kwargs)
44 |
45 | def forward_query(self, x: Tensor, **kwargs) -> List:
46 | """Forward query data in meta testing."""
47 | cls_score = self.fc(x)
48 | pred = F.softmax(cls_score, dim=1)
49 | pred = list(pred.detach().cpu().numpy())
50 | return pred
51 |
52 | def before_forward_support(self) -> None:
53 | """Used in meta testing.
54 |
55 | This function will be called before model forward support data during
56 | meta testing.
57 | """
58 | self.init_layers()
59 | self.train()
60 |
61 | def before_forward_query(self) -> None:
62 | """Used in meta testing.
63 |
64 | This function will be called before model forward query data during
65 | meta testing.
66 | """
67 | self.eval()
68 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .mse_loss import MSELoss
3 | from .nll_loss import NLLLoss
4 |
5 | __all__ = ['MSELoss', 'NLLLoss']
6 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/losses/mse_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Optional, Union
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from mmcls.models.builder import LOSSES
7 | from mmcls.models.losses.utils import weighted_loss
8 | from torch import Tensor
9 | from typing_extensions import Literal
10 |
11 |
12 | @weighted_loss
13 | def mse_loss(pred: Tensor, target: Tensor) -> Tensor:
14 | """Wrapper of mse loss."""
15 | return F.mse_loss(pred, target, reduction='none')
16 |
17 |
18 | @LOSSES.register_module()
19 | class MSELoss(nn.Module):
20 | """MSELoss.
21 |
22 | Args:
23 | reduction (str): The method that reduces the loss to a
24 | scalar. Options are "none", "mean" and "sum". Default: 'mean'.
25 | loss_weight (float): The weight of the loss. Default: 1.0.
26 | """
27 |
28 | def __init__(self,
29 | reduction: Literal['none', 'mean', 'sum'] = 'mean',
30 | loss_weight: float = 1.0) -> None:
31 | super().__init__()
32 | self.reduction = reduction
33 | self.loss_weight = loss_weight
34 |
35 | def forward(self,
36 | pred: Tensor,
37 | target: Tensor,
38 | weight: Optional[Tensor] = None,
39 | avg_factor: Optional[Union[float, int]] = None,
40 | reduction_override: str = None) -> Tensor:
41 | """Forward function of loss.
42 |
43 | Args:
44 | pred (Tensor): The prediction with shape (N, *), where * means
45 | any number of additional dimensions.
46 | target (Tensor): The learning target of the prediction
47 | with shape (N, *) same as the input.
48 | weight (Tensor | None): Weight of the loss for each
49 | prediction. Default: None.
50 | avg_factor (float | int | None): Average factor that is used to
51 | average the loss. Default: None.
52 | reduction_override (str | None): The reduction method used to
53 | override the original reduction method of the loss.
54 | Options are "none", "mean" and "sum". Default: None.
55 |
56 | Returns:
57 | Tensor: The calculated loss
58 | """
59 | assert reduction_override in (None, 'none', 'mean', 'sum')
60 | reduction = (
61 | reduction_override if reduction_override else self.reduction)
62 | loss = self.loss_weight * mse_loss(
63 | pred, target, weight, reduction=reduction, avg_factor=avg_factor)
64 | return loss
65 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/losses/nll_loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Optional, Union
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from mmcls.models.builder import LOSSES
7 | from mmcls.models.losses.utils import weighted_loss
8 | from torch import Tensor
9 | from typing_extensions import Literal
10 |
11 |
12 | @weighted_loss
13 | def nll_loss(pred: Tensor, target: Tensor) -> Tensor:
14 | """Wrapper of nll loss."""
15 | return F.nll_loss(pred, target, reduction='none')
16 |
17 |
18 | @LOSSES.register_module()
19 | class NLLLoss(nn.Module):
20 | """NLLLoss.
21 |
22 | Args:
23 | reduction (str): The method that reduces the loss to a
24 | scalar. Options are "none", "mean" and "sum". Default: 'mean'.
25 | loss_weight (float): The weight of the loss. Default: 1.0.
26 | """
27 |
28 | def __init__(self,
29 | reduction: Literal['none', 'mean', 'sum'] = 'mean',
30 | loss_weight: float = 1.0):
31 | super().__init__()
32 | self.reduction = reduction
33 | self.loss_weight = loss_weight
34 |
35 | def forward(self,
36 | pred: Tensor,
37 | target: Tensor,
38 | weight: Optional[Tensor] = None,
39 | avg_factor: Optional[Union[float, int]] = None,
40 | reduction_override: Optional[str] = None) -> Tensor:
41 | """Forward function of loss.
42 |
43 | Args:
44 | pred (Tensor): The prediction with shape (N, C).
45 | target (Tensor): The learning target of the prediction.
46 | with shape (N, 1).
47 | weight (Tensor | None): Weight of the loss for each
48 | prediction. Default: None.
49 | avg_factor (float | int | None): Average factor that is used to
50 | average the loss. Default: None.
51 | reduction_override (str | None): The reduction method used to
52 | override the original reduction method of the loss.
53 | Options are "none", "mean" and "sum". Default: None.
54 |
55 | Returns:
56 | Tensor: The calculated loss
57 | """
58 | assert reduction_override in (None, 'none', 'mean', 'sum')
59 | reduction = (
60 | reduction_override if reduction_override else self.reduction)
61 | loss = self.loss_weight * nll_loss(
62 | pred, target, weight, reduction=reduction, avg_factor=avg_factor)
63 | return loss
64 |
--------------------------------------------------------------------------------
/mmfewshot/classification/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .maml_module import convert_maml_module
3 |
4 | __all__ = ['convert_maml_module']
5 |
--------------------------------------------------------------------------------
/mmfewshot/classification/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .meta_test_parallel import MetaTestParallel
3 |
4 | __all__ = ['MetaTestParallel']
5 |
--------------------------------------------------------------------------------
/mmfewshot/classification/utils/meta_test_parallel.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmcv.parallel.scatter_gather import scatter_kwargs
4 |
5 |
6 | class MetaTestParallel(nn.Module):
7 | """The MetaTestParallel module that supports DataContainer.
8 |
9 | Note that each task is tested on a single GPU. Thus the data and model
10 | on different GPU should be independent. :obj:`MMDistributedDataParallel`
11 | always automatically synchronizes the grad in different GPUs when doing
12 | the loss backward, which can not meet the requirements. Thus we simply
13 | copy the module and wrap it with an :obj:`MetaTestParallel`, which will
14 | send data to the device model.
15 |
16 | MetaTestParallel has two main differences with PyTorch DataParallel:
17 |
18 | - It supports a custom type :class:`DataContainer` which allows
19 | more flexible control of input data during both GPU and CPU
20 | inference.
21 | - It implement three more APIs ``before_meta_test()``,
22 | ``before_forward_support()`` and ``before_forward_query()``.
23 |
24 | Args:
25 | module (:class:`nn.Module`): Module to be encapsulated.
26 | dim (int): Dimension used to scatter the data. Defaults to 0.
27 | """
28 |
29 | def __init__(self, module: nn.Module, dim: int = 0) -> None:
30 | super().__init__()
31 | self.dim = dim
32 | self.module = module
33 | self.device = self.module.device
34 | if self.device == 'cpu':
35 | self.device_id = [-1]
36 | else:
37 | self.device_id = [self.module.get_device()]
38 |
39 | def forward(self, *inputs, **kwargs):
40 | """Override the original forward function.
41 |
42 | The main difference lies in the CPU inference where the data in
43 | :class:`DataContainers` will still be gathered.
44 | """
45 |
46 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id)
47 | if not inputs and not kwargs:
48 | inputs = ((), )
49 | kwargs = ({}, )
50 | return self.module(*inputs[0], **kwargs[0])
51 |
52 | def scatter(self, inputs, kwargs, device_ids):
53 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
54 |
55 | def before_meta_test(self, *inputs, **kwargs) -> None:
56 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id)
57 | if not inputs and not kwargs:
58 | inputs = ((), )
59 | kwargs = ({}, )
60 | return self.module.before_meta_test(*inputs[0], **kwargs[0])
61 |
62 | def before_forward_support(self, *inputs, **kwargs) -> None:
63 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id)
64 | if not inputs and not kwargs:
65 | inputs = ((), )
66 | kwargs = ({}, )
67 | return self.module.before_forward_support(*inputs[0], **kwargs[0])
68 |
69 | def before_forward_query(self, *inputs, **kwargs) -> None:
70 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id)
71 | if not inputs and not kwargs:
72 | inputs = ((), )
73 | kwargs = ({}, )
74 | return self.module.before_forward_query(*inputs[0], **kwargs[0])
75 |
--------------------------------------------------------------------------------
/mmfewshot/detection/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .apis import * # noqa: F401,F403
3 | from .core import * # noqa: F401,F403
4 | from .datasets import * # noqa: F401,F403
5 | from .models import * # noqa: F401,F403
6 |
--------------------------------------------------------------------------------
/mmfewshot/detection/apis/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .inference import (inference_detector, init_detector,
3 | process_support_images)
4 | from .test import (multi_gpu_model_init, multi_gpu_test, single_gpu_model_init,
5 | single_gpu_test)
6 | from .train import train_detector
7 |
8 | __all__ = [
9 | 'train_detector', 'single_gpu_model_init', 'multi_gpu_model_init',
10 | 'single_gpu_test', 'multi_gpu_test', 'inference_detector', 'init_detector',
11 | 'process_support_images'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmfewshot/detection/core/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .evaluation import * # noqa: F401, F403
3 | from .utils import * # noqa: F401, F403
4 |
--------------------------------------------------------------------------------
/mmfewshot/detection/core/evaluation/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .eval_hooks import QuerySupportDistEvalHook, QuerySupportEvalHook
3 | from .mean_ap import eval_map
4 |
5 | __all__ = ['QuerySupportEvalHook', 'QuerySupportDistEvalHook', 'eval_map']
6 |
--------------------------------------------------------------------------------
/mmfewshot/detection/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .custom_hook import ContrastiveLossDecayHook
3 |
4 | __all__ = ['ContrastiveLossDecayHook']
5 |
--------------------------------------------------------------------------------
/mmfewshot/detection/core/utils/custom_hook.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Sequence
3 |
4 | from mmcv.parallel import is_module_wrapper
5 | from mmcv.runner import HOOKS, Hook, Runner
6 |
7 |
8 | @HOOKS.register_module()
9 | class ContrastiveLossDecayHook(Hook):
10 | """Hook for contrast loss weight decay used in FSCE.
11 |
12 | Args:
13 | decay_steps (list[int] | tuple[int]): Each item in the list is
14 | the step to decay the loss weight.
15 | decay_rate (float): Decay rate. Default: 0.5.
16 | """
17 |
18 | def __init__(self,
19 | decay_steps: Sequence[int],
20 | decay_rate: float = 0.5) -> None:
21 | assert isinstance(
22 | decay_steps,
23 | (list, tuple)), '`decay_steps` should be list or tuple.'
24 | self.decay_steps = decay_steps
25 | self.decay_rate = decay_rate
26 |
27 | def before_iter(self, runner: Runner) -> None:
28 | runner_iter = runner.iter + 1
29 | decay_rate = 1.0
30 | # update decay rate by number of iteration
31 | for step in self.decay_steps:
32 | if runner_iter > step:
33 | decay_rate *= self.decay_rate
34 | # set decay rate in the bbox_head
35 | if is_module_wrapper(runner.model):
36 | runner.model.module.roi_head.bbox_head.set_decay_rate(decay_rate)
37 | else:
38 | runner.model.roi_head.bbox_head.set_decay_rate(decay_rate)
39 |
--------------------------------------------------------------------------------
/mmfewshot/detection/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .base import BaseFewShotDataset
3 | from .builder import build_dataloader, build_dataset
4 | from .coco import COCO_SPLIT, FewShotCocoDataset
5 | from .dataloader_wrappers import NWayKShotDataloader
6 | from .dataset_wrappers import NWayKShotDataset, QueryAwareDataset
7 | from .pipelines import CropResizeInstance, GenerateMask
8 | from .utils import NumpyEncoder, get_copy_dataset_type
9 | from .voc import VOC_SPLIT, FewShotVOCDataset
10 |
11 | __all__ = [
12 | 'build_dataloader', 'build_dataset', 'QueryAwareDataset',
13 | 'NWayKShotDataset', 'NWayKShotDataloader', 'BaseFewShotDataset',
14 | 'FewShotVOCDataset', 'FewShotCocoDataset', 'CropResizeInstance',
15 | 'GenerateMask', 'NumpyEncoder', 'COCO_SPLIT', 'VOC_SPLIT',
16 | 'get_copy_dataset_type'
17 | ]
18 |
--------------------------------------------------------------------------------
/mmfewshot/detection/datasets/dataloader_wrappers.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, Iterator
3 |
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class NWayKShotDataloader:
8 | """A dataloader wrapper.
9 |
10 | It Create a iterator to generate query and support batch simultaneously.
11 | Each batch contains query data and support data, and the lengths are
12 | batch_size and (num_support_ways * num_support_shots) respectively.
13 |
14 | Args:
15 | query_data_loader (DataLoader): DataLoader of query dataset
16 | support_data_loader (DataLoader): DataLoader of support datasets.
17 | """
18 |
19 | def __init__(self, query_data_loader: DataLoader,
20 | support_data_loader: DataLoader) -> None:
21 | self.dataset = query_data_loader.dataset
22 | self.sampler = query_data_loader.sampler
23 | self.query_data_loader = query_data_loader
24 | self.support_data_loader = support_data_loader
25 |
26 | def __iter__(self) -> Iterator:
27 | # if infinite sampler is used, this part of code only run once
28 | self.query_iter = iter(self.query_data_loader)
29 | self.support_iter = iter(self.support_data_loader)
30 | return self
31 |
32 | def __next__(self) -> Dict:
33 | # call query and support iterator
34 | query_data = self.query_iter.next()
35 | support_data = self.support_iter.next()
36 | return {'query_data': query_data, 'support_data': support_data}
37 |
38 | def __len__(self) -> int:
39 | return len(self.query_data_loader)
40 |
41 |
42 | class TwoBranchDataloader:
43 | """A dataloader wrapper.
44 |
45 | It Create a iterator to iterate two different dataloader simultaneously.
46 | Note that `TwoBranchDataloader` dose not support `EpochBasedRunner`
47 | and the length of dataloader is decided by main dataset.
48 |
49 | Args:
50 | main_data_loader (DataLoader): DataLoader of main dataset.
51 | auxiliary_data_loader (DataLoader): DataLoader of auxiliary dataset.
52 | """
53 |
54 | def __init__(self, main_data_loader: DataLoader,
55 | auxiliary_data_loader: DataLoader) -> None:
56 | self.dataset = main_data_loader.dataset
57 | self.main_data_loader = main_data_loader
58 | self.auxiliary_data_loader = auxiliary_data_loader
59 |
60 | def __iter__(self) -> Iterator:
61 | # if infinite sampler is used, this part of code only run once
62 | self.main_iter = iter(self.main_data_loader)
63 | self.auxiliary_iter = iter(self.auxiliary_data_loader)
64 | return self
65 |
66 | def __next__(self) -> Dict:
67 | # The iterator actually has infinite length. Note that it can NOT
68 | # be used in `EpochBasedRunner`, because the `EpochBasedRunner` will
69 | # enumerate the dataloader forever.
70 | try:
71 | main_data = next(self.main_iter)
72 | except StopIteration:
73 | self.main_iter = iter(self.main_data_loader)
74 | main_data = next(self.main_iter)
75 | try:
76 | auxiliary_data = next(self.auxiliary_iter)
77 | except StopIteration:
78 | self.auxiliary_iter = iter(self.auxiliary_data_loader)
79 | auxiliary_data = next(self.auxiliary_iter)
80 | return {'main_data': main_data, 'auxiliary_data': auxiliary_data}
81 |
82 | def __len__(self) -> int:
83 | return len(self.main_data_loader)
84 |
--------------------------------------------------------------------------------
/mmfewshot/detection/datasets/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .formatting import MultiImageCollect, MultiImageFormatBundle
3 | from .transforms import (CropInstance, CropResizeInstance, GenerateMask,
4 | MultiImageNormalize, MultiImagePad,
5 | MultiImageRandomCrop, MultiImageRandomFlip,
6 | ResizeToMultiScale)
7 |
8 | __all__ = [
9 | 'CropResizeInstance', 'GenerateMask', 'CropInstance', 'ResizeToMultiScale',
10 | 'MultiImageNormalize', 'MultiImageFormatBundle', 'MultiImageCollect',
11 | 'MultiImagePad', 'MultiImageRandomCrop', 'MultiImageRandomFlip'
12 | ]
13 |
--------------------------------------------------------------------------------
/mmfewshot/detection/datasets/pipelines/formatting.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, List
3 |
4 | from mmcv.parallel import DataContainer as DC
5 | from mmdet.datasets.builder import PIPELINES
6 | from mmdet.datasets.pipelines import Collect, DefaultFormatBundle
7 |
8 |
9 | @PIPELINES.register_module()
10 | class MultiImageFormatBundle(DefaultFormatBundle):
11 |
12 | def __call__(self, results_list: List[Dict]) -> List[Dict]:
13 | """Transform and format common fields of each results in
14 | `results_list`.
15 |
16 | Args:
17 | results_list (list[dict]): List of result dict contains the data
18 | to convert.
19 |
20 | Returns:
21 | list[dict]: List of result dict contains the data that is formatted
22 | with default bundle.
23 | """
24 | for results in results_list:
25 | super().__call__(results)
26 | return results_list
27 |
28 |
29 | @PIPELINES.register_module()
30 | class MultiImageCollect(Collect):
31 |
32 | def __call__(self, results_list: List[Dict]) -> Dict:
33 | """Collect all keys of each results in `results_list`.
34 |
35 | The keys in `meta_keys` will be converted to :obj:mmcv.DataContainer.
36 | A scale suffix also will be added to each key to specific from which
37 | scale of results.
38 |
39 | Args:
40 | results_list (list[dict]): List of result dict contains the data
41 | to collect.
42 |
43 | Returns:
44 | dict: The result dict contains the following keys
45 |
46 | - `{key}_scale_{i}` for i in 'num_scales' for key in`self.keys`
47 | - `img_metas_scale_{i}` for i in 'num_scales'
48 | """
49 | data = {}
50 | for i, results in enumerate(results_list):
51 | img_meta = {key: results[key] for key in self.meta_keys}
52 | data[f'img_metas_scale{i}'] = DC(img_meta, cpu_only=True)
53 | for key in self.keys:
54 | data[f'{key}_scale_{i}'] = results[key]
55 | return data
56 |
--------------------------------------------------------------------------------
/mmfewshot/detection/datasets/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import json
3 |
4 | import numpy as np
5 |
6 |
7 | class NumpyEncoder(json.JSONEncoder):
8 | """Save numpy array obj to json."""
9 |
10 | def default(self, obj: object) -> object:
11 | if isinstance(obj, np.ndarray):
12 | return obj.tolist()
13 | return json.JSONEncoder.default(self, obj)
14 |
15 |
16 | def get_copy_dataset_type(dataset_type: str) -> str:
17 | """Return corresponding copy dataset type."""
18 | if dataset_type in ['FewShotVOCDataset', 'FewShotVOCDefaultDataset']:
19 | copy_dataset_type = 'FewShotVOCCopyDataset'
20 | elif dataset_type in ['FewShotCocoDataset', 'FewShotCocoDefaultDataset']:
21 | copy_dataset_type = 'FewShotCocoCopyDataset'
22 | else:
23 | raise TypeError(f'{dataset_type} '
24 | f'not support copy data_infos operation.')
25 |
26 | return copy_dataset_type
27 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
3 | ROI_EXTRACTORS, SHARED_HEADS, build_backbone,
4 | build_head, build_loss, build_neck,
5 | build_roi_extractor, build_shared_head)
6 |
7 | from .backbones import * # noqa: F401,F403
8 | from .builder import build_detector
9 | from .dense_heads import * # noqa: F401,F403
10 | from .detectors import * # noqa: F401,F403
11 | from .losses import * # noqa: F401,F403
12 | from .roi_heads import * # noqa: F401,F403
13 | from .utils import * # noqa: F401,F403
14 |
15 | __all__ = [
16 | 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
17 | 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
18 | 'build_shared_head', 'build_head', 'build_loss', 'build_detector'
19 | ]
20 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .resnet_with_meta_conv import ResNetWithMetaConv
3 |
4 | __all__ = ['ResNetWithMetaConv']
5 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/backbones/resnet_with_meta_conv.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Tuple
3 |
4 | from mmcv.cnn import build_conv_layer
5 | from mmdet.models import ResNet
6 | from mmdet.models.builder import BACKBONES
7 | from torch import Tensor
8 |
9 |
10 | @BACKBONES.register_module()
11 | class ResNetWithMetaConv(ResNet):
12 | """ResNet with `meta_conv` to handle different inputs in metarcnn and
13 | fsdetview.
14 |
15 | When input with shape (N, 3, H, W) from images, the network will use
16 | `conv1` as regular ResNet. When input with shape (N, 4, H, W) from (image +
17 | mask) the network will replace `conv1` with `meta_conv` to handle
18 | additional channel.
19 | """
20 |
21 | def __init__(self, **kwargs) -> None:
22 | super().__init__(**kwargs)
23 | self.meta_conv = build_conv_layer(
24 | self.conv_cfg, # from config of ResNet
25 | 4,
26 | 64,
27 | kernel_size=7,
28 | stride=2,
29 | padding=3,
30 | bias=False)
31 |
32 | def forward(self, x: Tensor, use_meta_conv: bool = False) -> Tuple[Tensor]:
33 | """Forward function.
34 |
35 | When input with shape (N, 3, H, W) from images, the network will use
36 | `conv1` as regular ResNet. When input with shape (N, 4, H, W) from
37 | (image + mask) the network will replace `conv1` with `meta_conv` to
38 | handle additional channel.
39 |
40 | Args:
41 | x (Tensor): Tensor with shape (N, 3, H, W) from images
42 | or (N, 4, H, W) from (images + masks).
43 | use_meta_conv (bool): If set True, forward input tensor with
44 | `meta_conv` which require tensor with shape (N, 4, H, W).
45 | Otherwise, forward input tensor with `conv1` which require
46 | tensor with shape (N, 3, H, W). Default: False.
47 |
48 | Returns:
49 | tuple[Tensor]: Tuple of features, each item with
50 | shape (N, C, H, W).
51 | """
52 | if use_meta_conv:
53 | x = self.meta_conv(x)
54 | else:
55 | x = self.conv1(x)
56 | x = self.norm1(x)
57 | x = self.relu(x)
58 | x = self.maxpool(x)
59 | outs = []
60 | for i, layer_name in enumerate(self.res_layers):
61 | res_layer = getattr(self, layer_name)
62 | x = res_layer(x)
63 | if i in self.out_indices:
64 | outs.append(x)
65 | return tuple(outs)
66 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/builder.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Optional
3 |
4 | from mmcv.utils import ConfigDict, print_log
5 | from mmdet.models.builder import DETECTORS
6 |
7 |
8 | def build_detector(cfg: ConfigDict, logger: Optional[object] = None):
9 | """Build detector."""
10 | # get the prefix of fixed parameters
11 | frozen_parameters = cfg.pop('frozen_parameters', None)
12 |
13 | model = DETECTORS.build(cfg)
14 | model.init_weights()
15 | # freeze parameters by prefix
16 | if frozen_parameters is not None:
17 | print_log(f'Frozen parameters: {frozen_parameters}', logger)
18 | for name, param in model.named_parameters():
19 | for frozen_prefix in frozen_parameters:
20 | if frozen_prefix in name:
21 | param.requires_grad = False
22 | if param.requires_grad:
23 | print_log(f'Training parameters: {name}', logger)
24 | return model
25 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/dense_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .attention_rpn_head import AttentionRPNHead
3 | from .two_branch_rpn_head import TwoBranchRPNHead
4 |
5 | __all__ = ['AttentionRPNHead', 'TwoBranchRPNHead']
6 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/detectors/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .attention_rpn_detector import AttentionRPNDetector
3 | from .fsce import FSCE
4 | from .fsdetview import FSDetView
5 | from .meta_rcnn import MetaRCNN
6 | from .mpsr import MPSR
7 | from .query_support_detector import QuerySupportDetector
8 | from .tfa import TFA
9 |
10 | __all__ = [
11 | 'QuerySupportDetector', 'AttentionRPNDetector', 'FSCE', 'FSDetView', 'TFA',
12 | 'MPSR', 'MetaRCNN'
13 | ]
14 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/detectors/fsce.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmdet.models.builder import DETECTORS
3 | from mmdet.models.detectors.two_stage import TwoStageDetector
4 |
5 |
6 | @DETECTORS.register_module()
7 | class FSCE(TwoStageDetector):
8 | """Implementation of `FSCE `_"""
9 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/detectors/fsdetview.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmdet.models.builder import DETECTORS
3 |
4 | from .meta_rcnn import MetaRCNN
5 |
6 |
7 | @DETECTORS.register_module()
8 | class FSDetView(MetaRCNN):
9 | """Implementation of `FSDetView `_."""
10 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/detectors/tfa.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmdet.models.builder import DETECTORS
3 | from mmdet.models.detectors.two_stage import TwoStageDetector
4 |
5 |
6 | @DETECTORS.register_module()
7 | class TFA(TwoStageDetector):
8 | """Implementation of `TFA `_"""
9 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .supervised_contrastive_loss import SupervisedContrastiveLoss
3 |
4 | __all__ = ['SupervisedContrastiveLoss']
5 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/roi_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .bbox_heads import (ContrastiveBBoxHead, CosineSimBBoxHead,
3 | MultiRelationBBoxHead)
4 | from .contrastive_roi_head import ContrastiveRoIHead
5 | from .fsdetview_roi_head import FSDetViewRoIHead
6 | from .meta_rcnn_roi_head import MetaRCNNRoIHead
7 | from .multi_relation_roi_head import MultiRelationRoIHead
8 | from .shared_heads import MetaRCNNResLayer
9 | from .two_branch_roi_head import TwoBranchRoIHead
10 |
11 | __all__ = [
12 | 'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead',
13 | 'ContrastiveRoIHead', 'MultiRelationRoIHead', 'FSDetViewRoIHead',
14 | 'MetaRCNNRoIHead', 'MetaRCNNResLayer', 'TwoBranchRoIHead'
15 | ]
16 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/roi_heads/bbox_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .contrastive_bbox_head import ContrastiveBBoxHead
3 | from .cosine_sim_bbox_head import CosineSimBBoxHead
4 | from .meta_bbox_head import MetaBBoxHead
5 | from .multi_relation_bbox_head import MultiRelationBBoxHead
6 | from .two_branch_bbox_head import TwoBranchBBoxHead
7 |
8 | __all__ = [
9 | 'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead',
10 | 'MetaBBoxHead', 'TwoBranchBBoxHead'
11 | ]
12 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/roi_heads/bbox_heads/meta_bbox_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import copy
3 | from typing import Dict, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 | from mmcv.runner import force_fp32
8 | from mmdet.models.builder import HEADS, build_loss
9 | from mmdet.models.losses import accuracy
10 | from mmdet.models.roi_heads import BBoxHead
11 | from torch import Tensor
12 |
13 |
14 | @HEADS.register_module()
15 | class MetaBBoxHead(BBoxHead):
16 | """BBoxHead with meta classification for metarcnn and fsdetview.
17 |
18 | Args:
19 | num_meta_classes (int): Number of classes for meta classification.
20 | meta_cls_in_channels (int): Number of support feature channels.
21 | with_meta_cls_loss (bool): Use meta classification loss.
22 | Default: True.
23 | meta_cls_loss_weight (float | None): The loss weight of `loss_meta`.
24 | Default: None.
25 | loss_meta (dict): Config for meta classification loss.
26 | """
27 |
28 | def __init__(self,
29 | num_meta_classes: int,
30 | meta_cls_in_channels: int = 2048,
31 | with_meta_cls_loss: bool = True,
32 | meta_cls_loss_weight: Optional[float] = None,
33 | loss_meta: Dict = dict(
34 | type='CrossEntropyLoss',
35 | use_sigmoid=False,
36 | loss_weight=1.0),
37 | *args,
38 | **kwargs) -> None:
39 | super().__init__(*args, **kwargs)
40 | self.with_meta_cls_loss = with_meta_cls_loss
41 | if with_meta_cls_loss:
42 | self.fc_meta = nn.Linear(meta_cls_in_channels, num_meta_classes)
43 | self.meta_cls_loss_weight = meta_cls_loss_weight
44 | self.loss_meta_cls = build_loss(copy.deepcopy(loss_meta))
45 |
46 | def forward_meta_cls(self, support_feat: Tensor) -> Tensor:
47 | """Forward function for meta classification.
48 |
49 | Args:
50 | support_feat (Tensor): Shape of (N, C, H, W).
51 |
52 | Returns:
53 | Tensor: Box scores with shape of (N, num_meta_classes, H, W).
54 | """
55 | meta_cls_score = self.fc_meta(support_feat)
56 | return meta_cls_score
57 |
58 | @force_fp32(apply_to='meta_cls_score')
59 | def loss_meta(self,
60 | meta_cls_score: Tensor,
61 | meta_cls_labels: Tensor,
62 | meta_cls_label_weights: Tensor,
63 | reduction_override: Optional[str] = None) -> Dict:
64 | """Meta classification loss.
65 |
66 | Args:
67 | meta_cls_score (Tensor): Predicted meta classification scores
68 | with shape (N, num_meta_classes).
69 | meta_cls_labels (Tensor): Corresponding class indices with
70 | shape (N).
71 | meta_cls_label_weights (Tensor): Meta classification loss weight
72 | of each sample with shape (N).
73 | reduction_override (str | None): The reduction method used to
74 | override the original reduction method of the loss. Options
75 | are "none", "mean" and "sum". Default: None.
76 |
77 | Returns:
78 | Dict: The calculated loss.
79 | """
80 | losses = dict()
81 | if self.meta_cls_loss_weight is None:
82 | loss_weight = 1. / max(
83 | torch.sum(meta_cls_label_weights > 0).float().item(), 1.)
84 | else:
85 | loss_weight = self.meta_cls_loss_weight
86 | if meta_cls_score.numel() > 0:
87 | loss_meta_cls_ = self.loss_meta_cls(
88 | meta_cls_score,
89 | meta_cls_labels,
90 | meta_cls_label_weights,
91 | reduction_override=reduction_override)
92 | losses['loss_meta_cls'] = loss_meta_cls_ * loss_weight
93 | losses['meta_acc'] = accuracy(meta_cls_score, meta_cls_labels)
94 | return losses
95 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/roi_heads/fsdetview_roi_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, Optional
3 |
4 | import torch
5 | from mmdet.models.builder import HEADS
6 | from torch import Tensor
7 |
8 | from .meta_rcnn_roi_head import MetaRCNNRoIHead
9 |
10 |
11 | @HEADS.register_module()
12 | class FSDetViewRoIHead(MetaRCNNRoIHead):
13 | """Roi head for `FSDetView `_.
14 |
15 | Args:
16 | aggregation_layer (dict): Config of `aggregation_layer`.
17 | Default: None.
18 | """
19 |
20 | def __init__(self,
21 | aggregation_layer: Optional[Dict] = None,
22 | **kwargs) -> None:
23 | super().__init__(aggregation_layer=aggregation_layer, **kwargs)
24 |
25 | def _bbox_forward(self, query_roi_feats: Tensor,
26 | support_roi_feats: Tensor) -> Dict:
27 | """Box head forward function used in both training and testing.
28 |
29 | Args:
30 | query_roi_feats (Tensor): Roi features with shape (N, C).
31 | support_roi_feats (Tensor): Roi features with shape (1, C).
32 |
33 | Returns:
34 | dict: A dictionary of predicted results.
35 | """
36 | # feature aggregation
37 | roi_feats = self.aggregation_layer(
38 | query_feat=query_roi_feats.unsqueeze(-1).unsqueeze(-1),
39 | support_feat=support_roi_feats.view(1, -1, 1, 1))
40 | roi_feats = torch.cat(roi_feats, dim=1)
41 | roi_feats = torch.cat((roi_feats, query_roi_feats), dim=1)
42 | cls_score, bbox_pred = self.bbox_head(roi_feats)
43 | bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
44 | return bbox_results
45 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/roi_heads/shared_heads/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .meta_rcnn_res_layer import MetaRCNNResLayer
3 |
4 | __all__ = ['MetaRCNNResLayer']
5 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/roi_heads/shared_heads/meta_rcnn_res_layer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch.nn as nn
3 | from mmdet.models.builder import SHARED_HEADS
4 | from mmdet.models.roi_heads import ResLayer
5 | from torch import Tensor
6 |
7 |
8 | @SHARED_HEADS.register_module()
9 | class MetaRCNNResLayer(ResLayer):
10 | """Shared resLayer for metarcnn and fsdetview.
11 |
12 | It provides different forward logics for query and support images.
13 | """
14 |
15 | def __init__(self, *args, **kwargs):
16 | super().__init__(*args, **kwargs)
17 | self.max_pool = nn.MaxPool2d(2)
18 | self.sigmoid = nn.Sigmoid()
19 |
20 | def forward(self, x: Tensor) -> Tensor:
21 | """Forward function for query images.
22 |
23 | Args:
24 | x (Tensor): Features from backbone with shape (N, C, H, W).
25 |
26 | Returns:
27 | Tensor: Shape of (N, C).
28 | """
29 | res_layer = getattr(self, f'layer{self.stage + 1}')
30 | out = res_layer(x)
31 | out = out.mean(3).mean(2)
32 | return out
33 |
34 | def forward_support(self, x: Tensor) -> Tensor:
35 | """Forward function for support images.
36 |
37 | Args:
38 | x (Tensor): Features from backbone with shape (N, C, H, W).
39 |
40 | Returns:
41 | Tensor: Shape of (N, C).
42 | """
43 | x = self.max_pool(x)
44 | res_layer = getattr(self, f'layer{self.stage + 1}')
45 | out = res_layer(x)
46 | out = self.sigmoid(out)
47 | out = out.mean(3).mean(2)
48 | return out
49 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/roi_heads/two_branch_roi_head.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from typing import Dict, List, Tuple
3 |
4 | import torch
5 | from mmdet.models.builder import HEADS
6 | from mmdet.models.roi_heads import StandardRoIHead
7 | from torch import Tensor
8 |
9 |
10 | @HEADS.register_module()
11 | class TwoBranchRoIHead(StandardRoIHead):
12 | """RoI head for `MPSR `_."""
13 |
14 | def forward_auxiliary_train(self, feats: Tuple[Tensor],
15 | gt_labels: List[Tensor]) -> Dict:
16 | """Forward function and calculate loss for auxiliary data in training.
17 |
18 | Args:
19 | feats (tuple[Tensor]): List of features at multiple scales, each
20 | is a 4D-tensor.
21 | gt_labels (list[Tensor]): List of class indices corresponding
22 | to each features, each is a 4D-tensor.
23 |
24 | Returns:
25 | dict[str, Tensor]: a dictionary of loss components
26 | """
27 | # bbox head forward and loss
28 | auxiliary_losses = self._bbox_forward_auxiliary_train(feats, gt_labels)
29 | return auxiliary_losses
30 |
31 | def _bbox_forward_auxiliary_train(self, feats: Tuple[Tensor],
32 | gt_labels: List[Tensor]) -> Dict:
33 | """Run forward function and calculate loss for box head in training.
34 |
35 | Args:
36 | feats (tuple[Tensor]): List of features at multiple scales, each
37 | is a 4D-tensor.
38 | gt_labels (list[Tensor]): List of class indices corresponding
39 | to each features, each is a 4D-tensor.
40 |
41 | Returns:
42 | dict[str, Tensor]: a dictionary of loss components
43 | """
44 | cls_scores, = self.bbox_head.forward_auxiliary(feats)
45 | cls_score = torch.cat(cls_scores, dim=0)
46 | labels = torch.cat(gt_labels, dim=0)
47 | label_weights = torch.ones_like(labels)
48 | losses = self.bbox_head.auxiliary_loss(cls_score, labels,
49 | label_weights)
50 |
51 | return losses
52 |
--------------------------------------------------------------------------------
/mmfewshot/detection/models/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .aggregation_layer import * # noqa: F401,F403
3 |
--------------------------------------------------------------------------------
/mmfewshot/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from .collate import multi_pipeline_collate_fn
3 | from .compat_config import compat_cfg
4 | from .dist_utils import check_dist_init, sync_random_seed
5 | from .infinite_sampler import (DistributedInfiniteGroupSampler,
6 | DistributedInfiniteSampler,
7 | InfiniteGroupSampler, InfiniteSampler)
8 | from .local_seed import local_numpy_seed
9 | from .logger import get_root_logger
10 | from .runner import InfiniteEpochBasedRunner
11 |
12 | __all__ = [
13 | 'multi_pipeline_collate_fn', 'local_numpy_seed',
14 | 'InfiniteEpochBasedRunner', 'InfiniteSampler', 'InfiniteGroupSampler',
15 | 'DistributedInfiniteSampler', 'DistributedInfiniteGroupSampler',
16 | 'get_root_logger', 'check_dist_init', 'sync_random_seed', 'compat_cfg'
17 | ]
18 |
--------------------------------------------------------------------------------
/mmfewshot/utils/collect_env.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcv.utils import collect_env as collect_basic_env
3 | from mmcv.utils import get_git_hash
4 |
5 | import mmfewshot
6 |
7 |
8 | def collect_env():
9 | env_info = collect_basic_env()
10 | env_info['MMFewShot'] = (
11 | mmfewshot.__version__ + '+' + get_git_hash(digits=7))
12 | return env_info
13 |
14 |
15 | if __name__ == '__main__':
16 | for name, val in collect_env().items():
17 | print(f'{name}: {val}')
18 |
--------------------------------------------------------------------------------
/mmfewshot/utils/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 | import torch.distributed as dist
5 | from mmcv.runner import get_dist_info
6 |
7 |
8 | def check_dist_init():
9 | return dist.is_available() and dist.is_initialized()
10 |
11 |
12 | def sync_random_seed(seed=None, device='cuda'):
13 | """Propagating the seed of rank 0 to all other ranks.
14 |
15 | Make sure different ranks share the same seed. All workers must call
16 | this function, otherwise it will deadlock. This method is generally used in
17 | `DistributedSampler`, because the seed should be identical across all
18 | processes in the distributed group.
19 | In distributed sampling, different ranks should sample non-overlapped
20 | data in the dataset. Therefore, this function is used to make sure that
21 | each rank shuffles the data indices in the same order based
22 | on the same seed. Then different ranks could use different indices
23 | to select non-overlapped data from the same data list.
24 | Args:
25 | seed (int, Optional): The seed. Default to None.
26 | device (str): The device where the seed will be put on.
27 | Default to 'cuda'.
28 | Returns:
29 | int: Seed to be used.
30 | """
31 | if seed is None:
32 | seed = np.random.randint(2**31)
33 | assert isinstance(seed, int)
34 |
35 | rank, world_size = get_dist_info()
36 |
37 | if world_size == 1:
38 | return seed
39 |
40 | if rank == 0:
41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device)
42 | else:
43 | random_num = torch.tensor(0, dtype=torch.int32, device=device)
44 | dist.broadcast(random_num, src=0)
45 | return random_num.item()
46 |
--------------------------------------------------------------------------------
/mmfewshot/utils/local_seed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from contextlib import contextmanager
3 | from typing import Optional
4 |
5 | import numpy as np
6 |
7 |
8 | @contextmanager
9 | def local_numpy_seed(seed: Optional[int] = None) -> None:
10 | """Run numpy codes with a local random seed.
11 |
12 | If seed is None, the default random state will be used.
13 | """
14 | state = np.random.get_state()
15 | if seed is not None:
16 | np.random.seed(seed)
17 | try:
18 | yield
19 | finally:
20 | np.random.set_state(state)
21 |
--------------------------------------------------------------------------------
/mmfewshot/utils/logger.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import logging
3 |
4 | from mmcv.utils import get_logger
5 |
6 |
7 | def get_root_logger(log_file=None, log_level=logging.INFO):
8 | return get_logger('mmfewshot', log_file, log_level)
9 |
--------------------------------------------------------------------------------
/mmfewshot/utils/runner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import time
3 |
4 | from mmcv.runner import EpochBasedRunner
5 | from mmcv.runner.builder import RUNNERS
6 | from torch.utils.data import DataLoader
7 |
8 |
9 | @RUNNERS.register_module()
10 | class InfiniteEpochBasedRunner(EpochBasedRunner):
11 | """Epoch-based Runner supports dataloader with InfiniteSampler.
12 |
13 | The workers of dataloader will re-initialize, when the iterator of
14 | dataloader is created. InfiniteSampler is designed to avoid these time
15 | consuming operations, since the iterator with InfiniteSampler will never
16 | reach the end.
17 | """
18 |
19 | def train(self, data_loader: DataLoader, **kwargs) -> None:
20 | self.model.train()
21 | self.mode = 'train'
22 | self.data_loader = data_loader
23 | self._max_iters = self._max_epochs * len(self.data_loader)
24 | self.call_hook('before_train_epoch')
25 | time.sleep(2) # Prevent possible deadlock during epoch transition
26 |
27 | # To reuse the iterator, we only create iterator once and bind it
28 | # with runner. In the next epoch, the iterator will be used against
29 | if not hasattr(self, 'data_loader_iter'):
30 | self.data_loader_iter = iter(self.data_loader)
31 |
32 | # The InfiniteSampler will never reach the end, but we set the
33 | # length of InfiniteSampler to the actual length of dataset.
34 | # The length of dataloader is determined by the length of sampler,
35 | # when the sampler is not None. Therefore, we can simply forward the
36 | # whole dataset in a epoch by length of dataloader.
37 |
38 | for i in range(len(self.data_loader)):
39 | data_batch = next(self.data_loader_iter)
40 | self._inner_iter = i
41 | self.call_hook('before_train_iter')
42 | self.run_iter(data_batch, train_mode=True, **kwargs)
43 | self.call_hook('after_train_iter')
44 | self._iter += 1
45 |
46 | self.call_hook('after_train_epoch')
47 | self._epoch += 1
48 |
--------------------------------------------------------------------------------
/mmfewshot/version.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Open-MMLab. All rights reserved.
2 | __version__ = '0.1.0_af4fad5'
3 | short_version = __version__
4 |
5 |
6 | def parse_version_info(version_str):
7 | version_info_ = []
8 | for x in version_str.split('.'):
9 | if x.isdigit():
10 | version_info_.append(int(x))
11 | elif x.find('rc') != -1:
12 | patch_version = x.split('rc')
13 | version_info_.append(int(patch_version[0]))
14 | version_info_.append(f'rc{patch_version[1]}')
15 | return tuple(version_info_)
16 |
17 |
18 | version_info = parse_version_info(__version__)
19 |
--------------------------------------------------------------------------------
/mmfscil/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
2 | from .models import *
3 | from .augments import *
4 |
--------------------------------------------------------------------------------
/mmfscil/apis/__init__.py:
--------------------------------------------------------------------------------
1 | from .train import train_model
2 |
--------------------------------------------------------------------------------
/mmfscil/augments/__init__.py:
--------------------------------------------------------------------------------
1 | from .mixup import BatchMixupLayer
2 | from .idty import Identity
3 | from .cutmix import BatchCutMixLayer
4 |
--------------------------------------------------------------------------------
/mmfscil/augments/cutmix.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 |
5 | from mmcls.models.utils.augment.builder import AUGMENT
6 | from mmcls.models.utils.augment.cutmix import BaseCutMixLayer
7 |
8 |
9 | @AUGMENT.register_module(name='BatchCutMixTwoLabel')
10 | class BatchCutMixLayer(BaseCutMixLayer):
11 | r"""CutMix layer for a batch of data.
12 |
13 | CutMix is a method to improve the network's generalization capability. It's
14 | proposed in `CutMix: Regularization Strategy to Train Strong Classifiers
15 | with Localizable Features `
16 |
17 | With this method, patches are cut and pasted among training images where
18 | the ground truth labels are also mixed proportionally to the area of the
19 | patches.
20 |
21 | Args:
22 | alpha (float): Parameters for Beta distribution to generate the
23 | mixing ratio. It should be a positive number. More details
24 | can be found in :class:`BatchMixupLayer`.
25 | num_classes (int): The number of classes
26 | prob (float): The probability to execute cutmix. It should be in
27 | range [0, 1]. Defaults to 1.0.
28 | cutmix_minmax (List[float], optional): The min/max area ratio of the
29 | patches. If not None, the bounding-box of patches is uniform
30 | sampled within this ratio range, and the ``alpha`` will be ignored.
31 | Otherwise, the bounding-box is generated according to the
32 | ``alpha``. Defaults to None.
33 | correct_lam (bool): Whether to apply lambda correction when cutmix bbox
34 | clipped by image borders. Defaults to True.
35 |
36 | Note:
37 | If the ``cutmix_minmax`` is None, how to generate the bounding-box of
38 | patches according to the ``alpha``?
39 |
40 | First, generate a :math:`\lambda`, details can be found in
41 | :class:`BatchMixupLayer`. And then, the area ratio of the bounding-box
42 | is calculated by:
43 |
44 | .. math::
45 | \text{ratio} = \sqrt{1-\lambda}
46 | """
47 |
48 | def __init__(self, *args, **kwargs):
49 | super(BatchCutMixLayer, self).__init__(*args, **kwargs)
50 |
51 | def cutmix(self, img, gt_label):
52 | lam = np.random.beta(self.alpha, self.alpha)
53 | batch_size = img.size(0)
54 | index = torch.randperm(batch_size)
55 |
56 | (bby1, bby2, bbx1, bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam)
57 | img[:, :, bby1:bby2, bbx1:bbx2] = img[index, :, bby1:bby2, bbx1:bbx2]
58 | return img, lam, gt_label, gt_label[index]
59 |
60 | def __call__(self, img, gt_label):
61 | return self.cutmix(img, gt_label)
62 |
--------------------------------------------------------------------------------
/mmfscil/augments/idty.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from mmcls.models.utils.augment.builder import AUGMENT
3 |
4 |
5 | @AUGMENT.register_module(name='IdentityTwoLabel')
6 | class Identity(object):
7 | """Change gt_label to one_hot encoding and keep img as the same.
8 |
9 | Args:
10 | num_classes (int): The number of classes.
11 | prob (float): MixUp probability. It should be in range [0, 1].
12 | Default to 1.0
13 | """
14 |
15 | def __init__(self, num_classes, prob=1.0):
16 | super(Identity, self).__init__()
17 |
18 | assert isinstance(num_classes, int)
19 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0
20 |
21 | self.num_classes = num_classes
22 | self.prob = prob
23 |
24 | def __call__(self, img, gt_label):
25 | return img, None, gt_label, None
26 |
--------------------------------------------------------------------------------
/mmfscil/augments/mixup.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 | import torch
4 |
5 | from mmcls.models.utils.augment.builder import AUGMENT
6 | from mmcls.models.utils.augment.mixup import BaseMixupLayer
7 |
8 |
9 | @AUGMENT.register_module(name='BatchMixupTwoLabel')
10 | class BatchMixupLayer(BaseMixupLayer):
11 | r"""Mixup layer for a batch of data.
12 |
13 | Mixup is a method to reduces the memorization of corrupt labels and
14 | increases the robustness to adversarial examples. It's
15 | proposed in `mixup: Beyond Empirical Risk Minimization
16 | `
17 |
18 | This method simply linearly mix pairs of data and their labels.
19 |
20 | Args:
21 | alpha (float): Parameters for Beta distribution to generate the
22 | mixing ratio. It should be a positive number. More details
23 | are in the note.
24 | num_classes (int): The number of classes.
25 | prob (float): The probability to execute mixup. It should be in
26 | range [0, 1]. Default sto 1.0.
27 |
28 | Note:
29 | The :math:`\alpha` (``alpha``) determines a random distribution
30 | :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample
31 | a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random
32 | distribution.
33 | """
34 |
35 | def __init__(self, *args, **kwargs):
36 | super(BatchMixupLayer, self).__init__(*args, **kwargs)
37 |
38 | def mixup(self, img, gt_label):
39 | lam = np.random.beta(self.alpha, self.alpha)
40 | batch_size = img.size(0)
41 | index = torch.randperm(batch_size)
42 |
43 | mixed_img = lam * img + (1 - lam) * img[index, :]
44 | return mixed_img, lam, gt_label, gt_label[index]
45 |
46 | def __call__(self, img, gt_label):
47 | return self.mixup(img, gt_label)
48 |
--------------------------------------------------------------------------------
/mmfscil/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .mini_imagenet import MiniImageNetFSCILDataset
2 | from .cub import CUBFSCILDataset
3 | from .cifar100 import CIFAR100FSCILDataset
4 | from .memory import MemoryDataset
5 |
--------------------------------------------------------------------------------
/mmfscil/datasets/memory.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import torch
4 | from torch.utils.data import Dataset
5 |
6 |
7 | class MemoryDataset(Dataset):
8 | """MemoryDataset is a dataset that loads features
9 | """
10 |
11 | def __init__(
12 | self,
13 | feats: torch.Tensor,
14 | labels: torch.Tensor,
15 | ):
16 | self.feats = feats
17 | self.labels = labels
18 | assert len(self.feats) == len(self.labels), "The features and labels are with different sizes."
19 |
20 | def __len__(self) -> int:
21 | """Return length of the dataset."""
22 | return len(self.feats)
23 |
24 | def __getitem__(self, idx: int) -> Dict:
25 | return {"feat": self.feats[idx], "gt_label": self.labels[idx]}
26 |
--------------------------------------------------------------------------------
/mmfscil/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .classifier import ImageClassifierCIL
2 | from .ETFHead import ETFHead, DRLoss
3 | from .mlp_ffn_neck import MLPFFNNeck
4 | from .resnet18 import ResNet18
5 |
--------------------------------------------------------------------------------
/mmfscil/models/mlp_ffn_neck.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from collections import OrderedDict
3 |
4 | import torch.nn as nn
5 | from mmcv.cnn import build_norm_layer
6 |
7 | from mmcls.models.builder import NECKS
8 |
9 |
10 | @NECKS.register_module()
11 | class MLPFFNNeck(nn.Module):
12 | def __init__(self, in_channels=512, out_channels=512):
13 | super().__init__()
14 | self.avg = nn.AdaptiveAvgPool2d((1, 1))
15 | self.ln1 = nn.Sequential(OrderedDict([
16 | ('linear', nn.Linear(in_channels, in_channels * 2)),
17 | ('ln', build_norm_layer(dict(type='LN'), in_channels * 2)[1]),
18 | ('relu', nn.LeakyReLU(0.1))
19 | ]))
20 | self.ln2 = nn.Sequential(OrderedDict([
21 | ('linear', nn.Linear(in_channels * 2, in_channels * 2)),
22 | ('ln', build_norm_layer(dict(type='LN'), in_channels * 2)[1]),
23 | ('relu', nn.LeakyReLU(0.1))
24 | ]))
25 | self.ln3 = nn.Sequential(OrderedDict([
26 | ('linear', nn.Linear(in_channels * 2, out_channels, bias=False)),
27 | ]))
28 | if in_channels == out_channels:
29 | # self.ffn = nn.Identity()
30 | self.ffn = nn.Sequential(OrderedDict([
31 | ('proj', nn.Linear(in_channels, out_channels, bias=False)),
32 | ]))
33 | else:
34 | self.ffn = nn.Sequential(OrderedDict([
35 | ('proj', nn.Linear(in_channels, out_channels, bias=False)),
36 | ]))
37 |
38 | def init_weights(self):
39 | pass
40 |
41 | def forward(self, inputs):
42 | if isinstance(inputs, tuple):
43 | inputs = inputs[-1]
44 | x = self.avg(inputs)
45 | x = x.view(inputs.size(0), -1)
46 | identity = x
47 | x = self.ln1(x)
48 | x = self.ln2(x)
49 | x = self.ln3(x)
50 | x = x + self.ffn(identity)
51 | return x
52 |
--------------------------------------------------------------------------------
/mmfscil/models/resnet18.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 | from mmcls.models import BACKBONES
5 |
6 |
7 | class BasicBlock(nn.Module):
8 | expansion = 1
9 |
10 | def __init__(self, in_planes, planes, stride=1):
11 | super(BasicBlock, self).__init__()
12 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
13 | self.bn1 = nn.BatchNorm2d(planes)
14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
15 | self.bn2 = nn.BatchNorm2d(planes)
16 |
17 | self.shortcut = nn.Sequential()
18 | if stride != 1 or in_planes != self.expansion * planes:
19 | self.shortcut = nn.Sequential(
20 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
21 | nn.BatchNorm2d(self.expansion * planes)
22 | )
23 |
24 | def forward(self, x):
25 | out = F.relu(self.bn1(self.conv1(x)))
26 | out = self.bn2(self.conv2(out))
27 | out += self.shortcut(x)
28 | out = F.relu(out)
29 | return out
30 |
31 |
32 | class Bottleneck(nn.Module):
33 | expansion = 4
34 |
35 | def __init__(self, in_planes, planes, stride=1):
36 | super(Bottleneck, self).__init__()
37 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
38 | self.bn1 = nn.BatchNorm2d(planes)
39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
40 | self.bn2 = nn.BatchNorm2d(planes)
41 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False)
42 | self.bn3 = nn.BatchNorm2d(self.expansion * planes)
43 |
44 | self.shortcut = nn.Sequential()
45 | if stride != 1 or in_planes != self.expansion * planes:
46 | self.shortcut = nn.Sequential(
47 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
48 | nn.BatchNorm2d(self.expansion * planes)
49 | )
50 |
51 | def forward(self, x):
52 | out = F.relu(self.bn1(self.conv1(x)))
53 | out = F.relu(self.bn2(self.conv2(out)))
54 | out = self.bn3(self.conv3(out))
55 | out += self.shortcut(x)
56 | out = F.relu(out)
57 | return out
58 |
59 |
60 | @BACKBONES.register_module()
61 | class ResNet18(nn.Module):
62 | def __init__(self, block=BasicBlock, num_blocks=(2, 2, 2, 2), low_dim=512):
63 | super().__init__()
64 | self.in_planes = 64
65 |
66 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
67 | self.bn1 = nn.BatchNorm2d(64)
68 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
69 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
70 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
71 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
72 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
73 | # self.fc = nn.Linear(512 * block.expansion, low_dim)
74 | # self.l2norm = Normalize(2)
75 |
76 | def _make_layer(self, block, planes, num_blocks, stride):
77 | strides = [stride] + [1] * (num_blocks - 1)
78 | layers = []
79 | for stride in strides:
80 | layers.append(block(self.in_planes, planes, stride))
81 | self.in_planes = planes * block.expansion
82 | return nn.Sequential(*layers)
83 |
84 | def forward(self, x):
85 | out = F.relu(self.bn1(self.conv1(x)))
86 | out = self.layer1(out)
87 | out = self.layer2(out)
88 | out = self.layer3(out)
89 | out = self.layer4(out)
90 | return out
91 |
--------------------------------------------------------------------------------
/tools/dist_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | GPUS=$2
5 | NNODES=${NNODES:-1}
6 | NODE_RANK=${NODE_RANK:-0}
7 | PORT=${PORT:-$((29500 + $RANDOM % 29))}
8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
9 |
10 |
11 | if command -v torchrun &> /dev/null
12 | then
13 | echo "Using torchrun mode."
14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
15 | torchrun --nnodes=$NNODES \
16 | --nnodes=$NNODES \
17 | --node_rank=$NODE_RANK \
18 | --master_addr=$MASTER_ADDR \
19 | --master_port=$PORT \
20 | --nproc_per_node=$GPUS \
21 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
22 | else
23 | echo "Using launch mode."
24 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
25 | python -m torch.distributed.launch \
26 | --nnodes=$NNODES \
27 | --node_rank=$NODE_RANK \
28 | --master_addr=$MASTER_ADDR \
29 | --master_port=$PORT \
30 | --nproc_per_node=$GPUS \
31 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
32 | fi
33 |
--------------------------------------------------------------------------------
/tools/docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | DATALOC=${DATALOC:-`realpath ../datasets`}
4 | LOGLOC=${LOGLOC:-`realpath ../logger`}
5 | IMG=${IMG:-"harbory/openmmlab:2206"}
6 | GPUS=${GPUS:-"all"}
7 |
8 | if [ ${GPUS} != "all" ]; then
9 | GPUS='"device='${GPUS}'"'
10 | fi
11 |
12 | docker run --gpus ${GPUS} -it --rm --ipc=host --net=host \
13 | --mount src=$(pwd),target=/opt/project,type=bind \
14 | --mount src=$DATALOC,target=/opt/data,type=bind \
15 | --mount src=$LOGLOC,target=/opt/logger,type=bind \
16 | $IMG
17 |
--------------------------------------------------------------------------------
/tools/run_fscil.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | CONFIG=$1
4 | WORKDIR=$2
5 | CKPT=$3
6 | GPUS=$4
7 | NNODES=${NNODES:-1}
8 | NODE_RANK=${NODE_RANK:-0}
9 | PORT=${PORT:-$((29500 + $RANDOM % 29))}
10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}
11 |
12 |
13 | if command -v torchrun &> /dev/null
14 | then
15 | echo "Using torchrun mode."
16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
17 | torchrun --nnodes=$NNODES \
18 | --nnodes=$NNODES \
19 | --node_rank=$NODE_RANK \
20 | --master_addr=$MASTER_ADDR \
21 | --master_port=$PORT \
22 | --nproc_per_node=$GPUS \
23 | $(dirname "$0")/fscil.py $CONFIG $WORKDIR $CKPT --launcher pytorch ${@:5}
24 | else
25 | echo "Using launch mode."
26 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \
27 | python -m torch.distributed.launch \
28 | --nnodes=$NNODES \
29 | --node_rank=$NODE_RANK \
30 | --master_addr=$MASTER_ADDR \
31 | --master_port=$PORT \
32 | --nproc_per_node=$GPUS \
33 | $(dirname "$0")/fscil.py $CONFIG $WORKDIR $CKPT --launcher pytorch ${@:5}
34 | fi
35 |
--------------------------------------------------------------------------------