├── .gitignore ├── README.md ├── configs ├── _base_ │ ├── datasets │ │ ├── cifar_fscil.py │ │ ├── cub_fscil.py │ │ └── mini_imagenet_fscil.py │ ├── default_runtime.py │ ├── models │ │ └── resnet_etf.py │ └── schedules │ │ ├── cifar_200e.py │ │ ├── cub_80e.py │ │ └── mini_imagenet_500e.py ├── cifar │ ├── resnet12_etf_bs512_200e_cifar.py │ └── resnet12_etf_bs512_200e_cifar_eval.py ├── cub │ ├── resnet18_etf_bs512_80e_cub.py │ └── resnet18_etf_bs512_80e_cub_eval.py └── mini_imagenet │ ├── resnet12_etf_bs512_500e_miniimagenet.py │ └── resnet12_etf_bs512_500e_miniimagenet_eval.py ├── docker_env └── Dockerfile ├── logs ├── cifar_base.log ├── cifar_inc.log ├── cub_base.log ├── cub_inc.log ├── min_base.log └── min_inc.log ├── mmcls ├── __init__.py ├── apis │ ├── __init__.py │ ├── inference.py │ ├── test.py │ └── train.py ├── core │ ├── __init__.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── eval_hooks.py │ │ ├── eval_metrics.py │ │ ├── mean_ap.py │ │ └── multilabel_eval_metrics.py │ ├── export │ │ ├── __init__.py │ │ └── test.py │ ├── hook │ │ ├── __init__.py │ │ ├── class_num_check_hook.py │ │ ├── lr_updater.py │ │ ├── precise_bn_hook.py │ │ └── wandblogger_hook.py │ ├── optimizers │ │ ├── __init__.py │ │ └── lamb.py │ ├── utils │ │ ├── __init__.py │ │ ├── dist_utils.py │ │ └── misc.py │ └── visualization │ │ ├── __init__.py │ │ └── image.py ├── datasets │ ├── __init__.py │ ├── base_dataset.py │ ├── builder.py │ ├── cifar.py │ ├── cub.py │ ├── custom.py │ ├── dataset_wrappers.py │ ├── imagenet.py │ ├── imagenet21k.py │ ├── mnist.py │ ├── multi_label.py │ ├── pipelines │ │ ├── __init__.py │ │ ├── auto_augment.py │ │ ├── compose.py │ │ ├── formatting.py │ │ ├── loading.py │ │ └── transforms.py │ ├── samplers │ │ ├── __init__.py │ │ ├── distributed_sampler.py │ │ └── repeat_aug.py │ ├── utils.py │ └── voc.py ├── models │ ├── __init__.py │ ├── backbones │ │ ├── __init__.py │ │ ├── alexnet.py │ │ ├── base_backbone.py │ │ ├── conformer.py │ │ ├── convmixer.py │ │ ├── convnext.py │ │ ├── cspnet.py │ │ ├── deit.py │ │ ├── densenet.py │ │ ├── efficientnet.py │ │ ├── hrnet.py │ │ ├── lenet.py │ │ ├── mlp_mixer.py │ │ ├── mobilenet_v2.py │ │ ├── mobilenet_v3.py │ │ ├── poolformer.py │ │ ├── regnet.py │ │ ├── repmlp.py │ │ ├── repvgg.py │ │ ├── res2net.py │ │ ├── resnest.py │ │ ├── resnet.py │ │ ├── resnet_cifar.py │ │ ├── resnext.py │ │ ├── seresnet.py │ │ ├── seresnext.py │ │ ├── shufflenet_v1.py │ │ ├── shufflenet_v2.py │ │ ├── swin_transformer.py │ │ ├── t2t_vit.py │ │ ├── timm_backbone.py │ │ ├── tnt.py │ │ ├── twins.py │ │ ├── van.py │ │ ├── vgg.py │ │ └── vision_transformer.py │ ├── builder.py │ ├── classifiers │ │ ├── __init__.py │ │ ├── base.py │ │ └── image.py │ ├── heads │ │ ├── __init__.py │ │ ├── base_head.py │ │ ├── cls_head.py │ │ ├── conformer_head.py │ │ ├── deit_head.py │ │ ├── linear_head.py │ │ ├── multi_label_head.py │ │ ├── multi_label_linear_head.py │ │ ├── stacked_head.py │ │ └── vision_transformer_head.py │ ├── losses │ │ ├── __init__.py │ │ ├── accuracy.py │ │ ├── asymmetric_loss.py │ │ ├── cross_entropy_loss.py │ │ ├── focal_loss.py │ │ ├── label_smooth_loss.py │ │ ├── seesaw_loss.py │ │ └── utils.py │ ├── necks │ │ ├── __init__.py │ │ ├── gap.py │ │ ├── gem.py │ │ └── hr_fuse.py │ └── utils │ │ ├── __init__.py │ │ ├── attention.py │ │ ├── augment │ │ ├── __init__.py │ │ ├── augments.py │ │ ├── builder.py │ │ ├── cutmix.py │ │ ├── identity.py │ │ ├── mixup.py │ │ ├── resizemix.py │ │ └── utils.py │ │ ├── channel_shuffle.py │ │ ├── embed.py │ │ ├── helpers.py │ │ ├── inverted_residual.py │ │ ├── make_divisible.py │ │ ├── position_encoding.py │ │ └── se_layer.py ├── utils │ ├── __init__.py │ ├── collect_env.py │ ├── device.py │ ├── distribution.py │ ├── logger.py │ └── setup_env.py └── version.py ├── mmfewshot ├── __init__.py ├── classification │ ├── __init__.py │ ├── apis │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── test.py │ │ └── train.py │ ├── core │ │ ├── __init__.py │ │ └── evaluation │ │ │ ├── __init__.py │ │ │ └── eval_hooks.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── builder.py │ │ ├── cub.py │ │ ├── dataset_wrappers.py │ │ ├── mini_imagenet.py │ │ ├── pipelines │ │ │ ├── __init__.py │ │ │ └── loading.py │ │ ├── tiered_imagenet.py │ │ └── utils.py │ ├── models │ │ ├── __init__.py │ │ ├── backbones │ │ │ ├── __init__.py │ │ │ ├── conv4.py │ │ │ ├── resnet12.py │ │ │ ├── utils.py │ │ │ └── wrn.py │ │ ├── classifiers │ │ │ ├── __init__.py │ │ │ ├── base.py │ │ │ ├── base_finetune.py │ │ │ ├── base_metric.py │ │ │ ├── baseline.py │ │ │ ├── baseline_plus.py │ │ │ ├── maml.py │ │ │ ├── matching_net.py │ │ │ ├── meta_baseline.py │ │ │ ├── neg_margin.py │ │ │ ├── proto_net.py │ │ │ └── relation_net.py │ │ ├── heads │ │ │ ├── __init__.py │ │ │ ├── base_head.py │ │ │ ├── cosine_distance_head.py │ │ │ ├── linear_head.py │ │ │ ├── matching_head.py │ │ │ ├── meta_baseline_head.py │ │ │ ├── neg_margin_head.py │ │ │ ├── prototype_head.py │ │ │ └── relation_head.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── mse_loss.py │ │ │ └── nll_loss.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── maml_module.py │ └── utils │ │ ├── __init__.py │ │ └── meta_test_parallel.py ├── detection │ ├── __init__.py │ ├── apis │ │ ├── __init__.py │ │ ├── inference.py │ │ ├── test.py │ │ └── train.py │ ├── core │ │ ├── __init__.py │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── eval_hooks.py │ │ │ └── mean_ap.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ └── custom_hook.py │ ├── datasets │ │ ├── __init__.py │ │ ├── base.py │ │ ├── builder.py │ │ ├── coco.py │ │ ├── dataloader_wrappers.py │ │ ├── dataset_wrappers.py │ │ ├── pipelines │ │ │ ├── __init__.py │ │ │ ├── formatting.py │ │ │ └── transforms.py │ │ ├── utils.py │ │ └── voc.py │ └── models │ │ ├── __init__.py │ │ ├── backbones │ │ ├── __init__.py │ │ └── resnet_with_meta_conv.py │ │ ├── builder.py │ │ ├── dense_heads │ │ ├── __init__.py │ │ ├── attention_rpn_head.py │ │ └── two_branch_rpn_head.py │ │ ├── detectors │ │ ├── __init__.py │ │ ├── attention_rpn_detector.py │ │ ├── fsce.py │ │ ├── fsdetview.py │ │ ├── meta_rcnn.py │ │ ├── mpsr.py │ │ ├── query_support_detector.py │ │ └── tfa.py │ │ ├── losses │ │ ├── __init__.py │ │ └── supervised_contrastive_loss.py │ │ ├── roi_heads │ │ ├── __init__.py │ │ ├── bbox_heads │ │ │ ├── __init__.py │ │ │ ├── contrastive_bbox_head.py │ │ │ ├── cosine_sim_bbox_head.py │ │ │ ├── meta_bbox_head.py │ │ │ ├── multi_relation_bbox_head.py │ │ │ └── two_branch_bbox_head.py │ │ ├── contrastive_roi_head.py │ │ ├── fsdetview_roi_head.py │ │ ├── meta_rcnn_roi_head.py │ │ ├── multi_relation_roi_head.py │ │ ├── shared_heads │ │ │ ├── __init__.py │ │ │ └── meta_rcnn_res_layer.py │ │ └── two_branch_roi_head.py │ │ └── utils │ │ ├── __init__.py │ │ └── aggregation_layer.py ├── utils │ ├── __init__.py │ ├── collate.py │ ├── collect_env.py │ ├── compat_config.py │ ├── dist_utils.py │ ├── infinite_sampler.py │ ├── local_seed.py │ ├── logger.py │ └── runner.py └── version.py ├── mmfscil ├── __init__.py ├── apis │ ├── __init__.py │ ├── eval_hook.py │ ├── fscil.py │ ├── test_fn.py │ └── train.py ├── augments │ ├── __init__.py │ ├── cutmix.py │ ├── idty.py │ └── mixup.py ├── datasets │ ├── __init__.py │ ├── cifar100.py │ ├── cub.py │ ├── memory.py │ └── mini_imagenet.py └── models │ ├── ETFHead.py │ ├── __init__.py │ ├── classifier.py │ ├── mlp_ffn_neck.py │ └── resnet18.py └── tools ├── dist_train.sh ├── docker.sh ├── fscil.py ├── run_fscil.sh └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | # *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # data 132 | data/ 133 | 134 | # Macos 135 | .DS_Store 136 | 137 | # Pycharm 138 | .idea/ 139 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cifar_fscil.py: -------------------------------------------------------------------------------- 1 | img_size = 32 2 | _img_resize_size = 36 3 | img_norm_cfg = dict(mean=[129.304, 124.070, 112.434], std=[68.170, 65.392, 70.418], to_rgb=False) 4 | meta_keys = ('filename', 'ori_filename', 'ori_shape', 5 | 'img_shape', 'flip', 'flip_direction', 6 | 'img_norm_cfg', 'cls_id', 'img_id') 7 | 8 | train_pipeline = [ 9 | dict(type='RandomResizedCrop', size=img_size, scale=(0.6, 1.), interpolation='bicubic'), 10 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 11 | dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), 12 | dict(type='Normalize', **img_norm_cfg), 13 | dict(type='ImageToTensor', keys=['img']), 14 | dict(type='ToTensor', keys=['gt_label']), 15 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys) 16 | ] 17 | 18 | test_pipeline = [ 19 | dict(type='Resize', size=(_img_resize_size, -1), interpolation='bicubic'), 20 | dict(type='CenterCrop', crop_size=img_size), 21 | dict(type='Normalize', **img_norm_cfg), 22 | dict(type='ImageToTensor', keys=['img']), 23 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys) 24 | ] 25 | 26 | data = dict( 27 | samples_per_gpu=64, 28 | workers_per_gpu=8, 29 | train_dataloader=dict( 30 | persistent_workers=True, 31 | ), 32 | val_dataloader=dict( 33 | persistent_workers=True, 34 | ), 35 | test_dataloader=dict( 36 | persistent_workers=True, 37 | ), 38 | train=dict( 39 | type='RepeatDataset', 40 | times=4, 41 | dataset=dict( 42 | type='CIFAR100FSCILDataset', 43 | data_prefix='/opt/data/cifar', 44 | pipeline=train_pipeline, 45 | num_cls=60, 46 | subset='train', 47 | ) 48 | ), 49 | val=dict( 50 | type='CIFAR100FSCILDataset', 51 | data_prefix='/opt/data/cifar', 52 | pipeline=test_pipeline, 53 | num_cls=60, 54 | subset='test', 55 | ), 56 | test=dict( 57 | type='CIFAR100FSCILDataset', 58 | data_prefix='/opt/data/cifar', 59 | pipeline=test_pipeline, 60 | num_cls=100, 61 | subset='test', 62 | ) 63 | ) 64 | -------------------------------------------------------------------------------- /configs/_base_/datasets/cub_fscil.py: -------------------------------------------------------------------------------- 1 | img_size = 224 2 | _img_resize_size = 256 3 | img_norm_cfg = dict( 4 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 5 | ) 6 | meta_keys = ('filename', 'ori_filename', 'ori_shape', 7 | 'img_shape', 'flip', 'flip_direction', 8 | 'img_norm_cfg', 'cls_id', 'img_id') 9 | 10 | train_pipeline = [ 11 | dict(type='LoadImageFromFile'), 12 | dict(type='RandomResizedCrop', size=img_size), 13 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 14 | dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), 15 | dict(type='Normalize', **img_norm_cfg), 16 | dict(type='ImageToTensor', keys=['img']), 17 | dict(type='ToTensor', keys=['gt_label']), 18 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys) 19 | ] 20 | 21 | test_pipeline = [ 22 | dict(type='LoadImageFromFile'), 23 | dict(type='Resize', size=(_img_resize_size, -1)), 24 | dict(type='CenterCrop', crop_size=img_size), 25 | dict(type='Normalize', **img_norm_cfg), 26 | dict(type='ImageToTensor', keys=['img']), 27 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys) 28 | ] 29 | 30 | data = dict( 31 | samples_per_gpu=64, 32 | workers_per_gpu=8, 33 | train_dataloader=dict( 34 | persistent_workers=True, 35 | ), 36 | val_dataloader=dict( 37 | persistent_workers=True, 38 | ), 39 | test_dataloader=dict( 40 | persistent_workers=True, 41 | ), 42 | train=dict( 43 | type='RepeatDataset', 44 | times=4, 45 | dataset=dict( 46 | type='CUBFSCILDataset', 47 | data_prefix='/opt/data/CUB_200_2011', 48 | pipeline=train_pipeline, 49 | num_cls=100, 50 | subset='train', 51 | ) 52 | ), 53 | val=dict( 54 | type='CUBFSCILDataset', 55 | data_prefix='/opt/data/CUB_200_2011', 56 | pipeline=test_pipeline, 57 | num_cls=100, 58 | subset='test', 59 | ), 60 | test=dict( 61 | type='CUBFSCILDataset', 62 | data_prefix='/opt/data/CUB_200_2011', 63 | pipeline=test_pipeline, 64 | num_cls=200, 65 | subset='test', 66 | ) 67 | ) 68 | -------------------------------------------------------------------------------- /configs/_base_/datasets/mini_imagenet_fscil.py: -------------------------------------------------------------------------------- 1 | img_size = 84 2 | img_norm_cfg = dict( 3 | mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True 4 | ) 5 | meta_keys = ('filename', 'ori_filename', 'ori_shape', 6 | 'img_shape', 'flip', 'flip_direction', 7 | 'img_norm_cfg', 'cls_id', 'img_id') 8 | 9 | train_pipeline = [ 10 | dict(type='LoadImageFromFile'), 11 | dict(type='RandomResizedCrop', size=img_size), 12 | dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'), 13 | dict(type='ColorJitter', brightness=0.4, contrast=0.4, saturation=0.4), 14 | dict(type='Normalize', **img_norm_cfg), 15 | dict(type='ImageToTensor', keys=['img']), 16 | dict(type='ToTensor', keys=['gt_label']), 17 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys) 18 | ] 19 | 20 | test_pipeline = [ 21 | dict(type='LoadImageFromFile'), 22 | dict(type='Resize', size=(int(img_size * 1.15), -1)), 23 | dict(type='CenterCrop', crop_size=img_size), 24 | dict(type='Normalize', **img_norm_cfg), 25 | dict(type='ImageToTensor', keys=['img']), 26 | dict(type='Collect', keys=['img', 'gt_label'], meta_keys=meta_keys) 27 | ] 28 | 29 | data = dict( 30 | samples_per_gpu=64, 31 | workers_per_gpu=8, 32 | train_dataloader=dict( 33 | persistent_workers=True, 34 | ), 35 | val_dataloader=dict( 36 | persistent_workers=True, 37 | ), 38 | test_dataloader=dict( 39 | persistent_workers=True, 40 | ), 41 | train=dict( 42 | type='RepeatDataset', 43 | times=10, 44 | dataset=dict( 45 | type='MiniImageNetFSCILDataset', 46 | data_prefix='/opt/data/miniimagenet', 47 | pipeline=train_pipeline, 48 | num_cls=60, 49 | subset='train', 50 | ) 51 | ), 52 | val=dict( 53 | type='MiniImageNetFSCILDataset', 54 | data_prefix='/opt/data/miniimagenet', 55 | pipeline=test_pipeline, 56 | num_cls=60, 57 | subset='test', 58 | ), 59 | test=dict( 60 | type='MiniImageNetFSCILDataset', 61 | data_prefix='/opt/data/miniimagenet', 62 | pipeline=test_pipeline, 63 | num_cls=100, 64 | subset='test', 65 | ) 66 | ) 67 | -------------------------------------------------------------------------------- /configs/_base_/default_runtime.py: -------------------------------------------------------------------------------- 1 | # checkpoint saving 2 | checkpoint_config = dict(interval=1, max_keep_ckpts=2) 3 | evaluation = dict(interval=1, save_best='auto') 4 | log_config = dict( 5 | interval=10, 6 | hooks=[ 7 | dict(type='TextLoggerHook'), 8 | ] 9 | ) 10 | 11 | dist_params = dict(backend='nccl') 12 | log_level = 'INFO' 13 | workflow = [('train', 1)] 14 | 15 | load_from = None 16 | resume_from = None 17 | 18 | # Test configs 19 | mean_neck_feat = True 20 | mean_cur_feat = False 21 | feat_test = False 22 | grad_clip = None 23 | finetune_lr = 0.1 24 | inc_start = 60 25 | inc_end = 100 26 | inc_step = 5 27 | 28 | copy_list = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1) 29 | step_list = (50, 50, 50, 50, 50, 50, 50, 50, 50, 50) 30 | -------------------------------------------------------------------------------- /configs/_base_/models/resnet_etf.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='ImageClassifierCIL', 4 | backbone=dict( 5 | type='ResNet12', 6 | with_avgpool=False, 7 | flatten=False 8 | ), 9 | neck=dict(type='MLPNeck', in_channels=640, out_channels=512), 10 | head=dict( 11 | type='ETFHead', 12 | num_classes=100, 13 | eval_classes=60, 14 | in_channels=512, 15 | loss=dict(type='DRLoss', loss_weight=10.), 16 | topk=(1, 5), 17 | cal_acc=True, 18 | ) 19 | ) 20 | -------------------------------------------------------------------------------- /configs/_base_/schedules/cifar_200e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', lr=0.25, momentum=0.9, weight_decay=0.0005 4 | ) 5 | optimizer_config = dict(grad_clip=None) 6 | lr_config = dict( 7 | policy='CosineAnnealingCooldown', 8 | min_lr=None, 9 | min_lr_ratio=1.e-2, 10 | cool_down_ratio=0.1, 11 | cool_down_time=10, 12 | by_epoch=False, 13 | # warmup 14 | warmup='linear', 15 | warmup_iters=100, 16 | warmup_ratio=0.1, 17 | warmup_by_epoch=False 18 | ) 19 | 20 | runner = dict(type='EpochBasedRunner', max_epochs=50) 21 | -------------------------------------------------------------------------------- /configs/_base_/schedules/cub_80e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', lr=0.025, momentum=0.9, weight_decay=0.0005 4 | ) 5 | optimizer_config = dict(grad_clip=None) 6 | lr_config = dict( 7 | policy='CosineAnnealingCooldown', 8 | min_lr=None, 9 | min_lr_ratio=0.1, 10 | cool_down_ratio=0.1, 11 | cool_down_time=10, 12 | by_epoch=False, 13 | # warmup 14 | warmup='linear', 15 | warmup_iters=100, 16 | warmup_ratio=0.1, 17 | warmup_by_epoch=False 18 | ) 19 | 20 | runner = dict(type='EpochBasedRunner', max_epochs=20) 21 | -------------------------------------------------------------------------------- /configs/_base_/schedules/mini_imagenet_500e.py: -------------------------------------------------------------------------------- 1 | # optimizer 2 | optimizer = dict( 3 | type='SGD', lr=0.25, momentum=0.9, weight_decay=0.0005 4 | ) 5 | optimizer_config = dict(grad_clip=None) 6 | # learning policy 7 | lr_config = dict( 8 | policy='step', 9 | gamma=0.25, 10 | warmup='linear', 11 | warmup_iters=3000, 12 | warmup_ratio=0.25, 13 | step=[20, 30, 35, 40, 45] 14 | ) 15 | runner = dict(type='EpochBasedRunner', max_epochs=50) 16 | -------------------------------------------------------------------------------- /configs/cifar/resnet12_etf_bs512_200e_cifar.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/resnet_etf.py', 3 | '../_base_/datasets/cifar_fscil.py', 4 | '../_base_/schedules/cifar_200e.py', 5 | '../_base_/default_runtime.py' 6 | ] 7 | 8 | # model settings 9 | model = dict( 10 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512), 11 | head=dict(type='ETFHead', in_channels=512, with_len=False), 12 | train_cfg=dict(augments=[ 13 | dict(type='BatchMixupTwoLabel', alpha=0.8, num_classes=-1, prob=0.4), 14 | dict(type='BatchCutMixTwoLabel', alpha=1.0, num_classes=-1, prob=0.4), 15 | dict(type='IdentityTwoLabel', num_classes=-1, prob=0.2), 16 | ]), 17 | ) 18 | -------------------------------------------------------------------------------- /configs/cifar/resnet12_etf_bs512_200e_cifar_eval.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/resnet_etf.py', 3 | '../_base_/datasets/cifar_fscil.py', 4 | '../_base_/schedules/cifar_200e.py', 5 | '../_base_/default_runtime.py' 6 | ] 7 | 8 | # model settings 9 | model = dict( 10 | mixup=0.5, 11 | mixup_prob=0.75, 12 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512), 13 | head=dict(type='ETFHead', in_channels=512, with_len=True), 14 | ) 15 | 16 | copy_list = (1, 1, 1, 1, 1, 1, 1, 1, None, None) 17 | step_list = (50, 75, 100, 120, 140, 160, 200, 200, None, None) 18 | finetune_lr = 0.25 19 | -------------------------------------------------------------------------------- /configs/cub/resnet18_etf_bs512_80e_cub.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/resnet_etf.py', 3 | '../_base_/datasets/cub_fscil.py', 4 | '../_base_/schedules/cub_80e.py', 5 | '../_base_/default_runtime.py' 6 | ] 7 | 8 | # CUB requires different inc settings 9 | inc_start = 100 10 | inc_end = 200 11 | inc_step = 10 12 | 13 | # model settings 14 | model = dict( 15 | backbone=dict( 16 | _delete_=True, 17 | type='ResNet', 18 | depth=18, 19 | frozen_stages=1, 20 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), 21 | norm_cfg=dict(type='SyncBN', requires_grad=True), 22 | ), 23 | neck=dict(type='MLPFFNNeck', in_channels=512, out_channels=512), 24 | head=dict( 25 | type='ETFHead', 26 | num_classes=200, 27 | eval_classes=100, 28 | with_len=False, 29 | ) 30 | ) 31 | -------------------------------------------------------------------------------- /configs/cub/resnet18_etf_bs512_80e_cub_eval.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/resnet_etf.py', 3 | '../_base_/datasets/cub_fscil.py', 4 | '../_base_/schedules/cub_80e.py', 5 | '../_base_/default_runtime.py' 6 | ] 7 | 8 | # CUB requires different inc settings 9 | inc_start = 100 10 | inc_end = 200 11 | inc_step = 10 12 | 13 | # model settings 14 | model = dict( 15 | backbone=dict( 16 | _delete_=True, 17 | type='ResNet', 18 | depth=18, 19 | frozen_stages=1, 20 | init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet18'), 21 | norm_cfg=dict(type='BN', requires_grad=True), 22 | ), 23 | neck=dict(type='MLPFFNNeck', in_channels=512, out_channels=512), 24 | head=dict( 25 | type='ETFHead', 26 | num_classes=200, 27 | eval_classes=100, 28 | with_len=False, 29 | ) 30 | ) 31 | 32 | copy_list = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10) 33 | step_list = (105, 110, 115, 120, 125, 130, 135, 140, 145, 150) 34 | finetune_lr = 0.05 35 | -------------------------------------------------------------------------------- /configs/mini_imagenet/resnet12_etf_bs512_500e_miniimagenet.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/resnet_etf.py', 3 | '../_base_/datasets/mini_imagenet_fscil.py', 4 | '../_base_/schedules/mini_imagenet_500e.py', 5 | '../_base_/default_runtime.py' 6 | ] 7 | 8 | 9 | # model settings 10 | model = dict( 11 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512), 12 | head=dict(type='ETFHead', in_channels=512, with_len=False), 13 | ) 14 | 15 | -------------------------------------------------------------------------------- /configs/mini_imagenet/resnet12_etf_bs512_500e_miniimagenet_eval.py: -------------------------------------------------------------------------------- 1 | _base_ = [ 2 | '../_base_/models/resnet_etf.py', 3 | '../_base_/datasets/mini_imagenet_fscil.py', 4 | '../_base_/schedules/mini_imagenet_500e.py', 5 | '../_base_/default_runtime.py' 6 | ] 7 | 8 | 9 | # model settings 10 | model = dict( 11 | mixup=0., 12 | neck=dict(type='MLPFFNNeck', in_channels=640, out_channels=512), 13 | head=dict(type='ETFHead', in_channels=512, with_len=False), 14 | ) 15 | 16 | copy_list = (1, 2, 3, 4, 5, 6, 7, 8, None, None) 17 | step_list = (100, 110, 120, 130, 140, 150, 160, 170, None, None) 18 | finetune_lr = 0.025 19 | -------------------------------------------------------------------------------- /docker_env/Dockerfile: -------------------------------------------------------------------------------- 1 | # Pytorch 1.13 CUDA 11.7 2 | FROM nvcr.io/nvidia/pytorch:22.06-py3 3 | 4 | # install ujson 5 | RUN pip install ujson 6 | 7 | # handle the timezone 8 | RUN apt-get update && DEBIAN_FRONTEND="noninteractive" TZ="PRC" apt-get install tzdata \ 9 | && apt-get clean && rm -rf /var/lib/apt/lists/* \ 10 | && unlink /etc/localtime && ln -s /usr/share/zoneinfo/PRC /etc/localtime 11 | 12 | # mmcv : 1.6.1 13 | RUN until MMCV_WITH_OPS=1 FORCE_CUDA=1 python -m pip install git+git://github.com/open-mmlab/mmcv.git@d409eedc816fccfb1c8d57e5eed5f03bd075f327; do sleep 0.1; done 14 | 15 | # git config 16 | RUN git config --global --add safe.directory /opt/project 17 | 18 | WORKDIR /opt/project 19 | -------------------------------------------------------------------------------- /mmcls/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import mmcv 5 | from packaging.version import parse 6 | 7 | from .version import __version__ 8 | 9 | 10 | def digit_version(version_str: str, length: int = 4): 11 | """Convert a version string into a tuple of integers. 12 | 13 | This method is usually used for comparing two versions. For pre-release 14 | versions: alpha < beta < rc. 15 | 16 | Args: 17 | version_str (str): The version string. 18 | length (int): The maximum number of version levels. Default: 4. 19 | 20 | Returns: 21 | tuple[int]: The version info in digits (integers). 22 | """ 23 | version = parse(version_str) 24 | assert version.release, f'failed to parse version {version_str}' 25 | release = list(version.release) 26 | release = release[:length] 27 | if len(release) < length: 28 | release = release + [0] * (length - len(release)) 29 | if version.is_prerelease: 30 | mapping = {'a': -3, 'b': -2, 'rc': -1} 31 | val = -4 32 | # version.pre can be None 33 | if version.pre: 34 | if version.pre[0] not in mapping: 35 | warnings.warn(f'unknown prerelease version {version.pre[0]}, ' 36 | 'version checking may go wrong') 37 | else: 38 | val = mapping[version.pre[0]] 39 | release.extend([val, version.pre[-1]]) 40 | else: 41 | release.extend([val, 0]) 42 | 43 | elif version.is_postrelease: 44 | release.extend([1, version.post]) 45 | else: 46 | release.extend([0, 0]) 47 | return tuple(release) 48 | 49 | 50 | mmcv_minimum_version = '1.4.2' 51 | mmcv_maximum_version = '1.7.0' 52 | mmcv_version = digit_version(mmcv.__version__) 53 | 54 | 55 | assert (mmcv_version >= digit_version(mmcv_minimum_version) 56 | and mmcv_version <= digit_version(mmcv_maximum_version)), \ 57 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 58 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 59 | 60 | __all__ = ['__version__', 'digit_version'] 61 | -------------------------------------------------------------------------------- /mmcls/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import inference_model, init_model, show_result_pyplot 3 | from .test import multi_gpu_test, single_gpu_test 4 | from .train import init_random_seed, set_random_seed, train_model 5 | 6 | __all__ = [ 7 | 'set_random_seed', 'train_model', 'init_model', 'inference_model', 8 | 'multi_gpu_test', 'single_gpu_test', 'show_result_pyplot', 9 | 'init_random_seed' 10 | ] 11 | -------------------------------------------------------------------------------- /mmcls/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import * # noqa: F401, F403 3 | from .hook import * # noqa: F401, F403 4 | from .optimizers import * # noqa: F401, F403 5 | from .utils import * # noqa: F401, F403 6 | -------------------------------------------------------------------------------- /mmcls/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import DistEvalHook, EvalHook 3 | from .eval_metrics import (calculate_confusion_matrix, f1_score, precision, 4 | precision_recall_f1, recall, support) 5 | from .mean_ap import average_precision, mAP 6 | from .multilabel_eval_metrics import average_performance 7 | 8 | __all__ = [ 9 | 'precision', 'recall', 'f1_score', 'support', 'average_precision', 'mAP', 10 | 'average_performance', 'calculate_confusion_matrix', 'precision_recall_f1', 11 | 'EvalHook', 'DistEvalHook' 12 | ] 13 | -------------------------------------------------------------------------------- /mmcls/core/evaluation/eval_hooks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import torch.distributed as dist 5 | from mmcv.runner import DistEvalHook as BaseDistEvalHook 6 | from mmcv.runner import EvalHook as BaseEvalHook 7 | from torch.nn.modules.batchnorm import _BatchNorm 8 | 9 | 10 | class EvalHook(BaseEvalHook): 11 | """Non-Distributed evaluation hook. 12 | 13 | Comparing with the ``EvalHook`` in MMCV, this hook will save the latest 14 | evaluation results as an attribute for other hooks to use (like 15 | `MMClsWandbHook`). 16 | """ 17 | 18 | def __init__(self, dataloader, **kwargs): 19 | super(EvalHook, self).__init__(dataloader, **kwargs) 20 | self.latest_results = None 21 | 22 | def _do_evaluate(self, runner): 23 | """perform evaluation and save ckpt.""" 24 | results = self.test_fn(runner.model, self.dataloader) 25 | self.latest_results = results 26 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 27 | key_score = self.evaluate(runner, results) 28 | # the key_score may be `None` so it needs to skip the action to save 29 | # the best checkpoint 30 | if self.save_best and key_score: 31 | self._save_ckpt(runner, key_score) 32 | 33 | 34 | class DistEvalHook(BaseDistEvalHook): 35 | """Non-Distributed evaluation hook. 36 | 37 | Comparing with the ``EvalHook`` in MMCV, this hook will save the latest 38 | evaluation results as an attribute for other hooks to use (like 39 | `MMClsWandbHook`). 40 | """ 41 | 42 | def __init__(self, dataloader, **kwargs): 43 | super(DistEvalHook, self).__init__(dataloader, **kwargs) 44 | self.latest_results = None 45 | 46 | def _do_evaluate(self, runner): 47 | """perform evaluation and save ckpt.""" 48 | # Synchronization of BatchNorm's buffer (running_mean 49 | # and running_var) is not supported in the DDP of pytorch, 50 | # which may cause the inconsistent performance of models in 51 | # different ranks, so we broadcast BatchNorm's buffers 52 | # of rank 0 to other ranks to avoid this. 53 | if self.broadcast_bn_buffer: 54 | model = runner.model 55 | for name, module in model.named_modules(): 56 | if isinstance(module, 57 | _BatchNorm) and module.track_running_stats: 58 | dist.broadcast(module.running_var, 0) 59 | dist.broadcast(module.running_mean, 0) 60 | 61 | tmpdir = self.tmpdir 62 | if tmpdir is None: 63 | tmpdir = osp.join(runner.work_dir, '.eval_hook') 64 | 65 | results = self.test_fn( 66 | runner.model, 67 | self.dataloader, 68 | tmpdir=tmpdir, 69 | gpu_collect=self.gpu_collect) 70 | self.latest_results = results 71 | if runner.rank == 0: 72 | print('\n') 73 | runner.log_buffer.output['eval_iter_num'] = len(self.dataloader) 74 | key_score = self.evaluate(runner, results) 75 | # the key_score may be `None` so it needs to skip the action to 76 | # save the best checkpoint 77 | if self.save_best and key_score: 78 | self._save_ckpt(runner, key_score) 79 | -------------------------------------------------------------------------------- /mmcls/core/evaluation/mean_ap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def average_precision(pred, target): 7 | r"""Calculate the average precision for a single class. 8 | 9 | AP summarizes a precision-recall curve as the weighted mean of maximum 10 | precisions obtained for any r'>r, where r is the recall: 11 | 12 | .. math:: 13 | \text{AP} = \sum_n (R_n - R_{n-1}) P_n 14 | 15 | Note that no approximation is involved since the curve is piecewise 16 | constant. 17 | 18 | Args: 19 | pred (np.ndarray): The model prediction with shape (N, ). 20 | target (np.ndarray): The target of each prediction with shape (N, ). 21 | 22 | Returns: 23 | float: a single float as average precision value. 24 | """ 25 | eps = np.finfo(np.float32).eps 26 | 27 | # sort examples 28 | sort_inds = np.argsort(-pred) 29 | sort_target = target[sort_inds] 30 | 31 | # count true positive examples 32 | pos_inds = sort_target == 1 33 | tp = np.cumsum(pos_inds) 34 | total_pos = tp[-1] 35 | 36 | # count not difficult examples 37 | pn_inds = sort_target != -1 38 | pn = np.cumsum(pn_inds) 39 | 40 | tp[np.logical_not(pos_inds)] = 0 41 | precision = tp / np.maximum(pn, eps) 42 | ap = np.sum(precision) / np.maximum(total_pos, eps) 43 | return ap 44 | 45 | 46 | def mAP(pred, target): 47 | """Calculate the mean average precision with respect of classes. 48 | 49 | Args: 50 | pred (torch.Tensor | np.ndarray): The model prediction with shape 51 | (N, C), where C is the number of classes. 52 | target (torch.Tensor | np.ndarray): The target of each prediction with 53 | shape (N, C), where C is the number of classes. 1 stands for 54 | positive examples, 0 stands for negative examples and -1 stands for 55 | difficult examples. 56 | 57 | Returns: 58 | float: A single float as mAP value. 59 | """ 60 | if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): 61 | pred = pred.detach().cpu().numpy() 62 | target = target.detach().cpu().numpy() 63 | elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)): 64 | raise TypeError('pred and target should both be torch.Tensor or' 65 | 'np.ndarray') 66 | 67 | assert pred.shape == \ 68 | target.shape, 'pred and target should be in the same shape.' 69 | num_classes = pred.shape[1] 70 | ap = np.zeros(num_classes) 71 | for k in range(num_classes): 72 | ap[k] = average_precision(pred[:, k], target[:, k]) 73 | mean_ap = ap.mean() * 100.0 74 | return mean_ap 75 | -------------------------------------------------------------------------------- /mmcls/core/evaluation/multilabel_eval_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | 7 | 8 | def average_performance(pred, target, thr=None, k=None): 9 | """Calculate CP, CR, CF1, OP, OR, OF1, where C stands for per-class 10 | average, O stands for overall average, P stands for precision, R stands for 11 | recall and F1 stands for F1-score. 12 | 13 | Args: 14 | pred (torch.Tensor | np.ndarray): The model prediction with shape 15 | (N, C), where C is the number of classes. 16 | target (torch.Tensor | np.ndarray): The target of each prediction with 17 | shape (N, C), where C is the number of classes. 1 stands for 18 | positive examples, 0 stands for negative examples and -1 stands for 19 | difficult examples. 20 | thr (float): The confidence threshold. Defaults to None. 21 | k (int): Top-k performance. Note that if thr and k are both given, k 22 | will be ignored. Defaults to None. 23 | 24 | Returns: 25 | tuple: (CP, CR, CF1, OP, OR, OF1) 26 | """ 27 | if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor): 28 | pred = pred.detach().cpu().numpy() 29 | target = target.detach().cpu().numpy() 30 | elif not (isinstance(pred, np.ndarray) and isinstance(target, np.ndarray)): 31 | raise TypeError('pred and target should both be torch.Tensor or' 32 | 'np.ndarray') 33 | if thr is None and k is None: 34 | thr = 0.5 35 | warnings.warn('Neither thr nor k is given, set thr as 0.5 by ' 36 | 'default.') 37 | elif thr is not None and k is not None: 38 | warnings.warn('Both thr and k are given, use threshold in favor of ' 39 | 'top-k.') 40 | 41 | assert pred.shape == \ 42 | target.shape, 'pred and target should be in the same shape.' 43 | 44 | eps = np.finfo(np.float32).eps 45 | target[target == -1] = 0 46 | if thr is not None: 47 | # a label is predicted positive if the confidence is no lower than thr 48 | pos_inds = pred >= thr 49 | 50 | else: 51 | # top-k labels will be predicted positive for any example 52 | sort_inds = np.argsort(-pred, axis=1) 53 | sort_inds_ = sort_inds[:, :k] 54 | inds = np.indices(sort_inds_.shape) 55 | pos_inds = np.zeros_like(pred) 56 | pos_inds[inds[0], sort_inds_] = 1 57 | 58 | tp = (pos_inds * target) == 1 59 | fp = (pos_inds * (1 - target)) == 1 60 | fn = ((1 - pos_inds) * target) == 1 61 | 62 | precision_class = tp.sum(axis=0) / np.maximum( 63 | tp.sum(axis=0) + fp.sum(axis=0), eps) 64 | recall_class = tp.sum(axis=0) / np.maximum( 65 | tp.sum(axis=0) + fn.sum(axis=0), eps) 66 | CP = precision_class.mean() * 100.0 67 | CR = recall_class.mean() * 100.0 68 | CF1 = 2 * CP * CR / np.maximum(CP + CR, eps) 69 | OP = tp.sum() / np.maximum(tp.sum() + fp.sum(), eps) * 100.0 70 | OR = tp.sum() / np.maximum(tp.sum() + fn.sum(), eps) * 100.0 71 | OF1 = 2 * OP * OR / np.maximum(OP + OR, eps) 72 | return CP, CR, CF1, OP, OR, OF1 73 | -------------------------------------------------------------------------------- /mmcls/core/export/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .test import ONNXRuntimeClassifier, TensorRTClassifier 3 | 4 | __all__ = ['ONNXRuntimeClassifier', 'TensorRTClassifier'] 5 | -------------------------------------------------------------------------------- /mmcls/core/export/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import warnings 3 | 4 | import numpy as np 5 | import onnxruntime as ort 6 | import torch 7 | 8 | from mmcls.models.classifiers import BaseClassifier 9 | 10 | 11 | class ONNXRuntimeClassifier(BaseClassifier): 12 | """Wrapper for classifier's inference with ONNXRuntime.""" 13 | 14 | def __init__(self, onnx_file, class_names, device_id): 15 | super(ONNXRuntimeClassifier, self).__init__() 16 | sess = ort.InferenceSession(onnx_file) 17 | 18 | providers = ['CPUExecutionProvider'] 19 | options = [{}] 20 | is_cuda_available = ort.get_device() == 'GPU' 21 | if is_cuda_available: 22 | providers.insert(0, 'CUDAExecutionProvider') 23 | options.insert(0, {'device_id': device_id}) 24 | sess.set_providers(providers, options) 25 | 26 | self.sess = sess 27 | self.CLASSES = class_names 28 | self.device_id = device_id 29 | self.io_binding = sess.io_binding() 30 | self.output_names = [_.name for _ in sess.get_outputs()] 31 | self.is_cuda_available = is_cuda_available 32 | 33 | def simple_test(self, img, img_metas, **kwargs): 34 | raise NotImplementedError('This method is not implemented.') 35 | 36 | def extract_feat(self, imgs): 37 | raise NotImplementedError('This method is not implemented.') 38 | 39 | def forward_train(self, imgs, **kwargs): 40 | raise NotImplementedError('This method is not implemented.') 41 | 42 | def forward_test(self, imgs, img_metas, **kwargs): 43 | input_data = imgs 44 | # set io binding for inputs/outputs 45 | device_type = 'cuda' if self.is_cuda_available else 'cpu' 46 | if not self.is_cuda_available: 47 | input_data = input_data.cpu() 48 | self.io_binding.bind_input( 49 | name='input', 50 | device_type=device_type, 51 | device_id=self.device_id, 52 | element_type=np.float32, 53 | shape=input_data.shape, 54 | buffer_ptr=input_data.data_ptr()) 55 | 56 | for name in self.output_names: 57 | self.io_binding.bind_output(name) 58 | # run session to get outputs 59 | self.sess.run_with_iobinding(self.io_binding) 60 | results = self.io_binding.copy_outputs_to_cpu()[0] 61 | return list(results) 62 | 63 | 64 | class TensorRTClassifier(BaseClassifier): 65 | 66 | def __init__(self, trt_file, class_names, device_id): 67 | super(TensorRTClassifier, self).__init__() 68 | from mmcv.tensorrt import TRTWraper, load_tensorrt_plugin 69 | try: 70 | load_tensorrt_plugin() 71 | except (ImportError, ModuleNotFoundError): 72 | warnings.warn('If input model has custom op from mmcv, \ 73 | you may have to build mmcv with TensorRT from source.') 74 | model = TRTWraper( 75 | trt_file, input_names=['input'], output_names=['probs']) 76 | 77 | self.model = model 78 | self.device_id = device_id 79 | self.CLASSES = class_names 80 | 81 | def simple_test(self, img, img_metas, **kwargs): 82 | raise NotImplementedError('This method is not implemented.') 83 | 84 | def extract_feat(self, imgs): 85 | raise NotImplementedError('This method is not implemented.') 86 | 87 | def forward_train(self, imgs, **kwargs): 88 | raise NotImplementedError('This method is not implemented.') 89 | 90 | def forward_test(self, imgs, img_metas, **kwargs): 91 | input_data = imgs 92 | with torch.cuda.device(self.device_id), torch.no_grad(): 93 | results = self.model({'input': input_data})['probs'] 94 | results = results.detach().cpu().numpy() 95 | 96 | return list(results) 97 | -------------------------------------------------------------------------------- /mmcls/core/hook/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .class_num_check_hook import ClassNumCheckHook 3 | from .lr_updater import CosineAnnealingCooldownLrUpdaterHook 4 | from .precise_bn_hook import PreciseBNHook 5 | from .wandblogger_hook import MMClsWandbHook 6 | 7 | __all__ = [ 8 | 'ClassNumCheckHook', 'PreciseBNHook', 9 | 'CosineAnnealingCooldownLrUpdaterHook', 'MMClsWandbHook' 10 | ] 11 | -------------------------------------------------------------------------------- /mmcls/core/hook/class_num_check_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved 2 | from mmcv.runner import IterBasedRunner 3 | from mmcv.runner.hooks import HOOKS, Hook 4 | from mmcv.utils import is_seq_of 5 | 6 | 7 | @HOOKS.register_module() 8 | class ClassNumCheckHook(Hook): 9 | 10 | def _check_head(self, runner, dataset): 11 | """Check whether the `num_classes` in head matches the length of 12 | `CLASSES` in `dataset`. 13 | 14 | Args: 15 | runner (obj:`EpochBasedRunner`, `IterBasedRunner`): runner object. 16 | dataset (obj: `BaseDataset`): the dataset to check. 17 | """ 18 | model = runner.model 19 | if dataset.CLASSES is None: 20 | runner.logger.warning( 21 | f'Please set `CLASSES` ' 22 | f'in the {dataset.__class__.__name__} and' 23 | f'check if it is consistent with the `num_classes` ' 24 | f'of head') 25 | else: 26 | assert is_seq_of(dataset.CLASSES, str), \ 27 | (f'`CLASSES` in {dataset.__class__.__name__}' 28 | f'should be a tuple of str.') 29 | for name, module in model.named_modules(): 30 | if hasattr(module, 'num_classes'): 31 | assert module.num_classes == len(dataset.CLASSES), \ 32 | (f'The `num_classes` ({module.num_classes}) in ' 33 | f'{module.__class__.__name__} of ' 34 | f'{model.__class__.__name__} does not matches ' 35 | f'the length of `CLASSES` ' 36 | f'{len(dataset.CLASSES)}) in ' 37 | f'{dataset.__class__.__name__}') 38 | 39 | def before_train_iter(self, runner): 40 | """Check whether the training dataset is compatible with head. 41 | 42 | Args: 43 | runner (obj: `IterBasedRunner`): Iter based Runner. 44 | """ 45 | if not isinstance(runner, IterBasedRunner): 46 | return 47 | self._check_head(runner, runner.data_loader._dataloader.dataset) 48 | 49 | def before_val_iter(self, runner): 50 | """Check whether the eval dataset is compatible with head. 51 | 52 | Args: 53 | runner (obj:`IterBasedRunner`): Iter based Runner. 54 | """ 55 | if not isinstance(runner, IterBasedRunner): 56 | return 57 | self._check_head(runner, runner.data_loader._dataloader.dataset) 58 | 59 | def before_train_epoch(self, runner): 60 | """Check whether the training dataset is compatible with head. 61 | 62 | Args: 63 | runner (obj:`EpochBasedRunner`): Epoch based Runner. 64 | """ 65 | self._check_head(runner, runner.data_loader.dataset) 66 | 67 | def before_val_epoch(self, runner): 68 | """Check whether the eval dataset is compatible with head. 69 | 70 | Args: 71 | runner (obj:`EpochBasedRunner`): Epoch based Runner. 72 | """ 73 | self._check_head(runner, runner.data_loader.dataset) 74 | -------------------------------------------------------------------------------- /mmcls/core/hook/lr_updater.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from math import cos, pi 3 | 4 | from mmcv.runner.hooks import HOOKS, LrUpdaterHook 5 | 6 | 7 | @HOOKS.register_module() 8 | class CosineAnnealingCooldownLrUpdaterHook(LrUpdaterHook): 9 | """Cosine annealing learning rate scheduler with cooldown. 10 | 11 | Args: 12 | min_lr (float, optional): The minimum learning rate after annealing. 13 | Defaults to None. 14 | min_lr_ratio (float, optional): The minimum learning ratio after 15 | nnealing. Defaults to None. 16 | cool_down_ratio (float): The cooldown ratio. Defaults to 0.1. 17 | cool_down_time (int): The cooldown time. Defaults to 10. 18 | by_epoch (bool): If True, the learning rate changes epoch by epoch. If 19 | False, the learning rate changes iter by iter. Defaults to True. 20 | warmup (string, optional): Type of warmup used. It can be None (use no 21 | warmup), 'constant', 'linear' or 'exp'. Defaults to None. 22 | warmup_iters (int): The number of iterations or epochs that warmup 23 | lasts. Defaults to 0. 24 | warmup_ratio (float): LR used at the beginning of warmup equals to 25 | ``warmup_ratio * initial_lr``. Defaults to 0.1. 26 | warmup_by_epoch (bool): If True, the ``warmup_iters`` 27 | means the number of epochs that warmup lasts, otherwise means the 28 | number of iteration that warmup lasts. Defaults to False. 29 | 30 | Note: 31 | You need to set one and only one of ``min_lr`` and ``min_lr_ratio``. 32 | """ 33 | 34 | def __init__(self, 35 | min_lr=None, 36 | min_lr_ratio=None, 37 | cool_down_ratio=0.1, 38 | cool_down_time=10, 39 | **kwargs): 40 | assert (min_lr is None) ^ (min_lr_ratio is None) 41 | self.min_lr = min_lr 42 | self.min_lr_ratio = min_lr_ratio 43 | self.cool_down_time = cool_down_time 44 | self.cool_down_ratio = cool_down_ratio 45 | super(CosineAnnealingCooldownLrUpdaterHook, self).__init__(**kwargs) 46 | 47 | def get_lr(self, runner, base_lr): 48 | if self.by_epoch: 49 | progress = runner.epoch 50 | max_progress = runner.max_epochs 51 | else: 52 | progress = runner.iter 53 | max_progress = runner.max_iters 54 | 55 | if self.min_lr_ratio is not None: 56 | target_lr = base_lr * self.min_lr_ratio 57 | else: 58 | target_lr = self.min_lr 59 | 60 | if progress > max_progress - self.cool_down_time: 61 | return target_lr * self.cool_down_ratio 62 | else: 63 | max_progress = max_progress - self.cool_down_time 64 | 65 | return annealing_cos(base_lr, target_lr, progress / max_progress) 66 | 67 | 68 | def annealing_cos(start, end, factor, weight=1): 69 | """Calculate annealing cos learning rate. 70 | 71 | Cosine anneal from `weight * start + (1 - weight) * end` to `end` as 72 | percentage goes from 0.0 to 1.0. 73 | 74 | Args: 75 | start (float): The starting learning rate of the cosine annealing. 76 | end (float): The ending learing rate of the cosine annealing. 77 | factor (float): The coefficient of `pi` when calculating the current 78 | percentage. Range from 0.0 to 1.0. 79 | weight (float, optional): The combination factor of `start` and `end` 80 | when calculating the actual starting learning rate. Default to 1. 81 | """ 82 | cos_out = cos(pi * factor) + 1 83 | return end + 0.5 * weight * (start - end) * cos_out 84 | -------------------------------------------------------------------------------- /mmcls/core/optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .lamb import Lamb 3 | 4 | __all__ = [ 5 | 'Lamb', 6 | ] 7 | -------------------------------------------------------------------------------- /mmcls/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .dist_utils import DistOptimizerHook, allreduce_grads, sync_random_seed 3 | from .misc import multi_apply 4 | 5 | __all__ = [ 6 | 'allreduce_grads', 'DistOptimizerHook', 'multi_apply', 'sync_random_seed' 7 | ] 8 | -------------------------------------------------------------------------------- /mmcls/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import OrderedDict 3 | 4 | import numpy as np 5 | import torch 6 | import torch.distributed as dist 7 | from mmcv.runner import OptimizerHook, get_dist_info 8 | from torch._utils import (_flatten_dense_tensors, _take_tensors, 9 | _unflatten_dense_tensors) 10 | 11 | 12 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): 13 | if bucket_size_mb > 0: 14 | bucket_size_bytes = bucket_size_mb * 1024 * 1024 15 | buckets = _take_tensors(tensors, bucket_size_bytes) 16 | else: 17 | buckets = OrderedDict() 18 | for tensor in tensors: 19 | tp = tensor.type() 20 | if tp not in buckets: 21 | buckets[tp] = [] 22 | buckets[tp].append(tensor) 23 | buckets = buckets.values() 24 | 25 | for bucket in buckets: 26 | flat_tensors = _flatten_dense_tensors(bucket) 27 | dist.all_reduce(flat_tensors) 28 | flat_tensors.div_(world_size) 29 | for tensor, synced in zip( 30 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)): 31 | tensor.copy_(synced) 32 | 33 | 34 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): 35 | grads = [ 36 | param.grad.data for param in params 37 | if param.requires_grad and param.grad is not None 38 | ] 39 | world_size = dist.get_world_size() 40 | if coalesce: 41 | _allreduce_coalesced(grads, world_size, bucket_size_mb) 42 | else: 43 | for tensor in grads: 44 | dist.all_reduce(tensor.div_(world_size)) 45 | 46 | 47 | class DistOptimizerHook(OptimizerHook): 48 | 49 | def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): 50 | self.grad_clip = grad_clip 51 | self.coalesce = coalesce 52 | self.bucket_size_mb = bucket_size_mb 53 | 54 | def after_train_iter(self, runner): 55 | runner.optimizer.zero_grad() 56 | runner.outputs['loss'].backward() 57 | if self.grad_clip is not None: 58 | self.clip_grads(runner.model.parameters()) 59 | runner.optimizer.step() 60 | 61 | 62 | def sync_random_seed(seed=None, device='cuda'): 63 | """Make sure different ranks share the same seed. 64 | 65 | All workers must call this function, otherwise it will deadlock. 66 | This method is generally used in `DistributedSampler`, 67 | because the seed should be identical across all processes 68 | in the distributed group. 69 | 70 | In distributed sampling, different ranks should sample non-overlapped 71 | data in the dataset. Therefore, this function is used to make sure that 72 | each rank shuffles the data indices in the same order based 73 | on the same seed. Then different ranks could use different indices 74 | to select non-overlapped data from the same data list. 75 | 76 | Args: 77 | seed (int, Optional): The seed. Default to None. 78 | device (str): The device where the seed will be put on. 79 | Default to 'cuda'. 80 | 81 | Returns: 82 | int: Seed to be used. 83 | """ 84 | if seed is None: 85 | seed = np.random.randint(2**31) 86 | assert isinstance(seed, int) 87 | 88 | rank, world_size = get_dist_info() 89 | 90 | if world_size == 1: 91 | return seed 92 | 93 | if rank == 0: 94 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 95 | else: 96 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 97 | dist.broadcast(random_num, src=0) 98 | return random_num.item() 99 | -------------------------------------------------------------------------------- /mmcls/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from functools import partial 3 | 4 | 5 | def multi_apply(func, *args, **kwargs): 6 | pfunc = partial(func, **kwargs) if kwargs else func 7 | map_results = map(pfunc, *args) 8 | return tuple(map(list, zip(*map_results))) 9 | -------------------------------------------------------------------------------- /mmcls/core/visualization/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .image import (BaseFigureContextManager, ImshowInfosContextManager, 3 | color_val_matplotlib, imshow_infos) 4 | 5 | __all__ = [ 6 | 'BaseFigureContextManager', 'ImshowInfosContextManager', 'imshow_infos', 7 | 'color_val_matplotlib' 8 | ] 9 | -------------------------------------------------------------------------------- /mmcls/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base_dataset import BaseDataset 3 | from .builder import (DATASETS, PIPELINES, SAMPLERS, build_dataloader, 4 | build_dataset, build_sampler) 5 | from .cifar import CIFAR10, CIFAR100 6 | from .cub import CUB 7 | from .custom import CustomDataset 8 | from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset, 9 | KFoldDataset, RepeatDataset) 10 | from .imagenet import ImageNet 11 | from .imagenet21k import ImageNet21k 12 | from .mnist import MNIST, FashionMNIST 13 | from .multi_label import MultiLabelDataset 14 | from .samplers import DistributedSampler, RepeatAugSampler 15 | from .voc import VOC 16 | 17 | __all__ = [ 18 | 'BaseDataset', 'ImageNet', 'CIFAR10', 'CIFAR100', 'MNIST', 'FashionMNIST', 19 | 'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset', 20 | 'DistributedSampler', 'ConcatDataset', 'RepeatDataset', 21 | 'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS', 22 | 'build_sampler', 'RepeatAugSampler', 'KFoldDataset', 'CUB', 'CustomDataset' 23 | ] 24 | -------------------------------------------------------------------------------- /mmcls/datasets/multi_label.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List 3 | 4 | import numpy as np 5 | 6 | from mmcls.core import average_performance, mAP 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | class MultiLabelDataset(BaseDataset): 11 | """Multi-label Dataset.""" 12 | 13 | def get_cat_ids(self, idx: int) -> List[int]: 14 | """Get category ids by index. 15 | 16 | Args: 17 | idx (int): Index of data. 18 | 19 | Returns: 20 | cat_ids (List[int]): Image categories of specified index. 21 | """ 22 | gt_labels = self.data_infos[idx]['gt_label'] 23 | cat_ids = np.where(gt_labels == 1)[0].tolist() 24 | return cat_ids 25 | 26 | def evaluate(self, 27 | results, 28 | metric='mAP', 29 | metric_options=None, 30 | indices=None, 31 | logger=None): 32 | """Evaluate the dataset. 33 | 34 | Args: 35 | results (list): Testing results of the dataset. 36 | metric (str | list[str]): Metrics to be evaluated. 37 | Default value is 'mAP'. Options are 'mAP', 'CP', 'CR', 'CF1', 38 | 'OP', 'OR' and 'OF1'. 39 | metric_options (dict, optional): Options for calculating metrics. 40 | Allowed keys are 'k' and 'thr'. Defaults to None 41 | logger (logging.Logger | str, optional): Logger used for printing 42 | related information during evaluation. Defaults to None. 43 | 44 | Returns: 45 | dict: evaluation results 46 | """ 47 | if metric_options is None or metric_options == {}: 48 | metric_options = {'thr': 0.5} 49 | 50 | if isinstance(metric, str): 51 | metrics = [metric] 52 | else: 53 | metrics = metric 54 | allowed_metrics = ['mAP', 'CP', 'CR', 'CF1', 'OP', 'OR', 'OF1'] 55 | eval_results = {} 56 | results = np.vstack(results) 57 | gt_labels = self.get_gt_labels() 58 | if indices is not None: 59 | gt_labels = gt_labels[indices] 60 | num_imgs = len(results) 61 | assert len(gt_labels) == num_imgs, 'dataset testing results should '\ 62 | 'be of the same length as gt_labels.' 63 | 64 | invalid_metrics = set(metrics) - set(allowed_metrics) 65 | if len(invalid_metrics) != 0: 66 | raise ValueError(f'metric {invalid_metrics} is not supported.') 67 | 68 | if 'mAP' in metrics: 69 | mAP_value = mAP(results, gt_labels) 70 | eval_results['mAP'] = mAP_value 71 | if len(set(metrics) - {'mAP'}) != 0: 72 | performance_keys = ['CP', 'CR', 'CF1', 'OP', 'OR', 'OF1'] 73 | performance_values = average_performance(results, gt_labels, 74 | **metric_options) 75 | for k, v in zip(performance_keys, performance_values): 76 | if k in metrics: 77 | eval_results[k] = v 78 | 79 | return eval_results 80 | -------------------------------------------------------------------------------- /mmcls/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .auto_augment import (AutoAugment, AutoContrast, Brightness, 3 | ColorTransform, Contrast, Cutout, Equalize, Invert, 4 | Posterize, RandAugment, Rotate, Sharpness, Shear, 5 | Solarize, SolarizeAdd, Translate) 6 | from .compose import Compose 7 | from .formatting import (Collect, ImageToTensor, ToNumpy, ToPIL, ToTensor, 8 | Transpose, to_tensor) 9 | from .loading import LoadImageFromFile 10 | from .transforms import (CenterCrop, ColorJitter, Lighting, Normalize, Pad, 11 | RandomCrop, RandomErasing, RandomFlip, 12 | RandomGrayscale, RandomResizedCrop, Resize) 13 | 14 | __all__ = [ 15 | 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToPIL', 'ToNumpy', 16 | 'Transpose', 'Collect', 'LoadImageFromFile', 'Resize', 'CenterCrop', 17 | 'RandomFlip', 'Normalize', 'RandomCrop', 'RandomResizedCrop', 18 | 'RandomGrayscale', 'Shear', 'Translate', 'Rotate', 'Invert', 19 | 'ColorTransform', 'Solarize', 'Posterize', 'AutoContrast', 'Equalize', 20 | 'Contrast', 'Brightness', 'Sharpness', 'AutoAugment', 'SolarizeAdd', 21 | 'Cutout', 'RandAugment', 'Lighting', 'ColorJitter', 'RandomErasing', 'Pad' 22 | ] 23 | -------------------------------------------------------------------------------- /mmcls/datasets/pipelines/compose.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections.abc import Sequence 3 | 4 | from mmcv.utils import build_from_cfg 5 | 6 | from ..builder import PIPELINES 7 | 8 | 9 | @PIPELINES.register_module() 10 | class Compose(object): 11 | """Compose a data pipeline with a sequence of transforms. 12 | 13 | Args: 14 | transforms (list[dict | callable]): 15 | Either config dicts of transforms or transform objects. 16 | """ 17 | 18 | def __init__(self, transforms): 19 | assert isinstance(transforms, Sequence) 20 | self.transforms = [] 21 | for transform in transforms: 22 | if isinstance(transform, dict): 23 | transform = build_from_cfg(transform, PIPELINES) 24 | self.transforms.append(transform) 25 | elif callable(transform): 26 | self.transforms.append(transform) 27 | else: 28 | raise TypeError('transform must be callable or a dict, but got' 29 | f' {type(transform)}') 30 | 31 | def __call__(self, data): 32 | for t in self.transforms: 33 | data = t(data) 34 | if data is None: 35 | return None 36 | return data 37 | 38 | def __repr__(self): 39 | format_string = self.__class__.__name__ + '(' 40 | for t in self.transforms: 41 | format_string += f'\n {t}' 42 | format_string += '\n)' 43 | return format_string 44 | -------------------------------------------------------------------------------- /mmcls/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | 4 | import mmcv 5 | import numpy as np 6 | 7 | from ..builder import PIPELINES 8 | 9 | 10 | @PIPELINES.register_module() 11 | class LoadImageFromFile(object): 12 | """Load an image from file. 13 | 14 | Required keys are "img_prefix" and "img_info" (a dict that must contain the 15 | key "filename"). Added or updated keys are "filename", "img", "img_shape", 16 | "ori_shape" (same as `img_shape`) and "img_norm_cfg" (means=0 and stds=1). 17 | 18 | Args: 19 | to_float32 (bool): Whether to convert the loaded image to a float32 20 | numpy array. If set to False, the loaded image is an uint8 array. 21 | Defaults to False. 22 | color_type (str): The flag argument for :func:`mmcv.imfrombytes()`. 23 | Defaults to 'color'. 24 | file_client_args (dict): Arguments to instantiate a FileClient. 25 | See :class:`mmcv.fileio.FileClient` for details. 26 | Defaults to ``dict(backend='disk')``. 27 | """ 28 | 29 | def __init__(self, 30 | to_float32=False, 31 | color_type='color', 32 | file_client_args=dict(backend='disk')): 33 | self.to_float32 = to_float32 34 | self.color_type = color_type 35 | self.file_client_args = file_client_args.copy() 36 | self.file_client = None 37 | 38 | def __call__(self, results): 39 | if self.file_client is None: 40 | self.file_client = mmcv.FileClient(**self.file_client_args) 41 | 42 | if results['img_prefix'] is not None: 43 | filename = osp.join(results['img_prefix'], 44 | results['img_info']['filename']) 45 | else: 46 | filename = results['img_info']['filename'] 47 | 48 | img_bytes = self.file_client.get(filename) 49 | img = mmcv.imfrombytes(img_bytes, flag=self.color_type) 50 | if self.to_float32: 51 | img = img.astype(np.float32) 52 | 53 | results['filename'] = filename 54 | results['ori_filename'] = results['img_info']['filename'] 55 | results['img'] = img 56 | results['img_shape'] = img.shape 57 | results['ori_shape'] = img.shape 58 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 59 | results['img_norm_cfg'] = dict( 60 | mean=np.zeros(num_channels, dtype=np.float32), 61 | std=np.ones(num_channels, dtype=np.float32), 62 | to_rgb=False) 63 | return results 64 | 65 | def __repr__(self): 66 | repr_str = (f'{self.__class__.__name__}(' 67 | f'to_float32={self.to_float32}, ' 68 | f"color_type='{self.color_type}', " 69 | f'file_client_args={self.file_client_args})') 70 | return repr_str 71 | -------------------------------------------------------------------------------- /mmcls/datasets/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .distributed_sampler import DistributedSampler 3 | from .repeat_aug import RepeatAugSampler 4 | 5 | __all__ = ('DistributedSampler', 'RepeatAugSampler') 6 | -------------------------------------------------------------------------------- /mmcls/datasets/samplers/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch.utils.data import DistributedSampler as _DistributedSampler 4 | 5 | from mmcls.core.utils import sync_random_seed 6 | from mmcls.datasets import SAMPLERS 7 | 8 | 9 | @SAMPLERS.register_module() 10 | class DistributedSampler(_DistributedSampler): 11 | 12 | def __init__(self, 13 | dataset, 14 | num_replicas=None, 15 | rank=None, 16 | shuffle=True, 17 | round_up=True, 18 | seed=0): 19 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 20 | self.shuffle = shuffle 21 | self.round_up = round_up 22 | if self.round_up: 23 | self.total_size = self.num_samples * self.num_replicas 24 | else: 25 | self.total_size = len(self.dataset) 26 | 27 | # In distributed sampling, different ranks should sample 28 | # non-overlapped data in the dataset. Therefore, this function 29 | # is used to make sure that each rank shuffles the data indices 30 | # in the same order based on the same seed. Then different ranks 31 | # could use different indices to select non-overlapped data from the 32 | # same data list. 33 | self.seed = sync_random_seed(seed) 34 | 35 | def __iter__(self): 36 | # deterministically shuffle based on epoch 37 | if self.shuffle: 38 | g = torch.Generator() 39 | # When :attr:`shuffle=True`, this ensures all replicas 40 | # use a different random ordering for each epoch. 41 | # Otherwise, the next iteration of this sampler will 42 | # yield the same ordering. 43 | g.manual_seed(self.epoch + self.seed) 44 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 45 | else: 46 | indices = torch.arange(len(self.dataset)).tolist() 47 | 48 | # add extra samples to make it evenly divisible 49 | if self.round_up: 50 | indices = ( 51 | indices * 52 | int(self.total_size / len(indices) + 1))[:self.total_size] 53 | assert len(indices) == self.total_size 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | if self.round_up: 58 | assert len(indices) == self.num_samples 59 | 60 | return iter(indices) 61 | -------------------------------------------------------------------------------- /mmcls/datasets/voc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | import xml.etree.ElementTree as ET 4 | 5 | import mmcv 6 | import numpy as np 7 | 8 | from .builder import DATASETS 9 | from .multi_label import MultiLabelDataset 10 | 11 | 12 | @DATASETS.register_module() 13 | class VOC(MultiLabelDataset): 14 | """`Pascal VOC `_ Dataset.""" 15 | 16 | CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 17 | 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 18 | 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 19 | 'tvmonitor') 20 | 21 | def __init__(self, **kwargs): 22 | super(VOC, self).__init__(**kwargs) 23 | if 'VOC2007' in self.data_prefix: 24 | self.year = 2007 25 | else: 26 | raise ValueError('Cannot infer dataset year from img_prefix.') 27 | 28 | def load_annotations(self): 29 | """Load annotations. 30 | 31 | Returns: 32 | list[dict]: Annotation info from XML file. 33 | """ 34 | data_infos = [] 35 | img_ids = mmcv.list_from_file(self.ann_file) 36 | for img_id in img_ids: 37 | filename = f'JPEGImages/{img_id}.jpg' 38 | xml_path = osp.join(self.data_prefix, 'Annotations', 39 | f'{img_id}.xml') 40 | tree = ET.parse(xml_path) 41 | root = tree.getroot() 42 | labels = [] 43 | labels_difficult = [] 44 | for obj in root.findall('object'): 45 | label_name = obj.find('name').text 46 | # in case customized dataset has wrong labels 47 | # or CLASSES has been override. 48 | if label_name not in self.CLASSES: 49 | continue 50 | label = self.class_to_idx[label_name] 51 | difficult = int(obj.find('difficult').text) 52 | if difficult: 53 | labels_difficult.append(label) 54 | else: 55 | labels.append(label) 56 | 57 | gt_label = np.zeros(len(self.CLASSES)) 58 | # The order cannot be swapped for the case where multiple objects 59 | # of the same kind exist and some are difficult. 60 | gt_label[labels_difficult] = -1 61 | gt_label[labels] = 1 62 | 63 | info = dict( 64 | img_prefix=self.data_prefix, 65 | img_info=dict(filename=filename), 66 | gt_label=gt_label.astype(np.int8)) 67 | data_infos.append(info) 68 | 69 | return data_infos 70 | -------------------------------------------------------------------------------- /mmcls/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .backbones import * # noqa: F401,F403 3 | from .builder import (BACKBONES, CLASSIFIERS, HEADS, LOSSES, NECKS, 4 | build_backbone, build_classifier, build_head, build_loss, 5 | build_neck) 6 | from .classifiers import * # noqa: F401,F403 7 | from .heads import * # noqa: F401,F403 8 | from .losses import * # noqa: F401,F403 9 | from .necks import * # noqa: F401,F403 10 | 11 | __all__ = [ 12 | 'BACKBONES', 'HEADS', 'NECKS', 'LOSSES', 'CLASSIFIERS', 'build_backbone', 13 | 'build_head', 'build_neck', 'build_loss', 'build_classifier' 14 | ] 15 | -------------------------------------------------------------------------------- /mmcls/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .alexnet import AlexNet 3 | from .conformer import Conformer 4 | from .convmixer import ConvMixer 5 | from .convnext import ConvNeXt 6 | from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt 7 | from .deit import DistilledVisionTransformer 8 | from .densenet import DenseNet 9 | from .efficientnet import EfficientNet 10 | from .hrnet import HRNet 11 | from .lenet import LeNet5 12 | from .mlp_mixer import MlpMixer 13 | from .mobilenet_v2 import MobileNetV2 14 | from .mobilenet_v3 import MobileNetV3 15 | from .poolformer import PoolFormer 16 | from .regnet import RegNet 17 | from .repmlp import RepMLPNet 18 | from .repvgg import RepVGG 19 | from .res2net import Res2Net 20 | from .resnest import ResNeSt 21 | from .resnet import ResNet, ResNetV1c, ResNetV1d 22 | from .resnet_cifar import ResNet_CIFAR 23 | from .resnext import ResNeXt 24 | from .seresnet import SEResNet 25 | from .seresnext import SEResNeXt 26 | from .shufflenet_v1 import ShuffleNetV1 27 | from .shufflenet_v2 import ShuffleNetV2 28 | from .swin_transformer import SwinTransformer 29 | from .t2t_vit import T2T_ViT 30 | from .timm_backbone import TIMMBackbone 31 | from .tnt import TNT 32 | from .twins import PCPVT, SVT 33 | from .van import VAN 34 | from .vgg import VGG 35 | from .vision_transformer import VisionTransformer 36 | 37 | __all__ = [ 38 | 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 39 | 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 40 | 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', 41 | 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG', 42 | 'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'PCPVT', 'SVT', 43 | 'EfficientNet', 'ConvNeXt', 'HRNet', 'ResNetV1c', 'ConvMixer', 44 | 'CSPDarkNet', 'CSPResNet', 'CSPResNeXt', 'CSPNet', 'RepMLPNet', 45 | 'PoolFormer', 'DenseNet', 'VAN' 46 | ] 47 | -------------------------------------------------------------------------------- /mmcls/models/backbones/alexnet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | from ..builder import BACKBONES 5 | from .base_backbone import BaseBackbone 6 | 7 | 8 | @BACKBONES.register_module() 9 | class AlexNet(BaseBackbone): 10 | """`AlexNet `_ backbone. 11 | 12 | The input for AlexNet is a 224x224 RGB image. 13 | 14 | Args: 15 | num_classes (int): number of classes for classification. 16 | The default value is -1, which uses the backbone as 17 | a feature extractor without the top classifier. 18 | """ 19 | 20 | def __init__(self, num_classes=-1): 21 | super(AlexNet, self).__init__() 22 | self.num_classes = num_classes 23 | self.features = nn.Sequential( 24 | nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), 25 | nn.ReLU(inplace=True), 26 | nn.MaxPool2d(kernel_size=3, stride=2), 27 | nn.Conv2d(64, 192, kernel_size=5, padding=2), 28 | nn.ReLU(inplace=True), 29 | nn.MaxPool2d(kernel_size=3, stride=2), 30 | nn.Conv2d(192, 384, kernel_size=3, padding=1), 31 | nn.ReLU(inplace=True), 32 | nn.Conv2d(384, 256, kernel_size=3, padding=1), 33 | nn.ReLU(inplace=True), 34 | nn.Conv2d(256, 256, kernel_size=3, padding=1), 35 | nn.ReLU(inplace=True), 36 | nn.MaxPool2d(kernel_size=3, stride=2), 37 | ) 38 | if self.num_classes > 0: 39 | self.classifier = nn.Sequential( 40 | nn.Dropout(), 41 | nn.Linear(256 * 6 * 6, 4096), 42 | nn.ReLU(inplace=True), 43 | nn.Dropout(), 44 | nn.Linear(4096, 4096), 45 | nn.ReLU(inplace=True), 46 | nn.Linear(4096, num_classes), 47 | ) 48 | 49 | def forward(self, x): 50 | 51 | x = self.features(x) 52 | if self.num_classes > 0: 53 | x = x.view(x.size(0), 256 * 6 * 6) 54 | x = self.classifier(x) 55 | 56 | return (x, ) 57 | -------------------------------------------------------------------------------- /mmcls/models/backbones/base_backbone.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from mmcv.runner import BaseModule 5 | 6 | 7 | class BaseBackbone(BaseModule, metaclass=ABCMeta): 8 | """Base backbone. 9 | 10 | This class defines the basic functions of a backbone. Any backbone that 11 | inherits this class should at least define its own `forward` function. 12 | """ 13 | 14 | def __init__(self, init_cfg=None): 15 | super(BaseBackbone, self).__init__(init_cfg) 16 | 17 | @abstractmethod 18 | def forward(self, x): 19 | """Forward computation. 20 | 21 | Args: 22 | x (tensor | tuple[tensor]): x could be a Torch.tensor or a tuple of 23 | Torch.tensor, containing input data for forward computation. 24 | """ 25 | pass 26 | 27 | def train(self, mode=True): 28 | """Set module status before forward computation. 29 | 30 | Args: 31 | mode (bool): Whether it is train_mode or test_mode 32 | """ 33 | super(BaseBackbone, self).train(mode) 34 | -------------------------------------------------------------------------------- /mmcls/models/backbones/lenet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | 4 | from ..builder import BACKBONES 5 | from .base_backbone import BaseBackbone 6 | 7 | 8 | @BACKBONES.register_module() 9 | class LeNet5(BaseBackbone): 10 | """`LeNet5 `_ backbone. 11 | 12 | The input for LeNet-5 is a 32×32 grayscale image. 13 | 14 | Args: 15 | num_classes (int): number of classes for classification. 16 | The default value is -1, which uses the backbone as 17 | a feature extractor without the top classifier. 18 | """ 19 | 20 | def __init__(self, num_classes=-1): 21 | super(LeNet5, self).__init__() 22 | self.num_classes = num_classes 23 | self.features = nn.Sequential( 24 | nn.Conv2d(1, 6, kernel_size=5, stride=1), nn.Tanh(), 25 | nn.AvgPool2d(kernel_size=2), 26 | nn.Conv2d(6, 16, kernel_size=5, stride=1), nn.Tanh(), 27 | nn.AvgPool2d(kernel_size=2), 28 | nn.Conv2d(16, 120, kernel_size=5, stride=1), nn.Tanh()) 29 | if self.num_classes > 0: 30 | self.classifier = nn.Sequential( 31 | nn.Linear(120, 84), 32 | nn.Tanh(), 33 | nn.Linear(84, num_classes), 34 | ) 35 | 36 | def forward(self, x): 37 | 38 | x = self.features(x) 39 | if self.num_classes > 0: 40 | x = self.classifier(x.squeeze()) 41 | 42 | return (x, ) 43 | -------------------------------------------------------------------------------- /mmcls/models/backbones/resnet_cifar.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn import build_conv_layer, build_norm_layer 4 | 5 | from ..builder import BACKBONES 6 | from .resnet import ResNet 7 | 8 | 9 | @BACKBONES.register_module() 10 | class ResNet_CIFAR(ResNet): 11 | """ResNet backbone for CIFAR. 12 | 13 | Compared to standard ResNet, it uses `kernel_size=3` and `stride=1` in 14 | conv1, and does not apply MaxPoolinng after stem. It has been proven to 15 | be more efficient than standard ResNet in other public codebase, e.g., 16 | `https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py`. 17 | 18 | Args: 19 | depth (int): Network depth, from {18, 34, 50, 101, 152}. 20 | in_channels (int): Number of input image channels. Default: 3. 21 | stem_channels (int): Output channels of the stem layer. Default: 64. 22 | base_channels (int): Middle channels of the first stage. Default: 64. 23 | num_stages (int): Stages of the network. Default: 4. 24 | strides (Sequence[int]): Strides of the first block of each stage. 25 | Default: ``(1, 2, 2, 2)``. 26 | dilations (Sequence[int]): Dilation of each stage. 27 | Default: ``(1, 1, 1, 1)``. 28 | out_indices (Sequence[int]): Output from which stages. If only one 29 | stage is specified, a single tensor (feature map) is returned, 30 | otherwise multiple stages are specified, a tuple of tensors will 31 | be returned. Default: ``(3, )``. 32 | style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two 33 | layer is the 3x3 conv layer, otherwise the stride-two layer is 34 | the first 1x1 conv layer. 35 | deep_stem (bool): This network has specific designed stem, thus it is 36 | asserted to be False. 37 | avg_down (bool): Use AvgPool instead of stride conv when 38 | downsampling in the bottleneck. Default: False. 39 | frozen_stages (int): Stages to be frozen (stop grad and set eval mode). 40 | -1 means not freezing any parameters. Default: -1. 41 | conv_cfg (dict | None): The config dict for conv layers. Default: None. 42 | norm_cfg (dict): The config dict for norm layers. 43 | norm_eval (bool): Whether to set norm layers to eval mode, namely, 44 | freeze running stats (mean and var). Note: Effect on Batch Norm 45 | and its variants only. Default: False. 46 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 47 | memory while slowing down the training speed. Default: False. 48 | zero_init_residual (bool): Whether to use zero init for last norm layer 49 | in resblocks to let them behave as identity. Default: True. 50 | """ 51 | 52 | def __init__(self, depth, deep_stem=False, **kwargs): 53 | super(ResNet_CIFAR, self).__init__( 54 | depth, deep_stem=deep_stem, **kwargs) 55 | assert not self.deep_stem, 'ResNet_CIFAR do not support deep_stem' 56 | 57 | def _make_stem_layer(self, in_channels, base_channels): 58 | self.conv1 = build_conv_layer( 59 | self.conv_cfg, 60 | in_channels, 61 | base_channels, 62 | kernel_size=3, 63 | stride=1, 64 | padding=1, 65 | bias=False) 66 | self.norm1_name, norm1 = build_norm_layer( 67 | self.norm_cfg, base_channels, postfix=1) 68 | self.add_module(self.norm1_name, norm1) 69 | self.relu = nn.ReLU(inplace=True) 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = self.norm1(x) 74 | x = self.relu(x) 75 | outs = [] 76 | for i, layer_name in enumerate(self.res_layers): 77 | res_layer = getattr(self, layer_name) 78 | x = res_layer(x) 79 | if i in self.out_indices: 80 | outs.append(x) 81 | return tuple(outs) 82 | -------------------------------------------------------------------------------- /mmcls/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.cnn import MODELS as MMCV_MODELS 3 | from mmcv.cnn.bricks.registry import ATTENTION as MMCV_ATTENTION 4 | from mmcv.utils import Registry 5 | 6 | MODELS = Registry('models', parent=MMCV_MODELS) 7 | 8 | BACKBONES = MODELS 9 | NECKS = MODELS 10 | HEADS = MODELS 11 | LOSSES = MODELS 12 | CLASSIFIERS = MODELS 13 | 14 | ATTENTION = Registry('attention', parent=MMCV_ATTENTION) 15 | 16 | 17 | def build_backbone(cfg): 18 | """Build backbone.""" 19 | return BACKBONES.build(cfg) 20 | 21 | 22 | def build_neck(cfg): 23 | """Build neck.""" 24 | return NECKS.build(cfg) 25 | 26 | 27 | def build_head(cfg): 28 | """Build head.""" 29 | return HEADS.build(cfg) 30 | 31 | 32 | def build_loss(cfg): 33 | """Build loss.""" 34 | return LOSSES.build(cfg) 35 | 36 | 37 | def build_classifier(cfg): 38 | return CLASSIFIERS.build(cfg) 39 | -------------------------------------------------------------------------------- /mmcls/models/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseClassifier 3 | from .image import ImageClassifier 4 | 5 | __all__ = ['BaseClassifier', 'ImageClassifier'] 6 | -------------------------------------------------------------------------------- /mmcls/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .cls_head import ClsHead 3 | from .conformer_head import ConformerHead 4 | from .deit_head import DeiTClsHead 5 | from .linear_head import LinearClsHead 6 | from .multi_label_head import MultiLabelClsHead 7 | from .multi_label_linear_head import MultiLabelLinearClsHead 8 | from .stacked_head import StackedLinearClsHead 9 | from .vision_transformer_head import VisionTransformerClsHead 10 | 11 | __all__ = [ 12 | 'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead', 13 | 'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead', 14 | 'ConformerHead' 15 | ] 16 | -------------------------------------------------------------------------------- /mmcls/models/heads/base_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | from mmcv.runner import BaseModule 5 | 6 | 7 | class BaseHead(BaseModule, metaclass=ABCMeta): 8 | """Base head.""" 9 | 10 | def __init__(self, init_cfg=None): 11 | super(BaseHead, self).__init__(init_cfg) 12 | 13 | @abstractmethod 14 | def forward_train(self, x, gt_label, **kwargs): 15 | pass 16 | -------------------------------------------------------------------------------- /mmcls/models/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..builder import HEADS 6 | from .cls_head import ClsHead 7 | 8 | 9 | @HEADS.register_module() 10 | class LinearClsHead(ClsHead): 11 | """Linear classifier head. 12 | 13 | Args: 14 | num_classes (int): Number of categories excluding the background 15 | category. 16 | in_channels (int): Number of channels in the input feature map. 17 | init_cfg (dict | optional): The extra init config of layers. 18 | Defaults to use dict(type='Normal', layer='Linear', std=0.01). 19 | """ 20 | 21 | def __init__(self, 22 | num_classes, 23 | in_channels, 24 | init_cfg=dict(type='Normal', layer='Linear', std=0.01), 25 | *args, 26 | **kwargs): 27 | super(LinearClsHead, self).__init__(init_cfg=init_cfg, *args, **kwargs) 28 | 29 | self.in_channels = in_channels 30 | self.num_classes = num_classes 31 | 32 | if self.num_classes <= 0: 33 | raise ValueError( 34 | f'num_classes={num_classes} must be a positive integer') 35 | 36 | self.fc = nn.Linear(self.in_channels, self.num_classes) 37 | 38 | def pre_logits(self, x): 39 | if isinstance(x, tuple): 40 | x = x[-1] 41 | return x 42 | 43 | def simple_test(self, x, softmax=True, post_process=True): 44 | """Inference without augmentation. 45 | 46 | Args: 47 | x (tuple[Tensor]): The input features. 48 | Multi-stage inputs are acceptable but only the last stage will 49 | be used to classify. The shape of every item should be 50 | ``(num_samples, in_channels)``. 51 | softmax (bool): Whether to softmax the classification score. 52 | post_process (bool): Whether to do post processing the 53 | inference results. It will convert the output to a list. 54 | 55 | Returns: 56 | Tensor | list: The inference results. 57 | 58 | - If no post processing, the output is a tensor with shape 59 | ``(num_samples, num_classes)``. 60 | - If post processing, the output is a multi-dimentional list of 61 | float and the dimensions are ``(num_samples, num_classes)``. 62 | """ 63 | x = self.pre_logits(x) 64 | cls_score = self.fc(x) 65 | 66 | if softmax: 67 | pred = ( 68 | F.softmax(cls_score, dim=1) if cls_score is not None else None) 69 | else: 70 | pred = cls_score 71 | 72 | if post_process: 73 | return self.post_process(pred) 74 | else: 75 | return pred 76 | 77 | def forward_train(self, x, gt_label, **kwargs): 78 | x = self.pre_logits(x) 79 | cls_score = self.fc(x) 80 | losses = self.loss(cls_score, gt_label, **kwargs) 81 | return losses 82 | -------------------------------------------------------------------------------- /mmcls/models/heads/multi_label_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | from ..builder import HEADS, build_loss 5 | from ..utils import is_tracing 6 | from .base_head import BaseHead 7 | 8 | 9 | @HEADS.register_module() 10 | class MultiLabelClsHead(BaseHead): 11 | """Classification head for multilabel task. 12 | 13 | Args: 14 | loss (dict): Config of classification loss. 15 | """ 16 | 17 | def __init__(self, 18 | loss=dict( 19 | type='CrossEntropyLoss', 20 | use_sigmoid=True, 21 | reduction='mean', 22 | loss_weight=1.0), 23 | init_cfg=None): 24 | super(MultiLabelClsHead, self).__init__(init_cfg=init_cfg) 25 | 26 | assert isinstance(loss, dict) 27 | 28 | self.compute_loss = build_loss(loss) 29 | 30 | def loss(self, cls_score, gt_label): 31 | gt_label = gt_label.type_as(cls_score) 32 | num_samples = len(cls_score) 33 | losses = dict() 34 | 35 | # map difficult examples to positive ones 36 | _gt_label = torch.abs(gt_label) 37 | # compute loss 38 | loss = self.compute_loss(cls_score, _gt_label, avg_factor=num_samples) 39 | losses['loss'] = loss 40 | return losses 41 | 42 | def forward_train(self, cls_score, gt_label, **kwargs): 43 | if isinstance(cls_score, tuple): 44 | cls_score = cls_score[-1] 45 | gt_label = gt_label.type_as(cls_score) 46 | losses = self.loss(cls_score, gt_label, **kwargs) 47 | return losses 48 | 49 | def pre_logits(self, x): 50 | if isinstance(x, tuple): 51 | x = x[-1] 52 | 53 | from mmcls.utils import get_root_logger 54 | logger = get_root_logger() 55 | logger.warning( 56 | 'The input of MultiLabelClsHead should be already logits. ' 57 | 'Please modify the backbone if you want to get pre-logits feature.' 58 | ) 59 | return x 60 | 61 | def simple_test(self, x, sigmoid=True, post_process=True): 62 | """Inference without augmentation. 63 | 64 | Args: 65 | cls_score (tuple[Tensor]): The input classification score logits. 66 | Multi-stage inputs are acceptable but only the last stage will 67 | be used to classify. The shape of every item should be 68 | ``(num_samples, num_classes)``. 69 | sigmoid (bool): Whether to sigmoid the classification score. 70 | post_process (bool): Whether to do post processing the 71 | inference results. It will convert the output to a list. 72 | 73 | Returns: 74 | Tensor | list: The inference results. 75 | 76 | - If no post processing, the output is a tensor with shape 77 | ``(num_samples, num_classes)``. 78 | - If post processing, the output is a multi-dimentional list of 79 | float and the dimensions are ``(num_samples, num_classes)``. 80 | """ 81 | if isinstance(x, tuple): 82 | x = x[-1] 83 | 84 | if sigmoid: 85 | pred = torch.sigmoid(x) if x is not None else None 86 | else: 87 | pred = x 88 | 89 | if post_process: 90 | return self.post_process(pred) 91 | else: 92 | return pred 93 | 94 | def post_process(self, pred): 95 | on_trace = is_tracing() 96 | if torch.onnx.is_in_onnx_export() or on_trace: 97 | return pred 98 | pred = list(pred.detach().cpu().numpy()) 99 | return pred 100 | -------------------------------------------------------------------------------- /mmcls/models/heads/multi_label_linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..builder import HEADS 6 | from .multi_label_head import MultiLabelClsHead 7 | 8 | 9 | @HEADS.register_module() 10 | class MultiLabelLinearClsHead(MultiLabelClsHead): 11 | """Linear classification head for multilabel task. 12 | 13 | Args: 14 | num_classes (int): Number of categories. 15 | in_channels (int): Number of channels in the input feature map. 16 | loss (dict): Config of classification loss. 17 | init_cfg (dict | optional): The extra init config of layers. 18 | Defaults to use dict(type='Normal', layer='Linear', std=0.01). 19 | """ 20 | 21 | def __init__(self, 22 | num_classes, 23 | in_channels, 24 | loss=dict( 25 | type='CrossEntropyLoss', 26 | use_sigmoid=True, 27 | reduction='mean', 28 | loss_weight=1.0), 29 | init_cfg=dict(type='Normal', layer='Linear', std=0.01)): 30 | super(MultiLabelLinearClsHead, self).__init__( 31 | loss=loss, init_cfg=init_cfg) 32 | 33 | if num_classes <= 0: 34 | raise ValueError( 35 | f'num_classes={num_classes} must be a positive integer') 36 | 37 | self.in_channels = in_channels 38 | self.num_classes = num_classes 39 | 40 | self.fc = nn.Linear(self.in_channels, self.num_classes) 41 | 42 | def pre_logits(self, x): 43 | if isinstance(x, tuple): 44 | x = x[-1] 45 | return x 46 | 47 | def forward_train(self, x, gt_label, **kwargs): 48 | x = self.pre_logits(x) 49 | gt_label = gt_label.type_as(x) 50 | cls_score = self.fc(x) 51 | losses = self.loss(cls_score, gt_label, **kwargs) 52 | return losses 53 | 54 | def simple_test(self, x, sigmoid=True, post_process=True): 55 | """Inference without augmentation. 56 | 57 | Args: 58 | x (tuple[Tensor]): The input features. 59 | Multi-stage inputs are acceptable but only the last stage will 60 | be used to classify. The shape of every item should be 61 | ``(num_samples, in_channels)``. 62 | sigmoid (bool): Whether to sigmoid the classification score. 63 | post_process (bool): Whether to do post processing the 64 | inference results. It will convert the output to a list. 65 | 66 | Returns: 67 | Tensor | list: The inference results. 68 | 69 | - If no post processing, the output is a tensor with shape 70 | ``(num_samples, num_classes)``. 71 | - If post processing, the output is a multi-dimentional list of 72 | float and the dimensions are ``(num_samples, num_classes)``. 73 | """ 74 | x = self.pre_logits(x) 75 | cls_score = self.fc(x) 76 | 77 | if sigmoid: 78 | pred = torch.sigmoid(cls_score) if cls_score is not None else None 79 | else: 80 | pred = cls_score 81 | 82 | if post_process: 83 | return self.post_process(pred) 84 | else: 85 | return pred 86 | -------------------------------------------------------------------------------- /mmcls/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .accuracy import Accuracy, accuracy 3 | from .asymmetric_loss import AsymmetricLoss, asymmetric_loss 4 | from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy, 5 | cross_entropy) 6 | from .focal_loss import FocalLoss, sigmoid_focal_loss 7 | from .label_smooth_loss import LabelSmoothLoss 8 | from .seesaw_loss import SeesawLoss 9 | from .utils import (convert_to_one_hot, reduce_loss, weight_reduce_loss, 10 | weighted_loss) 11 | 12 | __all__ = [ 13 | 'accuracy', 'Accuracy', 'asymmetric_loss', 'AsymmetricLoss', 14 | 'cross_entropy', 'binary_cross_entropy', 'CrossEntropyLoss', 'reduce_loss', 15 | 'weight_reduce_loss', 'LabelSmoothLoss', 'weighted_loss', 'FocalLoss', 16 | 'sigmoid_focal_loss', 'convert_to_one_hot', 'SeesawLoss' 17 | ] 18 | -------------------------------------------------------------------------------- /mmcls/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .gap import GlobalAveragePooling 3 | from .gem import GeneralizedMeanPooling 4 | from .hr_fuse import HRFuseScales 5 | 6 | __all__ = ['GlobalAveragePooling', 'GeneralizedMeanPooling', 'HRFuseScales'] 7 | -------------------------------------------------------------------------------- /mmcls/models/necks/gap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | 5 | from ..builder import NECKS 6 | 7 | 8 | @NECKS.register_module() 9 | class GlobalAveragePooling(nn.Module): 10 | """Global Average Pooling neck. 11 | 12 | Note that we use `view` to remove extra channel after pooling. We do not 13 | use `squeeze` as it will also remove the batch dimension when the tensor 14 | has a batch dimension of size 1, which can lead to unexpected errors. 15 | 16 | Args: 17 | dim (int): Dimensions of each sample channel, can be one of {1, 2, 3}. 18 | Default: 2 19 | """ 20 | 21 | def __init__(self, dim=2): 22 | super(GlobalAveragePooling, self).__init__() 23 | assert dim in [1, 2, 3], 'GlobalAveragePooling dim only support ' \ 24 | f'{1, 2, 3}, get {dim} instead.' 25 | if dim == 1: 26 | self.gap = nn.AdaptiveAvgPool1d(1) 27 | elif dim == 2: 28 | self.gap = nn.AdaptiveAvgPool2d((1, 1)) 29 | else: 30 | self.gap = nn.AdaptiveAvgPool3d((1, 1, 1)) 31 | 32 | def init_weights(self): 33 | pass 34 | 35 | def forward(self, inputs): 36 | if isinstance(inputs, tuple): 37 | outs = tuple([self.gap(x) for x in inputs]) 38 | outs = tuple( 39 | [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) 40 | elif isinstance(inputs, torch.Tensor): 41 | outs = self.gap(inputs) 42 | outs = outs.view(inputs.size(0), -1) 43 | else: 44 | raise TypeError('neck inputs should be tuple or torch.tensor') 45 | return outs 46 | -------------------------------------------------------------------------------- /mmcls/models/necks/gem.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from torch import Tensor, nn 4 | from torch.nn import functional as F 5 | from torch.nn.parameter import Parameter 6 | 7 | from ..builder import NECKS 8 | 9 | 10 | def gem(x: Tensor, p: Parameter, eps: float = 1e-6, clamp=True) -> Tensor: 11 | if clamp: 12 | x = x.clamp(min=eps) 13 | return F.avg_pool2d(x.pow(p), (x.size(-2), x.size(-1))).pow(1. / p) 14 | 15 | 16 | @NECKS.register_module() 17 | class GeneralizedMeanPooling(nn.Module): 18 | """Generalized Mean Pooling neck. 19 | 20 | Note that we use `view` to remove extra channel after pooling. We do not 21 | use `squeeze` as it will also remove the batch dimension when the tensor 22 | has a batch dimension of size 1, which can lead to unexpected errors. 23 | 24 | Args: 25 | p (float): Parameter value. 26 | Default: 3. 27 | eps (float): epsilon. 28 | Default: 1e-6 29 | clamp (bool): Use clamp before pooling. 30 | Default: True 31 | """ 32 | 33 | def __init__(self, p=3., eps=1e-6, clamp=True): 34 | assert p >= 1, "'p' must be a value greater then 1" 35 | super(GeneralizedMeanPooling, self).__init__() 36 | self.p = Parameter(torch.ones(1) * p) 37 | self.eps = eps 38 | self.clamp = clamp 39 | 40 | def forward(self, inputs): 41 | if isinstance(inputs, tuple): 42 | outs = tuple([ 43 | gem(x, p=self.p, eps=self.eps, clamp=self.clamp) 44 | for x in inputs 45 | ]) 46 | outs = tuple( 47 | [out.view(x.size(0), -1) for out, x in zip(outs, inputs)]) 48 | elif isinstance(inputs, torch.Tensor): 49 | outs = gem(inputs, p=self.p, eps=self.eps, clamp=self.clamp) 50 | outs = outs.view(inputs.size(0), -1) 51 | else: 52 | raise TypeError('neck inputs should be tuple or torch.tensor') 53 | return outs 54 | -------------------------------------------------------------------------------- /mmcls/models/necks/hr_fuse.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.cnn.bricks import ConvModule 4 | from mmcv.runner import BaseModule 5 | 6 | from ..backbones.resnet import Bottleneck, ResLayer 7 | from ..builder import NECKS 8 | 9 | 10 | @NECKS.register_module() 11 | class HRFuseScales(BaseModule): 12 | """Fuse feature map of multiple scales in HRNet. 13 | 14 | Args: 15 | in_channels (list[int]): The input channels of all scales. 16 | out_channels (int): The channels of fused feature map. 17 | Defaults to 2048. 18 | norm_cfg (dict): dictionary to construct norm layers. 19 | Defaults to ``dict(type='BN', momentum=0.1)``. 20 | init_cfg (dict | list[dict], optional): Initialization config dict. 21 | Defaults to ``dict(type='Normal', layer='Linear', std=0.01))``. 22 | """ 23 | 24 | def __init__(self, 25 | in_channels, 26 | out_channels=2048, 27 | norm_cfg=dict(type='BN', momentum=0.1), 28 | init_cfg=dict(type='Normal', layer='Linear', std=0.01)): 29 | super(HRFuseScales, self).__init__(init_cfg=init_cfg) 30 | self.in_channels = in_channels 31 | self.out_channels = out_channels 32 | self.norm_cfg = norm_cfg 33 | 34 | block_type = Bottleneck 35 | out_channels = [128, 256, 512, 1024] 36 | 37 | # Increase the channels on each resolution 38 | # from C, 2C, 4C, 8C to 128, 256, 512, 1024 39 | increase_layers = [] 40 | for i in range(len(in_channels)): 41 | increase_layers.append( 42 | ResLayer( 43 | block_type, 44 | in_channels=in_channels[i], 45 | out_channels=out_channels[i], 46 | num_blocks=1, 47 | stride=1, 48 | )) 49 | self.increase_layers = nn.ModuleList(increase_layers) 50 | 51 | # Downsample feature maps in each scale. 52 | downsample_layers = [] 53 | for i in range(len(in_channels) - 1): 54 | downsample_layers.append( 55 | ConvModule( 56 | in_channels=out_channels[i], 57 | out_channels=out_channels[i + 1], 58 | kernel_size=3, 59 | stride=2, 60 | padding=1, 61 | norm_cfg=self.norm_cfg, 62 | bias=False, 63 | )) 64 | self.downsample_layers = nn.ModuleList(downsample_layers) 65 | 66 | # The final conv block before final classifier linear layer. 67 | self.final_layer = ConvModule( 68 | in_channels=out_channels[3], 69 | out_channels=self.out_channels, 70 | kernel_size=1, 71 | norm_cfg=self.norm_cfg, 72 | bias=False, 73 | ) 74 | 75 | def forward(self, x): 76 | assert isinstance(x, tuple) and len(x) == len(self.in_channels) 77 | 78 | feat = self.increase_layers[0](x[0]) 79 | for i in range(len(self.downsample_layers)): 80 | feat = self.downsample_layers[i](feat) + \ 81 | self.increase_layers[i + 1](x[i + 1]) 82 | 83 | return (self.final_layer(feat), ) 84 | -------------------------------------------------------------------------------- /mmcls/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .attention import MultiheadAttention, ShiftWindowMSA 3 | from .augment.augments import Augments 4 | from .channel_shuffle import channel_shuffle 5 | from .embed import (HybridEmbed, PatchEmbed, PatchMerging, resize_pos_embed, 6 | resize_relative_position_bias_table) 7 | from .helpers import is_tracing, to_2tuple, to_3tuple, to_4tuple, to_ntuple 8 | from .inverted_residual import InvertedResidual 9 | from .make_divisible import make_divisible 10 | from .position_encoding import ConditionalPositionEncoding 11 | from .se_layer import SELayer 12 | 13 | __all__ = [ 14 | 'channel_shuffle', 'make_divisible', 'InvertedResidual', 'SELayer', 15 | 'to_ntuple', 'to_2tuple', 'to_3tuple', 'to_4tuple', 'PatchEmbed', 16 | 'PatchMerging', 'HybridEmbed', 'Augments', 'ShiftWindowMSA', 'is_tracing', 17 | 'MultiheadAttention', 'ConditionalPositionEncoding', 'resize_pos_embed', 18 | 'resize_relative_position_bias_table' 19 | ] 20 | -------------------------------------------------------------------------------- /mmcls/models/utils/augment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .augments import Augments 3 | from .cutmix import BatchCutMixLayer 4 | from .identity import Identity 5 | from .mixup import BatchMixupLayer 6 | from .resizemix import BatchResizeMixLayer 7 | 8 | __all__ = ('Augments', 'BatchCutMixLayer', 'Identity', 'BatchMixupLayer', 9 | 'BatchResizeMixLayer') 10 | -------------------------------------------------------------------------------- /mmcls/models/utils/augment/augments.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import random 3 | 4 | import numpy as np 5 | 6 | from .builder import build_augment 7 | 8 | 9 | class Augments(object): 10 | """Data augments. 11 | 12 | We implement some data augmentation methods, such as mixup, cutmix. 13 | 14 | Args: 15 | augments_cfg (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict`): 16 | Config dict of augments 17 | 18 | Example: 19 | >>> augments_cfg = [ 20 | dict(type='BatchCutMix', alpha=1., num_classes=10, prob=0.5), 21 | dict(type='BatchMixup', alpha=1., num_classes=10, prob=0.3) 22 | ] 23 | >>> augments = Augments(augments_cfg) 24 | >>> imgs = torch.randn(16, 3, 32, 32) 25 | >>> label = torch.randint(0, 10, (16, )) 26 | >>> imgs, label = augments(imgs, label) 27 | 28 | To decide which augmentation within Augments block is used 29 | the following rule is applied. 30 | We pick augmentation based on the probabilities. In the example above, 31 | we decide if we should use BatchCutMix with probability 0.5, 32 | BatchMixup 0.3. As Identity is not in augments_cfg, we use Identity with 33 | probability 1 - 0.5 - 0.3 = 0.2. 34 | """ 35 | 36 | def __init__(self, augments_cfg): 37 | super(Augments, self).__init__() 38 | 39 | if isinstance(augments_cfg, dict): 40 | augments_cfg = [augments_cfg] 41 | 42 | assert len(augments_cfg) > 0, \ 43 | 'The length of augments_cfg should be positive.' 44 | self.augments = [build_augment(cfg) for cfg in augments_cfg] 45 | self.augment_probs = [aug.prob for aug in self.augments] 46 | 47 | has_identity = any([cfg['type'] == 'Identity' for cfg in augments_cfg]) 48 | if has_identity: 49 | assert sum(self.augment_probs) == 1.0,\ 50 | 'The sum of augmentation probabilities should equal to 1,' \ 51 | ' but got {:.2f}'.format(sum(self.augment_probs)) 52 | else: 53 | assert sum(self.augment_probs) <= 1.0,\ 54 | 'The sum of augmentation probabilities should less than or ' \ 55 | 'equal to 1, but got {:.2f}'.format(sum(self.augment_probs)) 56 | identity_prob = 1 - sum(self.augment_probs) 57 | if identity_prob > 0: 58 | num_classes = self.augments[0].num_classes 59 | self.augments += [ 60 | build_augment( 61 | dict( 62 | type='Identity', 63 | num_classes=num_classes, 64 | prob=identity_prob)) 65 | ] 66 | self.augment_probs += [identity_prob] 67 | 68 | def __call__(self, img, gt_label): 69 | if self.augments: 70 | random_state = np.random.RandomState(random.randint(0, 2**32 - 1)) 71 | aug = random_state.choice(self.augments, p=self.augment_probs) 72 | return aug(img, gt_label) 73 | return img, gt_label 74 | -------------------------------------------------------------------------------- /mmcls/models/utils/augment/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import Registry, build_from_cfg 3 | 4 | AUGMENT = Registry('augment') 5 | 6 | 7 | def build_augment(cfg, default_args=None): 8 | return build_from_cfg(cfg, AUGMENT, default_args) 9 | -------------------------------------------------------------------------------- /mmcls/models/utils/augment/identity.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .builder import AUGMENT 3 | from .utils import one_hot_encoding 4 | 5 | 6 | @AUGMENT.register_module(name='Identity') 7 | class Identity(object): 8 | """Change gt_label to one_hot encoding and keep img as the same. 9 | 10 | Args: 11 | num_classes (int): The number of classes. 12 | prob (float): MixUp probability. It should be in range [0, 1]. 13 | Default to 1.0 14 | """ 15 | 16 | def __init__(self, num_classes, prob=1.0): 17 | super(Identity, self).__init__() 18 | 19 | assert isinstance(num_classes, int) 20 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0 21 | 22 | self.num_classes = num_classes 23 | self.prob = prob 24 | 25 | def one_hot(self, gt_label): 26 | return one_hot_encoding(gt_label, self.num_classes) 27 | 28 | def __call__(self, img, gt_label): 29 | return img, self.one_hot(gt_label) 30 | -------------------------------------------------------------------------------- /mmcls/models/utils/augment/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from .builder import AUGMENT 8 | from .utils import one_hot_encoding 9 | 10 | 11 | class BaseMixupLayer(object, metaclass=ABCMeta): 12 | """Base class for MixupLayer. 13 | 14 | Args: 15 | alpha (float): Parameters for Beta distribution to generate the 16 | mixing ratio. It should be a positive number. 17 | num_classes (int): The number of classes. 18 | prob (float): MixUp probability. It should be in range [0, 1]. 19 | Default to 1.0 20 | """ 21 | 22 | def __init__(self, alpha, num_classes, prob=1.0): 23 | super(BaseMixupLayer, self).__init__() 24 | 25 | assert isinstance(alpha, float) and alpha > 0 26 | assert isinstance(num_classes, int) 27 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0 28 | 29 | self.alpha = alpha 30 | self.num_classes = num_classes 31 | self.prob = prob 32 | 33 | @abstractmethod 34 | def mixup(self, imgs, gt_label): 35 | pass 36 | 37 | 38 | @AUGMENT.register_module(name='BatchMixup') 39 | class BatchMixupLayer(BaseMixupLayer): 40 | r"""Mixup layer for a batch of data. 41 | 42 | Mixup is a method to reduces the memorization of corrupt labels and 43 | increases the robustness to adversarial examples. It's 44 | proposed in `mixup: Beyond Empirical Risk Minimization 45 | ` 46 | 47 | This method simply linearly mix pairs of data and their labels. 48 | 49 | Args: 50 | alpha (float): Parameters for Beta distribution to generate the 51 | mixing ratio. It should be a positive number. More details 52 | are in the note. 53 | num_classes (int): The number of classes. 54 | prob (float): The probability to execute mixup. It should be in 55 | range [0, 1]. Default sto 1.0. 56 | 57 | Note: 58 | The :math:`\alpha` (``alpha``) determines a random distribution 59 | :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample 60 | a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random 61 | distribution. 62 | """ 63 | 64 | def __init__(self, *args, **kwargs): 65 | super(BatchMixupLayer, self).__init__(*args, **kwargs) 66 | 67 | def mixup(self, img, gt_label): 68 | one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes) 69 | lam = np.random.beta(self.alpha, self.alpha) 70 | batch_size = img.size(0) 71 | index = torch.randperm(batch_size) 72 | 73 | mixed_img = lam * img + (1 - lam) * img[index, :] 74 | mixed_gt_label = lam * one_hot_gt_label + ( 75 | 1 - lam) * one_hot_gt_label[index, :] 76 | 77 | return mixed_img, mixed_gt_label 78 | 79 | def __call__(self, img, gt_label): 80 | return self.mixup(img, gt_label) 81 | -------------------------------------------------------------------------------- /mmcls/models/utils/augment/resizemix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from mmcls.models.utils.augment.builder import AUGMENT 7 | from .cutmix import BatchCutMixLayer 8 | from .utils import one_hot_encoding 9 | 10 | 11 | @AUGMENT.register_module(name='BatchResizeMix') 12 | class BatchResizeMixLayer(BatchCutMixLayer): 13 | r"""ResizeMix Random Paste layer for a batch of data. 14 | 15 | The ResizeMix will resize an image to a small patch and paste it on another 16 | image. It's proposed in `ResizeMix: Mixing Data with Preserved Object 17 | Information and True Labels `_ 18 | 19 | Args: 20 | alpha (float): Parameters for Beta distribution to generate the 21 | mixing ratio. It should be a positive number. More details 22 | can be found in :class:`BatchMixupLayer`. 23 | num_classes (int): The number of classes. 24 | lam_min(float): The minimum value of lam. Defaults to 0.1. 25 | lam_max(float): The maximum value of lam. Defaults to 0.8. 26 | interpolation (str): algorithm used for upsampling: 27 | 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 28 | 'area'. Default to 'bilinear'. 29 | prob (float): The probability to execute resizemix. It should be in 30 | range [0, 1]. Defaults to 1.0. 31 | cutmix_minmax (List[float], optional): The min/max area ratio of the 32 | patches. If not None, the bounding-box of patches is uniform 33 | sampled within this ratio range, and the ``alpha`` will be ignored. 34 | Otherwise, the bounding-box is generated according to the 35 | ``alpha``. Defaults to None. 36 | correct_lam (bool): Whether to apply lambda correction when cutmix bbox 37 | clipped by image borders. Defaults to True 38 | **kwargs: Any other parameters accpeted by :class:`BatchCutMixLayer`. 39 | 40 | Note: 41 | The :math:`\lambda` (``lam``) is the mixing ratio. It's a random 42 | variable which follows :math:`Beta(\alpha, \alpha)` and is mapped 43 | to the range [``lam_min``, ``lam_max``]. 44 | 45 | .. math:: 46 | \lambda = \frac{Beta(\alpha, \alpha)} 47 | {\lambda_{max} - \lambda_{min}} + \lambda_{min} 48 | 49 | And the resize ratio of source images is calculated by :math:`\lambda`: 50 | 51 | .. math:: 52 | \text{ratio} = \sqrt{1-\lambda} 53 | """ 54 | 55 | def __init__(self, 56 | alpha, 57 | num_classes, 58 | lam_min: float = 0.1, 59 | lam_max: float = 0.8, 60 | interpolation='bilinear', 61 | prob=1.0, 62 | cutmix_minmax=None, 63 | correct_lam=True, 64 | **kwargs): 65 | super(BatchResizeMixLayer, self).__init__( 66 | alpha=alpha, 67 | num_classes=num_classes, 68 | prob=prob, 69 | cutmix_minmax=cutmix_minmax, 70 | correct_lam=correct_lam, 71 | **kwargs) 72 | self.lam_min = lam_min 73 | self.lam_max = lam_max 74 | self.interpolation = interpolation 75 | 76 | def cutmix(self, img, gt_label): 77 | one_hot_gt_label = one_hot_encoding(gt_label, self.num_classes) 78 | 79 | lam = np.random.beta(self.alpha, self.alpha) 80 | lam = lam * (self.lam_max - self.lam_min) + self.lam_min 81 | batch_size = img.size(0) 82 | index = torch.randperm(batch_size) 83 | 84 | (bby1, bby2, bbx1, 85 | bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam) 86 | 87 | img[:, :, bby1:bby2, bbx1:bbx2] = F.interpolate( 88 | img[index], 89 | size=(bby2 - bby1, bbx2 - bbx1), 90 | mode=self.interpolation) 91 | mixed_gt_label = lam * one_hot_gt_label + ( 92 | 1 - lam) * one_hot_gt_label[index, :] 93 | return img, mixed_gt_label 94 | -------------------------------------------------------------------------------- /mmcls/models/utils/augment/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn.functional as F 3 | 4 | 5 | def one_hot_encoding(gt, num_classes): 6 | """Change gt_label to one_hot encoding. 7 | 8 | If the shape has 2 or more 9 | dimensions, return it without encoding. 10 | Args: 11 | gt (Tensor): The gt label with shape (N,) or shape (N, */). 12 | num_classes (int): The number of classes. 13 | Return: 14 | Tensor: One hot gt label. 15 | """ 16 | if gt.ndim == 1: 17 | # multi-class classification 18 | return F.one_hot(gt, num_classes=num_classes) 19 | else: 20 | # binary classification 21 | # example. [[0], [1], [1]] 22 | # multi-label classification 23 | # example. [[0, 1, 1], [1, 0, 0], [1, 1, 1]] 24 | return gt 25 | -------------------------------------------------------------------------------- /mmcls/models/utils/channel_shuffle.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | 4 | 5 | def channel_shuffle(x, groups): 6 | """Channel Shuffle operation. 7 | 8 | This function enables cross-group information flow for multiple groups 9 | convolution layers. 10 | 11 | Args: 12 | x (Tensor): The input tensor. 13 | groups (int): The number of groups to divide the input tensor 14 | in the channel dimension. 15 | 16 | Returns: 17 | Tensor: The output tensor after channel shuffle operation. 18 | """ 19 | 20 | batch_size, num_channels, height, width = x.size() 21 | assert (num_channels % groups == 0), ('num_channels should be ' 22 | 'divisible by groups') 23 | channels_per_group = num_channels // groups 24 | 25 | x = x.view(batch_size, groups, channels_per_group, height, width) 26 | x = torch.transpose(x, 1, 2).contiguous() 27 | x = x.view(batch_size, -1, height, width) 28 | 29 | return x 30 | -------------------------------------------------------------------------------- /mmcls/models/utils/helpers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import collections.abc 3 | import warnings 4 | from itertools import repeat 5 | 6 | import torch 7 | from mmcv.utils import digit_version 8 | 9 | 10 | def is_tracing() -> bool: 11 | """Determine whether the model is called during the tracing of code with 12 | ``torch.jit.trace``.""" 13 | if digit_version(torch.__version__) >= digit_version('1.6.0'): 14 | on_trace = torch.jit.is_tracing() 15 | # In PyTorch 1.6, torch.jit.is_tracing has a bug. 16 | # Refers to https://github.com/pytorch/pytorch/issues/42448 17 | if isinstance(on_trace, bool): 18 | return on_trace 19 | else: 20 | return torch._C._is_tracing() 21 | else: 22 | warnings.warn( 23 | 'torch.jit.is_tracing is only supported after v1.6.0. ' 24 | 'Therefore is_tracing returns False automatically. Please ' 25 | 'set on_trace manually if you are using trace.', UserWarning) 26 | return False 27 | 28 | 29 | # From PyTorch internals 30 | def _ntuple(n): 31 | """A `to_tuple` function generator. 32 | 33 | It returns a function, this function will repeat the input to a tuple of 34 | length ``n`` if the input is not an Iterable object, otherwise, return the 35 | input directly. 36 | 37 | Args: 38 | n (int): The number of the target length. 39 | """ 40 | 41 | def parse(x): 42 | if isinstance(x, collections.abc.Iterable): 43 | return x 44 | return tuple(repeat(x, n)) 45 | 46 | return parse 47 | 48 | 49 | to_1tuple = _ntuple(1) 50 | to_2tuple = _ntuple(2) 51 | to_3tuple = _ntuple(3) 52 | to_4tuple = _ntuple(4) 53 | to_ntuple = _ntuple 54 | -------------------------------------------------------------------------------- /mmcls/models/utils/make_divisible.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | def make_divisible(value, divisor, min_value=None, min_ratio=0.9): 3 | """Make divisible function. 4 | 5 | This function rounds the channel number down to the nearest value that can 6 | be divisible by the divisor. 7 | 8 | Args: 9 | value (int): The original channel number. 10 | divisor (int): The divisor to fully divide the channel number. 11 | min_value (int, optional): The minimum value of the output channel. 12 | Default: None, means that the minimum value equal to the divisor. 13 | min_ratio (float): The minimum ratio of the rounded channel 14 | number to the original channel number. Default: 0.9. 15 | Returns: 16 | int: The modified output channel number 17 | """ 18 | 19 | if min_value is None: 20 | min_value = divisor 21 | new_value = max(min_value, int(value + divisor / 2) // divisor * divisor) 22 | # Make sure that round down does not go down by more than (1-min_ratio). 23 | if new_value < min_ratio * value: 24 | new_value += divisor 25 | return new_value 26 | -------------------------------------------------------------------------------- /mmcls/models/utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.runner.base_module import BaseModule 4 | 5 | 6 | class ConditionalPositionEncoding(BaseModule): 7 | """The Conditional Position Encoding (CPE) module. 8 | 9 | The CPE is the implementation of 'Conditional Positional Encodings 10 | for Vision Transformers '_. 11 | 12 | Args: 13 | in_channels (int): Number of input channels. 14 | embed_dims (int): The feature dimension. Default: 768. 15 | stride (int): Stride of conv layer. Default: 1. 16 | """ 17 | 18 | def __init__(self, in_channels, embed_dims=768, stride=1, init_cfg=None): 19 | super(ConditionalPositionEncoding, self).__init__(init_cfg=init_cfg) 20 | self.proj = nn.Conv2d( 21 | in_channels, 22 | embed_dims, 23 | kernel_size=3, 24 | stride=stride, 25 | padding=1, 26 | bias=True, 27 | groups=embed_dims) 28 | self.stride = stride 29 | 30 | def forward(self, x, hw_shape): 31 | B, N, C = x.shape 32 | H, W = hw_shape 33 | feat_token = x 34 | # convert (B, N, C) to (B, C, H, W) 35 | cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W).contiguous() 36 | if self.stride == 1: 37 | x = self.proj(cnn_feat) + cnn_feat 38 | else: 39 | x = self.proj(cnn_feat) 40 | x = x.flatten(2).transpose(1, 2) 41 | return x 42 | -------------------------------------------------------------------------------- /mmcls/models/utils/se_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch.nn as nn 4 | from mmcv.cnn import ConvModule 5 | from mmcv.runner import BaseModule 6 | 7 | from .make_divisible import make_divisible 8 | 9 | 10 | class SELayer(BaseModule): 11 | """Squeeze-and-Excitation Module. 12 | 13 | Args: 14 | channels (int): The input (and output) channels of the SE layer. 15 | squeeze_channels (None or int): The intermediate channel number of 16 | SElayer. Default: None, means the value of ``squeeze_channels`` 17 | is ``make_divisible(channels // ratio, divisor)``. 18 | ratio (int): Squeeze ratio in SELayer, the intermediate channel will 19 | be ``make_divisible(channels // ratio, divisor)``. Only used when 20 | ``squeeze_channels`` is None. Default: 16. 21 | divisor(int): The divisor to true divide the channel number. Only 22 | used when ``squeeze_channels`` is None. Default: 8. 23 | conv_cfg (None or dict): Config dict for convolution layer. Default: 24 | None, which means using conv2d. 25 | return_weight(bool): Whether to return the weight. Default: False. 26 | act_cfg (dict or Sequence[dict]): Config dict for activation layer. 27 | If act_cfg is a dict, two activation layers will be configurated 28 | by this dict. If act_cfg is a sequence of dicts, the first 29 | activation layer will be configurated by the first dict and the 30 | second activation layer will be configurated by the second dict. 31 | Default: (dict(type='ReLU'), dict(type='Sigmoid')) 32 | """ 33 | 34 | def __init__(self, 35 | channels, 36 | squeeze_channels=None, 37 | ratio=16, 38 | divisor=8, 39 | bias='auto', 40 | conv_cfg=None, 41 | act_cfg=(dict(type='ReLU'), dict(type='Sigmoid')), 42 | return_weight=False, 43 | init_cfg=None): 44 | super(SELayer, self).__init__(init_cfg) 45 | if isinstance(act_cfg, dict): 46 | act_cfg = (act_cfg, act_cfg) 47 | assert len(act_cfg) == 2 48 | assert mmcv.is_tuple_of(act_cfg, dict) 49 | self.global_avgpool = nn.AdaptiveAvgPool2d(1) 50 | if squeeze_channels is None: 51 | squeeze_channels = make_divisible(channels // ratio, divisor) 52 | assert isinstance(squeeze_channels, int) and squeeze_channels > 0, \ 53 | '"squeeze_channels" should be a positive integer, but get ' + \ 54 | f'{squeeze_channels} instead.' 55 | self.return_weight = return_weight 56 | self.conv1 = ConvModule( 57 | in_channels=channels, 58 | out_channels=squeeze_channels, 59 | kernel_size=1, 60 | stride=1, 61 | bias=bias, 62 | conv_cfg=conv_cfg, 63 | act_cfg=act_cfg[0]) 64 | self.conv2 = ConvModule( 65 | in_channels=squeeze_channels, 66 | out_channels=channels, 67 | kernel_size=1, 68 | stride=1, 69 | bias=bias, 70 | conv_cfg=conv_cfg, 71 | act_cfg=act_cfg[1]) 72 | 73 | def forward(self, x): 74 | out = self.global_avgpool(x) 75 | out = self.conv1(out) 76 | out = self.conv2(out) 77 | if self.return_weight: 78 | return out 79 | else: 80 | return x * out 81 | -------------------------------------------------------------------------------- /mmcls/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collect_env import collect_env 3 | from .device import auto_select_device 4 | from .distribution import wrap_distributed_model, wrap_non_distributed_model 5 | from .logger import get_root_logger, load_json_log 6 | from .setup_env import setup_multi_processes 7 | 8 | __all__ = [ 9 | 'collect_env', 'get_root_logger', 'load_json_log', 'setup_multi_processes', 10 | 'wrap_non_distributed_model', 'wrap_distributed_model', 11 | 'auto_select_device' 12 | ] 13 | -------------------------------------------------------------------------------- /mmcls/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_base_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmcls 6 | 7 | 8 | def collect_env(): 9 | """Collect the information of the running environments.""" 10 | env_info = collect_base_env() 11 | env_info['MMClassification'] = mmcls.__version__ + '+' + get_git_hash()[:7] 12 | return env_info 13 | 14 | 15 | if __name__ == '__main__': 16 | for name, val in collect_env().items(): 17 | print(f'{name}: {val}') 18 | -------------------------------------------------------------------------------- /mmcls/utils/device.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcv 3 | import torch 4 | from mmcv.utils import digit_version 5 | 6 | 7 | def auto_select_device() -> str: 8 | mmcv_version = digit_version(mmcv.__version__) 9 | if mmcv_version >= digit_version('1.6.0'): 10 | from mmcv.device import get_device 11 | return get_device() 12 | elif torch.cuda.is_available(): 13 | return 'cuda' 14 | else: 15 | return 'cpu' 16 | -------------------------------------------------------------------------------- /mmcls/utils/distribution.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | 3 | 4 | def wrap_non_distributed_model(model, device='cuda', dim=0, *args, **kwargs): 5 | """Wrap module in non-distributed environment by device type. 6 | 7 | - For CUDA, wrap as :obj:`mmcv.parallel.MMDataParallel`. 8 | - For MPS, wrap as :obj:`mmcv.device.mps.MPSDataParallel`. 9 | - For CPU & IPU, not wrap the model. 10 | 11 | Args: 12 | model(:class:`nn.Module`): model to be parallelized. 13 | device(str): device type, cuda, cpu or mlu. Defaults to cuda. 14 | dim(int): Dimension used to scatter the data. Defaults to 0. 15 | 16 | Returns: 17 | model(nn.Module): the model to be parallelized. 18 | """ 19 | if device == 'cuda': 20 | from mmcv.parallel import MMDataParallel 21 | model = MMDataParallel(model.cuda(), dim=dim, *args, **kwargs) 22 | elif device == 'cpu': 23 | model = model.cpu() 24 | elif device == 'ipu': 25 | model = model.cpu() 26 | elif device == 'mps': 27 | from mmcv.device import mps 28 | model = mps.MPSDataParallel(model.to('mps'), dim=dim, *args, **kwargs) 29 | else: 30 | raise RuntimeError(f'Unavailable device "{device}"') 31 | 32 | return model 33 | 34 | 35 | def wrap_distributed_model(model, device='cuda', *args, **kwargs): 36 | """Build DistributedDataParallel module by device type. 37 | 38 | - For CUDA, wrap as :obj:`mmcv.parallel.MMDistributedDataParallel`. 39 | - Other device types are not supported by now. 40 | 41 | Args: 42 | model(:class:`nn.Module`): module to be parallelized. 43 | device(str): device type, mlu or cuda. 44 | 45 | Returns: 46 | model(:class:`nn.Module`): the module to be parallelized 47 | 48 | References: 49 | .. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel. 50 | DistributedDataParallel.html 51 | """ 52 | if device == 'cuda': 53 | from mmcv.parallel import MMDistributedDataParallel 54 | model = MMDistributedDataParallel(model.cuda(), *args, **kwargs) 55 | else: 56 | raise RuntimeError(f'Unavailable device "{device}"') 57 | 58 | return model 59 | -------------------------------------------------------------------------------- /mmcls/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import json 3 | import logging 4 | from collections import defaultdict 5 | 6 | from mmcv.utils import get_logger 7 | 8 | 9 | def get_root_logger(log_file=None, log_level=logging.INFO): 10 | """Get root logger. 11 | 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to :obj:`logging.INFO`. 16 | 17 | Returns: 18 | :obj:`logging.Logger`: The obtained logger 19 | """ 20 | return get_logger('mmcls', log_file, log_level) 21 | 22 | 23 | def load_json_log(json_log): 24 | """load and convert json_logs to log_dicts. 25 | 26 | Args: 27 | json_log (str): The path of the json log file. 28 | 29 | Returns: 30 | dict[int, dict[str, list]]: 31 | Key is the epoch, value is a sub dict. The keys in each sub dict 32 | are different metrics, e.g. memory, bbox_mAP, and the value is a 33 | list of corresponding values in all iterations in this epoch. 34 | 35 | .. code-block:: python 36 | 37 | # An example output 38 | { 39 | 1: {'iter': [100, 200, 300], 'loss': [6.94, 6.73, 6.53]}, 40 | 2: {'iter': [100, 200, 300], 'loss': [6.33, 6.20, 6.07]}, 41 | ... 42 | } 43 | """ 44 | log_dict = dict() 45 | with open(json_log, 'r') as log_file: 46 | for line in log_file: 47 | log = json.loads(line.strip()) 48 | # skip lines without `epoch` field 49 | if 'epoch' not in log: 50 | continue 51 | epoch = log.pop('epoch') 52 | if epoch not in log_dict: 53 | log_dict[epoch] = defaultdict(list) 54 | for k, v in log.items(): 55 | log_dict[epoch][k].append(v) 56 | return log_dict 57 | -------------------------------------------------------------------------------- /mmcls/utils/setup_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os 3 | import platform 4 | import warnings 5 | 6 | import cv2 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def setup_multi_processes(cfg): 11 | """Setup multi-processing environment variables.""" 12 | # set multi-process start method as `fork` to speed up the training 13 | if platform.system() != 'Windows': 14 | mp_start_method = cfg.get('mp_start_method', 'fork') 15 | current_method = mp.get_start_method(allow_none=True) 16 | if current_method is not None and current_method != mp_start_method: 17 | warnings.warn( 18 | f'Multi-processing start method `{mp_start_method}` is ' 19 | f'different from the previous setting `{current_method}`.' 20 | f'It will be force set to `{mp_start_method}`. You can change ' 21 | f'this behavior by changing `mp_start_method` in your config.') 22 | mp.set_start_method(mp_start_method, force=True) 23 | 24 | # disable opencv multithreading to avoid system being overloaded 25 | opencv_num_threads = cfg.get('opencv_num_threads', 0) 26 | cv2.setNumThreads(opencv_num_threads) 27 | 28 | # setup OMP threads 29 | # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa 30 | if 'OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 31 | omp_num_threads = 1 32 | warnings.warn( 33 | f'Setting OMP_NUM_THREADS environment variable for each process ' 34 | f'to be {omp_num_threads} in default, to avoid your system being ' 35 | f'overloaded, please further tune the variable for optimal ' 36 | f'performance in your application as needed.') 37 | os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) 38 | 39 | # setup MKL threads 40 | if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1: 41 | mkl_num_threads = 1 42 | warnings.warn( 43 | f'Setting MKL_NUM_THREADS environment variable for each process ' 44 | f'to be {mkl_num_threads} in default, to avoid your system being ' 45 | f'overloaded, please further tune the variable for optimal ' 46 | f'performance in your application as needed.') 47 | os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) 48 | -------------------------------------------------------------------------------- /mmcls/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved 2 | 3 | __version__ = '0.23.2_71ef7ba' 4 | 5 | 6 | def parse_version_info(version_str): 7 | """Parse a version string into a tuple. 8 | 9 | Args: 10 | version_str (str): The version string. 11 | Returns: 12 | tuple[int | str]: The version info, e.g., "1.3.0" is parsed into 13 | (1, 3, 0), and "2.0.0rc1" is parsed into (2, 0, 0, 'rc1'). 14 | """ 15 | version_info = [] 16 | for x in version_str.split('.'): 17 | if x.isdigit(): 18 | version_info.append(int(x)) 19 | elif x.find('rc') != -1: 20 | patch_version = x.split('rc') 21 | version_info.append(int(patch_version[0])) 22 | version_info.append(f'rc{patch_version[1]}') 23 | return tuple(version_info) 24 | 25 | 26 | version_info = parse_version_info(__version__) 27 | 28 | __all__ = ['__version__', 'version_info', 'parse_version_info'] 29 | -------------------------------------------------------------------------------- /mmfewshot/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import mmcls # noqa: F401, F403 3 | import mmcv 4 | # import mmdet # noqa: F401, F403 5 | 6 | from .classification import * # noqa: F401, F403 7 | # from .detection import * # noqa: F401, F403 8 | from .utils import * # noqa: F401, F403 9 | from .version import __version__, short_version 10 | 11 | 12 | def digit_version(version_str): 13 | digit_version_ = [] 14 | for x in version_str.split('.'): 15 | if x.isdigit(): 16 | digit_version_.append(int(x)) 17 | elif x.find('rc') != -1: 18 | patch_version = x.split('rc') 19 | digit_version_.append(int(patch_version[0]) - 1) 20 | digit_version_.append(int(patch_version[1])) 21 | return digit_version_ 22 | 23 | 24 | mmcv_minimum_version = '1.3.12' 25 | # By harbor : 1.6.1 seems ok 26 | mmcv_maximum_version = '1.6.1' 27 | mmcv_version = digit_version(mmcv.__version__) 28 | 29 | 30 | assert (digit_version(mmcv_minimum_version) <= mmcv_version 31 | <= digit_version(mmcv_maximum_version)), \ 32 | f'MMCV=={mmcv.__version__} is used but incompatible. ' \ 33 | f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.' 34 | 35 | # By harbor : No need for mmdet 36 | # mmdet_minimum_version = '2.16.0' 37 | # mmdet_maximum_version = '2.25.0' 38 | # mmdet_version = digit_version(mmdet.__version__) 39 | # 40 | # 41 | # assert (digit_version(mmdet_minimum_version) <= mmdet_version 42 | # <= digit_version(mmdet_maximum_version)), \ 43 | # f'MMDET=={mmdet.__version__} is used but incompatible. ' \ 44 | # f'Please install mmdet>={mmdet_minimum_version},\ 45 | # <={mmdet_maximum_version}.' 46 | 47 | mmcls_minimum_version = '0.15.0' 48 | mmcls_maximum_version = '0.25.0' 49 | mmcls_version = digit_version(mmcls.__version__) 50 | 51 | 52 | assert (digit_version(mmcls_minimum_version) <= mmcls_version 53 | <= digit_version(mmcls_maximum_version)), \ 54 | f'MMCLS=={mmcls.__version__} is used but incompatible. ' \ 55 | f'Please install mmcls>={mmcls_minimum_version},\ 56 | <={mmcls_maximum_version}.' 57 | 58 | __all__ = ['__version__', 'short_version'] 59 | -------------------------------------------------------------------------------- /mmfewshot/classification/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .apis import * # noqa: F401,F403 3 | from .core import * # noqa: F401,F403 4 | from .datasets import * # noqa: F401,F403 5 | from .models import * # noqa: F401,F403 6 | from .utils import * # noqa: F401, F403 7 | -------------------------------------------------------------------------------- /mmfewshot/classification/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import (inference_classifier, init_classifier, 3 | process_support_images, show_result_pyplot) 4 | from .test import (Z_SCORE, multi_gpu_meta_test, single_gpu_meta_test, 5 | test_single_task) 6 | from .train import train_model 7 | 8 | __all__ = [ 9 | 'train_model', 'test_single_task', 'Z_SCORE', 'single_gpu_meta_test', 10 | 'multi_gpu_meta_test', 'init_classifier', 'process_support_images', 11 | 'inference_classifier', 'show_result_pyplot' 12 | ] 13 | -------------------------------------------------------------------------------- /mmfewshot/classification/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import * # noqa: F401, F403 3 | -------------------------------------------------------------------------------- /mmfewshot/classification/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import DistMetaTestEvalHook, MetaTestEvalHook 3 | 4 | __all__ = ['MetaTestEvalHook', 'DistMetaTestEvalHook'] 5 | -------------------------------------------------------------------------------- /mmfewshot/classification/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcls.datasets.builder import DATASETS, PIPELINES 3 | 4 | from .base import BaseFewShotDataset 5 | from .builder import (build_dataloader, build_dataset, 6 | build_meta_test_dataloader) 7 | from .cub import CUBDataset 8 | from .dataset_wrappers import EpisodicDataset, MetaTestDataset 9 | from .mini_imagenet import MiniImageNetDataset 10 | from .pipelines import LoadImageFromBytes 11 | from .tiered_imagenet import TieredImageNetDataset 12 | from .utils import label_wrapper 13 | 14 | __all__ = [ 15 | 'build_dataloader', 'build_dataset', 'DATASETS', 'PIPELINES', 'CUBDataset', 16 | 'LoadImageFromBytes', 'build_meta_test_dataloader', 'MiniImageNetDataset', 17 | 'TieredImageNetDataset', 'label_wrapper', 'BaseFewShotDataset', 18 | 'EpisodicDataset', 'MetaTestDataset' 19 | ] 20 | -------------------------------------------------------------------------------- /mmfewshot/classification/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .loading import LoadImageFromBytes 3 | 4 | __all__ = [ 5 | 'LoadImageFromBytes', 6 | ] 7 | -------------------------------------------------------------------------------- /mmfewshot/classification/datasets/pipelines/loading.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import os.path as osp 3 | from typing import Dict 4 | 5 | import mmcv 6 | import numpy as np 7 | from mmcls.datasets.builder import PIPELINES 8 | from mmcls.datasets.pipelines import LoadImageFromFile 9 | 10 | 11 | @PIPELINES.register_module() 12 | class LoadImageFromBytes(LoadImageFromFile): 13 | """Load an image from bytes.""" 14 | 15 | def __call__(self, results: Dict) -> Dict: 16 | if self.file_client is None: 17 | self.file_client = mmcv.FileClient(**self.file_client_args) 18 | if results['img_prefix'] is not None: 19 | filename = osp.join(results['img_prefix'], 20 | results['img_info']['filename']) 21 | else: 22 | filename = results['img_info']['filename'] 23 | if results.get('img_bytes', None) is None: 24 | img_bytes = self.file_client.get(filename) 25 | else: 26 | img_bytes = results.pop('img_bytes') 27 | img = mmcv.imfrombytes(img_bytes, flag=self.color_type) 28 | if self.to_float32: 29 | img = img.astype(np.float32) 30 | 31 | results['filename'] = filename 32 | results['ori_filename'] = results['img_info']['filename'] 33 | results['img'] = img 34 | results['img_shape'] = img.shape 35 | results['ori_shape'] = img.shape 36 | num_channels = 1 if len(img.shape) < 3 else img.shape[2] 37 | results['img_norm_cfg'] = dict( 38 | mean=np.zeros(num_channels, dtype=np.float32), 39 | std=np.ones(num_channels, dtype=np.float32), 40 | to_rgb=False) 41 | return results 42 | -------------------------------------------------------------------------------- /mmfewshot/classification/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import List, Union 3 | 4 | import numpy as np 5 | import torch 6 | from torch import Tensor 7 | 8 | 9 | def label_wrapper(labels: Union[Tensor, np.ndarray, List], 10 | class_ids: List[int]) -> Union[Tensor, np.ndarray, list]: 11 | """Map input labels into range of 0 to numbers of classes-1. 12 | 13 | It is usually used in the meta testing phase, in which the class ids are 14 | random sampled and discontinuous. 15 | 16 | Args: 17 | labels (Tensor | np.ndarray | list): The labels to be wrapped. 18 | class_ids (list[int]): All class ids of labels. 19 | 20 | Returns: 21 | (Tensor | np.ndarray | list): Same type as the input labels. 22 | """ 23 | class_id_map = {class_id: i for i, class_id in enumerate(class_ids)} 24 | if isinstance(labels, torch.Tensor): 25 | wrapped_labels = torch.tensor( 26 | [class_id_map[label.item()] for label in labels]) 27 | wrapped_labels = wrapped_labels.type_as(labels).to(labels.device) 28 | elif isinstance(labels, np.ndarray): 29 | wrapped_labels = np.array([class_id_map[label] for label in labels]) 30 | wrapped_labels = wrapped_labels.astype(labels.dtype) 31 | elif isinstance(labels, (tuple, list)): 32 | wrapped_labels = [class_id_map[label] for label in labels] 33 | else: 34 | raise TypeError('only support torch.Tensor, np.ndarray and list') 35 | return wrapped_labels 36 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcls.models.builder import * # noqa: F401,F403 3 | 4 | from .backbones import * # noqa: F401,F403 5 | from .classifiers import * # noqa: F401,F403 6 | from .heads import * # noqa: F401,F403 7 | from .losses import * # noqa: F401,F403 8 | from .utils import * # noqa: F401,F403 9 | 10 | __all__ = [] 11 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcls.models.builder import BACKBONES 3 | 4 | from .conv4 import Conv4, ConvNet 5 | from .resnet12 import ResNet12 6 | from .wrn import WideResNet, WRN28x10 7 | 8 | __all__ = [ 9 | 'BACKBONES', 'ResNet12', 'Conv4', 'ConvNet', 'WRN28x10', 'WideResNet' 10 | ] 11 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/backbones/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | from torch.distributions import Bernoulli 7 | 8 | # This part of code is modified from https://github.com/kjunelee/MetaOptNet 9 | 10 | 11 | class DropBlock(nn.Module): 12 | 13 | def __init__(self, block_size: int) -> None: 14 | super().__init__() 15 | self.block_size = block_size 16 | 17 | def forward(self, x: Tensor, gamma: float) -> Tensor: 18 | # Randomly zeroes 2D spatial blocks of the input tensor. 19 | if self.training: 20 | batch_size, channels, height, width = x.shape 21 | bernoulli = Bernoulli(gamma) 22 | mask = bernoulli.sample( 23 | (batch_size, channels, height - (self.block_size - 1), 24 | width - (self.block_size - 1))) 25 | mask = mask.to(x.device) 26 | block_mask = self._compute_block_mask(mask) 27 | countM = block_mask.size()[0] * block_mask.size( 28 | )[1] * block_mask.size()[2] * block_mask.size()[3] 29 | count_ones = block_mask.sum() 30 | 31 | return block_mask * x * (countM / count_ones) 32 | else: 33 | return x 34 | 35 | def _compute_block_mask(self, mask: Tensor) -> Tensor: 36 | left_padding = int((self.block_size - 1) / 2) 37 | right_padding = int(self.block_size / 2) 38 | 39 | non_zero_idxes = mask.nonzero() 40 | nr_blocks = non_zero_idxes.shape[0] 41 | 42 | offsets = torch.stack([ 43 | torch.arange(self.block_size).view(-1, 1).expand( 44 | self.block_size, self.block_size).reshape(-1), 45 | torch.arange(self.block_size).repeat(self.block_size), 46 | ]).t() 47 | offsets = torch.cat( 48 | (torch.zeros(self.block_size**2, 2).long(), offsets.long()), 1) 49 | offsets = offsets.to(mask.device) 50 | 51 | if nr_blocks > 0: 52 | non_zero_idxes = non_zero_idxes.repeat(self.block_size**2, 1) 53 | offsets = offsets.repeat(nr_blocks, 1).view(-1, 4) 54 | offsets = offsets.long() 55 | 56 | block_idxes = non_zero_idxes + offsets 57 | padded_mask = F.pad( 58 | mask, 59 | (left_padding, right_padding, left_padding, right_padding)) 60 | padded_mask[block_idxes[:, 0], block_idxes[:, 1], 61 | block_idxes[:, 2], block_idxes[:, 3]] = 1. 62 | else: 63 | padded_mask = F.pad( 64 | mask, 65 | (left_padding, right_padding, left_padding, right_padding)) 66 | 67 | block_mask = 1 - padded_mask 68 | return block_mask 69 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcls.models.builder import CLASSIFIERS 3 | 4 | from .base_finetune import BaseFinetuneClassifier 5 | from .base_metric import BaseMetricClassifier 6 | from .baseline import Baseline 7 | from .baseline_plus import BaselinePlus 8 | from .maml import MAML 9 | from .matching_net import MatchingNet 10 | from .meta_baseline import MetaBaseline 11 | from .neg_margin import NegMargin 12 | from .proto_net import ProtoNet 13 | from .relation_net import RelationNet 14 | 15 | __all__ = [ 16 | 'CLASSIFIERS', 'BaseFinetuneClassifier', 'BaseMetricClassifier', 17 | 'Baseline', 'BaselinePlus', 'ProtoNet', 'MatchingNet', 'RelationNet', 18 | 'NegMargin', 'MetaBaseline', 'MAML' 19 | ] 20 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict 3 | 4 | from mmcls.models.builder import CLASSIFIERS 5 | 6 | from .base_finetune import BaseFinetuneClassifier 7 | 8 | 9 | @CLASSIFIERS.register_module() 10 | class Baseline(BaseFinetuneClassifier): 11 | """Implementation of `Baseline `_. 12 | 13 | Args: 14 | head (dict): Config of classification head for training. 15 | meta_test_head (dict): Config of classification head for meta testing. 16 | the `meta_test_head` only will be built and run in meta testing. 17 | """ 18 | 19 | def __init__(self, 20 | head: Dict = dict( 21 | type='LinearHead', num_classes=100, in_channels=1600), 22 | meta_test_head: Dict = dict( 23 | type='LinearHead', num_classes=5, in_channels=1600), 24 | *args, 25 | **kwargs) -> None: 26 | super().__init__( 27 | head=head, meta_test_head=meta_test_head, *args, **kwargs) 28 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/baseline_plus.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict 3 | 4 | from mmcls.models.builder import CLASSIFIERS 5 | 6 | from .base_finetune import BaseFinetuneClassifier 7 | 8 | 9 | @CLASSIFIERS.register_module() 10 | class BaselinePlus(BaseFinetuneClassifier): 11 | """Implementation of `Baseline++ `_. 12 | 13 | Args: 14 | head (dict): Config of classification head for training. 15 | meta_test_head (dict): Config of classification head for meta testing. 16 | the `meta_test_head` only will be built and run in meta testing. 17 | """ 18 | 19 | def __init__(self, 20 | head: Dict = dict( 21 | type='CosineDistanceHead', 22 | num_classes=100, 23 | in_channels=1600), 24 | meta_test_head: Dict = dict( 25 | type='CosineDistanceHead', 26 | num_classes=5, 27 | in_channels=1600), 28 | *args, 29 | **kwargs) -> None: 30 | super().__init__( 31 | head=head, meta_test_head=meta_test_head, *args, **kwargs) 32 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/matching_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from typing import Dict 4 | 5 | from mmcls.models.builder import CLASSIFIERS 6 | 7 | from .base_metric import BaseMetricClassifier 8 | 9 | 10 | @CLASSIFIERS.register_module() 11 | class MatchingNet(BaseMetricClassifier): 12 | """Implementation of `MatchingNet `_.""" 13 | 14 | def __init__(self, head: Dict = dict(type='MatchingHead'), *args, 15 | **kwargs) -> None: 16 | self.head_cfg = copy.deepcopy(head) 17 | super().__init__(head=head, *args, **kwargs) 18 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/meta_baseline.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict 3 | 4 | from mmcls.models.builder import CLASSIFIERS 5 | 6 | from .base_metric import BaseMetricClassifier 7 | 8 | 9 | @CLASSIFIERS.register_module() 10 | class MetaBaseline(BaseMetricClassifier): 11 | """Implementation of `MetaBaseline `_. 12 | 13 | Args: 14 | head (dict): Config of classification head for training. 15 | """ 16 | 17 | def __init__(self, 18 | head: Dict = dict(type='MetaBaselineHead'), 19 | *args, 20 | **kwargs) -> None: 21 | super().__init__(head=head, *args, **kwargs) 22 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/neg_margin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict 3 | 4 | from mmcls.models.builder import CLASSIFIERS 5 | 6 | from .base_finetune import BaseFinetuneClassifier 7 | 8 | 9 | @CLASSIFIERS.register_module() 10 | class NegMargin(BaseFinetuneClassifier): 11 | """Implementation of `NegMargin `_.""" 12 | 13 | def __init__(self, 14 | head: Dict = dict( 15 | type='NegMarginHead', 16 | metric_type='cosine', 17 | num_classes=100, 18 | in_channels=1600, 19 | margin=-0.02, 20 | temperature=30.0), 21 | meta_test_head: Dict = dict( 22 | type='NegMarginHead', 23 | metric_type='cosine', 24 | num_classes=5, 25 | in_channels=1600, 26 | margin=0.0, 27 | temperature=5.0), 28 | *args, 29 | **kwargs) -> None: 30 | super().__init__( 31 | head=head, meta_test_head=meta_test_head, *args, **kwargs) 32 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/proto_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from typing import Dict 4 | 5 | from mmcls.models.builder import CLASSIFIERS 6 | 7 | from .base_metric import BaseMetricClassifier 8 | 9 | 10 | @CLASSIFIERS.register_module() 11 | class ProtoNet(BaseMetricClassifier): 12 | """Implementation of `ProtoNet `_.""" 13 | 14 | def __init__(self, 15 | head: Dict = dict(type='PrototypeHead'), 16 | *args, 17 | **kwargs) -> None: 18 | self.head_cfg = copy.deepcopy(head) 19 | super().__init__(head=head, *args, **kwargs) 20 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/classifiers/relation_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from typing import Dict 4 | 5 | from mmcls.models.builder import CLASSIFIERS 6 | 7 | from .base_metric import BaseMetricClassifier 8 | 9 | 10 | @CLASSIFIERS.register_module() 11 | class RelationNet(BaseMetricClassifier): 12 | """Implementation of `RelationNet `_.""" 13 | 14 | def __init__(self, 15 | head: Dict = dict( 16 | type='RelationHead', 17 | in_channels=64, 18 | feature_size=(19, 19)), 19 | *args, 20 | **kwargs) -> None: 21 | self.head_cfg = copy.deepcopy(head) 22 | super().__init__(head=head, *args, **kwargs) 23 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcls.models.builder import HEADS 3 | 4 | from .cosine_distance_head import CosineDistanceHead 5 | from .linear_head import LinearHead 6 | from .matching_head import MatchingHead 7 | from .meta_baseline_head import MetaBaselineHead 8 | from .neg_margin_head import NegMarginHead 9 | from .prototype_head import PrototypeHead 10 | from .relation_head import RelationHead 11 | 12 | __all__ = [ 13 | 'HEADS', 'MetaBaselineHead', 'MatchingHead', 'NegMarginHead', 'LinearHead', 14 | 'CosineDistanceHead', 'PrototypeHead', 'RelationHead' 15 | ] 16 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/heads/base_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from abc import ABCMeta, abstractmethod 3 | from typing import Dict, Tuple 4 | 5 | from mmcls.models.builder import HEADS, build_loss 6 | from mmcls.models.losses import Accuracy 7 | from mmcv.runner import BaseModule 8 | from torch import Tensor 9 | 10 | 11 | @HEADS.register_module() 12 | class BaseFewShotHead(BaseModule, metaclass=ABCMeta): 13 | """Base head for few shot classifier. 14 | 15 | Args: 16 | loss (dict): Training loss. 17 | topk (tuple[int]): Topk metric for computing the accuracy. 18 | cal_acc (bool): Whether to compute the accuracy during training. 19 | Default: False. 20 | """ 21 | 22 | def __init__(self, 23 | loss: Dict = dict(type='CrossEntropyLoss', loss_weight=1.0), 24 | topk: Tuple[int] = (1, ), 25 | cal_acc: bool = False) -> None: 26 | super().__init__() 27 | assert isinstance(loss, dict) 28 | assert isinstance(topk, (int, tuple)) 29 | if isinstance(topk, int): 30 | topk = (topk, ) 31 | for _topk in topk: 32 | assert _topk > 0, 'Top-k should be larger than 0' 33 | self.topk = topk 34 | 35 | self.compute_loss = build_loss(loss) 36 | self.compute_accuracy = Accuracy(topk=self.topk) 37 | self.cal_acc = cal_acc 38 | 39 | def loss(self, cls_score: Tensor, gt_label: Tensor) -> Dict: 40 | """Calculate loss. 41 | 42 | Args: 43 | cls_score (Tensor): The prediction. 44 | gt_label (Tensor): The learning target of the prediction. 45 | 46 | Returns: 47 | Dict: The calculated loss. 48 | """ 49 | num_samples = len(cls_score) 50 | losses = dict() 51 | # compute loss 52 | loss = self.compute_loss(cls_score, gt_label, avg_factor=num_samples) 53 | if self.cal_acc: 54 | # compute accuracy 55 | acc = self.compute_accuracy(cls_score, gt_label) 56 | assert len(acc) == len(self.topk) 57 | losses['accuracy'] = { 58 | f'top-{k}': a 59 | for k, a in zip(self.topk, acc) 60 | } 61 | losses['loss'] = loss 62 | return losses 63 | 64 | @abstractmethod 65 | def forward_train(self, **kwargs): 66 | """Forward training data.""" 67 | 68 | @abstractmethod 69 | def forward_support(self, x, gt_label, **kwargs): 70 | """Forward support data in meta testing.""" 71 | 72 | @abstractmethod 73 | def forward_query(self, x, **kwargs): 74 | """Forward query data in meta testing.""" 75 | 76 | @abstractmethod 77 | def before_forward_support(self): 78 | """Used in meta testing. 79 | 80 | This function will be called before model forward support data during 81 | meta testing. 82 | """ 83 | 84 | @abstractmethod 85 | def before_forward_query(self): 86 | """Used in meta testing. 87 | 88 | This function will be called before model forward query data during 89 | meta testing. 90 | """ 91 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/heads/linear_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, List 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from mmcls.models.builder import HEADS 7 | from torch import Tensor 8 | 9 | from .base_head import BaseFewShotHead 10 | 11 | 12 | @HEADS.register_module() 13 | class LinearHead(BaseFewShotHead): 14 | """Classification head for Baseline. 15 | 16 | Args: 17 | num_classes (int): Number of categories. 18 | in_channels (int): Number of channels in the input feature map. 19 | """ 20 | 21 | def __init__(self, num_classes: int, in_channels: int, *args, 22 | **kwargs) -> None: 23 | super().__init__(*args, **kwargs) 24 | assert num_classes > 0, f'num_classes={num_classes} ' \ 25 | f'must be a positive integer' 26 | 27 | self.num_classes = num_classes 28 | self.in_channels = in_channels 29 | 30 | self.init_layers() 31 | 32 | def init_layers(self) -> None: 33 | self.fc = nn.Linear(self.in_channels, self.num_classes) 34 | 35 | def forward_train(self, x: Tensor, gt_label: Tensor, **kwargs) -> Dict: 36 | """Forward training data.""" 37 | cls_score = self.fc(x) 38 | losses = self.loss(cls_score, gt_label) 39 | return losses 40 | 41 | def forward_support(self, x: Tensor, gt_label: Tensor, **kwargs) -> Dict: 42 | """Forward support data in meta testing.""" 43 | return self.forward_train(x, gt_label, **kwargs) 44 | 45 | def forward_query(self, x: Tensor, **kwargs) -> List: 46 | """Forward query data in meta testing.""" 47 | cls_score = self.fc(x) 48 | pred = F.softmax(cls_score, dim=1) 49 | pred = list(pred.detach().cpu().numpy()) 50 | return pred 51 | 52 | def before_forward_support(self) -> None: 53 | """Used in meta testing. 54 | 55 | This function will be called before model forward support data during 56 | meta testing. 57 | """ 58 | self.init_layers() 59 | self.train() 60 | 61 | def before_forward_query(self) -> None: 62 | """Used in meta testing. 63 | 64 | This function will be called before model forward query data during 65 | meta testing. 66 | """ 67 | self.eval() 68 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .mse_loss import MSELoss 3 | from .nll_loss import NLLLoss 4 | 5 | __all__ = ['MSELoss', 'NLLLoss'] 6 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/losses/mse_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional, Union 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from mmcls.models.builder import LOSSES 7 | from mmcls.models.losses.utils import weighted_loss 8 | from torch import Tensor 9 | from typing_extensions import Literal 10 | 11 | 12 | @weighted_loss 13 | def mse_loss(pred: Tensor, target: Tensor) -> Tensor: 14 | """Wrapper of mse loss.""" 15 | return F.mse_loss(pred, target, reduction='none') 16 | 17 | 18 | @LOSSES.register_module() 19 | class MSELoss(nn.Module): 20 | """MSELoss. 21 | 22 | Args: 23 | reduction (str): The method that reduces the loss to a 24 | scalar. Options are "none", "mean" and "sum". Default: 'mean'. 25 | loss_weight (float): The weight of the loss. Default: 1.0. 26 | """ 27 | 28 | def __init__(self, 29 | reduction: Literal['none', 'mean', 'sum'] = 'mean', 30 | loss_weight: float = 1.0) -> None: 31 | super().__init__() 32 | self.reduction = reduction 33 | self.loss_weight = loss_weight 34 | 35 | def forward(self, 36 | pred: Tensor, 37 | target: Tensor, 38 | weight: Optional[Tensor] = None, 39 | avg_factor: Optional[Union[float, int]] = None, 40 | reduction_override: str = None) -> Tensor: 41 | """Forward function of loss. 42 | 43 | Args: 44 | pred (Tensor): The prediction with shape (N, *), where * means 45 | any number of additional dimensions. 46 | target (Tensor): The learning target of the prediction 47 | with shape (N, *) same as the input. 48 | weight (Tensor | None): Weight of the loss for each 49 | prediction. Default: None. 50 | avg_factor (float | int | None): Average factor that is used to 51 | average the loss. Default: None. 52 | reduction_override (str | None): The reduction method used to 53 | override the original reduction method of the loss. 54 | Options are "none", "mean" and "sum". Default: None. 55 | 56 | Returns: 57 | Tensor: The calculated loss 58 | """ 59 | assert reduction_override in (None, 'none', 'mean', 'sum') 60 | reduction = ( 61 | reduction_override if reduction_override else self.reduction) 62 | loss = self.loss_weight * mse_loss( 63 | pred, target, weight, reduction=reduction, avg_factor=avg_factor) 64 | return loss 65 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/losses/nll_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional, Union 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | from mmcls.models.builder import LOSSES 7 | from mmcls.models.losses.utils import weighted_loss 8 | from torch import Tensor 9 | from typing_extensions import Literal 10 | 11 | 12 | @weighted_loss 13 | def nll_loss(pred: Tensor, target: Tensor) -> Tensor: 14 | """Wrapper of nll loss.""" 15 | return F.nll_loss(pred, target, reduction='none') 16 | 17 | 18 | @LOSSES.register_module() 19 | class NLLLoss(nn.Module): 20 | """NLLLoss. 21 | 22 | Args: 23 | reduction (str): The method that reduces the loss to a 24 | scalar. Options are "none", "mean" and "sum". Default: 'mean'. 25 | loss_weight (float): The weight of the loss. Default: 1.0. 26 | """ 27 | 28 | def __init__(self, 29 | reduction: Literal['none', 'mean', 'sum'] = 'mean', 30 | loss_weight: float = 1.0): 31 | super().__init__() 32 | self.reduction = reduction 33 | self.loss_weight = loss_weight 34 | 35 | def forward(self, 36 | pred: Tensor, 37 | target: Tensor, 38 | weight: Optional[Tensor] = None, 39 | avg_factor: Optional[Union[float, int]] = None, 40 | reduction_override: Optional[str] = None) -> Tensor: 41 | """Forward function of loss. 42 | 43 | Args: 44 | pred (Tensor): The prediction with shape (N, C). 45 | target (Tensor): The learning target of the prediction. 46 | with shape (N, 1). 47 | weight (Tensor | None): Weight of the loss for each 48 | prediction. Default: None. 49 | avg_factor (float | int | None): Average factor that is used to 50 | average the loss. Default: None. 51 | reduction_override (str | None): The reduction method used to 52 | override the original reduction method of the loss. 53 | Options are "none", "mean" and "sum". Default: None. 54 | 55 | Returns: 56 | Tensor: The calculated loss 57 | """ 58 | assert reduction_override in (None, 'none', 'mean', 'sum') 59 | reduction = ( 60 | reduction_override if reduction_override else self.reduction) 61 | loss = self.loss_weight * nll_loss( 62 | pred, target, weight, reduction=reduction, avg_factor=avg_factor) 63 | return loss 64 | -------------------------------------------------------------------------------- /mmfewshot/classification/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .maml_module import convert_maml_module 3 | 4 | __all__ = ['convert_maml_module'] 5 | -------------------------------------------------------------------------------- /mmfewshot/classification/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .meta_test_parallel import MetaTestParallel 3 | 4 | __all__ = ['MetaTestParallel'] 5 | -------------------------------------------------------------------------------- /mmfewshot/classification/utils/meta_test_parallel.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmcv.parallel.scatter_gather import scatter_kwargs 4 | 5 | 6 | class MetaTestParallel(nn.Module): 7 | """The MetaTestParallel module that supports DataContainer. 8 | 9 | Note that each task is tested on a single GPU. Thus the data and model 10 | on different GPU should be independent. :obj:`MMDistributedDataParallel` 11 | always automatically synchronizes the grad in different GPUs when doing 12 | the loss backward, which can not meet the requirements. Thus we simply 13 | copy the module and wrap it with an :obj:`MetaTestParallel`, which will 14 | send data to the device model. 15 | 16 | MetaTestParallel has two main differences with PyTorch DataParallel: 17 | 18 | - It supports a custom type :class:`DataContainer` which allows 19 | more flexible control of input data during both GPU and CPU 20 | inference. 21 | - It implement three more APIs ``before_meta_test()``, 22 | ``before_forward_support()`` and ``before_forward_query()``. 23 | 24 | Args: 25 | module (:class:`nn.Module`): Module to be encapsulated. 26 | dim (int): Dimension used to scatter the data. Defaults to 0. 27 | """ 28 | 29 | def __init__(self, module: nn.Module, dim: int = 0) -> None: 30 | super().__init__() 31 | self.dim = dim 32 | self.module = module 33 | self.device = self.module.device 34 | if self.device == 'cpu': 35 | self.device_id = [-1] 36 | else: 37 | self.device_id = [self.module.get_device()] 38 | 39 | def forward(self, *inputs, **kwargs): 40 | """Override the original forward function. 41 | 42 | The main difference lies in the CPU inference where the data in 43 | :class:`DataContainers` will still be gathered. 44 | """ 45 | 46 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id) 47 | if not inputs and not kwargs: 48 | inputs = ((), ) 49 | kwargs = ({}, ) 50 | return self.module(*inputs[0], **kwargs[0]) 51 | 52 | def scatter(self, inputs, kwargs, device_ids): 53 | return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) 54 | 55 | def before_meta_test(self, *inputs, **kwargs) -> None: 56 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id) 57 | if not inputs and not kwargs: 58 | inputs = ((), ) 59 | kwargs = ({}, ) 60 | return self.module.before_meta_test(*inputs[0], **kwargs[0]) 61 | 62 | def before_forward_support(self, *inputs, **kwargs) -> None: 63 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id) 64 | if not inputs and not kwargs: 65 | inputs = ((), ) 66 | kwargs = ({}, ) 67 | return self.module.before_forward_support(*inputs[0], **kwargs[0]) 68 | 69 | def before_forward_query(self, *inputs, **kwargs) -> None: 70 | inputs, kwargs = self.scatter(inputs, kwargs, self.device_id) 71 | if not inputs and not kwargs: 72 | inputs = ((), ) 73 | kwargs = ({}, ) 74 | return self.module.before_forward_query(*inputs[0], **kwargs[0]) 75 | -------------------------------------------------------------------------------- /mmfewshot/detection/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .apis import * # noqa: F401,F403 3 | from .core import * # noqa: F401,F403 4 | from .datasets import * # noqa: F401,F403 5 | from .models import * # noqa: F401,F403 6 | -------------------------------------------------------------------------------- /mmfewshot/detection/apis/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .inference import (inference_detector, init_detector, 3 | process_support_images) 4 | from .test import (multi_gpu_model_init, multi_gpu_test, single_gpu_model_init, 5 | single_gpu_test) 6 | from .train import train_detector 7 | 8 | __all__ = [ 9 | 'train_detector', 'single_gpu_model_init', 'multi_gpu_model_init', 10 | 'single_gpu_test', 'multi_gpu_test', 'inference_detector', 'init_detector', 11 | 'process_support_images' 12 | ] 13 | -------------------------------------------------------------------------------- /mmfewshot/detection/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .evaluation import * # noqa: F401, F403 3 | from .utils import * # noqa: F401, F403 4 | -------------------------------------------------------------------------------- /mmfewshot/detection/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .eval_hooks import QuerySupportDistEvalHook, QuerySupportEvalHook 3 | from .mean_ap import eval_map 4 | 5 | __all__ = ['QuerySupportEvalHook', 'QuerySupportDistEvalHook', 'eval_map'] 6 | -------------------------------------------------------------------------------- /mmfewshot/detection/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .custom_hook import ContrastiveLossDecayHook 3 | 4 | __all__ = ['ContrastiveLossDecayHook'] 5 | -------------------------------------------------------------------------------- /mmfewshot/detection/core/utils/custom_hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Sequence 3 | 4 | from mmcv.parallel import is_module_wrapper 5 | from mmcv.runner import HOOKS, Hook, Runner 6 | 7 | 8 | @HOOKS.register_module() 9 | class ContrastiveLossDecayHook(Hook): 10 | """Hook for contrast loss weight decay used in FSCE. 11 | 12 | Args: 13 | decay_steps (list[int] | tuple[int]): Each item in the list is 14 | the step to decay the loss weight. 15 | decay_rate (float): Decay rate. Default: 0.5. 16 | """ 17 | 18 | def __init__(self, 19 | decay_steps: Sequence[int], 20 | decay_rate: float = 0.5) -> None: 21 | assert isinstance( 22 | decay_steps, 23 | (list, tuple)), '`decay_steps` should be list or tuple.' 24 | self.decay_steps = decay_steps 25 | self.decay_rate = decay_rate 26 | 27 | def before_iter(self, runner: Runner) -> None: 28 | runner_iter = runner.iter + 1 29 | decay_rate = 1.0 30 | # update decay rate by number of iteration 31 | for step in self.decay_steps: 32 | if runner_iter > step: 33 | decay_rate *= self.decay_rate 34 | # set decay rate in the bbox_head 35 | if is_module_wrapper(runner.model): 36 | runner.model.module.roi_head.bbox_head.set_decay_rate(decay_rate) 37 | else: 38 | runner.model.roi_head.bbox_head.set_decay_rate(decay_rate) 39 | -------------------------------------------------------------------------------- /mmfewshot/detection/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .base import BaseFewShotDataset 3 | from .builder import build_dataloader, build_dataset 4 | from .coco import COCO_SPLIT, FewShotCocoDataset 5 | from .dataloader_wrappers import NWayKShotDataloader 6 | from .dataset_wrappers import NWayKShotDataset, QueryAwareDataset 7 | from .pipelines import CropResizeInstance, GenerateMask 8 | from .utils import NumpyEncoder, get_copy_dataset_type 9 | from .voc import VOC_SPLIT, FewShotVOCDataset 10 | 11 | __all__ = [ 12 | 'build_dataloader', 'build_dataset', 'QueryAwareDataset', 13 | 'NWayKShotDataset', 'NWayKShotDataloader', 'BaseFewShotDataset', 14 | 'FewShotVOCDataset', 'FewShotCocoDataset', 'CropResizeInstance', 15 | 'GenerateMask', 'NumpyEncoder', 'COCO_SPLIT', 'VOC_SPLIT', 16 | 'get_copy_dataset_type' 17 | ] 18 | -------------------------------------------------------------------------------- /mmfewshot/detection/datasets/dataloader_wrappers.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, Iterator 3 | 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class NWayKShotDataloader: 8 | """A dataloader wrapper. 9 | 10 | It Create a iterator to generate query and support batch simultaneously. 11 | Each batch contains query data and support data, and the lengths are 12 | batch_size and (num_support_ways * num_support_shots) respectively. 13 | 14 | Args: 15 | query_data_loader (DataLoader): DataLoader of query dataset 16 | support_data_loader (DataLoader): DataLoader of support datasets. 17 | """ 18 | 19 | def __init__(self, query_data_loader: DataLoader, 20 | support_data_loader: DataLoader) -> None: 21 | self.dataset = query_data_loader.dataset 22 | self.sampler = query_data_loader.sampler 23 | self.query_data_loader = query_data_loader 24 | self.support_data_loader = support_data_loader 25 | 26 | def __iter__(self) -> Iterator: 27 | # if infinite sampler is used, this part of code only run once 28 | self.query_iter = iter(self.query_data_loader) 29 | self.support_iter = iter(self.support_data_loader) 30 | return self 31 | 32 | def __next__(self) -> Dict: 33 | # call query and support iterator 34 | query_data = self.query_iter.next() 35 | support_data = self.support_iter.next() 36 | return {'query_data': query_data, 'support_data': support_data} 37 | 38 | def __len__(self) -> int: 39 | return len(self.query_data_loader) 40 | 41 | 42 | class TwoBranchDataloader: 43 | """A dataloader wrapper. 44 | 45 | It Create a iterator to iterate two different dataloader simultaneously. 46 | Note that `TwoBranchDataloader` dose not support `EpochBasedRunner` 47 | and the length of dataloader is decided by main dataset. 48 | 49 | Args: 50 | main_data_loader (DataLoader): DataLoader of main dataset. 51 | auxiliary_data_loader (DataLoader): DataLoader of auxiliary dataset. 52 | """ 53 | 54 | def __init__(self, main_data_loader: DataLoader, 55 | auxiliary_data_loader: DataLoader) -> None: 56 | self.dataset = main_data_loader.dataset 57 | self.main_data_loader = main_data_loader 58 | self.auxiliary_data_loader = auxiliary_data_loader 59 | 60 | def __iter__(self) -> Iterator: 61 | # if infinite sampler is used, this part of code only run once 62 | self.main_iter = iter(self.main_data_loader) 63 | self.auxiliary_iter = iter(self.auxiliary_data_loader) 64 | return self 65 | 66 | def __next__(self) -> Dict: 67 | # The iterator actually has infinite length. Note that it can NOT 68 | # be used in `EpochBasedRunner`, because the `EpochBasedRunner` will 69 | # enumerate the dataloader forever. 70 | try: 71 | main_data = next(self.main_iter) 72 | except StopIteration: 73 | self.main_iter = iter(self.main_data_loader) 74 | main_data = next(self.main_iter) 75 | try: 76 | auxiliary_data = next(self.auxiliary_iter) 77 | except StopIteration: 78 | self.auxiliary_iter = iter(self.auxiliary_data_loader) 79 | auxiliary_data = next(self.auxiliary_iter) 80 | return {'main_data': main_data, 'auxiliary_data': auxiliary_data} 81 | 82 | def __len__(self) -> int: 83 | return len(self.main_data_loader) 84 | -------------------------------------------------------------------------------- /mmfewshot/detection/datasets/pipelines/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .formatting import MultiImageCollect, MultiImageFormatBundle 3 | from .transforms import (CropInstance, CropResizeInstance, GenerateMask, 4 | MultiImageNormalize, MultiImagePad, 5 | MultiImageRandomCrop, MultiImageRandomFlip, 6 | ResizeToMultiScale) 7 | 8 | __all__ = [ 9 | 'CropResizeInstance', 'GenerateMask', 'CropInstance', 'ResizeToMultiScale', 10 | 'MultiImageNormalize', 'MultiImageFormatBundle', 'MultiImageCollect', 11 | 'MultiImagePad', 'MultiImageRandomCrop', 'MultiImageRandomFlip' 12 | ] 13 | -------------------------------------------------------------------------------- /mmfewshot/detection/datasets/pipelines/formatting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, List 3 | 4 | from mmcv.parallel import DataContainer as DC 5 | from mmdet.datasets.builder import PIPELINES 6 | from mmdet.datasets.pipelines import Collect, DefaultFormatBundle 7 | 8 | 9 | @PIPELINES.register_module() 10 | class MultiImageFormatBundle(DefaultFormatBundle): 11 | 12 | def __call__(self, results_list: List[Dict]) -> List[Dict]: 13 | """Transform and format common fields of each results in 14 | `results_list`. 15 | 16 | Args: 17 | results_list (list[dict]): List of result dict contains the data 18 | to convert. 19 | 20 | Returns: 21 | list[dict]: List of result dict contains the data that is formatted 22 | with default bundle. 23 | """ 24 | for results in results_list: 25 | super().__call__(results) 26 | return results_list 27 | 28 | 29 | @PIPELINES.register_module() 30 | class MultiImageCollect(Collect): 31 | 32 | def __call__(self, results_list: List[Dict]) -> Dict: 33 | """Collect all keys of each results in `results_list`. 34 | 35 | The keys in `meta_keys` will be converted to :obj:mmcv.DataContainer. 36 | A scale suffix also will be added to each key to specific from which 37 | scale of results. 38 | 39 | Args: 40 | results_list (list[dict]): List of result dict contains the data 41 | to collect. 42 | 43 | Returns: 44 | dict: The result dict contains the following keys 45 | 46 | - `{key}_scale_{i}` for i in 'num_scales' for key in`self.keys` 47 | - `img_metas_scale_{i}` for i in 'num_scales' 48 | """ 49 | data = {} 50 | for i, results in enumerate(results_list): 51 | img_meta = {key: results[key] for key in self.meta_keys} 52 | data[f'img_metas_scale{i}'] = DC(img_meta, cpu_only=True) 53 | for key in self.keys: 54 | data[f'{key}_scale_{i}'] = results[key] 55 | return data 56 | -------------------------------------------------------------------------------- /mmfewshot/detection/datasets/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import json 3 | 4 | import numpy as np 5 | 6 | 7 | class NumpyEncoder(json.JSONEncoder): 8 | """Save numpy array obj to json.""" 9 | 10 | def default(self, obj: object) -> object: 11 | if isinstance(obj, np.ndarray): 12 | return obj.tolist() 13 | return json.JSONEncoder.default(self, obj) 14 | 15 | 16 | def get_copy_dataset_type(dataset_type: str) -> str: 17 | """Return corresponding copy dataset type.""" 18 | if dataset_type in ['FewShotVOCDataset', 'FewShotVOCDefaultDataset']: 19 | copy_dataset_type = 'FewShotVOCCopyDataset' 20 | elif dataset_type in ['FewShotCocoDataset', 'FewShotCocoDefaultDataset']: 21 | copy_dataset_type = 'FewShotCocoCopyDataset' 22 | else: 23 | raise TypeError(f'{dataset_type} ' 24 | f'not support copy data_infos operation.') 25 | 26 | return copy_dataset_type 27 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.models.builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS, 3 | ROI_EXTRACTORS, SHARED_HEADS, build_backbone, 4 | build_head, build_loss, build_neck, 5 | build_roi_extractor, build_shared_head) 6 | 7 | from .backbones import * # noqa: F401,F403 8 | from .builder import build_detector 9 | from .dense_heads import * # noqa: F401,F403 10 | from .detectors import * # noqa: F401,F403 11 | from .losses import * # noqa: F401,F403 12 | from .roi_heads import * # noqa: F401,F403 13 | from .utils import * # noqa: F401,F403 14 | 15 | __all__ = [ 16 | 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES', 17 | 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor', 18 | 'build_shared_head', 'build_head', 'build_loss', 'build_detector' 19 | ] 20 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .resnet_with_meta_conv import ResNetWithMetaConv 3 | 4 | __all__ = ['ResNetWithMetaConv'] 5 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/backbones/resnet_with_meta_conv.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Tuple 3 | 4 | from mmcv.cnn import build_conv_layer 5 | from mmdet.models import ResNet 6 | from mmdet.models.builder import BACKBONES 7 | from torch import Tensor 8 | 9 | 10 | @BACKBONES.register_module() 11 | class ResNetWithMetaConv(ResNet): 12 | """ResNet with `meta_conv` to handle different inputs in metarcnn and 13 | fsdetview. 14 | 15 | When input with shape (N, 3, H, W) from images, the network will use 16 | `conv1` as regular ResNet. When input with shape (N, 4, H, W) from (image + 17 | mask) the network will replace `conv1` with `meta_conv` to handle 18 | additional channel. 19 | """ 20 | 21 | def __init__(self, **kwargs) -> None: 22 | super().__init__(**kwargs) 23 | self.meta_conv = build_conv_layer( 24 | self.conv_cfg, # from config of ResNet 25 | 4, 26 | 64, 27 | kernel_size=7, 28 | stride=2, 29 | padding=3, 30 | bias=False) 31 | 32 | def forward(self, x: Tensor, use_meta_conv: bool = False) -> Tuple[Tensor]: 33 | """Forward function. 34 | 35 | When input with shape (N, 3, H, W) from images, the network will use 36 | `conv1` as regular ResNet. When input with shape (N, 4, H, W) from 37 | (image + mask) the network will replace `conv1` with `meta_conv` to 38 | handle additional channel. 39 | 40 | Args: 41 | x (Tensor): Tensor with shape (N, 3, H, W) from images 42 | or (N, 4, H, W) from (images + masks). 43 | use_meta_conv (bool): If set True, forward input tensor with 44 | `meta_conv` which require tensor with shape (N, 4, H, W). 45 | Otherwise, forward input tensor with `conv1` which require 46 | tensor with shape (N, 3, H, W). Default: False. 47 | 48 | Returns: 49 | tuple[Tensor]: Tuple of features, each item with 50 | shape (N, C, H, W). 51 | """ 52 | if use_meta_conv: 53 | x = self.meta_conv(x) 54 | else: 55 | x = self.conv1(x) 56 | x = self.norm1(x) 57 | x = self.relu(x) 58 | x = self.maxpool(x) 59 | outs = [] 60 | for i, layer_name in enumerate(self.res_layers): 61 | res_layer = getattr(self, layer_name) 62 | x = res_layer(x) 63 | if i in self.out_indices: 64 | outs.append(x) 65 | return tuple(outs) 66 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/builder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Optional 3 | 4 | from mmcv.utils import ConfigDict, print_log 5 | from mmdet.models.builder import DETECTORS 6 | 7 | 8 | def build_detector(cfg: ConfigDict, logger: Optional[object] = None): 9 | """Build detector.""" 10 | # get the prefix of fixed parameters 11 | frozen_parameters = cfg.pop('frozen_parameters', None) 12 | 13 | model = DETECTORS.build(cfg) 14 | model.init_weights() 15 | # freeze parameters by prefix 16 | if frozen_parameters is not None: 17 | print_log(f'Frozen parameters: {frozen_parameters}', logger) 18 | for name, param in model.named_parameters(): 19 | for frozen_prefix in frozen_parameters: 20 | if frozen_prefix in name: 21 | param.requires_grad = False 22 | if param.requires_grad: 23 | print_log(f'Training parameters: {name}', logger) 24 | return model 25 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/dense_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .attention_rpn_head import AttentionRPNHead 3 | from .two_branch_rpn_head import TwoBranchRPNHead 4 | 5 | __all__ = ['AttentionRPNHead', 'TwoBranchRPNHead'] 6 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .attention_rpn_detector import AttentionRPNDetector 3 | from .fsce import FSCE 4 | from .fsdetview import FSDetView 5 | from .meta_rcnn import MetaRCNN 6 | from .mpsr import MPSR 7 | from .query_support_detector import QuerySupportDetector 8 | from .tfa import TFA 9 | 10 | __all__ = [ 11 | 'QuerySupportDetector', 'AttentionRPNDetector', 'FSCE', 'FSDetView', 'TFA', 12 | 'MPSR', 'MetaRCNN' 13 | ] 14 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/detectors/fsce.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.models.builder import DETECTORS 3 | from mmdet.models.detectors.two_stage import TwoStageDetector 4 | 5 | 6 | @DETECTORS.register_module() 7 | class FSCE(TwoStageDetector): 8 | """Implementation of `FSCE `_""" 9 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/detectors/fsdetview.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.models.builder import DETECTORS 3 | 4 | from .meta_rcnn import MetaRCNN 5 | 6 | 7 | @DETECTORS.register_module() 8 | class FSDetView(MetaRCNN): 9 | """Implementation of `FSDetView `_.""" 10 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/detectors/tfa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmdet.models.builder import DETECTORS 3 | from mmdet.models.detectors.two_stage import TwoStageDetector 4 | 5 | 6 | @DETECTORS.register_module() 7 | class TFA(TwoStageDetector): 8 | """Implementation of `TFA `_""" 9 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .supervised_contrastive_loss import SupervisedContrastiveLoss 3 | 4 | __all__ = ['SupervisedContrastiveLoss'] 5 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/roi_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .bbox_heads import (ContrastiveBBoxHead, CosineSimBBoxHead, 3 | MultiRelationBBoxHead) 4 | from .contrastive_roi_head import ContrastiveRoIHead 5 | from .fsdetview_roi_head import FSDetViewRoIHead 6 | from .meta_rcnn_roi_head import MetaRCNNRoIHead 7 | from .multi_relation_roi_head import MultiRelationRoIHead 8 | from .shared_heads import MetaRCNNResLayer 9 | from .two_branch_roi_head import TwoBranchRoIHead 10 | 11 | __all__ = [ 12 | 'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead', 13 | 'ContrastiveRoIHead', 'MultiRelationRoIHead', 'FSDetViewRoIHead', 14 | 'MetaRCNNRoIHead', 'MetaRCNNResLayer', 'TwoBranchRoIHead' 15 | ] 16 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/roi_heads/bbox_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .contrastive_bbox_head import ContrastiveBBoxHead 3 | from .cosine_sim_bbox_head import CosineSimBBoxHead 4 | from .meta_bbox_head import MetaBBoxHead 5 | from .multi_relation_bbox_head import MultiRelationBBoxHead 6 | from .two_branch_bbox_head import TwoBranchBBoxHead 7 | 8 | __all__ = [ 9 | 'CosineSimBBoxHead', 'ContrastiveBBoxHead', 'MultiRelationBBoxHead', 10 | 'MetaBBoxHead', 'TwoBranchBBoxHead' 11 | ] 12 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/roi_heads/bbox_heads/meta_bbox_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import copy 3 | from typing import Dict, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from mmcv.runner import force_fp32 8 | from mmdet.models.builder import HEADS, build_loss 9 | from mmdet.models.losses import accuracy 10 | from mmdet.models.roi_heads import BBoxHead 11 | from torch import Tensor 12 | 13 | 14 | @HEADS.register_module() 15 | class MetaBBoxHead(BBoxHead): 16 | """BBoxHead with meta classification for metarcnn and fsdetview. 17 | 18 | Args: 19 | num_meta_classes (int): Number of classes for meta classification. 20 | meta_cls_in_channels (int): Number of support feature channels. 21 | with_meta_cls_loss (bool): Use meta classification loss. 22 | Default: True. 23 | meta_cls_loss_weight (float | None): The loss weight of `loss_meta`. 24 | Default: None. 25 | loss_meta (dict): Config for meta classification loss. 26 | """ 27 | 28 | def __init__(self, 29 | num_meta_classes: int, 30 | meta_cls_in_channels: int = 2048, 31 | with_meta_cls_loss: bool = True, 32 | meta_cls_loss_weight: Optional[float] = None, 33 | loss_meta: Dict = dict( 34 | type='CrossEntropyLoss', 35 | use_sigmoid=False, 36 | loss_weight=1.0), 37 | *args, 38 | **kwargs) -> None: 39 | super().__init__(*args, **kwargs) 40 | self.with_meta_cls_loss = with_meta_cls_loss 41 | if with_meta_cls_loss: 42 | self.fc_meta = nn.Linear(meta_cls_in_channels, num_meta_classes) 43 | self.meta_cls_loss_weight = meta_cls_loss_weight 44 | self.loss_meta_cls = build_loss(copy.deepcopy(loss_meta)) 45 | 46 | def forward_meta_cls(self, support_feat: Tensor) -> Tensor: 47 | """Forward function for meta classification. 48 | 49 | Args: 50 | support_feat (Tensor): Shape of (N, C, H, W). 51 | 52 | Returns: 53 | Tensor: Box scores with shape of (N, num_meta_classes, H, W). 54 | """ 55 | meta_cls_score = self.fc_meta(support_feat) 56 | return meta_cls_score 57 | 58 | @force_fp32(apply_to='meta_cls_score') 59 | def loss_meta(self, 60 | meta_cls_score: Tensor, 61 | meta_cls_labels: Tensor, 62 | meta_cls_label_weights: Tensor, 63 | reduction_override: Optional[str] = None) -> Dict: 64 | """Meta classification loss. 65 | 66 | Args: 67 | meta_cls_score (Tensor): Predicted meta classification scores 68 | with shape (N, num_meta_classes). 69 | meta_cls_labels (Tensor): Corresponding class indices with 70 | shape (N). 71 | meta_cls_label_weights (Tensor): Meta classification loss weight 72 | of each sample with shape (N). 73 | reduction_override (str | None): The reduction method used to 74 | override the original reduction method of the loss. Options 75 | are "none", "mean" and "sum". Default: None. 76 | 77 | Returns: 78 | Dict: The calculated loss. 79 | """ 80 | losses = dict() 81 | if self.meta_cls_loss_weight is None: 82 | loss_weight = 1. / max( 83 | torch.sum(meta_cls_label_weights > 0).float().item(), 1.) 84 | else: 85 | loss_weight = self.meta_cls_loss_weight 86 | if meta_cls_score.numel() > 0: 87 | loss_meta_cls_ = self.loss_meta_cls( 88 | meta_cls_score, 89 | meta_cls_labels, 90 | meta_cls_label_weights, 91 | reduction_override=reduction_override) 92 | losses['loss_meta_cls'] = loss_meta_cls_ * loss_weight 93 | losses['meta_acc'] = accuracy(meta_cls_score, meta_cls_labels) 94 | return losses 95 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/roi_heads/fsdetview_roi_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, Optional 3 | 4 | import torch 5 | from mmdet.models.builder import HEADS 6 | from torch import Tensor 7 | 8 | from .meta_rcnn_roi_head import MetaRCNNRoIHead 9 | 10 | 11 | @HEADS.register_module() 12 | class FSDetViewRoIHead(MetaRCNNRoIHead): 13 | """Roi head for `FSDetView `_. 14 | 15 | Args: 16 | aggregation_layer (dict): Config of `aggregation_layer`. 17 | Default: None. 18 | """ 19 | 20 | def __init__(self, 21 | aggregation_layer: Optional[Dict] = None, 22 | **kwargs) -> None: 23 | super().__init__(aggregation_layer=aggregation_layer, **kwargs) 24 | 25 | def _bbox_forward(self, query_roi_feats: Tensor, 26 | support_roi_feats: Tensor) -> Dict: 27 | """Box head forward function used in both training and testing. 28 | 29 | Args: 30 | query_roi_feats (Tensor): Roi features with shape (N, C). 31 | support_roi_feats (Tensor): Roi features with shape (1, C). 32 | 33 | Returns: 34 | dict: A dictionary of predicted results. 35 | """ 36 | # feature aggregation 37 | roi_feats = self.aggregation_layer( 38 | query_feat=query_roi_feats.unsqueeze(-1).unsqueeze(-1), 39 | support_feat=support_roi_feats.view(1, -1, 1, 1)) 40 | roi_feats = torch.cat(roi_feats, dim=1) 41 | roi_feats = torch.cat((roi_feats, query_roi_feats), dim=1) 42 | cls_score, bbox_pred = self.bbox_head(roi_feats) 43 | bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) 44 | return bbox_results 45 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/roi_heads/shared_heads/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .meta_rcnn_res_layer import MetaRCNNResLayer 3 | 4 | __all__ = ['MetaRCNNResLayer'] 5 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/roi_heads/shared_heads/meta_rcnn_res_layer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch.nn as nn 3 | from mmdet.models.builder import SHARED_HEADS 4 | from mmdet.models.roi_heads import ResLayer 5 | from torch import Tensor 6 | 7 | 8 | @SHARED_HEADS.register_module() 9 | class MetaRCNNResLayer(ResLayer): 10 | """Shared resLayer for metarcnn and fsdetview. 11 | 12 | It provides different forward logics for query and support images. 13 | """ 14 | 15 | def __init__(self, *args, **kwargs): 16 | super().__init__(*args, **kwargs) 17 | self.max_pool = nn.MaxPool2d(2) 18 | self.sigmoid = nn.Sigmoid() 19 | 20 | def forward(self, x: Tensor) -> Tensor: 21 | """Forward function for query images. 22 | 23 | Args: 24 | x (Tensor): Features from backbone with shape (N, C, H, W). 25 | 26 | Returns: 27 | Tensor: Shape of (N, C). 28 | """ 29 | res_layer = getattr(self, f'layer{self.stage + 1}') 30 | out = res_layer(x) 31 | out = out.mean(3).mean(2) 32 | return out 33 | 34 | def forward_support(self, x: Tensor) -> Tensor: 35 | """Forward function for support images. 36 | 37 | Args: 38 | x (Tensor): Features from backbone with shape (N, C, H, W). 39 | 40 | Returns: 41 | Tensor: Shape of (N, C). 42 | """ 43 | x = self.max_pool(x) 44 | res_layer = getattr(self, f'layer{self.stage + 1}') 45 | out = res_layer(x) 46 | out = self.sigmoid(out) 47 | out = out.mean(3).mean(2) 48 | return out 49 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/roi_heads/two_branch_roi_head.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from typing import Dict, List, Tuple 3 | 4 | import torch 5 | from mmdet.models.builder import HEADS 6 | from mmdet.models.roi_heads import StandardRoIHead 7 | from torch import Tensor 8 | 9 | 10 | @HEADS.register_module() 11 | class TwoBranchRoIHead(StandardRoIHead): 12 | """RoI head for `MPSR `_.""" 13 | 14 | def forward_auxiliary_train(self, feats: Tuple[Tensor], 15 | gt_labels: List[Tensor]) -> Dict: 16 | """Forward function and calculate loss for auxiliary data in training. 17 | 18 | Args: 19 | feats (tuple[Tensor]): List of features at multiple scales, each 20 | is a 4D-tensor. 21 | gt_labels (list[Tensor]): List of class indices corresponding 22 | to each features, each is a 4D-tensor. 23 | 24 | Returns: 25 | dict[str, Tensor]: a dictionary of loss components 26 | """ 27 | # bbox head forward and loss 28 | auxiliary_losses = self._bbox_forward_auxiliary_train(feats, gt_labels) 29 | return auxiliary_losses 30 | 31 | def _bbox_forward_auxiliary_train(self, feats: Tuple[Tensor], 32 | gt_labels: List[Tensor]) -> Dict: 33 | """Run forward function and calculate loss for box head in training. 34 | 35 | Args: 36 | feats (tuple[Tensor]): List of features at multiple scales, each 37 | is a 4D-tensor. 38 | gt_labels (list[Tensor]): List of class indices corresponding 39 | to each features, each is a 4D-tensor. 40 | 41 | Returns: 42 | dict[str, Tensor]: a dictionary of loss components 43 | """ 44 | cls_scores, = self.bbox_head.forward_auxiliary(feats) 45 | cls_score = torch.cat(cls_scores, dim=0) 46 | labels = torch.cat(gt_labels, dim=0) 47 | label_weights = torch.ones_like(labels) 48 | losses = self.bbox_head.auxiliary_loss(cls_score, labels, 49 | label_weights) 50 | 51 | return losses 52 | -------------------------------------------------------------------------------- /mmfewshot/detection/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .aggregation_layer import * # noqa: F401,F403 3 | -------------------------------------------------------------------------------- /mmfewshot/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from .collate import multi_pipeline_collate_fn 3 | from .compat_config import compat_cfg 4 | from .dist_utils import check_dist_init, sync_random_seed 5 | from .infinite_sampler import (DistributedInfiniteGroupSampler, 6 | DistributedInfiniteSampler, 7 | InfiniteGroupSampler, InfiniteSampler) 8 | from .local_seed import local_numpy_seed 9 | from .logger import get_root_logger 10 | from .runner import InfiniteEpochBasedRunner 11 | 12 | __all__ = [ 13 | 'multi_pipeline_collate_fn', 'local_numpy_seed', 14 | 'InfiniteEpochBasedRunner', 'InfiniteSampler', 'InfiniteGroupSampler', 15 | 'DistributedInfiniteSampler', 'DistributedInfiniteGroupSampler', 16 | 'get_root_logger', 'check_dist_init', 'sync_random_seed', 'compat_cfg' 17 | ] 18 | -------------------------------------------------------------------------------- /mmfewshot/utils/collect_env.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcv.utils import collect_env as collect_basic_env 3 | from mmcv.utils import get_git_hash 4 | 5 | import mmfewshot 6 | 7 | 8 | def collect_env(): 9 | env_info = collect_basic_env() 10 | env_info['MMFewShot'] = ( 11 | mmfewshot.__version__ + '+' + get_git_hash(digits=7)) 12 | return env_info 13 | 14 | 15 | if __name__ == '__main__': 16 | for name, val in collect_env().items(): 17 | print(f'{name}: {val}') 18 | -------------------------------------------------------------------------------- /mmfewshot/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | import torch.distributed as dist 5 | from mmcv.runner import get_dist_info 6 | 7 | 8 | def check_dist_init(): 9 | return dist.is_available() and dist.is_initialized() 10 | 11 | 12 | def sync_random_seed(seed=None, device='cuda'): 13 | """Propagating the seed of rank 0 to all other ranks. 14 | 15 | Make sure different ranks share the same seed. All workers must call 16 | this function, otherwise it will deadlock. This method is generally used in 17 | `DistributedSampler`, because the seed should be identical across all 18 | processes in the distributed group. 19 | In distributed sampling, different ranks should sample non-overlapped 20 | data in the dataset. Therefore, this function is used to make sure that 21 | each rank shuffles the data indices in the same order based 22 | on the same seed. Then different ranks could use different indices 23 | to select non-overlapped data from the same data list. 24 | Args: 25 | seed (int, Optional): The seed. Default to None. 26 | device (str): The device where the seed will be put on. 27 | Default to 'cuda'. 28 | Returns: 29 | int: Seed to be used. 30 | """ 31 | if seed is None: 32 | seed = np.random.randint(2**31) 33 | assert isinstance(seed, int) 34 | 35 | rank, world_size = get_dist_info() 36 | 37 | if world_size == 1: 38 | return seed 39 | 40 | if rank == 0: 41 | random_num = torch.tensor(seed, dtype=torch.int32, device=device) 42 | else: 43 | random_num = torch.tensor(0, dtype=torch.int32, device=device) 44 | dist.broadcast(random_num, src=0) 45 | return random_num.item() 46 | -------------------------------------------------------------------------------- /mmfewshot/utils/local_seed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from contextlib import contextmanager 3 | from typing import Optional 4 | 5 | import numpy as np 6 | 7 | 8 | @contextmanager 9 | def local_numpy_seed(seed: Optional[int] = None) -> None: 10 | """Run numpy codes with a local random seed. 11 | 12 | If seed is None, the default random state will be used. 13 | """ 14 | state = np.random.get_state() 15 | if seed is not None: 16 | np.random.seed(seed) 17 | try: 18 | yield 19 | finally: 20 | np.random.set_state(state) 21 | -------------------------------------------------------------------------------- /mmfewshot/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import logging 3 | 4 | from mmcv.utils import get_logger 5 | 6 | 7 | def get_root_logger(log_file=None, log_level=logging.INFO): 8 | return get_logger('mmfewshot', log_file, log_level) 9 | -------------------------------------------------------------------------------- /mmfewshot/utils/runner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import time 3 | 4 | from mmcv.runner import EpochBasedRunner 5 | from mmcv.runner.builder import RUNNERS 6 | from torch.utils.data import DataLoader 7 | 8 | 9 | @RUNNERS.register_module() 10 | class InfiniteEpochBasedRunner(EpochBasedRunner): 11 | """Epoch-based Runner supports dataloader with InfiniteSampler. 12 | 13 | The workers of dataloader will re-initialize, when the iterator of 14 | dataloader is created. InfiniteSampler is designed to avoid these time 15 | consuming operations, since the iterator with InfiniteSampler will never 16 | reach the end. 17 | """ 18 | 19 | def train(self, data_loader: DataLoader, **kwargs) -> None: 20 | self.model.train() 21 | self.mode = 'train' 22 | self.data_loader = data_loader 23 | self._max_iters = self._max_epochs * len(self.data_loader) 24 | self.call_hook('before_train_epoch') 25 | time.sleep(2) # Prevent possible deadlock during epoch transition 26 | 27 | # To reuse the iterator, we only create iterator once and bind it 28 | # with runner. In the next epoch, the iterator will be used against 29 | if not hasattr(self, 'data_loader_iter'): 30 | self.data_loader_iter = iter(self.data_loader) 31 | 32 | # The InfiniteSampler will never reach the end, but we set the 33 | # length of InfiniteSampler to the actual length of dataset. 34 | # The length of dataloader is determined by the length of sampler, 35 | # when the sampler is not None. Therefore, we can simply forward the 36 | # whole dataset in a epoch by length of dataloader. 37 | 38 | for i in range(len(self.data_loader)): 39 | data_batch = next(self.data_loader_iter) 40 | self._inner_iter = i 41 | self.call_hook('before_train_iter') 42 | self.run_iter(data_batch, train_mode=True, **kwargs) 43 | self.call_hook('after_train_iter') 44 | self._iter += 1 45 | 46 | self.call_hook('after_train_epoch') 47 | self._epoch += 1 48 | -------------------------------------------------------------------------------- /mmfewshot/version.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Open-MMLab. All rights reserved. 2 | __version__ = '0.1.0_af4fad5' 3 | short_version = __version__ 4 | 5 | 6 | def parse_version_info(version_str): 7 | version_info_ = [] 8 | for x in version_str.split('.'): 9 | if x.isdigit(): 10 | version_info_.append(int(x)) 11 | elif x.find('rc') != -1: 12 | patch_version = x.split('rc') 13 | version_info_.append(int(patch_version[0])) 14 | version_info_.append(f'rc{patch_version[1]}') 15 | return tuple(version_info_) 16 | 17 | 18 | version_info = parse_version_info(__version__) 19 | -------------------------------------------------------------------------------- /mmfscil/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import * 2 | from .models import * 3 | from .augments import * 4 | -------------------------------------------------------------------------------- /mmfscil/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .train import train_model 2 | -------------------------------------------------------------------------------- /mmfscil/augments/__init__.py: -------------------------------------------------------------------------------- 1 | from .mixup import BatchMixupLayer 2 | from .idty import Identity 3 | from .cutmix import BatchCutMixLayer 4 | -------------------------------------------------------------------------------- /mmfscil/augments/cutmix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | from mmcls.models.utils.augment.builder import AUGMENT 6 | from mmcls.models.utils.augment.cutmix import BaseCutMixLayer 7 | 8 | 9 | @AUGMENT.register_module(name='BatchCutMixTwoLabel') 10 | class BatchCutMixLayer(BaseCutMixLayer): 11 | r"""CutMix layer for a batch of data. 12 | 13 | CutMix is a method to improve the network's generalization capability. It's 14 | proposed in `CutMix: Regularization Strategy to Train Strong Classifiers 15 | with Localizable Features ` 16 | 17 | With this method, patches are cut and pasted among training images where 18 | the ground truth labels are also mixed proportionally to the area of the 19 | patches. 20 | 21 | Args: 22 | alpha (float): Parameters for Beta distribution to generate the 23 | mixing ratio. It should be a positive number. More details 24 | can be found in :class:`BatchMixupLayer`. 25 | num_classes (int): The number of classes 26 | prob (float): The probability to execute cutmix. It should be in 27 | range [0, 1]. Defaults to 1.0. 28 | cutmix_minmax (List[float], optional): The min/max area ratio of the 29 | patches. If not None, the bounding-box of patches is uniform 30 | sampled within this ratio range, and the ``alpha`` will be ignored. 31 | Otherwise, the bounding-box is generated according to the 32 | ``alpha``. Defaults to None. 33 | correct_lam (bool): Whether to apply lambda correction when cutmix bbox 34 | clipped by image borders. Defaults to True. 35 | 36 | Note: 37 | If the ``cutmix_minmax`` is None, how to generate the bounding-box of 38 | patches according to the ``alpha``? 39 | 40 | First, generate a :math:`\lambda`, details can be found in 41 | :class:`BatchMixupLayer`. And then, the area ratio of the bounding-box 42 | is calculated by: 43 | 44 | .. math:: 45 | \text{ratio} = \sqrt{1-\lambda} 46 | """ 47 | 48 | def __init__(self, *args, **kwargs): 49 | super(BatchCutMixLayer, self).__init__(*args, **kwargs) 50 | 51 | def cutmix(self, img, gt_label): 52 | lam = np.random.beta(self.alpha, self.alpha) 53 | batch_size = img.size(0) 54 | index = torch.randperm(batch_size) 55 | 56 | (bby1, bby2, bbx1, bbx2), lam = self.cutmix_bbox_and_lam(img.shape, lam) 57 | img[:, :, bby1:bby2, bbx1:bbx2] = img[index, :, bby1:bby2, bbx1:bbx2] 58 | return img, lam, gt_label, gt_label[index] 59 | 60 | def __call__(self, img, gt_label): 61 | return self.cutmix(img, gt_label) 62 | -------------------------------------------------------------------------------- /mmfscil/augments/idty.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from mmcls.models.utils.augment.builder import AUGMENT 3 | 4 | 5 | @AUGMENT.register_module(name='IdentityTwoLabel') 6 | class Identity(object): 7 | """Change gt_label to one_hot encoding and keep img as the same. 8 | 9 | Args: 10 | num_classes (int): The number of classes. 11 | prob (float): MixUp probability. It should be in range [0, 1]. 12 | Default to 1.0 13 | """ 14 | 15 | def __init__(self, num_classes, prob=1.0): 16 | super(Identity, self).__init__() 17 | 18 | assert isinstance(num_classes, int) 19 | assert isinstance(prob, float) and 0.0 <= prob <= 1.0 20 | 21 | self.num_classes = num_classes 22 | self.prob = prob 23 | 24 | def __call__(self, img, gt_label): 25 | return img, None, gt_label, None 26 | -------------------------------------------------------------------------------- /mmfscil/augments/mixup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | import torch 4 | 5 | from mmcls.models.utils.augment.builder import AUGMENT 6 | from mmcls.models.utils.augment.mixup import BaseMixupLayer 7 | 8 | 9 | @AUGMENT.register_module(name='BatchMixupTwoLabel') 10 | class BatchMixupLayer(BaseMixupLayer): 11 | r"""Mixup layer for a batch of data. 12 | 13 | Mixup is a method to reduces the memorization of corrupt labels and 14 | increases the robustness to adversarial examples. It's 15 | proposed in `mixup: Beyond Empirical Risk Minimization 16 | ` 17 | 18 | This method simply linearly mix pairs of data and their labels. 19 | 20 | Args: 21 | alpha (float): Parameters for Beta distribution to generate the 22 | mixing ratio. It should be a positive number. More details 23 | are in the note. 24 | num_classes (int): The number of classes. 25 | prob (float): The probability to execute mixup. It should be in 26 | range [0, 1]. Default sto 1.0. 27 | 28 | Note: 29 | The :math:`\alpha` (``alpha``) determines a random distribution 30 | :math:`Beta(\alpha, \alpha)`. For each batch of data, we sample 31 | a mixing ratio (marked as :math:`\lambda`, ``lam``) from the random 32 | distribution. 33 | """ 34 | 35 | def __init__(self, *args, **kwargs): 36 | super(BatchMixupLayer, self).__init__(*args, **kwargs) 37 | 38 | def mixup(self, img, gt_label): 39 | lam = np.random.beta(self.alpha, self.alpha) 40 | batch_size = img.size(0) 41 | index = torch.randperm(batch_size) 42 | 43 | mixed_img = lam * img + (1 - lam) * img[index, :] 44 | return mixed_img, lam, gt_label, gt_label[index] 45 | 46 | def __call__(self, img, gt_label): 47 | return self.mixup(img, gt_label) 48 | -------------------------------------------------------------------------------- /mmfscil/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .mini_imagenet import MiniImageNetFSCILDataset 2 | from .cub import CUBFSCILDataset 3 | from .cifar100 import CIFAR100FSCILDataset 4 | from .memory import MemoryDataset 5 | -------------------------------------------------------------------------------- /mmfscil/datasets/memory.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | 7 | class MemoryDataset(Dataset): 8 | """MemoryDataset is a dataset that loads features 9 | """ 10 | 11 | def __init__( 12 | self, 13 | feats: torch.Tensor, 14 | labels: torch.Tensor, 15 | ): 16 | self.feats = feats 17 | self.labels = labels 18 | assert len(self.feats) == len(self.labels), "The features and labels are with different sizes." 19 | 20 | def __len__(self) -> int: 21 | """Return length of the dataset.""" 22 | return len(self.feats) 23 | 24 | def __getitem__(self, idx: int) -> Dict: 25 | return {"feat": self.feats[idx], "gt_label": self.labels[idx]} 26 | -------------------------------------------------------------------------------- /mmfscil/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .classifier import ImageClassifierCIL 2 | from .ETFHead import ETFHead, DRLoss 3 | from .mlp_ffn_neck import MLPFFNNeck 4 | from .resnet18 import ResNet18 5 | -------------------------------------------------------------------------------- /mmfscil/models/mlp_ffn_neck.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from collections import OrderedDict 3 | 4 | import torch.nn as nn 5 | from mmcv.cnn import build_norm_layer 6 | 7 | from mmcls.models.builder import NECKS 8 | 9 | 10 | @NECKS.register_module() 11 | class MLPFFNNeck(nn.Module): 12 | def __init__(self, in_channels=512, out_channels=512): 13 | super().__init__() 14 | self.avg = nn.AdaptiveAvgPool2d((1, 1)) 15 | self.ln1 = nn.Sequential(OrderedDict([ 16 | ('linear', nn.Linear(in_channels, in_channels * 2)), 17 | ('ln', build_norm_layer(dict(type='LN'), in_channels * 2)[1]), 18 | ('relu', nn.LeakyReLU(0.1)) 19 | ])) 20 | self.ln2 = nn.Sequential(OrderedDict([ 21 | ('linear', nn.Linear(in_channels * 2, in_channels * 2)), 22 | ('ln', build_norm_layer(dict(type='LN'), in_channels * 2)[1]), 23 | ('relu', nn.LeakyReLU(0.1)) 24 | ])) 25 | self.ln3 = nn.Sequential(OrderedDict([ 26 | ('linear', nn.Linear(in_channels * 2, out_channels, bias=False)), 27 | ])) 28 | if in_channels == out_channels: 29 | # self.ffn = nn.Identity() 30 | self.ffn = nn.Sequential(OrderedDict([ 31 | ('proj', nn.Linear(in_channels, out_channels, bias=False)), 32 | ])) 33 | else: 34 | self.ffn = nn.Sequential(OrderedDict([ 35 | ('proj', nn.Linear(in_channels, out_channels, bias=False)), 36 | ])) 37 | 38 | def init_weights(self): 39 | pass 40 | 41 | def forward(self, inputs): 42 | if isinstance(inputs, tuple): 43 | inputs = inputs[-1] 44 | x = self.avg(inputs) 45 | x = x.view(inputs.size(0), -1) 46 | identity = x 47 | x = self.ln1(x) 48 | x = self.ln2(x) 49 | x = self.ln3(x) 50 | x = x + self.ffn(identity) 51 | return x 52 | -------------------------------------------------------------------------------- /mmfscil/models/resnet18.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | from mmcls.models import BACKBONES 5 | 6 | 7 | class BasicBlock(nn.Module): 8 | expansion = 1 9 | 10 | def __init__(self, in_planes, planes, stride=1): 11 | super(BasicBlock, self).__init__() 12 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 15 | self.bn2 = nn.BatchNorm2d(planes) 16 | 17 | self.shortcut = nn.Sequential() 18 | if stride != 1 or in_planes != self.expansion * planes: 19 | self.shortcut = nn.Sequential( 20 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 21 | nn.BatchNorm2d(self.expansion * planes) 22 | ) 23 | 24 | def forward(self, x): 25 | out = F.relu(self.bn1(self.conv1(x))) 26 | out = self.bn2(self.conv2(out)) 27 | out += self.shortcut(x) 28 | out = F.relu(out) 29 | return out 30 | 31 | 32 | class Bottleneck(nn.Module): 33 | expansion = 4 34 | 35 | def __init__(self, in_planes, planes, stride=1): 36 | super(Bottleneck, self).__init__() 37 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 38 | self.bn1 = nn.BatchNorm2d(planes) 39 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 40 | self.bn2 = nn.BatchNorm2d(planes) 41 | self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) 42 | self.bn3 = nn.BatchNorm2d(self.expansion * planes) 43 | 44 | self.shortcut = nn.Sequential() 45 | if stride != 1 or in_planes != self.expansion * planes: 46 | self.shortcut = nn.Sequential( 47 | nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), 48 | nn.BatchNorm2d(self.expansion * planes) 49 | ) 50 | 51 | def forward(self, x): 52 | out = F.relu(self.bn1(self.conv1(x))) 53 | out = F.relu(self.bn2(self.conv2(out))) 54 | out = self.bn3(self.conv3(out)) 55 | out += self.shortcut(x) 56 | out = F.relu(out) 57 | return out 58 | 59 | 60 | @BACKBONES.register_module() 61 | class ResNet18(nn.Module): 62 | def __init__(self, block=BasicBlock, num_blocks=(2, 2, 2, 2), low_dim=512): 63 | super().__init__() 64 | self.in_planes = 64 65 | 66 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 67 | self.bn1 = nn.BatchNorm2d(64) 68 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 69 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) 70 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 71 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 72 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 73 | # self.fc = nn.Linear(512 * block.expansion, low_dim) 74 | # self.l2norm = Normalize(2) 75 | 76 | def _make_layer(self, block, planes, num_blocks, stride): 77 | strides = [stride] + [1] * (num_blocks - 1) 78 | layers = [] 79 | for stride in strides: 80 | layers.append(block(self.in_planes, planes, stride)) 81 | self.in_planes = planes * block.expansion 82 | return nn.Sequential(*layers) 83 | 84 | def forward(self, x): 85 | out = F.relu(self.bn1(self.conv1(x))) 86 | out = self.layer1(out) 87 | out = self.layer2(out) 88 | out = self.layer3(out) 89 | out = self.layer4(out) 90 | return out 91 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | GPUS=$2 5 | NNODES=${NNODES:-1} 6 | NODE_RANK=${NODE_RANK:-0} 7 | PORT=${PORT:-$((29500 + $RANDOM % 29))} 8 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 9 | 10 | 11 | if command -v torchrun &> /dev/null 12 | then 13 | echo "Using torchrun mode." 14 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 15 | torchrun --nnodes=$NNODES \ 16 | --nnodes=$NNODES \ 17 | --node_rank=$NODE_RANK \ 18 | --master_addr=$MASTER_ADDR \ 19 | --master_port=$PORT \ 20 | --nproc_per_node=$GPUS \ 21 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 22 | else 23 | echo "Using launch mode." 24 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 25 | python -m torch.distributed.launch \ 26 | --nnodes=$NNODES \ 27 | --node_rank=$NODE_RANK \ 28 | --master_addr=$MASTER_ADDR \ 29 | --master_port=$PORT \ 30 | --nproc_per_node=$GPUS \ 31 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 32 | fi 33 | -------------------------------------------------------------------------------- /tools/docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | DATALOC=${DATALOC:-`realpath ../datasets`} 4 | LOGLOC=${LOGLOC:-`realpath ../logger`} 5 | IMG=${IMG:-"harbory/openmmlab:2206"} 6 | GPUS=${GPUS:-"all"} 7 | 8 | if [ ${GPUS} != "all" ]; then 9 | GPUS='"device='${GPUS}'"' 10 | fi 11 | 12 | docker run --gpus ${GPUS} -it --rm --ipc=host --net=host \ 13 | --mount src=$(pwd),target=/opt/project,type=bind \ 14 | --mount src=$DATALOC,target=/opt/data,type=bind \ 15 | --mount src=$LOGLOC,target=/opt/logger,type=bind \ 16 | $IMG 17 | -------------------------------------------------------------------------------- /tools/run_fscil.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | CONFIG=$1 4 | WORKDIR=$2 5 | CKPT=$3 6 | GPUS=$4 7 | NNODES=${NNODES:-1} 8 | NODE_RANK=${NODE_RANK:-0} 9 | PORT=${PORT:-$((29500 + $RANDOM % 29))} 10 | MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"} 11 | 12 | 13 | if command -v torchrun &> /dev/null 14 | then 15 | echo "Using torchrun mode." 16 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 17 | torchrun --nnodes=$NNODES \ 18 | --nnodes=$NNODES \ 19 | --node_rank=$NODE_RANK \ 20 | --master_addr=$MASTER_ADDR \ 21 | --master_port=$PORT \ 22 | --nproc_per_node=$GPUS \ 23 | $(dirname "$0")/fscil.py $CONFIG $WORKDIR $CKPT --launcher pytorch ${@:5} 24 | else 25 | echo "Using launch mode." 26 | PYTHONPATH="$(dirname $0)/..":$PYTHONPATH OMP_NUM_THREADS=1 MKL_NUM_THREADS=1 \ 27 | python -m torch.distributed.launch \ 28 | --nnodes=$NNODES \ 29 | --node_rank=$NODE_RANK \ 30 | --master_addr=$MASTER_ADDR \ 31 | --master_port=$PORT \ 32 | --nproc_per_node=$GPUS \ 33 | $(dirname "$0")/fscil.py $CONFIG $WORKDIR $CKPT --launcher pytorch ${@:5} 34 | fi 35 | --------------------------------------------------------------------------------