├── .gitignore ├── README.md ├── city_cfgs └── mgan_50_65.py ├── data └── annotations │ └── val.json ├── eval ├── coco.py ├── eval_MR_multisetup.py ├── eval_demo.py ├── results.txt └── val_gt.json ├── mmdet ├── __init__.py ├── apis │ ├── __init__.py │ ├── env.py │ └── inference.py ├── core │ ├── __init__.py │ ├── anchor │ │ ├── __init__.py │ │ ├── anchor_generator.py │ │ ├── anchor_target.py │ │ └── guided_anchor_target.py │ ├── bbox │ │ ├── __init__.py │ │ ├── assign_sampling.py │ │ ├── assigners │ │ │ ├── __init__.py │ │ │ ├── approx_max_iou_assigner.py │ │ │ ├── assign_result.py │ │ │ ├── base_assigner.py │ │ │ └── max_iou_assigner.py │ │ ├── bbox_target.py │ │ ├── geometry.py │ │ ├── samplers │ │ │ ├── __init__.py │ │ │ ├── base_sampler.py │ │ │ ├── combined_sampler.py │ │ │ ├── instance_balanced_pos_sampler.py │ │ │ ├── iou_balanced_neg_sampler.py │ │ │ ├── ohem_sampler.py │ │ │ ├── pseudo_sampler.py │ │ │ ├── random_sampler.py │ │ │ └── sampling_result.py │ │ └── transforms.py │ ├── evaluation │ │ ├── __init__.py │ │ ├── bbox_overlaps.py │ │ ├── class_names.py │ │ ├── coco_utils.py │ │ ├── eval_hooks.py │ │ ├── mean_ap.py │ │ └── recall.py │ ├── fp16 │ │ ├── __init__.py │ │ ├── decorators.py │ │ ├── hooks.py │ │ └── utils.py │ ├── mask │ │ ├── __init__.py │ │ ├── mask_target.py │ │ └── utils.py │ ├── post_processing │ │ ├── __init__.py │ │ ├── bbox_nms.py │ │ └── merge_augs.py │ └── utils │ │ ├── __init__.py │ │ ├── dist_utils.py │ │ └── misc.py ├── datasets │ ├── __init__.py │ ├── builder.py │ ├── city.py │ ├── custom.py │ ├── dataset_wrappers.py │ ├── extra_aug.py │ ├── loader │ │ ├── __init__.py │ │ ├── build_loader.py │ │ └── sampler.py │ ├── registry.py │ ├── transforms.py │ └── utils.py ├── models │ ├── __init__.py │ ├── anchor_heads │ │ ├── __init__.py │ │ ├── anchor_head.py │ │ └── rpn_head.py │ ├── backbones │ │ ├── __init__.py │ │ └── vgg.py │ ├── bbox_heads │ │ ├── __init__.py │ │ ├── bbox_head.py │ │ └── convfc_bbox_head.py │ ├── builder.py │ ├── detectors │ │ ├── __init__.py │ │ ├── base.py │ │ ├── mgan.py │ │ └── test_mixins.py │ ├── mgan_heads │ │ ├── __init__.py │ │ └── mgan_head.py │ ├── necks │ │ ├── __init__.py │ │ ├── fpn.py │ │ └── hrfpn.py │ ├── registry.py │ ├── roi_extractors │ │ ├── __init__.py │ │ └── single_level.py │ └── utils │ │ ├── __init__.py │ │ ├── conv_module.py │ │ ├── conv_ws.py │ │ ├── norm.py │ │ ├── scale.py │ │ └── weight_init.py ├── ops │ ├── __init__.py │ ├── nms │ │ ├── __init__.py │ │ ├── nms_wrapper.py │ │ └── src │ │ │ ├── nms_cpu.cpp │ │ │ ├── nms_cuda.cpp │ │ │ ├── nms_kernel.cu │ │ │ ├── soft_nms_cpu.cpp │ │ │ └── soft_nms_cpu.pyx │ ├── roi_align │ │ ├── __init__.py │ │ ├── functions │ │ │ ├── __init__.py │ │ │ └── roi_align.py │ │ ├── gradcheck.py │ │ ├── modules │ │ │ ├── __init__.py │ │ │ └── roi_align.py │ │ └── src │ │ │ ├── roi_align_cuda.cpp │ │ │ └── roi_align_kernel.cu │ └── roi_pool │ │ ├── __init__.py │ │ ├── functions │ │ ├── __init__.py │ │ └── roi_pool.py │ │ ├── gradcheck.py │ │ ├── modules │ │ ├── __init__.py │ │ └── roi_pool.py │ │ └── src │ │ ├── roi_pool_cuda.cpp │ │ └── roi_pool_kernel.cu ├── utils │ ├── __init__.py │ └── registry.py └── version.py ├── setup.py └── tools ├── analyze_logs.py ├── coco_eval.py ├── dist_test.sh ├── dist_train.sh ├── slurm_test.sh ├── slurm_train.sh ├── test.py └── upgrade_model_version.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | /models/ 106 | .idea/ 107 | #/result/ 108 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mask-Guided Attention Network for Occluded Pedestrian Detection 2 | 3 | Pedestrian detection framework as detailed in [arXiv report](https://arxiv.org/abs/1910.06160), accepted to ICCV 2019. 4 | 5 | ## Installation 6 | Our MGAN is based on [mmdetection](https://github.com/open-mmlab/mmdetection). Please check [INSTALL.md](https://github.com/open-mmlab/mmdetection/blob/master/docs/INSTALL.md) for installation instructions. 7 | 8 | ## Datasets 9 | You can download [CityScapes Datasets](https://www.cityscapes-dataset.com/).Put it in data folder. 10 | 11 | ## Testing 12 | The following commands will test the model on 1 GPU. 13 | ``` 14 | python tools/test.py city_cfgs/mgan_50_65.py models/50_65.pth --out result/50_65.pkl 15 | ``` 16 | ## Eval 17 | The following command will evaltate the results on CityPersons 18 | ``` 19 | python eval/eval_demo.py 20 | ``` 21 | 22 | ## Results 23 | | R | HO | Download | 24 | |:----: | :----: | :-------------------------------------------------------------------------------------------------------------------------------: | 25 | | 11.0 | 50.3 | [Google Drive](https://drive.google.com/file/d/1gww2UZDLlE76JFA80LoR37OTHxokhaii/view?usp=sharing)/ [Baidu Yun](https://pan.baidu.com/s/1q68cjZZyH4lqNjy9nv588Q)(zq93) | -------------------------------------------------------------------------------- /city_cfgs/mgan_50_65.py: -------------------------------------------------------------------------------- 1 | # model settings 2 | model = dict( 3 | type='MGAN', 4 | pretrained='modelzoo://vgg16', 5 | backbone=dict( 6 | type='VGG', 7 | depth=16, 8 | frozen_stages=1), 9 | neck=None, 10 | rpn_head=dict( 11 | type='RPNHead', 12 | in_channels=512, 13 | feat_channels=512, 14 | anchor_scales=[4., 5.4, 7.2, 9.8, 13.2, 17.9, 24.2, 33.0, 44.1, 59.6, 80.0], 15 | anchor_ratios=[2.44], 16 | anchor_strides=[8], 17 | target_means=[.0, .0, .0, .0], 18 | target_stds=[1.0, 1.0, 1.0, 1.0], 19 | ), 20 | bbox_roi_extractor=dict( 21 | type='SingleRoIExtractor', 22 | roi_layer=dict(type='RoIAlign', out_size=7, sample_num=2), 23 | out_channels=512, 24 | featmap_strides=[8]), 25 | mgan_head=dict( 26 | type='MGANHead'), 27 | bbox_head=dict( 28 | type='SharedFCBBoxHead', 29 | num_fcs=2, 30 | in_channels=512, 31 | fc_out_channels=1024, 32 | roi_feat_size=7, 33 | num_classes=2, 34 | target_means=[0., 0., 0., 0.], 35 | target_stds=[0.1, 0.1, 0.2, 0.2], 36 | reg_class_agnostic=False, 37 | ) 38 | ) 39 | test_cfg = dict( 40 | rpn=dict( 41 | nms_across_levels=False, 42 | nms_pre=12000, 43 | nms_post=2000, 44 | max_num=2000, 45 | nms_thr=0.7, 46 | min_bbox_size=0), 47 | rcnn=dict( 48 | score_thr=0.0, nms=dict(type='nms', iou_thr=0.5), max_per_img=100) 49 | ) 50 | # dataset settings 51 | dataset_type = 'CityDataset' 52 | data_root = 'data/' 53 | img_norm_cfg = dict(mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) 54 | data = dict( 55 | imgs_per_gpu=2, 56 | workers_per_gpu=2, 57 | test=dict( 58 | type=dataset_type, 59 | ann_file=data_root + 'annotations/val.json', 60 | img_prefix=data_root + 'Cityscapes/leftImg8bit_trainvaltest/leftImg8bit', 61 | img_scale=(2048, 1024), 62 | img_norm_cfg=img_norm_cfg, 63 | size_divisor=32, 64 | flip_ratio=0, 65 | with_mask=False, 66 | with_label=False, 67 | test_mode=True) 68 | ) 69 | # yapf:enable 70 | # runtime settings 71 | total_epochs = 12 72 | dist_params = dict(backend='nccl') 73 | log_level = 'INFO' 74 | work_dir = '../work_dirs/mgan_50_65' 75 | load_from = None 76 | resume_from = None 77 | workflow = [('train', 1)] 78 | -------------------------------------------------------------------------------- /eval/eval_demo.py: -------------------------------------------------------------------------------- 1 | from coco import COCO 2 | from eval_MR_multisetup import COCOeval 3 | 4 | annType = 'bbox' # specify type here 5 | print('Running demo for *%s* results.' % (annType)) 6 | 7 | # initialize COCO ground truth api 8 | annFile = 'eval/val_gt.json' 9 | # initialize COCO detections api 10 | resFile = 'result/50_65.pkl.json' 11 | 12 | ## running evaluation 13 | res_file = open("eval/results.txt", "w") 14 | for id_setup in range(0, 2): 15 | cocoGt = COCO(annFile) 16 | cocoDt = cocoGt.loadRes(resFile) 17 | imgIds = sorted(cocoGt.getImgIds()) 18 | cocoEval = COCOeval(cocoGt, cocoDt, annType) 19 | cocoEval.params.imgIds = imgIds 20 | cocoEval.evaluate(id_setup) 21 | cocoEval.accumulate() 22 | cocoEval.summarize(id_setup, res_file) 23 | 24 | res_file.close() 25 | -------------------------------------------------------------------------------- /eval/results.txt: -------------------------------------------------------------------------------- 1 | Average Miss Rate (MR) @ Reasonable [ IoU=0.50 | height=[50:10000000000] | visibility=[0.65:10000000000.00] ] = 11.02% 2 | Average Miss Rate (MR) @ Reasonable_occ=heavy [ IoU=0.50 | height=[50:10000000000] | visibility=[0.20:0.65] ] = 50.34% 3 | -------------------------------------------------------------------------------- /mmdet/__init__.py: -------------------------------------------------------------------------------- 1 | from .version import __version__, short_version 2 | 3 | __all__ = ['__version__', 'short_version'] 4 | -------------------------------------------------------------------------------- /mmdet/apis/__init__.py: -------------------------------------------------------------------------------- 1 | from .env import init_dist, get_root_logger, set_random_seed 2 | from .inference import init_detector, inference_detector, show_result 3 | 4 | __all__ = [ 5 | 'init_dist', 'get_root_logger', 'set_random_seed', 6 | 'init_detector', 'inference_detector', 'show_result' 7 | ] 8 | -------------------------------------------------------------------------------- /mmdet/apis/env.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | import subprocess 5 | 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | from mmcv.runner import get_dist_info 11 | 12 | 13 | def init_dist(launcher, backend='nccl', **kwargs): 14 | if mp.get_start_method(allow_none=True) is None: 15 | mp.set_start_method('spawn') 16 | if launcher == 'pytorch': 17 | _init_dist_pytorch(backend, **kwargs) 18 | elif launcher == 'mpi': 19 | _init_dist_mpi(backend, **kwargs) 20 | elif launcher == 'slurm': 21 | _init_dist_slurm(backend, **kwargs) 22 | else: 23 | raise ValueError('Invalid launcher type: {}'.format(launcher)) 24 | 25 | 26 | def _init_dist_pytorch(backend, **kwargs): 27 | # TODO: use local_rank instead of rank % num_gpus 28 | rank = int(os.environ['RANK']) 29 | num_gpus = torch.cuda.device_count() 30 | torch.cuda.set_device(rank % num_gpus) 31 | dist.init_process_group(backend=backend, **kwargs) 32 | 33 | 34 | def _init_dist_mpi(backend, **kwargs): 35 | raise NotImplementedError 36 | 37 | 38 | def _init_dist_slurm(backend, port=29500, **kwargs): 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | 'scontrol show hostname {} | head -n1'.format(node_list)) 46 | os.environ['MASTER_PORT'] = str(port) 47 | os.environ['MASTER_ADDR'] = addr 48 | os.environ['WORLD_SIZE'] = str(ntasks) 49 | os.environ['RANK'] = str(proc_id) 50 | dist.init_process_group(backend=backend) 51 | 52 | 53 | def set_random_seed(seed): 54 | random.seed(seed) 55 | np.random.seed(seed) 56 | torch.manual_seed(seed) 57 | torch.cuda.manual_seed_all(seed) 58 | 59 | 60 | def get_root_logger(log_level=logging.INFO): 61 | logger = logging.getLogger() 62 | if not logger.hasHandlers(): 63 | logging.basicConfig( 64 | format='%(asctime)s - %(levelname)s - %(message)s', 65 | level=log_level) 66 | rank, _ = get_dist_info() 67 | if rank != 0: 68 | logger.setLevel('ERROR') 69 | return logger 70 | -------------------------------------------------------------------------------- /mmdet/apis/inference.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import mmcv 4 | import numpy as np 5 | import pycocotools.mask as maskUtils 6 | import torch 7 | from mmcv.runner import load_checkpoint 8 | 9 | from mmdet.core import get_classes 10 | from mmdet.datasets import to_tensor 11 | from mmdet.datasets.transforms import ImageTransform 12 | from mmdet.models import build_detector 13 | 14 | 15 | def init_detector(config, checkpoint=None, device='cuda:0'): 16 | """Initialize a detector from config file. 17 | 18 | Args: 19 | config (str or :obj:`mmcv.Config`): Config file path or the config 20 | object. 21 | checkpoint (str, optional): Checkpoint path. If left as None, the model 22 | will not load any weights. 23 | 24 | Returns: 25 | nn.Module: The constructed detector. 26 | """ 27 | if isinstance(config, str): 28 | config = mmcv.Config.fromfile(config) 29 | elif not isinstance(config, mmcv.Config): 30 | raise TypeError('config must be a filename or Config object, ' 31 | 'but got {}'.format(type(config))) 32 | config.model.pretrained = None 33 | model = build_detector(config.model, test_cfg=config.test_cfg) 34 | if checkpoint is not None: 35 | checkpoint = load_checkpoint(model, checkpoint) 36 | if 'CLASSES' in checkpoint['meta']: 37 | model.CLASSES = checkpoint['meta']['CLASSES'] 38 | else: 39 | warnings.warn('Class names are not saved in the checkpoint\'s ' 40 | 'meta data, use COCO classes by default.') 41 | model.CLASSES = get_classes('coco') 42 | model.cfg = config # save the config in the model for convenience 43 | model.to(device) 44 | model.eval() 45 | return model 46 | 47 | 48 | def inference_detector(model, imgs): 49 | """Inference image(s) with the detector. 50 | 51 | Args: 52 | model (nn.Module): The loaded detector. 53 | imgs (str/ndarray or list[str/ndarray]): Either image files or loaded 54 | images. 55 | 56 | Returns: 57 | If imgs is a str, a generator will be returned, otherwise return the 58 | detection results directly. 59 | """ 60 | cfg = model.cfg 61 | img_transform = ImageTransform( 62 | size_divisor=cfg.data.test.size_divisor, **cfg.img_norm_cfg) 63 | 64 | device = next(model.parameters()).device # model device 65 | if not isinstance(imgs, list): 66 | return _inference_single(model, imgs, img_transform, device) 67 | else: 68 | return _inference_generator(model, imgs, img_transform, device) 69 | 70 | 71 | def _prepare_data(img, img_transform, cfg, device): 72 | ori_shape = img.shape 73 | img, img_shape, pad_shape, scale_factor = img_transform( 74 | img, 75 | scale=cfg.data.test.img_scale, 76 | keep_ratio=cfg.data.test.get('resize_keep_ratio', True)) 77 | img = to_tensor(img).to(device).unsqueeze(0) 78 | img_meta = [ 79 | dict( 80 | ori_shape=ori_shape, 81 | img_shape=img_shape, 82 | pad_shape=pad_shape, 83 | scale_factor=scale_factor, 84 | flip=False) 85 | ] 86 | return dict(img=[img], img_meta=[img_meta]) 87 | 88 | 89 | def _inference_single(model, img, img_transform, device): 90 | img = mmcv.imread(img) 91 | data = _prepare_data(img, img_transform, model.cfg, device) 92 | with torch.no_grad(): 93 | result = model(return_loss=False, rescale=True, **data) 94 | return result 95 | 96 | 97 | def _inference_generator(model, imgs, img_transform, device): 98 | for img in imgs: 99 | yield _inference_single(model, img, img_transform, device) 100 | 101 | 102 | # TODO: merge this method with the one in BaseDetector 103 | def show_result(img, 104 | result, 105 | class_names, 106 | score_thr=0.3, 107 | wait_time=0, 108 | out_file=None): 109 | """Visualize the detection results on the image. 110 | 111 | Args: 112 | img (str or np.ndarray): Image filename or loaded image. 113 | result (tuple[list] or list): The detection result, can be either 114 | (bbox, segm) or just bbox. 115 | class_names (list[str] or tuple[str]): A list of class names. 116 | score_thr (float): The threshold to visualize the bboxes and masks. 117 | wait_time (int): Value of waitKey param. 118 | out_file (str, optional): If specified, the visualization result will 119 | be written to the out file instead of shown in a window. 120 | """ 121 | assert isinstance(class_names, (tuple, list)) 122 | img = mmcv.imread(img) 123 | if isinstance(result, tuple): 124 | bbox_result, segm_result = result 125 | else: 126 | bbox_result, segm_result = result, None 127 | bboxes = np.vstack(bbox_result) 128 | # draw segmentation masks 129 | if segm_result is not None: 130 | segms = mmcv.concat_list(segm_result) 131 | inds = np.where(bboxes[:, -1] > score_thr)[0] 132 | for i in inds: 133 | color_mask = np.random.randint(0, 256, (1, 3), dtype=np.uint8) 134 | mask = maskUtils.decode(segms[i]).astype(np.bool) 135 | img[mask] = img[mask] * 0.5 + color_mask * 0.5 136 | # draw bounding boxes 137 | labels = [ 138 | np.full(bbox.shape[0], i, dtype=np.int32) 139 | for i, bbox in enumerate(bbox_result) 140 | ] 141 | labels = np.concatenate(labels) 142 | mmcv.imshow_det_bboxes( 143 | img.copy(), 144 | bboxes, 145 | labels, 146 | class_names=class_names, 147 | score_thr=score_thr, 148 | show=out_file is None, 149 | wait_time=wait_time, 150 | out_file=out_file) 151 | -------------------------------------------------------------------------------- /mmdet/core/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor import * # noqa: F401, F403 2 | from .bbox import * # noqa: F401, F403 3 | from .evaluation import * # noqa: F401, F403 4 | from .fp16 import * # noqa: F401, F403 5 | from .mask import * # noqa: F401, F403 6 | from .post_processing import * # noqa: F401, F403 7 | from .utils import * # noqa: F401, F403 8 | -------------------------------------------------------------------------------- /mmdet/core/anchor/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_generator import AnchorGenerator 2 | from .anchor_target import anchor_target, anchor_inside_flags 3 | from .guided_anchor_target import ga_loc_target, ga_shape_target 4 | 5 | __all__ = [ 6 | 'AnchorGenerator', 'anchor_target', 'anchor_inside_flags', 'ga_loc_target', 7 | 'ga_shape_target' 8 | ] 9 | -------------------------------------------------------------------------------- /mmdet/core/anchor/anchor_generator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AnchorGenerator(object): 5 | 6 | def __init__(self, base_size, scales, ratios, scale_major=True, ctr=None): 7 | self.base_size = base_size 8 | self.scales = torch.Tensor(scales) 9 | self.ratios = torch.Tensor(ratios) 10 | self.scale_major = scale_major 11 | self.ctr = ctr 12 | self.base_anchors = self.gen_base_anchors() 13 | 14 | @property 15 | def num_base_anchors(self): 16 | return self.base_anchors.size(0) 17 | 18 | def gen_base_anchors(self): 19 | w = self.base_size 20 | h = self.base_size 21 | if self.ctr is None: 22 | x_ctr = 0.5 * (w - 1) 23 | y_ctr = 0.5 * (h - 1) 24 | else: 25 | x_ctr, y_ctr = self.ctr 26 | 27 | h_ratios = torch.sqrt(self.ratios) 28 | w_ratios = 1 / h_ratios 29 | if self.scale_major: 30 | ws = (w * w_ratios[:, None] * self.scales[None, :]).view(-1) 31 | hs = (h * h_ratios[:, None] * self.scales[None, :]).view(-1) 32 | else: 33 | ws = (w * self.scales[:, None] * w_ratios[None, :]).view(-1) 34 | hs = (h * self.scales[:, None] * h_ratios[None, :]).view(-1) 35 | 36 | base_anchors = torch.stack( 37 | [ 38 | x_ctr - 0.5 * (ws - 1), y_ctr - 0.5 * (hs - 1), 39 | x_ctr + 0.5 * (ws - 1), y_ctr + 0.5 * (hs - 1) 40 | ], 41 | dim=-1).round() 42 | 43 | return base_anchors 44 | 45 | def _meshgrid(self, x, y, row_major=True): 46 | xx = x.repeat(len(y)) 47 | yy = y.view(-1, 1).repeat(1, len(x)).view(-1) 48 | if row_major: 49 | return xx, yy 50 | else: 51 | return yy, xx 52 | 53 | def grid_anchors(self, featmap_size, stride=16, device='cuda'): 54 | base_anchors = self.base_anchors.to(device) 55 | 56 | feat_h, feat_w = featmap_size 57 | shift_x = torch.arange(0, feat_w, device=device) * stride 58 | shift_y = torch.arange(0, feat_h, device=device) * stride 59 | shift_xx, shift_yy = self._meshgrid(shift_x, shift_y) 60 | shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1) 61 | shifts = shifts.type_as(base_anchors) 62 | # first feat_w elements correspond to the first row of shifts 63 | # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get 64 | # shifted anchors (K, A, 4), reshape to (K*A, 4) 65 | 66 | all_anchors = base_anchors[None, :, :] + shifts[:, None, :] 67 | all_anchors = all_anchors.view(-1, 4) 68 | # first A rows correspond to A anchors of (0, 0) in feature map, 69 | # then (0, 1), (0, 2), ... 70 | return all_anchors 71 | 72 | def valid_flags(self, featmap_size, valid_size, device='cuda'): 73 | feat_h, feat_w = featmap_size 74 | valid_h, valid_w = valid_size 75 | assert valid_h <= feat_h and valid_w <= feat_w 76 | valid_x = torch.zeros(feat_w, dtype=torch.uint8, device=device) 77 | valid_y = torch.zeros(feat_h, dtype=torch.uint8, device=device) 78 | valid_x[:valid_w] = 1 79 | valid_y[:valid_h] = 1 80 | valid_xx, valid_yy = self._meshgrid(valid_x, valid_y) 81 | valid = valid_xx & valid_yy 82 | valid = valid[:, None].expand( 83 | valid.size(0), self.num_base_anchors).contiguous().view(-1) 84 | return valid 85 | -------------------------------------------------------------------------------- /mmdet/core/bbox/__init__.py: -------------------------------------------------------------------------------- 1 | from .geometry import bbox_overlaps 2 | from .assigners import BaseAssigner, MaxIoUAssigner, AssignResult 3 | from .samplers import (BaseSampler, PseudoSampler, RandomSampler, 4 | InstanceBalancedPosSampler, IoUBalancedNegSampler, 5 | CombinedSampler, SamplingResult) 6 | from .assign_sampling import build_assigner, build_sampler, assign_and_sample 7 | from .transforms import (bbox2delta, delta2bbox, bbox_flip, bbox_mapping, 8 | bbox_mapping_back, bbox2roi, roi2bbox, bbox2result, 9 | distance2bbox) 10 | from .bbox_target import bbox_target 11 | 12 | __all__ = [ 13 | 'bbox_overlaps', 'BaseAssigner', 'MaxIoUAssigner', 'AssignResult', 14 | 'BaseSampler', 'PseudoSampler', 'RandomSampler', 15 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 16 | 'SamplingResult', 'build_assigner', 'build_sampler', 'assign_and_sample', 17 | 'bbox2delta', 'delta2bbox', 'bbox_flip', 'bbox_mapping', 18 | 'bbox_mapping_back', 'bbox2roi', 'roi2bbox', 'bbox2result', 19 | 'distance2bbox', 'bbox_target' 20 | ] 21 | -------------------------------------------------------------------------------- /mmdet/core/bbox/assign_sampling.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | from . import assigners, samplers 4 | 5 | 6 | def build_assigner(cfg, **kwargs): 7 | if isinstance(cfg, assigners.BaseAssigner): 8 | return cfg 9 | elif isinstance(cfg, dict): 10 | return mmcv.runner.obj_from_dict(cfg, assigners, default_args=kwargs) 11 | else: 12 | raise TypeError('Invalid type {} for building a sampler'.format( 13 | type(cfg))) 14 | 15 | 16 | def build_sampler(cfg, **kwargs): 17 | if isinstance(cfg, samplers.BaseSampler): 18 | return cfg 19 | elif isinstance(cfg, dict): 20 | return mmcv.runner.obj_from_dict(cfg, samplers, default_args=kwargs) 21 | else: 22 | raise TypeError('Invalid type {} for building a sampler'.format( 23 | type(cfg))) 24 | 25 | 26 | def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg): 27 | bbox_assigner = build_assigner(cfg.assigner) 28 | bbox_sampler = build_sampler(cfg.sampler) 29 | assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore, 30 | gt_labels) 31 | sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes, 32 | gt_labels) 33 | return assign_result, sampling_result 34 | -------------------------------------------------------------------------------- /mmdet/core/bbox/assigners/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_assigner import BaseAssigner 2 | from .max_iou_assigner import MaxIoUAssigner 3 | from .approx_max_iou_assigner import ApproxMaxIoUAssigner 4 | from .assign_result import AssignResult 5 | 6 | __all__ = [ 7 | 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult' 8 | ] 9 | -------------------------------------------------------------------------------- /mmdet/core/bbox/assigners/approx_max_iou_assigner.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .max_iou_assigner import MaxIoUAssigner 4 | from ..geometry import bbox_overlaps 5 | 6 | 7 | class ApproxMaxIoUAssigner(MaxIoUAssigner): 8 | """Assign a corresponding gt bbox or background to each bbox. 9 | 10 | Each proposals will be assigned with `-1`, `0`, or a positive integer 11 | indicating the ground truth index. 12 | 13 | - -1: don't care 14 | - 0: negative sample, no assigned gt 15 | - positive integer: positive sample, index (1-based) of assigned gt 16 | 17 | Args: 18 | pos_iou_thr (float): IoU threshold for positive bboxes. 19 | neg_iou_thr (float or tuple): IoU threshold for negative bboxes. 20 | min_pos_iou (float): Minimum iou for a bbox to be considered as a 21 | positive bbox. Positive samples can have smaller IoU than 22 | pos_iou_thr due to the 4th step (assign max IoU sample to each gt). 23 | gt_max_assign_all (bool): Whether to assign all bboxes with the same 24 | highest overlap with some gt to that gt. 25 | ignore_iof_thr (float): IoF threshold for ignoring bboxes (if 26 | `gt_bboxes_ignore` is specified). Negative values mean not 27 | ignoring any bboxes. 28 | ignore_wrt_candidates (bool): Whether to compute the iof between 29 | `bboxes` and `gt_bboxes_ignore`, or the contrary. 30 | """ 31 | 32 | def __init__(self, 33 | pos_iou_thr, 34 | neg_iou_thr, 35 | min_pos_iou=.0, 36 | gt_max_assign_all=True, 37 | ignore_iof_thr=-1, 38 | ignore_wrt_candidates=True): 39 | self.pos_iou_thr = pos_iou_thr 40 | self.neg_iou_thr = neg_iou_thr 41 | self.min_pos_iou = min_pos_iou 42 | self.gt_max_assign_all = gt_max_assign_all 43 | self.ignore_iof_thr = ignore_iof_thr 44 | self.ignore_wrt_candidates = ignore_wrt_candidates 45 | 46 | def assign(self, 47 | approxs, 48 | squares, 49 | approxs_per_octave, 50 | gt_bboxes, 51 | gt_bboxes_ignore=None, 52 | gt_labels=None): 53 | """Assign gt to approxs. 54 | 55 | This method assign a gt bbox to each group of approxs (bboxes), 56 | each group of approxs is represent by a base approx (bbox) and 57 | will be assigned with -1, 0, or a positive number. 58 | -1 means don't care, 0 means negative sample, 59 | positive number is the index (1-based) of assigned gt. 60 | The assignment is done in following steps, the order matters. 61 | 62 | 1. assign every bbox to -1 63 | 2. use the max IoU of each group of approxs to assign 64 | 2. assign proposals whose iou with all gts < neg_iou_thr to 0 65 | 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr, 66 | assign it to that bbox 67 | 4. for each gt bbox, assign its nearest proposals (may be more than 68 | one) to itself 69 | 70 | Args: 71 | approxs (Tensor): Bounding boxes to be assigned, 72 | shape(approxs_per_octave*n, 4). 73 | squares (Tensor): Base Bounding boxes to be assigned, 74 | shape(n, 4). 75 | approxs_per_octave (int): number of approxs per octave 76 | gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4). 77 | gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are 78 | labelled as `ignored`, e.g., crowd boxes in COCO. 79 | gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ). 80 | 81 | Returns: 82 | :obj:`AssignResult`: The assign result. 83 | """ 84 | 85 | if squares.shape[0] == 0 or gt_bboxes.shape[0] == 0: 86 | raise ValueError('No gt or approxs') 87 | num_squares = squares.size(0) 88 | num_gts = gt_bboxes.size(0) 89 | # re-organize anchors by approxs_per_octave x num_squares 90 | approxs = torch.transpose( 91 | approxs.view(num_squares, approxs_per_octave, 4), 0, 92 | 1).contiguous().view(-1, 4) 93 | all_overlaps = bbox_overlaps(approxs, gt_bboxes) 94 | 95 | overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares, 96 | num_gts).max(dim=0) 97 | overlaps = torch.transpose(overlaps, 0, 1) 98 | 99 | bboxes = squares[:, :4] 100 | 101 | if (self.ignore_iof_thr > 0) and (gt_bboxes_ignore is not None) and ( 102 | gt_bboxes_ignore.numel() > 0): 103 | if self.ignore_wrt_candidates: 104 | ignore_overlaps = bbox_overlaps(bboxes, 105 | gt_bboxes_ignore, 106 | mode='iof') 107 | ignore_max_overlaps, _ = ignore_overlaps.max(dim=1) 108 | else: 109 | ignore_overlaps = bbox_overlaps(gt_bboxes_ignore, 110 | bboxes, 111 | mode='iof') 112 | ignore_max_overlaps, _ = ignore_overlaps.max(dim=0) 113 | overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1 114 | 115 | assign_result = self.assign_wrt_overlaps(overlaps, gt_labels) 116 | return assign_result 117 | -------------------------------------------------------------------------------- /mmdet/core/bbox/assigners/assign_result.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class AssignResult(object): 5 | 6 | def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): 7 | self.num_gts = num_gts 8 | self.gt_inds = gt_inds 9 | self.max_overlaps = max_overlaps 10 | self.labels = labels 11 | 12 | def add_gt_(self, gt_labels): 13 | self_inds = torch.arange( 14 | 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) 15 | self.gt_inds = torch.cat([self_inds, self.gt_inds]) 16 | self.max_overlaps = torch.cat( 17 | [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) 18 | if self.labels is not None: 19 | self.labels = torch.cat([gt_labels, self.labels]) 20 | -------------------------------------------------------------------------------- /mmdet/core/bbox/assigners/base_assigner.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | 4 | class BaseAssigner(metaclass=ABCMeta): 5 | 6 | @abstractmethod 7 | def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None): 8 | pass 9 | -------------------------------------------------------------------------------- /mmdet/core/bbox/bbox_target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .transforms import bbox2delta 4 | from ..utils import multi_apply 5 | 6 | 7 | def bbox_target(pos_bboxes_list, 8 | neg_bboxes_list, 9 | pos_gt_bboxes_list, 10 | pos_gt_labels_list, 11 | cfg, 12 | reg_classes=1, 13 | target_means=[.0, .0, .0, .0], 14 | target_stds=[1.0, 1.0, 1.0, 1.0], 15 | concat=True): 16 | labels, label_weights, bbox_targets, bbox_weights = multi_apply( 17 | bbox_target_single, 18 | pos_bboxes_list, 19 | neg_bboxes_list, 20 | pos_gt_bboxes_list, 21 | pos_gt_labels_list, 22 | cfg=cfg, 23 | reg_classes=reg_classes, 24 | target_means=target_means, 25 | target_stds=target_stds) 26 | 27 | if concat: 28 | labels = torch.cat(labels, 0) 29 | label_weights = torch.cat(label_weights, 0) 30 | bbox_targets = torch.cat(bbox_targets, 0) 31 | bbox_weights = torch.cat(bbox_weights, 0) 32 | return labels, label_weights, bbox_targets, bbox_weights 33 | 34 | 35 | def bbox_target_single(pos_bboxes, 36 | neg_bboxes, 37 | pos_gt_bboxes, 38 | pos_gt_labels, 39 | cfg, 40 | reg_classes=1, 41 | target_means=[.0, .0, .0, .0], 42 | target_stds=[1.0, 1.0, 1.0, 1.0]): 43 | num_pos = pos_bboxes.size(0) 44 | num_neg = neg_bboxes.size(0) 45 | num_samples = num_pos + num_neg 46 | labels = pos_bboxes.new_zeros(num_samples, dtype=torch.long) 47 | label_weights = pos_bboxes.new_zeros(num_samples) 48 | bbox_targets = pos_bboxes.new_zeros(num_samples, 4) 49 | bbox_weights = pos_bboxes.new_zeros(num_samples, 4) 50 | if num_pos > 0: 51 | labels[:num_pos] = pos_gt_labels 52 | pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight 53 | label_weights[:num_pos] = pos_weight 54 | pos_bbox_targets = bbox2delta(pos_bboxes, pos_gt_bboxes, target_means, 55 | target_stds) 56 | bbox_targets[:num_pos, :] = pos_bbox_targets 57 | bbox_weights[:num_pos, :] = 1 58 | if num_neg > 0: 59 | label_weights[-num_neg:] = 1.0 60 | 61 | return labels, label_weights, bbox_targets, bbox_weights 62 | 63 | 64 | def expand_target(bbox_targets, bbox_weights, labels, num_classes): 65 | bbox_targets_expand = bbox_targets.new_zeros((bbox_targets.size(0), 66 | 4 * num_classes)) 67 | bbox_weights_expand = bbox_weights.new_zeros((bbox_weights.size(0), 68 | 4 * num_classes)) 69 | for i in torch.nonzero(labels > 0).squeeze(-1): 70 | start, end = labels[i] * 4, (labels[i] + 1) * 4 71 | bbox_targets_expand[i, start:end] = bbox_targets[i, :] 72 | bbox_weights_expand[i, start:end] = bbox_weights[i, :] 73 | return bbox_targets_expand, bbox_weights_expand 74 | -------------------------------------------------------------------------------- /mmdet/core/bbox/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False): 5 | """Calculate overlap between two set of bboxes. 6 | 7 | If ``is_aligned`` is ``False``, then calculate the ious between each bbox 8 | of bboxes1 and bboxes2, otherwise the ious between each aligned pair of 9 | bboxes1 and bboxes2. 10 | 11 | Args: 12 | bboxes1 (Tensor): shape (m, 4) 13 | bboxes2 (Tensor): shape (n, 4), if is_aligned is ``True``, then m and n 14 | must be equal. 15 | mode (str): "iou" (intersection over union) or iof (intersection over 16 | foreground). 17 | 18 | Returns: 19 | ious(Tensor): shape (m, n) if is_aligned == False else shape (m, 1) 20 | """ 21 | 22 | assert mode in ['iou', 'iof'] 23 | 24 | rows = bboxes1.size(0) 25 | cols = bboxes2.size(0) 26 | if is_aligned: 27 | assert rows == cols 28 | 29 | if rows * cols == 0: 30 | return bboxes1.new(rows, 1) if is_aligned else bboxes1.new(rows, cols) 31 | 32 | if is_aligned: 33 | lt = torch.max(bboxes1[:, :2], bboxes2[:, :2]) # [rows, 2] 34 | rb = torch.min(bboxes1[:, 2:], bboxes2[:, 2:]) # [rows, 2] 35 | 36 | wh = (rb - lt + 1).clamp(min=0) # [rows, 2] 37 | overlap = wh[:, 0] * wh[:, 1] 38 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( 39 | bboxes1[:, 3] - bboxes1[:, 1] + 1) 40 | 41 | if mode == 'iou': 42 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( 43 | bboxes2[:, 3] - bboxes2[:, 1] + 1) 44 | ious = overlap / (area1 + area2 - overlap) 45 | else: 46 | ious = overlap / area1 47 | else: 48 | lt = torch.max(bboxes1[:, None, :2], bboxes2[:, :2]) # [rows, cols, 2] 49 | rb = torch.min(bboxes1[:, None, 2:], bboxes2[:, 2:]) # [rows, cols, 2] 50 | 51 | wh = (rb - lt + 1).clamp(min=0) # [rows, cols, 2] 52 | overlap = wh[:, :, 0] * wh[:, :, 1] 53 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( 54 | bboxes1[:, 3] - bboxes1[:, 1] + 1) 55 | 56 | if mode == 'iou': 57 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( 58 | bboxes2[:, 3] - bboxes2[:, 1] + 1) 59 | ious = overlap / (area1[:, None] + area2 - overlap) 60 | else: 61 | ious = overlap / (area1[:, None]) 62 | 63 | return ious 64 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import BaseSampler 2 | from .pseudo_sampler import PseudoSampler 3 | from .random_sampler import RandomSampler 4 | from .instance_balanced_pos_sampler import InstanceBalancedPosSampler 5 | from .iou_balanced_neg_sampler import IoUBalancedNegSampler 6 | from .combined_sampler import CombinedSampler 7 | from .ohem_sampler import OHEMSampler 8 | from .sampling_result import SamplingResult 9 | 10 | __all__ = [ 11 | 'BaseSampler', 'PseudoSampler', 'RandomSampler', 12 | 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler', 13 | 'OHEMSampler', 'SamplingResult' 14 | ] 15 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/base_sampler.py: -------------------------------------------------------------------------------- 1 | from abc import ABCMeta, abstractmethod 2 | 3 | import torch 4 | 5 | from .sampling_result import SamplingResult 6 | 7 | 8 | class BaseSampler(metaclass=ABCMeta): 9 | 10 | def __init__(self, 11 | num, 12 | pos_fraction, 13 | neg_pos_ub=-1, 14 | add_gt_as_proposals=True, 15 | **kwargs): 16 | self.num = num 17 | self.pos_fraction = pos_fraction 18 | self.neg_pos_ub = neg_pos_ub 19 | self.add_gt_as_proposals = add_gt_as_proposals 20 | self.pos_sampler = self 21 | self.neg_sampler = self 22 | 23 | @abstractmethod 24 | def _sample_pos(self, assign_result, num_expected, **kwargs): 25 | pass 26 | 27 | @abstractmethod 28 | def _sample_neg(self, assign_result, num_expected, **kwargs): 29 | pass 30 | 31 | def sample(self, 32 | assign_result, 33 | bboxes, 34 | gt_bboxes, 35 | gt_labels=None, 36 | **kwargs): 37 | """Sample positive and negative bboxes. 38 | 39 | This is a simple implementation of bbox sampling given candidates, 40 | assigning results and ground truth bboxes. 41 | 42 | Args: 43 | assign_result (:obj:`AssignResult`): Bbox assigning results. 44 | bboxes (Tensor): Boxes to be sampled from. 45 | gt_bboxes (Tensor): Ground truth bboxes. 46 | gt_labels (Tensor, optional): Class labels of ground truth bboxes. 47 | 48 | Returns: 49 | :obj:`SamplingResult`: Sampling result. 50 | """ 51 | bboxes = bboxes[:, :4] 52 | 53 | gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) 54 | if self.add_gt_as_proposals: 55 | bboxes = torch.cat([gt_bboxes, bboxes], dim=0) 56 | assign_result.add_gt_(gt_labels) 57 | gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) 58 | gt_flags = torch.cat([gt_ones, gt_flags]) 59 | 60 | num_expected_pos = int(self.num * self.pos_fraction) 61 | pos_inds = self.pos_sampler._sample_pos( 62 | assign_result, num_expected_pos, bboxes=bboxes, **kwargs) 63 | # We found that sampled indices have duplicated items occasionally. 64 | # (may be a bug of PyTorch) 65 | pos_inds = pos_inds.unique() 66 | num_sampled_pos = pos_inds.numel() 67 | num_expected_neg = self.num - num_sampled_pos 68 | if self.neg_pos_ub >= 0: 69 | _pos = max(1, num_sampled_pos) 70 | neg_upper_bound = int(self.neg_pos_ub * _pos) 71 | if num_expected_neg > neg_upper_bound: 72 | num_expected_neg = neg_upper_bound 73 | neg_inds = self.neg_sampler._sample_neg( 74 | assign_result, num_expected_neg, bboxes=bboxes, **kwargs) 75 | neg_inds = neg_inds.unique() 76 | 77 | return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, 78 | assign_result, gt_flags) 79 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/combined_sampler.py: -------------------------------------------------------------------------------- 1 | from .base_sampler import BaseSampler 2 | from ..assign_sampling import build_sampler 3 | 4 | 5 | class CombinedSampler(BaseSampler): 6 | 7 | def __init__(self, pos_sampler, neg_sampler, **kwargs): 8 | super(CombinedSampler, self).__init__(**kwargs) 9 | self.pos_sampler = build_sampler(pos_sampler, **kwargs) 10 | self.neg_sampler = build_sampler(neg_sampler, **kwargs) 11 | 12 | def _sample_pos(self, **kwargs): 13 | raise NotImplementedError 14 | 15 | def _sample_neg(self, **kwargs): 16 | raise NotImplementedError 17 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .random_sampler import RandomSampler 5 | 6 | 7 | class InstanceBalancedPosSampler(RandomSampler): 8 | 9 | def _sample_pos(self, assign_result, num_expected, **kwargs): 10 | pos_inds = torch.nonzero(assign_result.gt_inds > 0) 11 | if pos_inds.numel() != 0: 12 | pos_inds = pos_inds.squeeze(1) 13 | if pos_inds.numel() <= num_expected: 14 | return pos_inds 15 | else: 16 | unique_gt_inds = assign_result.gt_inds[pos_inds].unique() 17 | num_gts = len(unique_gt_inds) 18 | num_per_gt = int(round(num_expected / float(num_gts)) + 1) 19 | sampled_inds = [] 20 | for i in unique_gt_inds: 21 | inds = torch.nonzero(assign_result.gt_inds == i.item()) 22 | if inds.numel() != 0: 23 | inds = inds.squeeze(1) 24 | else: 25 | continue 26 | if len(inds) > num_per_gt: 27 | inds = self.random_choice(inds, num_per_gt) 28 | sampled_inds.append(inds) 29 | sampled_inds = torch.cat(sampled_inds) 30 | if len(sampled_inds) < num_expected: 31 | num_extra = num_expected - len(sampled_inds) 32 | extra_inds = np.array( 33 | list(set(pos_inds.cpu()) - set(sampled_inds.cpu()))) 34 | if len(extra_inds) > num_extra: 35 | extra_inds = self.random_choice(extra_inds, num_extra) 36 | extra_inds = torch.from_numpy(extra_inds).to( 37 | assign_result.gt_inds.device).long() 38 | sampled_inds = torch.cat([sampled_inds, extra_inds]) 39 | elif len(sampled_inds) > num_expected: 40 | sampled_inds = self.random_choice(sampled_inds, num_expected) 41 | return sampled_inds 42 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .random_sampler import RandomSampler 5 | 6 | 7 | class IoUBalancedNegSampler(RandomSampler): 8 | """IoU Balanced Sampling 9 | 10 | arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019) 11 | 12 | Sampling proposals according to their IoU. `floor_fraction` of needed RoIs 13 | are sampled from proposals whose IoU are lower than `floor_thr` randomly. 14 | The others are sampled from proposals whose IoU are higher than 15 | `floor_thr`. These proposals are sampled from some bins evenly, which are 16 | split by `num_bins` via IoU evenly. 17 | 18 | Args: 19 | num (int): number of proposals. 20 | pos_fraction (float): fraction of positive proposals. 21 | floor_thr (float): threshold (minimum) IoU for IoU balanced sampling, 22 | set to -1 if all using IoU balanced sampling. 23 | floor_fraction (float): sampling fraction of proposals under floor_thr. 24 | num_bins (int): number of bins in IoU balanced sampling. 25 | """ 26 | 27 | def __init__(self, 28 | num, 29 | pos_fraction, 30 | floor_thr=-1, 31 | floor_fraction=0, 32 | num_bins=3, 33 | **kwargs): 34 | super(IoUBalancedNegSampler, self).__init__(num, pos_fraction, 35 | **kwargs) 36 | assert floor_thr >= 0 or floor_thr == -1 37 | assert 0 <= floor_fraction <= 1 38 | assert num_bins >= 1 39 | 40 | self.floor_thr = floor_thr 41 | self.floor_fraction = floor_fraction 42 | self.num_bins = num_bins 43 | 44 | def sample_via_interval(self, max_overlaps, full_set, num_expected): 45 | max_iou = max_overlaps.max() 46 | iou_interval = (max_iou - self.floor_thr) / self.num_bins 47 | per_num_expected = int(num_expected / self.num_bins) 48 | 49 | sampled_inds = [] 50 | for i in range(self.num_bins): 51 | start_iou = self.floor_thr + i * iou_interval 52 | end_iou = self.floor_thr + (i + 1) * iou_interval 53 | tmp_set = set( 54 | np.where( 55 | np.logical_and(max_overlaps >= start_iou, 56 | max_overlaps < end_iou))[0]) 57 | tmp_inds = list(tmp_set & full_set) 58 | if len(tmp_inds) > per_num_expected: 59 | tmp_sampled_set = self.random_choice(tmp_inds, 60 | per_num_expected) 61 | else: 62 | tmp_sampled_set = np.array(tmp_inds, dtype=np.int) 63 | sampled_inds.append(tmp_sampled_set) 64 | 65 | sampled_inds = np.concatenate(sampled_inds) 66 | if len(sampled_inds) < num_expected: 67 | num_extra = num_expected - len(sampled_inds) 68 | extra_inds = np.array(list(full_set - set(sampled_inds))) 69 | if len(extra_inds) > num_extra: 70 | extra_inds = self.random_choice(extra_inds, num_extra) 71 | sampled_inds = np.concatenate([sampled_inds, extra_inds]) 72 | 73 | return sampled_inds 74 | 75 | def _sample_neg(self, assign_result, num_expected, **kwargs): 76 | neg_inds = torch.nonzero(assign_result.gt_inds == 0) 77 | if neg_inds.numel() != 0: 78 | neg_inds = neg_inds.squeeze(1) 79 | if len(neg_inds) <= num_expected: 80 | return neg_inds 81 | else: 82 | max_overlaps = assign_result.max_overlaps.cpu().numpy() 83 | # balance sampling for negative samples 84 | neg_set = set(neg_inds.cpu().numpy()) 85 | 86 | if self.floor_thr > 0: 87 | floor_set = set( 88 | np.where( 89 | np.logical_and(max_overlaps >= 0, 90 | max_overlaps < self.floor_thr))[0]) 91 | iou_sampling_set = set( 92 | np.where(max_overlaps >= self.floor_thr)[0]) 93 | elif self.floor_thr == 0: 94 | floor_set = set(np.where(max_overlaps == 0)[0]) 95 | iou_sampling_set = set( 96 | np.where(max_overlaps > self.floor_thr)[0]) 97 | else: 98 | floor_set = set() 99 | iou_sampling_set = set( 100 | np.where(max_overlaps > self.floor_thr)[0]) 101 | 102 | floor_neg_inds = list(floor_set & neg_set) 103 | iou_sampling_neg_inds = list(iou_sampling_set & neg_set) 104 | num_expected_iou_sampling = int(num_expected * 105 | (1 - self.floor_fraction)) 106 | if len(iou_sampling_neg_inds) > num_expected_iou_sampling: 107 | if self.num_bins >= 2: 108 | iou_sampled_inds = self.sample_via_interval( 109 | max_overlaps, set(iou_sampling_neg_inds), 110 | num_expected_iou_sampling) 111 | else: 112 | iou_sampled_inds = self.random_choice( 113 | iou_sampling_neg_inds, num_expected_iou_sampling) 114 | else: 115 | iou_sampled_inds = np.array( 116 | iou_sampling_neg_inds, dtype=np.int) 117 | num_expected_floor = num_expected - len(iou_sampled_inds) 118 | if len(floor_neg_inds) > num_expected_floor: 119 | sampled_floor_inds = self.random_choice( 120 | floor_neg_inds, num_expected_floor) 121 | else: 122 | sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int) 123 | sampled_inds = np.concatenate( 124 | (sampled_floor_inds, iou_sampled_inds)) 125 | if len(sampled_inds) < num_expected: 126 | num_extra = num_expected - len(sampled_inds) 127 | extra_inds = np.array(list(neg_set - set(sampled_inds))) 128 | if len(extra_inds) > num_extra: 129 | extra_inds = self.random_choice(extra_inds, num_extra) 130 | sampled_inds = np.concatenate((sampled_inds, extra_inds)) 131 | sampled_inds = torch.from_numpy(sampled_inds).long().to( 132 | assign_result.gt_inds.device) 133 | return sampled_inds 134 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/ohem_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_sampler import BaseSampler 4 | from ..transforms import bbox2roi 5 | 6 | 7 | class OHEMSampler(BaseSampler): 8 | 9 | def __init__(self, 10 | num, 11 | pos_fraction, 12 | context, 13 | neg_pos_ub=-1, 14 | add_gt_as_proposals=True, 15 | **kwargs): 16 | super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub, 17 | add_gt_as_proposals) 18 | if not hasattr(context, 'num_stages'): 19 | self.bbox_roi_extractor = context.bbox_roi_extractor 20 | self.bbox_head = context.bbox_head 21 | else: 22 | self.bbox_roi_extractor = context.bbox_roi_extractor[ 23 | context.current_stage] 24 | self.bbox_head = context.bbox_head[context.current_stage] 25 | 26 | def hard_mining(self, inds, num_expected, bboxes, labels, feats): 27 | with torch.no_grad(): 28 | rois = bbox2roi([bboxes]) 29 | bbox_feats = self.bbox_roi_extractor( 30 | feats[:self.bbox_roi_extractor.num_inputs], rois) 31 | cls_score, _ = self.bbox_head(bbox_feats) 32 | loss = self.bbox_head.loss( 33 | cls_score=cls_score, 34 | bbox_pred=None, 35 | labels=labels, 36 | label_weights=cls_score.new_ones(cls_score.size(0)), 37 | bbox_targets=None, 38 | bbox_weights=None, 39 | reduction_override='none')['loss_cls'] 40 | _, topk_loss_inds = loss.topk(num_expected) 41 | return inds[topk_loss_inds] 42 | 43 | def _sample_pos(self, 44 | assign_result, 45 | num_expected, 46 | bboxes=None, 47 | feats=None, 48 | **kwargs): 49 | # Sample some hard positive samples 50 | pos_inds = torch.nonzero(assign_result.gt_inds > 0) 51 | if pos_inds.numel() != 0: 52 | pos_inds = pos_inds.squeeze(1) 53 | if pos_inds.numel() <= num_expected: 54 | return pos_inds 55 | else: 56 | return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds], 57 | assign_result.labels[pos_inds], feats) 58 | 59 | def _sample_neg(self, 60 | assign_result, 61 | num_expected, 62 | bboxes=None, 63 | feats=None, 64 | **kwargs): 65 | # Sample some hard negative samples 66 | neg_inds = torch.nonzero(assign_result.gt_inds == 0) 67 | if neg_inds.numel() != 0: 68 | neg_inds = neg_inds.squeeze(1) 69 | if len(neg_inds) <= num_expected: 70 | return neg_inds 71 | else: 72 | return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds], 73 | assign_result.labels[neg_inds], feats) 74 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/pseudo_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .base_sampler import BaseSampler 4 | from .sampling_result import SamplingResult 5 | 6 | 7 | class PseudoSampler(BaseSampler): 8 | 9 | def __init__(self, **kwargs): 10 | pass 11 | 12 | def _sample_pos(self, **kwargs): 13 | raise NotImplementedError 14 | 15 | def _sample_neg(self, **kwargs): 16 | raise NotImplementedError 17 | 18 | def sample(self, assign_result, bboxes, gt_bboxes, **kwargs): 19 | pos_inds = torch.nonzero( 20 | assign_result.gt_inds > 0).squeeze(-1).unique() 21 | neg_inds = torch.nonzero( 22 | assign_result.gt_inds == 0).squeeze(-1).unique() 23 | gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8) 24 | sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, 25 | assign_result, gt_flags) 26 | return sampling_result 27 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/random_sampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from .base_sampler import BaseSampler 5 | 6 | 7 | class RandomSampler(BaseSampler): 8 | 9 | def __init__(self, 10 | num, 11 | pos_fraction, 12 | neg_pos_ub=-1, 13 | add_gt_as_proposals=True, 14 | **kwargs): 15 | super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, 16 | add_gt_as_proposals) 17 | 18 | @staticmethod 19 | def random_choice(gallery, num): 20 | """Random select some elements from the gallery. 21 | 22 | It seems that Pytorch's implementation is slower than numpy so we use 23 | numpy to randperm the indices. 24 | """ 25 | assert len(gallery) >= num 26 | if isinstance(gallery, list): 27 | gallery = np.array(gallery) 28 | cands = np.arange(len(gallery)) 29 | np.random.shuffle(cands) 30 | rand_inds = cands[:num] 31 | if not isinstance(gallery, np.ndarray): 32 | rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) 33 | return gallery[rand_inds] 34 | 35 | def _sample_pos(self, assign_result, num_expected, **kwargs): 36 | """Randomly sample some positive samples.""" 37 | pos_inds = torch.nonzero(assign_result.gt_inds > 0) 38 | if pos_inds.numel() != 0: 39 | pos_inds = pos_inds.squeeze(1) 40 | if pos_inds.numel() <= num_expected: 41 | return pos_inds 42 | else: 43 | return self.random_choice(pos_inds, num_expected) 44 | 45 | def _sample_neg(self, assign_result, num_expected, **kwargs): 46 | """Randomly sample some negative samples.""" 47 | neg_inds = torch.nonzero(assign_result.gt_inds == 0) 48 | if neg_inds.numel() != 0: 49 | neg_inds = neg_inds.squeeze(1) 50 | if len(neg_inds) <= num_expected: 51 | return neg_inds 52 | else: 53 | return self.random_choice(neg_inds, num_expected) 54 | -------------------------------------------------------------------------------- /mmdet/core/bbox/samplers/sampling_result.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SamplingResult(object): 5 | 6 | def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, 7 | gt_flags): 8 | self.pos_inds = pos_inds 9 | self.neg_inds = neg_inds 10 | self.pos_bboxes = bboxes[pos_inds] 11 | self.neg_bboxes = bboxes[neg_inds] 12 | self.pos_is_gt = gt_flags[pos_inds] 13 | 14 | self.num_gts = gt_bboxes.shape[0] 15 | self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 16 | self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :] 17 | if assign_result.labels is not None: 18 | self.pos_gt_labels = assign_result.labels[pos_inds] 19 | else: 20 | self.pos_gt_labels = None 21 | 22 | @property 23 | def bboxes(self): 24 | return torch.cat([self.pos_bboxes, self.neg_bboxes]) 25 | -------------------------------------------------------------------------------- /mmdet/core/bbox/transforms.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def bbox2delta(proposals, gt, means=[0, 0, 0, 0], stds=[1, 1, 1, 1]): 7 | assert proposals.size() == gt.size() 8 | 9 | proposals = proposals.float() 10 | gt = gt.float() 11 | px = (proposals[..., 0] + proposals[..., 2]) * 0.5 12 | py = (proposals[..., 1] + proposals[..., 3]) * 0.5 13 | pw = proposals[..., 2] - proposals[..., 0] + 1.0 14 | ph = proposals[..., 3] - proposals[..., 1] + 1.0 15 | 16 | gx = (gt[..., 0] + gt[..., 2]) * 0.5 17 | gy = (gt[..., 1] + gt[..., 3]) * 0.5 18 | gw = gt[..., 2] - gt[..., 0] + 1.0 19 | gh = gt[..., 3] - gt[..., 1] + 1.0 20 | 21 | dx = (gx - px) / pw 22 | dy = (gy - py) / ph 23 | dw = torch.log(gw / pw) 24 | dh = torch.log(gh / ph) 25 | deltas = torch.stack([dx, dy, dw, dh], dim=-1) 26 | 27 | means = deltas.new_tensor(means).unsqueeze(0) 28 | stds = deltas.new_tensor(stds).unsqueeze(0) 29 | deltas = deltas.sub_(means).div_(stds) 30 | 31 | return deltas 32 | 33 | 34 | def delta2bbox(rois, 35 | deltas, 36 | means=[0, 0, 0, 0], 37 | stds=[1, 1, 1, 1], 38 | max_shape=None, 39 | wh_ratio_clip=16 / 1000): 40 | means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) 41 | stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) 42 | denorm_deltas = deltas * stds + means 43 | dx = denorm_deltas[:, 0::4] 44 | dy = denorm_deltas[:, 1::4] 45 | dw = denorm_deltas[:, 2::4] 46 | dh = denorm_deltas[:, 3::4] 47 | max_ratio = np.abs(np.log(wh_ratio_clip)) 48 | dw = dw.clamp(min=-max_ratio, max=max_ratio) 49 | dh = dh.clamp(min=-max_ratio, max=max_ratio) 50 | px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) 51 | py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) 52 | pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw) 53 | ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh) 54 | gw = pw * dw.exp() 55 | gh = ph * dh.exp() 56 | gx = torch.addcmul(px, 1, pw, dx) # gx = px + pw * dx 57 | gy = torch.addcmul(py, 1, ph, dy) # gy = py + ph * dy 58 | x1 = gx - gw * 0.5 + 0.5 59 | y1 = gy - gh * 0.5 + 0.5 60 | x2 = gx + gw * 0.5 - 0.5 61 | y2 = gy + gh * 0.5 - 0.5 62 | if max_shape is not None: 63 | x1 = x1.clamp(min=0, max=max_shape[1] - 1) 64 | y1 = y1.clamp(min=0, max=max_shape[0] - 1) 65 | x2 = x2.clamp(min=0, max=max_shape[1] - 1) 66 | y2 = y2.clamp(min=0, max=max_shape[0] - 1) 67 | bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) 68 | return bboxes 69 | 70 | 71 | def bbox_flip(bboxes, img_shape): 72 | """Flip bboxes horizontally. 73 | 74 | Args: 75 | bboxes(Tensor or ndarray): Shape (..., 4*k) 76 | img_shape(tuple): Image shape. 77 | 78 | Returns: 79 | Same type as `bboxes`: Flipped bboxes. 80 | """ 81 | if isinstance(bboxes, torch.Tensor): 82 | assert bboxes.shape[-1] % 4 == 0 83 | flipped = bboxes.clone() 84 | flipped[:, 0::4] = img_shape[1] - bboxes[:, 2::4] - 1 85 | flipped[:, 2::4] = img_shape[1] - bboxes[:, 0::4] - 1 86 | return flipped 87 | elif isinstance(bboxes, np.ndarray): 88 | return mmcv.bbox_flip(bboxes, img_shape) 89 | 90 | 91 | def bbox_mapping(bboxes, img_shape, scale_factor, flip): 92 | """Map bboxes from the original image scale to testing scale""" 93 | new_bboxes = bboxes * scale_factor 94 | if flip: 95 | new_bboxes = bbox_flip(new_bboxes, img_shape) 96 | return new_bboxes 97 | 98 | 99 | def bbox_mapping_back(bboxes, img_shape, scale_factor, flip): 100 | """Map bboxes from testing scale to original image scale""" 101 | new_bboxes = bbox_flip(bboxes, img_shape) if flip else bboxes 102 | new_bboxes = new_bboxes / scale_factor 103 | return new_bboxes 104 | 105 | 106 | def bbox2roi(bbox_list): 107 | """Convert a list of bboxes to roi format. 108 | 109 | Args: 110 | bbox_list (list[Tensor]): a list of bboxes corresponding to a batch 111 | of images. 112 | 113 | Returns: 114 | Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2] 115 | """ 116 | rois_list = [] 117 | for img_id, bboxes in enumerate(bbox_list): 118 | if bboxes.size(0) > 0: 119 | img_inds = bboxes.new_full((bboxes.size(0), 1), img_id) 120 | rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1) 121 | else: 122 | rois = bboxes.new_zeros((0, 5)) 123 | rois_list.append(rois) 124 | rois = torch.cat(rois_list, 0) 125 | return rois 126 | 127 | 128 | def roi2bbox(rois): 129 | bbox_list = [] 130 | img_ids = torch.unique(rois[:, 0].cpu(), sorted=True) 131 | for img_id in img_ids: 132 | inds = (rois[:, 0] == img_id.item()) 133 | bbox = rois[inds, 1:] 134 | bbox_list.append(bbox) 135 | return bbox_list 136 | 137 | 138 | def bbox2result(bboxes, labels, num_classes): 139 | """Convert detection results to a list of numpy arrays. 140 | 141 | Args: 142 | bboxes (Tensor): shape (n, 5) 143 | labels (Tensor): shape (n, ) 144 | num_classes (int): class number, including background class 145 | 146 | Returns: 147 | list(ndarray): bbox results of each class 148 | """ 149 | if bboxes.shape[0] == 0: 150 | return [ 151 | np.zeros((0, 5), dtype=np.float32) for i in range(num_classes - 1) 152 | ] 153 | else: 154 | bboxes = bboxes.cpu().numpy() 155 | labels = labels.cpu().numpy() 156 | return [bboxes[labels == i, :] for i in range(num_classes - 1)] 157 | 158 | 159 | def distance2bbox(points, distance, max_shape=None): 160 | """Decode distance prediction to bounding box. 161 | 162 | Args: 163 | points (Tensor): Shape (n, 2), [x, y]. 164 | distance (Tensor): Distance from the given point to 4 165 | boundaries (left, top, right, bottom). 166 | max_shape (tuple): Shape of the image. 167 | 168 | Returns: 169 | Tensor: Decoded bboxes. 170 | """ 171 | x1 = points[:, 0] - distance[:, 0] 172 | y1 = points[:, 1] - distance[:, 1] 173 | x2 = points[:, 0] + distance[:, 2] 174 | y2 = points[:, 1] + distance[:, 3] 175 | if max_shape is not None: 176 | x1 = x1.clamp(min=0, max=max_shape[1] - 1) 177 | y1 = y1.clamp(min=0, max=max_shape[0] - 1) 178 | x2 = x2.clamp(min=0, max=max_shape[1] - 1) 179 | y2 = y2.clamp(min=0, max=max_shape[0] - 1) 180 | return torch.stack([x1, y1, x2, y2], -1) 181 | -------------------------------------------------------------------------------- /mmdet/core/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .class_names import (voc_classes, imagenet_det_classes, 2 | imagenet_vid_classes, coco_classes, dataset_aliases, 3 | get_classes) 4 | from .coco_utils import coco_eval, fast_eval_recall, results2json 5 | from .eval_hooks import (DistEvalHook, DistEvalmAPHook, CocoDistEvalRecallHook, 6 | CocoDistEvalmAPHook) 7 | from .mean_ap import average_precision, eval_map, print_map_summary 8 | from .recall import (eval_recalls, print_recall_summary, plot_num_recall, 9 | plot_iou_recall) 10 | 11 | __all__ = [ 12 | 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes', 13 | 'coco_classes', 'dataset_aliases', 'get_classes', 'coco_eval', 14 | 'fast_eval_recall', 'results2json', 'DistEvalHook', 'DistEvalmAPHook', 15 | 'CocoDistEvalRecallHook', 'CocoDistEvalmAPHook', 'average_precision', 16 | 'eval_map', 'print_map_summary', 'eval_recalls', 'print_recall_summary', 17 | 'plot_num_recall', 'plot_iou_recall' 18 | ] 19 | -------------------------------------------------------------------------------- /mmdet/core/evaluation/bbox_overlaps.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def bbox_overlaps(bboxes1, bboxes2, mode='iou'): 5 | """Calculate the ious between each bbox of bboxes1 and bboxes2. 6 | 7 | Args: 8 | bboxes1(ndarray): shape (n, 4) 9 | bboxes2(ndarray): shape (k, 4) 10 | mode(str): iou (intersection over union) or iof (intersection 11 | over foreground) 12 | 13 | Returns: 14 | ious(ndarray): shape (n, k) 15 | """ 16 | 17 | assert mode in ['iou', 'iof'] 18 | 19 | bboxes1 = bboxes1.astype(np.float32) 20 | bboxes2 = bboxes2.astype(np.float32) 21 | rows = bboxes1.shape[0] 22 | cols = bboxes2.shape[0] 23 | ious = np.zeros((rows, cols), dtype=np.float32) 24 | if rows * cols == 0: 25 | return ious 26 | exchange = False 27 | if bboxes1.shape[0] > bboxes2.shape[0]: 28 | bboxes1, bboxes2 = bboxes2, bboxes1 29 | ious = np.zeros((cols, rows), dtype=np.float32) 30 | exchange = True 31 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * ( 32 | bboxes1[:, 3] - bboxes1[:, 1] + 1) 33 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * ( 34 | bboxes2[:, 3] - bboxes2[:, 1] + 1) 35 | for i in range(bboxes1.shape[0]): 36 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 37 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 38 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 39 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 40 | overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( 41 | y_end - y_start + 1, 0) 42 | if mode == 'iou': 43 | union = area1[i] + area2 - overlap 44 | else: 45 | union = area1[i] if not exchange else area2 46 | ious[i, :] = overlap / union 47 | if exchange: 48 | ious = ious.T 49 | return ious 50 | -------------------------------------------------------------------------------- /mmdet/core/evaluation/class_names.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | 4 | def wider_face_classes(): 5 | return ['face'] 6 | 7 | 8 | def voc_classes(): 9 | return [ 10 | 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 11 | 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 12 | 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor' 13 | ] 14 | 15 | 16 | def imagenet_det_classes(): 17 | return [ 18 | 'accordion', 'airplane', 'ant', 'antelope', 'apple', 'armadillo', 19 | 'artichoke', 'axe', 'baby_bed', 'backpack', 'bagel', 'balance_beam', 20 | 'banana', 'band_aid', 'banjo', 'baseball', 'basketball', 'bathing_cap', 21 | 'beaker', 'bear', 'bee', 'bell_pepper', 'bench', 'bicycle', 'binder', 22 | 'bird', 'bookshelf', 'bow_tie', 'bow', 'bowl', 'brassiere', 'burrito', 23 | 'bus', 'butterfly', 'camel', 'can_opener', 'car', 'cart', 'cattle', 24 | 'cello', 'centipede', 'chain_saw', 'chair', 'chime', 'cocktail_shaker', 25 | 'coffee_maker', 'computer_keyboard', 'computer_mouse', 'corkscrew', 26 | 'cream', 'croquet_ball', 'crutch', 'cucumber', 'cup_or_mug', 'diaper', 27 | 'digital_clock', 'dishwasher', 'dog', 'domestic_cat', 'dragonfly', 28 | 'drum', 'dumbbell', 'electric_fan', 'elephant', 'face_powder', 'fig', 29 | 'filing_cabinet', 'flower_pot', 'flute', 'fox', 'french_horn', 'frog', 30 | 'frying_pan', 'giant_panda', 'goldfish', 'golf_ball', 'golfcart', 31 | 'guacamole', 'guitar', 'hair_dryer', 'hair_spray', 'hamburger', 32 | 'hammer', 'hamster', 'harmonica', 'harp', 'hat_with_a_wide_brim', 33 | 'head_cabbage', 'helmet', 'hippopotamus', 'horizontal_bar', 'horse', 34 | 'hotdog', 'iPod', 'isopod', 'jellyfish', 'koala_bear', 'ladle', 35 | 'ladybug', 'lamp', 'laptop', 'lemon', 'lion', 'lipstick', 'lizard', 36 | 'lobster', 'maillot', 'maraca', 'microphone', 'microwave', 'milk_can', 37 | 'miniskirt', 'monkey', 'motorcycle', 'mushroom', 'nail', 'neck_brace', 38 | 'oboe', 'orange', 'otter', 'pencil_box', 'pencil_sharpener', 'perfume', 39 | 'person', 'piano', 'pineapple', 'ping-pong_ball', 'pitcher', 'pizza', 40 | 'plastic_bag', 'plate_rack', 'pomegranate', 'popsicle', 'porcupine', 41 | 'power_drill', 'pretzel', 'printer', 'puck', 'punching_bag', 'purse', 42 | 'rabbit', 'racket', 'ray', 'red_panda', 'refrigerator', 43 | 'remote_control', 'rubber_eraser', 'rugby_ball', 'ruler', 44 | 'salt_or_pepper_shaker', 'saxophone', 'scorpion', 'screwdriver', 45 | 'seal', 'sheep', 'ski', 'skunk', 'snail', 'snake', 'snowmobile', 46 | 'snowplow', 'soap_dispenser', 'soccer_ball', 'sofa', 'spatula', 47 | 'squirrel', 'starfish', 'stethoscope', 'stove', 'strainer', 48 | 'strawberry', 'stretcher', 'sunglasses', 'swimming_trunks', 'swine', 49 | 'syringe', 'table', 'tape_player', 'tennis_ball', 'tick', 'tie', 50 | 'tiger', 'toaster', 'traffic_light', 'train', 'trombone', 'trumpet', 51 | 'turtle', 'tv_or_monitor', 'unicycle', 'vacuum', 'violin', 52 | 'volleyball', 'waffle_iron', 'washer', 'water_bottle', 'watercraft', 53 | 'whale', 'wine_bottle', 'zebra' 54 | ] 55 | 56 | 57 | def imagenet_vid_classes(): 58 | return [ 59 | 'airplane', 'antelope', 'bear', 'bicycle', 'bird', 'bus', 'car', 60 | 'cattle', 'dog', 'domestic_cat', 'elephant', 'fox', 'giant_panda', 61 | 'hamster', 'horse', 'lion', 'lizard', 'monkey', 'motorcycle', 'rabbit', 62 | 'red_panda', 'sheep', 'snake', 'squirrel', 'tiger', 'train', 'turtle', 63 | 'watercraft', 'whale', 'zebra' 64 | ] 65 | 66 | 67 | def coco_classes(): 68 | return [ 69 | 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 70 | 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign', 71 | 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 72 | 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 73 | 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 74 | 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard', 75 | 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork', 76 | 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 77 | 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair', 78 | 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 79 | 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', 80 | 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 81 | 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush' 82 | ] 83 | 84 | 85 | dataset_aliases = { 86 | 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'], 87 | 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'], 88 | 'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'], 89 | 'coco': ['coco', 'mscoco', 'ms_coco'], 90 | 'wider_face': ['WIDERFaceDataset', 'wider_face', 'WDIERFace'] 91 | } 92 | 93 | 94 | def get_classes(dataset): 95 | """Get class names of a dataset.""" 96 | alias2name = {} 97 | for name, aliases in dataset_aliases.items(): 98 | for alias in aliases: 99 | alias2name[alias] = name 100 | 101 | if mmcv.is_str(dataset): 102 | if dataset in alias2name: 103 | labels = eval(alias2name[dataset] + '_classes()') 104 | else: 105 | raise ValueError('Unrecognized dataset: {}'.format(dataset)) 106 | else: 107 | raise TypeError('dataset must a str, but got {}'.format(type(dataset))) 108 | return labels 109 | -------------------------------------------------------------------------------- /mmdet/core/evaluation/recall.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from terminaltables import AsciiTable 3 | 4 | from .bbox_overlaps import bbox_overlaps 5 | 6 | 7 | def _recalls(all_ious, proposal_nums, thrs): 8 | 9 | img_num = all_ious.shape[0] 10 | total_gt_num = sum([ious.shape[0] for ious in all_ious]) 11 | 12 | _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32) 13 | for k, proposal_num in enumerate(proposal_nums): 14 | tmp_ious = np.zeros(0) 15 | for i in range(img_num): 16 | ious = all_ious[i][:, :proposal_num].copy() 17 | gt_ious = np.zeros((ious.shape[0])) 18 | if ious.size == 0: 19 | tmp_ious = np.hstack((tmp_ious, gt_ious)) 20 | continue 21 | for j in range(ious.shape[0]): 22 | gt_max_overlaps = ious.argmax(axis=1) 23 | max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps] 24 | gt_idx = max_ious.argmax() 25 | gt_ious[j] = max_ious[gt_idx] 26 | box_idx = gt_max_overlaps[gt_idx] 27 | ious[gt_idx, :] = -1 28 | ious[:, box_idx] = -1 29 | tmp_ious = np.hstack((tmp_ious, gt_ious)) 30 | _ious[k, :] = tmp_ious 31 | 32 | _ious = np.fliplr(np.sort(_ious, axis=1)) 33 | recalls = np.zeros((proposal_nums.size, thrs.size)) 34 | for i, thr in enumerate(thrs): 35 | recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num) 36 | 37 | return recalls 38 | 39 | 40 | def set_recall_param(proposal_nums, iou_thrs): 41 | """Check proposal_nums and iou_thrs and set correct format. 42 | """ 43 | if isinstance(proposal_nums, list): 44 | _proposal_nums = np.array(proposal_nums) 45 | elif isinstance(proposal_nums, int): 46 | _proposal_nums = np.array([proposal_nums]) 47 | else: 48 | _proposal_nums = proposal_nums 49 | 50 | if iou_thrs is None: 51 | _iou_thrs = np.array([0.5]) 52 | elif isinstance(iou_thrs, list): 53 | _iou_thrs = np.array(iou_thrs) 54 | elif isinstance(iou_thrs, float): 55 | _iou_thrs = np.array([iou_thrs]) 56 | else: 57 | _iou_thrs = iou_thrs 58 | 59 | return _proposal_nums, _iou_thrs 60 | 61 | 62 | def eval_recalls(gts, 63 | proposals, 64 | proposal_nums=None, 65 | iou_thrs=None, 66 | print_summary=True): 67 | """Calculate recalls. 68 | 69 | Args: 70 | gts(list or ndarray): a list of arrays of shape (n, 4) 71 | proposals(list or ndarray): a list of arrays of shape (k, 4) or (k, 5) 72 | proposal_nums(int or list of int or ndarray): top N proposals 73 | thrs(float or list or ndarray): iou thresholds 74 | 75 | Returns: 76 | ndarray: recalls of different ious and proposal nums 77 | """ 78 | 79 | img_num = len(gts) 80 | assert img_num == len(proposals) 81 | 82 | proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs) 83 | 84 | all_ious = [] 85 | for i in range(img_num): 86 | if proposals[i].ndim == 2 and proposals[i].shape[1] == 5: 87 | scores = proposals[i][:, 4] 88 | sort_idx = np.argsort(scores)[::-1] 89 | img_proposal = proposals[i][sort_idx, :] 90 | else: 91 | img_proposal = proposals[i] 92 | prop_num = min(img_proposal.shape[0], proposal_nums[-1]) 93 | if gts[i] is None or gts[i].shape[0] == 0: 94 | ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32) 95 | else: 96 | ious = bbox_overlaps(gts[i], img_proposal[:prop_num, :4]) 97 | all_ious.append(ious) 98 | all_ious = np.array(all_ious) 99 | recalls = _recalls(all_ious, proposal_nums, iou_thrs) 100 | if print_summary: 101 | print_recall_summary(recalls, proposal_nums, iou_thrs) 102 | return recalls 103 | 104 | 105 | def print_recall_summary(recalls, 106 | proposal_nums, 107 | iou_thrs, 108 | row_idxs=None, 109 | col_idxs=None): 110 | """Print recalls in a table. 111 | 112 | Args: 113 | recalls(ndarray): calculated from `bbox_recalls` 114 | proposal_nums(ndarray or list): top N proposals 115 | iou_thrs(ndarray or list): iou thresholds 116 | row_idxs(ndarray): which rows(proposal nums) to print 117 | col_idxs(ndarray): which cols(iou thresholds) to print 118 | """ 119 | proposal_nums = np.array(proposal_nums, dtype=np.int32) 120 | iou_thrs = np.array(iou_thrs) 121 | if row_idxs is None: 122 | row_idxs = np.arange(proposal_nums.size) 123 | if col_idxs is None: 124 | col_idxs = np.arange(iou_thrs.size) 125 | row_header = [''] + iou_thrs[col_idxs].tolist() 126 | table_data = [row_header] 127 | for i, num in enumerate(proposal_nums[row_idxs]): 128 | row = [ 129 | '{:.3f}'.format(val) 130 | for val in recalls[row_idxs[i], col_idxs].tolist() 131 | ] 132 | row.insert(0, num) 133 | table_data.append(row) 134 | table = AsciiTable(table_data) 135 | print(table.table) 136 | 137 | 138 | def plot_num_recall(recalls, proposal_nums): 139 | """Plot Proposal_num-Recalls curve. 140 | 141 | Args: 142 | recalls(ndarray or list): shape (k,) 143 | proposal_nums(ndarray or list): same shape as `recalls` 144 | """ 145 | if isinstance(proposal_nums, np.ndarray): 146 | _proposal_nums = proposal_nums.tolist() 147 | else: 148 | _proposal_nums = proposal_nums 149 | if isinstance(recalls, np.ndarray): 150 | _recalls = recalls.tolist() 151 | else: 152 | _recalls = recalls 153 | 154 | import matplotlib.pyplot as plt 155 | f = plt.figure() 156 | plt.plot([0] + _proposal_nums, [0] + _recalls) 157 | plt.xlabel('Proposal num') 158 | plt.ylabel('Recall') 159 | plt.axis([0, proposal_nums.max(), 0, 1]) 160 | f.show() 161 | 162 | 163 | def plot_iou_recall(recalls, iou_thrs): 164 | """Plot IoU-Recalls curve. 165 | 166 | Args: 167 | recalls(ndarray or list): shape (k,) 168 | iou_thrs(ndarray or list): same shape as `recalls` 169 | """ 170 | if isinstance(iou_thrs, np.ndarray): 171 | _iou_thrs = iou_thrs.tolist() 172 | else: 173 | _iou_thrs = iou_thrs 174 | if isinstance(recalls, np.ndarray): 175 | _recalls = recalls.tolist() 176 | else: 177 | _recalls = recalls 178 | 179 | import matplotlib.pyplot as plt 180 | f = plt.figure() 181 | plt.plot(_iou_thrs + [1.0], _recalls + [0.]) 182 | plt.xlabel('IoU') 183 | plt.ylabel('Recall') 184 | plt.axis([iou_thrs.min(), 1, 0, 1]) 185 | f.show() 186 | -------------------------------------------------------------------------------- /mmdet/core/fp16/__init__.py: -------------------------------------------------------------------------------- 1 | from .decorators import auto_fp16, force_fp32 2 | from .hooks import Fp16OptimizerHook, wrap_fp16_model 3 | 4 | __all__ = ['auto_fp16', 'force_fp32', 'Fp16OptimizerHook', 'wrap_fp16_model'] 5 | -------------------------------------------------------------------------------- /mmdet/core/fp16/decorators.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from inspect import getfullargspec 3 | 4 | import torch 5 | 6 | from .utils import cast_tensor_type 7 | 8 | 9 | def auto_fp16(apply_to=None, out_fp32=False): 10 | """Decorator to enable fp16 training automatically. 11 | 12 | This decorator is useful when you write custom modules and want to support 13 | mixed precision training. If inputs arguments are fp32 tensors, they will 14 | be converted to fp16 automatically. Arguments other than fp32 tensors are 15 | ignored. 16 | 17 | Args: 18 | apply_to (Iterable, optional): The argument names to be converted. 19 | `None` indicates all arguments. 20 | out_fp32 (bool): Whether to convert the output back to fp32. 21 | 22 | :Example: 23 | 24 | class MyModule1(nn.Module) 25 | 26 | # Convert x and y to fp16 27 | @auto_fp16() 28 | def forward(self, x, y): 29 | pass 30 | 31 | class MyModule2(nn.Module): 32 | 33 | # convert pred to fp16 34 | @auto_fp16(apply_to=('pred', )) 35 | def do_something(self, pred, others): 36 | pass 37 | """ 38 | 39 | def auto_fp16_wrapper(old_func): 40 | 41 | @functools.wraps(old_func) 42 | def new_func(*args, **kwargs): 43 | # check if the module has set the attribute `fp16_enabled`, if not, 44 | # just fallback to the original method. 45 | if not isinstance(args[0], torch.nn.Module): 46 | raise TypeError('@auto_fp16 can only be used to decorate the ' 47 | 'method of nn.Module') 48 | if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): 49 | return old_func(*args, **kwargs) 50 | # get the arg spec of the decorated method 51 | args_info = getfullargspec(old_func) 52 | # get the argument names to be casted 53 | args_to_cast = args_info.args if apply_to is None else apply_to 54 | # convert the args that need to be processed 55 | new_args = [] 56 | # NOTE: default args are not taken into consideration 57 | if args: 58 | arg_names = args_info.args[:len(args)] 59 | for i, arg_name in enumerate(arg_names): 60 | if arg_name in args_to_cast: 61 | new_args.append( 62 | cast_tensor_type(args[i], torch.float, torch.half)) 63 | else: 64 | new_args.append(args[i]) 65 | # convert the kwargs that need to be processed 66 | new_kwargs = {} 67 | if kwargs: 68 | for arg_name, arg_value in kwargs.items(): 69 | if arg_name in args_to_cast: 70 | new_kwargs[arg_name] = cast_tensor_type( 71 | arg_value, torch.float, torch.half) 72 | else: 73 | new_kwargs[arg_name] = arg_value 74 | # apply converted arguments to the decorated method 75 | output = old_func(*new_args, **new_kwargs) 76 | # cast the results back to fp32 if necessary 77 | if out_fp32: 78 | output = cast_tensor_type(output, torch.half, torch.float) 79 | return output 80 | 81 | return new_func 82 | 83 | return auto_fp16_wrapper 84 | 85 | 86 | def force_fp32(apply_to=None, out_fp16=False): 87 | """Decorator to convert input arguments to fp32 in force. 88 | 89 | This decorator is useful when you write custom modules and want to support 90 | mixed precision training. If there are some inputs that must be processed 91 | in fp32 mode, then this decorator can handle it. If inputs arguments are 92 | fp16 tensors, they will be converted to fp32 automatically. Arguments other 93 | than fp16 tensors are ignored. 94 | 95 | Args: 96 | apply_to (Iterable, optional): The argument names to be converted. 97 | `None` indicates all arguments. 98 | out_fp16 (bool): Whether to convert the output back to fp16. 99 | 100 | :Example: 101 | 102 | class MyModule1(nn.Module) 103 | 104 | # Convert x and y to fp32 105 | @force_fp32() 106 | def loss(self, x, y): 107 | pass 108 | 109 | class MyModule2(nn.Module): 110 | 111 | # convert pred to fp32 112 | @force_fp32(apply_to=('pred', )) 113 | def post_process(self, pred, others): 114 | pass 115 | """ 116 | 117 | def force_fp32_wrapper(old_func): 118 | 119 | @functools.wraps(old_func) 120 | def new_func(*args, **kwargs): 121 | # check if the module has set the attribute `fp16_enabled`, if not, 122 | # just fallback to the original method. 123 | if not isinstance(args[0], torch.nn.Module): 124 | raise TypeError('@force_fp32 can only be used to decorate the ' 125 | 'method of nn.Module') 126 | if not (hasattr(args[0], 'fp16_enabled') and args[0].fp16_enabled): 127 | return old_func(*args, **kwargs) 128 | # get the arg spec of the decorated method 129 | args_info = getfullargspec(old_func) 130 | # get the argument names to be casted 131 | args_to_cast = args_info.args if apply_to is None else apply_to 132 | # convert the args that need to be processed 133 | new_args = [] 134 | if args: 135 | arg_names = args_info.args[:len(args)] 136 | for i, arg_name in enumerate(arg_names): 137 | if arg_name in args_to_cast: 138 | new_args.append( 139 | cast_tensor_type(args[i], torch.half, torch.float)) 140 | else: 141 | new_args.append(args[i]) 142 | # convert the kwargs that need to be processed 143 | new_kwargs = dict() 144 | if kwargs: 145 | for arg_name, arg_value in kwargs.items(): 146 | if arg_name in args_to_cast: 147 | new_kwargs[arg_name] = cast_tensor_type( 148 | arg_value, torch.half, torch.float) 149 | else: 150 | new_kwargs[arg_name] = arg_value 151 | # apply converted arguments to the decorated method 152 | output = old_func(*new_args, **new_kwargs) 153 | # cast the results back to fp32 if necessary 154 | if out_fp16: 155 | output = cast_tensor_type(output, torch.float, torch.half) 156 | return output 157 | 158 | return new_func 159 | 160 | return force_fp32_wrapper 161 | -------------------------------------------------------------------------------- /mmdet/core/fp16/hooks.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from mmcv.runner import OptimizerHook 5 | 6 | from .utils import cast_tensor_type 7 | from ..utils.dist_utils import allreduce_grads 8 | 9 | 10 | class Fp16OptimizerHook(OptimizerHook): 11 | """FP16 optimizer hook. 12 | 13 | The steps of fp16 optimizer is as follows. 14 | 1. Scale the loss value. 15 | 2. BP in the fp16 model. 16 | 2. Copy gradients from fp16 model to fp32 weights. 17 | 3. Update fp32 weights. 18 | 4. Copy updated parameters from fp32 weights to fp16 model. 19 | 20 | Refer to https://arxiv.org/abs/1710.03740 for more details. 21 | 22 | Args: 23 | loss_scale (float): Scale factor multiplied with loss. 24 | """ 25 | 26 | def __init__(self, 27 | grad_clip=None, 28 | coalesce=True, 29 | bucket_size_mb=-1, 30 | loss_scale=512., 31 | distributed=True): 32 | self.grad_clip = grad_clip 33 | self.coalesce = coalesce 34 | self.bucket_size_mb = bucket_size_mb 35 | self.loss_scale = loss_scale 36 | self.distributed = distributed 37 | 38 | def before_run(self, runner): 39 | # keep a copy of fp32 weights 40 | runner.optimizer.param_groups = copy.deepcopy( 41 | runner.optimizer.param_groups) 42 | # convert model to fp16 43 | wrap_fp16_model(runner.model) 44 | 45 | def copy_grads_to_fp32(self, fp16_net, fp32_weights): 46 | """Copy gradients from fp16 model to fp32 weight copy.""" 47 | for fp32_param, fp16_param in zip(fp32_weights, fp16_net.parameters()): 48 | if fp16_param.grad is not None: 49 | if fp32_param.grad is None: 50 | fp32_param.grad = fp32_param.data.new(fp32_param.size()) 51 | fp32_param.grad.copy_(fp16_param.grad) 52 | 53 | def copy_params_to_fp16(self, fp16_net, fp32_weights): 54 | """Copy updated params from fp32 weight copy to fp16 model.""" 55 | for fp16_param, fp32_param in zip(fp16_net.parameters(), fp32_weights): 56 | fp16_param.data.copy_(fp32_param.data) 57 | 58 | def after_train_iter(self, runner): 59 | # clear grads of last iteration 60 | runner.model.zero_grad() 61 | runner.optimizer.zero_grad() 62 | # scale the loss value 63 | scaled_loss = runner.outputs['loss'] * self.loss_scale 64 | scaled_loss.backward() 65 | # copy fp16 grads in the model to fp32 params in the optimizer 66 | fp32_weights = [] 67 | for param_group in runner.optimizer.param_groups: 68 | fp32_weights += param_group['params'] 69 | self.copy_grads_to_fp32(runner.model, fp32_weights) 70 | # allreduce grads 71 | if self.distributed: 72 | allreduce_grads(fp32_weights, self.coalesce, self.bucket_size_mb) 73 | # scale the gradients back 74 | for param in fp32_weights: 75 | if param.grad is not None: 76 | param.grad.div_(self.loss_scale) 77 | if self.grad_clip is not None: 78 | self.clip_grads(fp32_weights) 79 | # update fp32 params 80 | runner.optimizer.step() 81 | # copy fp32 params to the fp16 model 82 | self.copy_params_to_fp16(runner.model, fp32_weights) 83 | 84 | 85 | def wrap_fp16_model(model): 86 | # convert model to fp16 87 | model.half() 88 | # patch the normalization layers to make it work in fp32 mode 89 | patch_norm_fp32(model) 90 | # set `fp16_enabled` flag 91 | for m in model.modules(): 92 | if hasattr(m, 'fp16_enabled'): 93 | m.fp16_enabled = True 94 | 95 | 96 | def patch_norm_fp32(module): 97 | if isinstance(module, (nn.modules.batchnorm._BatchNorm, nn.GroupNorm)): 98 | module.float() 99 | module.forward = patch_forward_method(module.forward, torch.half, 100 | torch.float) 101 | for child in module.children(): 102 | patch_norm_fp32(child) 103 | return module 104 | 105 | 106 | def patch_forward_method(func, src_type, dst_type, convert_output=True): 107 | """Patch the forward method of a module. 108 | 109 | Args: 110 | func (callable): The original forward method. 111 | src_type (torch.dtype): Type of input arguments to be converted from. 112 | dst_type (torch.dtype): Type of input arguments to be converted to. 113 | convert_output (bool): Whether to convert the output back to src_type. 114 | 115 | Returns: 116 | callable: The patched forward method. 117 | """ 118 | 119 | def new_forward(*args, **kwargs): 120 | output = func(*cast_tensor_type(args, src_type, dst_type), 121 | **cast_tensor_type(kwargs, src_type, dst_type)) 122 | if convert_output: 123 | output = cast_tensor_type(output, dst_type, src_type) 124 | return output 125 | 126 | return new_forward 127 | -------------------------------------------------------------------------------- /mmdet/core/fp16/utils.py: -------------------------------------------------------------------------------- 1 | from collections import abc 2 | 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def cast_tensor_type(inputs, src_type, dst_type): 8 | if isinstance(inputs, torch.Tensor): 9 | return inputs.to(dst_type) 10 | elif isinstance(inputs, str): 11 | return inputs 12 | elif isinstance(inputs, np.ndarray): 13 | return inputs 14 | elif isinstance(inputs, abc.Mapping): 15 | return type(inputs)({ 16 | k: cast_tensor_type(v, src_type, dst_type) 17 | for k, v in inputs.items() 18 | }) 19 | elif isinstance(inputs, abc.Iterable): 20 | return type(inputs)( 21 | cast_tensor_type(item, src_type, dst_type) for item in inputs) 22 | else: 23 | return inputs 24 | -------------------------------------------------------------------------------- /mmdet/core/mask/__init__.py: -------------------------------------------------------------------------------- 1 | from .utils import split_combined_polys 2 | from .mask_target import mask_target 3 | 4 | __all__ = ['split_combined_polys', 'mask_target'] 5 | -------------------------------------------------------------------------------- /mmdet/core/mask/mask_target.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import mmcv 4 | 5 | 6 | def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list, 7 | cfg): 8 | cfg_list = [cfg for _ in range(len(pos_proposals_list))] 9 | mask_targets = map(mask_target_single, pos_proposals_list, 10 | pos_assigned_gt_inds_list, gt_masks_list, cfg_list) 11 | mask_targets = torch.cat(list(mask_targets)) 12 | return mask_targets 13 | 14 | 15 | def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg): 16 | mask_size = cfg.mask_size 17 | num_pos = pos_proposals.size(0) 18 | mask_targets = [] 19 | if num_pos > 0: 20 | proposals_np = pos_proposals.cpu().numpy() 21 | pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy() 22 | for i in range(num_pos): 23 | gt_mask = gt_masks[pos_assigned_gt_inds[i]] 24 | bbox = proposals_np[i, :].astype(np.int32) 25 | x1, y1, x2, y2 = bbox 26 | w = np.maximum(x2 - x1 + 1, 1) 27 | h = np.maximum(y2 - y1 + 1, 1) 28 | # mask is uint8 both before and after resizing 29 | target = mmcv.imresize(gt_mask[y1:y1 + h, x1:x1 + w], 30 | (mask_size, mask_size)) 31 | mask_targets.append(target) 32 | mask_targets = torch.from_numpy(np.stack(mask_targets)).float().to( 33 | pos_proposals.device) 34 | else: 35 | mask_targets = pos_proposals.new_zeros((0, mask_size, mask_size)) 36 | return mask_targets 37 | -------------------------------------------------------------------------------- /mmdet/core/mask/utils.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | 3 | 4 | def split_combined_polys(polys, poly_lens, polys_per_mask): 5 | """Split the combined 1-D polys into masks. 6 | 7 | A mask is represented as a list of polys, and a poly is represented as 8 | a 1-D array. In dataset, all masks are concatenated into a single 1-D 9 | tensor. Here we need to split the tensor into original representations. 10 | 11 | Args: 12 | polys (list): a list (length = image num) of 1-D tensors 13 | poly_lens (list): a list (length = image num) of poly length 14 | polys_per_mask (list): a list (length = image num) of poly number 15 | of each mask 16 | 17 | Returns: 18 | list: a list (length = image num) of list (length = mask num) of 19 | list (length = poly num) of numpy array 20 | """ 21 | mask_polys_list = [] 22 | for img_id in range(len(polys)): 23 | polys_single = polys[img_id] 24 | polys_lens_single = poly_lens[img_id].tolist() 25 | polys_per_mask_single = polys_per_mask[img_id].tolist() 26 | 27 | split_polys = mmcv.slice_list(polys_single, polys_lens_single) 28 | mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single) 29 | mask_polys_list.append(mask_polys) 30 | return mask_polys_list 31 | -------------------------------------------------------------------------------- /mmdet/core/post_processing/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox_nms import multiclass_nms 2 | from .merge_augs import (merge_aug_proposals, merge_aug_bboxes, 3 | merge_aug_scores, merge_aug_masks) 4 | 5 | __all__ = [ 6 | 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes', 7 | 'merge_aug_scores', 'merge_aug_masks' 8 | ] 9 | -------------------------------------------------------------------------------- /mmdet/core/post_processing/bbox_nms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from mmdet.ops.nms import nms_wrapper 4 | 5 | 6 | def multiclass_nms(multi_bboxes, 7 | multi_scores, 8 | score_thr, 9 | nms_cfg, 10 | max_num=-1, 11 | score_factors=None): 12 | """NMS for multi-class bboxes. 13 | 14 | Args: 15 | multi_bboxes (Tensor): shape (n, #class*4) or (n, 4) 16 | multi_scores (Tensor): shape (n, #class) 17 | score_thr (float): bbox threshold, bboxes with scores lower than it 18 | will not be considered. 19 | nms_thr (float): NMS IoU threshold 20 | max_num (int): if there are more than max_num bboxes after NMS, 21 | only top max_num will be kept. 22 | score_factors (Tensor): The factors multiplied to scores before 23 | applying NMS 24 | 25 | Returns: 26 | tuple: (bboxes, labels), tensors of shape (k, 5) and (k, 1). Labels 27 | are 0-based. 28 | """ 29 | num_classes = multi_scores.shape[1] 30 | bboxes, labels = [], [] 31 | nms_cfg_ = nms_cfg.copy() 32 | nms_type = nms_cfg_.pop('type', 'nms') 33 | nms_op = getattr(nms_wrapper, nms_type) 34 | for i in range(1, num_classes): 35 | cls_inds = multi_scores[:, i] > score_thr 36 | if not cls_inds.any(): 37 | continue 38 | # get bboxes and scores of this class 39 | if multi_bboxes.shape[1] == 4: 40 | _bboxes = multi_bboxes[cls_inds, :] 41 | else: 42 | _bboxes = multi_bboxes[cls_inds, i * 4:(i + 1) * 4] 43 | _scores = multi_scores[cls_inds, i] 44 | if score_factors is not None: 45 | _scores *= score_factors[cls_inds] 46 | cls_dets = torch.cat([_bboxes, _scores[:, None]], dim=1) 47 | cls_dets, _ = nms_op(cls_dets, **nms_cfg_) 48 | cls_labels = multi_bboxes.new_full( 49 | (cls_dets.shape[0], ), i - 1, dtype=torch.long) 50 | bboxes.append(cls_dets) 51 | labels.append(cls_labels) 52 | if bboxes: 53 | bboxes = torch.cat(bboxes) 54 | labels = torch.cat(labels) 55 | if bboxes.shape[0] > max_num: 56 | _, inds = bboxes[:, -1].sort(descending=True) 57 | inds = inds[:max_num] 58 | bboxes = bboxes[inds] 59 | labels = labels[inds] 60 | else: 61 | bboxes = multi_bboxes.new_zeros((0, 5)) 62 | labels = multi_bboxes.new_zeros((0, ), dtype=torch.long) 63 | 64 | return bboxes, labels 65 | -------------------------------------------------------------------------------- /mmdet/core/post_processing/merge_augs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import numpy as np 4 | 5 | from mmdet.ops import nms 6 | from ..bbox import bbox_mapping_back 7 | 8 | 9 | def merge_aug_proposals(aug_proposals, img_metas, rpn_test_cfg): 10 | """Merge augmented proposals (multiscale, flip, etc.) 11 | 12 | Args: 13 | aug_proposals (list[Tensor]): proposals from different testing 14 | schemes, shape (n, 5). Note that they are not rescaled to the 15 | original image size. 16 | img_metas (list[dict]): image info including "shape_scale" and "flip". 17 | rpn_test_cfg (dict): rpn test config. 18 | 19 | Returns: 20 | Tensor: shape (n, 4), proposals corresponding to original image scale. 21 | """ 22 | recovered_proposals = [] 23 | for proposals, img_info in zip(aug_proposals, img_metas): 24 | img_shape = img_info['img_shape'] 25 | scale_factor = img_info['scale_factor'] 26 | flip = img_info['flip'] 27 | _proposals = proposals.clone() 28 | _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape, 29 | scale_factor, flip) 30 | recovered_proposals.append(_proposals) 31 | aug_proposals = torch.cat(recovered_proposals, dim=0) 32 | merged_proposals, _ = nms(aug_proposals, rpn_test_cfg.nms_thr) 33 | scores = merged_proposals[:, 4] 34 | _, order = scores.sort(0, descending=True) 35 | num = min(rpn_test_cfg.max_num, merged_proposals.shape[0]) 36 | order = order[:num] 37 | merged_proposals = merged_proposals[order, :] 38 | return merged_proposals 39 | 40 | 41 | def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg): 42 | """Merge augmented detection bboxes and scores. 43 | 44 | Args: 45 | aug_bboxes (list[Tensor]): shape (n, 4*#class) 46 | aug_scores (list[Tensor] or None): shape (n, #class) 47 | img_shapes (list[Tensor]): shape (3, ). 48 | rcnn_test_cfg (dict): rcnn test config. 49 | 50 | Returns: 51 | tuple: (bboxes, scores) 52 | """ 53 | recovered_bboxes = [] 54 | for bboxes, img_info in zip(aug_bboxes, img_metas): 55 | img_shape = img_info[0]['img_shape'] 56 | scale_factor = img_info[0]['scale_factor'] 57 | flip = img_info[0]['flip'] 58 | bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip) 59 | recovered_bboxes.append(bboxes) 60 | bboxes = torch.stack(recovered_bboxes).mean(dim=0) 61 | if aug_scores is None: 62 | return bboxes 63 | else: 64 | scores = torch.stack(aug_scores).mean(dim=0) 65 | return bboxes, scores 66 | 67 | 68 | def merge_aug_scores(aug_scores): 69 | """Merge augmented bbox scores.""" 70 | if isinstance(aug_scores[0], torch.Tensor): 71 | return torch.mean(torch.stack(aug_scores), dim=0) 72 | else: 73 | return np.mean(aug_scores, axis=0) 74 | 75 | 76 | def merge_aug_masks(aug_masks, img_metas, rcnn_test_cfg, weights=None): 77 | """Merge augmented mask prediction. 78 | 79 | Args: 80 | aug_masks (list[ndarray]): shape (n, #class, h, w) 81 | img_shapes (list[ndarray]): shape (3, ). 82 | rcnn_test_cfg (dict): rcnn test config. 83 | 84 | Returns: 85 | tuple: (bboxes, scores) 86 | """ 87 | recovered_masks = [ 88 | mask if not img_info[0]['flip'] else mask[..., ::-1] 89 | for mask, img_info in zip(aug_masks, img_metas) 90 | ] 91 | if weights is None: 92 | merged_masks = np.mean(recovered_masks, axis=0) 93 | else: 94 | merged_masks = np.average( 95 | np.array(recovered_masks), axis=0, weights=np.array(weights)) 96 | return merged_masks 97 | -------------------------------------------------------------------------------- /mmdet/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dist_utils import allreduce_grads, DistOptimizerHook 2 | from .misc import tensor2imgs, unmap, multi_apply 3 | 4 | __all__ = [ 5 | 'allreduce_grads', 'DistOptimizerHook', 'tensor2imgs', 'unmap', 6 | 'multi_apply' 7 | ] 8 | -------------------------------------------------------------------------------- /mmdet/core/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch.distributed as dist 4 | from torch._utils import (_flatten_dense_tensors, _unflatten_dense_tensors, 5 | _take_tensors) 6 | from mmcv.runner import OptimizerHook 7 | 8 | 9 | def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): 10 | if bucket_size_mb > 0: 11 | bucket_size_bytes = bucket_size_mb * 1024 * 1024 12 | buckets = _take_tensors(tensors, bucket_size_bytes) 13 | else: 14 | buckets = OrderedDict() 15 | for tensor in tensors: 16 | tp = tensor.type() 17 | if tp not in buckets: 18 | buckets[tp] = [] 19 | buckets[tp].append(tensor) 20 | buckets = buckets.values() 21 | 22 | for bucket in buckets: 23 | flat_tensors = _flatten_dense_tensors(bucket) 24 | dist.all_reduce(flat_tensors) 25 | flat_tensors.div_(world_size) 26 | for tensor, synced in zip( 27 | bucket, _unflatten_dense_tensors(flat_tensors, bucket)): 28 | tensor.copy_(synced) 29 | 30 | 31 | def allreduce_grads(params, coalesce=True, bucket_size_mb=-1): 32 | grads = [ 33 | param.grad.data for param in params 34 | if param.requires_grad and param.grad is not None 35 | ] 36 | world_size = dist.get_world_size() 37 | if coalesce: 38 | _allreduce_coalesced(grads, world_size, bucket_size_mb) 39 | else: 40 | for tensor in grads: 41 | dist.all_reduce(tensor.div_(world_size)) 42 | 43 | 44 | class DistOptimizerHook(OptimizerHook): 45 | 46 | def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1): 47 | self.grad_clip = grad_clip 48 | self.coalesce = coalesce 49 | self.bucket_size_mb = bucket_size_mb 50 | 51 | def after_train_iter(self, runner): 52 | runner.optimizer.zero_grad() 53 | runner.outputs['loss'].backward() 54 | allreduce_grads(runner.model.parameters(), self.coalesce, 55 | self.bucket_size_mb) 56 | if self.grad_clip is not None: 57 | self.clip_grads(runner.model.parameters()) 58 | runner.optimizer.step() 59 | -------------------------------------------------------------------------------- /mmdet/core/utils/misc.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import mmcv 4 | import numpy as np 5 | from six.moves import map, zip 6 | 7 | 8 | def tensor2imgs(tensor, mean=(0, 0, 0), std=(1, 1, 1), to_rgb=True): 9 | num_imgs = tensor.size(0) 10 | mean = np.array(mean, dtype=np.float32) 11 | std = np.array(std, dtype=np.float32) 12 | imgs = [] 13 | for img_id in range(num_imgs): 14 | img = tensor[img_id, ...].cpu().numpy().transpose(1, 2, 0) 15 | img = mmcv.imdenormalize( 16 | img, mean, std, to_bgr=to_rgb).astype(np.uint8) 17 | imgs.append(np.ascontiguousarray(img)) 18 | return imgs 19 | 20 | 21 | def multi_apply(func, *args, **kwargs): 22 | pfunc = partial(func, **kwargs) if kwargs else func 23 | map_results = map(pfunc, *args) 24 | return tuple(map(list, zip(*map_results))) 25 | 26 | 27 | def unmap(data, count, inds, fill=0): 28 | """ Unmap a subset of item (data) back to the original set of items (of 29 | size count) """ 30 | if data.dim() == 1: 31 | ret = data.new_full((count, ), fill) 32 | ret[inds] = data 33 | else: 34 | new_size = (count, ) + data.size()[1:] 35 | ret = data.new_full(new_size, fill) 36 | ret[inds, :] = data 37 | return ret 38 | -------------------------------------------------------------------------------- /mmdet/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom import CustomDataset 2 | from .loader import GroupSampler, DistributedGroupSampler, build_dataloader 3 | from .utils import to_tensor, random_scale, show_ann 4 | from .dataset_wrappers import ConcatDataset, RepeatDataset 5 | from .extra_aug import ExtraAugmentation 6 | from .registry import DATASETS 7 | from .builder import build_dataset 8 | from .city import CityDataset 9 | 10 | __all__ = [ 11 | 'CustomDataset', 'GroupSampler', 12 | 'DistributedGroupSampler', 'build_dataloader', 'to_tensor', 'random_scale', 13 | 'show_ann', 'ConcatDataset', 'RepeatDataset', 'ExtraAugmentation', 'DATASETS', 'build_dataset', 'CityDataset', 14 | ] 15 | -------------------------------------------------------------------------------- /mmdet/datasets/builder.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | from mmdet.utils import build_from_cfg 4 | from .dataset_wrappers import ConcatDataset, RepeatDataset 5 | from .registry import DATASETS 6 | 7 | 8 | def _concat_dataset(cfg): 9 | ann_files = cfg['ann_file'] 10 | img_prefixes = cfg.get('img_prefix', None) 11 | seg_prefixes = cfg.get('seg_prefixes', None) 12 | proposal_files = cfg.get('proposal_file', None) 13 | 14 | datasets = [] 15 | num_dset = len(ann_files) 16 | for i in range(num_dset): 17 | data_cfg = copy.deepcopy(cfg) 18 | data_cfg['ann_file'] = ann_files[i] 19 | if isinstance(img_prefixes, (list, tuple)): 20 | data_cfg['img_prefix'] = img_prefixes[i] 21 | if isinstance(seg_prefixes, (list, tuple)): 22 | data_cfg['seg_prefix'] = seg_prefixes[i] 23 | if isinstance(proposal_files, (list, tuple)): 24 | data_cfg['proposal_file'] = proposal_files[i] 25 | datasets.append(build_dataset(data_cfg)) 26 | 27 | return ConcatDataset(datasets) 28 | 29 | 30 | def build_dataset(cfg): 31 | if cfg['type'] == 'RepeatDataset': 32 | dataset = RepeatDataset(build_dataset(cfg['dataset']), cfg['times']) 33 | elif isinstance(cfg['ann_file'], (list, tuple)): 34 | dataset = _concat_dataset(cfg) 35 | else: 36 | dataset = build_from_cfg(cfg, DATASETS) 37 | 38 | return dataset 39 | -------------------------------------------------------------------------------- /mmdet/datasets/city.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from pycocotools.coco import COCO 4 | 5 | from .custom import CustomDataset 6 | from .registry import DATASETS 7 | 8 | 9 | @DATASETS.register_module 10 | class CityDataset(CustomDataset): 11 | 12 | CLASSES = ('pedestrian') 13 | 14 | def load_annotations(self, ann_file): 15 | self.coco = COCO(ann_file) 16 | self.cat_ids = self.coco.getCatIds() 17 | self.cat2label = { 18 | cat_id: i + 1 19 | for i, cat_id in enumerate(self.cat_ids) 20 | } 21 | self.img_ids = self.coco.getImgIds() 22 | img_infos = [] 23 | for i in self.img_ids: 24 | info = self.coco.loadImgs([i])[0] 25 | info['filename'] = os.path.join(info['mode'], info['city_name'], info['file_name']) 26 | img_infos.append(info) 27 | return img_infos 28 | 29 | def get_ann_info(self, idx): 30 | img_id = self.img_infos[idx]['id'] 31 | ann_ids = self.coco.getAnnIds(imgIds=[img_id]) 32 | ann_info = self.coco.loadAnns(ann_ids) 33 | return self._parse_ann_info(ann_info, self.with_mask) 34 | 35 | def _filter_imgs(self, min_size=32): 36 | """Filter images too small or without ground truths.""" 37 | valid_inds = [] 38 | ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values()) 39 | for i, img_info in enumerate(self.img_infos): 40 | # if self.img_ids[i] not in ids_with_ann: 41 | # continue 42 | if min(img_info['width'], img_info['height']) >= min_size: 43 | valid_inds.append(i) 44 | return valid_inds 45 | 46 | def _parse_ann_info(self, ann_info, with_mask=True): 47 | """Parse bbox and mask annotation. 48 | 49 | Args: 50 | ann_info (list[dict]): Annotation info of an image. 51 | with_mask (bool): Whether to parse mask annotations. 52 | 53 | Returns: 54 | dict: A dict containing the following keys: bboxes, bboxes_ignore, 55 | labels, masks, mask_polys, poly_lens. 56 | """ 57 | gt_bboxes = [] 58 | gt_labels = [] 59 | gt_bboxes_ignore = [] 60 | # Two formats are provided. 61 | # 1. mask: a binary map of the same size of the image. 62 | # 2. polys: each mask consists of one or several polys, each poly is a 63 | # list of float. 64 | if with_mask: 65 | gt_masks = [] 66 | gt_mask_polys = [] 67 | gt_poly_lens = [] 68 | for i, ann in enumerate(ann_info): 69 | if ann.get('ignore', False): 70 | continue 71 | x1, y1, w, h = ann['bbox'] 72 | if ann['area'] <= 0 or w < 1 or h < 1: 73 | continue 74 | bbox = [x1, y1, x1 + w - 1, y1 + h - 1] 75 | if ann['iscrowd']: 76 | gt_bboxes_ignore.append(bbox) 77 | else: 78 | gt_bboxes.append(bbox) 79 | gt_labels.append(self.cat2label[ann['category_id']]) 80 | if with_mask: 81 | gt_masks.append(self.coco.annToMask(ann)) 82 | mask_polys = [ 83 | p for p in ann['segmentation'] if len(p) >= 6 84 | ] # valid polygons have >= 3 points (6 coordinates) 85 | poly_lens = [len(p) for p in mask_polys] 86 | gt_mask_polys.append(mask_polys) 87 | gt_poly_lens.extend(poly_lens) 88 | if gt_bboxes: 89 | gt_bboxes = np.array(gt_bboxes, dtype=np.float32) 90 | gt_labels = np.array(gt_labels, dtype=np.int64) 91 | else: 92 | gt_bboxes = np.zeros((0, 4), dtype=np.float32) 93 | gt_labels = np.array([], dtype=np.int64) 94 | 95 | if gt_bboxes_ignore: 96 | gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32) 97 | else: 98 | gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32) 99 | 100 | ann = dict( 101 | bboxes=gt_bboxes, labels=gt_labels, bboxes_ignore=gt_bboxes_ignore) 102 | 103 | if with_mask: 104 | ann['masks'] = gt_masks 105 | # poly format is not used in the current implementation 106 | ann['mask_polys'] = gt_mask_polys 107 | ann['poly_lens'] = gt_poly_lens 108 | return ann 109 | -------------------------------------------------------------------------------- /mmdet/datasets/dataset_wrappers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data.dataset import ConcatDataset as _ConcatDataset 3 | 4 | from .registry import DATASETS 5 | 6 | 7 | @DATASETS.register_module 8 | class ConcatDataset(_ConcatDataset): 9 | """A wrapper of concatenated dataset. 10 | 11 | Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but 12 | concat the group flag for image aspect ratio. 13 | 14 | Args: 15 | datasets (list[:obj:`Dataset`]): A list of datasets. 16 | """ 17 | 18 | def __init__(self, datasets): 19 | super(ConcatDataset, self).__init__(datasets) 20 | self.CLASSES = datasets[0].CLASSES 21 | if hasattr(datasets[0], 'flag'): 22 | flags = [] 23 | for i in range(0, len(datasets)): 24 | flags.append(datasets[i].flag) 25 | self.flag = np.concatenate(flags) 26 | 27 | 28 | @DATASETS.register_module 29 | class RepeatDataset(object): 30 | """A wrapper of repeated dataset. 31 | 32 | The length of repeated dataset will be `times` larger than the original 33 | dataset. This is useful when the data loading time is long but the dataset 34 | is small. Using RepeatDataset can reduce the data loading time between 35 | epochs. 36 | 37 | Args: 38 | dataset (:obj:`Dataset`): The dataset to be repeated. 39 | times (int): Repeat times. 40 | """ 41 | 42 | def __init__(self, dataset, times): 43 | self.dataset = dataset 44 | self.times = times 45 | self.CLASSES = dataset.CLASSES 46 | if hasattr(self.dataset, 'flag'): 47 | self.flag = np.tile(self.dataset.flag, times) 48 | 49 | self._ori_len = len(self.dataset) 50 | 51 | def __getitem__(self, idx): 52 | return self.dataset[idx % self._ori_len] 53 | 54 | def __len__(self): 55 | return self.times * self._ori_len 56 | -------------------------------------------------------------------------------- /mmdet/datasets/extra_aug.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | from numpy import random 4 | 5 | from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps 6 | 7 | 8 | class PhotoMetricDistortion(object): 9 | 10 | def __init__(self, 11 | brightness_delta=32, 12 | contrast_range=(0.5, 1.5), 13 | saturation_range=(0.5, 1.5), 14 | hue_delta=18): 15 | self.brightness_delta = brightness_delta 16 | self.contrast_lower, self.contrast_upper = contrast_range 17 | self.saturation_lower, self.saturation_upper = saturation_range 18 | self.hue_delta = hue_delta 19 | 20 | def __call__(self, img, boxes, labels): 21 | # random brightness 22 | if random.randint(2): 23 | delta = random.uniform(-self.brightness_delta, 24 | self.brightness_delta) 25 | img += delta 26 | 27 | # mode == 0 --> do random contrast first 28 | # mode == 1 --> do random contrast last 29 | mode = random.randint(2) 30 | if mode == 1: 31 | if random.randint(2): 32 | alpha = random.uniform(self.contrast_lower, 33 | self.contrast_upper) 34 | img *= alpha 35 | 36 | # convert color from BGR to HSV 37 | img = mmcv.bgr2hsv(img) 38 | 39 | # random saturation 40 | if random.randint(2): 41 | img[..., 1] *= random.uniform(self.saturation_lower, 42 | self.saturation_upper) 43 | 44 | # random hue 45 | if random.randint(2): 46 | img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) 47 | img[..., 0][img[..., 0] > 360] -= 360 48 | img[..., 0][img[..., 0] < 0] += 360 49 | 50 | # convert color from HSV to BGR 51 | img = mmcv.hsv2bgr(img) 52 | 53 | # random contrast 54 | if mode == 0: 55 | if random.randint(2): 56 | alpha = random.uniform(self.contrast_lower, 57 | self.contrast_upper) 58 | img *= alpha 59 | 60 | # randomly swap channels 61 | if random.randint(2): 62 | img = img[..., random.permutation(3)] 63 | 64 | return img, boxes, labels 65 | 66 | 67 | class Expand(object): 68 | 69 | def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): 70 | if to_rgb: 71 | self.mean = mean[::-1] 72 | else: 73 | self.mean = mean 74 | self.min_ratio, self.max_ratio = ratio_range 75 | 76 | def __call__(self, img, boxes, labels): 77 | if random.randint(2): 78 | return img, boxes, labels 79 | 80 | h, w, c = img.shape 81 | ratio = random.uniform(self.min_ratio, self.max_ratio) 82 | expand_img = np.full((int(h * ratio), int(w * ratio), c), 83 | self.mean).astype(img.dtype) 84 | left = int(random.uniform(0, w * ratio - w)) 85 | top = int(random.uniform(0, h * ratio - h)) 86 | expand_img[top:top + h, left:left + w] = img 87 | img = expand_img 88 | boxes += np.tile((left, top), 2) 89 | return img, boxes, labels 90 | 91 | 92 | class RandomCrop(object): 93 | 94 | def __init__(self, min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), min_crop_size=0.3): 95 | # 1: return ori img 96 | self.sample_mode = (1, *min_ious, 0) 97 | self.min_crop_size = min_crop_size 98 | 99 | def __call__(self, img, boxes, labels): 100 | h, w, c = img.shape 101 | while True: 102 | mode = random.choice(self.sample_mode) 103 | if mode == 1: 104 | return img, boxes, labels 105 | 106 | min_iou = mode 107 | for i in range(50): 108 | new_w = random.uniform(self.min_crop_size * w, w) 109 | new_h = random.uniform(self.min_crop_size * h, h) 110 | 111 | # h / w in [0.5, 2] 112 | if new_h / new_w < 0.5 or new_h / new_w > 2: 113 | continue 114 | 115 | left = random.uniform(w - new_w) 116 | top = random.uniform(h - new_h) 117 | 118 | patch = np.array((int(left), int(top), int(left + new_w), 119 | int(top + new_h))) 120 | overlaps = bbox_overlaps( 121 | patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1) 122 | if overlaps.min() < min_iou: 123 | continue 124 | 125 | # center of boxes should inside the crop img 126 | center = (boxes[:, :2] + boxes[:, 2:]) / 2 127 | mask = (center[:, 0] > patch[0]) * ( 128 | center[:, 1] > patch[1]) * (center[:, 0] < patch[2]) * ( 129 | center[:, 1] < patch[3]) 130 | if not mask.any(): 131 | continue 132 | boxes = boxes[mask] 133 | labels = labels[mask] 134 | 135 | # adjust boxes 136 | img = img[patch[1]:patch[3], patch[0]:patch[2]] 137 | boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:]) 138 | boxes[:, :2] = boxes[:, :2].clip(min=patch[:2]) 139 | boxes -= np.tile(patch[:2], 2) 140 | 141 | return img, boxes, labels 142 | 143 | 144 | class ExtraAugmentation(object): 145 | 146 | def __init__(self, 147 | photo_metric_distortion=None, 148 | expand=None, 149 | random_crop=None): 150 | self.transforms = [] 151 | if photo_metric_distortion is not None: 152 | self.transforms.append( 153 | PhotoMetricDistortion(**photo_metric_distortion)) 154 | if expand is not None: 155 | self.transforms.append(Expand(**expand)) 156 | if random_crop is not None: 157 | self.transforms.append(RandomCrop(**random_crop)) 158 | 159 | def __call__(self, img, boxes, labels): 160 | img = img.astype(np.float32) 161 | for transform in self.transforms: 162 | img, boxes, labels = transform(img, boxes, labels) 163 | return img, boxes, labels 164 | -------------------------------------------------------------------------------- /mmdet/datasets/loader/__init__.py: -------------------------------------------------------------------------------- 1 | from .build_loader import build_dataloader 2 | from .sampler import GroupSampler, DistributedGroupSampler 3 | 4 | __all__ = ['GroupSampler', 'DistributedGroupSampler', 'build_dataloader'] 5 | -------------------------------------------------------------------------------- /mmdet/datasets/loader/build_loader.py: -------------------------------------------------------------------------------- 1 | import platform 2 | from functools import partial 3 | 4 | from mmcv.runner import get_dist_info 5 | from mmcv.parallel import collate 6 | from torch.utils.data import DataLoader 7 | 8 | from .sampler import GroupSampler, DistributedGroupSampler, DistributedSampler 9 | 10 | if platform.system() != 'Windows': 11 | # https://github.com/pytorch/pytorch/issues/973 12 | import resource 13 | rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) 14 | resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) 15 | 16 | 17 | def build_dataloader(dataset, 18 | imgs_per_gpu, 19 | workers_per_gpu, 20 | num_gpus=1, 21 | dist=True, 22 | **kwargs): 23 | shuffle = kwargs.get('shuffle', True) 24 | if dist: 25 | rank, world_size = get_dist_info() 26 | if shuffle: 27 | sampler = DistributedGroupSampler(dataset, imgs_per_gpu, 28 | world_size, rank) 29 | else: 30 | sampler = DistributedSampler( 31 | dataset, world_size, rank, shuffle=False) 32 | batch_size = imgs_per_gpu 33 | num_workers = workers_per_gpu 34 | else: 35 | sampler = GroupSampler(dataset, imgs_per_gpu) if shuffle else None 36 | batch_size = num_gpus * imgs_per_gpu 37 | num_workers = num_gpus * workers_per_gpu 38 | 39 | data_loader = DataLoader( 40 | dataset, 41 | batch_size=batch_size, 42 | sampler=sampler, 43 | num_workers=num_workers, 44 | collate_fn=partial(collate, samples_per_gpu=imgs_per_gpu), 45 | pin_memory=False, 46 | **kwargs) 47 | 48 | return data_loader 49 | -------------------------------------------------------------------------------- /mmdet/datasets/loader/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import math 4 | import torch 5 | import numpy as np 6 | 7 | from mmcv.runner.utils import get_dist_info 8 | from torch.utils.data import Sampler 9 | from torch.utils.data import DistributedSampler as _DistributedSampler 10 | 11 | 12 | class DistributedSampler(_DistributedSampler): 13 | 14 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 15 | super().__init__(dataset, num_replicas=num_replicas, rank=rank) 16 | self.shuffle = shuffle 17 | 18 | def __iter__(self): 19 | # deterministically shuffle based on epoch 20 | if self.shuffle: 21 | g = torch.Generator() 22 | g.manual_seed(self.epoch) 23 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 24 | else: 25 | indices = torch.arange(len(self.dataset)).tolist() 26 | 27 | # add extra samples to make it evenly divisible 28 | indices += indices[:(self.total_size - len(indices))] 29 | assert len(indices) == self.total_size 30 | 31 | # subsample 32 | indices = indices[self.rank:self.total_size:self.num_replicas] 33 | assert len(indices) == self.num_samples 34 | 35 | return iter(indices) 36 | 37 | 38 | class GroupSampler(Sampler): 39 | 40 | def __init__(self, dataset, samples_per_gpu=1): 41 | assert hasattr(dataset, 'flag') 42 | self.dataset = dataset 43 | self.samples_per_gpu = samples_per_gpu 44 | self.flag = dataset.flag.astype(np.int64) 45 | self.group_sizes = np.bincount(self.flag) 46 | self.num_samples = 0 47 | for i, size in enumerate(self.group_sizes): 48 | self.num_samples += int(np.ceil( 49 | size / self.samples_per_gpu)) * self.samples_per_gpu 50 | 51 | def __iter__(self): 52 | indices = [] 53 | for i, size in enumerate(self.group_sizes): 54 | if size == 0: 55 | continue 56 | indice = np.where(self.flag == i)[0] 57 | assert len(indice) == size 58 | np.random.shuffle(indice) 59 | num_extra = int(np.ceil(size / self.samples_per_gpu) 60 | ) * self.samples_per_gpu - len(indice) 61 | indice = np.concatenate([indice, indice[:num_extra]]) 62 | indices.append(indice) 63 | indices = np.concatenate(indices) 64 | indices = [ 65 | indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu] 66 | for i in np.random.permutation( 67 | range(len(indices) // self.samples_per_gpu)) 68 | ] 69 | indices = np.concatenate(indices) 70 | indices = indices.astype(np.int64).tolist() 71 | assert len(indices) == self.num_samples 72 | return iter(indices) 73 | 74 | def __len__(self): 75 | return self.num_samples 76 | 77 | 78 | class DistributedGroupSampler(Sampler): 79 | """Sampler that restricts data loading to a subset of the dataset. 80 | It is especially useful in conjunction with 81 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 82 | process can pass a DistributedSampler instance as a DataLoader sampler, 83 | and load a subset of the original dataset that is exclusive to it. 84 | .. note:: 85 | Dataset is assumed to be of constant size. 86 | Arguments: 87 | dataset: Dataset used for sampling. 88 | num_replicas (optional): Number of processes participating in 89 | distributed training. 90 | rank (optional): Rank of the current process within num_replicas. 91 | """ 92 | 93 | def __init__(self, 94 | dataset, 95 | samples_per_gpu=1, 96 | num_replicas=None, 97 | rank=None): 98 | _rank, _num_replicas = get_dist_info() 99 | if num_replicas is None: 100 | num_replicas = _num_replicas 101 | if rank is None: 102 | rank = _rank 103 | self.dataset = dataset 104 | self.samples_per_gpu = samples_per_gpu 105 | self.num_replicas = num_replicas 106 | self.rank = rank 107 | self.epoch = 0 108 | 109 | assert hasattr(self.dataset, 'flag') 110 | self.flag = self.dataset.flag 111 | self.group_sizes = np.bincount(self.flag) 112 | 113 | self.num_samples = 0 114 | for i, j in enumerate(self.group_sizes): 115 | self.num_samples += int( 116 | math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu / 117 | self.num_replicas)) * self.samples_per_gpu 118 | self.total_size = self.num_samples * self.num_replicas 119 | 120 | def __iter__(self): 121 | # deterministically shuffle based on epoch 122 | g = torch.Generator() 123 | g.manual_seed(self.epoch) 124 | 125 | indices = [] 126 | for i, size in enumerate(self.group_sizes): 127 | if size > 0: 128 | indice = np.where(self.flag == i)[0] 129 | assert len(indice) == size 130 | indice = indice[list(torch.randperm(int(size), 131 | generator=g))].tolist() 132 | extra = int( 133 | math.ceil( 134 | size * 1.0 / self.samples_per_gpu / self.num_replicas) 135 | ) * self.samples_per_gpu * self.num_replicas - len(indice) 136 | indice += indice[:extra] 137 | indices += indice 138 | 139 | assert len(indices) == self.total_size 140 | 141 | indices = [ 142 | indices[j] for i in list( 143 | torch.randperm( 144 | len(indices) // self.samples_per_gpu, generator=g)) 145 | for j in range(i * self.samples_per_gpu, (i + 1) * 146 | self.samples_per_gpu) 147 | ] 148 | 149 | # subsample 150 | offset = self.num_samples * self.rank 151 | indices = indices[offset:offset + self.num_samples] 152 | assert len(indices) == self.num_samples 153 | 154 | return iter(indices) 155 | 156 | def __len__(self): 157 | return self.num_samples 158 | 159 | def set_epoch(self, epoch): 160 | self.epoch = epoch 161 | -------------------------------------------------------------------------------- /mmdet/datasets/registry.py: -------------------------------------------------------------------------------- 1 | from mmdet.utils import Registry 2 | 3 | DATASETS = Registry('dataset') 4 | -------------------------------------------------------------------------------- /mmdet/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | import mmcv 2 | import numpy as np 3 | import torch 4 | 5 | __all__ = [ 6 | 'ImageTransform', 'BboxTransform', 'MaskTransform', 'SegMapTransform', 7 | 'Numpy2Tensor' 8 | ] 9 | 10 | 11 | class ImageTransform(object): 12 | """Preprocess an image. 13 | 14 | 1. rescale the image to expected size 15 | 2. normalize the image 16 | 3. flip the image (if needed) 17 | 4. pad the image (if needed) 18 | 5. transpose to (c, h, w) 19 | """ 20 | 21 | def __init__(self, 22 | mean=(0, 0, 0), 23 | std=(1, 1, 1), 24 | to_rgb=True, 25 | size_divisor=None): 26 | self.mean = np.array(mean, dtype=np.float32) 27 | self.std = np.array(std, dtype=np.float32) 28 | self.to_rgb = to_rgb 29 | self.size_divisor = size_divisor 30 | 31 | def __call__(self, img, scale, flip=False, keep_ratio=True): 32 | if keep_ratio: 33 | img, scale_factor = mmcv.imrescale(img, scale, return_scale=True) 34 | else: 35 | img, w_scale, h_scale = mmcv.imresize( 36 | img, scale, return_scale=True) 37 | scale_factor = np.array( 38 | [w_scale, h_scale, w_scale, h_scale], dtype=np.float32) 39 | img_shape = img.shape 40 | img = mmcv.imnormalize(img, self.mean, self.std, self.to_rgb) 41 | if flip: 42 | img = mmcv.imflip(img) 43 | if self.size_divisor is not None: 44 | img = mmcv.impad_to_multiple(img, self.size_divisor) 45 | pad_shape = img.shape 46 | else: 47 | pad_shape = img_shape 48 | img = img.transpose(2, 0, 1) 49 | return img, img_shape, pad_shape, scale_factor 50 | 51 | 52 | def bbox_flip(bboxes, img_shape): 53 | """Flip bboxes horizontally. 54 | 55 | Args: 56 | bboxes(ndarray): shape (..., 4*k) 57 | img_shape(tuple): (height, width) 58 | """ 59 | assert bboxes.shape[-1] % 4 == 0 60 | w = img_shape[1] 61 | flipped = bboxes.copy() 62 | flipped[..., 0::4] = w - bboxes[..., 2::4] - 1 63 | flipped[..., 2::4] = w - bboxes[..., 0::4] - 1 64 | return flipped 65 | 66 | 67 | class BboxTransform(object): 68 | """Preprocess gt bboxes. 69 | 70 | 1. rescale bboxes according to image size 71 | 2. flip bboxes (if needed) 72 | 3. pad the first dimension to `max_num_gts` 73 | """ 74 | 75 | def __init__(self, max_num_gts=None): 76 | self.max_num_gts = max_num_gts 77 | 78 | def __call__(self, bboxes, img_shape, scale_factor, flip=False): 79 | gt_bboxes = bboxes * scale_factor 80 | if flip: 81 | gt_bboxes = bbox_flip(gt_bboxes, img_shape) 82 | gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) 83 | gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) 84 | if self.max_num_gts is None: 85 | return gt_bboxes 86 | else: 87 | num_gts = gt_bboxes.shape[0] 88 | padded_bboxes = np.zeros((self.max_num_gts, 4), dtype=np.float32) 89 | padded_bboxes[:num_gts, :] = gt_bboxes 90 | return padded_bboxes 91 | 92 | class BboxTransformNoClip(object): 93 | """Preprocess gt bboxes. 94 | 95 | 1. rescale bboxes according to image size 96 | 2. flip bboxes (if needed) 97 | 3. pad the first dimension to `max_num_gts` 98 | """ 99 | 100 | def __init__(self, max_num_gts=None): 101 | self.max_num_gts = max_num_gts 102 | 103 | def __call__(self, bboxes, img_shape, scale_factor, flip=False): 104 | gt_bboxes = bboxes * scale_factor 105 | if flip: 106 | gt_bboxes = bbox_flip(gt_bboxes, img_shape) 107 | # gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) 108 | # gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) 109 | if self.max_num_gts is None: 110 | return gt_bboxes 111 | else: 112 | num_gts = gt_bboxes.shape[0] 113 | padded_bboxes = np.zeros((self.max_num_gts, 4), dtype=np.float32) 114 | padded_bboxes[:num_gts, :] = gt_bboxes 115 | return padded_bboxes 116 | 117 | 118 | 119 | class MaskTransform(object): 120 | """Preprocess masks. 121 | 122 | 1. resize masks to expected size and stack to a single array 123 | 2. flip the masks (if needed) 124 | 3. pad the masks (if needed) 125 | """ 126 | 127 | def __call__(self, masks, pad_shape, scale_factor, flip=False): 128 | masks = [ 129 | mmcv.imrescale(mask, scale_factor, interpolation='nearest') 130 | for mask in masks 131 | ] 132 | if flip: 133 | masks = [mask[:, ::-1] for mask in masks] 134 | padded_masks = [ 135 | mmcv.impad(mask, pad_shape[:2], pad_val=0) for mask in masks 136 | ] 137 | padded_masks = np.stack(padded_masks, axis=0) 138 | return padded_masks 139 | 140 | 141 | class SegMapTransform(object): 142 | """Preprocess semantic segmentation maps. 143 | 144 | 1. rescale the segmentation map to expected size 145 | 3. flip the image (if needed) 146 | 4. pad the image (if needed) 147 | """ 148 | 149 | def __init__(self, size_divisor=None): 150 | self.size_divisor = size_divisor 151 | 152 | def __call__(self, img, scale, flip=False, keep_ratio=True): 153 | if keep_ratio: 154 | img = mmcv.imrescale(img, scale, interpolation='nearest') 155 | else: 156 | img = mmcv.imresize(img, scale, interpolation='nearest') 157 | if flip: 158 | img = mmcv.imflip(img) 159 | if self.size_divisor is not None: 160 | img = mmcv.impad_to_multiple(img, self.size_divisor) 161 | return img 162 | 163 | 164 | class Numpy2Tensor(object): 165 | 166 | def __init__(self): 167 | pass 168 | 169 | def __call__(self, *args): 170 | if len(args) == 1: 171 | return torch.from_numpy(args[0]) 172 | else: 173 | return tuple([torch.from_numpy(np.array(array)) for array in args]) 174 | -------------------------------------------------------------------------------- /mmdet/datasets/utils.py: -------------------------------------------------------------------------------- 1 | from collections import Sequence 2 | 3 | import matplotlib.pyplot as plt 4 | import mmcv 5 | import numpy as np 6 | import torch 7 | 8 | 9 | def to_tensor(data): 10 | """Convert objects of various python types to :obj:`torch.Tensor`. 11 | 12 | Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, 13 | :class:`Sequence`, :class:`int` and :class:`float`. 14 | """ 15 | if isinstance(data, torch.Tensor): 16 | return data 17 | elif isinstance(data, np.ndarray): 18 | return torch.from_numpy(data) 19 | elif isinstance(data, Sequence) and not mmcv.is_str(data): 20 | return torch.tensor(data) 21 | elif isinstance(data, int): 22 | return torch.LongTensor([data]) 23 | elif isinstance(data, float): 24 | return torch.FloatTensor([data]) 25 | else: 26 | raise TypeError('type {} cannot be converted to tensor.'.format( 27 | type(data))) 28 | 29 | 30 | def random_scale(img_scales, mode='range'): 31 | """Randomly select a scale from a list of scales or scale ranges. 32 | 33 | Args: 34 | img_scales (list[tuple]): Image scale or scale range. 35 | mode (str): "range" or "value". 36 | 37 | Returns: 38 | tuple: Sampled image scale. 39 | """ 40 | num_scales = len(img_scales) 41 | if num_scales == 1: # fixed scale is specified 42 | img_scale = img_scales[0] 43 | elif num_scales == 2: # randomly sample a scale 44 | if mode == 'range': 45 | img_scale_long = [max(s) for s in img_scales] 46 | img_scale_short = [min(s) for s in img_scales] 47 | long_edge = np.random.randint( 48 | min(img_scale_long), 49 | max(img_scale_long) + 1) 50 | short_edge = np.random.randint( 51 | min(img_scale_short), 52 | max(img_scale_short) + 1) 53 | img_scale = (long_edge, short_edge) 54 | elif mode == 'value': 55 | img_scale = img_scales[np.random.randint(num_scales)] 56 | else: 57 | if mode != 'value': 58 | raise ValueError( 59 | 'Only "value" mode supports more than 2 image scales') 60 | img_scale = img_scales[np.random.randint(num_scales)] 61 | return img_scale 62 | 63 | 64 | def show_ann(coco, img, ann_info): 65 | plt.imshow(mmcv.bgr2rgb(img)) 66 | plt.axis('off') 67 | coco.showAnns(ann_info) 68 | plt.show() 69 | -------------------------------------------------------------------------------- /mmdet/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .backbones import * # noqa: F401,F403 2 | from .necks import * # noqa: F401,F403 3 | from .roi_extractors import * # noqa: F401,F403 4 | from .anchor_heads import * # noqa: F401,F403 5 | from .bbox_heads import * # noqa: F401,F403 6 | from .detectors import * # noqa: F401,F403 7 | from .mgan_heads import * 8 | from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS, 9 | LOSSES, DETECTORS) 10 | from .builder import (build_backbone, build_neck, build_roi_extractor, 11 | build_shared_head, build_head, build_loss, 12 | build_detector) 13 | 14 | __all__ = [ 15 | 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES', 16 | 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor', 17 | 'build_shared_head', 'build_head', 'build_loss', 'build_detector' 18 | ] 19 | -------------------------------------------------------------------------------- /mmdet/models/anchor_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .anchor_head import AnchorHead 2 | from .rpn_head import RPNHead 3 | 4 | __all__ = [ 5 | 'AnchorHead', 'RPNHead', 6 | ] 7 | -------------------------------------------------------------------------------- /mmdet/models/anchor_heads/rpn_head.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from mmcv.cnn import normal_init 5 | 6 | from mmdet.core import delta2bbox 7 | from mmdet.ops import nms 8 | from .anchor_head import AnchorHead 9 | from ..registry import HEADS 10 | 11 | 12 | @HEADS.register_module 13 | class RPNHead(AnchorHead): 14 | 15 | def __init__(self, in_channels, **kwargs): 16 | super(RPNHead, self).__init__(2, in_channels, **kwargs) 17 | 18 | def _init_layers(self): 19 | self.rpn_conv = nn.Conv2d( 20 | self.in_channels, self.feat_channels, 3, padding=1) 21 | self.rpn_cls = nn.Conv2d(self.feat_channels, 22 | self.num_anchors * self.cls_out_channels, 1) 23 | self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1) 24 | 25 | def init_weights(self): 26 | normal_init(self.rpn_conv, std=0.01) 27 | normal_init(self.rpn_cls, std=0.01) 28 | normal_init(self.rpn_reg, std=0.01) 29 | 30 | def forward_single(self, x): 31 | x = self.rpn_conv(x) 32 | x = F.relu(x, inplace=True) 33 | rpn_cls_score = self.rpn_cls(x) 34 | rpn_bbox_pred = self.rpn_reg(x) 35 | return rpn_cls_score, rpn_bbox_pred 36 | 37 | def get_bboxes_single(self, 38 | cls_scores, 39 | bbox_preds, 40 | mlvl_anchors, 41 | img_shape, 42 | scale_factor, 43 | cfg, 44 | rescale=False): 45 | mlvl_proposals = [] 46 | for idx in range(len(cls_scores)): 47 | rpn_cls_score = cls_scores[idx] 48 | rpn_bbox_pred = bbox_preds[idx] 49 | assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:] 50 | anchors = mlvl_anchors[idx] 51 | rpn_cls_score = rpn_cls_score.permute(1, 2, 0) 52 | 53 | rpn_cls_score = rpn_cls_score.reshape(-1) 54 | scores = rpn_cls_score.sigmoid() 55 | 56 | rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4) 57 | if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre: 58 | _, topk_inds = scores.topk(cfg.nms_pre) 59 | rpn_bbox_pred = rpn_bbox_pred[topk_inds, :] 60 | anchors = anchors[topk_inds, :] 61 | scores = scores[topk_inds] 62 | proposals = delta2bbox(anchors, rpn_bbox_pred, self.target_means, 63 | self.target_stds, img_shape) 64 | if cfg.min_bbox_size > 0: 65 | w = proposals[:, 2] - proposals[:, 0] + 1 66 | h = proposals[:, 3] - proposals[:, 1] + 1 67 | valid_inds = torch.nonzero((w >= cfg.min_bbox_size) & 68 | (h >= cfg.min_bbox_size)).squeeze() 69 | proposals = proposals[valid_inds, :] 70 | scores = scores[valid_inds] 71 | proposals = torch.cat([proposals, scores.unsqueeze(-1)], dim=-1) 72 | proposals, _ = nms(proposals, cfg.nms_thr) 73 | proposals = proposals[:cfg.nms_post, :] 74 | mlvl_proposals.append(proposals) 75 | proposals = torch.cat(mlvl_proposals, 0) 76 | if cfg.nms_across_levels: 77 | proposals, _ = nms(proposals, cfg.nms_thr) 78 | proposals = proposals[:cfg.max_num, :] 79 | else: 80 | scores = proposals[:, 4] 81 | num = min(cfg.max_num, proposals.shape[0]) 82 | _, topk_inds = scores.topk(num) 83 | proposals = proposals[topk_inds, :] 84 | return proposals 85 | -------------------------------------------------------------------------------- /mmdet/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .vgg import VGG 2 | 3 | __all__ = ['VGG'] 4 | -------------------------------------------------------------------------------- /mmdet/models/backbones/vgg.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.nn as nn 4 | from mmcv.cnn import (VGG, constant_init, kaiming_init, 5 | normal_init) 6 | 7 | from mmcv.runner import load_checkpoint 8 | from ..registry import BACKBONES 9 | 10 | 11 | @BACKBONES.register_module 12 | class VGG(VGG): 13 | def __init__(self, 14 | depth=16, 15 | with_last_pool=False, 16 | ceil_mode=True, 17 | frozen_stages=-1, 18 | ): 19 | super(VGG, self).__init__( 20 | depth, 21 | with_last_pool=with_last_pool, 22 | ceil_mode=ceil_mode, 23 | frozen_stages=frozen_stages, 24 | ) 25 | 26 | def init_weights(self, pretrained=None): 27 | if isinstance(pretrained, str): 28 | logger = logging.getLogger() 29 | load_checkpoint(self, pretrained, strict=False, logger=logger) 30 | elif pretrained is None: 31 | for m in self.features.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | kaiming_init(m) 34 | elif isinstance(m, nn.BatchNorm2d): 35 | constant_init(m, 1) 36 | elif isinstance(m, nn.Linear): 37 | normal_init(m, std=0.01) 38 | else: 39 | raise TypeError('pretrained must be a str or None') 40 | 41 | def forward(self, x): 42 | # remove the pool4 43 | for layer in self.features[:23]: 44 | x = layer(x) 45 | for layer in self.features[24:]: 46 | x = layer(x) 47 | return tuple([x]) 48 | 49 | 50 | -------------------------------------------------------------------------------- /mmdet/models/bbox_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .bbox_head import BBoxHead 2 | from .convfc_bbox_head import ConvFCBBoxHead, SharedFCBBoxHead 3 | 4 | __all__ = ['BBoxHead', 'ConvFCBBoxHead', 'SharedFCBBoxHead'] 5 | -------------------------------------------------------------------------------- /mmdet/models/builder.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from mmdet.utils import build_from_cfg 4 | from .registry import (BACKBONES, NECKS, ROI_EXTRACTORS, SHARED_HEADS, HEADS, 5 | LOSSES, DETECTORS) 6 | 7 | 8 | def build(cfg, registry, default_args=None): 9 | if isinstance(cfg, list): 10 | modules = [ 11 | build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg 12 | ] 13 | return nn.Sequential(*modules) 14 | else: 15 | return build_from_cfg(cfg, registry, default_args) 16 | 17 | 18 | def build_backbone(cfg): 19 | return build(cfg, BACKBONES) 20 | 21 | 22 | def build_neck(cfg): 23 | return build(cfg, NECKS) 24 | 25 | 26 | def build_roi_extractor(cfg): 27 | return build(cfg, ROI_EXTRACTORS) 28 | 29 | 30 | def build_shared_head(cfg): 31 | return build(cfg, SHARED_HEADS) 32 | 33 | 34 | def build_head(cfg): 35 | return build(cfg, HEADS) 36 | 37 | 38 | def build_loss(cfg): 39 | return build(cfg, LOSSES) 40 | 41 | 42 | def build_detector(cfg, train_cfg=None, test_cfg=None): 43 | return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg)) 44 | -------------------------------------------------------------------------------- /mmdet/models/detectors/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseDetector 2 | from .mgan import MGAN 3 | 4 | __all__ = [ 5 | 'BaseDetector', 'MGAN', 6 | ] 7 | -------------------------------------------------------------------------------- /mmdet/models/detectors/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from abc import ABCMeta, abstractmethod 3 | 4 | import mmcv 5 | import numpy as np 6 | import torch.nn as nn 7 | import pycocotools.mask as maskUtils 8 | 9 | from mmdet.core import tensor2imgs, get_classes, auto_fp16 10 | 11 | 12 | class BaseDetector(nn.Module): 13 | """Base class for detectors""" 14 | 15 | __metaclass__ = ABCMeta 16 | 17 | def __init__(self): 18 | super(BaseDetector, self).__init__() 19 | self.fp16_enabled = False 20 | 21 | @property 22 | def with_neck(self): 23 | return hasattr(self, 'neck') and self.neck is not None 24 | 25 | @property 26 | def with_shared_head(self): 27 | return hasattr(self, 'shared_head') and self.shared_head is not None 28 | 29 | @property 30 | def with_bbox(self): 31 | return hasattr(self, 'bbox_head') and self.bbox_head is not None 32 | 33 | @property 34 | def with_mask(self): 35 | return hasattr(self, 'mask_head') and self.mask_head is not None 36 | 37 | @property 38 | def with_mgan(self): 39 | return hasattr(self, 'mgan_head') and self.mgan_head is not None 40 | 41 | 42 | 43 | @abstractmethod 44 | def extract_feat(self, imgs): 45 | pass 46 | 47 | def extract_feats(self, imgs): 48 | assert isinstance(imgs, list) 49 | for img in imgs: 50 | yield self.extract_feat(img) 51 | 52 | @abstractmethod 53 | def forward_train(self, imgs, img_metas, **kwargs): 54 | pass 55 | 56 | @abstractmethod 57 | def simple_test(self, img, img_meta, **kwargs): 58 | pass 59 | 60 | @abstractmethod 61 | def aug_test(self, imgs, img_metas, **kwargs): 62 | pass 63 | 64 | def init_weights(self, pretrained=None): 65 | if pretrained is not None: 66 | logger = logging.getLogger() 67 | logger.info('load model from: {}'.format(pretrained)) 68 | 69 | def forward_test(self, imgs, img_metas, **kwargs): 70 | for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]: 71 | if not isinstance(var, list): 72 | raise TypeError('{} must be a list, but got {}'.format( 73 | name, type(var))) 74 | 75 | num_augs = len(imgs) 76 | if num_augs != len(img_metas): 77 | raise ValueError( 78 | 'num of augmentations ({}) != num of image meta ({})'.format( 79 | len(imgs), len(img_metas))) 80 | # TODO: remove the restriction of imgs_per_gpu == 1 when prepared 81 | imgs_per_gpu = imgs[0].size(0) 82 | assert imgs_per_gpu == 1 83 | 84 | if num_augs == 1: 85 | return self.simple_test(imgs[0], img_metas[0], **kwargs) 86 | else: 87 | return self.aug_test(imgs, img_metas, **kwargs) 88 | 89 | @auto_fp16(apply_to=('img', )) 90 | def forward(self, img, img_meta, return_loss=True, **kwargs): 91 | if return_loss: 92 | return self.forward_train(img, img_meta, **kwargs) 93 | else: 94 | return self.forward_test(img, img_meta, **kwargs) 95 | 96 | def show_result(self, 97 | data, 98 | result, 99 | img_norm_cfg, 100 | dataset=None, 101 | score_thr=0.3): 102 | if isinstance(result, tuple): 103 | bbox_result, segm_result = result 104 | else: 105 | bbox_result, segm_result = result, None 106 | 107 | img_tensor = data['img'][0] 108 | img_metas = data['img_meta'][0].data[0] 109 | imgs = tensor2imgs(img_tensor, **img_norm_cfg) 110 | assert len(imgs) == len(img_metas) 111 | 112 | if dataset is None: 113 | class_names = self.CLASSES 114 | elif isinstance(dataset, str): 115 | class_names = get_classes(dataset) 116 | elif isinstance(dataset, (list, tuple)): 117 | class_names = dataset 118 | else: 119 | raise TypeError( 120 | 'dataset must be a valid dataset name or a sequence' 121 | ' of class names, not {}'.format(type(dataset))) 122 | 123 | for img, img_meta in zip(imgs, img_metas): 124 | h, w, _ = img_meta['img_shape'] 125 | img_show = img[:h, :w, :] 126 | 127 | bboxes = np.vstack(bbox_result) 128 | # draw segmentation masks 129 | if segm_result is not None: 130 | segms = mmcv.concat_list(segm_result) 131 | inds = np.where(bboxes[:, -1] > score_thr)[0] 132 | for i in inds: 133 | color_mask = np.random.randint( 134 | 0, 256, (1, 3), dtype=np.uint8) 135 | mask = maskUtils.decode(segms[i]).astype(np.bool) 136 | img_show[mask] = img_show[mask] * 0.5 + color_mask * 0.5 137 | # draw bounding boxes 138 | labels = [ 139 | np.full(bbox.shape[0], i, dtype=np.int32) 140 | for i, bbox in enumerate(bbox_result) 141 | ] 142 | labels = np.concatenate(labels) 143 | mmcv.imshow_det_bboxes( 144 | img_show, 145 | bboxes, 146 | labels, 147 | class_names=class_names, 148 | score_thr=score_thr) 149 | -------------------------------------------------------------------------------- /mmdet/models/detectors/mgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .base import BaseDetector 5 | from .test_mixins import RPNTestMixin, BBoxTestMixin 6 | from .. import builder 7 | from ..registry import DETECTORS 8 | from mmdet.core import bbox2roi, bbox2result, build_assigner, build_sampler 9 | 10 | 11 | @DETECTORS.register_module 12 | class MGAN(BaseDetector, RPNTestMixin, BBoxTestMixin): 13 | 14 | def __init__(self, 15 | backbone, 16 | neck=None, 17 | shared_head=None, 18 | rpn_head=None, 19 | bbox_roi_extractor=None, 20 | bbox_head=None, 21 | mgan_head=None, 22 | train_cfg=None, 23 | test_cfg=None, 24 | pretrained=None): 25 | super(MGAN, self).__init__() 26 | self.backbone = builder.build_backbone(backbone) 27 | 28 | if neck is not None: 29 | self.neck = builder.build_neck(neck) 30 | 31 | if shared_head is not None: 32 | self.shared_head = builder.build_shared_head(shared_head) 33 | 34 | if rpn_head is not None: 35 | self.rpn_head = builder.build_head(rpn_head) 36 | 37 | if bbox_head is not None: 38 | self.bbox_roi_extractor = builder.build_roi_extractor( 39 | bbox_roi_extractor) 40 | self.bbox_head = builder.build_head(bbox_head) 41 | 42 | if mgan_head is not None: 43 | self.mgan_head = builder.build_head(mgan_head) 44 | 45 | self.train_cfg = train_cfg 46 | self.test_cfg = test_cfg 47 | 48 | self.init_weights(pretrained=pretrained) 49 | 50 | @property 51 | def with_rpn(self): 52 | return hasattr(self, 'rpn_head') and self.rpn_head is not None 53 | 54 | def init_weights(self, pretrained=None): 55 | super(MGAN, self).init_weights(pretrained) 56 | self.backbone.init_weights(pretrained=pretrained) 57 | if self.with_neck: 58 | if isinstance(self.neck, nn.Sequential): 59 | for m in self.neck: 60 | m.init_weights() 61 | else: 62 | self.neck.init_weights() 63 | if self.with_shared_head: 64 | self.shared_head.init_weights(pretrained=pretrained) 65 | if self.with_rpn: 66 | self.rpn_head.init_weights() 67 | if self.with_bbox: 68 | self.bbox_roi_extractor.init_weights() 69 | self.bbox_head.init_weights() 70 | 71 | 72 | def extract_feat(self, img): 73 | x = self.backbone(img) 74 | if self.with_neck: 75 | x = self.neck(x) 76 | return x 77 | 78 | def simple_test(self, img, img_meta, proposals=None, rescale=False): 79 | """Test without augmentation.""" 80 | assert self.with_bbox, "Bbox head must be implemented." 81 | 82 | x = self.extract_feat(img) 83 | 84 | proposal_list = self.simple_test_rpn( 85 | x, img_meta, self.test_cfg.rpn) if proposals is None else proposals 86 | 87 | det_bboxes, det_labels = self.simple_test_bboxes_mgan( 88 | x, img_meta, proposal_list, self.test_cfg.rcnn, rescale=rescale) 89 | bbox_results = bbox2result(det_bboxes, det_labels, 90 | self.bbox_head.num_classes) 91 | 92 | if not self.with_mask: 93 | return bbox_results 94 | else: 95 | segm_results = self.simple_test_mask( 96 | x, img_meta, det_bboxes, det_labels, rescale=rescale) 97 | return bbox_results, segm_results 98 | 99 | def aug_test(self, imgs, img_metas, rescale=False): 100 | """Test with augmentations. 101 | 102 | If rescale is False, then returned bboxes and masks will fit the scale 103 | of imgs[0]. 104 | """ 105 | # recompute feats to save memory 106 | proposal_list = self.aug_test_rpn( 107 | self.extract_feats(imgs), img_metas, self.test_cfg.rpn) 108 | det_bboxes, det_labels = self.aug_test_bboxes( 109 | self.extract_feats(imgs), img_metas, proposal_list, 110 | self.test_cfg.rcnn) 111 | 112 | if rescale: 113 | _det_bboxes = det_bboxes 114 | else: 115 | _det_bboxes = det_bboxes.clone() 116 | _det_bboxes[:, :4] *= img_metas[0][0]['scale_factor'] 117 | bbox_results = bbox2result(_det_bboxes, det_labels, 118 | self.bbox_head.num_classes) 119 | 120 | # det_bboxes always keep the original scale 121 | if self.with_mask: 122 | segm_results = self.aug_test_mask( 123 | self.extract_feats(imgs), img_metas, det_bboxes, det_labels) 124 | return bbox_results, segm_results 125 | else: 126 | return bbox_results 127 | -------------------------------------------------------------------------------- /mmdet/models/detectors/test_mixins.py: -------------------------------------------------------------------------------- 1 | from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_proposals, 2 | merge_aug_bboxes, merge_aug_masks, multiclass_nms) 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | class RPNTestMixin(object): 7 | def simple_test_rpn(self, x, img_meta, rpn_test_cfg): 8 | rpn_outs = self.rpn_head(x) 9 | proposal_inputs = rpn_outs + (img_meta, rpn_test_cfg) 10 | proposal_list = self.rpn_head.get_bboxes(*proposal_inputs) 11 | return proposal_list 12 | 13 | class BBoxTestMixin(object): 14 | def simple_test_bboxes_mgan(self, 15 | x, 16 | img_meta, 17 | proposals, 18 | rcnn_test_cfg, 19 | rescale=False): 20 | """Test only det bboxes without augmentation.""" 21 | rois = bbox2roi(proposals) 22 | roi_feats = self.bbox_roi_extractor( 23 | x[:len(self.bbox_roi_extractor.featmap_strides)], rois) 24 | if self.with_shared_head: 25 | roi_feats = self.shared_head(roi_feats) 26 | roi_feats = self.mgan_head(roi_feats) 27 | cls_score, bbox_pred = self.bbox_head(roi_feats) 28 | img_shape = img_meta[0]['img_shape'] 29 | scale_factor = img_meta[0]['scale_factor'] 30 | det_bboxes, det_labels = self.bbox_head.get_det_bboxes( 31 | rois, 32 | cls_score, 33 | bbox_pred, 34 | img_shape, 35 | scale_factor, 36 | rescale=rescale, 37 | cfg=rcnn_test_cfg) 38 | return det_bboxes, det_labels 39 | -------------------------------------------------------------------------------- /mmdet/models/mgan_heads/__init__.py: -------------------------------------------------------------------------------- 1 | from .mgan_head import MGANHead 2 | 3 | __all__ = ['MGANHead'] 4 | -------------------------------------------------------------------------------- /mmdet/models/mgan_heads/mgan_head.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..registry import HEADS 4 | from ..utils import ConvModule 5 | from mmdet.core import auto_fp16 6 | 7 | 8 | @HEADS.register_module 9 | class MGANHead(nn.Module): 10 | 11 | def __init__(self, 12 | num_convs=2, 13 | roi_feat_size=7, 14 | in_channels=512, 15 | conv_out_channels=512, 16 | conv_cfg=None, 17 | norm_cfg=None): 18 | super(MGANHead, self).__init__() 19 | self.num_convs = num_convs 20 | self.roi_feat_size = roi_feat_size 21 | self.in_channels = in_channels 22 | self.conv_out_channels = conv_out_channels 23 | 24 | self.conv_cfg = conv_cfg 25 | self.norm_cfg = norm_cfg 26 | self.fp16_enabled = False 27 | 28 | self.convs = nn.ModuleList() 29 | for i in range(self.num_convs): 30 | in_channels = ( 31 | self.in_channels if i == 0 else self.conv_out_channels) 32 | self.convs.append( 33 | ConvModule( 34 | in_channels, 35 | self.conv_out_channels, 36 | 3, 37 | padding=1, 38 | conv_cfg=conv_cfg, 39 | norm_cfg=norm_cfg)) 40 | logits_in_channel = self.conv_out_channels 41 | self.conv_logits = nn.Conv2d(logits_in_channel, 1, 1) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.debug_imgs = None 44 | 45 | @auto_fp16() 46 | def forward(self, x): 47 | feat = x 48 | for conv in self.convs: 49 | x = conv(x) 50 | out = self.conv_logits(x).sigmoid() * feat 51 | return out 52 | 53 | 54 | -------------------------------------------------------------------------------- /mmdet/models/necks/__init__.py: -------------------------------------------------------------------------------- 1 | from .fpn import FPN 2 | from .hrfpn import HRFPN 3 | 4 | __all__ = ['FPN', 'HRFPN'] 5 | -------------------------------------------------------------------------------- /mmdet/models/necks/fpn.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from mmcv.cnn import xavier_init 4 | 5 | from mmdet.core import auto_fp16 6 | from ..registry import NECKS 7 | from ..utils import ConvModule 8 | 9 | 10 | @NECKS.register_module 11 | class FPN(nn.Module): 12 | 13 | def __init__(self, 14 | in_channels, 15 | out_channels, 16 | num_outs, 17 | start_level=0, 18 | end_level=-1, 19 | add_extra_convs=False, 20 | extra_convs_on_inputs=True, 21 | relu_before_extra_convs=False, 22 | conv_cfg=None, 23 | norm_cfg=None, 24 | activation=None): 25 | super(FPN, self).__init__() 26 | assert isinstance(in_channels, list) 27 | self.in_channels = in_channels 28 | self.out_channels = out_channels 29 | self.num_ins = len(in_channels) 30 | self.num_outs = num_outs 31 | self.activation = activation 32 | self.relu_before_extra_convs = relu_before_extra_convs 33 | self.fp16_enabled = False 34 | 35 | if end_level == -1: 36 | self.backbone_end_level = self.num_ins 37 | assert num_outs >= self.num_ins - start_level 38 | else: 39 | # if end_level < inputs, no extra level is allowed 40 | self.backbone_end_level = end_level 41 | assert end_level <= len(in_channels) 42 | assert num_outs == end_level - start_level 43 | self.start_level = start_level 44 | self.end_level = end_level 45 | self.add_extra_convs = add_extra_convs 46 | self.extra_convs_on_inputs = extra_convs_on_inputs 47 | 48 | self.lateral_convs = nn.ModuleList() 49 | self.fpn_convs = nn.ModuleList() 50 | 51 | for i in range(self.start_level, self.backbone_end_level): 52 | l_conv = ConvModule( 53 | in_channels[i], 54 | out_channels, 55 | 1, 56 | conv_cfg=conv_cfg, 57 | norm_cfg=norm_cfg, 58 | activation=self.activation, 59 | inplace=False) 60 | fpn_conv = ConvModule( 61 | out_channels, 62 | out_channels, 63 | 3, 64 | padding=1, 65 | conv_cfg=conv_cfg, 66 | norm_cfg=norm_cfg, 67 | activation=self.activation, 68 | inplace=False) 69 | 70 | self.lateral_convs.append(l_conv) 71 | self.fpn_convs.append(fpn_conv) 72 | 73 | # add extra conv layers (e.g., RetinaNet) 74 | extra_levels = num_outs - self.backbone_end_level + self.start_level 75 | if add_extra_convs and extra_levels >= 1: 76 | for i in range(extra_levels): 77 | if i == 0 and self.extra_convs_on_inputs: 78 | in_channels = self.in_channels[self.backbone_end_level - 1] 79 | else: 80 | in_channels = out_channels 81 | extra_fpn_conv = ConvModule( 82 | in_channels, 83 | out_channels, 84 | 3, 85 | stride=2, 86 | padding=1, 87 | conv_cfg=conv_cfg, 88 | norm_cfg=norm_cfg, 89 | activation=self.activation, 90 | inplace=False) 91 | self.fpn_convs.append(extra_fpn_conv) 92 | 93 | # default init_weights for conv(msra) and norm in ConvModule 94 | def init_weights(self): 95 | for m in self.modules(): 96 | if isinstance(m, nn.Conv2d): 97 | xavier_init(m, distribution='uniform') 98 | 99 | @auto_fp16() 100 | def forward(self, inputs): 101 | assert len(inputs) == len(self.in_channels) 102 | 103 | # build laterals 104 | laterals = [ 105 | lateral_conv(inputs[i + self.start_level]) 106 | for i, lateral_conv in enumerate(self.lateral_convs) 107 | ] 108 | 109 | # build top-down path 110 | used_backbone_levels = len(laterals) 111 | for i in range(used_backbone_levels - 1, 0, -1): 112 | laterals[i - 1] += F.interpolate( 113 | laterals[i], scale_factor=2, mode='nearest') 114 | 115 | # build outputs 116 | # part 1: from original levels 117 | outs = [ 118 | self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels) 119 | ] 120 | # part 2: add extra levels 121 | if self.num_outs > len(outs): 122 | # use max pool to get more levels on top of outputs 123 | # (e.g., Faster R-CNN, Mask R-CNN) 124 | if not self.add_extra_convs: 125 | for i in range(self.num_outs - used_backbone_levels): 126 | outs.append(F.max_pool2d(outs[-1], 1, stride=2)) 127 | # add conv layers on top of original feature maps (RetinaNet) 128 | else: 129 | if self.extra_convs_on_inputs: 130 | orig = inputs[self.backbone_end_level - 1] 131 | outs.append(self.fpn_convs[used_backbone_levels](orig)) 132 | else: 133 | outs.append(self.fpn_convs[used_backbone_levels](outs[-1])) 134 | for i in range(used_backbone_levels + 1, self.num_outs): 135 | if self.relu_before_extra_convs: 136 | outs.append(self.fpn_convs[i](F.relu(outs[-1]))) 137 | else: 138 | outs.append(self.fpn_convs[i](outs[-1])) 139 | return tuple(outs) 140 | -------------------------------------------------------------------------------- /mmdet/models/necks/hrfpn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.utils.checkpoint import checkpoint 5 | from mmcv.cnn.weight_init import caffe2_xavier_init 6 | 7 | from ..utils import ConvModule 8 | from ..registry import NECKS 9 | 10 | 11 | @NECKS.register_module 12 | class HRFPN(nn.Module): 13 | """HRFPN (High Resolution Feature Pyrmamids) 14 | 15 | arXiv: https://arxiv.org/abs/1904.04514 16 | 17 | Args: 18 | in_channels (list): number of channels for each branch. 19 | out_channels (int): output channels of feature pyramids. 20 | num_outs (int): number of output stages. 21 | pooling_type (str): pooling for generating feature pyramids 22 | from {MAX, AVG}. 23 | conv_cfg (dict): dictionary to construct and config conv layer. 24 | norm_cfg (dict): dictionary to construct and config norm layer. 25 | with_cp (bool): Use checkpoint or not. Using checkpoint will save some 26 | memory while slowing down the training speed. 27 | """ 28 | 29 | def __init__(self, 30 | in_channels, 31 | out_channels, 32 | num_outs=5, 33 | pooling_type='AVG', 34 | conv_cfg=None, 35 | norm_cfg=None, 36 | with_cp=False): 37 | super(HRFPN, self).__init__() 38 | assert isinstance(in_channels, list) 39 | self.in_channels = in_channels 40 | self.out_channels = out_channels 41 | self.num_ins = len(in_channels) 42 | self.num_outs = num_outs 43 | self.with_cp = with_cp 44 | self.conv_cfg = conv_cfg 45 | self.norm_cfg = norm_cfg 46 | 47 | self.reduction_conv = ConvModule( 48 | sum(in_channels), 49 | out_channels, 50 | kernel_size=1, 51 | conv_cfg=self.conv_cfg, 52 | activation=None) 53 | 54 | self.fpn_convs = nn.ModuleList() 55 | for i in range(self.num_outs): 56 | self.fpn_convs.append( 57 | ConvModule( 58 | out_channels, 59 | out_channels, 60 | kernel_size=3, 61 | padding=1, 62 | conv_cfg=self.conv_cfg, 63 | activation=None)) 64 | 65 | if pooling_type == 'MAX': 66 | self.pooling = F.max_pool2d 67 | else: 68 | self.pooling = F.avg_pool2d 69 | 70 | def init_weights(self): 71 | for m in self.modules(): 72 | if isinstance(m, nn.Conv2d): 73 | caffe2_xavier_init(m) 74 | 75 | def forward(self, inputs): 76 | assert len(inputs) == self.num_ins 77 | outs = [inputs[0]] 78 | for i in range(1, self.num_ins): 79 | outs.append( 80 | F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear')) 81 | out = torch.cat(outs, dim=1) 82 | if out.requires_grad and self.with_cp: 83 | out = checkpoint(self.reduction_conv, out) 84 | else: 85 | out = self.reduction_conv(out) 86 | outs = [out] 87 | for i in range(1, self.num_outs): 88 | outs.append(self.pooling(out, kernel_size=2**i, stride=2**i)) 89 | outputs = [] 90 | 91 | for i in range(self.num_outs): 92 | if outs[i].requires_grad and self.with_cp: 93 | tmp_out = checkpoint(self.fpn_convs[i], outs[i]) 94 | else: 95 | tmp_out = self.fpn_convs[i](outs[i]) 96 | outputs.append(tmp_out) 97 | return tuple(outputs) 98 | -------------------------------------------------------------------------------- /mmdet/models/registry.py: -------------------------------------------------------------------------------- 1 | from mmdet.utils import Registry 2 | 3 | BACKBONES = Registry('backbone') 4 | NECKS = Registry('neck') 5 | ROI_EXTRACTORS = Registry('roi_extractor') 6 | SHARED_HEADS = Registry('shared_head') 7 | HEADS = Registry('head') 8 | LOSSES = Registry('loss') 9 | DETECTORS = Registry('detector') 10 | -------------------------------------------------------------------------------- /mmdet/models/roi_extractors/__init__.py: -------------------------------------------------------------------------------- 1 | from .single_level import SingleRoIExtractor 2 | 3 | __all__ = ['SingleRoIExtractor'] 4 | -------------------------------------------------------------------------------- /mmdet/models/roi_extractors/single_level.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from mmdet import ops 7 | from mmdet.core import force_fp32 8 | from ..registry import ROI_EXTRACTORS 9 | 10 | 11 | @ROI_EXTRACTORS.register_module 12 | class SingleRoIExtractor(nn.Module): 13 | """Extract RoI features from a single level feature map. 14 | 15 | If there are mulitple input feature levels, each RoI is mapped to a level 16 | according to its scale. 17 | 18 | Args: 19 | roi_layer (dict): Specify RoI layer type and arguments. 20 | out_channels (int): Output channels of RoI layers. 21 | featmap_strides (int): Strides of input feature maps. 22 | finest_scale (int): Scale threshold of mapping to level 0. 23 | """ 24 | 25 | def __init__(self, 26 | roi_layer, 27 | out_channels, 28 | featmap_strides, 29 | finest_scale=56): 30 | super(SingleRoIExtractor, self).__init__() 31 | self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides) 32 | self.out_channels = out_channels 33 | self.featmap_strides = featmap_strides 34 | self.finest_scale = finest_scale 35 | self.fp16_enabled = False 36 | 37 | @property 38 | def num_inputs(self): 39 | """int: Input feature map levels.""" 40 | return len(self.featmap_strides) 41 | 42 | def init_weights(self): 43 | pass 44 | 45 | def build_roi_layers(self, layer_cfg, featmap_strides): 46 | cfg = layer_cfg.copy() 47 | layer_type = cfg.pop('type') 48 | assert hasattr(ops, layer_type) 49 | layer_cls = getattr(ops, layer_type) 50 | roi_layers = nn.ModuleList( 51 | [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides]) 52 | return roi_layers 53 | 54 | def map_roi_levels(self, rois, num_levels): 55 | """Map rois to corresponding feature levels by scales. 56 | 57 | - scale < finest_scale * 2: level 0 58 | - finest_scale * 2 <= scale < finest_scale * 4: level 1 59 | - finest_scale * 4 <= scale < finest_scale * 8: level 2 60 | - scale >= finest_scale * 8: level 3 61 | 62 | Args: 63 | rois (Tensor): Input RoIs, shape (k, 5). 64 | num_levels (int): Total level number. 65 | 66 | Returns: 67 | Tensor: Level index (0-based) of each RoI, shape (k, ) 68 | """ 69 | scale = torch.sqrt( 70 | (rois[:, 3] - rois[:, 1] + 1) * (rois[:, 4] - rois[:, 2] + 1)) 71 | target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6)) 72 | target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long() 73 | return target_lvls 74 | 75 | @force_fp32(apply_to=('feats',), out_fp16=True) 76 | def forward(self, feats, rois): 77 | if len(feats) == 1: 78 | return self.roi_layers[0](feats[0], rois) 79 | 80 | out_size = self.roi_layers[0].out_size 81 | num_levels = len(feats) 82 | target_lvls = self.map_roi_levels(rois, num_levels) 83 | roi_feats = feats[0].new_zeros(rois.size()[0], self.out_channels, 84 | out_size, out_size) 85 | for i in range(num_levels): 86 | inds = target_lvls == i 87 | if inds.any(): 88 | rois_ = rois[inds, :] 89 | roi_feats_t = self.roi_layers[i](feats[i], rois_) 90 | roi_feats[inds] += roi_feats_t 91 | return roi_feats 92 | -------------------------------------------------------------------------------- /mmdet/models/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .conv_ws import conv_ws_2d, ConvWS2d 2 | from .conv_module import build_conv_layer, ConvModule 3 | from .norm import build_norm_layer 4 | from .scale import Scale 5 | from .weight_init import (xavier_init, normal_init, uniform_init, kaiming_init, 6 | bias_init_with_prob) 7 | 8 | __all__ = [ 9 | 'conv_ws_2d', 'ConvWS2d', 'build_conv_layer', 'ConvModule', 10 | 'build_norm_layer', 'xavier_init', 'normal_init', 'uniform_init', 11 | 'kaiming_init', 'bias_init_with_prob', 'Scale' 12 | ] 13 | -------------------------------------------------------------------------------- /mmdet/models/utils/conv_module.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch.nn as nn 4 | from mmcv.cnn import kaiming_init, constant_init 5 | 6 | from .conv_ws import ConvWS2d 7 | from .norm import build_norm_layer 8 | 9 | conv_cfg = { 10 | 'Conv': nn.Conv2d, 11 | 'ConvWS': ConvWS2d, 12 | # TODO: octave conv 13 | } 14 | 15 | 16 | def build_conv_layer(cfg, *args, **kwargs): 17 | """ Build convolution layer 18 | 19 | Args: 20 | cfg (None or dict): cfg should contain: 21 | type (str): identify conv layer type. 22 | layer args: args needed to instantiate a conv layer. 23 | 24 | Returns: 25 | layer (nn.Module): created conv layer 26 | """ 27 | if cfg is None: 28 | cfg_ = dict(type='Conv') 29 | else: 30 | assert isinstance(cfg, dict) and 'type' in cfg 31 | cfg_ = cfg.copy() 32 | 33 | layer_type = cfg_.pop('type') 34 | if layer_type not in conv_cfg: 35 | raise KeyError('Unrecognized norm type {}'.format(layer_type)) 36 | else: 37 | conv_layer = conv_cfg[layer_type] 38 | 39 | layer = conv_layer(*args, **kwargs, **cfg_) 40 | 41 | return layer 42 | 43 | 44 | class ConvModule(nn.Module): 45 | """Conv-Norm-Activation block. 46 | 47 | Args: 48 | in_channels (int): Same as nn.Conv2d. 49 | out_channels (int): Same as nn.Conv2d. 50 | kernel_size (int or tuple[int]): Same as nn.Conv2d. 51 | stride (int or tuple[int]): Same as nn.Conv2d. 52 | padding (int or tuple[int]): Same as nn.Conv2d. 53 | dilation (int or tuple[int]): Same as nn.Conv2d. 54 | groups (int): Same as nn.Conv2d. 55 | bias (bool or str): If specified as `auto`, it will be decided by the 56 | norm_cfg. Bias will be set as True if norm_cfg is None, otherwise 57 | False. 58 | conv_cfg (dict): Config dict for convolution layer. 59 | norm_cfg (dict): Config dict for normalization layer. 60 | activation (str or None): Activation type, "ReLU" by default. 61 | inplace (bool): Whether to use inplace mode for activation. 62 | activate_last (bool): Whether to apply the activation layer in the 63 | last. (Do not use this flag since the behavior and api may be 64 | changed in the future.) 65 | """ 66 | 67 | def __init__(self, 68 | in_channels, 69 | out_channels, 70 | kernel_size, 71 | stride=1, 72 | padding=0, 73 | dilation=1, 74 | groups=1, 75 | bias='auto', 76 | conv_cfg=None, 77 | norm_cfg=None, 78 | activation='relu', 79 | inplace=True, 80 | activate_last=True): 81 | super(ConvModule, self).__init__() 82 | assert conv_cfg is None or isinstance(conv_cfg, dict) 83 | assert norm_cfg is None or isinstance(norm_cfg, dict) 84 | self.conv_cfg = conv_cfg 85 | self.norm_cfg = norm_cfg 86 | self.activation = activation 87 | self.inplace = inplace 88 | self.activate_last = activate_last 89 | 90 | self.with_norm = norm_cfg is not None 91 | self.with_activatation = activation is not None 92 | # if the conv layer is before a norm layer, bias is unnecessary. 93 | if bias == 'auto': 94 | bias = False if self.with_norm else True 95 | self.with_bias = bias 96 | 97 | if self.with_norm and self.with_bias: 98 | warnings.warn('ConvModule has norm and bias at the same time') 99 | 100 | # build convolution layer 101 | self.conv = build_conv_layer( 102 | conv_cfg, 103 | in_channels, 104 | out_channels, 105 | kernel_size, 106 | stride=stride, 107 | padding=padding, 108 | dilation=dilation, 109 | groups=groups, 110 | bias=bias) 111 | # export the attributes of self.conv to a higher level for convenience 112 | self.in_channels = self.conv.in_channels 113 | self.out_channels = self.conv.out_channels 114 | self.kernel_size = self.conv.kernel_size 115 | self.stride = self.conv.stride 116 | self.padding = self.conv.padding 117 | self.dilation = self.conv.dilation 118 | self.transposed = self.conv.transposed 119 | self.output_padding = self.conv.output_padding 120 | self.groups = self.conv.groups 121 | 122 | # build normalization layers 123 | if self.with_norm: 124 | norm_channels = out_channels if self.activate_last else in_channels 125 | self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels) 126 | self.add_module(self.norm_name, norm) 127 | 128 | # build activation layer 129 | if self.with_activatation: 130 | if self.activation not in ['relu']: 131 | raise ValueError('{} is currently not supported.'.format( 132 | self.activation)) 133 | if self.activation == 'relu': 134 | self.activate = nn.ReLU(inplace=inplace) 135 | 136 | # Use msra init by default 137 | self.init_weights() 138 | 139 | @property 140 | def norm(self): 141 | return getattr(self, self.norm_name) 142 | 143 | def init_weights(self): 144 | nonlinearity = 'relu' if self.activation is None else self.activation 145 | kaiming_init(self.conv, nonlinearity=nonlinearity) 146 | if self.with_norm: 147 | constant_init(self.norm, 1, bias=0) 148 | 149 | def forward(self, x, activate=True, norm=True): 150 | if self.activate_last: 151 | x = self.conv(x) 152 | if norm and self.with_norm: 153 | x = self.norm(x) 154 | if activate and self.with_activatation: 155 | x = self.activate(x) 156 | else: 157 | # WARN: this may be removed or modified 158 | if norm and self.with_norm: 159 | x = self.norm(x) 160 | if activate and self.with_activatation: 161 | x = self.activate(x) 162 | x = self.conv(x) 163 | return x 164 | -------------------------------------------------------------------------------- /mmdet/models/utils/conv_ws.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | def conv_ws_2d(input, 6 | weight, 7 | bias=None, 8 | stride=1, 9 | padding=0, 10 | dilation=1, 11 | groups=1, 12 | eps=1e-5): 13 | c_in = weight.size(0) 14 | weight_flat = weight.view(c_in, -1) 15 | mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1) 16 | std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1) 17 | weight = (weight - mean) / (std + eps) 18 | return F.conv2d(input, weight, bias, stride, padding, dilation, groups) 19 | 20 | 21 | class ConvWS2d(nn.Conv2d): 22 | 23 | def __init__(self, 24 | in_channels, 25 | out_channels, 26 | kernel_size, 27 | stride=1, 28 | padding=0, 29 | dilation=1, 30 | groups=1, 31 | bias=True, 32 | eps=1e-5): 33 | super(ConvWS2d, self).__init__( 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride=stride, 38 | padding=padding, 39 | dilation=dilation, 40 | groups=groups, 41 | bias=bias) 42 | self.eps = eps 43 | 44 | def forward(self, x): 45 | return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding, 46 | self.dilation, self.groups, self.eps) 47 | -------------------------------------------------------------------------------- /mmdet/models/utils/norm.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | norm_cfg = { 4 | # format: layer_type: (abbreviation, module) 5 | 'BN': ('bn', nn.BatchNorm2d), 6 | 'SyncBN': ('bn', nn.SyncBatchNorm), 7 | 'GN': ('gn', nn.GroupNorm), 8 | # and potentially 'SN' 9 | } 10 | 11 | 12 | def build_norm_layer(cfg, num_features, postfix=''): 13 | """ Build normalization layer 14 | 15 | Args: 16 | cfg (dict): cfg should contain: 17 | type (str): identify norm layer type. 18 | layer args: args needed to instantiate a norm layer. 19 | requires_grad (bool): [optional] whether stop gradient updates 20 | num_features (int): number of channels from input. 21 | postfix (int, str): appended into norm abbreviation to 22 | create named layer. 23 | 24 | Returns: 25 | name (str): abbreviation + postfix 26 | layer (nn.Module): created norm layer 27 | """ 28 | assert isinstance(cfg, dict) and 'type' in cfg 29 | cfg_ = cfg.copy() 30 | 31 | layer_type = cfg_.pop('type') 32 | if layer_type not in norm_cfg: 33 | raise KeyError('Unrecognized norm type {}'.format(layer_type)) 34 | else: 35 | abbr, norm_layer = norm_cfg[layer_type] 36 | if norm_layer is None: 37 | raise NotImplementedError 38 | 39 | assert isinstance(postfix, (int, str)) 40 | name = abbr + str(postfix) 41 | 42 | requires_grad = cfg_.pop('requires_grad', True) 43 | cfg_.setdefault('eps', 1e-5) 44 | if layer_type != 'GN': 45 | layer = norm_layer(num_features, **cfg_) 46 | if layer_type == 'SyncBN': 47 | layer._specify_ddp_gpu_num(1) 48 | else: 49 | assert 'num_groups' in cfg_ 50 | layer = norm_layer(num_channels=num_features, **cfg_) 51 | 52 | for param in layer.parameters(): 53 | param.requires_grad = requires_grad 54 | 55 | return name, layer 56 | -------------------------------------------------------------------------------- /mmdet/models/utils/scale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Scale(nn.Module): 6 | 7 | def __init__(self, scale=1.0): 8 | super(Scale, self).__init__() 9 | self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float)) 10 | 11 | def forward(self, x): 12 | return x * self.scale 13 | -------------------------------------------------------------------------------- /mmdet/models/utils/weight_init.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | 4 | 5 | def xavier_init(module, gain=1, bias=0, distribution='normal'): 6 | assert distribution in ['uniform', 'normal'] 7 | if distribution == 'uniform': 8 | nn.init.xavier_uniform_(module.weight, gain=gain) 9 | else: 10 | nn.init.xavier_normal_(module.weight, gain=gain) 11 | if hasattr(module, 'bias'): 12 | nn.init.constant_(module.bias, bias) 13 | 14 | 15 | def normal_init(module, mean=0, std=1, bias=0): 16 | nn.init.normal_(module.weight, mean, std) 17 | if hasattr(module, 'bias'): 18 | nn.init.constant_(module.bias, bias) 19 | 20 | 21 | def uniform_init(module, a=0, b=1, bias=0): 22 | nn.init.uniform_(module.weight, a, b) 23 | if hasattr(module, 'bias'): 24 | nn.init.constant_(module.bias, bias) 25 | 26 | 27 | def kaiming_init(module, 28 | mode='fan_out', 29 | nonlinearity='relu', 30 | bias=0, 31 | distribution='normal'): 32 | assert distribution in ['uniform', 'normal'] 33 | if distribution == 'uniform': 34 | nn.init.kaiming_uniform_( 35 | module.weight, mode=mode, nonlinearity=nonlinearity) 36 | else: 37 | nn.init.kaiming_normal_( 38 | module.weight, mode=mode, nonlinearity=nonlinearity) 39 | if hasattr(module, 'bias'): 40 | nn.init.constant_(module.bias, bias) 41 | 42 | 43 | def bias_init_with_prob(prior_prob): 44 | """ initialize conv/fc bias value according to giving probablity""" 45 | bias_init = float(-np.log((1 - prior_prob) / prior_prob)) 46 | return bias_init 47 | -------------------------------------------------------------------------------- /mmdet/ops/__init__.py: -------------------------------------------------------------------------------- 1 | from .nms import nms, soft_nms 2 | from .roi_align import RoIAlign, roi_align 3 | from .roi_pool import RoIPool, roi_pool 4 | 5 | 6 | __all__ = [ 7 | 'nms', 'soft_nms', 'RoIAlign', 'roi_align', 'RoIPool', 'roi_pool', 8 | ] 9 | -------------------------------------------------------------------------------- /mmdet/ops/nms/__init__.py: -------------------------------------------------------------------------------- 1 | from .nms_wrapper import nms, soft_nms 2 | 3 | __all__ = ['nms', 'soft_nms'] 4 | -------------------------------------------------------------------------------- /mmdet/ops/nms/nms_wrapper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from . import nms_cuda, nms_cpu 5 | from .soft_nms_cpu import soft_nms_cpu 6 | 7 | 8 | def nms(dets, iou_thr, device_id=None): 9 | """Dispatch to either CPU or GPU NMS implementations. 10 | 11 | The input can be either a torch tensor or numpy array. GPU NMS will be used 12 | if the input is a gpu tensor or device_id is specified, otherwise CPU NMS 13 | will be used. The returned type will always be the same as inputs. 14 | 15 | Arguments: 16 | dets (torch.Tensor or np.ndarray): bboxes with scores. 17 | iou_thr (float): IoU threshold for NMS. 18 | device_id (int, optional): when `dets` is a numpy array, if `device_id` 19 | is None, then cpu nms is used, otherwise gpu_nms will be used. 20 | 21 | Returns: 22 | tuple: kept bboxes and indice, which is always the same data type as 23 | the input. 24 | """ 25 | # convert dets (tensor or numpy array) to tensor 26 | if isinstance(dets, torch.Tensor): 27 | is_numpy = False 28 | dets_th = dets 29 | elif isinstance(dets, np.ndarray): 30 | is_numpy = True 31 | device = 'cpu' if device_id is None else 'cuda:{}'.format(device_id) 32 | dets_th = torch.from_numpy(dets).to(device) 33 | else: 34 | raise TypeError( 35 | 'dets must be either a Tensor or numpy array, but got {}'.format( 36 | type(dets))) 37 | 38 | # execute cpu or cuda nms 39 | if dets_th.shape[0] == 0: 40 | inds = dets_th.new_zeros(0, dtype=torch.long) 41 | else: 42 | if dets_th.is_cuda: 43 | inds = nms_cuda.nms(dets_th, iou_thr) 44 | else: 45 | inds = nms_cpu.nms(dets_th, iou_thr) 46 | 47 | if is_numpy: 48 | inds = inds.cpu().numpy() 49 | return dets[inds, :], inds 50 | 51 | 52 | def soft_nms(dets, iou_thr, method='linear', sigma=0.5, min_score=1e-3): 53 | if isinstance(dets, torch.Tensor): 54 | is_tensor = True 55 | dets_np = dets.detach().cpu().numpy() 56 | elif isinstance(dets, np.ndarray): 57 | is_tensor = False 58 | dets_np = dets 59 | else: 60 | raise TypeError( 61 | 'dets must be either a Tensor or numpy array, but got {}'.format( 62 | type(dets))) 63 | 64 | method_codes = {'linear': 1, 'gaussian': 2} 65 | if method not in method_codes: 66 | raise ValueError('Invalid method for SoftNMS: {}'.format(method)) 67 | new_dets, inds = soft_nms_cpu( 68 | dets_np, 69 | iou_thr, 70 | method=method_codes[method], 71 | sigma=sigma, 72 | min_score=min_score) 73 | 74 | if is_tensor: 75 | return dets.new_tensor(new_dets), dets.new_tensor( 76 | inds, dtype=torch.long) 77 | else: 78 | return new_dets.astype(np.float32), inds.astype(np.int64) 79 | -------------------------------------------------------------------------------- /mmdet/ops/nms/src/nms_cpu.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include 3 | 4 | template 5 | at::Tensor nms_cpu_kernel(const at::Tensor& dets, const float threshold) { 6 | AT_ASSERTM(!dets.type().is_cuda(), "dets must be a CPU tensor"); 7 | 8 | if (dets.numel() == 0) { 9 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 10 | } 11 | 12 | auto x1_t = dets.select(1, 0).contiguous(); 13 | auto y1_t = dets.select(1, 1).contiguous(); 14 | auto x2_t = dets.select(1, 2).contiguous(); 15 | auto y2_t = dets.select(1, 3).contiguous(); 16 | auto scores = dets.select(1, 4).contiguous(); 17 | 18 | at::Tensor areas_t = (x2_t - x1_t + 1) * (y2_t - y1_t + 1); 19 | 20 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); 21 | 22 | auto ndets = dets.size(0); 23 | at::Tensor suppressed_t = 24 | at::zeros({ndets}, dets.options().dtype(at::kByte).device(at::kCPU)); 25 | 26 | auto suppressed = suppressed_t.data(); 27 | auto order = order_t.data(); 28 | auto x1 = x1_t.data(); 29 | auto y1 = y1_t.data(); 30 | auto x2 = x2_t.data(); 31 | auto y2 = y2_t.data(); 32 | auto areas = areas_t.data(); 33 | 34 | for (int64_t _i = 0; _i < ndets; _i++) { 35 | auto i = order[_i]; 36 | if (suppressed[i] == 1) continue; 37 | auto ix1 = x1[i]; 38 | auto iy1 = y1[i]; 39 | auto ix2 = x2[i]; 40 | auto iy2 = y2[i]; 41 | auto iarea = areas[i]; 42 | 43 | for (int64_t _j = _i + 1; _j < ndets; _j++) { 44 | auto j = order[_j]; 45 | if (suppressed[j] == 1) continue; 46 | auto xx1 = std::max(ix1, x1[j]); 47 | auto yy1 = std::max(iy1, y1[j]); 48 | auto xx2 = std::min(ix2, x2[j]); 49 | auto yy2 = std::min(iy2, y2[j]); 50 | 51 | auto w = std::max(static_cast(0), xx2 - xx1 + 1); 52 | auto h = std::max(static_cast(0), yy2 - yy1 + 1); 53 | auto inter = w * h; 54 | auto ovr = inter / (iarea + areas[j] - inter); 55 | if (ovr >= threshold) suppressed[j] = 1; 56 | } 57 | } 58 | return at::nonzero(suppressed_t == 0).squeeze(1); 59 | } 60 | 61 | at::Tensor nms(const at::Tensor& dets, const float threshold) { 62 | at::Tensor result; 63 | AT_DISPATCH_FLOATING_TYPES(dets.type(), "nms", [&] { 64 | result = nms_cpu_kernel(dets, threshold); 65 | }); 66 | return result; 67 | } 68 | 69 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 70 | m.def("nms", &nms, "non-maximum suppression"); 71 | } -------------------------------------------------------------------------------- /mmdet/ops/nms/src/nms_cuda.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include 3 | 4 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 5 | 6 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh); 7 | 8 | at::Tensor nms(const at::Tensor& dets, const float threshold) { 9 | CHECK_CUDA(dets); 10 | if (dets.numel() == 0) 11 | return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); 12 | return nms_cuda(dets, threshold); 13 | } 14 | 15 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 16 | m.def("nms", &nms, "non-maximum suppression"); 17 | } -------------------------------------------------------------------------------- /mmdet/ops/nms/src/nms_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | #include 10 | 11 | int const threadsPerBlock = sizeof(unsigned long long) * 8; 12 | 13 | __device__ inline float devIoU(float const * const a, float const * const b) { 14 | float left = max(a[0], b[0]), right = min(a[2], b[2]); 15 | float top = max(a[1], b[1]), bottom = min(a[3], b[3]); 16 | float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); 17 | float interS = width * height; 18 | float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); 19 | float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); 20 | return interS / (Sa + Sb - interS); 21 | } 22 | 23 | __global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, 24 | const float *dev_boxes, unsigned long long *dev_mask) { 25 | const int row_start = blockIdx.y; 26 | const int col_start = blockIdx.x; 27 | 28 | // if (row_start > col_start) return; 29 | 30 | const int row_size = 31 | min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); 32 | const int col_size = 33 | min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); 34 | 35 | __shared__ float block_boxes[threadsPerBlock * 5]; 36 | if (threadIdx.x < col_size) { 37 | block_boxes[threadIdx.x * 5 + 0] = 38 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; 39 | block_boxes[threadIdx.x * 5 + 1] = 40 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; 41 | block_boxes[threadIdx.x * 5 + 2] = 42 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; 43 | block_boxes[threadIdx.x * 5 + 3] = 44 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; 45 | block_boxes[threadIdx.x * 5 + 4] = 46 | dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; 47 | } 48 | __syncthreads(); 49 | 50 | if (threadIdx.x < row_size) { 51 | const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; 52 | const float *cur_box = dev_boxes + cur_box_idx * 5; 53 | int i = 0; 54 | unsigned long long t = 0; 55 | int start = 0; 56 | if (row_start == col_start) { 57 | start = threadIdx.x + 1; 58 | } 59 | for (i = start; i < col_size; i++) { 60 | if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { 61 | t |= 1ULL << i; 62 | } 63 | } 64 | const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); 65 | dev_mask[cur_box_idx * col_blocks + col_start] = t; 66 | } 67 | } 68 | 69 | // boxes is a N x 5 tensor 70 | at::Tensor nms_cuda(const at::Tensor boxes, float nms_overlap_thresh) { 71 | using scalar_t = float; 72 | AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor"); 73 | auto scores = boxes.select(1, 4); 74 | auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); 75 | auto boxes_sorted = boxes.index_select(0, order_t); 76 | 77 | int boxes_num = boxes.size(0); 78 | 79 | const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); 80 | 81 | scalar_t* boxes_dev = boxes_sorted.data(); 82 | 83 | THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState 84 | 85 | unsigned long long* mask_dev = NULL; 86 | //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, 87 | // boxes_num * col_blocks * sizeof(unsigned long long))); 88 | 89 | mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); 90 | 91 | dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), 92 | THCCeilDiv(boxes_num, threadsPerBlock)); 93 | dim3 threads(threadsPerBlock); 94 | nms_kernel<<>>(boxes_num, 95 | nms_overlap_thresh, 96 | boxes_dev, 97 | mask_dev); 98 | 99 | std::vector mask_host(boxes_num * col_blocks); 100 | THCudaCheck(cudaMemcpy(&mask_host[0], 101 | mask_dev, 102 | sizeof(unsigned long long) * boxes_num * col_blocks, 103 | cudaMemcpyDeviceToHost)); 104 | 105 | std::vector remv(col_blocks); 106 | memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); 107 | 108 | at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); 109 | int64_t* keep_out = keep.data(); 110 | 111 | int num_to_keep = 0; 112 | for (int i = 0; i < boxes_num; i++) { 113 | int nblock = i / threadsPerBlock; 114 | int inblock = i % threadsPerBlock; 115 | 116 | if (!(remv[nblock] & (1ULL << inblock))) { 117 | keep_out[num_to_keep++] = i; 118 | unsigned long long *p = &mask_host[0] + i * col_blocks; 119 | for (int j = nblock; j < col_blocks; j++) { 120 | remv[j] |= p[j]; 121 | } 122 | } 123 | } 124 | 125 | THCudaFree(state, mask_dev); 126 | // TODO improve this part 127 | return std::get<0>(order_t.index({ 128 | keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( 129 | order_t.device(), keep.scalar_type()) 130 | }).sort(0, false)); 131 | } -------------------------------------------------------------------------------- /mmdet/ops/nms/src/soft_nms_cpu.pyx: -------------------------------------------------------------------------------- 1 | # ---------------------------------------------------------- 2 | # Soft-NMS: Improving Object Detection With One Line of Code 3 | # Copyright (c) University of Maryland, College Park 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Navaneeth Bodla and Bharat Singh 6 | # Modified by Kai Chen 7 | # ---------------------------------------------------------- 8 | 9 | # cython: language_level=3, boundscheck=False 10 | 11 | import numpy as np 12 | cimport numpy as np 13 | 14 | 15 | cdef inline np.float32_t max(np.float32_t a, np.float32_t b): 16 | return a if a >= b else b 17 | 18 | cdef inline np.float32_t min(np.float32_t a, np.float32_t b): 19 | return a if a <= b else b 20 | 21 | 22 | def soft_nms_cpu( 23 | np.ndarray[float, ndim=2] boxes_in, 24 | float iou_thr, 25 | unsigned int method=1, 26 | float sigma=0.5, 27 | float min_score=0.001, 28 | ): 29 | boxes = boxes_in.copy() 30 | cdef unsigned int N = boxes.shape[0] 31 | cdef float iw, ih, box_area 32 | cdef float ua 33 | cdef int pos = 0 34 | cdef float maxscore = 0 35 | cdef int maxpos = 0 36 | cdef float x1, x2, y1, y2, tx1, tx2, ty1, ty2, ts, area, weight, ov 37 | inds = np.arange(N) 38 | 39 | for i in range(N): 40 | maxscore = boxes[i, 4] 41 | maxpos = i 42 | 43 | tx1 = boxes[i, 0] 44 | ty1 = boxes[i, 1] 45 | tx2 = boxes[i, 2] 46 | ty2 = boxes[i, 3] 47 | ts = boxes[i, 4] 48 | ti = inds[i] 49 | 50 | pos = i + 1 51 | # get max box 52 | while pos < N: 53 | if maxscore < boxes[pos, 4]: 54 | maxscore = boxes[pos, 4] 55 | maxpos = pos 56 | pos = pos + 1 57 | 58 | # add max box as a detection 59 | boxes[i, 0] = boxes[maxpos, 0] 60 | boxes[i, 1] = boxes[maxpos, 1] 61 | boxes[i, 2] = boxes[maxpos, 2] 62 | boxes[i, 3] = boxes[maxpos, 3] 63 | boxes[i, 4] = boxes[maxpos, 4] 64 | inds[i] = inds[maxpos] 65 | 66 | # swap ith box with position of max box 67 | boxes[maxpos, 0] = tx1 68 | boxes[maxpos, 1] = ty1 69 | boxes[maxpos, 2] = tx2 70 | boxes[maxpos, 3] = ty2 71 | boxes[maxpos, 4] = ts 72 | inds[maxpos] = ti 73 | 74 | tx1 = boxes[i, 0] 75 | ty1 = boxes[i, 1] 76 | tx2 = boxes[i, 2] 77 | ty2 = boxes[i, 3] 78 | ts = boxes[i, 4] 79 | 80 | pos = i + 1 81 | # NMS iterations, note that N changes if detection boxes fall below 82 | # threshold 83 | while pos < N: 84 | x1 = boxes[pos, 0] 85 | y1 = boxes[pos, 1] 86 | x2 = boxes[pos, 2] 87 | y2 = boxes[pos, 3] 88 | s = boxes[pos, 4] 89 | 90 | area = (x2 - x1 + 1) * (y2 - y1 + 1) 91 | iw = (min(tx2, x2) - max(tx1, x1) + 1) 92 | if iw > 0: 93 | ih = (min(ty2, y2) - max(ty1, y1) + 1) 94 | if ih > 0: 95 | ua = float((tx2 - tx1 + 1) * (ty2 - ty1 + 1) + area - iw * ih) 96 | ov = iw * ih / ua # iou between max box and detection box 97 | 98 | if method == 1: # linear 99 | if ov > iou_thr: 100 | weight = 1 - ov 101 | else: 102 | weight = 1 103 | elif method == 2: # gaussian 104 | weight = np.exp(-(ov * ov) / sigma) 105 | else: # original NMS 106 | if ov > iou_thr: 107 | weight = 0 108 | else: 109 | weight = 1 110 | 111 | boxes[pos, 4] = weight * boxes[pos, 4] 112 | 113 | # if box score falls below threshold, discard the box by 114 | # swapping with last box update N 115 | if boxes[pos, 4] < min_score: 116 | boxes[pos, 0] = boxes[N-1, 0] 117 | boxes[pos, 1] = boxes[N-1, 1] 118 | boxes[pos, 2] = boxes[N-1, 2] 119 | boxes[pos, 3] = boxes[N-1, 3] 120 | boxes[pos, 4] = boxes[N-1, 4] 121 | inds[pos] = inds[N - 1] 122 | N = N - 1 123 | pos = pos - 1 124 | 125 | pos = pos + 1 126 | 127 | return boxes[:N], inds[:N] 128 | -------------------------------------------------------------------------------- /mmdet/ops/roi_align/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions.roi_align import roi_align 2 | from .modules.roi_align import RoIAlign 3 | 4 | __all__ = ['roi_align', 'RoIAlign'] 5 | -------------------------------------------------------------------------------- /mmdet/ops/roi_align/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leotju/MGAN/4f2c621715b3e277d3f33d39af6764250367b205/mmdet/ops/roi_align/functions/__init__.py -------------------------------------------------------------------------------- /mmdet/ops/roi_align/functions/roi_align.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function 2 | from torch.nn.modules.utils import _pair 3 | 4 | from .. import roi_align_cuda 5 | 6 | 7 | class RoIAlignFunction(Function): 8 | 9 | @staticmethod 10 | def forward(ctx, features, rois, out_size, spatial_scale, sample_num=0): 11 | out_h, out_w = _pair(out_size) 12 | assert isinstance(out_h, int) and isinstance(out_w, int) 13 | ctx.spatial_scale = spatial_scale 14 | ctx.sample_num = sample_num 15 | ctx.save_for_backward(rois) 16 | ctx.feature_size = features.size() 17 | 18 | batch_size, num_channels, data_height, data_width = features.size() 19 | num_rois = rois.size(0) 20 | 21 | output = features.new_zeros(num_rois, num_channels, out_h, out_w) 22 | if features.is_cuda: 23 | roi_align_cuda.forward(features, rois, out_h, out_w, spatial_scale, 24 | sample_num, output) 25 | else: 26 | raise NotImplementedError 27 | 28 | return output 29 | 30 | @staticmethod 31 | def backward(ctx, grad_output): 32 | feature_size = ctx.feature_size 33 | spatial_scale = ctx.spatial_scale 34 | sample_num = ctx.sample_num 35 | rois = ctx.saved_tensors[0] 36 | assert (feature_size is not None and grad_output.is_cuda) 37 | 38 | batch_size, num_channels, data_height, data_width = feature_size 39 | out_w = grad_output.size(3) 40 | out_h = grad_output.size(2) 41 | 42 | grad_input = grad_rois = None 43 | if ctx.needs_input_grad[0]: 44 | grad_input = rois.new_zeros(batch_size, num_channels, data_height, 45 | data_width) 46 | roi_align_cuda.backward(grad_output.contiguous(), rois, out_h, 47 | out_w, spatial_scale, sample_num, 48 | grad_input) 49 | 50 | return grad_input, grad_rois, None, None, None 51 | 52 | 53 | roi_align = RoIAlignFunction.apply 54 | -------------------------------------------------------------------------------- /mmdet/ops/roi_align/gradcheck.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch.autograd import gradcheck 4 | 5 | import os.path as osp 6 | import sys 7 | sys.path.append(osp.abspath(osp.join(__file__, '../../'))) 8 | from roi_align import RoIAlign # noqa: E402 9 | 10 | feat_size = 15 11 | spatial_scale = 1.0 / 8 12 | img_size = feat_size / spatial_scale 13 | num_imgs = 2 14 | num_rois = 20 15 | 16 | batch_ind = np.random.randint(num_imgs, size=(num_rois, 1)) 17 | rois = np.random.rand(num_rois, 4) * img_size * 0.5 18 | rois[:, 2:] += img_size * 0.5 19 | rois = np.hstack((batch_ind, rois)) 20 | 21 | feat = torch.randn( 22 | num_imgs, 16, feat_size, feat_size, requires_grad=True, device='cuda:0') 23 | rois = torch.from_numpy(rois).float().cuda() 24 | inputs = (feat, rois) 25 | print('Gradcheck for roi align...') 26 | test = gradcheck(RoIAlign(3, spatial_scale), inputs, atol=1e-3, eps=1e-3) 27 | print(test) 28 | test = gradcheck(RoIAlign(3, spatial_scale, 2), inputs, atol=1e-3, eps=1e-3) 29 | print(test) 30 | -------------------------------------------------------------------------------- /mmdet/ops/roi_align/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leotju/MGAN/4f2c621715b3e277d3f33d39af6764250367b205/mmdet/ops/roi_align/modules/__init__.py -------------------------------------------------------------------------------- /mmdet/ops/roi_align/modules/roi_align.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from ..functions.roi_align import roi_align 5 | 6 | 7 | class RoIAlign(nn.Module): 8 | 9 | def __init__(self, 10 | out_size, 11 | spatial_scale, 12 | sample_num=0, 13 | use_torchvision=False): 14 | super(RoIAlign, self).__init__() 15 | 16 | self.out_size = out_size 17 | self.spatial_scale = float(spatial_scale) 18 | self.sample_num = int(sample_num) 19 | self.use_torchvision = use_torchvision 20 | 21 | def forward(self, features, rois): 22 | if self.use_torchvision: 23 | from torchvision.ops import roi_align as tv_roi_align 24 | return tv_roi_align(features, rois, _pair(self.out_size), 25 | self.spatial_scale, self.sample_num) 26 | else: 27 | return roi_align(features, rois, self.out_size, self.spatial_scale, 28 | self.sample_num) 29 | -------------------------------------------------------------------------------- /mmdet/ops/roi_align/src/roi_align_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | int ROIAlignForwardLaucher(const at::Tensor features, const at::Tensor rois, 7 | const float spatial_scale, const int sample_num, 8 | const int channels, const int height, 9 | const int width, const int num_rois, 10 | const int pooled_height, const int pooled_width, 11 | at::Tensor output); 12 | 13 | int ROIAlignBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, 14 | const float spatial_scale, const int sample_num, 15 | const int channels, const int height, 16 | const int width, const int num_rois, 17 | const int pooled_height, const int pooled_width, 18 | at::Tensor bottom_grad); 19 | 20 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 21 | #define CHECK_CONTIGUOUS(x) \ 22 | AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 23 | #define CHECK_INPUT(x) \ 24 | CHECK_CUDA(x); \ 25 | CHECK_CONTIGUOUS(x) 26 | 27 | int roi_align_forward_cuda(at::Tensor features, at::Tensor rois, 28 | int pooled_height, int pooled_width, 29 | float spatial_scale, int sample_num, 30 | at::Tensor output) { 31 | CHECK_INPUT(features); 32 | CHECK_INPUT(rois); 33 | CHECK_INPUT(output); 34 | 35 | // Number of ROIs 36 | int num_rois = rois.size(0); 37 | int size_rois = rois.size(1); 38 | 39 | if (size_rois != 5) { 40 | printf("wrong roi size\n"); 41 | return 0; 42 | } 43 | 44 | int num_channels = features.size(1); 45 | int data_height = features.size(2); 46 | int data_width = features.size(3); 47 | 48 | ROIAlignForwardLaucher(features, rois, spatial_scale, sample_num, 49 | num_channels, data_height, data_width, num_rois, 50 | pooled_height, pooled_width, output); 51 | 52 | return 1; 53 | } 54 | 55 | int roi_align_backward_cuda(at::Tensor top_grad, at::Tensor rois, 56 | int pooled_height, int pooled_width, 57 | float spatial_scale, int sample_num, 58 | at::Tensor bottom_grad) { 59 | CHECK_INPUT(top_grad); 60 | CHECK_INPUT(rois); 61 | CHECK_INPUT(bottom_grad); 62 | 63 | // Number of ROIs 64 | int num_rois = rois.size(0); 65 | int size_rois = rois.size(1); 66 | if (size_rois != 5) { 67 | printf("wrong roi size\n"); 68 | return 0; 69 | } 70 | 71 | int num_channels = bottom_grad.size(1); 72 | int data_height = bottom_grad.size(2); 73 | int data_width = bottom_grad.size(3); 74 | 75 | ROIAlignBackwardLaucher(top_grad, rois, spatial_scale, sample_num, 76 | num_channels, data_height, data_width, num_rois, 77 | pooled_height, pooled_width, bottom_grad); 78 | 79 | return 1; 80 | } 81 | 82 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 83 | m.def("forward", &roi_align_forward_cuda, "Roi_Align forward (CUDA)"); 84 | m.def("backward", &roi_align_backward_cuda, "Roi_Align backward (CUDA)"); 85 | } 86 | -------------------------------------------------------------------------------- /mmdet/ops/roi_pool/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions.roi_pool import roi_pool 2 | from .modules.roi_pool import RoIPool 3 | 4 | __all__ = ['roi_pool', 'RoIPool'] 5 | -------------------------------------------------------------------------------- /mmdet/ops/roi_pool/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leotju/MGAN/4f2c621715b3e277d3f33d39af6764250367b205/mmdet/ops/roi_pool/functions/__init__.py -------------------------------------------------------------------------------- /mmdet/ops/roi_pool/functions/roi_pool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Function 3 | from torch.nn.modules.utils import _pair 4 | 5 | from .. import roi_pool_cuda 6 | 7 | 8 | class RoIPoolFunction(Function): 9 | 10 | @staticmethod 11 | def forward(ctx, features, rois, out_size, spatial_scale): 12 | assert features.is_cuda 13 | out_h, out_w = _pair(out_size) 14 | assert isinstance(out_h, int) and isinstance(out_w, int) 15 | ctx.save_for_backward(rois) 16 | num_channels = features.size(1) 17 | num_rois = rois.size(0) 18 | out_size = (num_rois, num_channels, out_h, out_w) 19 | output = features.new_zeros(out_size) 20 | argmax = features.new_zeros(out_size, dtype=torch.int) 21 | roi_pool_cuda.forward(features, rois, out_h, out_w, spatial_scale, 22 | output, argmax) 23 | ctx.spatial_scale = spatial_scale 24 | ctx.feature_size = features.size() 25 | ctx.argmax = argmax 26 | 27 | return output 28 | 29 | @staticmethod 30 | def backward(ctx, grad_output): 31 | assert grad_output.is_cuda 32 | spatial_scale = ctx.spatial_scale 33 | feature_size = ctx.feature_size 34 | argmax = ctx.argmax 35 | rois = ctx.saved_tensors[0] 36 | assert feature_size is not None 37 | 38 | grad_input = grad_rois = None 39 | if ctx.needs_input_grad[0]: 40 | grad_input = grad_output.new_zeros(feature_size) 41 | roi_pool_cuda.backward(grad_output.contiguous(), rois, argmax, 42 | spatial_scale, grad_input) 43 | 44 | return grad_input, grad_rois, None, None 45 | 46 | 47 | roi_pool = RoIPoolFunction.apply 48 | -------------------------------------------------------------------------------- /mmdet/ops/roi_pool/gradcheck.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import gradcheck 3 | 4 | import os.path as osp 5 | import sys 6 | sys.path.append(osp.abspath(osp.join(__file__, '../../'))) 7 | from roi_pool import RoIPool # noqa: E402 8 | 9 | feat = torch.randn(4, 16, 15, 15, requires_grad=True).cuda() 10 | rois = torch.Tensor([[0, 0, 0, 50, 50], [0, 10, 30, 43, 55], 11 | [1, 67, 40, 110, 120]]).cuda() 12 | inputs = (feat, rois) 13 | print('Gradcheck for roi pooling...') 14 | test = gradcheck(RoIPool(4, 1.0 / 8), inputs, eps=1e-5, atol=1e-3) 15 | print(test) 16 | -------------------------------------------------------------------------------- /mmdet/ops/roi_pool/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Leotju/MGAN/4f2c621715b3e277d3f33d39af6764250367b205/mmdet/ops/roi_pool/modules/__init__.py -------------------------------------------------------------------------------- /mmdet/ops/roi_pool/modules/roi_pool.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.nn.modules.utils import _pair 3 | 4 | from ..functions.roi_pool import roi_pool 5 | 6 | 7 | class RoIPool(nn.Module): 8 | 9 | def __init__(self, out_size, spatial_scale, use_torchvision=False): 10 | super(RoIPool, self).__init__() 11 | 12 | self.out_size = out_size 13 | self.spatial_scale = float(spatial_scale) 14 | self.use_torchvision = use_torchvision 15 | 16 | def forward(self, features, rois): 17 | if self.use_torchvision: 18 | from torchvision.ops import roi_pool as tv_roi_pool 19 | return tv_roi_pool(features, rois, _pair(self.out_size), 20 | self.spatial_scale) 21 | else: 22 | return roi_pool(features, rois, self.out_size, self.spatial_scale) 23 | -------------------------------------------------------------------------------- /mmdet/ops/roi_pool/src/roi_pool_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | int ROIPoolForwardLaucher(const at::Tensor features, const at::Tensor rois, 7 | const float spatial_scale, const int channels, 8 | const int height, const int width, const int num_rois, 9 | const int pooled_h, const int pooled_w, 10 | at::Tensor output, at::Tensor argmax); 11 | 12 | int ROIPoolBackwardLaucher(const at::Tensor top_grad, const at::Tensor rois, 13 | const at::Tensor argmax, const float spatial_scale, 14 | const int batch_size, const int channels, 15 | const int height, const int width, 16 | const int num_rois, const int pooled_h, 17 | const int pooled_w, at::Tensor bottom_grad); 18 | 19 | #define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ") 20 | #define CHECK_CONTIGUOUS(x) \ 21 | AT_CHECK(x.is_contiguous(), #x, " must be contiguous ") 22 | #define CHECK_INPUT(x) \ 23 | CHECK_CUDA(x); \ 24 | CHECK_CONTIGUOUS(x) 25 | 26 | int roi_pooling_forward_cuda(at::Tensor features, at::Tensor rois, 27 | int pooled_height, int pooled_width, 28 | float spatial_scale, at::Tensor output, 29 | at::Tensor argmax) { 30 | CHECK_INPUT(features); 31 | CHECK_INPUT(rois); 32 | CHECK_INPUT(output); 33 | CHECK_INPUT(argmax); 34 | 35 | // Number of ROIs 36 | int num_rois = rois.size(0); 37 | int size_rois = rois.size(1); 38 | 39 | if (size_rois != 5) { 40 | printf("wrong roi size\n"); 41 | return 0; 42 | } 43 | 44 | int channels = features.size(1); 45 | int height = features.size(2); 46 | int width = features.size(3); 47 | 48 | ROIPoolForwardLaucher(features, rois, spatial_scale, channels, height, width, 49 | num_rois, pooled_height, pooled_width, output, argmax); 50 | 51 | return 1; 52 | } 53 | 54 | int roi_pooling_backward_cuda(at::Tensor top_grad, at::Tensor rois, 55 | at::Tensor argmax, float spatial_scale, 56 | at::Tensor bottom_grad) { 57 | CHECK_INPUT(top_grad); 58 | CHECK_INPUT(rois); 59 | CHECK_INPUT(argmax); 60 | CHECK_INPUT(bottom_grad); 61 | 62 | int pooled_height = top_grad.size(2); 63 | int pooled_width = top_grad.size(3); 64 | int num_rois = rois.size(0); 65 | int size_rois = rois.size(1); 66 | 67 | if (size_rois != 5) { 68 | printf("wrong roi size\n"); 69 | return 0; 70 | } 71 | int batch_size = bottom_grad.size(0); 72 | int channels = bottom_grad.size(1); 73 | int height = bottom_grad.size(2); 74 | int width = bottom_grad.size(3); 75 | 76 | ROIPoolBackwardLaucher(top_grad, rois, argmax, spatial_scale, batch_size, 77 | channels, height, width, num_rois, pooled_height, 78 | pooled_width, bottom_grad); 79 | 80 | return 1; 81 | } 82 | 83 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 84 | m.def("forward", &roi_pooling_forward_cuda, "Roi_Pooling forward (CUDA)"); 85 | m.def("backward", &roi_pooling_backward_cuda, "Roi_Pooling backward (CUDA)"); 86 | } 87 | -------------------------------------------------------------------------------- /mmdet/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .registry import Registry, build_from_cfg 2 | 3 | __all__ = ['Registry', 'build_from_cfg'] 4 | -------------------------------------------------------------------------------- /mmdet/utils/registry.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | 3 | import mmcv 4 | 5 | 6 | class Registry(object): 7 | 8 | def __init__(self, name): 9 | self._name = name 10 | self._module_dict = dict() 11 | 12 | def __repr__(self): 13 | format_str = self.__class__.__name__ + '(name={}, items={})'.format( 14 | self._name, list(self._module_dict.keys())) 15 | return format_str 16 | 17 | @property 18 | def name(self): 19 | return self._name 20 | 21 | @property 22 | def module_dict(self): 23 | return self._module_dict 24 | 25 | def get(self, key): 26 | return self._module_dict.get(key, None) 27 | 28 | def _register_module(self, module_class): 29 | """Register a module. 30 | 31 | Args: 32 | module (:obj:`nn.Module`): Module to be registered. 33 | """ 34 | if not inspect.isclass(module_class): 35 | raise TypeError('module must be a class, but got {}'.format( 36 | type(module_class))) 37 | module_name = module_class.__name__ 38 | if module_name in self._module_dict: 39 | raise KeyError('{} is already registered in {}'.format( 40 | module_name, self.name)) 41 | self._module_dict[module_name] = module_class 42 | 43 | def register_module(self, cls): 44 | self._register_module(cls) 45 | return cls 46 | 47 | 48 | def build_from_cfg(cfg, registry, default_args=None): 49 | """Build a module from config dict. 50 | 51 | Args: 52 | cfg (dict): Config dict. It should at least contain the key "type". 53 | registry (:obj:`Registry`): The registry to search the type from. 54 | default_args (dict, optional): Default initialization arguments. 55 | 56 | Returns: 57 | obj: The constructed object. 58 | """ 59 | assert isinstance(cfg, dict) and 'type' in cfg 60 | assert isinstance(default_args, dict) or default_args is None 61 | args = cfg.copy() 62 | obj_type = args.pop('type') 63 | if mmcv.is_str(obj_type): 64 | obj_type = registry.get(obj_type) 65 | if obj_type is None: 66 | raise KeyError('{} is not in the {} registry'.format( 67 | obj_type, registry.name)) 68 | elif not inspect.isclass(obj_type): 69 | raise TypeError('type must be a str or valid type, but got {}'.format( 70 | type(obj_type))) 71 | if default_args is not None: 72 | for name, value in default_args.items(): 73 | args.setdefault(name, value) 74 | return obj_type(**args) 75 | -------------------------------------------------------------------------------- /mmdet/version.py: -------------------------------------------------------------------------------- 1 | # GENERATED VERSION FILE 2 | # TIME: Sat Oct 12 11:09:31 2019 3 | 4 | __version__ = '0.6.0+7eec2c0' 5 | short_version = '0.6.0' 6 | -------------------------------------------------------------------------------- /tools/coco_eval.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | from mmdet.core import coco_eval 4 | 5 | 6 | def main(): 7 | parser = ArgumentParser(description='COCO Evaluation') 8 | parser.add_argument('result', help='result file path') 9 | parser.add_argument('--ann', help='annotation file path') 10 | parser.add_argument( 11 | '--types', 12 | type=str, 13 | nargs='+', 14 | choices=['proposal_fast', 'proposal', 'bbox', 'segm', 'keypoint'], 15 | default=['bbox'], 16 | help='result types') 17 | parser.add_argument( 18 | '--max-dets', 19 | type=int, 20 | nargs='+', 21 | default=[100, 300, 1000], 22 | help='proposal numbers, only used for recall evaluation') 23 | args = parser.parse_args() 24 | coco_eval(args.result, args.types, args.ann, args.max_dets) 25 | 26 | 27 | if __name__ == '__main__': 28 | main() 29 | -------------------------------------------------------------------------------- /tools/dist_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | 8 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ 9 | $(dirname "$0")/test.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON=${PYTHON:-"python"} 4 | 5 | CONFIG=$1 6 | GPUS=$2 7 | 8 | $PYTHON -m torch.distributed.launch --nproc_per_node=$GPUS \ 9 | $(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3} 10 | -------------------------------------------------------------------------------- /tools/slurm_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | CHECKPOINT=$4 9 | GPUS=${GPUS:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | PY_ARGS=${@:5} 13 | SRUN_ARGS=${SRUN_ARGS:-""} 14 | 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/test.py ${CONFIG} ${CHECKPOINT} --launcher="slurm" ${PY_ARGS} 24 | -------------------------------------------------------------------------------- /tools/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | 5 | PARTITION=$1 6 | JOB_NAME=$2 7 | CONFIG=$3 8 | WORK_DIR=$4 9 | GPUS=${5:-8} 10 | GPUS_PER_NODE=${GPUS_PER_NODE:-8} 11 | CPUS_PER_TASK=${CPUS_PER_TASK:-5} 12 | SRUN_ARGS=${SRUN_ARGS:-""} 13 | PY_ARGS=${PY_ARGS:-"--validate"} 14 | 15 | srun -p ${PARTITION} \ 16 | --job-name=${JOB_NAME} \ 17 | --gres=gpu:${GPUS_PER_NODE} \ 18 | --ntasks=${GPUS} \ 19 | --ntasks-per-node=${GPUS_PER_NODE} \ 20 | --cpus-per-task=${CPUS_PER_TASK} \ 21 | --kill-on-bad-exit=1 \ 22 | ${SRUN_ARGS} \ 23 | python -u tools/train.py ${CONFIG} --work_dir=${WORK_DIR} --launcher="slurm" ${PY_ARGS} 24 | -------------------------------------------------------------------------------- /tools/upgrade_model_version.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | 8 | def convert(in_file, out_file): 9 | """Convert keys in checkpoints. 10 | 11 | There can be some breaking changes during the development of mmdetection, 12 | and this tool is used for upgrading checkpoints trained with old versions 13 | to the latest one. 14 | """ 15 | checkpoint = torch.load(in_file) 16 | in_state_dict = checkpoint.pop('state_dict') 17 | out_state_dict = OrderedDict() 18 | for key, val in in_state_dict.items(): 19 | # Use ConvModule instead of nn.Conv2d in RetinaNet 20 | # cls_convs.0.weight -> cls_convs.0.conv.weight 21 | m = re.search(r'(cls_convs|reg_convs).\d.(weight|bias)', key) 22 | if m is not None: 23 | param = m.groups()[1] 24 | new_key = key.replace(param, 'conv.{}'.format(param)) 25 | out_state_dict[new_key] = val 26 | continue 27 | 28 | out_state_dict[key] = val 29 | checkpoint['state_dict'] = out_state_dict 30 | torch.save(checkpoint, out_file) 31 | 32 | 33 | def main(): 34 | parser = argparse.ArgumentParser(description='Upgrade model version') 35 | parser.add_argument('in_file', help='input checkpoint file') 36 | parser.add_argument('out_file', help='output checkpoint file') 37 | args = parser.parse_args() 38 | convert(args.in_file, args.out_file) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | --------------------------------------------------------------------------------